1 | package tuffy.infer; |
2 | |
3 | |
4 | |
5 | import java.util.ArrayList; |
6 | import java.util.HashMap; |
7 | |
8 | import tuffy.db.RDB; |
9 | import tuffy.ground.Grounding; |
10 | import tuffy.ground.partition.Bucket; |
11 | import tuffy.ground.partition.Component; |
12 | import tuffy.ground.partition.Partition; |
13 | import tuffy.ground.partition.PartitionScheme; |
14 | import tuffy.ground.partition.Partitioning; |
15 | import tuffy.mln.MarkovLogicNetwork; |
16 | import tuffy.util.Config; |
17 | import tuffy.util.Settings; |
18 | import tuffy.util.Timer; |
19 | import tuffy.util.UIMan; |
20 | |
21 | /** |
22 | * Scheduler of partition-aware inference. |
23 | * |
24 | */ |
25 | public class InferPartitioned { |
26 | MarkovLogicNetwork mln; |
27 | DataMover dmover; |
28 | RDB db; |
29 | Grounding grounding; |
30 | Partitioning parting; |
31 | PartitionScheme pmap; |
32 | ArrayList<Bucket> wholeBuckets = new ArrayList<Bucket>(); |
33 | HashMap<Component, ArrayList<Bucket>> partBuckets = |
34 | new HashMap<Component, ArrayList<Bucket>>(); |
35 | |
36 | public PartitionScheme getPartitionScheme(){ |
37 | return pmap; |
38 | } |
39 | |
40 | |
41 | public InferPartitioned(Grounding g, DataMover dmover){ |
42 | grounding = g; |
43 | mln = g.getMLN(); |
44 | db = mln.getRDB(); |
45 | this.dmover = dmover; |
46 | partition(); |
47 | } |
48 | |
49 | /** |
50 | * Partition the MRF produced by the grounding process. |
51 | */ |
52 | private void partition(){ |
53 | parting = new Partitioning(grounding); |
54 | UIMan.println(">>> Partitioning MRF..."); |
55 | pmap = parting.partitionMRF(Config.partition_size_bound); |
56 | UIMan.verbose(2, pmap.getStats()); |
57 | groupPartitionsIntoBuckets(); |
58 | int ncomp = pmap.numComponents(); |
59 | int npart = pmap.numParts(); |
60 | int nbuck = getNumBuckets(); |
61 | String sp = "### " + ncomp + " components; " + npart + " partitions; " + nbuck + " buckets"; |
62 | |
63 | UIMan.println(sp); |
64 | |
65 | UIMan.verbose(1, sp); |
66 | } |
67 | |
68 | public int getNumBuckets(){ |
69 | int nb = wholeBuckets.size(); |
70 | for(Component c : partBuckets.keySet()){ |
71 | nb += partBuckets.get(c).size(); |
72 | } |
73 | return nb; |
74 | } |
75 | |
76 | /** |
77 | * Group components/partitions to enable efficient batch loading and parallel inference. |
78 | */ |
79 | private void groupPartitionsIntoBuckets(){ |
80 | for(Component c : pmap.components){ |
81 | if(c.size() <= Config.ram_size){ |
82 | boolean taken = false; |
83 | for(Bucket z : wholeBuckets){ |
84 | if(z.size() + c.size() <= Config.ram_size){ |
85 | taken = true; |
86 | z.addComponent(c); |
87 | } |
88 | } |
89 | if(!taken){ |
90 | Bucket z = new Bucket(db, pmap); |
91 | z.addComponent(c); |
92 | wholeBuckets.add(z); |
93 | } |
94 | }else{ |
95 | ArrayList<Bucket> zones = new ArrayList<Bucket>(); |
96 | Bucket z = new Bucket(db, pmap); |
97 | zones.add(z); |
98 | for(Partition p : c.parts){ |
99 | if(z.size() + p.size() <= Config.ram_size){ |
100 | z.addPart(p); |
101 | }else{ |
102 | z = new Bucket(db, pmap); |
103 | z.addPart(p); |
104 | zones.add(z); |
105 | } |
106 | } |
107 | partBuckets.put(c, zones); |
108 | } |
109 | } |
110 | } |
111 | |
112 | /** |
113 | * Run partition-aware MAP inference. |
114 | */ |
115 | public double infer(Settings s){ |
116 | |
117 | double cost = 0; |
118 | int numberOfSnap = 0; |
119 | |
120 | if(s.getString("task").equals("MAP")){ |
121 | |
122 | |
123 | cost = 0; |
124 | numberOfSnap = 1; |
125 | |
126 | for(Bucket z : wholeBuckets){ |
127 | UIMan.println(">>> Processing " + z); |
128 | UIMan.println(" Loading data..."); |
129 | z.load(mln); |
130 | InferBucket ib = new InferBucket(z); |
131 | //UIMan.verbose(1, " [Settings]"); |
132 | //UIMan.verbose(1, s.toString()); |
133 | UIMan.println(" Running inference with " + ib.getNumThreads() + " thread(s)..."); |
134 | ib.infer(s); |
135 | UIMan.verbose(1, " Flushing states..."); |
136 | ib.flushAtomStates(dmover, mln.relAtoms); |
137 | cost += ib.getCost(); |
138 | z.discard(); |
139 | } |
140 | |
141 | // large components requires some swapping |
142 | for(Component c : partBuckets.keySet()){ |
143 | double licost = Double.MAX_VALUE; |
144 | ArrayList<Bucket> zones = partBuckets.get(c); |
145 | for(int t=1; t<=Config.gauss_seidel_infer_rounds; t++){ |
146 | double icost = 0; |
147 | for(Bucket z : zones){ |
148 | UIMan.println(">>> Processing " + z); |
149 | UIMan.println(" Loading data..."); |
150 | z.load(mln); |
151 | InferBucket ib = new InferBucket(z); |
152 | if(t==1){ |
153 | ib.setMrfInitStrategy(tuffy.infer.MRF.INIT_STRATEGY.COIN_FLIP); |
154 | }else{ |
155 | ib.setMrfInitStrategy(tuffy.infer.MRF.INIT_STRATEGY.COPY_LOW); |
156 | } |
157 | UIMan.println(" Running inference with " + ib.getNumThreads() + " thread(s)..."); |
158 | ib.infer(s); |
159 | UIMan.verbose(1, " Flushing states..."); |
160 | ib.flushAtomStates(dmover, mln.relAtoms); |
161 | icost += ib.getCost(); |
162 | if(!Config.snapshot_mode){ |
163 | z.discard(); |
164 | } |
165 | } |
166 | if(icost < licost){ |
167 | licost = icost; |
168 | } |
169 | } |
170 | cost += licost; |
171 | } |
172 | |
173 | |
174 | }else{ |
175 | |
176 | cost = 0; |
177 | |
178 | int beginTime = (int) Timer.elapsedSeconds(); |
179 | |
180 | // small components that fit into memory can be sovled in one shot |
181 | |
182 | int nsamples = s.getInt("nsamples"); |
183 | |
184 | if(Config.snapshot_mode){ |
185 | s.put("nsamples", 100); |
186 | } |
187 | |
188 | for(int i=0; i< nsamples; i+= s.getInt("nsamples")){ |
189 | |
190 | numberOfSnap ++; |
191 | |
192 | int curTime = (int) Timer.elapsedSeconds(); |
193 | |
194 | Config.currentSampledNumber += s.getInt("nsamples"); |
195 | |
196 | if(i != 0){ |
197 | Config.snapshoting_so_do_not_do_init_flip = true; |
198 | } |
199 | |
200 | UIMan.println(">>> MCSAT FOR SAMPLES " + i + " ~ " + (i+s.getInt("nsamples"))); |
201 | |
202 | for(Bucket z : wholeBuckets){ |
203 | UIMan.println(">>> Processing " + z); |
204 | UIMan.println(" Loading data..."); |
205 | if(!Config.snapshoting_so_do_not_do_init_flip){ |
206 | z.load(mln); |
207 | } |
208 | InferBucket ib = new InferBucket(z); |
209 | //UIMan.verbose(1, " [Settings]"); |
210 | //UIMan.verbose(1, s.toString()); |
211 | UIMan.println(" Running inference with " + ib.getNumThreads() + " thread(s)..."); |
212 | ib.infer(s); |
213 | UIMan.verbose(1, " Flushing states..."); |
214 | ib.flushAtomStates(dmover, mln.relAtoms); |
215 | cost += ib.getCost(); |
216 | if(!Config.snapshot_mode){ |
217 | z.discard(); |
218 | } |
219 | } |
220 | |
221 | // large components requires some swapping |
222 | for(Component c : partBuckets.keySet()){ |
223 | double licost = Double.MAX_VALUE; |
224 | ArrayList<Bucket> zones = partBuckets.get(c); |
225 | for(int t=1; t<=Config.gauss_seidel_infer_rounds; t++){ |
226 | double icost = 0; |
227 | for(Bucket z : zones){ |
228 | UIMan.println(">>> Processing " + z); |
229 | UIMan.println(" Loading data..."); |
230 | if(Config.snapshoting_so_do_not_do_init_flip){ |
231 | z.load(mln); |
232 | } |
233 | InferBucket ib = new InferBucket(z); |
234 | if(t==1){ |
235 | ib.setMrfInitStrategy(tuffy.infer.MRF.INIT_STRATEGY.COIN_FLIP); |
236 | }else{ |
237 | ib.setMrfInitStrategy(tuffy.infer.MRF.INIT_STRATEGY.COPY_LOW); |
238 | } |
239 | UIMan.println(" Running inference with " + ib.getNumThreads() + " thread(s)..."); |
240 | ib.infer(s); |
241 | UIMan.verbose(1, " Flushing states..."); |
242 | ib.flushAtomStates(dmover, mln.relAtoms); |
243 | icost += ib.getCost(); |
244 | if(!Config.snapshot_mode){ |
245 | z.discard(); |
246 | } |
247 | } |
248 | //if(icost < licost){ |
249 | // licost = icost; |
250 | //} |
251 | licost += icost; |
252 | } |
253 | cost += licost/Config.gauss_seidel_infer_rounds; |
254 | } |
255 | |
256 | int endTime = (int) Timer.elapsedSeconds(); |
257 | beginTime += endTime - curTime; |
258 | |
259 | if(Config.snapshot_mode){ |
260 | dmover.dumpProbsToFile(mln.relAtoms, Config.dir_out + "/snapshots-" + beginTime + "s"); |
261 | } |
262 | |
263 | if(beginTime > Config.timeout){ |
264 | System.out.println("!!! TIME OUT AT " + (beginTime) + " sec."); |
265 | break; |
266 | } |
267 | } |
268 | |
269 | } |
270 | |
271 | return cost/numberOfSnap; |
272 | } |
273 | |
274 | } |