Skip to content

Commit ec8458a

Browse files
committed
Faster polyeval
1 parent 1520dbe commit ec8458a

File tree

3 files changed

+87
-37
lines changed

3 files changed

+87
-37
lines changed

spartan_parallel/src/dense_mlpoly.rs

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ use super::random::RandomTape;
77
use super::transcript::ProofTranscript;
88
use core::ops::Index;
99
use merlin::Transcript;
10+
use rayon::{iter::ParallelIterator, slice::ParallelSliceMut};
1011
use serde::{Deserialize, Serialize};
11-
use std::collections::HashMap;
12+
use std::{cmp::min, collections::HashMap};
1213

1314
#[cfg(feature = "multicore")]
1415
use rayon::prelude::*;
@@ -247,6 +248,57 @@ impl<S: SpartanExtensionField> DensePolynomial<S> {
247248
self.len = n;
248249
}
249250

251+
fn fold_r(proofs: &mut [S], r: &[S], step: usize, mut l: usize) {
252+
for r in r {
253+
let r1 = S::field_one() - r.clone();
254+
let r2 = r.clone();
255+
256+
l = l.div_ceil(2);
257+
(0..l).for_each(|i| {
258+
proofs[i * step] = r1 * proofs[2 * i * step] + r2 * proofs[(2 * i + 1) * step];
259+
});
260+
}
261+
}
262+
263+
// returns Z(r) in O(n) time
264+
pub fn evaluate_and_consume_parallel(&mut self, r: &[S]) -> S {
265+
assert_eq!(r.len(), self.get_num_vars());
266+
let mut inst = std::mem::take(&mut self.Z);
267+
268+
let len = self.len;
269+
let dist_size = len / min(len, rayon::current_num_threads().next_power_of_two()); // distributed number of proofs on each thread
270+
let num_threads = len / dist_size;
271+
272+
// To perform rigorous parallelism, both len and # threads must be powers of 2
273+
// # threads must fully divide num_proofs for even distribution
274+
assert_eq!(len, len.next_power_of_two());
275+
assert_eq!(num_threads, num_threads.next_power_of_two());
276+
277+
// Determine parallelism levels
278+
let levels = len.log_2(); // total layers
279+
let sub_levels = dist_size.log_2(); // parallel layers
280+
let final_levels = num_threads.log_2(); // single core final layers
281+
// Divide r into sub and final
282+
let sub_r = &r[0..sub_levels];
283+
let final_r = &r[sub_levels..levels];
284+
285+
if sub_levels > 0 {
286+
inst = inst
287+
.par_chunks_mut(dist_size)
288+
.map(|chunk| {
289+
Self::fold_r(chunk, sub_r, 1, dist_size);
290+
chunk.to_vec()
291+
})
292+
.flatten().collect()
293+
}
294+
295+
if final_levels > 0 {
296+
// aggregate the final result from sub-threads outputs using a single core
297+
Self::fold_r(&mut inst, final_r, dist_size, num_threads);
298+
}
299+
inst[0]
300+
}
301+
250302
// returns Z(r) in O(n) time
251303
pub fn evaluate(&self, r: &[S]) -> S {
252304
// r must have a value for each variable

spartan_parallel/src/r1csproof.rs

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,11 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {
8282
evals_ABC: &mut DensePolynomialPqx<S>,
8383
evals_z: &mut DensePolynomialPqx<S>,
8484
transcript: &mut Transcript,
85-
) -> (SumcheckInstanceProof<S>, Vec<S>, Vec<S>) {
85+
) -> (SumcheckInstanceProof<S>, Vec<S>, Vec<S>, Vec<Vec<S>>) {
8686
let comb_func = |poly_A_comp: &S, poly_B_comp: &S, poly_C_comp: &S| -> S {
8787
*poly_A_comp * *poly_B_comp * *poly_C_comp
8888
};
89-
let (sc_proof_phase_two, r, claims) = SumcheckInstanceProof::<S>::prove_cubic_disjoint_rounds(
89+
let (sc_proof_phase_two, r, claims, claimed_vars_at_ry) = SumcheckInstanceProof::<S>::prove_cubic_disjoint_rounds(
9090
claim,
9191
num_rounds,
9292
num_rounds_y_max,
@@ -102,7 +102,7 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {
102102
transcript,
103103
);
104104

105-
(sc_proof_phase_two, r, claims)
105+
(sc_proof_phase_two, r, claims, claimed_vars_at_ry)
106106
}
107107

108108
fn protocol_name() -> &'static [u8] {
@@ -344,7 +344,7 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {
344344

345345
// Sumcheck 2: (rA + rB + rC) * Z * eq(p) = e
346346
let timer_tmp = Timer::new("prove_sum_check");
347-
let (sc_proof_phase2, ry_rev, _claims_phase2) = R1CSProof::prove_phase_two(
347+
let (sc_proof_phase2, ry_rev, _claims_phase2, claimed_vars_at_ry) = R1CSProof::prove_phase_two(
348348
num_rounds_y + num_rounds_w + num_rounds_p,
349349
num_rounds_y,
350350
num_rounds_w,
@@ -378,6 +378,10 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {
378378
let timer_polyeval = Timer::new("polyeval");
379379

380380
// For every possible wit_sec.num_inputs, compute ry_factor = prodX(1 - ryX)...
381+
let mut rq_factors = vec![ONE; num_rounds_q + 1];
382+
for i in 0..num_rounds_q {
383+
rq_factors[i + 1] = rq_factors[i] * (ONE - rq[i]);
384+
}
381385
let mut ry_factors = vec![ONE; num_rounds_y + 1];
382386
for i in 0..num_rounds_y {
383387
ry_factors[i + 1] = ry_factors[i] * (ONE - ry[i]);
@@ -388,42 +392,26 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {
388392
let mut num_inputs_list = Vec::new();
389393
// List of all evaluations
390394
let mut Zr_list = Vec::new();
391-
// List of evaluations separated by witness_secs
395+
// Obtain list of evaluations separated by witness_secs
396+
// Note: eval_vars_at_ry_list and raw_eval_vars_at_ry_list are W * P but claimed_vars_at_ry_list is P * W, and
397+
// raw_eval_vars_at_ry_list does not multiply rq_factor and ry_factor
392398
let mut eval_vars_at_ry_list = vec![Vec::new(); num_witness_secs];
393-
let mut raw_eval_vars_at_ry_list = vec![Vec::new(); num_witness_secs]; // Does not multiply ry_factor
399+
let mut raw_eval_vars_at_ry_list = vec![Vec::new(); num_witness_secs];
394400
for i in 0..num_witness_secs {
395401
let w = witness_secs[i];
396402
let wit_sec_num_instance = w.w_mat.len();
397-
eval_vars_at_ry_list.push(Vec::new());
398-
399403
for p in 0..wit_sec_num_instance {
400404
if w.num_inputs[p] > 1 {
401405
poly_list.push(&w.poly_w[p]);
402406
num_proofs_list.push(w.w_mat[p].len());
403407
num_inputs_list.push(w.num_inputs[p]);
404-
// Depending on w.num_inputs[p], ry_short can be two different values
405-
let ry_short = {
406-
// if w.num_inputs[p] >= num_inputs, need to pad 0's to the front of ry
407-
if w.num_inputs[p] >= max_num_inputs {
408-
let ry_pad = vec![ZERO; w.num_inputs[p].log_2() - max_num_inputs.log_2()];
409-
[ry_pad, ry.clone()].concat()
410-
}
411-
// Else ry_short is the last w.num_inputs[p].log_2() entries of ry
412-
// thus, to obtain the actual ry, need to multiply by (1 - ry0)(1 - ry1)..., which is ry_factors[num_rounds_y - w.num_inputs[p]]
413-
else {
414-
ry[num_rounds_y - w.num_inputs[p].log_2()..].to_vec()
415-
}
416-
};
417-
let rq_short = rq[num_rounds_q - num_proofs_list[num_proofs_list.len() - 1].log_2()..].to_vec();
418-
let r = &[rq_short, ry_short.clone()].concat();
419-
let eval_vars_at_ry = poly_list[poly_list.len() - 1].evaluate(r);
420-
Zr_list.push(eval_vars_at_ry);
421-
if w.num_inputs[p] >= max_num_inputs {
422-
eval_vars_at_ry_list[i].push(eval_vars_at_ry);
423-
} else {
424-
eval_vars_at_ry_list[i].push(eval_vars_at_ry * ry_factors[num_rounds_y - w.num_inputs[p].log_2()]);
425-
}
426-
raw_eval_vars_at_ry_list[i].push(eval_vars_at_ry);
408+
// Find out the extra q and y padding to remove in raw_eval_vars_at_ry_list
409+
let rq_pad_inv = rq_factors[num_rounds_q - num_proofs[p].log_2()].invert().unwrap();
410+
let ry_pad_inv = if w.num_inputs[p] >= max_num_inputs { ONE } else { ry_factors[num_rounds_y - w.num_inputs[p].log_2()].invert().unwrap() };
411+
eval_vars_at_ry_list[i].push(claimed_vars_at_ry[p][i] * rq_pad_inv); // I don't know why need to divide by rq and later multiply it back, but it doesn't work without this
412+
let claimed_vars_at_ry_no_pad = claimed_vars_at_ry[p][i] * rq_pad_inv * ry_pad_inv;
413+
Zr_list.push(claimed_vars_at_ry_no_pad);
414+
raw_eval_vars_at_ry_list[i].push(claimed_vars_at_ry_no_pad);
427415
} else {
428416
eval_vars_at_ry_list[i].push(ZERO);
429417
raw_eval_vars_at_ry_list[i].push(ZERO);
@@ -491,16 +479,15 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {
491479
};
492480
let mut eval_vars_comb =
493481
(0..num_witness_secs).fold(ZERO, |s, i| s + prefix_list[i] * e(i));
494-
for q in 0..(num_rounds_q - num_proofs[p].log_2()) {
495-
eval_vars_comb = eval_vars_comb * (ONE - rq[q]);
496-
}
482+
eval_vars_comb *= rq_factors[num_rounds_q - num_proofs[p].log_2()];
497483
eval_vars_comb_list.push(eval_vars_comb);
498484
}
499485
timer_polyeval.stop();
500486

501487
let poly_vars = DensePolynomial::new(eval_vars_comb_list);
502488
let eval_vars_at_ry = poly_vars.evaluate(&rp);
503-
489+
// prove the final step of sum-check #2
490+
// Deferred to verifier
504491
timer_prove.stop();
505492

506493
(

spartan_parallel/src/sumcheck.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ impl<S: SpartanExtensionField> SumcheckInstanceProof<S> {
332332
poly_C: &mut DensePolynomialPqx<S>,
333333
comb_func: F,
334334
transcript: &mut Transcript,
335-
) -> (Self, Vec<S>, Vec<S>)
335+
) -> (Self, Vec<S>, Vec<S>, Vec<Vec<S>>)
336336
where
337337
F: Fn(&S, &S, &S) -> S,
338338
{
@@ -353,6 +353,8 @@ impl<S: SpartanExtensionField> SumcheckInstanceProof<S> {
353353
let mut witness_secs_len = num_rounds_w.pow2();
354354
let mut instance_len: usize = num_rounds_p.pow2();
355355

356+
// Every variable binded to ry
357+
let mut claimed_vars_at_ry = Vec::new();
356358
for j in 0..num_rounds {
357359
/* For debugging only */
358360
/* If the value is not 0, the instance / input is wrong */
@@ -385,6 +387,14 @@ impl<S: SpartanExtensionField> SumcheckInstanceProof<S> {
385387
} else {
386388
MODE_P
387389
};
390+
if j == num_rounds_y_max {
391+
for p in 0..poly_C.num_instances {
392+
claimed_vars_at_ry.push(Vec::new());
393+
for w in 0..poly_C.num_witness_secs {
394+
claimed_vars_at_ry[p].push(poly_C.index(p, 0, w, 0));
395+
}
396+
}
397+
}
388398

389399
if inputs_len > 1 {
390400
inputs_len /= 2
@@ -486,6 +496,7 @@ impl<S: SpartanExtensionField> SumcheckInstanceProof<S> {
486496
poly_B.index(0, 0, 0, 0),
487497
poly_C.index(0, 0, 0, 0),
488498
],
499+
claimed_vars_at_ry,
489500
)
490501
}
491502

0 commit comments

Comments
 (0)