1 | package tuffy.ground; |
2 | |
3 | |
4 | import java.sql.ResultSet; |
5 | import java.sql.SQLException; |
6 | import java.util.ArrayList; |
7 | import java.util.HashSet; |
8 | |
9 | import tuffy.db.RDB; |
10 | import tuffy.db.SQLMan; |
11 | import tuffy.infer.ds.GClause; |
12 | import tuffy.mln.Clause; |
13 | import tuffy.mln.Literal; |
14 | import tuffy.mln.MarkovLogicNetwork; |
15 | import tuffy.mln.Predicate; |
16 | import tuffy.mln.Term; |
17 | import tuffy.util.Config; |
18 | import tuffy.util.ExceptionMan; |
19 | import tuffy.util.StringMan; |
20 | import tuffy.util.Timer; |
21 | import tuffy.util.UIMan; |
22 | |
23 | /** |
24 | * This class handles the grounding process of MLN inference/learning |
25 | * with SQL queries. See our technical report at |
26 | * http://tuffguy.cs.wisc.edu/tuffy/tuffy-tech-report.pdf |
27 | * |
28 | * as well as prior works: |
29 | * http://alchemy.cs.washington.edu/papers/singla06a/singla06a.pdf |
30 | * http://alchemy.cs.washington.edu/papers/pdfs/shavlik-natarajan09.pdf |
31 | * |
32 | * Alchemy implements "lazy inference" with a one-step |
33 | * look-ahead strategy for initial groundings; |
34 | * we generalize it into a closure algorithm that avoid incremental "activation" |
35 | * altogether. |
36 | */ |
37 | |
38 | public class Grounding { |
39 | /** |
40 | * Relational database used for grounding. |
41 | */ |
42 | private RDB db; |
43 | |
44 | /** |
45 | * MLN to be grounded. |
46 | */ |
47 | private MarkovLogicNetwork mln; |
48 | |
49 | /** |
50 | * Number of active atoms. |
51 | */ |
52 | private int numAtoms; |
53 | |
54 | /** |
55 | * Number of active clauses. |
56 | */ |
57 | private int numClauses; |
58 | |
59 | /** |
60 | * Get the MLN object used for grounding. |
61 | */ |
62 | public MarkovLogicNetwork getMLN(){ |
63 | return mln; |
64 | } |
65 | |
66 | /** |
67 | * Create a grounding worker for an MLN. |
68 | */ |
69 | public Grounding(MarkovLogicNetwork mln){ |
70 | bindDB(mln.getRDB()); |
71 | this.mln = mln; |
72 | } |
73 | |
74 | /** |
75 | * Return the number of active atoms in the grounding result. |
76 | */ |
77 | public int getNumAtoms(){ |
78 | return numAtoms; |
79 | } |
80 | |
81 | /** |
82 | * Return the number of active clauses in the grounding result. |
83 | */ |
84 | public int getNumClauses(){ |
85 | return numClauses; |
86 | } |
87 | |
88 | private void createAtomTable(String rel){ |
89 | db.dropTable(rel); |
90 | // create atoms table |
91 | if(Config.gp == true){ |
92 | String sql = "CREATE TABLE " + rel + |
93 | "(" + |
94 | "tupleID bigint DEFAULT 0, " + |
95 | "atomID SERIAL, " + |
96 | "predID INT DEFAULT 0, " + |
97 | //"compID INT DEFAULT 0, " + |
98 | //"partID INT DEFAULT 0, " + |
99 | "blockID INT DEFAULT NULL, " + |
100 | "prob FLOAT DEFAULT NULL, " + |
101 | "truth BOOL DEFAULT FALSE " + |
102 | ") DISTRIBUTED BY (tupleID)"; |
103 | db.update(sql); |
104 | }else{ |
105 | String sql = "CREATE TABLE " + rel + |
106 | "(" + |
107 | "tupleID bigint DEFAULT 0, " + |
108 | "atomID SERIAL, " + |
109 | "predID INT DEFAULT 0, " + |
110 | //"compID INT DEFAULT 0, " + |
111 | //"partID INT DEFAULT 0, " + |
112 | "blockID INT DEFAULT NULL, " + |
113 | "prob FLOAT DEFAULT NULL, " + |
114 | "truth BOOL DEFAULT FALSE " + |
115 | ")"; |
116 | db.update(sql); |
117 | } |
118 | } |
119 | |
120 | private void createClauseTable(String rel){ |
121 | db.dropTable(rel); |
122 | ArrayList<String> fields = new ArrayList<String>(); |
123 | fields.add("cid SERIAL PRIMARY KEY"); |
124 | fields.add("lits INT[]"); |
125 | fields.add("weight FLOAT8"); |
126 | fields.add("fcid INT[]"); |
127 | fields.add("ffcid text[]"); |
128 | String sql = "CREATE TABLE " + rel + |
129 | StringMan.commaListParen(fields); |
130 | db.update(sql); |
131 | } |
132 | |
133 | |
134 | private void createActTables(){ |
135 | for(Predicate p : mln.getAllPred()){ |
136 | db.dropTable(p.getRelAct()); |
137 | String sql = "CREATE TABLE " + p.getRelAct() + |
138 | "(id bigint)"; |
139 | db.update(sql); |
140 | } |
141 | } |
142 | |
143 | private void destroyActTables(){ |
144 | for(Predicate p : mln.getAllPred()){ |
145 | db.dropTable(p.getRelAct()); |
146 | } |
147 | } |
148 | |
149 | |
150 | |
151 | /** |
152 | * Activate "soft evidence" atoms. |
153 | */ |
154 | private void activateSoftEvidence(){ |
155 | int cnt = 0; |
156 | UIMan.verbose(2, ">>> Activating soft evidence atoms..."); |
157 | for(Predicate p : mln.getAllPred()) { |
158 | String iql = "INSERT INTO " + p.getRelAct() + |
159 | " SELECT id FROM " + p.getRelName() + |
160 | " WHERE prior > 0 AND prior >= " + |
161 | Config.soft_evidence_activation_threshold + |
162 | " AND prior < 1 AND id NOT IN (SELECT id FROM " + p.getRelAct() + ")"; |
163 | db.update(iql); |
164 | cnt += db.getLastUpdateRowCount(); |
165 | } |
166 | if(cnt > 0){ |
167 | UIMan.verbose(2, "### active soft evidence = " + |
168 | UIMan.comma(cnt)); |
169 | } |
170 | } |
171 | |
172 | |
173 | /** |
174 | * Activate all the query atoms that are true in the training data. |
175 | * Used by learning. |
176 | */ |
177 | private void activateQueryAtoms(){ |
178 | int cnt = 0; |
179 | UIMan.verbose(2, ">>> Activating query atoms..."); |
180 | for(Predicate p : mln.getAllPred()) { |
181 | String iql = "INSERT INTO " + p.getRelAct() + |
182 | " SELECT id FROM " + p.getRelName() + |
183 | " WHERE ((club = 3)) AND truth = true " + |
184 | "AND id NOT IN (SELECT id FROM " + p.getRelAct() + ")"; |
185 | db.update(iql); |
186 | cnt += db.getLastUpdateRowCount(); |
187 | } |
188 | UIMan.verbose(2, "### active query atoms = " + UIMan.comma(cnt)); |
189 | } |
190 | |
191 | |
192 | private void activateUnknownAtoms(){ |
193 | int cnt = 0; |
194 | UIMan.verbose(2, ">>> Activating all unknown atoms..."); |
195 | for(Predicate p : mln.getAllPred()) { |
196 | String iql = "INSERT INTO " + p.getRelAct() + |
197 | " SELECT id FROM " + p.getRelName() + |
198 | " WHERE club < 2 " + |
199 | "AND id NOT IN (SELECT id FROM " + p.getRelAct() + ")"; |
200 | db.update(iql); |
201 | cnt += db.getLastUpdateRowCount(); |
202 | } |
203 | UIMan.verbose(2, "### active unknown atoms = " + UIMan.comma(cnt)); |
204 | } |
205 | |
206 | /** |
207 | * Bind to a database connection, and initialize global database |
208 | * objects. |
209 | */ |
210 | private void bindDB(RDB adb) { |
211 | db = adb; |
212 | |
213 | String sql; |
214 | sql = "CREATE OR REPLACE FUNCTION " + "unitNegativeClause" + |
215 | "(lits int[]) RETURNS INT AS $$\n" + |
216 | "BEGIN\n" + |
217 | "IF array_upper(lits, 1) > 1 THEN RETURN 0; END IF;\n" + |
218 | "RETURN lits[1];\n" + |
219 | "END;\n" + SQLMan.funcTail() + " IMMUTABLE"; |
220 | db.update(sql); |
221 | |
222 | /* |
223 | seqActiveName = SQLMan.seqName("active_atoms"); |
224 | db.dropSequence(seqActiveName); |
225 | db.commit(); |
226 | String sql = "CREATE SEQUENCE " + seqActiveName + |
227 | " START WITH 1"; |
228 | db.update(sql); |
229 | |
230 | sql = "CREATE OR REPLACE FUNCTION " + "convert_id" + |
231 | "(list INT[], oid INT, nid INT) RETURNS INT[] AS $$\n" + |
232 | "DECLARE\n" + |
233 | "nlist INT[]; \n" + |
234 | "BEGIN\n" + |
235 | "nlist := list;" + |
236 | "FOR i IN 1 .. array_upper(list,1) LOOP\n" + |
237 | "IF list[i]=oid THEN nlist[i]:=nid;" + |
238 | "ELSIF list[i]=-oid THEN nlist[i]:=-nid;" + |
239 | "END IF;" + |
240 | "END LOOP;\n" + |
241 | "RETURN UNIQ(SORT(nlist));\n" + |
242 | "END;\n" + SQLMan.funcTail() + " IMMUTABLE"; |
243 | db.update(sql); |
244 | */ |
245 | } |
246 | |
247 | |
248 | /** |
249 | * Construct the MRF. First compute the closure of active atoms, |
250 | * then active clauses. |
251 | */ |
252 | public void constructMRF(){ |
253 | UIMan.println(">>> Grounding..."); |
254 | |
255 | UIMan.verbose(1, ">>> Computing closure of active atoms..."); |
256 | String sql; |
257 | createActTables(); |
258 | activateSoftEvidence(); |
259 | if(Config.learning_mode) activateQueryAtoms(); |
260 | if(Config.mark_all_atoms_active){ |
261 | activateUnknownAtoms(); |
262 | }else{ |
263 | computeActiveAtoms(); |
264 | } |
265 | |
266 | UIMan.verbose(2, ">>> Gathering active atoms..."); |
267 | createAtomTable(mln.relAtoms); |
268 | numAtoms = populateAtomTable(mln.relAtoms); |
269 | UIMan.verbose(2, "### active atoms = " + numAtoms); |
270 | |
271 | UIMan.verbose(1, ">>> Computing active clauses..."); |
272 | String cbuffer = "mln" + mln.getID() + "_cbuffer"; |
273 | sql = "CREATE TABLE " + cbuffer + "(list INT[], weight FLOAT8, " + |
274 | "fcid INT, ffcid text)"; |
275 | db.update(sql); |
276 | this.computeActiveClauses(cbuffer); |
277 | this.addSoftEvidClauses(mln.relAtoms, cbuffer); |
278 | this.addKeyConstraintClauses(mln.relAtoms, cbuffer); |
279 | |
280 | createClauseTable(mln.relClauses); |
281 | numClauses = this.consolidateClauses(cbuffer, mln.relClauses); |
282 | |
283 | if(!Config.learning_mode){ |
284 | db.dropTable(cbuffer); |
285 | } |
286 | destroyActTables(); |
287 | |
288 | UIMan.println("### atoms = " + UIMan.comma(numAtoms) + "; clauses = " + UIMan.comma(numClauses)); |
289 | } |
290 | |
291 | private int populateAtomTable(String relAtoms){ |
292 | String sql; |
293 | for(Predicate p : mln.getAllPred()) { |
294 | if(p.isImmutable()) continue; |
295 | |
296 | sql = "INSERT INTO " + relAtoms + "(tupleID,predID,truth,prob) " + |
297 | "SELECT id," + p.getID() + ",truth,prior FROM " + p.getRelName() + |
298 | " WHERE id IN (SELECT id FROM " + p.getRelAct() + ")"; |
299 | db.update(sql); |
300 | sql = "UPDATE " + p.getRelName() + " pt SET atomID=NULL"; |
301 | db.update(sql); |
302 | sql = "UPDATE " + p.getRelName() + " pt SET atomID=ra.atomID " + |
303 | " FROM " + relAtoms + " ra " + |
304 | " WHERE ra.predID=" + p.getID() + " AND ra.tupleID=pt.id"; |
305 | |
306 | //UIMan.verbose(3, "----- popularting " + p + "\t" + db.explain(sql)); |
307 | |
308 | db.update(sql); |
309 | db.vacuum(p.getRelName()); |
310 | db.analyze(p.getRelName()); |
311 | |
312 | } |
313 | db.vacuum(relAtoms); |
314 | db.analyze(relAtoms); |
315 | int numVars = (int)db.countTuples(relAtoms); |
316 | return numVars; |
317 | } |
318 | |
319 | /** |
320 | * Compute the closure of active atoms. |
321 | * |
322 | * For a positive clause, active atoms are those with positive sense, |
323 | * plus those with negative sense but truth value may not be the |
324 | * default FALSE. |
325 | * Those with negative sense and default truth value will be true if |
326 | * we set default value of atoms as false, and therefore do not generate |
327 | * violated grounded clauses. |
328 | * |
329 | * For a negative clause, active atoms are those with negative sense. |
330 | * |
331 | * There are multiple rounds in this function. |
332 | * The goal of multiple rounds is compute a closure for all possible |
333 | * active atoms. For example, say we assume the default truth for |
334 | * atom is FALSE. Although at first round, we do not need |
335 | * to compute the groundings for negative literature $!p$ in a positive clause |
336 | * except those positive evidences (because they do not introduce any |
337 | * violations), other clauses may introduce active atoms with the same |
338 | * predicate $p$ in the active set. |
339 | * In this case, these introduced atoms can be flipped, so their truth |
340 | * values are not necessarily FALSE. Under this circumstance, the negative |
341 | * literals may also be FALSE. |
342 | * So there may be more groundings for the first clause, and therefore, |
343 | * more than one rounds is necessary to adjust that. |
344 | * |
345 | * If in the first round, no more atoms of predicate $p$ is introduced |
346 | * in active table, then according to the previous analysis, re-grounding |
347 | * of this clause is not necessary. |
348 | * |
349 | * @return number of active atoms |
350 | */ |
351 | private void computeActiveAtoms() { |
352 | boolean converged = false; |
353 | int cnt = 1; |
354 | int frontier = -2; |
355 | String relTemp = "temp_clauses"; |
356 | HashSet<Predicate> changedLastTime = new HashSet<Predicate>(); |
357 | HashSet<Predicate> changedThisTime = new HashSet<Predicate>(); |
358 | |
359 | for(Predicate p : mln.getAllPred()){ |
360 | changedLastTime.add(p); |
361 | db.analyze(p.getRelName()); |
362 | } |
363 | |
364 | while(!converged) { |
365 | |
366 | if(Config.gp){ |
367 | for(Predicate p : mln.getAllPred()) { |
368 | if(p.isImmutable()) continue; |
369 | |
370 | String sql2 = "UPDATE " + p.getRelName() + " pt SET atomID=NULL"; |
371 | db.update(sql2); |
372 | |
373 | sql2 = "UPDATE " + p.getRelName() + " pt SET atomID=-1 " + |
374 | " WHERE id IN (SELECT id FROM " + p.getRelAct() + ")"; |
375 | UIMan.verbose(3, sql2); |
376 | db.execute(sql2); |
377 | |
378 | } |
379 | } |
380 | |
381 | |
382 | |
383 | |
384 | UIMan.verboseInline(1, ">>> Round #" + (cnt++) + ":"); |
385 | UIMan.verbose(2, ""); |
386 | converged = true; |
387 | for(Clause c : mln.getRelevantClauses()) { |
388 | |
389 | HashSet<Boolean> possibleClausePos = new HashSet<Boolean>(); |
390 | if(c.hasEmbeddedWeight()){ |
391 | possibleClausePos.add(true); |
392 | possibleClausePos.add(false); |
393 | }else{ |
394 | possibleClausePos.add(c.isPositiveClause()); |
395 | } |
396 | |
397 | for(boolean posClause : possibleClausePos){ |
398 | |
399 | // optimization: check necessity |
400 | boolean worth = false, fresh = false; |
401 | for(Literal lit : c.getRegLiterals()){ |
402 | if((lit.getSense()==posClause) && !lit.getPred().isImmutable()){ |
403 | worth = true; |
404 | } |
405 | if(changedLastTime.contains(lit.getPred())) { |
406 | fresh = true; |
407 | } |
408 | } |
409 | if(!worth || !fresh) continue; |
410 | // ground could-be-violated clauses |
411 | ArrayList<String> clubs = new ArrayList<String>(); |
412 | ArrayList<String> ids = new ArrayList<String>(); |
413 | |
414 | ArrayList<String> actFrom = new ArrayList<String>(); |
415 | |
416 | for(Literal l : c.getRegLiterals()){ |
417 | |
418 | actFrom.add(l.getPred().getRelAct() + " at" + l.getIdx() |
419 | + " ON (at" + l.getIdx() + ".id = t" + l.getIdx() + ".id)"); |
420 | |
421 | if(l.getPred().isImmutable()) continue; |
422 | if(l.getSense() != posClause) continue; |
423 | ids.add("t" + l.getIdx() + ".id as id" + l.getIdx()); |
424 | clubs.add("t" + l.getIdx() + ".club as club" + l.getIdx()); |
425 | |
426 | } |
427 | |
428 | String sql; |
429 | |
430 | if(Config.gp == true){ |
431 | //TODO: SLOW |
432 | sql = "SELECT DISTINCT " + StringMan.commaList(clubs)+ ", " + |
433 | StringMan.commaList(ids) + " FROM " + |
434 | c.sqlFromList |
435 | + " WHERE "; |
436 | |
437 | }else{ |
438 | sql = "SELECT DISTINCT " + StringMan.commaList(clubs)+ ", " + |
439 | StringMan.commaList(ids) + " FROM " + |
440 | c.sqlFromList + " WHERE "; |
441 | ; |
442 | //+ |
443 | //" LEFT OUTER JOIN " + StringMan.join(" LEFT OUTER JOIN ", actFrom) + " WHERE "; |
444 | |
445 | } |
446 | |
447 | ArrayList<String> conds = new ArrayList<String>(); |
448 | // used to exclude all-evidence clauses |
449 | ArrayList<String> negActConds = new ArrayList<String>(); |
450 | |
451 | for(Literal lit : c.getRegLiterals()) { |
452 | Predicate p = lit.getPred(); |
453 | String rp = "t" + lit.getIdx(); |
454 | String sra = "(SELECT * FROM " + p.getRelAct() + ")"; |
455 | |
456 | ArrayList<String> iconds = new ArrayList<String>(); |
457 | String fc = (lit.getSense()?"'0'" : "'1'"); |
458 | iconds.add(rp + ".truth=" + fc); // explicit evidence |
459 | if(lit.getSense()){ |
460 | //TODO: double check |
461 | if(!lit.getPred().isCompletelySepcified()){ |
462 | iconds.add(rp + ".id IS NULL"); // implicit false evidence |
463 | } |
464 | } |
465 | |
466 | if(Config.gp == false){ |
467 | iconds.add(rp + ".id IN " + sra); // current active atoms |
468 | }else{ |
469 | //iconds.add(rp + ".id = at" + lit.getIdx() + ".id"); |
470 | // iconds.add("at" + lit.getIdx() + ".id <> NULL"); |
471 | iconds.add(rp + ".atomID = -1"); |
472 | } |
473 | |
474 | if(lit.getPred().isClosedWorld()){ |
475 | |
476 | }else{ |
477 | |
478 | if(lit.getSense() || !posClause) { |
479 | iconds.add("(" + rp + ".club < 2)"); // unknown truth |
480 | } |
481 | |
482 | if(!posClause && !lit.getSense()) { // negative clause, negative literal |
483 | negActConds.add("(" + rp + ".club < 2)"); |
484 | } |
485 | |
486 | } |
487 | |
488 | conds.add(SQLMan.orSelCond(iconds)); |
489 | } |
490 | if(!posClause) { |
491 | if(negActConds.isEmpty()) continue; |
492 | conds.add(SQLMan.orSelCond(negActConds)); |
493 | } |
494 | |
495 | /* |
496 | // discard same-variable-opposite-sense truism |
497 | for(Predicate p : c.getReferencedPredicates()) { |
498 | ArrayList<Literal> pos = new ArrayList<Literal>(); |
499 | ArrayList<Literal> neg = new ArrayList<Literal>(); |
500 | for(Literal lit : c.getLiteralsOfPredicate(p)) { |
501 | if(lit.getSense()) pos.add(lit); |
502 | else neg.add(lit); |
503 | } |
504 | if(pos.isEmpty() || neg.isEmpty()) continue; |
505 | for(Literal plit : pos) { |
506 | for(Literal nlit : neg) { |
507 | String pid = "t" + plit.getIdx() + ".id"; |
508 | String nid = "t" + nlit.getIdx() + ".id"; |
509 | ArrayList<String> oconds = new ArrayList<String>(); |
510 | oconds.add(pid + " IS NULL"); |
511 | oconds.add(nid + " IS NULL"); |
512 | oconds.add(pid + "<>" + nid); |
513 | conds.add(SQLMan.orSelCond(oconds)); |
514 | } |
515 | } |
516 | } |
517 | */ |
518 | |
519 | sql += SQLMan.andSelCond(conds); |
520 | if(!c.sqlWhereBindings.isEmpty()){ |
521 | sql += " AND " + c.sqlWhereBindings; |
522 | } |
523 | |
524 | if(c.hasEmbeddedWeight()){ |
525 | |
526 | String embedWeightTable = ""; |
527 | for(Literal l : c.getRegLiterals()){ |
528 | for(int k=0;k<l.getTerms().size();k++){ |
529 | if(l.getTerms().get(k).var().equals(c.getVarWeight())){ |
530 | embedWeightTable = "t" + l.getIdx() + "." + l.getPred().getArgs().get(k); |
531 | } |
532 | } |
533 | } |
534 | |
535 | if(posClause == true){ |
536 | |
537 | sql = sql + " AND " + embedWeightTable + " > 0 "; |
538 | }else{ |
539 | sql = sql + " AND " + embedWeightTable + " < 0 "; |
540 | } |
541 | } |
542 | |
543 | if(Config.verbose_level == 1) UIMan.print("."); |
544 | UIMan.verbose(2, ">>> Grounding " + c.toString()); |
545 | UIMan.verbose(3, sql); |
546 | //UIMan.verbose(3, db.explain(sql)); |
547 | |
548 | db.dropTable(relTemp); |
549 | sql = "CREATE TABLE " + relTemp + |
550 | " AS " + sql; |
551 | |
552 | /* |
553 | if(Config.gp == true){ |
554 | for(Predicate toa : c.getReferencedPredicates()){ |
555 | UIMan.verbose(3, ">>> Analyze " + toa.getName()); |
556 | db.analyze(toa.getRelName()); |
557 | db.commit(); |
558 | |
559 | try{ |
560 | db.execute("SELECT * FROM " + toa.getRelName() + " WHERE truth = '0' limit 1"); |
561 | }catch(Exception e){ |
562 | continue; |
563 | } |
564 | |
565 | } |
566 | }*/ |
567 | |
568 | db.update(sql); |
569 | |
570 | long ngc = db.countTuples(relTemp); |
571 | UIMan.verbose(2, " Created " + UIMan.comma(ngc) + " groundings"); |
572 | UIMan.verbose(3, ">>> Expanding active atoms..."); |
573 | // activate more atoms |
574 | boolean found = false; |
575 | for(Literal lit : c.getRegLiterals()) { |
576 | if(lit.getPred().isImmutable()) continue; |
577 | if(lit.getSense() != posClause) continue; |
578 | String iql = "INSERT INTO " + lit.getPred().getRelAct() + |
579 | " SELECT DISTINCT t.id" + lit.getIdx() + " FROM " + relTemp + |
580 | " t WHERE t.club" + lit.getIdx() + " <2" + |
581 | " EXCEPT " + |
582 | " SELECT * FROM " + lit.getPred().getRelAct(); |
583 | if(Config.verbose_level == 1) UIMan.print("."); |
584 | db.update(iql); |
585 | if(db.getLastUpdateRowCount() > 0) { |
586 | found = true; |
587 | UIMan.verbose(2, " Found " + |
588 | UIMan.comma(db.getLastUpdateRowCount()) + |
589 | " new active atoms for predicate [" + |
590 | lit.getPred().getName() + "]" ); |
591 | changedThisTime.add(lit.getPred()); |
592 | converged = false; |
593 | db.analyze(lit.getPred().getRelAct()); |
594 | } |
595 | } |
596 | if(!found){ |
597 | UIMan.verbose(2, " Found no new atoms."); |
598 | } |
599 | UIMan.verbose(2, ""); |
600 | } |
601 | } |
602 | |
603 | |
604 | int nmore = 0; |
605 | for(Predicate p : mln.getAllPred()){ |
606 | if(p.isImmutable()) continue; |
607 | if(p.hasMoreToGround()){ |
608 | nmore++; |
609 | break; |
610 | } |
611 | } |
612 | UIMan.verbose(1, ""); |
613 | if(nmore == 0) break; |
614 | changedLastTime = changedThisTime; |
615 | changedThisTime = new HashSet<Predicate>(); |
616 | -- frontier; |
617 | } |
618 | db.dropTable(relTemp); |
619 | } |
620 | |
621 | |
622 | /** |
623 | * Create the atom-clause incidence relation. |
624 | */ |
625 | @SuppressWarnings("unused") |
626 | private void computeIncidenceTable(String relClauses, String relIncidence){ |
627 | UIMan.println(">>> Computing incidence table..."); |
628 | db.dropTable(relIncidence); |
629 | String sql = "CREATE TABLE " + relIncidence + |
630 | "(cid INT, aid INT)"; |
631 | db.update(sql); |
632 | sql = "INSERT INTO " + relIncidence + "(cid, aid) " + |
633 | "SELECT cid, ABS(UNNEST(lits)) FROM " + relClauses; |
634 | db.update(sql); |
635 | UIMan.println("### pins = " + db.getLastUpdateRowCount()); |
636 | } |
637 | |
638 | /** |
639 | * Computes ground clauses activated by the current set |
640 | * of active atoms. |
641 | * |
642 | * Grounding the clause using set of active atoms. Then |
643 | * merge the weight of clauses with the same atom set. |
644 | * Another optimization is merging clauses with only one |
645 | * active atom whose sense is opposite. This is by |
646 | * reverse the sense of negative clause, together with the sense |
647 | * of the only literal. |
648 | * |
649 | * The grounding result is saved in table {@value Config#relClauses}. |
650 | * With the schema like |
651 | * <br/> |
652 | * +------+--------+---------+---------+------+-------+<br/> |
653 | * | lit | weight | (posWt) | (negWt) | fcid | ffcid |<br/> |
654 | * +------+--------+---------+---------+------+-------+<br/> |
655 | * <br/> |
656 | * where posWt and negWt depend on {@link Config#calcCostOffset}, |
657 | * and fcid {@link Config#track_clause_provenance}. |
658 | * |
659 | * @param retainInactiveAtoms set this to true when |
660 | * the original lazy inference is in use. Otherwise, i.e. the closure |
661 | * algorithm is in use, set it to false. |
662 | * |
663 | */ |
664 | private void computeActiveClauses(String cbuffer) { |
665 | Timer.start("totalgrounding"); |
666 | UIMan.verboseInline(1, ">>> Grounding clauses..."); |
667 | UIMan.verbose(2, ""); |
668 | double longestSec = 0; |
669 | Clause longestClause = null; |
670 | |
671 | String sql; |
672 | int totalclauses = 0; |
673 | ArrayList<Clause> relevantClauses = new ArrayList<Clause>(mln.getRelevantClauses()); |
674 | |
675 | int clsidx = 1; |
676 | int clstotal = relevantClauses.size(); |
677 | for(Clause c : relevantClauses) { |
678 | |
679 | HashSet<Boolean> possibleClausePos = new HashSet<Boolean>(); |
680 | if(c.hasEmbeddedWeight()){ |
681 | possibleClausePos.add(true); |
682 | possibleClausePos.add(false); |
683 | }else{ |
684 | possibleClausePos.add(c.isPositiveClause()); |
685 | } |
686 | |
687 | for(boolean posClause : possibleClausePos){ |
688 | |
689 | ArrayList<String> ids = new ArrayList<String>(); |
690 | ArrayList<String> conds = new ArrayList<String>(); |
691 | ArrayList<String> negActConds = new ArrayList<String>(); |
692 | |
693 | // discard irrelevant variables |
694 | for(Literal lit : c.getRegLiterals()) { |
695 | Predicate p = lit.getPred(); |
696 | String r = "t" + lit.getIdx(); |
697 | if(!p.isImmutable()) { |
698 | ids.add((lit.getSense()?"" : "-") + "(CASE WHEN " + |
699 | r + ".id IS NULL THEN 0 WHEN " + |
700 | r + ".atomID IS NULL THEN 0 ELSE " + |
701 | r + ".atomID END)"); |
702 | } |
703 | String fc = (lit.getSense()?"FALSE" : "TRUE"); |
704 | |
705 | ArrayList<String> iconds = new ArrayList<String>(); |
706 | String rp = r; |
707 | iconds.add(rp + ".truth=" + fc); // explicit evidence |
708 | if(lit.getSense()){ |
709 | //TODO: double check |
710 | if(!lit.getPred().isCompletelySepcified()){ |
711 | iconds.add(rp + ".id IS NULL"); // implicit false evidence |
712 | } |
713 | } |
714 | |
715 | if(lit.getSense() || !posClause) { |
716 | |
717 | if(lit.getPred().isClosedWorld()){ |
718 | //TODO: double check!!!!!!!!!!!!!! |
719 | if(lit.getPred().hasSoftEvidence()){ |
720 | iconds.add(r + ".atomID IS NOT NULL"); |
721 | } |
722 | |
723 | }else{ |
724 | if(Config.learning_mode){ |
725 | iconds.add(rp + ".club < 2 OR " + rp + ".club = 3"); |
726 | }else{ |
727 | iconds.add(rp + ".club < 2"); // unknown truth |
728 | } |
729 | } |
730 | if(!posClause) { |
731 | negActConds.add(r + ".atomID IS NOT NULL"); // active atom |
732 | } |
733 | }else { |
734 | iconds.add(r + ".atomID IS NOT NULL"); // active atoms |
735 | } |
736 | |
737 | if(!posClause && !lit.getSense()) { // negative clause, negative literal |
738 | if(lit.getPred().isClosedWorld()){ |
739 | |
740 | }else{ |
741 | if(Config.learning_mode){ |
742 | negActConds.add("(" + rp + ".club < 2 OR " + rp + ".club = 3" + ")"); |
743 | }else{ |
744 | negActConds.add("(" + rp + ".club < 2" + ")"); |
745 | } |
746 | } |
747 | |
748 | } |
749 | |
750 | conds.add(SQLMan.orSelCond(iconds)); |
751 | |
752 | } |
753 | if(ids.isEmpty()) continue; |
754 | if(!posClause) { |
755 | if(negActConds.isEmpty()) continue; |
756 | conds.add(SQLMan.orSelCond(negActConds)); |
757 | } |
758 | /* |
759 | // discard same-variable-opposite-sense truism |
760 | for(Predicate p : c.getReferencedPredicates()) { |
761 | ArrayList<Literal> pos = new ArrayList<Literal>(); |
762 | ArrayList<Literal> neg = new ArrayList<Literal>(); |
763 | for(Literal lit : c.getLiteralsOfPredicate(p)) { |
764 | if(lit.getSense()) pos.add(lit); |
765 | else neg.add(lit); |
766 | } |
767 | if(pos.isEmpty() || neg.isEmpty()) continue; |
768 | for(Literal plit : pos) { |
769 | for(Literal nlit : neg) { |
770 | String pid = "t" + plit.getIdx() + ".id"; |
771 | String nid = "t" + nlit.getIdx() + ".id"; |
772 | ArrayList<String> oconds = new ArrayList<String>(); |
773 | oconds.add(pid + " IS NULL"); |
774 | oconds.add(nid + " IS NULL"); |
775 | oconds.add(pid + "<>" + nid); |
776 | conds.add(SQLMan.orSelCond(oconds)); |
777 | } |
778 | } |
779 | } |
780 | */ |
781 | if(!c.hasExistentialQuantifiers()) { |
782 | sql = "SELECT " + |
783 | "UNIQ(SORT(ARRAY[" + StringMan.commaList(ids) + "]-0)) as list2, " + |
784 | c.getWeightExp() + " as weight2 "+ |
785 | // ffid |
786 | (!c.getWeightExp().contains("metaTable")? |
787 | (", CAST(0 as Integer) as ffid"): |
788 | (", metaTable.myid as ffid "))+ |
789 | " FROM " + |
790 | c.sqlFromList + " WHERE " + c.sqlWhereBindings; |
791 | if(!conds.isEmpty()) { |
792 | sql += " AND " + SQLMan.andSelCond(conds); |
793 | } |
794 | |
795 | if(posClause == true){ |
796 | |
797 | sql = sql + " AND " + c.getWeightExp() + " > 0 "; |
798 | }else{ |
799 | sql = sql + " AND " + c.getWeightExp() + " < 0 "; |
800 | } |
801 | |
802 | }else { |
803 | ArrayList<String> aggs = new ArrayList<String>(); |
804 | for(String ide : ids) { |
805 | aggs.add("array_agg(" + ide + ")"); |
806 | } |
807 | sql = "SELECT " + c.sqlPivotAttrsList + |
808 | (c.sqlPivotAttrsList.length() > 0 ? "," : "") + |
809 | " UNIQ(SORT("+ StringMan.join("+", aggs) +"-0)) as list2, " + |
810 | c.getWeightExp() + " as weight2 "+ |
811 | // ffid |
812 | (c.getWeightExp().contains("FLOAT8")? |
813 | (", CAST(0 as Integer) as ffid"): |
814 | (", metaTable.myid as ffid "))+ |
815 | " FROM " + |
816 | c.sqlFromList + " WHERE " + c.sqlWhereBindings; |
817 | if(!conds.isEmpty()) { |
818 | sql += " AND " + SQLMan.andSelCond(conds); |
819 | } |
820 | if(c.sqlPivotAttrsList.length() > 0){ |
821 | sql += " GROUP BY " + c.sqlPivotAttrsList + " , ffid"; |
822 | } |
823 | sql = "SELECT list2, weight2, ffid FROM " + |
824 | "(" + sql + ") tpivoted"; |
825 | |
826 | if(posClause == true){ |
827 | |
828 | sql = sql + " WHERE " + c.getWeightExp() + " > 0 "; |
829 | }else{ |
830 | sql = sql + " WHERE " + c.getWeightExp() + " < 0 "; |
831 | } |
832 | } |
833 | |
834 | boolean unifySoftUnitClauses = true; |
835 | if(unifySoftUnitClauses) { |
836 | sql = "SELECT (CASE WHEN unitNegativeClause(list2)>=0 THEN " + |
837 | "list2 ELSE array[-list2[1]] END) AS list, " + |
838 | "(CASE WHEN unitNegativeClause(list2)>=0 THEN weight2 " + |
839 | "ELSE -weight2 END) AS weight, " + |
840 | "(CASE WHEN unitNegativeClause(list2)>=0 THEN " + c.getId() + " " + |
841 | "ELSE -" + c.getId() + " END) AS fcid " + |
842 | // ffcid, for learning |
843 | ", (CASE WHEN unitNegativeClause(list2)>=0 THEN ('" + |
844 | c.getId() + ".' || ffid) " + |
845 | "ELSE ('-" + c.getId() + ".' || ffid) END) AS ffcid " + |
846 | // |
847 | "FROM (" + sql + ") as " + c.getName() + |
848 | " WHERE array_upper(list2,1)>=1"; |
849 | }else { |
850 | sql = "SELECT list2 AS list, " + |
851 | "weight2 AS weight, " + c.getId() + " AS ffcid " + |
852 | "FROM (" + sql + ") as " + c.getName() + |
853 | " WHERE array_upper(list2,1)>=1"; |
854 | } |
855 | if(Config.verbose_level == 1) UIMan.print("."); |
856 | UIMan.verbose(2, ">>> Grounding clause " + |
857 | (clsidx++) + " / " + clstotal + "\n" + |
858 | c.toString()); |
859 | UIMan.verbose(3, sql); |
860 | sql = "INSERT INTO " + cbuffer + "\n" + sql; |
861 | Timer.start("gnd"); |
862 | |
863 | /* |
864 | if(Config.gp == true){ |
865 | for(Predicate toa : c.getReferencedPredicates()){ |
866 | UIMan.verbose(3, ">>> Analyze " + toa.getName()); |
867 | db.analyze(toa.getRelName()); |
868 | db.commit(); |
869 | } |
870 | }*/ |
871 | |
872 | db.update(sql); |
873 | |
874 | // report stats |
875 | totalclauses += db.getLastUpdateRowCount(); |
876 | if(Timer.elapsedSeconds("gnd") > longestSec){ |
877 | longestClause = c; |
878 | longestSec = Timer.elapsedSeconds("gnd"); |
879 | } |
880 | UIMan.verbose(2, "### took " + Timer.elapsed("gnd")); |
881 | UIMan.verbose(2, "### new clauses = " + |
882 | UIMan.comma(db.getLastUpdateRowCount()) + |
883 | "; total = " + UIMan.comma(totalclauses) + "\n"); |
884 | } |
885 | } |
886 | if(longestClause != null){ |
887 | UIMan.verbose(3, "### Longest per-clause grounding time = " + longestSec + " sec, by"); |
888 | UIMan.verbose(3, longestClause.toString()); |
889 | } |
890 | if(Config.verbose_level == 1) UIMan.println("."); |
891 | UIMan.verbose(1, "### total grounding = " + Timer.elapsed("totalgrounding")); |
892 | } |
893 | |
894 | |
895 | private int consolidateClauses(String cbuffer, String relClauses){ |
896 | UIMan.verbose(1, ">>> Consolidating ground clauses..."); |
897 | String sql; |
898 | // combine equivalent classes |
899 | ArrayList<String> args = new ArrayList<String>(); |
900 | ArrayList<String> sels = new ArrayList<String>(); |
901 | args.add("lits"); |
902 | args.add("weight"); |
903 | sels.add("list"); |
904 | sels.add("sum(weight)"); |
905 | if(Config.track_clause_provenance){ |
906 | args.add("fcid"); |
907 | args.add("ffcid"); |
908 | sels.add("UNIQ(SORT(array_agg(fcid)))"); |
909 | sels.add("array_agg(ffcid)"); |
910 | } |
911 | sql = "INSERT INTO " + relClauses + |
912 | StringMan.commaListParen(args) + |
913 | " SELECT " + |
914 | StringMan.commaList(sels) + |
915 | "FROM " + cbuffer + |
916 | " GROUP BY list"; |
917 | Timer.start("gnd"); |
918 | db.update(sql); |
919 | UIMan.verbose(1, "### took " + Timer.elapsed("gnd")); |
920 | int numClauses = db.getLastUpdateRowCount(); |
921 | return numClauses; |
922 | } |
923 | |
924 | private void addSoftEvidClauses(String relAtoms, String cbuffer){ |
925 | UIMan.verbose(1, ">>> Adding unit clauses for soft evidence..."); |
926 | int cnt = 0; |
927 | String iql = "INSERT INTO " + cbuffer + "(list, weight, fcid, ffcid) " + |
928 | " SELECT array[atomID], " + |
929 | " (CASE WHEN prob>=1 THEN " + Config.hard_weight + |
930 | " WHEN prob<=0 THEN -" + Config.hard_weight + |
931 | " ELSE ln(prob / (1-prob)) END), " + |
932 | " 0, '0' " + |
933 | " FROM " + relAtoms + |
934 | " WHERE prob IS NOT NULL AND prob <> 0.5"; |
935 | db.update(iql); |
936 | cnt += db.getLastUpdateRowCount(); |
937 | UIMan.verbose(1, "### soft-evidence clauses = " + UIMan.comma(cnt)); |
938 | } |
939 | |
940 | private void addKeyConstraintClauses(String relAtoms, String cbuffer){ |
941 | |
942 | int fcid = mln.getAllNormalizedClauses().size(); |
943 | |
944 | for(Predicate p : mln.getAllPred()){ |
945 | |
946 | if(!p.hasDependentAttributes()){ |
947 | continue; |
948 | } |
949 | |
950 | fcid ++; |
951 | String ffcid = fcid + ".1"; |
952 | |
953 | ArrayList<String> whereList_key = new ArrayList<String>(); |
954 | ArrayList<String> whereList_label = new ArrayList<String>(); |
955 | |
956 | // goal: |
957 | //select array[-t0.atomid, -t1.atomid], Integer.Max, 1, 1.1 |
958 | //from pred_dwinner t0, pred_dwinner t1 |
959 | //where (t0.key1 = t1.key1 AND ...) AND (t0.label2 <> t1.label2 OR ... ) |
960 | //AND t0.atomid in (select atomid from mln0_atoms) |
961 | |
962 | for(String keyAttr : p.getKeyAttrs()){ |
963 | whereList_key.add("(t0." + keyAttr + "=" + "t1." + keyAttr + ")"); |
964 | } |
965 | |
966 | for(String labelAttr : p.getDependentAttrs()){ |
967 | whereList_label.add("(t0." + labelAttr + "<>" + "t1." + labelAttr + ")"); |
968 | } |
969 | |
970 | String sql = "INSERT INTO " + cbuffer + |
971 | " SELECT array[-t0.atomid, -t1.atomid], " + Integer.MAX_VALUE + ", " + fcid + ", '" + ffcid +"'" + |
972 | " FROM " + p.getRelName() + " t0, " + p.getRelName() + " t1 " + |
973 | " WHERE (" + StringMan.join(" AND ", whereList_key) + ") AND " + |
974 | "(" + StringMan.join(" OR ", whereList_label) + ") AND " + |
975 | " t0.atomid IN (SELECT atomid FROM " + relAtoms +") AND " + |
976 | " t1.atomid IN (SELECT atomid FROM " + relAtoms +")"; |
977 | |
978 | db.execute(sql); |
979 | |
980 | } |
981 | |
982 | |
983 | } |
984 | |
985 | /** |
986 | * |
987 | * An attempt of computing cost lower bounds with the probabilistic method. |
988 | */ |
989 | @SuppressWarnings("unused") |
990 | private void reportCostStats(String relClauses){ |
991 | try { |
992 | String sql = "SELECT SUM(ABS(weight)) FROM " + relClauses; |
993 | ResultSet rs = db.query(sql); |
994 | if(rs.next()){ |
995 | double tw = rs.getDouble(1); |
996 | UIMan.println("total weight = " + tw); |
997 | } |
998 | rs.close(); |
999 | |
1000 | sql = "SELECT * FROM " + relClauses; |
1001 | rs = db.query(sql); |
1002 | double[] posApp = new double[numAtoms+1]; |
1003 | double[] negApp = new double[numAtoms+1]; |
1004 | double[] freq = new double[numAtoms+1]; |
1005 | double[] probs = {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1}; |
1006 | double[] ecosts = new double[probs.length]; |
1007 | double[] punsat = new double[probs.length]; |
1008 | while(rs.next()){ |
1009 | GClause c = new GClause(); |
1010 | c.parse(rs); |
1011 | for(int i=0; i<probs.length; i++){ |
1012 | punsat[i] = 1; |
1013 | } |
1014 | for(int lit : c.lits){ |
1015 | if(lit > 0){ |
1016 | if(c.weight>0) posApp[lit]+=c.weight/c.lits.length; |
1017 | else negApp[lit]-=c.weight*c.lits.length; |
1018 | for(int i=0; i<probs.length; i++){ |
1019 | punsat[i] = punsat[i] * (1-probs[i]); |
1020 | } |
1021 | } |
1022 | if(lit < 0){ |
1023 | if(c.weight>0) negApp[-lit]+=c.weight/c.lits.length; |
1024 | else posApp[-lit]-=c.weight*c.lits.length; |
1025 | for(int i=0; i<probs.length; i++){ |
1026 | punsat[i] = punsat[i] * (probs[i]); |
1027 | } |
1028 | } |
1029 | } |
1030 | if(c.weight > 0){ |
1031 | for(int i=0; i<probs.length; i++){ |
1032 | ecosts[i] += punsat[i] * c.weight; |
1033 | } |
1034 | }else{ |
1035 | for(int i=0; i<probs.length; i++){ |
1036 | ecosts[i] -= (1-punsat[i]) * c.weight; |
1037 | } |
1038 | } |
1039 | } |
1040 | rs.close(); |
1041 | |
1042 | for(int i=0; i<probs.length; i++){ |
1043 | System.out.println("For all x P[x] = " + probs[i] + " --> E[cost] = " + ecosts[i]); |
1044 | } |
1045 | |
1046 | for(int i=1; i<=numAtoms; i++){ |
1047 | if(posApp[i] + negApp[i] == 0) continue; |
1048 | freq[i] = posApp[i] / (double)(posApp[i] + negApp[i]); |
1049 | } |
1050 | |
1051 | rs = db.query(sql); |
1052 | double ocost = 0; |
1053 | while(rs.next()){ |
1054 | GClause c = new GClause(); |
1055 | c.parse(rs); |
1056 | double opunsat = 1; |
1057 | for(int lit : c.lits){ |
1058 | if(lit > 0){ |
1059 | opunsat *= 1 - freq[lit]; |
1060 | }else{ |
1061 | opunsat *= freq[-lit]; |
1062 | } |
1063 | } |
1064 | if(c.weight > 0){ |
1065 | ocost += opunsat * c.weight; |
1066 | }else{ |
1067 | ocost -= (1-opunsat) * c.weight; |
1068 | } |
1069 | } |
1070 | rs.close(); |
1071 | System.out.println("Heterogeneous probs --> E[cost] = " + ocost); |
1072 | |
1073 | } catch (SQLException e) { |
1074 | ExceptionMan.handle(e); |
1075 | } |
1076 | } |
1077 | |
1078 | |
1079 | } |