| 1 | package felix.operator; |
| 2 | |
| 3 | |
| 4 | import java.io.BufferedWriter; |
| 5 | import java.io.FileInputStream; |
| 6 | import java.io.FileWriter; |
| 7 | import java.util.ArrayList; |
| 8 | import java.util.Arrays; |
| 9 | import java.util.HashMap; |
| 10 | import java.util.HashSet; |
| 11 | import java.util.Set; |
| 12 | |
| 13 | import org.postgresql.PGConnection; |
| 14 | |
| 15 | import tuffy.db.RDB; |
| 16 | import tuffy.mln.Predicate; |
| 17 | import tuffy.ra.ConjunctiveQuery; |
| 18 | import tuffy.util.Config; |
| 19 | import tuffy.util.ExceptionMan; |
| 20 | import tuffy.util.StringMan; |
| 21 | import tuffy.util.Timer; |
| 22 | import tuffy.util.UIMan; |
| 23 | import felix.dstruct.DataMovementOperator; |
| 24 | import felix.dstruct.FelixPredicate; |
| 25 | import felix.dstruct.FelixQuery; |
| 26 | import felix.dstruct.StatOperator; |
| 27 | import felix.dstruct.FelixPredicate.FPProperty; |
| 28 | import felix.parser.FelixCommandOptions; |
| 29 | import felix.util.FelixConfig; |
| 30 | import felix.util.FelixStringMan; |
| 31 | import felix.util.FelixUIMan; |
| 32 | |
| 33 | /** |
| 34 | * A CRF operator in Felix. |
| 35 | * @author Ce Zhang |
| 36 | * |
| 37 | */ |
| 38 | public class CRFOperator extends StatOperator{ |
| 39 | |
| 40 | /** |
| 41 | * Target predicate of this CRF operator. |
| 42 | */ |
| 43 | FelixPredicate crfHead; |
| 44 | |
| 45 | /** |
| 46 | * Mapping from label's constant ID to a new label ID. This new label ID is from 0, which |
| 47 | * is used in the inference of CRF (s.t., we can use array to represent |
| 48 | * labels). |
| 49 | */ |
| 50 | HashMap<String, Integer> label2ID = new HashMap<String, Integer>(); |
| 51 | |
| 52 | /** |
| 53 | * The inverse map of {@link CRFOperator#label2ID}. Here String[] means |
| 54 | * there can be multiple fields corresponding to CRF's labeling field. |
| 55 | */ |
| 56 | HashMap<Integer, String[]> id2Label = new HashMap<Integer, String[]>(); |
| 57 | |
| 58 | /** |
| 59 | * All DataMovementOperators used as LR rules. |
| 60 | */ |
| 61 | ArrayList<DataMovementOperator> lrDMOs = new ArrayList<DataMovementOperator>(); |
| 62 | |
| 63 | /** |
| 64 | * The DataMovementOperator used as CRF Chain rule. |
| 65 | */ |
| 66 | DataMovementOperator crfDMO = null; |
| 67 | |
| 68 | /** |
| 69 | * The DataMovementOperator which is the union of all LR rules. |
| 70 | */ |
| 71 | DataMovementOperator lrDMO = null; |
| 72 | |
| 73 | /** |
| 74 | * The DataMovementOperator representing the table/view for all unigram features. |
| 75 | */ |
| 76 | DataMovementOperator unigramDMO = null; |
| 77 | |
| 78 | /** |
| 79 | * The DataMovementOperator representing the table/view for all bigram features. |
| 80 | */ |
| 81 | DataMovementOperator bigramDMO = null; |
| 82 | |
| 83 | /** |
| 84 | * The DataMovementOperator representing the table/view for all possible labels. |
| 85 | */ |
| 86 | DataMovementOperator labelDomainDMO = null; |
| 87 | |
| 88 | /** |
| 89 | * The DataMovementOperator fetching a partition of grounded literals (which may |
| 90 | * be a sequence or multiples sequences) |
| 91 | */ |
| 92 | DataMovementOperator getAllPossiblePartitioningDMO = null; |
| 93 | |
| 94 | /** |
| 95 | * The DataMovementOperator fetching all bigram features for a given sequence. |
| 96 | */ |
| 97 | DataMovementOperator getBigramFeaturesForPartitioningDMO = null; |
| 98 | |
| 99 | /** |
| 100 | * The DataMovementOperator fetching all unigram features for a given sequence |
| 101 | */ |
| 102 | DataMovementOperator getUnigramFeaturesForPartitioningDMO = null; |
| 103 | |
| 104 | /** |
| 105 | * The DataMovementOperator fetching all bigram features. |
| 106 | */ |
| 107 | DataMovementOperator getAllBigramFeaturesDMO = null; |
| 108 | |
| 109 | /** |
| 110 | * The DataMovementOperator fetching all unigram features. |
| 111 | */ |
| 112 | DataMovementOperator getAllUnigramFeaturesDMO = null; |
| 113 | |
| 114 | public int nRuns = 0; |
| 115 | |
| 116 | /** |
| 117 | * The constructor of CRFOperator. |
| 118 | * @param _fq Felix query. |
| 119 | * @param _goalPredicates target predicates of this coref operator. |
| 120 | * @param _opt Command line options of this Felix run. |
| 121 | */ |
| 122 | public CRFOperator(FelixQuery _fq, HashSet<FelixPredicate> _goalPredicates, |
| 123 | FelixCommandOptions _opt) { |
| 124 | super(_fq, _goalPredicates, _opt); |
| 125 | this.type = OPType.CRF; |
| 126 | this.precedence = 5; |
| 127 | } |
| 128 | |
| 129 | boolean prepared = false; |
| 130 | |
| 131 | /** |
| 132 | * Prepares operator for execution. |
| 133 | */ |
| 134 | @Override |
| 135 | public void prepare() { |
| 136 | |
| 137 | this.lrDMOs.clear(); |
| 138 | allDMOs.clear(); |
| 139 | |
| 140 | //if(!prepared){ |
| 141 | |
| 142 | db = RDB.getRDBbyConfig(Config.db_schema); |
| 143 | |
| 144 | crfHead = this.getTargetPredicateIfHasOnlyOne(); |
| 145 | HashSet<ConjunctiveQuery> chainQueries = |
| 146 | this.translateFelixClasesIntoFactorGraphEdgeQueries(crfHead, false, this.inputPredicateScope, FPProperty.CHAIN_RECUR); |
| 147 | HashSet<ConjunctiveQuery> lrQueries = |
| 148 | this.translateFelixClasesIntoFactorGraphEdgeQueries(crfHead, false, this.inputPredicateScope, FPProperty.NON_RECUR); |
| 149 | |
| 150 | this.prepareDMO(lrQueries, chainQueries); |
| 151 | |
| 152 | prepared = true; |
| 153 | //} |
| 154 | |
| 155 | } |
| 156 | |
| 157 | /** |
| 158 | * Return the signature of a string array as a string. |
| 159 | * @param _array |
| 160 | * @return |
| 161 | */ |
| 162 | String array2str(String[] _array){ |
| 163 | String ret = ""; |
| 164 | for(String s : _array){ |
| 165 | ret = ret + ":" + s; |
| 166 | } |
| 167 | return ret; |
| 168 | } |
| 169 | |
| 170 | /** |
| 171 | * Executes operator. |
| 172 | */ |
| 173 | @Override |
| 174 | public void run() { |
| 175 | |
| 176 | nRuns ++; |
| 177 | |
| 178 | if(!this.options.useDualDecomposition){ |
| 179 | crfHead.setHasSoftEvidence(true); |
| 180 | } |
| 181 | |
| 182 | UIMan.println(">>> Start Running " + this); |
| 183 | |
| 184 | Timer.start("CRF-Operator-" + crfHead.getName() + "-" + this.getId()); |
| 185 | |
| 186 | try{ |
| 187 | BufferedWriter bw = new BufferedWriter( |
| 188 | new FileWriter(Config.getLoadingDir() + "/_loading_crf_" + |
| 189 | crfHead.getName() + "_op" + this.getId())); |
| 190 | |
| 191 | this.labelDomainDMO.execute(null, new ArrayList<Integer>()); |
| 192 | int labelID = 0; |
| 193 | int nLabelFileds = crfHead.getLabelFieldsArgs().size(); |
| 194 | |
| 195 | while(this.labelDomainDMO.next()){ |
| 196 | String[] currLabel = new String[nLabelFileds]; |
| 197 | |
| 198 | for(int i=0;i<nLabelFileds;i++){ |
| 199 | currLabel[i] = this.labelDomainDMO.getNext(i+1).toString(); |
| 200 | } |
| 201 | |
| 202 | this.id2Label.put(labelID, currLabel); |
| 203 | this.label2ID.put(array2str(currLabel), labelID); |
| 204 | labelID++; |
| 205 | } |
| 206 | |
| 207 | if(id2Label.size() == 0){ |
| 208 | return; |
| 209 | } |
| 210 | |
| 211 | |
| 212 | if(crfHead.getCRFPartitionFields() != null){ |
| 213 | |
| 214 | db.disableAutoCommitForNow(); |
| 215 | this.fastInfer(bw); |
| 216 | db.commit(); |
| 217 | db.restoreAutoCommitState(); |
| 218 | |
| 219 | }else{ |
| 220 | UIMan.warn("Your CRF Rule Cannot be partitioned into different components..."); |
| 221 | |
| 222 | ExceptionMan.die("Rewritting your rules may be a better idea... " + |
| 223 | "Or simply use -noCRF option"); |
| 224 | |
| 225 | /** |
| 226 | db.disableAutoCommitForNow(); |
| 227 | this.slowInfer(bw); |
| 228 | db.commit(); |
| 229 | db.restoreAutoCommitState(); |
| 230 | **/ |
| 231 | } |
| 232 | |
| 233 | bw.close(); |
| 234 | |
| 235 | |
| 236 | FileInputStream in = new FileInputStream(Config.getLoadingDir() + |
| 237 | "/_loading_crf_" + crfHead.getName() + "_op" + this.getId()); |
| 238 | |
| 239 | PGConnection con = (PGConnection) db.getConnection(); |
| 240 | |
| 241 | String sql; |
| 242 | |
| 243 | if(options.useDualDecomposition){ |
| 244 | for(FelixPredicate fp : this.dd_CommonOutput){ |
| 245 | if(!fp.getName().equals(this.crfHead.getName())){ |
| 246 | ExceptionMan.die("ERROR: I am not fuzzy-LR/CRF/COREF!!! Contact us!!!"); |
| 247 | continue; |
| 248 | } |
| 249 | |
| 250 | |
| 251 | in = new FileInputStream(Config.getLoadingDir() + |
| 252 | "/_loading_crf_" + crfHead.getName() + "_op" + this.getId()); |
| 253 | String tableName = this.dd_commonOutputPredicate_2_tableName.get(fp); |
| 254 | |
| 255 | |
| 256 | sql = "COPY " + tableName + "(truth, prior, club, " + StringMan.commaList(crfHead.getArgs()) + " ) FROM STDIN CSV"; |
| 257 | con.getCopyAPI().copyIn(sql, in); |
| 258 | in.close(); |
| 259 | |
| 260 | } |
| 261 | |
| 262 | if(FelixConfig.isFirstRunOfDD){ |
| 263 | in = new FileInputStream(Config.getLoadingDir() + |
| 264 | "/_loading_crf_" + crfHead.getName() + "_op" + this.getId()); |
| 265 | |
| 266 | sql = "COPY " + crfHead.getRelName() + "(truth, prior, club, " + StringMan.commaList(crfHead.getArgs()) + " ) FROM STDIN CSV"; |
| 267 | con.getCopyAPI().copyIn(sql, in); |
| 268 | in.close(); |
| 269 | } |
| 270 | crfHead.isCurrentlyView = false; |
| 271 | |
| 272 | }else{ |
| 273 | |
| 274 | sql = "COPY " + crfHead.getRelName() + "(truth, prior, club, " + StringMan.commaList(crfHead.getArgs()) + " ) FROM STDIN CSV"; |
| 275 | con.getCopyAPI().copyIn(sql, in); |
| 276 | in.close(); |
| 277 | crfHead.isCurrentlyView = false; |
| 278 | |
| 279 | } |
| 280 | |
| 281 | FelixUIMan.println(0,0,"\n>>> {" + this + "} uses " + Timer.elapsed("CRF-Operator-" + crfHead.getName() + "-" + this.getId())); |
| 282 | |
| 283 | db.close(); |
| 284 | |
| 285 | if(!options.useDualDecomposition){ |
| 286 | this.belongsToBucket.runNextOperatorInBucket(); |
| 287 | } |
| 288 | |
| 289 | }catch(Exception e){ |
| 290 | e.printStackTrace(); |
| 291 | } |
| 292 | |
| 293 | } |
| 294 | |
| 295 | @Override |
| 296 | public String explain() { |
| 297 | //TODO: |
| 298 | return null; |
| 299 | } |
| 300 | |
| 301 | /** |
| 302 | * Conduct CRF infer WITH knowledge about partitioning which is parsed statically |
| 303 | * from the input program. |
| 304 | * |
| 305 | * @param bw Buffered writer to dump results. |
| 306 | */ |
| 307 | public void fastInfer(BufferedWriter bw){ |
| 308 | |
| 309 | //get the whole sequence in db based on provided seqHead |
| 310 | |
| 311 | try { |
| 312 | |
| 313 | if(this.getAllUnigramFeaturesDMO != null){ |
| 314 | this.getAllUnigramFeaturesDMO.execute(null, new ArrayList<Integer>()); |
| 315 | this.getAllUnigramFeaturesDMO.next(); |
| 316 | }else{ |
| 317 | //TODO: |
| 318 | return; |
| 319 | } |
| 320 | |
| 321 | boolean reachUnigramEnd = false; |
| 322 | |
| 323 | this.getAllPossiblePartitioningDMO.execute(null, new ArrayList<Integer>()); |
| 324 | int ctt = 0; |
| 325 | while(this.getAllPossiblePartitioningDMO.next()){ |
| 326 | |
| 327 | //System.out.println(ctt++); |
| 328 | |
| 329 | String partSignature = ""; |
| 330 | ArrayList<Integer> bindings = new ArrayList<Integer>(); |
| 331 | ArrayList<String> toSig = new ArrayList<String>(); |
| 332 | for(String s : crfHead.getCRFPartitionFields()){ |
| 333 | toSig.add(this.getAllPossiblePartitioningDMO.getNext(s) + ""); |
| 334 | bindings.add(this.getAllPossiblePartitioningDMO.getNext(s)); |
| 335 | } |
| 336 | partSignature = StringMan.commaList(toSig); |
| 337 | |
| 338 | Sequence imSeq = new Sequence(crfHead, null, id2Label, label2ID); |
| 339 | |
| 340 | String[] currLabel = new String[crfHead.getLabelFieldsArgs().size()]; |
| 341 | String[] prevLabel = new String[crfHead.getLabelFieldsArgs().size()]; |
| 342 | |
| 343 | while(true){ |
| 344 | |
| 345 | if(reachUnigramEnd == true){ |
| 346 | break; |
| 347 | } |
| 348 | |
| 349 | String curPartSignature = ""; |
| 350 | toSig = new ArrayList<String>(); |
| 351 | for(String s : crfHead.getCRFPartitionFields()){ |
| 352 | toSig.add(this.getAllUnigramFeaturesDMO.getNext(s) + ""); |
| 353 | } |
| 354 | curPartSignature = StringMan.commaList(toSig); |
| 355 | |
| 356 | if(curPartSignature.equals(partSignature)){ |
| 357 | |
| 358 | String currSignature = ""; |
| 359 | int lct = 0; |
| 360 | toSig = new ArrayList<String>(); |
| 361 | for(int i=0;i<crfHead.arity();i++){ |
| 362 | if(crfHead.getLabelPositions().contains(i)){ |
| 363 | currLabel[lct++] = this.getAllUnigramFeaturesDMO.getNext(i+1).toString(); |
| 364 | toSig.add("%s"); |
| 365 | }else{ |
| 366 | toSig.add(this.getAllUnigramFeaturesDMO.getNext(i+1).toString()); |
| 367 | } |
| 368 | } |
| 369 | currSignature = StringMan.commaList(toSig); |
| 370 | |
| 371 | Double weight = this.getAllUnigramFeaturesDMO.getNextDouble(crfHead.arity()+1); |
| 372 | |
| 373 | imSeq.registerNodeIfNotExist(currSignature); |
| 374 | imSeq.registerUnigramFeatures(currSignature, this.label2ID.get(array2str(currLabel)), weight); |
| 375 | |
| 376 | if(this.getAllUnigramFeaturesDMO.next() == null){ |
| 377 | reachUnigramEnd = true; |
| 378 | break; |
| 379 | } |
| 380 | |
| 381 | }else{ |
| 382 | break; |
| 383 | } |
| 384 | } |
| 385 | |
| 386 | this.getBigramFeaturesForPartitioningDMO.execute(null, bindings); |
| 387 | |
| 388 | /* |
| 389 | System.err.println(); |
| 390 | System.err.println(this.crfDMO.getAllFreeViewName()); |
| 391 | System.err.println(this.crfDMO.physicalQueryPlan.objectConjunctiveQuery); |
| 392 | System.err.println(this.crfDMO.physicalQueryPlan.objectPreparedStatement); |
| 393 | System.err.println(this.getBigramFeaturesForPartitioningDMO.physicalQueryPlan.objectPreparedStatement); |
| 394 | */ |
| 395 | |
| 396 | while(getBigramFeaturesForPartitioningDMO.next()){ |
| 397 | |
| 398 | String prevSignature = ""; |
| 399 | |
| 400 | toSig = new ArrayList<String>(); |
| 401 | int lct = 0; |
| 402 | for(int i=0;i<crfHead.arity();i++){ |
| 403 | if(crfHead.getLabelPositions().contains(i)){ |
| 404 | prevLabel[lct++] = this.getBigramFeaturesForPartitioningDMO.getNext(i+1).toString(); |
| 405 | toSig.add("%s"); |
| 406 | }else{ |
| 407 | toSig.add(this.getBigramFeaturesForPartitioningDMO.getNext(i+1).toString()); |
| 408 | } |
| 409 | } |
| 410 | prevSignature = StringMan.commaList(toSig); |
| 411 | |
| 412 | String currSignature = ""; |
| 413 | lct = 0; |
| 414 | toSig = new ArrayList<String>(); |
| 415 | for(int i=crfHead.arity();i< 2* crfHead.arity();i++){ |
| 416 | if(crfHead.getLabelPositions().contains(i - crfHead.arity())){ |
| 417 | currLabel[lct++] = this.getBigramFeaturesForPartitioningDMO.getNext(i+1).toString(); |
| 418 | toSig.add("%s"); |
| 419 | }else{ |
| 420 | toSig.add(this.getBigramFeaturesForPartitioningDMO.getNext(i+1).toString()); |
| 421 | } |
| 422 | } |
| 423 | currSignature = StringMan.commaList(toSig); |
| 424 | |
| 425 | Double weight = getBigramFeaturesForPartitioningDMO.getNextDouble(2*crfHead.arity()+1); |
| 426 | |
| 427 | imSeq.registerNodeIfNotExist(prevSignature); |
| 428 | imSeq.registerNodeIfNotExist(currSignature); |
| 429 | imSeq.registerBigramFeatures(prevSignature, currSignature, label2ID.get(array2str(prevLabel)), |
| 430 | label2ID.get(array2str(currLabel)), weight); |
| 431 | } |
| 432 | |
| 433 | |
| 434 | |
| 435 | imSeq.infer(); |
| 436 | imSeq.dumpAnswers(bw); |
| 437 | |
| 438 | } |
| 439 | |
| 440 | } catch (Exception e) { |
| 441 | ExceptionMan.die("unconsistent value!"); |
| 442 | e.printStackTrace(); |
| 443 | } |
| 444 | } |
| 445 | |
| 446 | /** |
| 447 | * Conduct CRF infer WITHOUT knowledge about partitioning which is parsed statically |
| 448 | * from the input program. |
| 449 | * |
| 450 | * @deprecated |
| 451 | * |
| 452 | * @param bw Buffered writer to dump results. |
| 453 | */ |
| 454 | public void slowInfer(BufferedWriter bw){ |
| 455 | |
| 456 | //get the whole sequence in db based on provided seqHead |
| 457 | |
| 458 | try { |
| 459 | |
| 460 | if(this.getAllUnigramFeaturesDMO != null){ |
| 461 | this.getAllUnigramFeaturesDMO.execute(null, new ArrayList<Integer>()); |
| 462 | } |
| 463 | if(this.getAllBigramFeaturesDMO != null){ |
| 464 | this.getAllBigramFeaturesDMO.execute(null, new ArrayList<Integer>()); |
| 465 | } |
| 466 | |
| 467 | Sequence imSeq = new Sequence(crfHead, null, id2Label, label2ID); |
| 468 | String[] currLabel = new String[crfHead.getLabelFieldsArgs().size()]; |
| 469 | String[] prevLabel = new String[crfHead.getLabelFieldsArgs().size()]; |
| 470 | ArrayList<String> toSig = new ArrayList<String>(); |
| 471 | |
| 472 | if(this.getAllUnigramFeaturesDMO != null){ |
| 473 | while(this.getAllUnigramFeaturesDMO.next()){ |
| 474 | |
| 475 | String currSignature = ""; |
| 476 | int lct = 0; |
| 477 | toSig = new ArrayList<String>(); |
| 478 | for(int i=0;i<crfHead.arity();i++){ |
| 479 | if(crfHead.getLabelPositions().contains(i)){ |
| 480 | currLabel[lct++] = this.getAllUnigramFeaturesDMO.getNext(i+1).toString(); |
| 481 | toSig.add("%s"); |
| 482 | }else{ |
| 483 | toSig.add(this.getAllUnigramFeaturesDMO.getNext(i+1).toString()); |
| 484 | } |
| 485 | } |
| 486 | currSignature = StringMan.commaList(toSig); |
| 487 | |
| 488 | Double weight = this.getAllUnigramFeaturesDMO.getNextDouble(crfHead.arity()+1); |
| 489 | |
| 490 | imSeq.registerNodeIfNotExist(currSignature); |
| 491 | imSeq.registerUnigramFeatures(currSignature, this.label2ID.get(array2str(currLabel)), weight); |
| 492 | } |
| 493 | } |
| 494 | |
| 495 | |
| 496 | |
| 497 | if(this.getAllBigramFeaturesDMO != null){ |
| 498 | while(this.getAllBigramFeaturesDMO.next()){ |
| 499 | |
| 500 | String prevSignature = ""; |
| 501 | |
| 502 | toSig = new ArrayList<String>(); |
| 503 | int lct = 0; |
| 504 | for(int i=0;i<crfHead.arity();i++){ |
| 505 | if(crfHead.getLabelPositions().contains(i)){ |
| 506 | prevLabel[lct++] = this.getAllBigramFeaturesDMO.getNext(i+1).toString(); |
| 507 | toSig.add("%s"); |
| 508 | }else{ |
| 509 | toSig.add(this.getAllBigramFeaturesDMO.getNext(i+1).toString()); |
| 510 | } |
| 511 | } |
| 512 | prevSignature = StringMan.commaList(toSig); |
| 513 | |
| 514 | String currSignature = ""; |
| 515 | lct = 0; |
| 516 | toSig = new ArrayList<String>(); |
| 517 | for(int i=crfHead.arity();i< 2* crfHead.arity();i++){ |
| 518 | if(crfHead.getLabelPositions().contains(i - crfHead.arity())){ |
| 519 | currLabel[lct++] = this.getAllBigramFeaturesDMO.getNext(i+1).toString(); |
| 520 | toSig.add("%s"); |
| 521 | }else{ |
| 522 | toSig.add(this.getAllBigramFeaturesDMO.getNext(i+1).toString()); |
| 523 | } |
| 524 | } |
| 525 | currSignature = StringMan.commaList(toSig); |
| 526 | |
| 527 | Double weight = getAllBigramFeaturesDMO.getNextDouble(2*crfHead.arity()+1); |
| 528 | |
| 529 | imSeq.registerNodeIfNotExist(prevSignature); |
| 530 | imSeq.registerNodeIfNotExist(currSignature); |
| 531 | imSeq.registerBigramFeatures(prevSignature, currSignature, label2ID.get(array2str(prevLabel)), |
| 532 | label2ID.get(array2str(currLabel)), weight); |
| 533 | |
| 534 | } |
| 535 | } |
| 536 | |
| 537 | imSeq.infer(); |
| 538 | imSeq.dumpAnswers(bw); |
| 539 | |
| 540 | |
| 541 | } catch (Exception e) { |
| 542 | e.printStackTrace(); |
| 543 | } |
| 544 | } |
| 545 | |
| 546 | /** |
| 547 | * Generate Data Movement Operator used by this CRF Operator. |
| 548 | * @param rules rules defining this operator. |
| 549 | */ |
| 550 | public void prepareDMO(HashSet<ConjunctiveQuery> lrQueries, HashSet<ConjunctiveQuery> chainQueries){ |
| 551 | |
| 552 | try { |
| 553 | |
| 554 | // DMO for LR rules |
| 555 | for(ConjunctiveQuery cq : lrQueries){ |
| 556 | |
| 557 | DataMovementOperator dmo = new DataMovementOperator(db, this); |
| 558 | dmo.logicQueryPlan.addQuery(cq, cq.head.getPred().getArgs(), |
| 559 | new ArrayList<String>(Arrays.asList("weight")) ); |
| 560 | |
| 561 | dmo.predictedBB = 0; |
| 562 | dmo.PredictedFF = 1; |
| 563 | dmo.PredictedBF = 0; |
| 564 | |
| 565 | dmo.allowOptimization = false; |
| 566 | |
| 567 | allDMOs.add(dmo); |
| 568 | lrDMOs.add(dmo); |
| 569 | } |
| 570 | |
| 571 | // DMOs for CRF chain rules |
| 572 | for(ConjunctiveQuery cq : chainQueries){ |
| 573 | |
| 574 | DataMovementOperator dmo = new DataMovementOperator(db, this); |
| 575 | dmo.logicQueryPlan.addQuery(cq, cq.head.getPred().getArgs(), |
| 576 | new ArrayList<String>(Arrays.asList("weight")) ); |
| 577 | |
| 578 | dmo.predictedBB = 0; |
| 579 | dmo.PredictedFF = 0; |
| 580 | dmo.PredictedBF = this.getPartitionSize(); |
| 581 | |
| 582 | dmo.allowOptimization = true; |
| 583 | |
| 584 | allDMOs.add(dmo); |
| 585 | crfDMO = dmo; |
| 586 | } |
| 587 | |
| 588 | |
| 589 | //the DMO for the union of all LR DMOs |
| 590 | if(this.lrDMOs.size() > 0){ |
| 591 | this.lrDMO = DataMovementOperator.UnionAll(db, this, |
| 592 | this.lrDMOs, StringMan.zeros(this.crfHead.arity() + 1), new ArrayList<Integer>()); |
| 593 | allDMOs.add(lrDMO); |
| 594 | |
| 595 | //the DMO for all unigram features. |
| 596 | this.unigramDMO = new DataMovementOperator(db, this); |
| 597 | this.unigramDMO.allowOptimization = false; |
| 598 | this.unigramDMO.asView = false; |
| 599 | this.unigramDMO.logicQueryPlan.addQuery(db.getPrepareStatement( |
| 600 | "SELECT " + StringMan.commaList(lrDMO.selListFromRule) + ", sum(weight) AS sumweight " + |
| 601 | " FROM " + lrDMO.getAllFreeViewName() + " GROUP BY " + StringMan.commaList(lrDMO.selListFromRule) + |
| 602 | (crfHead.getCRFPartitionFields() == null ? "" : |
| 603 | " ORDER BY " + StringMan.commaList(crfHead.getCRFPartitionFields())) ), |
| 604 | lrDMO.selListFromRule, new ArrayList<String>(Arrays.asList("sumweight"))); |
| 605 | allDMOs.add(this.unigramDMO); |
| 606 | } |
| 607 | |
| 608 | //the DMO for the label domain |
| 609 | this.labelDomainDMO = new DataMovementOperator(db, this); |
| 610 | this.labelDomainDMO.allowOptimization = false; |
| 611 | ArrayList<String> fields = crfHead.getLabelFieldsTypeTable(); |
| 612 | ArrayList<String> fieldsViewName = new ArrayList<String>(); |
| 613 | for(int i=0;i<fields.size();i++){ |
| 614 | fieldsViewName.add("t" + i + "." + "constantid AS l" + i); |
| 615 | fields.set(i, fields.get(i) + " t" + i); |
| 616 | } |
| 617 | this.labelDomainDMO.logicQueryPlan.addQuery(db.getPrepareStatement( |
| 618 | "SELECT " + StringMan.commaList(fieldsViewName) +" FROM " + FelixStringMan.commaList(fields)), |
| 619 | crfHead.getLabelFieldsArgs(), |
| 620 | new ArrayList<String>()); |
| 621 | allDMOs.add(this.labelDomainDMO); |
| 622 | |
| 623 | // check whether we can use partitioning, or we must use a slow |
| 624 | // version which detects the head for each sequence using recursive SQL. |
| 625 | if(crfHead.getCRFPartitionFields() == null){ |
| 626 | |
| 627 | }else{ |
| 628 | |
| 629 | if(this.unigramDMO != null){ |
| 630 | this.getAllPossiblePartitioningDMO = new DataMovementOperator(db, this); |
| 631 | this.getAllPossiblePartitioningDMO.allowOptimization = false; |
| 632 | this.getAllPossiblePartitioningDMO.logicQueryPlan.addQuery(db.getPrepareStatement( |
| 633 | "SELECT DISTINCT * FROM ( " + |
| 634 | //"SELECT " + StringMan.commaList(crfHead.getCRFPartitionFields()) |
| 635 | // + " FROM " + crfDMO.getAllFreeViewName() + " UNION " + |
| 636 | " SELECT " + FelixStringMan.commaList(crfHead.getCRFPartitionFields()) |
| 637 | + " FROM " + unigramDMO.getAllFreeViewName() + ") nt ORDER BY " + |
| 638 | FelixStringMan.commaList(crfHead.getCRFPartitionFields())), crfHead.getCRFPartitionFields(), |
| 639 | new ArrayList<String>()); |
| 640 | allDMOs.add(this.getAllPossiblePartitioningDMO); |
| 641 | }else{ |
| 642 | this.getAllPossiblePartitioningDMO = new DataMovementOperator(db, this); |
| 643 | this.getAllPossiblePartitioningDMO.allowOptimization = false; |
| 644 | this.getAllPossiblePartitioningDMO.logicQueryPlan.addQuery(db.getPrepareStatement( |
| 645 | "SELECT DISTINCT * FROM ( SELECT " + FelixStringMan.commaList(crfHead.getCRFPartitionFields()) |
| 646 | + " FROM " + crfDMO.getAllFreeViewName() + ") nt ORDER BY " + |
| 647 | FelixStringMan.commaList(crfHead.getCRFPartitionFields())), crfHead.getCRFPartitionFields(), |
| 648 | new ArrayList<String>()); |
| 649 | allDMOs.add(this.getAllPossiblePartitioningDMO); |
| 650 | } |
| 651 | |
| 652 | |
| 653 | this.getBigramFeaturesForPartitioningDMO = DataMovementOperator.Select(db, this, |
| 654 | this.crfDMO, crfHead.getCRFPartitionFields()); |
| 655 | this.getBigramFeaturesForPartitioningDMO.isIntermediaDMO = true; |
| 656 | this.getBigramFeaturesForPartitioningDMO.hasKnownFetchingOrder = true; |
| 657 | allDMOs.add(this.getBigramFeaturesForPartitioningDMO); |
| 658 | |
| 659 | } |
| 660 | |
| 661 | this.getAllBigramFeaturesDMO = DataMovementOperator.Select(db, this, |
| 662 | this.crfDMO, new ArrayList<String>()); |
| 663 | this.getAllBigramFeaturesDMO.isIntermediaDMO = true; |
| 664 | this.getAllBigramFeaturesDMO.hasKnownFetchingOrder = true; |
| 665 | allDMOs.add(this.getAllBigramFeaturesDMO); |
| 666 | |
| 667 | if(this.unigramDMO != null){ |
| 668 | this.getAllUnigramFeaturesDMO = DataMovementOperator.SelectOrderBy(db, this, |
| 669 | this.unigramDMO, new ArrayList<String>(), |
| 670 | " ORDER BY " + StringMan.commaList(crfHead.getCRFPartitionFields())); |
| 671 | this.getAllUnigramFeaturesDMO.isIntermediaDMO = true; |
| 672 | this.getAllUnigramFeaturesDMO.hasKnownFetchingOrder = true; |
| 673 | allDMOs.add(this.getAllUnigramFeaturesDMO); |
| 674 | } |
| 675 | |
| 676 | } catch (Exception e) { |
| 677 | e.printStackTrace(); |
| 678 | } |
| 679 | } |
| 680 | |
| 681 | /** |
| 682 | * Estimate the number of sequences. |
| 683 | * @return |
| 684 | */ |
| 685 | public int getPartitionSize(){ |
| 686 | if(crfHead.getCRFPartitionFields() == null){ |
| 687 | return 1; |
| 688 | } |
| 689 | |
| 690 | int maxSingleField = -1; |
| 691 | for(int i=0;i<crfHead.getArgs().size();i++){ |
| 692 | if(crfHead.getCRFPartitionFields().contains(crfHead.getArgs().get(i))){ |
| 693 | // just an estimate |
| 694 | if(maxSingleField < crfHead.getTypeAt(i).size()){ |
| 695 | maxSingleField = crfHead.getTypeAt(i).size(); |
| 696 | } |
| 697 | } |
| 698 | } |
| 699 | if(maxSingleField > 0){ |
| 700 | return maxSingleField/this.partitionedInto; |
| 701 | } |
| 702 | |
| 703 | return Integer.MAX_VALUE; |
| 704 | } |
| 705 | |
| 706 | |
| 707 | /** |
| 708 | * Returns sum of given log numbers. |
| 709 | * @param logX |
| 710 | * @param logY |
| 711 | * @return |
| 712 | */ |
| 713 | //ACKNOWLEDGE: FROM https://facwiki.cs.byu.edu/nlp/index.php/Log_Domain_Computations |
| 714 | //COPYRIGHT OF THIS FUNCTION BELONGS TO ITS ORIGINAL AUTHOR |
| 715 | public static double logAdd(double logX, double logY) { |
| 716 | |
| 717 | if (logY > logX) { |
| 718 | double temp = logX; |
| 719 | logX = logY; |
| 720 | logY = temp; |
| 721 | } |
| 722 | |
| 723 | if (logX == Double.NEGATIVE_INFINITY) { |
| 724 | return logX; |
| 725 | } |
| 726 | |
| 727 | double negDiff = logY - logX; |
| 728 | //if (negDiff < -1000000) { |
| 729 | // return logX; |
| 730 | //} |
| 731 | |
| 732 | return logX + java.lang.Math.log(1.0 + java.lang.Math.exp(negDiff)); |
| 733 | } |
| 734 | |
| 735 | |
| 736 | /** |
| 737 | * Class for a node in the {@link Sequence}. |
| 738 | */ |
| 739 | class Node{ |
| 740 | public double[][] prev2currBigram; |
| 741 | public double[] currUnigram; |
| 742 | |
| 743 | public double[] forwardSum; |
| 744 | public double[] backwardSum; |
| 745 | public double[] currentMax; |
| 746 | public int[] prevArgMax; |
| 747 | public int[] nextArgMax; |
| 748 | |
| 749 | public Node(int nOfLabel){ |
| 750 | prev2currBigram = new double[nOfLabel][nOfLabel]; |
| 751 | for(int i=0;i<nOfLabel;i++){ |
| 752 | for(int j=0;j<nOfLabel;j++){ |
| 753 | prev2currBigram[i][j] = 0; |
| 754 | } |
| 755 | } |
| 756 | |
| 757 | currUnigram = new double[nOfLabel]; |
| 758 | forwardSum = new double[nOfLabel]; |
| 759 | backwardSum = new double[nOfLabel]; |
| 760 | currentMax = new double[nOfLabel]; |
| 761 | prevArgMax = new int[nOfLabel]; |
| 762 | nextArgMax = new int[nOfLabel]; |
| 763 | for(int i=0;i<nOfLabel;i++){ |
| 764 | currUnigram[i] = 0; |
| 765 | forwardSum[i] = Double.NEGATIVE_INFINITY; |
| 766 | backwardSum[i] = Double.NEGATIVE_INFINITY; |
| 767 | currentMax[i] = Double.NEGATIVE_INFINITY; |
| 768 | prevArgMax[i] = -1; |
| 769 | nextArgMax[i] = -1; |
| 770 | } |
| 771 | |
| 772 | |
| 773 | } |
| 774 | } |
| 775 | |
| 776 | /** |
| 777 | * In-memory representation of a CRF chain. This class |
| 778 | * supports infer (both marginal and MAP) and dumps results to file. |
| 779 | * |
| 780 | * @author Ce Zhang |
| 781 | * |
| 782 | */ |
| 783 | class Sequence{ |
| 784 | |
| 785 | // although I personally think using string instead of integer id will not |
| 786 | // be so slow (because the sequence is normally short), we may like to try a |
| 787 | // pure-integer version if we think current Viterbi is slow... |
| 788 | |
| 789 | /** |
| 790 | * The predicate to be labeled. |
| 791 | */ |
| 792 | Predicate pred = null; |
| 793 | |
| 794 | /** |
| 795 | * The signature of the root node. |
| 796 | */ |
| 797 | String rootSignature = ""; |
| 798 | |
| 799 | /** |
| 800 | * The signature of the last node |
| 801 | */ |
| 802 | String lastSignature = ""; |
| 803 | |
| 804 | /** |
| 805 | * Map from signature to Node object. |
| 806 | */ |
| 807 | HashMap<String, Node> signature2Node = new HashMap<String, Node>(); |
| 808 | |
| 809 | /** |
| 810 | * Set of all roots nodes in this sequence. |
| 811 | */ |
| 812 | HashSet<String> roots = new HashSet<String>(); |
| 813 | |
| 814 | /** |
| 815 | * Set of all last nodes in this sequence. |
| 816 | */ |
| 817 | HashSet<String> lasts = new HashSet<String>(); |
| 818 | |
| 819 | /** |
| 820 | * The optimal labels for the last nodes, which are used in MAP inference. |
| 821 | */ |
| 822 | HashMap<String, Integer> last2maxArg = new HashMap<String, Integer>(); |
| 823 | |
| 824 | /** |
| 825 | * Map from one node to the next node in the chain. |
| 826 | */ |
| 827 | HashMap<String, String> next = new HashMap<String, String>(); |
| 828 | |
| 829 | /** |
| 830 | * Map from one node to the previous node in the chain. |
| 831 | */ |
| 832 | HashMap<String, String> prev = new HashMap<String, String>(); |
| 833 | |
| 834 | /** |
| 835 | * See {@link CRFOperator#label2ID}. |
| 836 | */ |
| 837 | HashMap<Integer, String[]> id2Label = new HashMap<Integer, String[]>(); |
| 838 | |
| 839 | /** |
| 840 | * See {@link CRFOperator#id2Label}. |
| 841 | */ |
| 842 | HashMap<String, Integer> label2ID = new HashMap<String, Integer>(); |
| 843 | |
| 844 | /** |
| 845 | * the constructor. |
| 846 | * @param _p the predicate to be labeled. |
| 847 | * @param _rootSignature the root of this sequence, which can be null (in this case, |
| 848 | * this class will find roots before infer). |
| 849 | * @param _id2Label |
| 850 | * @param _label2ID |
| 851 | */ |
| 852 | public Sequence(Predicate _p , String _rootSignature, HashMap<Integer, String[]> _id2Label, HashMap<String, Integer> _label2ID){ |
| 853 | pred = _p; |
| 854 | rootSignature = _rootSignature; |
| 855 | if(rootSignature != null){ |
| 856 | signature2Node.put(rootSignature, new Node(_id2Label.size())); |
| 857 | } |
| 858 | id2Label = _id2Label; |
| 859 | label2ID = _label2ID; |
| 860 | } |
| 861 | |
| 862 | |
| 863 | /** |
| 864 | * Add a node in this sequence with a given signature. |
| 865 | * @param _signature |
| 866 | */ |
| 867 | public void registerNodeIfNotExist(String _signature){ |
| 868 | |
| 869 | if(signature2Node.containsKey(_signature)){ |
| 870 | return; |
| 871 | } |
| 872 | |
| 873 | signature2Node.put(_signature, new Node(this.id2Label.size())); |
| 874 | } |
| 875 | |
| 876 | /** |
| 877 | * Add a bigram feature for a node with signature _currSignature and label _currLabel. |
| 878 | * @param _prevSignature |
| 879 | * @param _currSignature |
| 880 | * @param _prevLabel |
| 881 | * @param _currLabel |
| 882 | * @param _weight |
| 883 | */ |
| 884 | public void registerBigramFeatures(String _prevSignature, String _currSignature, int _prevLabel, int _currLabel, Double _weight){ |
| 885 | |
| 886 | assert this.signature2Node.containsKey(_prevSignature); |
| 887 | assert this.signature2Node.containsKey(_currSignature); |
| 888 | |
| 889 | if(next.containsKey(_prevSignature)){ |
| 890 | assert next.get(_prevSignature).equals(_currSignature); |
| 891 | }else{ |
| 892 | next.put(_prevSignature, _currSignature); |
| 893 | } |
| 894 | |
| 895 | if(prev.containsKey(_currSignature)){ |
| 896 | assert prev.get(_currSignature).equals(_prevSignature); |
| 897 | }else{ |
| 898 | prev.put(_currSignature, _prevSignature); |
| 899 | } |
| 900 | |
| 901 | Node tmpNode = this.signature2Node.get(_currSignature); |
| 902 | tmpNode.prev2currBigram[_prevLabel][_currLabel] += _weight; |
| 903 | |
| 904 | } |
| 905 | |
| 906 | /** |
| 907 | * Add a unigram feature for a node with signature _currSignature and label _currLabel. |
| 908 | * @param _currSignature |
| 909 | * @param _currLabel |
| 910 | * @param _weight |
| 911 | */ |
| 912 | public void registerUnigramFeatures(String _currSignature, int _currLabel, Double _weight){ |
| 913 | assert this.signature2Node.containsKey(_currSignature); |
| 914 | |
| 915 | Node tmpNode = this.signature2Node.get(_currSignature); |
| 916 | tmpNode.currUnigram[_currLabel] = _weight; |
| 917 | } |
| 918 | |
| 919 | /** |
| 920 | * Infer on this sequence. |
| 921 | */ |
| 922 | public void infer(){ |
| 923 | |
| 924 | //find root |
| 925 | if(rootSignature == null){ |
| 926 | Set<String> nodes = new HashSet<String>(); |
| 927 | nodes.addAll(this.signature2Node.keySet()); |
| 928 | |
| 929 | while(true){ |
| 930 | if(nodes.size() == 0){ |
| 931 | break; |
| 932 | } |
| 933 | |
| 934 | String rootCandidate = nodes.iterator().next(); |
| 935 | while(true){ |
| 936 | if(this.next.get(rootCandidate) != null){ |
| 937 | rootCandidate = this.next.get(rootCandidate); |
| 938 | }else{ |
| 939 | break; |
| 940 | } |
| 941 | } |
| 942 | |
| 943 | while(true){ |
| 944 | nodes.remove(rootCandidate); |
| 945 | if(this.prev.get(rootCandidate) != null){ |
| 946 | rootCandidate = this.prev.get(rootCandidate); |
| 947 | }else{ |
| 948 | rootSignature = rootCandidate; |
| 949 | roots.add(rootSignature); |
| 950 | break; |
| 951 | } |
| 952 | } |
| 953 | |
| 954 | } |
| 955 | }else{ |
| 956 | roots.add(rootSignature); |
| 957 | } |
| 958 | |
| 959 | |
| 960 | //forward |
| 961 | for(String sssss : roots){ |
| 962 | this.rootSignature = sssss; |
| 963 | String current = this.rootSignature; |
| 964 | while(true){ |
| 965 | if(current.equals(rootSignature)){ |
| 966 | Node n = this.signature2Node.get(current); |
| 967 | // foreach label |
| 968 | for(int i = 0; i < n.forwardSum.length; i++){ |
| 969 | n.forwardSum[i] = n.currUnigram[i]; |
| 970 | n.currentMax[i] = n.currUnigram[i]; |
| 971 | } |
| 972 | }else{ |
| 973 | Node n = this.signature2Node.get(current); |
| 974 | Node p = this.signature2Node.get(this.prev.get(current)); |
| 975 | |
| 976 | for(int i = 0; i < n.forwardSum.length; i++){ |
| 977 | int maxArg = -1; |
| 978 | double maxValue = Double.NEGATIVE_INFINITY; |
| 979 | for(int j = 0; j< p.forwardSum.length; j++){ |
| 980 | n.forwardSum[i] = logAdd(n.prev2currBigram[j][i] + n.currUnigram[i] + p.forwardSum[j], n.forwardSum[i]); |
| 981 | double tmp = n.prev2currBigram[j][i] + n.currUnigram[i] + p.currentMax[j]; |
| 982 | if( tmp > maxValue ){ |
| 983 | maxValue = tmp; |
| 984 | maxArg = j; |
| 985 | } |
| 986 | } |
| 987 | n.prevArgMax[i] = maxArg; |
| 988 | n.currentMax[i] = maxValue; |
| 989 | } |
| 990 | } |
| 991 | if(!this.next.containsKey(current)){ |
| 992 | |
| 993 | Node n = this.signature2Node.get(current); |
| 994 | this.lastSignature = current; |
| 995 | |
| 996 | int maxArg = -1; |
| 997 | double maxValue = Double.NEGATIVE_INFINITY; |
| 998 | for(int i = 0; i < n.forwardSum.length; i++){ |
| 999 | if(n.currentMax[i] > maxValue){ |
| 1000 | maxArg = i; |
| 1001 | maxValue = n.currentMax[i]; |
| 1002 | } |
| 1003 | } |
| 1004 | |
| 1005 | this.lasts.add(current); |
| 1006 | this.last2maxArg.put(current, maxArg); |
| 1007 | break; |
| 1008 | } |
| 1009 | current = this.next.get(current); |
| 1010 | } |
| 1011 | |
| 1012 | //backward |
| 1013 | String last = current; |
| 1014 | while(true){ |
| 1015 | if(current.equals(last)){ |
| 1016 | Node n = this.signature2Node.get(current); |
| 1017 | // foreach label |
| 1018 | for(int i = 0; i < n.backwardSum.length; i++){ |
| 1019 | n.backwardSum[i] = n.currUnigram[i]; |
| 1020 | } |
| 1021 | }else{ |
| 1022 | Node p = this.signature2Node.get(current); |
| 1023 | Node n = this.signature2Node.get(this.next.get(current)); |
| 1024 | |
| 1025 | for(int i = 0; i < p.backwardSum.length; i++){ |
| 1026 | for(int j = 0; j< n.backwardSum.length; j++){ |
| 1027 | p.backwardSum[i] = logAdd(n.prev2currBigram[i][j] + n.currUnigram[j] + n.backwardSum[j], p.backwardSum[i]); |
| 1028 | } |
| 1029 | } |
| 1030 | } |
| 1031 | |
| 1032 | if(!this.prev.containsKey(current)){ |
| 1033 | break; |
| 1034 | } |
| 1035 | current = this.prev.get(current); |
| 1036 | } |
| 1037 | } |
| 1038 | } |
| 1039 | |
| 1040 | /** |
| 1041 | * Dump answers to the given buffered writer. These answers are in a format that can be |
| 1042 | * COPY into postgres table directly. |
| 1043 | * @param bw |
| 1044 | */ |
| 1045 | public void dumpAnswers(BufferedWriter bw){ |
| 1046 | try{ |
| 1047 | |
| 1048 | if(isMarginal || (FelixConfig.isFirstRunOfDD && dd_commonOutputPredicate_2_tableName.containsKey(crfHead) )){ |
| 1049 | for(String sssss : roots){ |
| 1050 | this.rootSignature = sssss; |
| 1051 | String current = this.rootSignature; |
| 1052 | while(true){ |
| 1053 | |
| 1054 | if(current.equals(rootSignature)){ |
| 1055 | Node n = this.signature2Node.get(current); |
| 1056 | double sum = Double.NEGATIVE_INFINITY; |
| 1057 | for(int i = 0; i < n.forwardSum.length; i++){ |
| 1058 | sum = logAdd(n.currUnigram[i] + n.backwardSum[i], sum); |
| 1059 | } |
| 1060 | for(int i = 0; i < n.forwardSum.length; i++){ |
| 1061 | double prob = n.currUnigram[i] + n.backwardSum[i] - sum; |
| 1062 | prob = Math.exp(prob); |
| 1063 | if(prob > Config.soft_evidence_activation_threshold){ |
| 1064 | ArrayList<String> parts = new ArrayList<String>(); |
| 1065 | // parts.add(Integer.toString(pred.nextTupleIDAndUpdate())); |
| 1066 | parts.add("TRUE"); |
| 1067 | parts.add(Double.toString(prob)); |
| 1068 | |
| 1069 | if(options.useDualDecomposition){ |
| 1070 | parts.add(Integer.toString(2)); |
| 1071 | }else{ |
| 1072 | parts.add(Integer.toString(2)); |
| 1073 | } |
| 1074 | // parts.add("1");//this is for vote |
| 1075 | |
| 1076 | //String tmp = current.replaceAll("\\?", id2Label.get(i)); |
| 1077 | String tmp = String.format(current, (Object[]) id2Label.get(i)); |
| 1078 | |
| 1079 | bw.append(FelixStringMan.commaListNoSpace(parts) + "," + tmp + "\n"); |
| 1080 | } |
| 1081 | } |
| 1082 | |
| 1083 | }else{ |
| 1084 | Node n = this.signature2Node.get(current); |
| 1085 | Node p = this.signature2Node.get(this.prev.get(current)); |
| 1086 | |
| 1087 | double sum = Double.NEGATIVE_INFINITY; |
| 1088 | for(int i = 0; i < n.forwardSum.length; i++){ |
| 1089 | for(int j = 0; j< p.forwardSum.length; j++){ |
| 1090 | sum = logAdd(p.forwardSum[j] + n.currUnigram[i] + n.prev2currBigram[j][i] |
| 1091 | + n.backwardSum[i], sum); |
| 1092 | } |
| 1093 | } |
| 1094 | |
| 1095 | for(int i = 0; i < n.forwardSum.length; i++){ |
| 1096 | double marginal = 0; |
| 1097 | |
| 1098 | for(int j = 0; j< p.forwardSum.length; j++){ |
| 1099 | marginal = logAdd( p.forwardSum[j] + n.currUnigram[i] + n.prev2currBigram[j][i] + n.backwardSum[i], |
| 1100 | marginal ); |
| 1101 | } |
| 1102 | |
| 1103 | double prob = Math.exp(marginal - sum); |
| 1104 | |
| 1105 | if(prob > Config.soft_evidence_activation_threshold){ |
| 1106 | ArrayList<String> parts = new ArrayList<String>(); |
| 1107 | // parts.add(Integer.toString(pred.nextTupleIDAndUpdate())); |
| 1108 | parts.add("TRUE"); |
| 1109 | parts.add(Double.toString(prob)); |
| 1110 | |
| 1111 | if(options.useDualDecomposition){ |
| 1112 | parts.add(Integer.toString(2)); |
| 1113 | }else{ |
| 1114 | parts.add(Integer.toString(2)); |
| 1115 | } |
| 1116 | // parts.add("1");//this is for vote |
| 1117 | |
| 1118 | String tmp = String.format(current, (Object[]) id2Label.get(i)); |
| 1119 | bw.append(FelixStringMan.commaListNoSpace(parts) + "," + tmp + "\n"); |
| 1120 | } |
| 1121 | |
| 1122 | } |
| 1123 | |
| 1124 | } |
| 1125 | |
| 1126 | if(!this.next.containsKey(current)){ |
| 1127 | break; |
| 1128 | } |
| 1129 | current = this.next.get(current); |
| 1130 | } |
| 1131 | } |
| 1132 | }else{ |
| 1133 | |
| 1134 | for(String sssss : lasts){ |
| 1135 | |
| 1136 | String current = sssss; |
| 1137 | |
| 1138 | int toDump = this.last2maxArg.get(sssss); |
| 1139 | |
| 1140 | while(true){ |
| 1141 | |
| 1142 | Node n = this.signature2Node.get(current); |
| 1143 | |
| 1144 | ArrayList<String> parts = new ArrayList<String>(); |
| 1145 | // parts.add(Integer.toString(pred.nextTupleIDAndUpdate())); |
| 1146 | parts.add("TRUE"); |
| 1147 | parts.add(""); |
| 1148 | |
| 1149 | if(options.useDualDecomposition){ |
| 1150 | parts.add(Integer.toString(2)); |
| 1151 | }else{ |
| 1152 | parts.add(Integer.toString(2)); |
| 1153 | } |
| 1154 | // parts.add("1");//this is for vote |
| 1155 | |
| 1156 | String tmp = String.format(current, (Object[]) id2Label.get(toDump)); |
| 1157 | bw.append(FelixStringMan.commaListNoSpace(parts) + "," + tmp + "\n"); |
| 1158 | |
| 1159 | if(!this.prev.containsKey(current)){ |
| 1160 | break; |
| 1161 | } |
| 1162 | |
| 1163 | toDump = n.prevArgMax[toDump]; |
| 1164 | current = this.prev.get(current); |
| 1165 | } |
| 1166 | } |
| 1167 | |
| 1168 | } |
| 1169 | }catch(Exception e){ |
| 1170 | e.printStackTrace(); |
| 1171 | } |
| 1172 | |
| 1173 | } |
| 1174 | } |
| 1175 | |
| 1176 | @Override |
| 1177 | public void learn() { |
| 1178 | |
| 1179 | } |
| 1180 | |
| 1181 | } |