Skip to content

Commit efdc44a

Browse files
committed
Correct DensePolynomial in Spartan sumcheck
1 parent 23fe607 commit efdc44a

File tree

1 file changed

+53
-42
lines changed

1 file changed

+53
-42
lines changed

spartan_parallel/src/sumcheck.rs

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {
7070

7171
impl<E: ExtensionField> SumcheckInstanceProof<E> {
7272
// _debug: remove native sumcheck prover
73-
/*
7473
pub fn prove_cubic<F>(
7574
claim: &E,
7675
num_rounds: usize,
@@ -93,13 +92,17 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {
9392

9493
let len = poly_A.evaluations().len() / 2;
9594
for i in 0..len {
95+
let poly_A_vec = poly_A.get_ext_field_vec();
96+
let poly_B_vec = poly_B.get_ext_field_vec();
97+
let poly_C_vec = poly_C.get_ext_field_vec();
98+
9699
// eval 0: bound_func is A(low)
97-
eval_point_0 += comb_func(&poly_A[i], &poly_B[i], &poly_C[i]);
100+
eval_point_0 += comb_func(&poly_A_vec[i], &poly_B_vec[i], &poly_C_vec[i]);
98101

99102
// eval 2: bound_func is -A(low) + 2*A(high)
100-
let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i];
101-
let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i];
102-
let poly_C_bound_point = poly_C[len + i] + poly_C[len + i] - poly_C[i];
103+
let poly_A_bound_point = poly_A_vec[len + i] + poly_A_vec[len + i] - poly_A_vec[i];
104+
let poly_B_bound_point = poly_B_vec[len + i] + poly_B_vec[len + i] - poly_B_vec[i];
105+
let poly_C_bound_point = poly_C_vec[len + i] + poly_C_vec[len + i] - poly_C_vec[i];
103106
eval_point_2 = eval_point_2
104107
+ comb_func(
105108
&poly_A_bound_point,
@@ -108,9 +111,9 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {
108111
);
109112

110113
// eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2)
111-
let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i];
112-
let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i];
113-
let poly_C_bound_point = poly_C_bound_point + poly_C[len + i] - poly_C[i];
114+
let poly_A_bound_point = poly_A_bound_point + poly_A_vec[len + i] - poly_A_vec[i];
115+
let poly_B_bound_point = poly_B_bound_point + poly_B_vec[len + i] - poly_B_vec[i];
116+
let poly_C_bound_point = poly_C_bound_point + poly_C_vec[len + i] - poly_C_vec[i];
114117

115118
eval_point_3 = eval_point_3
116119
+ comb_func(
@@ -130,17 +133,21 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {
130133
let r_j = challenge_scalar(transcript, b"challenge_nextround");
131134
r.push(r_j);
132135
// bound all tables to the verifier's challenege
133-
poly_A.bound_poly_var_top(&r_j);
134-
poly_B.bound_poly_var_top(&r_j);
135-
poly_C.bound_poly_var_top(&r_j);
136+
poly_A.fix_variables_in_place(&[r_j]);
137+
poly_B.fix_variables_in_place(&[r_j]);
138+
poly_C.fix_variables_in_place(&[r_j]);
136139
e = poly.evaluate(&r_j);
137140
cubic_polys.push(poly.compress());
138141
}
139142

140143
(
141144
SumcheckInstanceProof::new(cubic_polys),
142145
r,
143-
vec![poly_A[0], poly_B[0], poly_C[0]],
146+
vec![
147+
poly_A.get_ext_field_vec()[0],
148+
poly_B.get_ext_field_vec()[0],
149+
poly_C.get_ext_field_vec()[0]
150+
],
144151
)
145152
}
146153

@@ -167,7 +174,6 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {
167174
let (poly_A_vec_par, poly_B_vec_par, poly_C_par) = poly_vec_par;
168175
let (poly_A_vec_seq, poly_B_vec_seq, poly_C_vec_seq) = poly_vec_seq;
169176

170-
//let (poly_A_vec_seq, poly_B_vec_seq, poly_C_vec_seq) = poly_vec_seq;
171177
let mut e = *claim;
172178
let mut r: Vec<E> = Vec::new();
173179
let mut cubic_polys: Vec<CompressedUniPoly<E>> = Vec::new();
@@ -180,15 +186,18 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {
180186
let mut eval_point_2 = E::ZERO;
181187
let mut eval_point_3 = E::ZERO;
182188

183-
let len = poly_A.len() / 2;
189+
let len = poly_A.evaluations().len() / 2;
184190
for i in 0..len {
191+
let poly_A_vec = poly_A.get_ext_field_vec();
192+
let poly_B_vec = poly_B.get_ext_field_vec();
193+
let poly_C_par_vec = poly_C_par.get_ext_field_vec();
185194
// eval 0: bound_func is A(low)
186-
eval_point_0 = eval_point_0 + comb_func(&poly_A[i], &poly_B[i], &poly_C_par[i]);
195+
eval_point_0 = eval_point_0 + comb_func(&poly_A_vec[i], &poly_B_vec[i], &poly_C_par_vec[i]);
187196

188197
// eval 2: bound_func is -A(low) + 2*A(high)
189-
let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i];
190-
let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i];
191-
let poly_C_bound_point = poly_C_par[len + i] + poly_C_par[len + i] - poly_C_par[i];
198+
let poly_A_bound_point = poly_A_vec[len + i] + poly_A_vec[len + i] - poly_A_vec[i];
199+
let poly_B_bound_point = poly_B_vec[len + i] + poly_B_vec[len + i] - poly_B_vec[i];
200+
let poly_C_bound_point = poly_C_par_vec[len + i] + poly_C_par_vec[len + i] - poly_C_par_vec[i];
192201
eval_point_2 = eval_point_2
193202
+ comb_func(
194203
&poly_A_bound_point,
@@ -197,9 +206,9 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {
197206
);
198207

199208
// eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2)
200-
let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i];
201-
let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i];
202-
let poly_C_bound_point = poly_C_bound_point + poly_C_par[len + i] - poly_C_par[i];
209+
let poly_A_bound_point = poly_A_bound_point + poly_A_vec[len + i] - poly_A_vec[i];
210+
let poly_B_bound_point = poly_B_bound_point + poly_B_vec[len + i] - poly_B_vec[i];
211+
let poly_C_bound_point = poly_C_bound_point + poly_C_par_vec[len + i] - poly_C_par_vec[i];
203212

204213
eval_point_3 = eval_point_3
205214
+ comb_func(
@@ -220,24 +229,27 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {
220229
let mut eval_point_0 = E::ZERO;
221230
let mut eval_point_2 = E::ZERO;
222231
let mut eval_point_3 = E::ZERO;
223-
let len = poly_A.len() / 2;
232+
let len = poly_A.evaluations().len() / 2;
224233
for i in 0..len {
234+
let poly_A_vec = poly_A.get_ext_field_vec();
235+
let poly_B_vec = poly_B.get_ext_field_vec();
236+
let poly_C_vec = poly_C.get_ext_field_vec();
225237
// eval 0: bound_func is A(low)
226-
eval_point_0 = eval_point_0 + comb_func(&poly_A[i], &poly_B[i], &poly_C[i]);
238+
eval_point_0 = eval_point_0 + comb_func(&poly_A_vec[i], &poly_B_vec[i], &poly_C_vec[i]);
227239
// eval 2: bound_func is -A(low) + 2*A(high)
228-
let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i];
229-
let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i];
230-
let poly_C_bound_point = poly_C[len + i] + poly_C[len + i] - poly_C[i];
240+
let poly_A_bound_point = poly_A_vec[len + i] + poly_A_vec[len + i] - poly_A_vec[i];
241+
let poly_B_bound_point = poly_B_vec[len + i] + poly_B_vec[len + i] - poly_B_vec[i];
242+
let poly_C_bound_point = poly_C_vec[len + i] + poly_C_vec[len + i] - poly_C_vec[i];
231243
eval_point_2 = eval_point_2
232244
+ comb_func(
233245
&poly_A_bound_point,
234246
&poly_B_bound_point,
235247
&poly_C_bound_point,
236248
);
237249
// eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2)
238-
let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i];
239-
let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i];
240-
let poly_C_bound_point = poly_C_bound_point + poly_C[len + i] - poly_C[i];
250+
let poly_A_bound_point = poly_A_bound_point + poly_A_vec[len + i] - poly_A_vec[i];
251+
let poly_B_bound_point = poly_B_bound_point + poly_B_vec[len + i] - poly_B_vec[i];
252+
let poly_C_bound_point = poly_C_bound_point + poly_C_vec[len + i] - poly_C_vec[i];
241253
eval_point_3 = eval_point_3
242254
+ comb_func(
243255
&poly_A_bound_point,
@@ -273,41 +285,41 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {
273285

274286
// bound all tables to the verifier's challenege
275287
for (poly_A, poly_B) in poly_A_vec_par.iter_mut().zip(poly_B_vec_par.iter_mut()) {
276-
poly_A.bound_poly_var_top(&r_j);
277-
poly_B.bound_poly_var_top(&r_j);
288+
poly_A.fix_variables_in_place(&[r_j]);
289+
poly_B.fix_variables_in_place(&[r_j]);
278290
}
279-
poly_C_par.bound_poly_var_top(&r_j);
291+
poly_C_par.fix_variables_in_place(&[r_j]);
280292

281293
for (poly_A, poly_B, poly_C) in izip!(
282294
poly_A_vec_seq.iter_mut(),
283295
poly_B_vec_seq.iter_mut(),
284296
poly_C_vec_seq.iter_mut()
285297
) {
286-
poly_A.bound_poly_var_top(&r_j);
287-
poly_B.bound_poly_var_top(&r_j);
288-
poly_C.bound_poly_var_top(&r_j);
298+
poly_A.fix_variables_in_place(&[r_j]);
299+
poly_B.fix_variables_in_place(&[r_j]);
300+
poly_C.fix_variables_in_place(&[r_j]);
289301
}
290302

291303
e = poly.evaluate(&r_j);
292304
cubic_polys.push(poly.compress());
293305
}
294306

295307
let poly_A_par_final = (0..poly_A_vec_par.len())
296-
.map(|i| poly_A_vec_par[i][0])
308+
.map(|i| poly_A_vec_par[i].get_ext_field_vec()[0])
297309
.collect();
298310
let poly_B_par_final = (0..poly_B_vec_par.len())
299-
.map(|i| poly_B_vec_par[i][0])
311+
.map(|i| poly_B_vec_par[i].get_ext_field_vec()[0])
300312
.collect();
301-
let claims_prod = (poly_A_par_final, poly_B_par_final, poly_C_par[0]);
313+
let claims_prod = (poly_A_par_final, poly_B_par_final, poly_C_par.get_ext_field_vec()[0]);
302314

303315
let poly_A_seq_final = (0..poly_A_vec_seq.len())
304-
.map(|i| poly_A_vec_seq[i][0])
316+
.map(|i| poly_A_vec_seq[i].get_ext_field_vec()[0])
305317
.collect();
306318
let poly_B_seq_final = (0..poly_B_vec_seq.len())
307-
.map(|i| poly_B_vec_seq[i][0])
319+
.map(|i| poly_B_vec_seq[i].get_ext_field_vec()[0])
308320
.collect();
309321
let poly_C_seq_final = (0..poly_C_vec_seq.len())
310-
.map(|i| poly_C_vec_seq[i][0])
322+
.map(|i| poly_C_vec_seq[i].get_ext_field_vec()[0])
311323
.collect();
312324
let claims_dotp = (poly_A_seq_final, poly_B_seq_final, poly_C_seq_final);
313325

@@ -318,7 +330,6 @@ impl<E: ExtensionField> SumcheckInstanceProof<E> {
318330
claims_dotp,
319331
)
320332
}
321-
*/
322333

323334
pub fn prove_cubic_disjoint_rounds<F>(
324335
claim: &E,

0 commit comments

Comments
 (0)