1 | package tuffy.infer; |
2 | |
3 | |
4 | import java.util.ArrayList; |
5 | |
6 | import tuffy.ground.partition.Bucket; |
7 | import tuffy.ground.partition.Component; |
8 | import tuffy.ground.partition.Partition; |
9 | import tuffy.infer.MRF.INIT_STRATEGY; |
10 | import tuffy.infer.ds.GAtom; |
11 | import tuffy.util.Config; |
12 | import tuffy.util.ExceptionMan; |
13 | import tuffy.util.Settings; |
14 | import tuffy.util.UIMan; |
15 | /** |
16 | * |
17 | * A bucket of inference tasks that can run in prallel. Currently, each task |
18 | * correspond to an MRF component so that the components can be processed in parallel. |
19 | */ |
20 | public class InferBucket{ |
21 | private Bucket bucket; |
22 | private int numThreads = 2; |
23 | private double cost; |
24 | |
25 | private int totalSamples = 0; |
26 | |
27 | Config.TUFFY_INFERENCE_TASK task = null; |
28 | Settings settings = null; |
29 | |
30 | public void infer(Settings s){ |
31 | settings = s; |
32 | task = Config.TUFFY_INFERENCE_TASK.valueOf(s.getString("task")); |
33 | this.runInferParallel(); |
34 | } |
35 | |
36 | |
37 | public InferBucket(Bucket bucket){ |
38 | this.bucket = bucket; |
39 | numThreads = Config.getNumThreads(); |
40 | } |
41 | |
42 | public void flushAtomStates(DataMover dmover, String relAtoms){ |
43 | ArrayList<GAtom> gatoms = new ArrayList<GAtom>(); |
44 | for(Component c: bucket.getComponents()){ |
45 | if(c.atoms == null) continue; |
46 | gatoms.addAll(c.atoms.values()); |
47 | } |
48 | dmover.flushAtomStates(gatoms, relAtoms); |
49 | } |
50 | |
51 | public void setMrfInitStrategy(INIT_STRATEGY strategy){ |
52 | for(Partition p : bucket.getPartitions()){ |
53 | if(p.mrf == null) continue; |
54 | p.mrf.setInitStrategy(strategy); |
55 | } |
56 | } |
57 | |
58 | /** |
59 | * Get the cost after inference. |
60 | */ |
61 | public double getCost(){ |
62 | return cost; |
63 | } |
64 | |
65 | public double getSamples(){ |
66 | return this.totalSamples; |
67 | } |
68 | |
69 | private Object sentinel = new Object(); |
70 | |
71 | /** |
72 | * Add up the cost. |
73 | * |
74 | * @see CompWorker#run() |
75 | * @param c |
76 | */ |
77 | public void addCost(double c){ |
78 | synchronized(sentinel){ |
79 | cost += c; |
80 | } |
81 | } |
82 | |
83 | /** |
84 | * Add up the cost. |
85 | * |
86 | * @see CompWorker#run() |
87 | * @param c |
88 | */ |
89 | public void addSamples(int c){ |
90 | synchronized(sentinel){ |
91 | this.totalSamples += c; |
92 | } |
93 | } |
94 | |
95 | /** |
96 | * A worker thread that runs inference on one component at a time. |
97 | * Used for parallelization in a producer-consumer model. |
98 | */ |
99 | public static class CompWorker extends Thread{ |
100 | InferBucket ibucket; |
101 | Config.TUFFY_INFERENCE_TASK task; |
102 | Settings settings; |
103 | |
104 | public CompWorker(InferBucket bucket){ |
105 | this.ibucket = bucket; |
106 | settings = bucket.settings; |
107 | task = bucket.task; |
108 | } |
109 | |
110 | public void run(){ |
111 | while(true){ |
112 | Component comp = ibucket.getTask(); |
113 | if(comp == null) return; |
114 | //UIMan.println(this + " is processing " + comp + " with " + comp.numAtoms + " atoms and " + comp.numClauses); |
115 | InferComponent ic = new InferComponent(comp); |
116 | switch(task){ |
117 | case MAP: |
118 | int ntries = (Integer)(settings.get("ntries")); |
119 | double flipsPerAtom = (Double)(settings.get("flipsPerAtom")); |
120 | int nflips = (int)(flipsPerAtom * comp.numAtoms); |
121 | //UIMan.println("flips = " + nflips + "; tries = " + ntries); |
122 | ic.inferMAP(ntries, nflips); |
123 | ibucket.addCost(ic.getCost()); |
124 | break; |
125 | case MARGINAL: |
126 | int nsamples = (Integer)(settings.get("nsamples")); |
127 | flipsPerAtom = (Double)(settings.get("flipsPerAtom")); |
128 | nflips = (int)(flipsPerAtom * comp.numAtoms); |
129 | ibucket.addCost(ic.inferMarginal(nsamples, nflips)); |
130 | break; |
131 | } |
132 | } |
133 | } |
134 | } |
135 | |
136 | /** |
137 | * The queue of components to be processed |
138 | */ |
139 | private ArrayList<Component> q = null; |
140 | |
141 | /** |
142 | * Get the next unprocessed component in the queue |
143 | * |
144 | * @see CompWorker#run() |
145 | */ |
146 | public Component getTask(){ |
147 | synchronized(sentinel){ |
148 | if(q.isEmpty()) return null; |
149 | Component c = q.remove(q.size()-1); |
150 | //UIMan.println(" got " + c + " from the queue"); |
151 | return c; |
152 | } |
153 | } |
154 | |
155 | /** |
156 | * Solve the components in parallel. |
157 | * @param ntries |
158 | * @param nflips |
159 | */ |
160 | private void runInferParallel(){ |
161 | cost = 0; |
162 | q = new ArrayList<Component>(); |
163 | q.addAll(bucket.getComponents()); |
164 | ArrayList<CompWorker> workers = new ArrayList<CompWorker>(); |
165 | |
166 | if(numThreads > 1){ |
167 | UIMan.setSilent(true); |
168 | } |
169 | for(int i=0; i<numThreads; i++){ |
170 | CompWorker t = new CompWorker(this); |
171 | workers.add(t); |
172 | t.start(); |
173 | } |
174 | for(CompWorker t : workers){ |
175 | try { |
176 | t.join(); |
177 | } catch (InterruptedException e) { |
178 | ExceptionMan.handle(e); |
179 | } |
180 | } |
181 | UIMan.setSilent(false); |
182 | } |
183 | |
184 | |
185 | public void setNumThreads(int numThreads) { |
186 | this.numThreads = numThreads; |
187 | } |
188 | |
189 | |
190 | public int getNumThreads() { |
191 | return numThreads; |
192 | } |
193 | |
194 | |
195 | |
196 | } |