1 | package tuffy.infer; |
2 | |
3 | |
4 | import java.util.ArrayList; |
5 | import java.util.Collections; |
6 | import java.util.Random; |
7 | |
8 | import tuffy.ground.partition.Component; |
9 | import tuffy.ground.partition.Partition; |
10 | import tuffy.infer.MRF.INIT_STRATEGY; |
11 | import tuffy.infer.ds.GAtom; |
12 | import tuffy.util.Config; |
13 | import tuffy.util.MathMan; |
14 | /** |
15 | * Performing inference on one MRF component. |
16 | */ |
17 | public class InferComponent { |
18 | private Component comp; |
19 | //private MarkovLogicNetwork mln; |
20 | |
21 | private double lowCost = Double.MAX_VALUE; |
22 | |
23 | public double getCost(){ |
24 | return lowCost; |
25 | } |
26 | |
27 | public Component getComponent(){ |
28 | return comp; |
29 | } |
30 | |
31 | public InferComponent(Component comp){ |
32 | //this.mln = mln; |
33 | this.comp = comp; |
34 | } |
35 | |
36 | |
37 | |
38 | /** |
39 | * Run partition-aware MAP inference with the Gauss-Seidel scheme. |
40 | * |
41 | */ |
42 | public void inferMAP(int totalTries, int totalFlipsPerTry){ |
43 | int nflips, rounds; |
44 | if(comp.numParts() == 1){ |
45 | rounds = 1; |
46 | nflips = totalFlipsPerTry; |
47 | }else{ |
48 | rounds = Config.gauss_seidel_infer_rounds; |
49 | nflips = totalFlipsPerTry/Config.gauss_seidel_infer_rounds; |
50 | } |
51 | inferGaussSeidelMap(rounds, totalTries, nflips); |
52 | } |
53 | |
54 | /** |
55 | * Run partition-aware marginal inference with the Gauss-Seidel scheme. |
56 | * |
57 | */ |
58 | public double inferMarginal(int totalSamples, int totalFlipsPerSample){ |
59 | int nflips, rounds; |
60 | if(comp.numParts() == 1){ |
61 | rounds = 1; |
62 | nflips = totalFlipsPerSample; |
63 | }else{ |
64 | rounds = Config.gauss_seidel_infer_rounds; |
65 | nflips = totalFlipsPerSample/Config.gauss_seidel_infer_rounds; |
66 | } |
67 | |
68 | return inferGaussSeidelMarginal(rounds, totalSamples, nflips); |
69 | } |
70 | |
71 | |
72 | |
73 | private void initTruthRandom(){ |
74 | Random rand = new Random(); |
75 | for(GAtom a : comp.atoms.values()){ |
76 | a.lowTruth = a.truth = rand.nextBoolean(); |
77 | } |
78 | } |
79 | |
80 | private void setMrfInitStrategy(INIT_STRATEGY strategy){ |
81 | for(Partition p : comp.parts){ |
82 | if(p.mrf == null) continue; |
83 | p.mrf.setInitStrategy(strategy); |
84 | } |
85 | } |
86 | |
87 | /** |
88 | * Gauss-Seidel MAP inference scheme. Calls WalkSAT on each |
89 | * partition in a round-robin manner. |
90 | * |
91 | * @param rounds |
92 | * @param ntries |
93 | * @param nflips total number of flips per try in one round |
94 | */ |
95 | private double inferGaussSeidelMap(int rounds, int ntries, int nflips){ |
96 | initTruthRandom(); |
97 | saveLowLowTruth(); |
98 | setMrfInitStrategy(INIT_STRATEGY.COIN_FLIP); |
99 | ArrayList<Partition> iparts = null; |
100 | iparts = new ArrayList<Partition>(comp.parts); |
101 | Collections.shuffle(iparts); |
102 | for(int r=1; r<=rounds; r++){ |
103 | for(Partition p : iparts){ |
104 | if(p.mrf == null) continue; |
105 | p.mrf.invalidateLowCost(); |
106 | p.mrf.inferWalkSAT(ntries, MathMan.prorate(nflips, |
107 | ((double)p.numAtoms)/comp.numAtoms)); |
108 | p.mrf.restoreLowTruth(); |
109 | } |
110 | saveLowLowTruth(); |
111 | setMrfInitStrategy(INIT_STRATEGY.COPY_LOW); |
112 | Collections.shuffle(iparts); |
113 | } |
114 | setMrfInitStrategy(INIT_STRATEGY.COIN_FLIP); |
115 | restoreLowLowTruth(); |
116 | return lowCost; |
117 | } |
118 | |
119 | int totalSamples = 0; |
120 | |
121 | public int getTotalSamples(){ |
122 | return totalSamples; |
123 | } |
124 | |
125 | /** |
126 | * Gauss-Seidel MAP inference scheme. Calls WalkSAT on each |
127 | * partition in a round-robin manner. |
128 | * |
129 | * @param rounds |
130 | * @param ntries |
131 | * @param nflips total number of flips per try in one round |
132 | */ |
133 | private double inferGaussSeidelMarginal(int rounds, int nsamples, int nflips){ |
134 | if(!Config.snapshoting_so_do_not_do_init_flip){ |
135 | initTruthRandom(); |
136 | } |
137 | saveLowLowTruth(); |
138 | ArrayList<Partition> iparts = null; |
139 | iparts = new ArrayList<Partition>(comp.parts); |
140 | Collections.shuffle(iparts); |
141 | int rnsamples = Math.max(3, nsamples/rounds); |
142 | |
143 | double sumCost = 0; |
144 | |
145 | for(int r=1; r<=rounds; r++){ |
146 | double lastRoundCost = 0; |
147 | for(Partition p : iparts){ |
148 | if(p.mrf == null) continue; |
149 | p.mrf.invalidateLowCost(); |
150 | |
151 | int flips = MathMan.prorate(nflips, |
152 | ((double)p.numAtoms)/comp.numAtoms); |
153 | |
154 | double cc = p.mrf.mcsat(rnsamples, flips);//TODO |
155 | sumCost += cc; |
156 | p.mrf.updateAtomMarginalProbs(r*rnsamples); |
157 | } |
158 | Collections.shuffle(iparts); |
159 | } |
160 | |
161 | totalSamples = rnsamples * rounds; |
162 | return sumCost/totalSamples; |
163 | } |
164 | |
165 | private void saveLowLowTruth(){ |
166 | double cost = recalcCost(); |
167 | if(cost >= lowCost) return; |
168 | lowCost = cost; |
169 | for(GAtom a : comp.atoms.values()){ |
170 | a.lowlowTruth = a.lowTruth; |
171 | } |
172 | } |
173 | |
174 | |
175 | private void restoreLowLowTruth(){ |
176 | for(GAtom a : comp.atoms.values()){ |
177 | a.truth = a.lowlowTruth; |
178 | } |
179 | } |
180 | |
181 | /** |
182 | * Recalculate the cost on this component, which is the sum |
183 | * of the cost on the MRF of each partition. |
184 | * |
185 | * @see MRF#recalcCost() |
186 | */ |
187 | private double recalcCost(){ |
188 | double cost = 0; |
189 | for(Partition p : comp.parts){ |
190 | if(p.mrf == null) continue; |
191 | cost += p.mrf.recalcCost(); |
192 | } |
193 | return cost; |
194 | } |
195 | |
196 | |
197 | |
198 | } |