| 1 | package felix.operator; |
| 2 | |
| 3 | |
| 4 | import felix.dstruct.FelixClause; |
| 5 | import felix.dstruct.FelixPredicate; |
| 6 | import felix.dstruct.FelixQuery; |
| 7 | import felix.dstruct.StatOperator; |
| 8 | import felix.dstruct.FelixPredicate.FPProperty; |
| 9 | import felix.parser.FelixCommandOptions; |
| 10 | import felix.util.FelixConfig; |
| 11 | import felix.util.FelixUIMan; |
| 12 | |
| 13 | |
| 14 | import java.util.HashMap; |
| 15 | import java.util.HashSet; |
| 16 | |
| 17 | |
| 18 | |
| 19 | import tuffy.db.RDB; |
| 20 | import tuffy.ground.Grounding; |
| 21 | import tuffy.ground.KBMC; |
| 22 | import tuffy.ground.partition.PartitionScheme; |
| 23 | import tuffy.infer.DataMover; |
| 24 | import tuffy.infer.InferPartitioned; |
| 25 | import tuffy.infer.MRF; |
| 26 | import tuffy.mln.Clause; |
| 27 | import tuffy.mln.MarkovLogicNetwork; |
| 28 | import tuffy.mln.Predicate; |
| 29 | import tuffy.ra.ConjunctiveQuery; |
| 30 | import tuffy.ra.Expression; |
| 31 | import tuffy.util.Config; |
| 32 | import tuffy.util.Settings; |
| 33 | import tuffy.util.Timer; |
| 34 | import tuffy.util.UIMan; |
| 35 | |
| 36 | |
| 37 | /** |
| 38 | * A Tuffy operator in Felix. |
| 39 | * @author Ce Zhang |
| 40 | * |
| 41 | */ |
| 42 | public class TUFFYOperator extends StatOperator{ |
| 43 | |
| 44 | /** |
| 45 | * The grounding worker of this Tuffy operator. |
| 46 | */ |
| 47 | Grounding grounding; |
| 48 | |
| 49 | /** |
| 50 | * Markov logic network used by this Tuffy operator. |
| 51 | */ |
| 52 | MarkovLogicNetwork mln; |
| 53 | |
| 54 | /** |
| 55 | * The constructor of TUFFYOperator. |
| 56 | * @param _fq Felix query. |
| 57 | * @param _goalPredicates target predicates of this coref operator. |
| 58 | * @param _opt Command line options of this Felix run. |
| 59 | */ |
| 60 | public TUFFYOperator(FelixQuery _fq, HashSet<FelixPredicate> _goalPredicates, |
| 61 | FelixCommandOptions _opt) { |
| 62 | super(_fq, _goalPredicates, _opt); |
| 63 | this.type = OPType.TUFFY; |
| 64 | this.precedence = 1; |
| 65 | |
| 66 | for(FelixClause fc : this.allRelevantFelixClause){ |
| 67 | |
| 68 | int nOpen = 0; |
| 69 | |
| 70 | for(Predicate p : fc.getReferencedPredicates()){ |
| 71 | |
| 72 | if(!p.isClosedWorld() || p.getName().endsWith("_map")){ |
| 73 | this.commonCandidate.add(fq.getPredByName(p.getName())); |
| 74 | } |
| 75 | } |
| 76 | |
| 77 | } |
| 78 | |
| 79 | } |
| 80 | |
| 81 | boolean prepared = false; |
| 82 | |
| 83 | /** |
| 84 | * Prepares operator for execution. |
| 85 | */ |
| 86 | @Override |
| 87 | public void prepare() { |
| 88 | |
| 89 | if(!prepared){ |
| 90 | |
| 91 | UIMan.println(">>> Start Running " + this); |
| 92 | |
| 93 | try { |
| 94 | |
| 95 | |
| 96 | } catch (Exception e) { |
| 97 | e.printStackTrace(); |
| 98 | } |
| 99 | |
| 100 | prepared = true; |
| 101 | } |
| 102 | |
| 103 | } |
| 104 | |
| 105 | |
| 106 | /** |
| 107 | * Executes operator. |
| 108 | */ |
| 109 | @Override |
| 110 | public void run() { |
| 111 | |
| 112 | UIMan.println(">>> Start Running " + this); |
| 113 | |
| 114 | db = RDB.getRDBbyConfig(Config.db_schema); |
| 115 | |
| 116 | Timer.start("Timer - TUFFY - " + this.getId()); |
| 117 | |
| 118 | mln = new MarkovLogicNetwork(); |
| 119 | |
| 120 | mln.setDB(db); |
| 121 | |
| 122 | HashSet<Predicate> registeredPredicates = new HashSet<Predicate>(); |
| 123 | |
| 124 | HashMap<Predicate, Boolean> oriClosedWorld = new HashMap<Predicate, Boolean>(); |
| 125 | |
| 126 | |
| 127 | for(FelixPredicate fp : this.inputPredicates){ |
| 128 | oriClosedWorld.put(fp, fp.isClosedWorld()); |
| 129 | fp.setClosedWorld(true); |
| 130 | if(this.isMarginal){ |
| 131 | if(fp.isCorefPredicate == false && fp.isCorefMapPredicate == false){ |
| 132 | if(!options.useDualDecomposition || FelixConfig.isFirstRunOfDD){ |
| 133 | fp.setHasSoftEvidence(true); |
| 134 | }else{ |
| 135 | fp.setHasSoftEvidence(false); |
| 136 | } |
| 137 | } |
| 138 | } |
| 139 | |
| 140 | } |
| 141 | |
| 142 | for(Predicate fp : this.outputPredicates){ |
| 143 | oriClosedWorld.put(fp, fp.isClosedWorld()); |
| 144 | fp.setClosedWorld(this.currentState); |
| 145 | fp.setCompeletelySpecified(this.currentState); |
| 146 | |
| 147 | } |
| 148 | |
| 149 | for(FelixPredicate fp : this.dd_CommonOutput){ |
| 150 | |
| 151 | fp.setHasSoftEvidence(true); |
| 152 | fp.setCompeletelySpecified(true); |
| 153 | oriClosedWorld.put(fp, fp.isClosedWorld()); |
| 154 | fp.setClosedWorld(!this.currentState); |
| 155 | |
| 156 | } |
| 157 | |
| 158 | HashSet<FelixClause> allClause = allRelevantFelixClause; |
| 159 | allClause.addAll(this.dd_PriorClauses); |
| 160 | |
| 161 | for(FelixClause fc : this.allRelevantFelixClause){ |
| 162 | |
| 163 | boolean isNotUseful = true; |
| 164 | |
| 165 | Clause cloned = fc.clone(); |
| 166 | |
| 167 | if(this.clauseConstraints.get(fc) != null){ |
| 168 | |
| 169 | for(Expression e : this.clauseConstraints.get(fc)){ |
| 170 | cloned.getConstraints().add(e.clone()); |
| 171 | } |
| 172 | } |
| 173 | |
| 174 | mln.registerClause(cloned); |
| 175 | |
| 176 | for(Predicate fp : fc.getReferencedPredicates()){ |
| 177 | if(registeredPredicates.contains(fp)){ |
| 178 | continue; |
| 179 | } |
| 180 | fp.getRelatedClauses().clear(); |
| 181 | mln.registerPred(fp); |
| 182 | |
| 183 | registeredPredicates.add(fp); |
| 184 | } |
| 185 | |
| 186 | } |
| 187 | |
| 188 | for(ConjunctiveQuery sr : fq.getScopingRules()){ |
| 189 | mln.registerScopingRule(sr.clone()); |
| 190 | } |
| 191 | |
| 192 | mln.normalizeClauses(); |
| 193 | |
| 194 | for(Predicate p : mln.getAllPred()){ |
| 195 | p.setDB(db); |
| 196 | } |
| 197 | mln.finalizeClauseDefinitions(db); |
| 198 | |
| 199 | |
| 200 | //mln.prepareDB4Mobius(db); |
| 201 | |
| 202 | //mln.cleanUnknownPredTables(); |
| 203 | |
| 204 | KBMC kbmc = new KBMC(mln); |
| 205 | kbmc.run(); |
| 206 | mln.applyAllScopes(); |
| 207 | FelixUIMan.println(1, ">>> Marking queries..."); |
| 208 | mln.storeAllQueries(); |
| 209 | |
| 210 | mln.getDB().commit(); |
| 211 | //mln.getDB().close(); |
| 212 | //mln.setDB(RDB.getRDBbyConfig(Config.db_schema)); |
| 213 | |
| 214 | grounding = new Grounding(mln); |
| 215 | grounding.constructMRF(); |
| 216 | |
| 217 | DataMover dmover = new DataMover(mln); |
| 218 | |
| 219 | if(options.maxFlips == 0){ |
| 220 | options.maxFlips = 10 * grounding.getNumAtoms(); |
| 221 | } |
| 222 | if(options.maxTries == 0){ |
| 223 | options.maxTries = 1; |
| 224 | } |
| 225 | |
| 226 | MRF mrf = null; |
| 227 | |
| 228 | if(options.disablePartition){ |
| 229 | if(!options.marginal || options.dual){ |
| 230 | UIMan.println(">>> Running MAP inference..."); |
| 231 | String mapfout = options.fout; |
| 232 | if(options.dual) mapfout += ".map"; |
| 233 | |
| 234 | UIMan.println(" Loading MRF from DB to RAM..."); |
| 235 | mrf = dmover.loadMrfFromDb(mln.relAtoms, mln.relClauses); |
| 236 | mrf.inferWalkSAT(options.maxTries, options.maxFlips); |
| 237 | dmover.flushAtomStates(mrf.atoms.values(), mln.relAtoms, true); |
| 238 | |
| 239 | UIMan.println("### Best answer has cost " + UIMan.decimalRound(2,mrf.lowCost)); |
| 240 | //UIMan.println(">>> Writing answer to file: " + mapfout); |
| 241 | //dmover.dumpTruthToFile(mln.relAtoms, mapfout); |
| 242 | } |
| 243 | |
| 244 | if(options.marginal || options.dual){ |
| 245 | UIMan.println(">>> Running marginal inference..."); |
| 246 | String mfout = options.fout; |
| 247 | if(options.dual) mfout += ".marginal"; |
| 248 | |
| 249 | if(mrf == null){ |
| 250 | mrf = new MRF(mln); |
| 251 | dmover.loadMrfFromDb(mrf, mln.relAtoms, mln.relClauses); |
| 252 | } |
| 253 | |
| 254 | double sumCost = mrf.mcsat(options.mcsatSamples, options.maxFlips); |
| 255 | dmover.flushAtomStates(mrf.atoms.values(), mln.relAtoms); |
| 256 | |
| 257 | UIMan.println("### Average Cost = " + UIMan.decimalRound(2,sumCost/options.mcsatSamples)); |
| 258 | |
| 259 | //UIMan.println(">>> Writing answer to file: " + mfout); |
| 260 | //dmover.dumpProbsToFile(mln.relAtoms, mfout); |
| 261 | } |
| 262 | }else{ |
| 263 | |
| 264 | InferPartitioned ip = new InferPartitioned(grounding, dmover); |
| 265 | PartitionScheme pmap = ip.getPartitionScheme(); |
| 266 | int ncomp = pmap.numComponents(); |
| 267 | int nbuck = ip.getNumBuckets(); |
| 268 | String sdata = UIMan.comma(ncomp) + (ncomp > 1 ? " components" : "component"); |
| 269 | sdata += " (grouped into "; |
| 270 | sdata += UIMan.comma(nbuck) + (nbuck > 1 ? " buckets" : " bucket)"); |
| 271 | |
| 272 | |
| 273 | Settings settings = new Settings(); |
| 274 | Double fpa = ((double)options.maxFlips)/grounding.getNumAtoms(); |
| 275 | |
| 276 | if(!options.marginal || options.dual){ |
| 277 | UIMan.println(">>> Running MAP inference on " + sdata); |
| 278 | String mapfout = options.fout; |
| 279 | if(options.dual) mapfout += ".map"; |
| 280 | |
| 281 | settings.put("task", "MAP"); |
| 282 | settings.put("ntries", new Integer(options.maxTries)); |
| 283 | settings.put("flipsPerAtom", fpa); |
| 284 | double lowCost = ip.infer(settings); |
| 285 | |
| 286 | UIMan.println("### Best answer has cost " + UIMan.decimalRound(2,lowCost)); |
| 287 | //UIMan.println(">>> Writing answer to file: " + mapfout); |
| 288 | //dmover.dumpTruthToFile(mln.relAtoms, mapfout); |
| 289 | } |
| 290 | |
| 291 | if(options.marginal || options.dual){ |
| 292 | UIMan.println(">>> Running marginal inference on " + sdata); |
| 293 | String mfout = options.fout; |
| 294 | if(options.dual) mfout += ".marginal"; |
| 295 | |
| 296 | settings.put("task", "MARGINAL"); |
| 297 | settings.put("nsamples", new Integer(options.mcsatSamples)); |
| 298 | settings.put("flipsPerAtom", fpa); |
| 299 | double aveCost = ip.infer(settings); |
| 300 | |
| 301 | UIMan.println("### Average Cost = " + UIMan.decimalRound(2,aveCost)); |
| 302 | |
| 303 | //UIMan.println(">>> Writing answer to file: " + mfout); |
| 304 | //dmover.dumpProbsToFile(mln.relAtoms, mfout); |
| 305 | } |
| 306 | |
| 307 | } |
| 308 | |
| 309 | this.belongsToBucket.addMLNRelTable(mln.relAtoms); |
| 310 | |
| 311 | for(Predicate fp : oriClosedWorld.keySet()){ |
| 312 | fp.setClosedWorld(oriClosedWorld.get(fp)); |
| 313 | } |
| 314 | |
| 315 | FelixUIMan.println(0,0,">>> {" + this + "} uses " + Timer.elapsed("Timer - TUFFY - " + this.getId())); |
| 316 | FelixUIMan.println(0, 0, ""); |
| 317 | |
| 318 | //System.out.println("--------------[Timer - TUFFY - " + this.getId() + "]: " + |
| 319 | // Timer.elapsed("Timer - TUFFY - " + this.getId())); |
| 320 | |
| 321 | db.close(); |
| 322 | |
| 323 | if(!options.useDualDecomposition){ |
| 324 | this.belongsToBucket.runNextOperatorInBucket(); |
| 325 | } |
| 326 | |
| 327 | } |
| 328 | |
| 329 | @Override |
| 330 | public String explain() { |
| 331 | return null; |
| 332 | } |
| 333 | |
| 334 | @Override |
| 335 | public void learn() { |
| 336 | |
| 337 | } |
| 338 | |
| 339 | } |