Skip to content

Commit 485beec

Browse files
committed
Change parallelism scheme for bound_poly_vars_rq
1 parent 36b8d6c commit 485beec

File tree

1 file changed

+75
-34
lines changed

1 file changed

+75
-34
lines changed

spartan_parallel/src/custom_dense_mlpoly.rs

Lines changed: 75 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ const MODE_P: usize = 1;
99
const MODE_Q: usize = 2;
1010
const MODE_W: usize = 3;
1111
const MODE_X: usize = 4;
12+
const NUM_MULTI_THREAD_CORES: usize = 8;
1213

1314
// Customized Dense ML Polynomials for Data-Parallelism
1415
// These Dense ML Polys are aimed for space-efficiency by removing the 0s for invalid (p, q, w, x) quadruple
@@ -206,49 +207,71 @@ impl<S: SpartanExtensionField> DensePolynomialPqx<S> {
206207
}
207208

208209
// Bound the entire "q" section to r_q in reverse
209-
pub fn bound_poly_vars_rq(&mut self,
210+
pub fn bound_poly_vars_rq(
211+
&mut self,
210212
r_q: &[S],
211213
) {
212-
let ONE = S::field_one();
213-
let num_instances = min(self.num_instances, self.Z.len());
214+
let Z = std::mem::take(&mut self.Z);
214215

215-
self.Z = (0..num_instances)
216-
.into_par_iter()
217-
.map(|p| {
216+
self.Z = Z
217+
.into_iter()
218+
.enumerate()
219+
.map(|(p, mut inst)| {
218220
let num_proofs = self.num_proofs[p];
219-
let num_witness_secs = min(self.num_witness_secs, self.Z[p][0].len());
221+
let dist_size = num_proofs / min(num_proofs, NUM_MULTI_THREAD_CORES); // distributed number of proofs on each thread
222+
let num_threads = num_proofs / dist_size;
223+
224+
// To perform rigorous parallelism, both num_proofs and # threads must be powers of 2
225+
// # threads must fully divide num_proofs for even distribution
226+
assert!(num_proofs & (num_proofs - 1) == 0);
227+
assert!(num_threads & (num_threads - 1) == 0);
228+
229+
// Determine parallelism levels
230+
let levels = num_proofs.trailing_zeros() as usize; // total layers
231+
let sub_levels = dist_size.trailing_zeros() as usize; // parallelism layers
232+
let final_levels = num_threads.trailing_zeros() as usize; // single core final layers
233+
let left_over_q_len = r_q.len() - levels; // if r_q.len() > log2(num_proofs)
234+
235+
// single proof matrix dimension W x X
236+
let num_witness_secs = min(self.num_witness_secs, inst[0].len());
220237
let num_inputs = self.num_inputs[p];
221238

222-
let wit = (0..num_witness_secs).into_par_iter().map(|w| {
223-
(0..num_inputs).into_par_iter().map(|x| {
224-
let mut np = num_proofs;
225-
let mut x_fold = (0..num_proofs).map(|q| self.Z[p][q][w][x]).collect::<Vec<S>>();
226-
for r in r_q {
227-
if np == 1 {
228-
x_fold[0] *= ONE - *r;
229-
} else {
230-
np /= 2;
231-
for q in 0..np {
232-
x_fold[q] = x_fold[2 * q] + *r * (x_fold[2 * q + 1] - x_fold[2 * q]);
233-
}
234-
}
239+
if sub_levels > 0 {
240+
let thread_split_inst = (0..num_threads)
241+
.map(|_| {
242+
inst.split_off(inst.len() - dist_size)
243+
})
244+
.rev()
245+
.collect::<Vec<Vec<Vec<Vec<S>>>>>();
246+
247+
inst = thread_split_inst
248+
.into_par_iter()
249+
.map(|mut chunk| {
250+
fold(&mut chunk, r_q, 0, 1, sub_levels, num_witness_secs, num_inputs);
251+
chunk
252+
})
253+
.collect::<Vec<Vec<Vec<Vec<S>>>>>()
254+
.into_iter().flatten().collect()
255+
}
256+
257+
if final_levels > 0 {
258+
// aggregate the final result from sub-threads outputs using a single core
259+
fold(&mut inst, r_q, 0, dist_size, final_levels, num_witness_secs, num_inputs);
260+
}
261+
262+
if left_over_q_len > 0 {
263+
// the series of random challenges exceeds the total number of variables
264+
let c = r_q[(r_q.len() - left_over_q_len)..r_q.len()].iter().fold(S::field_one(), |acc, n| acc * (S::field_one() - *n));
265+
for w in 0..inst[0].len() {
266+
for x in 0..inst[0][0].len() {
267+
inst[0][w][x] *= c;
235268
}
269+
}
270+
}
236271

237-
x_fold
238-
}).collect::<Vec<Vec<S>>>()
239-
}).collect::<Vec<Vec<Vec<S>>>>();
240-
241-
(0..num_proofs)
242-
.into_par_iter()
243-
.map(|q| {
244-
(0..wit.len()).map(|w| {
245-
(0..wit[w].len()).map(|x| {
246-
wit[w][x][q]
247-
}).collect::<Vec<S>>()
248-
}).collect::<Vec<Vec<S>>>()
249-
}).collect::<Vec<Vec<Vec<S>>>>()
272+
inst
250273
}).collect::<Vec<Vec<Vec<Vec<S>>>>>();
251-
274+
252275
self.max_num_proofs /= 2usize.pow(r_q.len() as u32);
253276
}
254277

@@ -304,4 +327,22 @@ impl<S: SpartanExtensionField> DensePolynomialPqx<S> {
304327
}
305328
DensePolynomial::new(Z_poly)
306329
}
330+
}
331+
332+
fn fold<S: SpartanExtensionField>(proofs: &mut Vec<Vec<Vec<S>>>, r_q: &[S], idx: usize, step: usize, lvl: usize, w: usize, x: usize) {
333+
if lvl > 0 {
334+
fold(proofs, r_q, 2 * idx, step, lvl - 1, w, x);
335+
fold(proofs, r_q, 2 * idx + step, step, lvl - 1, w, x);
336+
337+
let r1 = S::field_one() - r_q[lvl];
338+
let r2 = r_q[lvl];
339+
340+
(0..w).for_each(|w| {
341+
(0..x).for_each(|x| {
342+
proofs[idx][w][x] = r1 * proofs[idx * 2][w][x] + r2 * proofs[idx * 2 + step][w][x];
343+
});
344+
});
345+
} else {
346+
// level 0. do nothing
347+
}
307348
}

0 commit comments

Comments
 (0)