Skip to content

Commit 1520dbe

Browse files
committed
Simple multicore impl for sumcheck
1 parent 16b8e1c commit 1520dbe

File tree

1 file changed

+149
-78
lines changed

1 file changed

+149
-78
lines changed

spartan_parallel/src/sumcheck.rs

Lines changed: 149 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use super::transcript::{AppendToTranscript, ProofTranscript};
1010
use super::unipoly::{CompressedUniPoly, UniPoly};
1111
use itertools::izip;
1212
use merlin::Transcript;
13+
use rayon::iter::{IntoParallelIterator, ParallelIterator};
1314
use serde::{Deserialize, Serialize};
1415
use std::cmp::min;
1516

@@ -506,7 +507,7 @@ impl<S: SpartanExtensionField> SumcheckInstanceProof<S> {
506507
transcript: &mut Transcript,
507508
) -> (Self, Vec<S>, Vec<S>)
508509
where
509-
F: Fn(&S, &S, &S, &S) -> S,
510+
F: Fn(&S, &S, &S, &S) -> S + std::marker::Sync,
510511
{
511512
let ZERO = S::field_zero();
512513

@@ -569,87 +570,157 @@ impl<S: SpartanExtensionField> SumcheckInstanceProof<S> {
569570
};
570571

571572
let poly = {
572-
let mut eval_point_0 = ZERO;
573-
let mut eval_point_2 = ZERO;
574-
let mut eval_point_3 = ZERO;
575-
576-
// We are guaranteed initially instance_len < num_proofs.len() < instance_len x 2
577-
// So min(instance_len, num_proofs.len()) suffices
578-
for p in 0..min(instance_len, num_proofs.len()) {
579-
if mode == MODE_X { num_cons[p] = num_cons[p].div_ceil(2); }
580-
// If q > num_proofs[p], the polynomials always evaluate to 0
581-
if mode == MODE_Q { num_proofs[p] = num_proofs[p].div_ceil(2); }
582-
for q in 0..num_proofs[p] {
583-
for x in 0..num_cons[p] {
584-
// evaluate A, B, C, D on p, q, x
585-
let (poly_A_low, poly_A_high) = match mode {
586-
MODE_X => (
587-
poly_Ap[p] * poly_Aq[q] * poly_Ax[2 * x],
588-
poly_Ap[p] * poly_Aq[q] * poly_Ax[2 * x + 1],
589-
),
590-
MODE_Q => (
591-
poly_Ap[p] * poly_Aq[2 * q] * poly_Ax[x],
592-
poly_Ap[p] * poly_Aq[2 * q + 1] * poly_Ax[x],
593-
),
594-
MODE_P => (
595-
poly_Ap[2 * p] * poly_Aq[q] * poly_Ax[x],
596-
poly_Ap[2 * p + 1] * poly_Aq[q] * poly_Ax[x],
597-
),
598-
_ => unreachable!()
599-
};
600-
let poly_B_low = poly_B.index_low(p, q, 0, x, mode);
601-
let poly_B_high = poly_B.index_high(p, q, 0, x, mode);
602-
let poly_C_low = poly_C.index_low(p, q, 0, x, mode);
603-
let poly_C_high = poly_C.index_high(p, q, 0, x, mode);
604-
let poly_D_low = poly_D.index_low(p, q, 0, x, mode);
605-
let poly_D_high = poly_D.index_high(p, q, 0, x, mode);
606-
607-
// eval 0: bound_func is A(low)
608-
eval_point_0 = eval_point_0
609-
+ comb_func(
610-
&poly_A_low,
611-
&poly_B_low,
612-
&poly_C_low,
613-
&poly_D_low,
614-
); // Az[x, x, x, ..., 0]
615-
616-
// eval 2: bound_func is -A(low) + 2*A(high)
617-
let poly_A_bound_point = poly_A_high + poly_A_high - poly_A_low;
618-
let poly_B_bound_point = poly_B_high + poly_B_high - poly_B_low;
619-
let poly_C_bound_point = poly_C_high + poly_C_high - poly_C_low;
620-
let poly_D_bound_point = poly_D_high + poly_D_high - poly_D_low;
621-
eval_point_2 = eval_point_2
622-
+ comb_func(
623-
&poly_A_bound_point,
624-
&poly_B_bound_point,
625-
&poly_C_bound_point,
626-
&poly_D_bound_point,
627-
); // Az[x, x, ..., 2]
573+
if mode == MODE_X {
574+
// Multicore evaluation in MODE_X
575+
let mut eval_point_0 = ZERO;
576+
let mut eval_point_2 = ZERO;
577+
let mut eval_point_3 = ZERO;
578+
579+
// We are guaranteed initially instance_len < num_proofs.len() < instance_len x 2
580+
// So min(instance_len, num_proofs.len()) suffices
581+
for p in 0..min(instance_len, num_proofs.len()) {
582+
num_cons[p] = num_cons[p].div_ceil(2);
583+
(eval_point_0, eval_point_2, eval_point_3) = (0..num_proofs[p]).into_par_iter().map(|q| {
584+
let mut eval_point_0 = ZERO;
585+
let mut eval_point_2 = ZERO;
586+
let mut eval_point_3 = ZERO;
587+
for x in 0..num_cons[p] {
588+
// evaluate A, B, C, D on p, q, x
589+
let poly_A_low = poly_Ap[p] * poly_Aq[q] * poly_Ax[2 * x];
590+
let poly_A_high = poly_Ap[p] * poly_Aq[q] * poly_Ax[2 * x + 1];
591+
let poly_B_low = poly_B.index_low(p, q, 0, x, mode);
592+
let poly_B_high = poly_B.index_high(p, q, 0, x, mode);
593+
let poly_C_low = poly_C.index_low(p, q, 0, x, mode);
594+
let poly_C_high = poly_C.index_high(p, q, 0, x, mode);
595+
let poly_D_low = poly_D.index_low(p, q, 0, x, mode);
596+
let poly_D_high = poly_D.index_high(p, q, 0, x, mode);
597+
598+
// eval 0: bound_func is A(low)
599+
eval_point_0 = eval_point_0
600+
+ comb_func(
601+
&poly_A_low,
602+
&poly_B_low,
603+
&poly_C_low,
604+
&poly_D_low,
605+
); // Az[x, x, x, ..., 0]
606+
607+
// eval 2: bound_func is -A(low) + 2*A(high)
608+
let poly_A_bound_point = poly_A_high + poly_A_high - poly_A_low;
609+
let poly_B_bound_point = poly_B_high + poly_B_high - poly_B_low;
610+
let poly_C_bound_point = poly_C_high + poly_C_high - poly_C_low;
611+
let poly_D_bound_point = poly_D_high + poly_D_high - poly_D_low;
612+
eval_point_2 = eval_point_2
613+
+ comb_func(
614+
&poly_A_bound_point,
615+
&poly_B_bound_point,
616+
&poly_C_bound_point,
617+
&poly_D_bound_point,
618+
); // Az[x, x, ..., 2]
619+
620+
// eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2)
621+
let poly_A_bound_point = poly_A_bound_point + poly_A_high - poly_A_low;
622+
let poly_B_bound_point = poly_B_bound_point + poly_B_high - poly_B_low;
623+
let poly_C_bound_point = poly_C_bound_point + poly_C_high - poly_C_low;
624+
let poly_D_bound_point = poly_D_bound_point + poly_D_high - poly_D_low;
625+
eval_point_3 = eval_point_3
626+
+ comb_func(
627+
&poly_A_bound_point,
628+
&poly_B_bound_point,
629+
&poly_C_bound_point,
630+
&poly_D_bound_point,
631+
); // Az[x, x, ..., 3]
632+
}
633+
(eval_point_0, eval_point_2, eval_point_3)
634+
}).collect::<Vec<(S, S, S)>>().into_iter().fold((eval_point_0, eval_point_2, eval_point_3), |(e0, e2, e3), (a0, a2, a3)| (e0 + a0, e2 + a2, e3 + a3));
635+
}
628636

629-
// eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2)
630-
let poly_A_bound_point = poly_A_bound_point + poly_A_high - poly_A_low;
631-
let poly_B_bound_point = poly_B_bound_point + poly_B_high - poly_B_low;
632-
let poly_C_bound_point = poly_C_bound_point + poly_C_high - poly_C_low;
633-
let poly_D_bound_point = poly_D_bound_point + poly_D_high - poly_D_low;
634-
eval_point_3 = eval_point_3
635-
+ comb_func(
636-
&poly_A_bound_point,
637-
&poly_B_bound_point,
638-
&poly_C_bound_point,
639-
&poly_D_bound_point,
640-
); // Az[x, x, ..., 3]
637+
let evals = vec![
638+
eval_point_0,
639+
claim_per_round - eval_point_0,
640+
eval_point_2,
641+
eval_point_3,
642+
];
643+
let poly = UniPoly::from_evals(&evals);
644+
poly
645+
} else {
646+
// Singlecore evaluation in other Modes
647+
let mut eval_point_0 = ZERO;
648+
let mut eval_point_2 = ZERO;
649+
let mut eval_point_3 = ZERO;
650+
651+
// We are guaranteed initially instance_len < num_proofs.len() < instance_len x 2
652+
// So min(instance_len, num_proofs.len()) suffices
653+
for p in 0..min(instance_len, num_proofs.len()) {
654+
// If q > num_proofs[p], the polynomials always evaluate to 0
655+
if mode == MODE_Q { num_proofs[p] = num_proofs[p].div_ceil(2); }
656+
for q in 0..num_proofs[p] {
657+
for x in 0..num_cons[p] {
658+
// evaluate A, B, C, D on p, q, x
659+
let (poly_A_low, poly_A_high) = match mode {
660+
MODE_Q => (
661+
poly_Ap[p] * poly_Aq[2 * q] * poly_Ax[x],
662+
poly_Ap[p] * poly_Aq[2 * q + 1] * poly_Ax[x],
663+
),
664+
MODE_P => (
665+
poly_Ap[2 * p] * poly_Aq[q] * poly_Ax[x],
666+
poly_Ap[2 * p + 1] * poly_Aq[q] * poly_Ax[x],
667+
),
668+
_ => unreachable!()
669+
};
670+
let poly_B_low = poly_B.index_low(p, q, 0, x, mode);
671+
let poly_B_high = poly_B.index_high(p, q, 0, x, mode);
672+
let poly_C_low = poly_C.index_low(p, q, 0, x, mode);
673+
let poly_C_high = poly_C.index_high(p, q, 0, x, mode);
674+
let poly_D_low = poly_D.index_low(p, q, 0, x, mode);
675+
let poly_D_high = poly_D.index_high(p, q, 0, x, mode);
676+
677+
// eval 0: bound_func is A(low)
678+
eval_point_0 = eval_point_0
679+
+ comb_func(
680+
&poly_A_low,
681+
&poly_B_low,
682+
&poly_C_low,
683+
&poly_D_low,
684+
); // Az[x, x, x, ..., 0]
685+
686+
// eval 2: bound_func is -A(low) + 2*A(high)
687+
let poly_A_bound_point = poly_A_high + poly_A_high - poly_A_low;
688+
let poly_B_bound_point = poly_B_high + poly_B_high - poly_B_low;
689+
let poly_C_bound_point = poly_C_high + poly_C_high - poly_C_low;
690+
let poly_D_bound_point = poly_D_high + poly_D_high - poly_D_low;
691+
eval_point_2 = eval_point_2
692+
+ comb_func(
693+
&poly_A_bound_point,
694+
&poly_B_bound_point,
695+
&poly_C_bound_point,
696+
&poly_D_bound_point,
697+
); // Az[x, x, ..., 2]
698+
699+
// eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2)
700+
let poly_A_bound_point = poly_A_bound_point + poly_A_high - poly_A_low;
701+
let poly_B_bound_point = poly_B_bound_point + poly_B_high - poly_B_low;
702+
let poly_C_bound_point = poly_C_bound_point + poly_C_high - poly_C_low;
703+
let poly_D_bound_point = poly_D_bound_point + poly_D_high - poly_D_low;
704+
eval_point_3 = eval_point_3
705+
+ comb_func(
706+
&poly_A_bound_point,
707+
&poly_B_bound_point,
708+
&poly_C_bound_point,
709+
&poly_D_bound_point,
710+
); // Az[x, x, ..., 3]
711+
}
641712
}
642713
}
643-
}
644714

645-
let evals = vec![
646-
eval_point_0,
647-
claim_per_round - eval_point_0,
648-
eval_point_2,
649-
eval_point_3,
650-
];
651-
let poly = UniPoly::from_evals(&evals);
652-
poly
715+
let evals = vec![
716+
eval_point_0,
717+
claim_per_round - eval_point_0,
718+
eval_point_2,
719+
eval_point_3,
720+
];
721+
let poly = UniPoly::from_evals(&evals);
722+
poly
723+
}
653724
};
654725

655726
// append the prover's message to the transcript

0 commit comments

Comments
 (0)