1 | package tuffy.ground.partition; |
2 | |
3 | |
4 | |
5 | import java.io.BufferedWriter; |
6 | import java.io.File; |
7 | import java.io.FileInputStream; |
8 | import java.io.FileOutputStream; |
9 | import java.io.OutputStreamWriter; |
10 | import java.sql.ResultSet; |
11 | import java.util.ArrayList; |
12 | import java.util.Collections; |
13 | import java.util.HashMap; |
14 | import java.util.HashSet; |
15 | |
16 | |
17 | import org.postgresql.PGConnection; |
18 | |
19 | import tuffy.db.RDB; |
20 | import tuffy.ground.Grounding; |
21 | import tuffy.mln.MarkovLogicNetwork; |
22 | import tuffy.util.Config; |
23 | import tuffy.util.ExceptionMan; |
24 | import tuffy.util.UIMan; |
25 | import tuffy.util.UnionFind; |
26 | |
27 | /** |
28 | * Utilities for partitioning the MRF generated by the grounding process. |
29 | * |
30 | */ |
31 | public class Partitioning { |
32 | private MarkovLogicNetwork mln = null; |
33 | private RDB db = null; |
34 | |
35 | /** |
36 | * Construct a partitioning worker based on the grounding |
37 | * result. |
38 | * @param g the grounding worker |
39 | */ |
40 | public Partitioning(Grounding g){ |
41 | mln = g.getMLN(); |
42 | db = mln.getRDB(); |
43 | } |
44 | |
45 | |
46 | private double perAtomWeight(){ |
47 | return |
48 | 16.0 // MRF collections |
49 | + 48.0; // GAtom |
50 | } |
51 | |
52 | private double perClauseSharedWeight(int nlits){ |
53 | return |
54 | (4.0 // atom-clause index |
55 | + 4.0 // unsat |
56 | + 32.0 + nlits * 4 // GClause object |
57 | ) / nlits; |
58 | } |
59 | |
60 | |
61 | /** |
62 | * Agglomeratively cluster the atoms in the MRF into partitions |
63 | * while heuristically minimizing the (weighted) cut size. The heuristic |
64 | * is to scan through all clauses in the descending-abs(weight) order. |
65 | * |
66 | * @param ramBudgetPerPartition size bound of a partition, roughly proportional to |
67 | * the number of atoms and clauses in each partition |
68 | * |
69 | * @return a partitioning scheme |
70 | */ |
71 | public PartitionScheme partitionAtoms(double ramBudgetPerPartition){ |
72 | UIMan.verbose(1, ">>> Partitioning atoms..."); |
73 | double maxPartRAM = ramBudgetPerPartition; |
74 | try { |
75 | db.disableAutoCommitForNow(); |
76 | // read in set of atoms |
77 | ArrayList<Integer> atoms = new ArrayList<Integer>(); |
78 | String sql = "SELECT atomID FROM " + mln.relAtoms; |
79 | ResultSet rs = db.query(sql); |
80 | while(rs.next()){ |
81 | atoms.add(rs.getInt(1)); |
82 | } |
83 | rs.close(); |
84 | |
85 | // the memory usage amortized onto individual atoms |
86 | HashMap<Integer, Double> atomWts = new HashMap<Integer, Double>(); |
87 | for(int a : atoms){ |
88 | atomWts.put(a, perAtomWeight()); |
89 | } |
90 | sql = "SELECT lits FROM " + mln.relClauses; |
91 | rs = db.query(sql); |
92 | while(rs.next()){ |
93 | Integer[] lits = (Integer[])rs.getArray("lits").getArray(); |
94 | double delta = perClauseSharedWeight(lits.length); |
95 | for(int lit : lits){ |
96 | int a = Math.abs(lit); |
97 | double w = atomWts.get(a) + delta; |
98 | atomWts.put(a, w); |
99 | } |
100 | } |
101 | rs.close(); |
102 | |
103 | // first scan, compute components and parts |
104 | UnionFind<Integer> ufpart = new UnionFind<Integer>(); |
105 | UnionFind<Integer> ufcomp = new UnionFind<Integer>(); |
106 | ufpart.makeUnionFind(atoms, atomWts); |
107 | ufcomp.makeUnionFind(atoms, atomWts); |
108 | |
109 | sql = "SELECT lits FROM " + mln.relClauses + |
110 | " ORDER BY ABS(weight) DESC"; |
111 | rs = db.query(sql); |
112 | while(rs.next()){ |
113 | Integer[] lits = (Integer[])rs.getArray("lits").getArray(); |
114 | if(lits.length <= 1) continue; |
115 | double afterSize = 0; |
116 | HashSet<Integer> roots = new HashSet<Integer>(); |
117 | for(int lit : lits){ |
118 | Integer root = ufpart.getRoot(Math.abs(lit)); |
119 | if(!roots.contains(root)){ |
120 | afterSize += ufpart.clusterWeight(root); |
121 | roots.add(root); |
122 | } |
123 | } |
124 | Integer c1 = Math.abs(lits[0]); |
125 | for(int i=1; i<lits.length; i++){ |
126 | ufcomp.union(c1, Math.abs(lits[i])); |
127 | //if(ProbMan.testChance(1-afterSize/maxPartSize)) |
128 | if(afterSize <= maxPartRAM){ |
129 | ufpart.union(c1, Math.abs(lits[i])); |
130 | } |
131 | } |
132 | } |
133 | rs.close(); |
134 | |
135 | // allocate result data structures |
136 | HashMap<Integer, Component> compMap = new HashMap<Integer, Component>(); |
137 | for(Integer r : ufcomp.getRoots()){ |
138 | int size = ufcomp.clusterSize(r); |
139 | double wt = ufcomp.clusterWeight(r); |
140 | Component comp = new Component(); |
141 | comp.numAtoms = size; |
142 | comp.ramSize = wt; |
143 | comp.rep = r; |
144 | compMap.put(r, comp); |
145 | } |
146 | |
147 | HashMap<Integer, Partition> partMap = new HashMap<Integer, Partition>(); |
148 | for(Integer r : ufpart.getRoots()){ |
149 | int size = ufpart.clusterSize(r); |
150 | double wt = ufpart.clusterWeight(r); |
151 | Partition part = new Partition(); |
152 | part.numAtoms = size; |
153 | part.ramSize = wt; |
154 | partMap.put(r, part); |
155 | Component comp = compMap.get(ufcomp.getRoot(r)); |
156 | comp.parts.add(part); |
157 | part.parentComponent = comp; |
158 | } |
159 | |
160 | // second scan, aggregate size info |
161 | sql = "SELECT lits, weight FROM " + mln.relClauses;; |
162 | rs = db.query(sql); |
163 | while(rs.next()){ |
164 | Integer[] lits = (Integer[])rs.getArray("lits").getArray(); |
165 | double weight = Math.abs(rs.getDouble("weight")); |
166 | Component comp = compMap.get(ufcomp.getRoot(Math.abs(lits[0]))); |
167 | comp.totalWeight += weight; |
168 | comp.numClauses ++; |
169 | comp.numPins += lits.length; |
170 | HashSet<Partition> pset = new HashSet<Partition>(); |
171 | for(int lit : lits){ |
172 | int a = Math.abs(lit); |
173 | Partition part = partMap.get(ufpart.getRoot(a)); |
174 | if(!pset.contains(part)){ |
175 | pset.add(part); |
176 | part.numIncidentClauses ++; |
177 | } |
178 | } |
179 | if(pset.size() > 1){ |
180 | comp.totalCutWeight += weight; |
181 | comp.numCutClauses ++; |
182 | } |
183 | } |
184 | rs.close(); |
185 | db.restoreAutoCommitState(); |
186 | |
187 | // assign IDs to components and parts |
188 | ArrayList<Component> comps = new ArrayList<Component>(compMap.values()); |
189 | Collections.sort(comps); |
190 | int partID = 0; |
191 | for(int i=0; i<comps.size(); i++){ |
192 | Component comp = comps.get(i); |
193 | comp.id = i+1; |
194 | Collections.sort(comp.parts); |
195 | for(int j=0; j<comp.parts.size(); j++){ |
196 | Partition part = comp.parts.get(j); |
197 | part.id = ++partID; |
198 | } |
199 | } |
200 | PartitionScheme pmap = new PartitionScheme(comps); |
201 | |
202 | // output clustering results |
203 | File fbuf = new File(Config.getLoadingDir(), "partition_map" + mln.getID()); |
204 | BufferedWriter writer = new BufferedWriter(new OutputStreamWriter |
205 | (new FileOutputStream(fbuf),"UTF8")); |
206 | for(int a : atoms){ |
207 | int compid = compMap.get(ufcomp.getRoot(a)).id; |
208 | int partid = partMap.get(ufpart.getRoot(a)).id; |
209 | writer.append(a + " , " + compid + " , " + partid + "\n"); |
210 | } |
211 | writer.close(); |
212 | db.dropTable(mln.relAtomPart); |
213 | sql = "CREATE TABLE " + mln.relAtomPart + "(atomID INT, compID INT, partID INT)"; |
214 | db.update(sql); |
215 | FileInputStream in = new FileInputStream(fbuf); |
216 | PGConnection con = (PGConnection)db.getConnection(); |
217 | sql = "COPY " + mln.relAtomPart + " FROM STDIN CSV"; |
218 | con.getCopyAPI().copyIn(sql, in); |
219 | in.close(); |
220 | return pmap; |
221 | } catch (Exception e) { |
222 | ExceptionMan.handle(e); |
223 | } |
224 | return null; |
225 | } |
226 | |
227 | public PartitionScheme partitionMRF(double maxPartitionSize){ |
228 | PartitionScheme pmap = partitionAtoms(maxPartitionSize); |
229 | partitionClauses(pmap); |
230 | return pmap; |
231 | } |
232 | |
233 | /** |
234 | * Given a partitioning scheme, partition the data accordingly. |
235 | * |
236 | */ |
237 | public void partitionClauses(PartitionScheme pmap){ |
238 | int numParts = pmap.numParts(); |
239 | String sql; |
240 | UIMan.verbose(1, ">>> Partitioning the MRF into " + numParts + " parts..."); |
241 | db.dropTable(mln.relClausePart); |
242 | sql = "CREATE TABLE " + mln.relClausePart + |
243 | "(cid INT, partID INT, compID INT)"; |
244 | db.update(sql); |
245 | |
246 | sql = "INSERT INTO " + mln.relClausePart + "(cid, partID, compID) " + |
247 | "SELECT rc.cid, ra.partID, ra.compID FROM " + |
248 | mln.relClauses + " rc, " + mln.relAtomPart + " ra " + |
249 | " WHERE ABS(_random_element(rc.lits))=ra.atomID"; |
250 | db.update(sql); |
251 | |
252 | } |
253 | |
254 | } |