Skip to content

Commit

Permalink
Correct DensePolynomial in Spartan sumcheck
Browse files Browse the repository at this point in the history
  • Loading branch information
darth-cy committed Feb 12, 2025
1 parent 23fe607 commit efdc44a
Showing 1 changed file with 53 additions and 42 deletions.
95 changes: 53 additions & 42 deletions spartan_parallel/src/sumcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {

impl<E: ExtensionField> SumcheckInstanceProof<E> {
// _debug: remove native sumcheck prover
/*
pub fn prove_cubic<F>(
claim: &E,
num_rounds: usize,
Expand All @@ -93,13 +92,17 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {

let len = poly_A.evaluations().len() / 2;
for i in 0..len {
let poly_A_vec = poly_A.get_ext_field_vec();
let poly_B_vec = poly_B.get_ext_field_vec();
let poly_C_vec = poly_C.get_ext_field_vec();

// eval 0: bound_func is A(low)
eval_point_0 += comb_func(&poly_A[i], &poly_B[i], &poly_C[i]);
eval_point_0 += comb_func(&poly_A_vec[i], &poly_B_vec[i], &poly_C_vec[i]);

// eval 2: bound_func is -A(low) + 2*A(high)
let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i];
let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i];
let poly_C_bound_point = poly_C[len + i] + poly_C[len + i] - poly_C[i];
let poly_A_bound_point = poly_A_vec[len + i] + poly_A_vec[len + i] - poly_A_vec[i];
let poly_B_bound_point = poly_B_vec[len + i] + poly_B_vec[len + i] - poly_B_vec[i];
let poly_C_bound_point = poly_C_vec[len + i] + poly_C_vec[len + i] - poly_C_vec[i];
eval_point_2 = eval_point_2
+ comb_func(
&poly_A_bound_point,
Expand All @@ -108,9 +111,9 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {
);

// eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2)
let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i];
let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i];
let poly_C_bound_point = poly_C_bound_point + poly_C[len + i] - poly_C[i];
let poly_A_bound_point = poly_A_bound_point + poly_A_vec[len + i] - poly_A_vec[i];
let poly_B_bound_point = poly_B_bound_point + poly_B_vec[len + i] - poly_B_vec[i];
let poly_C_bound_point = poly_C_bound_point + poly_C_vec[len + i] - poly_C_vec[i];

eval_point_3 = eval_point_3
+ comb_func(
Expand All @@ -130,17 +133,21 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {
let r_j = challenge_scalar(transcript, b"challenge_nextround");
r.push(r_j);
// bound all tables to the verifier's challenege
poly_A.bound_poly_var_top(&r_j);
poly_B.bound_poly_var_top(&r_j);
poly_C.bound_poly_var_top(&r_j);
poly_A.fix_variables_in_place(&[r_j]);
poly_B.fix_variables_in_place(&[r_j]);
poly_C.fix_variables_in_place(&[r_j]);
e = poly.evaluate(&r_j);
cubic_polys.push(poly.compress());
}

(
SumcheckInstanceProof::new(cubic_polys),
r,
vec![poly_A[0], poly_B[0], poly_C[0]],
vec![
poly_A.get_ext_field_vec()[0],
poly_B.get_ext_field_vec()[0],
poly_C.get_ext_field_vec()[0]
],
)
}

Expand All @@ -167,7 +174,6 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {
let (poly_A_vec_par, poly_B_vec_par, poly_C_par) = poly_vec_par;
let (poly_A_vec_seq, poly_B_vec_seq, poly_C_vec_seq) = poly_vec_seq;

//let (poly_A_vec_seq, poly_B_vec_seq, poly_C_vec_seq) = poly_vec_seq;
let mut e = *claim;
let mut r: Vec<E> = Vec::new();
let mut cubic_polys: Vec<CompressedUniPoly<E>> = Vec::new();
Expand All @@ -180,15 +186,18 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {
let mut eval_point_2 = E::ZERO;
let mut eval_point_3 = E::ZERO;

let len = poly_A.len() / 2;
let len = poly_A.evaluations().len() / 2;
for i in 0..len {
let poly_A_vec = poly_A.get_ext_field_vec();
let poly_B_vec = poly_B.get_ext_field_vec();
let poly_C_par_vec = poly_C_par.get_ext_field_vec();
// eval 0: bound_func is A(low)
eval_point_0 = eval_point_0 + comb_func(&poly_A[i], &poly_B[i], &poly_C_par[i]);
eval_point_0 = eval_point_0 + comb_func(&poly_A_vec[i], &poly_B_vec[i], &poly_C_par_vec[i]);

// eval 2: bound_func is -A(low) + 2*A(high)
let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i];
let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i];
let poly_C_bound_point = poly_C_par[len + i] + poly_C_par[len + i] - poly_C_par[i];
let poly_A_bound_point = poly_A_vec[len + i] + poly_A_vec[len + i] - poly_A_vec[i];
let poly_B_bound_point = poly_B_vec[len + i] + poly_B_vec[len + i] - poly_B_vec[i];
let poly_C_bound_point = poly_C_par_vec[len + i] + poly_C_par_vec[len + i] - poly_C_par_vec[i];
eval_point_2 = eval_point_2
+ comb_func(
&poly_A_bound_point,
Expand All @@ -197,9 +206,9 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {
);

// eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2)
let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i];
let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i];
let poly_C_bound_point = poly_C_bound_point + poly_C_par[len + i] - poly_C_par[i];
let poly_A_bound_point = poly_A_bound_point + poly_A_vec[len + i] - poly_A_vec[i];
let poly_B_bound_point = poly_B_bound_point + poly_B_vec[len + i] - poly_B_vec[i];
let poly_C_bound_point = poly_C_bound_point + poly_C_par_vec[len + i] - poly_C_par_vec[i];

eval_point_3 = eval_point_3
+ comb_func(
Expand All @@ -220,24 +229,27 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {
let mut eval_point_0 = E::ZERO;
let mut eval_point_2 = E::ZERO;
let mut eval_point_3 = E::ZERO;
let len = poly_A.len() / 2;
let len = poly_A.evaluations().len() / 2;
for i in 0..len {
let poly_A_vec = poly_A.get_ext_field_vec();
let poly_B_vec = poly_B.get_ext_field_vec();
let poly_C_vec = poly_C.get_ext_field_vec();
// eval 0: bound_func is A(low)
eval_point_0 = eval_point_0 + comb_func(&poly_A[i], &poly_B[i], &poly_C[i]);
eval_point_0 = eval_point_0 + comb_func(&poly_A_vec[i], &poly_B_vec[i], &poly_C_vec[i]);
// eval 2: bound_func is -A(low) + 2*A(high)
let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i];
let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i];
let poly_C_bound_point = poly_C[len + i] + poly_C[len + i] - poly_C[i];
let poly_A_bound_point = poly_A_vec[len + i] + poly_A_vec[len + i] - poly_A_vec[i];
let poly_B_bound_point = poly_B_vec[len + i] + poly_B_vec[len + i] - poly_B_vec[i];
let poly_C_bound_point = poly_C_vec[len + i] + poly_C_vec[len + i] - poly_C_vec[i];
eval_point_2 = eval_point_2
+ comb_func(
&poly_A_bound_point,
&poly_B_bound_point,
&poly_C_bound_point,
);
// eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2)
let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i];
let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i];
let poly_C_bound_point = poly_C_bound_point + poly_C[len + i] - poly_C[i];
let poly_A_bound_point = poly_A_bound_point + poly_A_vec[len + i] - poly_A_vec[i];
let poly_B_bound_point = poly_B_bound_point + poly_B_vec[len + i] - poly_B_vec[i];
let poly_C_bound_point = poly_C_bound_point + poly_C_vec[len + i] - poly_C_vec[i];
eval_point_3 = eval_point_3
+ comb_func(
&poly_A_bound_point,
Expand Down Expand Up @@ -273,41 +285,41 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {

// bound all tables to the verifier's challenege
for (poly_A, poly_B) in poly_A_vec_par.iter_mut().zip(poly_B_vec_par.iter_mut()) {
poly_A.bound_poly_var_top(&r_j);
poly_B.bound_poly_var_top(&r_j);
poly_A.fix_variables_in_place(&[r_j]);
poly_B.fix_variables_in_place(&[r_j]);
}
poly_C_par.bound_poly_var_top(&r_j);
poly_C_par.fix_variables_in_place(&[r_j]);

for (poly_A, poly_B, poly_C) in izip!(
poly_A_vec_seq.iter_mut(),
poly_B_vec_seq.iter_mut(),
poly_C_vec_seq.iter_mut()
) {
poly_A.bound_poly_var_top(&r_j);
poly_B.bound_poly_var_top(&r_j);
poly_C.bound_poly_var_top(&r_j);
poly_A.fix_variables_in_place(&[r_j]);
poly_B.fix_variables_in_place(&[r_j]);
poly_C.fix_variables_in_place(&[r_j]);
}

e = poly.evaluate(&r_j);
cubic_polys.push(poly.compress());
}

let poly_A_par_final = (0..poly_A_vec_par.len())
.map(|i| poly_A_vec_par[i][0])
.map(|i| poly_A_vec_par[i].get_ext_field_vec()[0])
.collect();
let poly_B_par_final = (0..poly_B_vec_par.len())
.map(|i| poly_B_vec_par[i][0])
.map(|i| poly_B_vec_par[i].get_ext_field_vec()[0])
.collect();
let claims_prod = (poly_A_par_final, poly_B_par_final, poly_C_par[0]);
let claims_prod = (poly_A_par_final, poly_B_par_final, poly_C_par.get_ext_field_vec()[0]);

let poly_A_seq_final = (0..poly_A_vec_seq.len())
.map(|i| poly_A_vec_seq[i][0])
.map(|i| poly_A_vec_seq[i].get_ext_field_vec()[0])
.collect();
let poly_B_seq_final = (0..poly_B_vec_seq.len())
.map(|i| poly_B_vec_seq[i][0])
.map(|i| poly_B_vec_seq[i].get_ext_field_vec()[0])
.collect();
let poly_C_seq_final = (0..poly_C_vec_seq.len())
.map(|i| poly_C_vec_seq[i][0])
.map(|i| poly_C_vec_seq[i].get_ext_field_vec()[0])
.collect();
let claims_dotp = (poly_A_seq_final, poly_B_seq_final, poly_C_seq_final);

Expand All @@ -318,7 +330,6 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {
claims_dotp,
)
}
*/

pub fn prove_cubic_disjoint_rounds<F>(
claim: &E,
Expand Down

0 comments on commit efdc44a

Please sign in to comment.