Skip to content

Commit ace0f04

Browse files
committed
cleanup
1 parent 5f50fb8 commit ace0f04

File tree

7 files changed

+79
-89
lines changed

7 files changed

+79
-89
lines changed

crates/stark-backend/src/gkr/prover.rs

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,13 @@ impl<'a, F: Field> FixedFirstHypercubeEqEvals<'a, F> {
5454

5555
for &y_i in y.iter().rev() {
5656
let (left, right) = evals.split_at_mut(curr_len);
57-
left.par_iter_mut().zip(right.par_iter_mut()).for_each(|(l, r)| {
58-
let tmp = *l * y_i;
59-
*r = tmp;
60-
*l -= tmp;
61-
});
57+
left.par_iter_mut()
58+
.zip(right.par_iter_mut())
59+
.for_each(|(l, r)| {
60+
let tmp = *l * y_i;
61+
*r = tmp;
62+
*l -= tmp;
63+
});
6264
curr_len *= 2;
6365
}
6466
evals
@@ -91,7 +93,7 @@ impl<F> Deref for FixedFirstHypercubeEqEvals<'_, F> {
9193
///
9294
/// P(x) = eq(x, y) * (numer(x) + lambda * denom(x))
9395
/// ```
94-
struct GkrMultivariatePolyOracle<'a, F: Clone> {
96+
struct GkrMultivariatePolyOracle<'a, F> {
9597
pub eq_evals: &'a FixedFirstHypercubeEqEvals<'a, F>,
9698
pub input_layer: Layer<F>,
9799
pub eq_fixed_var_correction: F,
@@ -114,12 +116,24 @@ impl<F: Field> MultivariatePolyOracle<F> for GkrMultivariatePolyOracle<'_, F> {
114116
Layer::LogUpGeneric {
115117
numerators,
116118
denominators,
117-
} => eval_logup_sum(self.eq_evals, numerators, denominators, n_terms, self.lambda),
119+
} => eval_logup_sum(
120+
self.eq_evals,
121+
numerators,
122+
denominators,
123+
n_terms,
124+
self.lambda,
125+
),
118126
};
119127

120128
eval_at_0 *= self.eq_fixed_var_correction;
121129
eval_at_2 *= self.eq_fixed_var_correction;
122-
correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, self.eq_evals.y, n_variables)
130+
correct_sum_as_poly_in_first_variable(
131+
eval_at_0,
132+
eval_at_2,
133+
claim,
134+
self.eq_evals.y,
135+
n_variables,
136+
)
123137
}
124138

125139
fn fix_first_in_place(&mut self, alpha: F) {
@@ -197,14 +211,8 @@ fn eval_logup_sum<F: Field>(
197211
let (n2_0, d2_0) = (n1_0.double() - n0_0, d1_0.double() - d0_0);
198212
let (n2_1, d2_1) = (n1_1.double() - n0_1, d1_1.double() - d0_1);
199213

200-
let (num_t0, den_t0) = (
201-
n0_0 * d0_1 + n0_1 * d0_0,
202-
d0_0 * d0_1,
203-
);
204-
let (num_t2, den_t2) = (
205-
n2_0 * d2_1 + n2_1 * d2_0,
206-
d2_0 * d2_1,
207-
);
214+
let (num_t0, den_t0) = (n0_0 * d0_1 + n0_1 * d0_0, d0_0 * d0_1);
215+
let (num_t2, den_t2) = (n2_0 * d2_1 + n2_1 * d2_0, d2_0 * d2_1);
208216

209217
let eq = eq_evals[i];
210218
let eval_t0 = eq * (num_t0 + lambda * den_t0);
@@ -281,23 +289,25 @@ pub fn prove_batch<F: Field, EF: ExtensionField<F>>(
281289
.collect();
282290

283291
let mut output_claims_by_instance = vec![None; n_instances];
284-
let mut layer_masks_by_instance = (0..n_instances).map(|_| Vec::new()).collect_vec();
292+
let mut layer_masks_by_instance = vec![vec![]; n_instances];
285293
let mut sumcheck_proofs = Vec::new();
286294

287295
let mut ood_point = Vec::new();
288296
let mut claims_to_verify_by_instance = vec![None; n_instances];
289297

290-
for layer in 0..n_layers {
291-
let n_remaining_layers = n_layers - layer;
298+
let mut output_instances_by_layer: Vec<Vec<usize>> = vec![Vec::new(); n_layers];
299+
for (instance, &instance_n_layers) in n_layers_by_instance.iter().enumerate() {
300+
let output_layer = n_layers - instance_n_layers;
301+
output_instances_by_layer[output_layer].push(instance);
302+
}
292303

304+
for layer in 0..n_layers {
293305
// Check all the instances for output layers.
294-
for (instance, layers) in layers_by_instance.iter_mut().enumerate() {
295-
if n_layers_by_instance[instance] == n_remaining_layers {
296-
let output_layer = layers.next().unwrap();
297-
let output_layer_values = output_layer.try_into_output_layer_values().unwrap();
298-
claims_to_verify_by_instance[instance] = Some(output_layer_values.clone());
299-
output_claims_by_instance[instance] = Some(output_layer_values);
300-
}
306+
for &instance in &output_instances_by_layer[layer] {
307+
let output_layer = layers_by_instance[instance].next().unwrap();
308+
let output_layer_values = output_layer.try_into_output_layer_values().unwrap();
309+
claims_to_verify_by_instance[instance] = Some(output_layer_values.clone());
310+
output_claims_by_instance[instance] = Some(output_layer_values);
301311
}
302312

303313
// Seed the channel with layer claims.

crates/stark-backend/src/gkr/types.rs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
1-
use std::ops::Index;
2-
3-
use p3_maybe_rayon::prelude::*;
41
use p3_field::Field;
2+
use p3_maybe_rayon::prelude::*;
53
use serde::{Deserialize, Serialize};
64
use thiserror::Error;
75

86
use crate::{
9-
poly::{
10-
multi::{fold_mle_evals, Mle, MultivariatePolyOracle},
11-
uni::Fraction,
12-
},
7+
poly::multi::{fold_mle_evals, Mle, MultivariatePolyOracle},
138
sumcheck::{SumcheckError, SumcheckProof},
149
};
1510

@@ -207,10 +202,10 @@ impl<F: Field> Layer<F> {
207202
Self::LogUpGeneric {
208203
numerators,
209204
denominators,
210-
} => {
205+
} => {
211206
numerators.fix_first_in_place(x0);
212207
denominators.fix_first_in_place(x0);
213-
},
208+
}
214209
}
215210
}
216211
}

crates/stark-backend/src/interaction/gkr_log_up.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use rayon::iter::IntoParallelRefIterator;
1717
use serde::{Deserialize, Serialize};
1818
use thiserror::Error;
1919

20+
use crate::interaction::SymbolicInteraction;
2021
use crate::{
2122
air_builders::symbolic::SymbolicConstraints,
2223
gkr,
@@ -27,7 +28,6 @@ use crate::{
2728
},
2829
rap::PermutationAirBuilderWithExposedValues,
2930
};
30-
use crate::interaction::SymbolicInteraction;
3131

3232
pub struct GkrLogUpPhase<F, EF, Challenger> {
3333
// FIXME: USE THIS IN POW
@@ -112,11 +112,7 @@ where
112112
_params_per_air: &[&Self::PartialProvingKey],
113113
trace_view_per_air: &[PairTraceView<'_, F>],
114114
) -> Option<(Self::PartialProof, RapPhaseProverData<EF>)> {
115-
let all_interactions = interactions_per_air
116-
.iter()
117-
.cloned()
118-
.flatten()
119-
.collect_vec();
115+
let all_interactions = interactions_per_air.iter().cloned().flatten().collect_vec();
120116
if all_interactions.is_empty() {
121117
return None;
122118
}

crates/stark-backend/src/interaction/gkr_log_up/prove.rs

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::iter;
22

3-
use itertools::{izip, Itertools};
3+
use itertools::Itertools;
44
use p3_challenger::FieldChallenger;
55
use p3_field::{ExtensionField, Field};
66
use p3_matrix::{
@@ -11,10 +11,7 @@ use p3_maybe_rayon::prelude::*;
1111
use p3_util::log2_strict_usize;
1212

1313
use crate::{
14-
air_builders::symbolic::{
15-
symbolic_expression::{SymbolicEvaluator, SymbolicExpression},
16-
SymbolicConstraints,
17-
},
14+
air_builders::symbolic::symbolic_expression::{SymbolicEvaluator, SymbolicExpression},
1815
gkr::{GkrArtifact, Layer, Layer::LogUpGeneric},
1916
interaction::{
2017
gkr_log_up::{num_interaction_dimensions, GkrAuxData, GkrLogUpPhase},
@@ -91,7 +88,9 @@ impl<EF: Field> GkrLogUpInstance<EF> {
9188
.iter()
9289
.chain(iter::once(&SymbolicExpression::Constant(b)))
9390
.zip(beta_pows)
94-
.fold(EF::ZERO, |acc, (expr, &beta)| acc + beta * evaluator.eval_expr(expr));
91+
.fold(EF::ZERO, |acc, (expr, &beta)| {
92+
acc + beta * evaluator.eval_expr(expr)
93+
});
9594

9695
sigma_row[col] = sigma;
9796
}
@@ -167,26 +166,23 @@ where
167166
) -> GkrAuxData<EF> {
168167
let ood_point = &gkr_artifact.ood_point;
169168

170-
let max_interactions = interactions_per_air
171-
.iter()
172-
.map(|v| v.len())
173-
.max()
174-
.unwrap();
169+
let max_interactions = interactions_per_air.iter().map(|v| v.len()).max().unwrap();
175170
let gamma_pows = gamma.powers().take(2 * max_interactions).collect_vec();
176171

177-
let trace_view_per_air_filtered =
178-
interactions_per_air
179-
.iter()
180-
.zip(trace_view_per_air.iter())
181-
.filter_map(|(interactions, view)| {
182-
if interactions.is_empty() {
183-
None
184-
} else {
185-
Some(view)
186-
}
187-
}).collect_vec();
188-
189-
let results: Vec<_> = trace_view_per_air_filtered.par_iter()
172+
let trace_view_per_air_filtered = interactions_per_air
173+
.iter()
174+
.zip(trace_view_per_air.iter())
175+
.filter_map(|(interactions, view)| {
176+
if interactions.is_empty() {
177+
None
178+
} else {
179+
Some(view)
180+
}
181+
})
182+
.collect_vec();
183+
184+
let results: Vec<_> = trace_view_per_air_filtered
185+
.par_iter()
190186
.zip(gkr_instances.par_iter())
191187
.zip(gkr_artifact.n_variables_by_instance.par_iter())
192188
.map(|((trace_view, gkr_instance), &n_vars)| {
@@ -222,8 +218,7 @@ where
222218
after_challenge_trace_per_air.push(None);
223219
exposed_values_per_air.push(None);
224220
} else {
225-
let
226-
(after_trace, exposed, count_mle, sigma_mle) = results_iter.next().unwrap();
221+
let (after_trace, exposed, count_mle, sigma_mle) = results_iter.next().unwrap();
227222

228223
for (count_mle_claim, sigma_mle_claim) in count_mle.iter().zip(sigma_mle.iter()) {
229224
challenger.observe_ext_element(*count_mle_claim);

crates/stark-backend/src/poly/multi.rs

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ impl<F: Field> MultivariatePolyOracle<F> for Mle<F> {
8989
UnivariatePolynomial::from_coeffs(vec![y0, slope])
9090
}
9191

92-
9392
fn fix_first_in_place(&mut self, alpha: F) {
9493
let midpoint = self.len() / 2;
9594
let (lhs_evals, rhs_evals) = self.split_at_mut(midpoint);
@@ -258,14 +257,8 @@ mod test {
258257
// -(1 - x_2) - 2 x_2 + 6 (1 - x_2) + 8 x_2 = x_2 + 5
259258
mle.fix_first_in_place(alpha);
260259

261-
assert_eq!(
262-
mle.eval(&[BabyBear::ZERO]),
263-
BabyBear::from_canonical_u32(5)
264-
);
265-
assert_eq!(
266-
mle.eval(&[BabyBear::ONE]),
267-
BabyBear::from_canonical_u32(6)
268-
);
260+
assert_eq!(mle.eval(&[BabyBear::ZERO]), BabyBear::from_canonical_u32(5));
261+
assert_eq!(mle.eval(&[BabyBear::ONE]), BabyBear::from_canonical_u32(6));
269262
}
270263

271264
#[test]

crates/stark-backend/src/poly/uni.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
//! Copied from starkware-libs/stwo under Apache-2.0 license.
22
use p3_field::Field;
33
use serde::{Deserialize, Serialize};
4-
use std::ops::MulAssign;
54
use std::{
65
iter::Sum,
76
ops::{Add, Deref, Mul, Neg, Sub},
@@ -85,7 +84,10 @@ impl<F: Field> UnivariatePolynomial<F> {
8584

8685
fn remove_trailing_zeroes(&mut self) {
8786
self.coeffs.truncate(
88-
self.coeffs.iter().rposition(|c| !c.is_zero()).map_or(0, |i| i + 1)
87+
self.coeffs
88+
.iter()
89+
.rposition(|c| !c.is_zero())
90+
.map_or(0, |i| i + 1),
8991
);
9092
}
9193

crates/stark-backend/src/sumcheck.rs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,13 @@ where
8888
let eval_at_0 = round_poly.evaluate_at_zero();
8989
let eval_at_1 = round_poly.evaluate_at_one();
9090

91-
assert_eq!(
91+
debug_assert_eq!(
9292
eval_at_0 + eval_at_1,
9393
claim,
9494
"Round {round}, poly {i}: eval(0) + eval(1) != claim ({} != {claim})",
9595
eval_at_0 + eval_at_1,
9696
);
97-
assert!(
97+
debug_assert!(
9898
round_poly.degree() <= MAX_DEGREE,
9999
"Round {round}, poly {i}: degree {} > max {MAX_DEGREE}",
100100
round_poly.degree(),
@@ -115,17 +115,16 @@ where
115115

116116
let challenge = challenger.sample_ext_element();
117117

118-
claims.par_iter_mut().zip(this_round_polys
119-
.par_iter())
118+
claims
119+
.par_iter_mut()
120+
.zip(this_round_polys.par_iter())
120121
.for_each(|(claim, round_poly)| *claim = round_poly.evaluate(challenge));
121122

122-
polys
123-
.par_iter_mut()
124-
.for_each(|multivariate_poly| {
125-
if n_remaining_rounds == multivariate_poly.arity() {
126-
multivariate_poly.fix_first_in_place(challenge)
127-
}
128-
});
123+
polys.par_iter_mut().for_each(|multivariate_poly| {
124+
if n_remaining_rounds == multivariate_poly.arity() {
125+
multivariate_poly.fix_first_in_place(challenge)
126+
}
127+
});
129128

130129
round_polys.push(round_poly);
131130
evaluation_point.push(challenge);

0 commit comments

Comments
 (0)