tuffy.infer
Class MRF

java.lang.Object
  extended by tuffy.infer.MRF

public class MRF
extends java.lang.Object

In-memory data structure representing an MRF.


Nested Class Summary
static class MRF.INIT_STRATEGY
           
 
Field Summary
protected  java.util.HashMap<java.lang.Integer,java.util.ArrayList<GClause>> adj
          Index from GAtom ID to GClause.
 java.util.HashMap<java.lang.Integer,GAtom> atoms
          Map from GAtom ID to GAtom object.
 java.util.HashMap<java.lang.String,java.lang.Long> clauseNiNjViolationTallies
          This map records the tallies for calculating E(v_i*v_j).
 java.util.ArrayList<GClause> clauses
          Array of all GClause objects in this MRF.
 java.util.HashMap<java.lang.String,java.lang.Long> clauseSatTallies
          This array records total number of satisfaction for a clause.
 java.util.HashMap<java.lang.String,java.lang.Long> clauseSquareVioTallies
          This array records total number of square violation for a clause.
 java.util.HashMap<java.lang.String,java.lang.Long> clauseVioTallies
          This array records total number of violation for a clause.
protected  java.util.HashSet<java.lang.Integer> dirtyAtoms
          Atoms that have been flipped since last saving to low.
 java.util.HashMap<java.lang.String,java.lang.Double> expectationOfNiNjViolation
          This map records the expectation of E(v_i*v_j).
 java.util.HashMap<java.lang.String,java.lang.Double> expectationOfSatisfication
          This array records the expection of #satisfaction for each clause.
 java.util.HashMap<java.lang.String,java.lang.Double> expectationOfSquareViolation
          This map records the expectation of square #violation for each clause.
 java.util.HashMap<java.lang.String,java.lang.Double> expectationOfViolation
          This map records the expectation of #violation for each clause.
 long inferOps
           
protected  MRF.INIT_STRATEGY initStrategy
           
 KeyBlock keyBlock
           
 double lowCost
          Lowest cost ever seen.
 boolean ownsAllAtoms
           
protected  boolean sampleSatMode
          The flag indicating whether MCSAT is running WalkSAT or SampleSAT.
protected  int totalAlive
          Number of GClauses that is selected, and therefore must be satisfied by next SampleSAT invocation of MCSAT.
protected  double totalCost
          The total cost of this MRF under current atoms' truth setting.
protected  HashArray<GClause> unsat
          Array of unsatisfied GClauses under current atoms' truth setting.
 
Constructor Summary
MRF(MarkovLogicNetwork mln)
          Default constructor.
MRF(MarkovLogicNetwork mln, int partID, java.util.HashMap<java.lang.Integer,GAtom> gatoms)
           
 
Method Summary
 void addAtom(int aid)
          Add an atom into this MRF.
 void auditClauseViolations()
          Track ground clause violations to fo-clauses.
protected  void buildIndices()
          Build literal-->clauses index.
protected  double calcCosts()
          Compute total cost and per-atom delta cost.
 void calcExpViolation()
          Calculating the different expectations by filling the HashMaps related to expectations in this class.
 void discard()
          Discard all data structures, in hope of facilitating faster GC.
protected  void enableAllClauses()
          Reset all clauses to be alive.
protected  void fixAtom(int aid, boolean t)
          Fix the truth value of an atom.
 java.util.HashSet<java.lang.Integer> getCoreAtoms()
           
 double getCost()
           
 MRF.INIT_STRATEGY getInitStrategy()
           
 MarkovLogicNetwork getMLN()
           
 void inferWalkSAT(int nTries, int nSteps)
          Run WalkSAT.
 void initMRF()
          Initialize the state of the MRF.
 void invalidateLowCost()
          Reset low-cost to infinity.
protected  boolean isAlwaysTrue(GClause gc)
          Test if a clause is always true no matter how we flip flippable atoms.
protected  boolean isTrueLit(int lit)
          Check if a given literal is true under current truth assignment.
 void mcsat(int numSamples, int numFlips)
          Execute the MC-SAT algorithm.
protected  boolean ownsAtom(int aid)
          Test if a given atom is "owned" by this MRF.
 double recalcCost()
          Recalculate total cost.
 void restoreLowTruth()
          Assign the recorded low-cost truth values to current truth values.
protected  int retainOnlyHardClauses()
          Kill soft clauses.
protected  int retainSomeGoodClauses()
          Retain a subset of currently satisfied clauses, according to the sampling method of MC-SAT.
protected  boolean sampleSAT(int nSteps)
          SampleSAT (with WalkSAT inside), used to uniformly sample a zero-cost world.
protected  void saveLowTruth(double cost)
          If current truths have the lowest cost, save them.
 void setInitStrategy(MRF.INIT_STRATEGY strategy)
           
protected  boolean testChance(double p)
          Coin flipping.
protected  void unfixAllAtoms()
          Unfix all atoms.
 void updateAtomMarginalProbs(int numSamples)
           
 void updateClauseVoiTallies()
          Update the number of violations of a clause.
 void updateClauseWeights(java.util.HashMap<java.lang.String,java.lang.Double> currentWeight)
          Change the weight of GClause based on updated weight of Clause.
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Field Detail

adj

protected java.util.HashMap<java.lang.Integer,java.util.ArrayList<GClause>> adj
Index from GAtom ID to GClause.


atoms

public java.util.HashMap<java.lang.Integer,GAtom> atoms
Map from GAtom ID to GAtom object.


clauseNiNjViolationTallies

public java.util.HashMap<java.lang.String,java.lang.Long> clauseNiNjViolationTallies
This map records the tallies for calculating E(v_i*v_j).


clauses

public java.util.ArrayList<GClause> clauses
Array of all GClause objects in this MRF.


clauseSatTallies

public java.util.HashMap<java.lang.String,java.lang.Long> clauseSatTallies
This array records total number of satisfaction for a clause.


clauseSquareVioTallies

public java.util.HashMap<java.lang.String,java.lang.Long> clauseSquareVioTallies
This array records total number of square violation for a clause. Dividing this number by MCSAT#nClauseVioTallies will give the estimated expectation of #violation.


clauseVioTallies

public java.util.HashMap<java.lang.String,java.lang.Long> clauseVioTallies
This array records total number of violation for a clause. Dividing this number by MCSAT#nClauseVioTallies will give the estimated expectation of #violation.


dirtyAtoms

protected java.util.HashSet<java.lang.Integer> dirtyAtoms
Atoms that have been flipped since last saving to low.


expectationOfNiNjViolation

public java.util.HashMap<java.lang.String,java.lang.Double> expectationOfNiNjViolation
This map records the expectation of E(v_i*v_j). This is filled by MCSAT#calcExpViolation().


expectationOfSatisfication

public java.util.HashMap<java.lang.String,java.lang.Double> expectationOfSatisfication
This array records the expection of #satisfaction for each clause. This is filled by MCSAT#calcExpViolation().


expectationOfSquareViolation

public java.util.HashMap<java.lang.String,java.lang.Double> expectationOfSquareViolation
This map records the expectation of square #violation for each clause. This is filled by MCSAT#calcExpViolation().


expectationOfViolation

public java.util.HashMap<java.lang.String,java.lang.Double> expectationOfViolation
This map records the expectation of #violation for each clause. This is filled by MCSAT#calcExpViolation().


inferOps

public long inferOps

initStrategy

protected MRF.INIT_STRATEGY initStrategy

keyBlock

public KeyBlock keyBlock

lowCost

public double lowCost
Lowest cost ever seen.


ownsAllAtoms

public boolean ownsAllAtoms

sampleSatMode

protected boolean sampleSatMode
The flag indicating whether MCSAT is running WalkSAT or SampleSAT.


totalAlive

protected int totalAlive
Number of GClauses that is selected, and therefore must be satisfied by next SampleSAT invocation of MCSAT.


totalCost

protected double totalCost
The total cost of this MRF under current atoms' truth setting.


unsat

protected HashArray<GClause> unsat
Array of unsatisfied GClauses under current atoms' truth setting.

Constructor Detail

MRF

public MRF(MarkovLogicNetwork mln)
Default constructor. Does not really do anything.


MRF

public MRF(MarkovLogicNetwork mln,
           int partID,
           java.util.HashMap<java.lang.Integer,GAtom> gatoms)
Parameters:
partID - id of this MRF
gatoms - ground atoms
Method Detail

addAtom

public void addAtom(int aid)
Add an atom into this MRF.

Parameters:
aid - id of the atom

auditClauseViolations

public void auditClauseViolations()
Track ground clause violations to fo-clauses. Stats are records on a per fo-clause basis.

See Also:
Stats.reportMostViolatedClauses(tuffy.infer.MRF, int)

buildIndices

protected void buildIndices()
Build literal-->clauses index. Used by WalkSAT.


calcCosts

protected double calcCosts()
Compute total cost and per-atom delta cost. The delta cost of an atom is the change in the total cost if this atom is flipped.

Returns:
total cost

calcExpViolation

public void calcExpViolation()
Calculating the different expectations by filling the HashMaps related to expectations in this class.


discard

public void discard()
Discard all data structures, in hope of facilitating faster GC.


enableAllClauses

protected void enableAllClauses()
Reset all clauses to be alive.


fixAtom

protected void fixAtom(int aid,
                       boolean t)
Fix the truth value of an atom.

Parameters:
aid - id of the atom
t - truth value to be fixed

getCoreAtoms

public java.util.HashSet<java.lang.Integer> getCoreAtoms()

getCost

public double getCost()

getInitStrategy

public MRF.INIT_STRATEGY getInitStrategy()

getMLN

public MarkovLogicNetwork getMLN()

inferWalkSAT

public void inferWalkSAT(int nTries,
                         int nSteps)
Run WalkSAT.

Parameters:
nTries - number of tries
nSteps - number of steps per try

initMRF

public void initMRF()
Initialize the state of the MRF.


invalidateLowCost

public void invalidateLowCost()
Reset low-cost to infinity.


isAlwaysTrue

protected boolean isAlwaysTrue(GClause gc)
Test if a clause is always true no matter how we flip flippable atoms.

Parameters:
gc - the clause

isTrueLit

protected boolean isTrueLit(int lit)
Check if a given literal is true under current truth assignment.

Parameters:
lit - the literal represented as an integer

mcsat

public void mcsat(int numSamples,
                  int numFlips)
Execute the MC-SAT algorithm.

Parameters:
numSamples - number of MC-SAT samples
numFlips - number of SampleSAT steps in each iteration

ownsAtom

protected boolean ownsAtom(int aid)
Test if a given atom is "owned" by this MRF. An atom may not belong to this MRF if this MRF represents a partition of a component that has multiple partitions.

Parameters:
aid - id of the atom

recalcCost

public double recalcCost()
Recalculate total cost.

Returns:
updated total cost

restoreLowTruth

public void restoreLowTruth()
Assign the recorded low-cost truth values to current truth values.


retainOnlyHardClauses

protected int retainOnlyHardClauses()
Kill soft clauses.

Returns:
the number of hard clauses

retainSomeGoodClauses

protected int retainSomeGoodClauses()
Retain a subset of currently satisfied clauses, according to the sampling method of MC-SAT.

Returns:
the number of retained clauses

sampleSAT

protected boolean sampleSAT(int nSteps)
SampleSAT (with WalkSAT inside), used to uniformly sample a zero-cost world. WalkSAT is used as a SAT solver to find the first (quasi-)zero-cost world. Simulated annealing (SA) is stochastically performed to wander around.

Parameters:
nSteps -
Returns:
true iff a zero-cost world was reached

saveLowTruth

protected void saveLowTruth(double cost)
If current truths have the lowest cost, save them.

Parameters:
cost - the current cost

setInitStrategy

public void setInitStrategy(MRF.INIT_STRATEGY strategy)

testChance

protected boolean testChance(double p)
Coin flipping.

Parameters:
p - probability of returning true

unfixAllAtoms

protected void unfixAllAtoms()
Unfix all atoms.


updateAtomMarginalProbs

public void updateAtomMarginalProbs(int numSamples)

updateClauseVoiTallies

public void updateClauseVoiTallies()
Update the number of violations of a clause. For each GClause, their value can increase at most 1 for each MCSAT iteration. For Clause, their value can increase more, because there may be more than one GClauses associated with it.


updateClauseWeights

public void updateClauseWeights(java.util.HashMap<java.lang.String,java.lang.Double> currentWeight)
Change the weight of GClause based on updated weight of Clause. This new weight will be aware by MCSAT. The cost of flipping atom and the unsat set for GClause will be calculated automatically by this function.

Parameters:
currentWeight - The weight of clauses to be flushed in this MCSAT instance.