@@ -10,6 +10,7 @@ use super::transcript::{AppendToTranscript, ProofTranscript};
1010use super :: unipoly:: { CompressedUniPoly , UniPoly } ;
1111use itertools:: izip;
1212use merlin:: Transcript ;
13+ use rayon:: iter:: { IntoParallelIterator , ParallelIterator } ;
1314use serde:: { Deserialize , Serialize } ;
1415use std:: cmp:: min;
1516
@@ -506,7 +507,7 @@ impl<S: SpartanExtensionField> SumcheckInstanceProof<S> {
506507 transcript : & mut Transcript ,
507508 ) -> ( Self , Vec < S > , Vec < S > )
508509 where
509- F : Fn ( & S , & S , & S , & S ) -> S ,
510+ F : Fn ( & S , & S , & S , & S ) -> S + std :: marker :: Sync ,
510511 {
511512 let ZERO = S :: field_zero ( ) ;
512513
@@ -569,87 +570,157 @@ impl<S: SpartanExtensionField> SumcheckInstanceProof<S> {
569570 } ;
570571
571572 let poly = {
572- let mut eval_point_0 = ZERO ;
573- let mut eval_point_2 = ZERO ;
574- let mut eval_point_3 = ZERO ;
575-
576- // We are guaranteed initially instance_len < num_proofs.len() < instance_len x 2
577- // So min(instance_len, num_proofs.len()) suffices
578- for p in 0 ..min ( instance_len, num_proofs. len ( ) ) {
579- if mode == MODE_X { num_cons[ p] = num_cons[ p] . div_ceil ( 2 ) ; }
580- // If q > num_proofs[p], the polynomials always evaluate to 0
581- if mode == MODE_Q { num_proofs[ p] = num_proofs[ p] . div_ceil ( 2 ) ; }
582- for q in 0 ..num_proofs[ p] {
583- for x in 0 ..num_cons[ p] {
584- // evaluate A, B, C, D on p, q, x
585- let ( poly_A_low, poly_A_high) = match mode {
586- MODE_X => (
587- poly_Ap[ p] * poly_Aq[ q] * poly_Ax[ 2 * x] ,
588- poly_Ap[ p] * poly_Aq[ q] * poly_Ax[ 2 * x + 1 ] ,
589- ) ,
590- MODE_Q => (
591- poly_Ap[ p] * poly_Aq[ 2 * q] * poly_Ax[ x] ,
592- poly_Ap[ p] * poly_Aq[ 2 * q + 1 ] * poly_Ax[ x] ,
593- ) ,
594- MODE_P => (
595- poly_Ap[ 2 * p] * poly_Aq[ q] * poly_Ax[ x] ,
596- poly_Ap[ 2 * p + 1 ] * poly_Aq[ q] * poly_Ax[ x] ,
597- ) ,
598- _ => unreachable ! ( )
599- } ;
600- let poly_B_low = poly_B. index_low ( p, q, 0 , x, mode) ;
601- let poly_B_high = poly_B. index_high ( p, q, 0 , x, mode) ;
602- let poly_C_low = poly_C. index_low ( p, q, 0 , x, mode) ;
603- let poly_C_high = poly_C. index_high ( p, q, 0 , x, mode) ;
604- let poly_D_low = poly_D. index_low ( p, q, 0 , x, mode) ;
605- let poly_D_high = poly_D. index_high ( p, q, 0 , x, mode) ;
606-
607- // eval 0: bound_func is A(low)
608- eval_point_0 = eval_point_0
609- + comb_func (
610- & poly_A_low,
611- & poly_B_low,
612- & poly_C_low,
613- & poly_D_low,
614- ) ; // Az[x, x, x, ..., 0]
615-
616- // eval 2: bound_func is -A(low) + 2*A(high)
617- let poly_A_bound_point = poly_A_high + poly_A_high - poly_A_low;
618- let poly_B_bound_point = poly_B_high + poly_B_high - poly_B_low;
619- let poly_C_bound_point = poly_C_high + poly_C_high - poly_C_low;
620- let poly_D_bound_point = poly_D_high + poly_D_high - poly_D_low;
621- eval_point_2 = eval_point_2
622- + comb_func (
623- & poly_A_bound_point,
624- & poly_B_bound_point,
625- & poly_C_bound_point,
626- & poly_D_bound_point,
627- ) ; // Az[x, x, ..., 2]
573+ if mode == MODE_X {
574+ // Multicore evaluation in MODE_X
575+ let mut eval_point_0 = ZERO ;
576+ let mut eval_point_2 = ZERO ;
577+ let mut eval_point_3 = ZERO ;
578+
579+ // We are guaranteed initially instance_len < num_proofs.len() < instance_len x 2
580+ // So min(instance_len, num_proofs.len()) suffices
581+ for p in 0 ..min ( instance_len, num_proofs. len ( ) ) {
582+ num_cons[ p] = num_cons[ p] . div_ceil ( 2 ) ;
583+ ( eval_point_0, eval_point_2, eval_point_3) = ( 0 ..num_proofs[ p] ) . into_par_iter ( ) . map ( |q| {
584+ let mut eval_point_0 = ZERO ;
585+ let mut eval_point_2 = ZERO ;
586+ let mut eval_point_3 = ZERO ;
587+ for x in 0 ..num_cons[ p] {
588+ // evaluate A, B, C, D on p, q, x
589+ let poly_A_low = poly_Ap[ p] * poly_Aq[ q] * poly_Ax[ 2 * x] ;
590+ let poly_A_high = poly_Ap[ p] * poly_Aq[ q] * poly_Ax[ 2 * x + 1 ] ;
591+ let poly_B_low = poly_B. index_low ( p, q, 0 , x, mode) ;
592+ let poly_B_high = poly_B. index_high ( p, q, 0 , x, mode) ;
593+ let poly_C_low = poly_C. index_low ( p, q, 0 , x, mode) ;
594+ let poly_C_high = poly_C. index_high ( p, q, 0 , x, mode) ;
595+ let poly_D_low = poly_D. index_low ( p, q, 0 , x, mode) ;
596+ let poly_D_high = poly_D. index_high ( p, q, 0 , x, mode) ;
597+
598+ // eval 0: bound_func is A(low)
599+ eval_point_0 = eval_point_0
600+ + comb_func (
601+ & poly_A_low,
602+ & poly_B_low,
603+ & poly_C_low,
604+ & poly_D_low,
605+ ) ; // Az[x, x, x, ..., 0]
606+
607+ // eval 2: bound_func is -A(low) + 2*A(high)
608+ let poly_A_bound_point = poly_A_high + poly_A_high - poly_A_low;
609+ let poly_B_bound_point = poly_B_high + poly_B_high - poly_B_low;
610+ let poly_C_bound_point = poly_C_high + poly_C_high - poly_C_low;
611+ let poly_D_bound_point = poly_D_high + poly_D_high - poly_D_low;
612+ eval_point_2 = eval_point_2
613+ + comb_func (
614+ & poly_A_bound_point,
615+ & poly_B_bound_point,
616+ & poly_C_bound_point,
617+ & poly_D_bound_point,
618+ ) ; // Az[x, x, ..., 2]
619+
620+ // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2)
621+ let poly_A_bound_point = poly_A_bound_point + poly_A_high - poly_A_low;
622+ let poly_B_bound_point = poly_B_bound_point + poly_B_high - poly_B_low;
623+ let poly_C_bound_point = poly_C_bound_point + poly_C_high - poly_C_low;
624+ let poly_D_bound_point = poly_D_bound_point + poly_D_high - poly_D_low;
625+ eval_point_3 = eval_point_3
626+ + comb_func (
627+ & poly_A_bound_point,
628+ & poly_B_bound_point,
629+ & poly_C_bound_point,
630+ & poly_D_bound_point,
631+ ) ; // Az[x, x, ..., 3]
632+ }
633+ ( eval_point_0, eval_point_2, eval_point_3)
634+ } ) . collect :: < Vec < ( S , S , S ) > > ( ) . into_iter ( ) . fold ( ( eval_point_0, eval_point_2, eval_point_3) , |( e0, e2, e3) , ( a0, a2, a3) | ( e0 + a0, e2 + a2, e3 + a3) ) ;
635+ }
628636
629- // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2)
630- let poly_A_bound_point = poly_A_bound_point + poly_A_high - poly_A_low;
631- let poly_B_bound_point = poly_B_bound_point + poly_B_high - poly_B_low;
632- let poly_C_bound_point = poly_C_bound_point + poly_C_high - poly_C_low;
633- let poly_D_bound_point = poly_D_bound_point + poly_D_high - poly_D_low;
634- eval_point_3 = eval_point_3
635- + comb_func (
636- & poly_A_bound_point,
637- & poly_B_bound_point,
638- & poly_C_bound_point,
639- & poly_D_bound_point,
640- ) ; // Az[x, x, ..., 3]
637+ let evals = vec ! [
638+ eval_point_0,
639+ claim_per_round - eval_point_0,
640+ eval_point_2,
641+ eval_point_3,
642+ ] ;
643+ let poly = UniPoly :: from_evals ( & evals) ;
644+ poly
645+ } else {
646+ // Singlecore evaluation in other Modes
647+ let mut eval_point_0 = ZERO ;
648+ let mut eval_point_2 = ZERO ;
649+ let mut eval_point_3 = ZERO ;
650+
651+ // We are guaranteed initially instance_len < num_proofs.len() < instance_len x 2
652+ // So min(instance_len, num_proofs.len()) suffices
653+ for p in 0 ..min ( instance_len, num_proofs. len ( ) ) {
654+ // If q > num_proofs[p], the polynomials always evaluate to 0
655+ if mode == MODE_Q { num_proofs[ p] = num_proofs[ p] . div_ceil ( 2 ) ; }
656+ for q in 0 ..num_proofs[ p] {
657+ for x in 0 ..num_cons[ p] {
658+ // evaluate A, B, C, D on p, q, x
659+ let ( poly_A_low, poly_A_high) = match mode {
660+ MODE_Q => (
661+ poly_Ap[ p] * poly_Aq[ 2 * q] * poly_Ax[ x] ,
662+ poly_Ap[ p] * poly_Aq[ 2 * q + 1 ] * poly_Ax[ x] ,
663+ ) ,
664+ MODE_P => (
665+ poly_Ap[ 2 * p] * poly_Aq[ q] * poly_Ax[ x] ,
666+ poly_Ap[ 2 * p + 1 ] * poly_Aq[ q] * poly_Ax[ x] ,
667+ ) ,
668+ _ => unreachable ! ( )
669+ } ;
670+ let poly_B_low = poly_B. index_low ( p, q, 0 , x, mode) ;
671+ let poly_B_high = poly_B. index_high ( p, q, 0 , x, mode) ;
672+ let poly_C_low = poly_C. index_low ( p, q, 0 , x, mode) ;
673+ let poly_C_high = poly_C. index_high ( p, q, 0 , x, mode) ;
674+ let poly_D_low = poly_D. index_low ( p, q, 0 , x, mode) ;
675+ let poly_D_high = poly_D. index_high ( p, q, 0 , x, mode) ;
676+
677+ // eval 0: bound_func is A(low)
678+ eval_point_0 = eval_point_0
679+ + comb_func (
680+ & poly_A_low,
681+ & poly_B_low,
682+ & poly_C_low,
683+ & poly_D_low,
684+ ) ; // Az[x, x, x, ..., 0]
685+
686+ // eval 2: bound_func is -A(low) + 2*A(high)
687+ let poly_A_bound_point = poly_A_high + poly_A_high - poly_A_low;
688+ let poly_B_bound_point = poly_B_high + poly_B_high - poly_B_low;
689+ let poly_C_bound_point = poly_C_high + poly_C_high - poly_C_low;
690+ let poly_D_bound_point = poly_D_high + poly_D_high - poly_D_low;
691+ eval_point_2 = eval_point_2
692+ + comb_func (
693+ & poly_A_bound_point,
694+ & poly_B_bound_point,
695+ & poly_C_bound_point,
696+ & poly_D_bound_point,
697+ ) ; // Az[x, x, ..., 2]
698+
699+ // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2)
700+ let poly_A_bound_point = poly_A_bound_point + poly_A_high - poly_A_low;
701+ let poly_B_bound_point = poly_B_bound_point + poly_B_high - poly_B_low;
702+ let poly_C_bound_point = poly_C_bound_point + poly_C_high - poly_C_low;
703+ let poly_D_bound_point = poly_D_bound_point + poly_D_high - poly_D_low;
704+ eval_point_3 = eval_point_3
705+ + comb_func (
706+ & poly_A_bound_point,
707+ & poly_B_bound_point,
708+ & poly_C_bound_point,
709+ & poly_D_bound_point,
710+ ) ; // Az[x, x, ..., 3]
711+ }
641712 }
642713 }
643- }
644714
645- let evals = vec ! [
646- eval_point_0,
647- claim_per_round - eval_point_0,
648- eval_point_2,
649- eval_point_3,
650- ] ;
651- let poly = UniPoly :: from_evals ( & evals) ;
652- poly
715+ let evals = vec ! [
716+ eval_point_0,
717+ claim_per_round - eval_point_0,
718+ eval_point_2,
719+ eval_point_3,
720+ ] ;
721+ let poly = UniPoly :: from_evals ( & evals) ;
722+ poly
723+ }
653724 } ;
654725
655726 // append the prover's message to the transcript
0 commit comments