Skip to content

Commit 4fa7446

Browse files
committed
do more in-place
1 parent f73f1f2 commit 4fa7446

File tree

5 files changed

+22
-108
lines changed

5 files changed

+22
-108
lines changed

crates/stark-backend/src/gkr/prover.rs

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,20 +125,14 @@ impl<F: Field> MultivariatePolyOracle<F> for GkrMultivariatePolyOracle<'_, F> {
125125
correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, n_variables)
126126
}
127127

128-
fn partial_evaluation(self, alpha: F) -> Self {
128+
fn fix_first_in_place(&mut self, alpha: F) {
129129
if self.has_zero_arity() {
130-
return self;
130+
return;
131131
}
132132

133133
let z0 = self.eq_evals.y[self.eq_evals.y.len() - self.arity()];
134-
let eq_fixed_var_correction = self.eq_fixed_var_correction * (alpha * z0 + (F::ONE - alpha) * (F::ONE - z0));
135-
136-
Self {
137-
eq_evals: self.eq_evals,
138-
eq_fixed_var_correction,
139-
input_layer: self.input_layer.fix_first_variable(alpha),
140-
lambda: self.lambda,
141-
}
134+
self.eq_fixed_var_correction *= alpha * z0 + (F::ONE - alpha) * (F::ONE - z0);
135+
self.input_layer.fix_first_variable_in_place(alpha);
142136
}
143137
}
144138

crates/stark-backend/src/gkr/tests.rs

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -166,79 +166,3 @@ fn test_logup_with_generic_trace() -> Result<(), GkrError<BabyBear>> {
166166
);
167167
Ok(())
168168
}
169-
170-
#[test]
171-
fn test_logup_with_singles_trace() -> Result<(), GkrError<BabyBear>> {
172-
const N: usize = 1 << 5;
173-
type F = BabyBear;
174-
175-
let mut rng = create_seeded_rng();
176-
let denominator_values = (0..N).map(|_| rng.gen()).collect_vec();
177-
let sum: Fraction<F> = denominator_values
178-
.iter()
179-
.map(|&d| Fraction::new(F::ONE, d))
180-
.sum();
181-
let denominators = Mle::from_vec(denominator_values);
182-
let top_layer = Layer::LogUpSingles {
183-
denominators: denominators.clone(),
184-
};
185-
186-
let engine = default_engine();
187-
let (proof, _) = gkr::prove_batch(&mut engine.new_challenger(), vec![top_layer]);
188-
189-
let GkrArtifact {
190-
ood_point,
191-
claims_to_verify_by_instance,
192-
n_variables_by_instance: _,
193-
} = gkr::partially_verify_batch(vec![Gate::LogUp], &proof, &mut engine.new_challenger())?;
194-
195-
assert_eq!(claims_to_verify_by_instance.len(), 1);
196-
assert_eq!(proof.output_claims_by_instance.len(), 1);
197-
assert_eq!(
198-
claims_to_verify_by_instance[0],
199-
[F::ONE, denominators.eval(&ood_point)]
200-
);
201-
assert_eq!(
202-
proof.output_claims_by_instance[0],
203-
[sum.numerator, sum.denominator]
204-
);
205-
Ok(())
206-
}
207-
208-
#[test]
209-
fn test_logup_with_multiplicities_trace() -> Result<(), GkrError<BabyBear>> {
210-
const N: usize = 1 << 5;
211-
let mut rng = create_seeded_rng();
212-
let numerator_values = (0..N).map(|_| rng.gen::<BabyBear>()).collect_vec();
213-
let denominator_values = (0..N).map(|_| rng.gen::<BabyBear>()).collect_vec();
214-
let sum: Fraction<BabyBear> = zip(&numerator_values, &denominator_values)
215-
.map(|(&n, &d)| Fraction::new(n, d))
216-
.sum();
217-
let numerators = Mle::from_vec(numerator_values);
218-
let denominators = Mle::from_vec(denominator_values);
219-
let top_layer = Layer::LogUpMultiplicities {
220-
numerators: numerators.clone(),
221-
denominators: denominators.clone(),
222-
};
223-
224-
let engine = default_engine();
225-
let (proof, _) = gkr::prove_batch(&mut engine.new_challenger(), vec![top_layer]);
226-
227-
let GkrArtifact {
228-
ood_point,
229-
claims_to_verify_by_instance,
230-
n_variables_by_instance: _,
231-
} = gkr::partially_verify_batch(vec![Gate::LogUp], &proof, &mut engine.new_challenger())?;
232-
233-
assert_eq!(claims_to_verify_by_instance.len(), 1);
234-
assert_eq!(proof.output_claims_by_instance.len(), 1);
235-
assert_eq!(
236-
claims_to_verify_by_instance[0],
237-
[numerators.eval(&ood_point), denominators.eval(&ood_point)]
238-
);
239-
assert_eq!(
240-
proof.output_claims_by_instance[0],
241-
[sum.numerator, sum.denominator]
242-
);
243-
Ok(())
244-
}

crates/stark-backend/src/gkr/types.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,19 +197,19 @@ impl<F: Field> Layer<F> {
197197
}
198198

199199
/// Returns a transformed layer with the first variable of each column fixed to `assignment`.
200-
pub fn fix_first_variable(self, x0: F) -> Self {
200+
pub fn fix_first_variable_in_place(&mut self, x0: F) {
201201
if self.n_variables() == 0 {
202-
return self;
202+
return;
203203
}
204204

205205
match self {
206-
Self::GrandProduct(mle) => Self::GrandProduct(mle.partial_evaluation(x0)),
206+
Self::GrandProduct(mle) => mle.fix_first_in_place(x0),
207207
Self::LogUpGeneric {
208208
numerators,
209209
denominators,
210-
} => Self::LogUpGeneric {
211-
numerators: numerators.partial_evaluation(x0),
212-
denominators: denominators.partial_evaluation(x0),
210+
} => {
211+
numerators.fix_first_in_place(x0);
212+
denominators.fix_first_in_place(x0);
213213
},
214214
}
215215
}

crates/stark-backend/src/poly/multi.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ pub trait MultivariatePolyOracle<F>: Send + Sync {
1919
fn marginalize_first(&self, claim: F) -> UnivariatePolynomial<F>;
2020

2121
/// Returns the multivariate polynomial `h(x_2, ..., x_n) = g(alpha, x_2, ..., x_n)`.
22-
fn partial_evaluation(self, alpha: F) -> Self;
22+
fn fix_first_in_place(&mut self, alpha: F);
2323
}
2424

2525
/// Multilinear extension of the function defined on the boolean hypercube.
@@ -90,7 +90,7 @@ impl<F: Field> MultivariatePolyOracle<F> for Mle<F> {
9090
}
9191

9292

93-
fn partial_evaluation(mut self, alpha: F) -> Self {
93+
fn fix_first_in_place(&mut self, alpha: F) {
9494
let midpoint = self.len() / 2;
9595
let (lhs_evals, rhs_evals) = self.split_at_mut(midpoint);
9696
lhs_evals
@@ -100,7 +100,6 @@ impl<F: Field> MultivariatePolyOracle<F> for Mle<F> {
100100
*lhs += alpha * (rhs - *lhs);
101101
});
102102
self.evals.truncate(midpoint);
103-
self
104103
}
105104
}
106105

@@ -254,17 +253,17 @@ mod test {
254253
BabyBear::from_canonical_u32(4),
255254
];
256255
// (1 - x_1)(1 - x_2) + 2 (1 - x_1) x_2 + 3 x_1 (1 - x_2) + 4 x_1 x_2
257-
let mle = Mle::from_vec(evals);
256+
let mut mle = Mle::from_vec(evals);
258257
let alpha = BabyBear::from_canonical_u32(2);
259258
// -(1 - x_2) - 2 x_2 + 6 (1 - x_2) + 8 x_2 = x_2 + 5
260-
let partial_eval = mle.partial_evaluation(alpha);
259+
mle.fix_first_in_place(alpha);
261260

262261
assert_eq!(
263-
partial_eval.eval(&[BabyBear::ZERO]),
262+
mle.eval(&[BabyBear::ZERO]),
264263
BabyBear::from_canonical_u32(5)
265264
);
266265
assert_eq!(
267-
partial_eval.eval(&[BabyBear::ONE]),
266+
mle.eval(&[BabyBear::ONE]),
268267
BabyBear::from_canonical_u32(6)
269268
);
270269
}

crates/stark-backend/src/sumcheck.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,16 +121,13 @@ where
121121
.map(|round_poly| round_poly.evaluate(challenge))
122122
.collect();
123123

124-
polys = polys
125-
.into_par_iter()
126-
.map(|multivariate_poly| {
127-
if n_remaining_rounds != multivariate_poly.arity() {
128-
multivariate_poly
129-
} else {
130-
multivariate_poly.partial_evaluation(challenge)
124+
polys
125+
.par_iter_mut()
126+
.for_each(|multivariate_poly| {
127+
if n_remaining_rounds == multivariate_poly.arity() {
128+
multivariate_poly.fix_first_in_place(challenge)
131129
}
132-
})
133-
.collect();
130+
});
134131

135132
round_polys.push(round_poly);
136133
evaluation_point.push(challenge);

0 commit comments

Comments
 (0)