1 | package felix.operator; |
2 | |
3 | import java.io.BufferedWriter; |
4 | import java.io.File; |
5 | import java.io.FileInputStream; |
6 | import java.io.FileOutputStream; |
7 | import java.io.OutputStreamWriter; |
8 | import java.sql.SQLException; |
9 | import java.util.ArrayList; |
10 | import java.util.Arrays; |
11 | import java.util.Collections; |
12 | import java.util.HashMap; |
13 | import java.util.HashSet; |
14 | import java.util.List; |
15 | |
16 | import org.postgresql.PGConnection; |
17 | |
18 | |
19 | import tuffy.db.RDB; |
20 | import tuffy.mln.Predicate; |
21 | import tuffy.mln.Type; |
22 | import tuffy.ra.ConjunctiveQuery; |
23 | import tuffy.ra.Expression; |
24 | import tuffy.ra.Function; |
25 | import tuffy.util.Config; |
26 | import tuffy.util.ExceptionMan; |
27 | import tuffy.util.StringMan; |
28 | import tuffy.util.Timer; |
29 | import tuffy.util.UIMan; |
30 | import tuffy.util.UnionFind; |
31 | import felix.dstruct.DataMovementOperator; |
32 | import felix.dstruct.FelixPredicate; |
33 | import felix.dstruct.FelixQuery; |
34 | import felix.dstruct.StatOperator; |
35 | import felix.dstruct.FelixPredicate.FPProperty; |
36 | import felix.parser.FelixCommandOptions; |
37 | import felix.util.FelixConfig; |
38 | import felix.util.FelixUIMan; |
39 | |
40 | /** |
41 | * A COREF operator in Felix. |
42 | * @author Ce Zhang |
43 | * |
44 | */ |
45 | public class COREFOperator extends StatOperator{ |
46 | |
47 | /** |
48 | * DMOs for soft positive edges. |
49 | */ |
50 | ArrayList<DataMovementOperator> softPosDMOs = new ArrayList<DataMovementOperator>(); |
51 | |
52 | /** |
53 | * DMOs for soft negative edges. |
54 | */ |
55 | ArrayList<DataMovementOperator> softNegDMOs = new ArrayList<DataMovementOperator>(); |
56 | |
57 | /** |
58 | * DMOs for hard positive edges. |
59 | */ |
60 | ArrayList<DataMovementOperator> hardPosDMOs = new ArrayList<DataMovementOperator>(); |
61 | |
62 | /** |
63 | * DMOs for hard negative edges. |
64 | */ |
65 | ArrayList<DataMovementOperator> hardNegDMOs = new ArrayList<DataMovementOperator>(); |
66 | |
67 | /** |
68 | * DMO for efficient representation of hard negative rules. |
69 | */ |
70 | DataMovementOperator nodeClassRule = null; |
71 | |
72 | /** |
73 | * DMO for efficient representation of hard negative rules. |
74 | */ |
75 | DataMovementOperator classTagsRule = null; |
76 | |
77 | /** |
78 | * DMO for retrieving the node's domain. |
79 | */ |
80 | DataMovementOperator nodeListDMO = null; |
81 | |
82 | /** |
83 | * The DataMovementOperator which is the union of all hard-neg DataMovementOperators. |
84 | */ |
85 | DataMovementOperator hardNegDMO1 = null; |
86 | |
87 | /** |
88 | * The DataMovementOperator which is the union of all hard-neg DataMovementOperators. |
89 | */ |
90 | DataMovementOperator hardNegDMO2 = null; |
91 | |
92 | /** |
93 | * The DataMovementOperator which is the union of all hard-pos DataMovementOperators. |
94 | */ |
95 | DataMovementOperator hardPosDMO = null; |
96 | |
97 | /** |
98 | * The DataMovementOperator which is the union of all soft-pos DataMovementOperators. |
99 | */ |
100 | DataMovementOperator softPosDMO1 = null; |
101 | |
102 | /** |
103 | * The DataMovementOperator which is the union of all soft-pos DataMovementOperators. |
104 | */ |
105 | DataMovementOperator softPosDMO2 = null; |
106 | |
107 | /** |
108 | * Whether represent clusterings results using pairwise representation. |
109 | * Note that, setting this parameter to true will cause quadratic numbers of |
110 | * result tuples. |
111 | */ |
112 | public boolean usePairwiseRepresentation = false; |
113 | |
114 | /** |
115 | * Target predicate of this Coref operator. |
116 | */ |
117 | FelixPredicate corefHead; |
118 | |
119 | /** |
120 | * Special syntax sugar for coref. |
121 | */ |
122 | HashMap<Integer, Integer> nodeClass = null; |
123 | |
124 | /** |
125 | * Special syntax sugar for coref. |
126 | */ |
127 | HashMap<Integer, HashSet<Integer>> classTags = null; |
128 | |
129 | /** |
130 | * The constructor of COREFOperator. |
131 | * @param _fq Felix query. |
132 | * @param _goalPredicates target predicates of this coref operator. |
133 | * @param _opt Command line options of this Felix run. |
134 | */ |
135 | public COREFOperator(FelixQuery _fq, HashSet<FelixPredicate> _goalPredicates, |
136 | FelixCommandOptions _opt) { |
137 | super(_fq, _goalPredicates, _opt); |
138 | for(FelixPredicate p : _goalPredicates){ |
139 | p.isCorefPredicate = true; |
140 | } |
141 | this.type = OPType.COREF; |
142 | this.precedence = 10; |
143 | } |
144 | |
145 | /** |
146 | * Get the size of domain on which clustering is conducted. |
147 | * @param p clustering predicate |
148 | * @return |
149 | */ |
150 | public int getDomainSize(Predicate p){ |
151 | return p.getTypeAt(0).size()/this.partitionedInto; |
152 | } |
153 | |
154 | |
155 | /** |
156 | * Generate Data Movement Operator used by this Coref Operator. |
157 | * @param rules rules defining this operator. |
158 | */ |
159 | public void prepareDMO(HashSet<ConjunctiveQuery> rules){ |
160 | |
161 | // build linear representation table for this operator |
162 | if(fq.getPredByName(corefHead.getName() + "_map") == null){ |
163 | FelixPredicate pmap = new FelixPredicate(corefHead.getName() + "_map", true); |
164 | pmap.appendArgument(corefHead.getTypeAt(0)); |
165 | pmap.appendArgument(corefHead.getTypeAt(1)); |
166 | pmap.prepareDB(db); |
167 | fq.addFelixPredicate(pmap); |
168 | } |
169 | |
170 | |
171 | // generate Data Movement Operator according to conjunctive queries |
172 | for(ConjunctiveQuery cq : rules){ |
173 | |
174 | /* |
175 | if(cq.sourceClause.hasEmbeddedWeight()){ |
176 | Expression e = new Expression(Function.GreaterThan); |
177 | Expression tmpe1 = Expression.exprVariableBinding(cq.sourceClause.getVarWeight()); |
178 | Expression tmpe = Expression.exprConstInteger(0); |
179 | e.addArgument(tmpe1); |
180 | e.addArgument(tmpe); |
181 | e.changeName = false; |
182 | |
183 | cq.addConstraint(e); |
184 | }*/ |
185 | |
186 | DataMovementOperator dmo = new DataMovementOperator(db, this); |
187 | dmo.logicQueryPlan.addQuery(cq, cq.head.getPred().getArgs(), |
188 | new ArrayList<String>(Arrays.asList("weight")) ); |
189 | dmo.whichToBound.add(cq.head.getTerms().get(0).toString()); |
190 | |
191 | // hard rule |
192 | if(cq.sourceClause.isHardClause() && cq.getWeight() > 0){ |
193 | dmo.predictedBB = 0; |
194 | dmo.PredictedFF = 1; |
195 | dmo.PredictedBF = 0; |
196 | hardPosDMOs.add(dmo); |
197 | allDMOs.add(dmo); |
198 | } |
199 | // hard negative rule |
200 | else if(cq.sourceClause.isHardClause() && cq.getWeight() < 0){ |
201 | dmo.predictedBB = 0; |
202 | dmo.PredictedFF = 1; |
203 | dmo.PredictedBF = this.getDomainSize(corefHead); |
204 | hardNegDMOs.add(dmo); |
205 | allDMOs.add(dmo); |
206 | } |
207 | // soft incomplete rule |
208 | //TODO |
209 | // else if( (!cq.sourceClause.isHardClause() && cq.getWeight() > 0) || |
210 | // cq.sourceClause.hasEmbeddedWeight() |
211 | // ){ |
212 | else if( (!cq.sourceClause.isHardClause() )){ |
213 | dmo.predictedBB = 0; |
214 | dmo.PredictedFF = 0; |
215 | dmo.PredictedBF = this.getDomainSize(corefHead); |
216 | softPosDMOs.add(dmo); |
217 | allDMOs.add(dmo); |
218 | }else{ |
219 | UIMan.warn("The following rule is ignored in the COREFOperator!\n" + cq); |
220 | } |
221 | |
222 | } |
223 | |
224 | // generate Data Movement Operator used in this operator |
225 | |
226 | //first, the DMO used to fetch the node domain |
227 | DataMovementOperator dmo = new DataMovementOperator(db, this); |
228 | dmo.predictedBB = 0; dmo.PredictedFF = 1; dmo.PredictedBF = 0; |
229 | dmo.logicQueryPlan.addQuery(db.getPrepareStatement( |
230 | "SELECT DISTINCT constantID FROM " + corefHead.getTypeAt(0).getRelName()), |
231 | new ArrayList<String>(Arrays.asList("constantID")), new ArrayList<String>()); |
232 | dmo.allowOptimization = false; |
233 | nodeListDMO = dmo; |
234 | allDMOs.add(dmo); |
235 | |
236 | //second, the DMO for the union of all hard-pos DMO |
237 | if(this.hardPosDMOs.size() > 0){ |
238 | this.hardPosDMO = DataMovementOperator.UnionAll(db, this, |
239 | this.hardPosDMOs, "000", new ArrayList<Integer>()); |
240 | this.hardPosDMO.isIntermediaDMO = true; |
241 | allDMOs.add(hardPosDMO); |
242 | } |
243 | |
244 | //third, the DMO for the union of all soft-pos DMO |
245 | if(this.softPosDMOs.size() > 0){ |
246 | |
247 | this.softPosDMO1 = DataMovementOperator.UnionAll(db, this, |
248 | this.softPosDMOs, "100", new ArrayList<Integer>()); |
249 | this.softPosDMO1.isIntermediaDMO = true; |
250 | allDMOs.add(softPosDMO1); |
251 | |
252 | /* |
253 | DataMovementOperator groupedSoftPosDMO1 = new DataMovementOperator(db, this); |
254 | groupedSoftPosDMO1.predictedBB = 0; groupedSoftPosDMO1.PredictedFF = 1; groupedSoftPosDMO1.PredictedBF = 0; |
255 | groupedSoftPosDMO1.logicQueryPlan.addQuery(db.getPrepareStatement( |
256 | "SELECT " + softPosDMO1.finalSelList.get(0) + "," |
257 | + softPosDMO1.finalSelList.get(1) + "," |
258 | + "sum(" + softPosDMO1.finalSelList.get(0) + ") as sumweight " |
259 | + "FROM " + softPosDMO1.getAllFreeViewName() + " " |
260 | + "GROUP BY " + softPosDMO1.finalSelList.get(0) + "," |
261 | + softPosDMO1.finalSelList.get(1) + " " |
262 | + "WHERE " + "sumweight > 0" |
263 | + " AND " + softPosDMO1.finalSelList.get(0) + " = ?"), |
264 | softPosDMO1.finalSelList, new ArrayList<String>()); |
265 | groupedSoftPosDMO1.allowOptimization = false; |
266 | allDMOs.add(groupedSoftPosDMO1); |
267 | */ |
268 | |
269 | |
270 | this.softPosDMO2 = DataMovementOperator.UnionAll(db, this, |
271 | this.softPosDMOs, "010", new ArrayList<Integer>()); |
272 | this.softPosDMO2.isIntermediaDMO = true; |
273 | allDMOs.add(softPosDMO2); |
274 | |
275 | /* |
276 | DataMovementOperator groupedSoftPosDMO2 = new DataMovementOperator(db, this); |
277 | groupedSoftPosDMO2.predictedBB = 0; groupedSoftPosDMO2.PredictedFF = 1; groupedSoftPosDMO2.PredictedBF = 0; |
278 | groupedSoftPosDMO2.logicQueryPlan.addQuery(db.getPrepareStatement( |
279 | "SELECT " + softPosDMO2.finalSelList.get(0) + "," |
280 | + softPosDMO2.finalSelList.get(1) + "," |
281 | + "sum(" + softPosDMO2.finalSelList.get(0) + ") as sumweight " |
282 | + "FROM " + softPosDMO2.getAllFreeViewName() + " " |
283 | + "GROUP BY " + softPosDMO2.finalSelList.get(0) + "," |
284 | + softPosDMO2.finalSelList.get(1) + " " |
285 | + "WHERE " + "sumweight > 0" |
286 | + " AND " + softPosDMO2.finalSelList.get(0) + " = ?"), |
287 | softPosDMO1.finalSelList, new ArrayList<String>()); |
288 | groupedSoftPosDMO2.allowOptimization = false; |
289 | allDMOs.add(groupedSoftPosDMO2); |
290 | */ |
291 | } |
292 | |
293 | //forth, the DMO for the union of all hard-negative DMO |
294 | if(this.hardNegDMOs.size() > 0){ |
295 | this.hardNegDMO1 = DataMovementOperator.UnionAll(db, this, |
296 | this.hardNegDMOs, "100", new ArrayList<Integer>()); |
297 | this.hardNegDMO1.isIntermediaDMO = true; |
298 | allDMOs.add(hardNegDMO1); |
299 | |
300 | this.hardNegDMO2 = DataMovementOperator.UnionAll(db, this, |
301 | this.hardNegDMOs, "010", new ArrayList<Integer>()); |
302 | this.hardNegDMO2.isIntermediaDMO = true; |
303 | allDMOs.add(hardNegDMO2); |
304 | } |
305 | |
306 | //then process special rules |
307 | for(ConjunctiveQuery cq : fq.getSpecialClusteringRules(this.corefHead.getName())){ |
308 | |
309 | if(cq.type == ConjunctiveQuery.CLUSTERING_RULE_TYPE.NODE_CLASS){ |
310 | |
311 | nodeClassRule = new DataMovementOperator(db, this); |
312 | nodeClassRule.logicQueryPlan.addQuery(cq, cq.head.getPred().getArgs(), |
313 | new ArrayList<String>() ); |
314 | nodeClassRule.whichToBound.add(cq.head.getTerms().get(0).toString()); |
315 | nodeClassRule.allowOptimization = false; |
316 | allDMOs.add(nodeClassRule); |
317 | |
318 | }else if(cq.type == ConjunctiveQuery.CLUSTERING_RULE_TYPE.CLASS_TAGS){ |
319 | classTagsRule = new DataMovementOperator(db, this); |
320 | classTagsRule.logicQueryPlan.addQuery(cq, cq.head.getPred().getArgs(), |
321 | new ArrayList<String>() ); |
322 | classTagsRule.whichToBound.add(cq.head.getTerms().get(0).toString()); |
323 | classTagsRule.allowOptimization = false; |
324 | allDMOs.add(classTagsRule); |
325 | |
326 | }else{ |
327 | ExceptionMan.die("No special rules other than NODE_CLASS " + |
328 | "and CLASS_TAGS are supported in Felix!"); |
329 | } |
330 | |
331 | } |
332 | |
333 | |
334 | } |
335 | |
336 | boolean prepared = false; |
337 | |
338 | /** |
339 | * Prepares operator for execution. |
340 | */ |
341 | @Override |
342 | public void prepare() { |
343 | |
344 | softPosDMOs.clear(); |
345 | |
346 | softNegDMOs.clear(); |
347 | |
348 | hardPosDMOs.clear(); |
349 | |
350 | hardNegDMOs.clear(); |
351 | |
352 | allDMOs.clear(); |
353 | //if(!prepared){ |
354 | |
355 | db = RDB.getRDBbyConfig(Config.db_schema); |
356 | |
357 | corefHead = this.getTargetPredicateIfHasOnlyOne(); |
358 | HashSet<ConjunctiveQuery> rules = |
359 | this.translateFelixClasesIntoFactorGraphEdgeQueries(corefHead, true, |
360 | this.inputPredicateScope, |
361 | FPProperty.NON_RECUR, |
362 | FPProperty.CHAIN_RECUR, |
363 | FPProperty.OTHER_RECUR); |
364 | |
365 | this.prepareDMO(rules); |
366 | prepared = true; |
367 | //} |
368 | } |
369 | |
370 | /** |
371 | * Executes operator. |
372 | */ |
373 | @Override |
374 | public void run() { |
375 | UIMan.print(">>> Start Running " + this); |
376 | |
377 | try{ |
378 | |
379 | this.isMarginal = belongsToBucket.isMarginal(); |
380 | |
381 | Timer.start("Coref-Op" + this.getId()); |
382 | |
383 | if(corefHead == null){ |
384 | throw new Exception("The head of this Coref operator is NULL."); |
385 | } |
386 | |
387 | cluster(); |
388 | |
389 | //this.oriMLN.dumpMapAnswerForPredicate(options.fout+"_coref_" + headPredicate.getName() + "_op" + id, |
390 | // headPredicate, false); |
391 | |
392 | FelixUIMan.println(0,0, "\n>>> {" + this + "} uses " + Timer.elapsed("Coref-Op" + this.getId())); |
393 | |
394 | if(!options.useDualDecomposition){ |
395 | this.belongsToBucket.runNextOperatorInBucket(); |
396 | } |
397 | |
398 | db.commit(); |
399 | db.close(); |
400 | |
401 | // TODO!!!!!!!!!!!!!!! |
402 | // db.close(); |
403 | |
404 | } catch (Exception e) { |
405 | e.printStackTrace(); |
406 | } |
407 | |
408 | } |
409 | |
410 | @Override |
411 | public String explain() { |
412 | // TODO Auto-generated method stub |
413 | return null; |
414 | } |
415 | |
416 | /** |
417 | * Clustering worker. |
418 | */ |
419 | public void cluster() throws Exception{ |
420 | |
421 | procSepcialRules(); |
422 | |
423 | // get domain for the clustering predicate |
424 | // We assume P(type,type) |
425 | ArrayList<Integer> nodes = new ArrayList<Integer>(); |
426 | this.nodeListDMO.execute(null, new ArrayList<Integer>()); |
427 | while(this.nodeListDMO.next()){ |
428 | nodes.add(this.nodeListDMO.getNext(1)); |
429 | } |
430 | |
431 | FelixUIMan.println(2,0,"#nodes = " + nodes.size()); |
432 | |
433 | Collections.shuffle(nodes); |
434 | HashMap<Integer, Integer> rankMap = new HashMap<Integer, Integer>(); |
435 | |
436 | ArrayList<Integer> ranks = new ArrayList<Integer>(); |
437 | for(int i=0; i<nodes.size(); i++){ |
438 | rankMap.put(nodes.get(i), i); |
439 | ranks.add(i); |
440 | } |
441 | |
442 | UnionFind<Integer> clusters = new UnionFind<Integer>(); |
443 | clusters.makeUnionFind(ranks); |
444 | ranks = null; |
445 | |
446 | Timer.start("clustering"); |
447 | int ct = 0; |
448 | int edges = 0; |
449 | |
450 | HashMap<Integer, HashSet<Integer>> hardClusters = new HashMap<Integer, HashSet<Integer>>(); |
451 | |
452 | for(int i=0; i<nodes.size(); i++){ |
453 | HashSet<Integer> s = new HashSet<Integer>(); |
454 | s.add(nodes.get(i)); |
455 | hardClusters.put(i, s); |
456 | } |
457 | |
458 | if(this.hardPosDMO != null){ |
459 | //UIMan.println(">>> Processing hard positive edges..."); |
460 | Timer.start("hardpos"); |
461 | db.disableAutoCommitForNow(); |
462 | |
463 | this.hardPosDMO.execute(null, new ArrayList<Integer>()); |
464 | int cnt = 0; |
465 | while(this.hardPosDMO.next()){ |
466 | cnt ++; |
467 | if(cnt % 100000000 == 0){ |
468 | // UIMan.print("*"); |
469 | FelixUIMan.println(2,0, "# hard edges: " + cnt); |
470 | } |
471 | Integer i = rankMap.get(this.hardPosDMO.getNext(1)); |
472 | Integer j = rankMap.get(this.hardPosDMO.getNext(2)); |
473 | i = clusters.getRoot(i); |
474 | j = clusters.getRoot(j); |
475 | if(i == j) continue; |
476 | if(i < j){ |
477 | hardClusters.get(i).addAll(hardClusters.get(j)); |
478 | hardClusters.remove(j); |
479 | }else{ |
480 | hardClusters.get(j).addAll(hardClusters.get(i)); |
481 | hardClusters.remove(i); |
482 | } |
483 | clusters.unionByValue(i, j); |
484 | } |
485 | |
486 | int x = 0; |
487 | for(int y : hardClusters.keySet()){ |
488 | x += hardClusters.get(y).size(); |
489 | } |
490 | |
491 | db.restoreAutoCommitState(); |
492 | //Timer.printElapsed("hardpos"); |
493 | } |
494 | |
495 | |
496 | for(int i=0; i<nodes.size(); i++){ |
497 | |
498 | int s = nodes.get(i); |
499 | if(ct % 100000 == 0){ |
500 | //UIMan.print("."); |
501 | FelixUIMan.println(2,0,ct + "/" + nodes.size() +" : " + |
502 | Timer.elapsed("clustering") + " edges : " + edges); |
503 | int nc = clusters.getNumClusters(); |
504 | FelixUIMan.println(2,0,"#clusters = " + nc); |
505 | } |
506 | ct ++; |
507 | |
508 | Integer root = clusters.getRoot(i); |
509 | |
510 | if(root != i) continue; |
511 | HashSet<Integer> classesInCluster = null; |
512 | |
513 | if(nodeClassRule != null && classTagsRule != null){ |
514 | classesInCluster = new HashSet<Integer>(); |
515 | for(int n : hardClusters.get(root)){ |
516 | classesInCluster.add(nodeClass.get(n)); |
517 | } |
518 | } |
519 | |
520 | List<Integer> wl = this.retrieveNeighbors(s); |
521 | HashSet<Integer> brokenByHardNeg = new HashSet<Integer>(); |
522 | brokenByHardNeg.addAll(this.retrieveHardNegEdges(s)); |
523 | |
524 | edges += wl.size(); |
525 | |
526 | //if(couldLinkPairwiseDMO != null){ |
527 | // group.addAll(hardClusters.get(i)); |
528 | //} |
529 | HashSet<Integer> merged = new HashSet<Integer>(); |
530 | merged.add(root); |
531 | |
532 | for(int t : wl){ |
533 | |
534 | int j = rankMap.get(t); |
535 | Integer r2 = clusters.getRoot(j); |
536 | if(j <= i || j != r2 || r2 == root) continue; |
537 | if(brokenByHardNeg.contains(t)) continue; |
538 | |
539 | if(nodeClassRule != null && classTagsRule != null){ |
540 | HashSet<Integer> classesInOtherCluster = new HashSet<Integer>(); |
541 | for(int n : hardClusters.get(r2)){ |
542 | classesInOtherCluster.add(nodeClass.get(n)); |
543 | } |
544 | classesInOtherCluster.removeAll(classesInCluster); |
545 | boolean compatible = true; |
546 | if(!classesInOtherCluster.isEmpty()){ |
547 | labf: |
548 | for(int x : classesInCluster){ |
549 | for(int y : classesInOtherCluster){ |
550 | HashSet<Integer> xt = classTags.get(x); |
551 | HashSet<Integer> yt = classTags.get(y); |
552 | xt.retainAll(yt); |
553 | if(xt.isEmpty()){ |
554 | compatible = false; |
555 | break labf; |
556 | } |
557 | } |
558 | } |
559 | } |
560 | if(compatible){ |
561 | classesInCluster.addAll(classesInOtherCluster); |
562 | }else{ |
563 | continue; |
564 | } |
565 | } |
566 | |
567 | clusters.unionByValue(root, r2); |
568 | brokenByHardNeg.addAll(this.retrieveHardNegEdges(t)); |
569 | merged.add(r2); |
570 | } |
571 | } |
572 | |
573 | int nc = clusters.getNumClusters(); |
574 | FelixUIMan.println(2,0,"# clusters = " + nc); |
575 | FelixUIMan.println(2,0,"# edges : " + edges); |
576 | |
577 | this.dumpAnswerToDBTable(corefHead, clusters, nodes); |
578 | } |
579 | |
580 | /** |
581 | * Get soft-pos neighbors of a given node. |
582 | * @param m1 |
583 | * @return |
584 | * @throws SQLException |
585 | */ |
586 | public List<Integer> retrieveNeighbors(Integer m1) throws SQLException{ |
587 | |
588 | ArrayList<Integer> ret = new ArrayList<Integer>(); |
589 | HashSet<Integer> ns = new HashSet<Integer>(); |
590 | |
591 | if(this.softPosDMO1 != null){ |
592 | |
593 | this.softPosDMO1.execute(null , new ArrayList<Integer>(Arrays.asList(m1))); |
594 | |
595 | while(this.softPosDMO1.next()){ |
596 | Integer i = softPosDMO1.getNext(1) + softPosDMO1.getNext(2) - m1; |
597 | if(!ns.contains(i)){ |
598 | ns.add(i); |
599 | ret.add(i); |
600 | } |
601 | } |
602 | } |
603 | |
604 | if(this.softPosDMO2 != null){ |
605 | |
606 | this.softPosDMO2.execute(null , new ArrayList<Integer>(Arrays.asList(m1))); |
607 | |
608 | while(this.softPosDMO2.next()){ |
609 | Integer i = softPosDMO2.getNext(1) + softPosDMO2.getNext(2) - m1; |
610 | if(!ns.contains(i)){ |
611 | ns.add(i); |
612 | ret.add(i); |
613 | } |
614 | } |
615 | } |
616 | |
617 | return ret; |
618 | } |
619 | |
620 | /** |
621 | * Get hard-neg neighbors of a given node. |
622 | * @param m1 |
623 | * @return |
624 | * @throws SQLException |
625 | */ |
626 | public List<Integer> retrieveHardNegEdges(Integer m1) throws SQLException{ |
627 | |
628 | ArrayList<Integer> ret = new ArrayList<Integer>(); |
629 | HashSet<Integer> ns = new HashSet<Integer>(); |
630 | |
631 | if(this.hardNegDMO1 != null){ |
632 | |
633 | this.hardNegDMO1.execute(null , new ArrayList<Integer>(Arrays.asList(m1))); |
634 | |
635 | while(this.hardNegDMO1.next()){ |
636 | Integer i = hardNegDMO1.getNext(1) + hardNegDMO1.getNext(2) - m1; |
637 | if(!ns.contains(i)){ |
638 | ns.add(i); |
639 | ret.add(i); |
640 | } |
641 | } |
642 | } |
643 | |
644 | if(this.hardNegDMO2 != null){ |
645 | |
646 | this.hardNegDMO2.execute(null , new ArrayList<Integer>(Arrays.asList(m1))); |
647 | |
648 | while(this.hardNegDMO2.next()){ |
649 | Integer i = hardNegDMO2.getNext(1) + hardNegDMO2.getNext(2) - m1; |
650 | if(!ns.contains(i)){ |
651 | ns.add(i); |
652 | ret.add(i); |
653 | } |
654 | } |
655 | } |
656 | |
657 | return ret; |
658 | } |
659 | |
660 | /** |
661 | * Process Tag and Class rules. |
662 | */ |
663 | private void procSepcialRules(){ |
664 | //System.out.println(">>> Processing special rules..."); |
665 | if(nodeClassRule != null){ |
666 | nodeClass = new HashMap<Integer, Integer>(); |
667 | nodeClassRule.execute(null, new ArrayList<Integer>()); |
668 | while(nodeClassRule.next()){ |
669 | int a = nodeClassRule.getNext(1); |
670 | int b = nodeClassRule.getNext(2); |
671 | nodeClass.put(a, b); |
672 | } |
673 | } |
674 | |
675 | if(classTagsRule != null){ |
676 | classTags = new HashMap<Integer, HashSet<Integer>>(); |
677 | classTagsRule.execute(null, new ArrayList<Integer>()); |
678 | while(classTagsRule.next()){ |
679 | int a = classTagsRule.getNext(1); |
680 | int b = classTagsRule.getNext(2); |
681 | if(!classTags.containsKey(a)){ |
682 | classTags.put(a, new HashSet<Integer>()); |
683 | } |
684 | classTags.get(a).add(b); |
685 | } |
686 | |
687 | } |
688 | |
689 | } |
690 | |
691 | /** |
692 | * Dump answers to a database table (or create view for it). |
693 | * @param p clustering predicate. |
694 | * @param clusters clustering result. |
695 | * @param nodes domain on which clustering is conducted. |
696 | */ |
697 | public void dumpAnswerToDBTable(Predicate p, UnionFind<Integer> clusters, ArrayList<Integer> nodes){ |
698 | |
699 | File loadingFile = new File(Config.getLoadingDir(), "loading_cg_" + p.getRelName() + "_op" + this.getId()); |
700 | //p.nextTupleID = 0; |
701 | |
702 | Predicate pmap = fq.getPredByName(p.getName() + "_map"); |
703 | //pmap.nextTupleID = 0; |
704 | |
705 | String relLinear = pmap.getRelName(); |
706 | |
707 | try { |
708 | BufferedWriter loadingFileWriter = new BufferedWriter(new OutputStreamWriter |
709 | (new FileOutputStream(loadingFile),"UTF8")); |
710 | |
711 | HashMap<Integer, Integer> map = clusters.getPartitionMap(); |
712 | HashMap<Integer, HashSet<Integer>> label2mention = new HashMap<Integer, HashSet<Integer>>(); |
713 | for(int k : map.keySet()){ |
714 | int c = map.get(k); |
715 | k = nodes.get(k); |
716 | c = nodes.get(c); |
717 | HashSet<Integer> set = label2mention.get(c); |
718 | if(set == null){ |
719 | set = new HashSet<Integer>(); |
720 | label2mention.put(c, set); |
721 | } |
722 | set.add(k); |
723 | } |
724 | |
725 | if(this.usePairwiseRepresentation){ |
726 | |
727 | for(Integer clusterID : label2mention.keySet()){ |
728 | if(label2mention.get(clusterID).size() == 0){ |
729 | continue; |
730 | } |
731 | |
732 | for(Integer node1 : label2mention.get(clusterID)){ |
733 | for(Integer node2 : label2mention.get(clusterID)){ |
734 | |
735 | ArrayList<String> parts = new ArrayList<String>(); |
736 | //parts.add(Integer.toString(p.nextTupleID(p.nextTupleID++))); |
737 | parts.add("TRUE"); |
738 | if(options.useDualDecomposition){ |
739 | parts.add(Integer.toString(1)); |
740 | }else{ |
741 | parts.add(Integer.toString(2)); |
742 | } |
743 | |
744 | |
745 | parts.add(Integer.toString(node1)); |
746 | parts.add(Integer.toString(node2)); |
747 | |
748 | loadingFileWriter.append(StringMan.join(",", parts) + "\n"); |
749 | |
750 | } |
751 | } |
752 | } |
753 | }else{ |
754 | for(Integer clusterID : label2mention.keySet()){ |
755 | if(label2mention.get(clusterID).size() == 0){ |
756 | continue; |
757 | } |
758 | |
759 | int newClusterID = clusterID; |
760 | for(Integer node1 : label2mention.get(clusterID)){ |
761 | if(node1 < newClusterID){ |
762 | newClusterID = node1; |
763 | } |
764 | } |
765 | |
766 | for(Integer node1 : label2mention.get(clusterID)){ |
767 | |
768 | ArrayList<String> parts = new ArrayList<String>(); |
769 | //parts.add(Integer.toString(p.nextTupleID(pmap.nextTupleID++))); |
770 | //parts.add("TRUE"); |
771 | |
772 | if(options.useDualDecomposition){ |
773 | parts.add("TRUE"); |
774 | parts.add(Integer.toString(2)); |
775 | }else{ |
776 | parts.add("TRUE"); |
777 | parts.add(Integer.toString(2)); |
778 | } |
779 | |
780 | parts.add(Integer.toString(node1)); |
781 | parts.add(Integer.toString(newClusterID)); |
782 | |
783 | loadingFileWriter.append(StringMan.join(",", parts) + "\n"); |
784 | |
785 | } |
786 | } |
787 | } |
788 | loadingFileWriter.close(); |
789 | |
790 | if(this.usePairwiseRepresentation){ |
791 | |
792 | FileInputStream in = new FileInputStream(loadingFile); |
793 | PGConnection con = (PGConnection)db.getConnection(); |
794 | |
795 | String sql; |
796 | //String sql = "DELETE FROM " + p.getRelName(); |
797 | //db.update(sql); |
798 | //db.vacuum(p.getRelName()); |
799 | |
800 | sql = "COPY " + p.getRelName() + "(truth, club, " + StringMan.commaList(p.getArgs()) + " ) FROM STDIN CSV"; |
801 | con.getCopyAPI().copyIn(sql, in); |
802 | in.close(); |
803 | p.isCurrentlyView = false; |
804 | |
805 | }else{ |
806 | |
807 | FileInputStream in = new FileInputStream(loadingFile); |
808 | PGConnection con = (PGConnection)db.getConnection(); |
809 | |
810 | String sql; |
811 | db.dropView(p.getRelName()); |
812 | db.dropTable(p.getRelName()); |
813 | |
814 | //sql = "DELETE FROM " + relLinear; |
815 | //db.update(sql); |
816 | //db.vacuum(relLinear); |
817 | |
818 | if(options.useDualDecomposition){ |
819 | |
820 | if(FelixConfig.isFirstRunOfDD){ |
821 | sql = "COPY " + relLinear + "(truth, club," + StringMan.commaList(pmap.getArgs()) + " ) FROM STDIN CSV"; |
822 | con.getCopyAPI().copyIn(sql, in); |
823 | in.close(); |
824 | //p.setHasSoftEvidence(true); |
825 | } |
826 | }else{ |
827 | sql = "COPY " + relLinear + "(truth, club, " + StringMan.commaList(pmap.getArgs()) + " ) FROM STDIN CSV"; |
828 | con.getCopyAPI().copyIn(sql, in); |
829 | in.close(); |
830 | } |
831 | |
832 | db.dropIndex(relLinear + "_label_idx"); |
833 | sql = "CREATE INDEX " + relLinear + "_label_idx on " + |
834 | relLinear + "(" + pmap.getArgs().get(1) + ")"; |
835 | db.update(sql); |
836 | |
837 | db.dropIndex(relLinear + "_node_idx"); |
838 | sql = "CREATE INDEX " + relLinear + "_node_idx on " + |
839 | relLinear + "(" + pmap.getArgs().get(0) + ")"; |
840 | db.update(sql); |
841 | |
842 | db.analyze(relLinear); |
843 | |
844 | db.dropSequence(p.getRelName()+"_seq"); |
845 | sql = "CREATE SEQUENCE " + p.getRelName()+"_seq" + |
846 | " START WITH 1"; |
847 | db.update(sql); |
848 | |
849 | sql = "CREATE VIEW " + p.getRelName() + " AS SELECT nextval('" + p.getRelName()+"_seq')" + |
850 | "::integer AS id," + |
851 | "TRUE::boolean AS truth, NULL::float as prior, " + |
852 | "2::integer as club, NULL::integer as atomID, " + |
853 | "t1."+pmap.getArgs().get(0)+"::integer as " + p.getArgs().get(0) + |
854 | ", t2."+pmap.getArgs().get(0)+"::integer as " + p.getArgs().get(1) + |
855 | " FROM " + relLinear + " t1, " + relLinear + " t2" + |
856 | " WHERE t1."+pmap.getArgs().get(1)+"=t2."+pmap.getArgs().get(1)+""; |
857 | p.isCurrentlyView = true; |
858 | //p.hasSoftEvidence = false; |
859 | p.setHasSoftEvidence(false); |
860 | db.update(sql); |
861 | |
862 | ((FelixPredicate)p).viewDef = sql + ""; |
863 | |
864 | if(options.useDualDecomposition){ |
865 | for(FelixPredicate fp : this.dd_CommonOutput){ |
866 | if(!fp.getName().equals(this.corefHead.getName()) |
867 | && !fp.getName().equals(this.corefHead.corefMAPPredicate.getName())){ |
868 | ExceptionMan.die("COREF 868: There must be something wrong with the parser!"); |
869 | } |
870 | |
871 | in = new FileInputStream(loadingFile); |
872 | String tableName; |
873 | String viewName; |
874 | |
875 | if(fp.isCorefMapPredicate){ |
876 | tableName = this.dd_commonOutputPredicate_2_tableName.get(fp); |
877 | }else{ |
878 | tableName = this.dd_commonOutputPredicate_2_tableName.get(fp.corefMAPPredicate); |
879 | } |
880 | |
881 | if(fp.isCorefMapPredicate){ |
882 | viewName = this.dd_commonOutputPredicate_2_tableName.get(fp.oriCorefPredicate); |
883 | }else{ |
884 | viewName = this.dd_commonOutputPredicate_2_tableName.get(fp); |
885 | } |
886 | |
887 | if(viewName != null) |
888 | db.dropView(viewName); |
889 | //db.dropTable(tableName); |
890 | |
891 | sql = "DELETE from " + tableName; |
892 | db.execute(sql); |
893 | |
894 | sql = "COPY " + tableName + "(truth, club, " + StringMan.commaList(pmap.getArgs()) + " ) FROM STDIN CSV"; |
895 | con.getCopyAPI().copyIn(sql, in); |
896 | in.close(); |
897 | |
898 | db.execute("UPDATE " + tableName + " SET prior = 1"); |
899 | |
900 | db.dropIndex(tableName + "_label_idx"); |
901 | sql = "CREATE INDEX " + tableName + "_label_idx on " + |
902 | tableName + "(" + pmap.getArgs().get(1) + ")"; |
903 | db.update(sql); |
904 | |
905 | db.dropIndex(tableName + "_node_idx"); |
906 | sql = "CREATE INDEX " + tableName + "_node_idx on " + |
907 | tableName + "(" + pmap.getArgs().get(0) + ")"; |
908 | db.update(sql); |
909 | |
910 | db.analyze(tableName); |
911 | |
912 | if(viewName != null){ |
913 | |
914 | db.dropSequence(viewName+"_seq"); |
915 | sql = "CREATE SEQUENCE " + viewName+"_seq" + |
916 | " START WITH 1"; |
917 | db.update(sql); |
918 | |
919 | sql = "CREATE VIEW " + viewName + " AS SELECT nextval('" + viewName+"_seq')" + |
920 | "::integer AS id," + |
921 | "TRUE::boolean AS truth, 1 as prior, " + |
922 | "2::integer as club, NULL::integer as atomID, " + |
923 | "t1."+pmap.getArgs().get(0)+"::integer as " + p.getArgs().get(0) + |
924 | ", t2."+pmap.getArgs().get(0)+"::integer as " + p.getArgs().get(1) + |
925 | " FROM " + tableName + " t1, " + tableName + " t2" + |
926 | " WHERE t1."+pmap.getArgs().get(1)+"=t2."+pmap.getArgs().get(1)+""; |
927 | db.update(sql); |
928 | |
929 | |
930 | |
931 | } |
932 | |
933 | } |
934 | } |
935 | |
936 | } |
937 | |
938 | } catch (Exception e) { |
939 | ExceptionMan.handle(e); |
940 | } |
941 | |
942 | } |
943 | |
944 | |
945 | @Override |
946 | public void learn() { |
947 | |
948 | } |
949 | |
950 | } |