Skip to content

Commit 86c0c4c

Browse files
committed
wip
1 parent 39738f1 commit 86c0c4c

File tree

3 files changed

+104
-116
lines changed

3 files changed

+104
-116
lines changed

crates/stark-backend/src/interaction/gkr_log_up.rs

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,17 @@ where
123123
let beta_pows = generate_betas(beta, &all_interactions);
124124

125125
// Build GKR instances.
126-
let gkr_instances: Vec<_> =
127-
Self::build_gkr_instances(trace_view_per_air, interactions_per_air, alpha, &beta_pows);
126+
let gkr_instances: Vec<_> = metrics_span("gkr_build_instances_ms", || {
127+
Self::build_gkr_instances(trace_view_per_air, interactions_per_air, alpha, &beta_pows)
128+
});
128129

129130
// Construct input layers and run GKR proof.
130-
let input_layers: Vec<_> = gkr_instances
131-
.par_iter()
132-
.map(|gkr_instance| gkr_instance.build_gkr_input_layer())
133-
.collect();
131+
let input_layers: Vec<_> = metrics_span("build_gkr_input_layer_ms", || {
132+
gkr_instances
133+
.par_iter()
134+
.map(|gkr_instance| gkr_instance.build_gkr_input_layer())
135+
.collect()
136+
});
134137

135138
let (gkr_proof, gkr_artifact) = metrics_span("gkr_prove_batch_ms", || {
136139
gkr::prove_batch(challenger, input_layers)
@@ -142,14 +145,16 @@ where
142145
exposed_values_per_air,
143146
numer_mle_claims_per_instance,
144147
denom_mle_claims_per_instance,
145-
} = Self::generate_aux_per_air(
146-
challenger,
147-
interactions_per_air,
148-
trace_view_per_air,
149-
gamma,
150-
&gkr_instances,
151-
&gkr_artifact,
152-
);
148+
} = metrics_span("gkr_generate_aux", || {
149+
Self::generate_aux_per_air(
150+
challenger,
151+
interactions_per_air,
152+
trace_view_per_air,
153+
gamma,
154+
&gkr_instances,
155+
&gkr_artifact,
156+
)
157+
});
153158

154159
let mut challenges = vec![alpha, beta, gamma];
155160
challenges.extend_from_slice(&gkr_artifact.ood_point);

crates/stark-backend/src/interaction/gkr_log_up/prove.rs

Lines changed: 83 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use crate::{
2424
},
2525
parizip,
2626
poly::multi::{hypercube_eq_over_y, Mle},
27-
utils::metrics_span_increment,
27+
utils::metrics_span,
2828
};
2929

3030
/// A struct that holds the interaction-related data for a single GKR instance:
@@ -165,68 +165,75 @@ where
165165
.unwrap();
166166
let gamma_pows = gamma.powers().take(2 * max_interactions).collect_vec();
167167

168-
let trace_view_per_air_filtered = interactions_per_air
168+
let ood_point_per_instance: Vec<_> = interactions_per_air
169169
.iter()
170170
.zip(trace_view_per_air.iter())
171-
.filter_map(|(interactions, view)| {
172-
if interactions.is_empty() {
173-
None
174-
} else {
175-
Some(view)
176-
}
171+
.filter_map(|(interactions, view)| (!interactions.is_empty()).then_some(view))
172+
.zip(&gkr_artifact.n_variables_by_instance)
173+
.map(|(trace_view, &n_vars)| {
174+
let height = trace_view.partitioned_main[0].height();
175+
let log_height = log2_strict_usize(height);
176+
let instance_ood = &ood_point[ood_point.len() - n_vars..];
177+
let (_, r) = instance_ood.split_at(n_vars - log_height);
178+
r
177179
})
178-
.collect_vec();
179-
180-
let results: Vec<_> = parizip!(
181-
trace_view_per_air_filtered,
182-
gkr_instances,
183-
&gkr_artifact.n_variables_by_instance
184-
)
185-
.map(|(trace_view, gkr_instance, &n_vars)| {
186-
let height = trace_view.partitioned_main[0].height();
187-
let log_height = log2_strict_usize(height);
188-
189-
let instance_ood = &ood_point[ood_point.len() - n_vars..];
190-
let (_, r) = instance_ood.split_at(n_vars - log_height);
191-
debug_assert_eq!(r.len(), log_height);
180+
.collect();
192181

193-
let (after_challenge_trace, exposed_values, numer_mle_claims, denom_mle_claims) =
194-
Self::generate_aux(gkr_instance, &gamma_pows, r);
182+
let numer_mle_claims_per_instance: Vec<Vec<_>> =
183+
parizip!(&ood_point_per_instance, gkr_instances,)
184+
.map(|(r, gkr_instance)| {
185+
gkr_instance
186+
.numerators
187+
.transpose()
188+
.par_rows_mut()
189+
.map(|row| Mle::eval_slice_with_buf(row, r))
190+
.collect()
191+
})
192+
.collect();
193+
194+
let denom_mle_claims_per_instance: Vec<Vec<_>> =
195+
parizip!(&ood_point_per_instance, gkr_instances,)
196+
.map(|(r, gkr_instance)| {
197+
gkr_instance
198+
.denominators
199+
.transpose()
200+
.par_rows_mut()
201+
.map(|row| Mle::eval_slice_with_buf(row, r))
202+
.collect()
203+
})
204+
.collect();
205+
206+
for (numer_mle, denom_mle) in numer_mle_claims_per_instance
207+
.iter()
208+
.zip(denom_mle_claims_per_instance.iter())
209+
{
210+
for (numer_mle_claim, denom_mle_claim) in numer_mle.iter().zip(denom_mle.iter()) {
211+
challenger.observe_ext_element(*numer_mle_claim);
212+
challenger.observe_ext_element(*denom_mle_claim);
213+
}
214+
}
195215

196-
(
197-
after_challenge_trace,
198-
exposed_values,
199-
numer_mle_claims,
200-
denom_mle_claims,
201-
)
202-
})
203-
.collect();
216+
let results: Vec<_> = metrics_span("generate_perm_trace_time_ms", || {
217+
parizip!(&ood_point_per_instance, gkr_instances)
218+
.map(|(r, gkr_instance)| Self::generate_perm_trace(gkr_instance, &gamma_pows, r))
219+
.collect()
220+
});
204221

205222
let mut results_iter = results.into_iter();
206223

207-
let n_airs = interactions_per_air.len();
208-
let mut after_challenge_trace_per_air = Vec::with_capacity(n_airs);
209-
let mut exposed_values_per_air = Vec::with_capacity(n_airs);
210-
let mut numer_mle_claims_per_instance = Vec::with_capacity(n_airs);
211-
let mut denom_mle_claims_per_instance = Vec::with_capacity(n_airs);
212-
213-
for interactions in interactions_per_air.iter() {
214-
if interactions.is_empty() {
215-
after_challenge_trace_per_air.push(None);
216-
exposed_values_per_air.push(None);
217-
} else {
218-
let (after_trace, exposed, numer_mle, denom_mle) = results_iter.next().unwrap();
224+
let (after_challenge_trace_per_air, exposed_values_per_air): (Vec<_>, Vec<_>) =
225+
interactions_per_air
226+
.iter()
227+
.map(|interactions| {
228+
if interactions.is_empty() {
229+
(None, None)
230+
} else {
231+
let (after_trace, exposed) = results_iter.next().unwrap();
232+
(Some(after_trace), Some(exposed))
233+
}
234+
})
235+
.unzip();
219236

220-
for (numer_mle_claim, denom_mle_claim) in numer_mle.iter().zip(denom_mle.iter()) {
221-
challenger.observe_ext_element(*numer_mle_claim);
222-
challenger.observe_ext_element(*denom_mle_claim);
223-
}
224-
after_challenge_trace_per_air.push(Some(after_trace));
225-
exposed_values_per_air.push(Some(exposed));
226-
numer_mle_claims_per_instance.push(numer_mle);
227-
denom_mle_claims_per_instance.push(denom_mle);
228-
}
229-
}
230237
GkrAuxData {
231238
after_challenge_trace_per_air,
232239
exposed_values_per_air,
@@ -235,59 +242,35 @@ where
235242
}
236243
}
237244

238-
fn generate_aux(
245+
fn generate_perm_trace(
239246
gkr_instance: &GkrLogUpInstance<EF>,
240247
gamma_pows: &[EF],
241248
r: &[EF],
242-
) -> (DenseMatrix<EF>, Vec<EF>, Vec<EF>, Vec<EF>) {
243-
let numer_mle_claims: Vec<_> = gkr_instance
249+
) -> (DenseMatrix<EF>, Vec<EF>) {
250+
let s_at_rows: Vec<EF> = gkr_instance
244251
.numerators
245-
.transpose()
246-
.par_rows_mut()
247-
.map(|row| Mle::eval_slice(row, r))
248-
.collect();
249-
250-
let denom_mle_claims: Vec<_> = gkr_instance
251-
.denominators
252-
.transpose()
253-
.par_rows_mut()
254-
.map(|row| Mle::eval_slice(row, r))
252+
.par_row_slices()
253+
.zip(gkr_instance.denominators.par_row_slices())
254+
.map(|(count_row, sigma_row)| {
255+
itertools::interleave(count_row, sigma_row)
256+
.zip(gamma_pows)
257+
.fold(EF::ZERO, |acc, (&val, &gamma_pow)| acc + gamma_pow * val)
258+
})
255259
.collect();
256260

257-
let (after_challenge_trace, exposed_values) =
258-
metrics_span_increment("generate_perm_trace_time_ms", || {
259-
let s_at_rows: Vec<EF> = gkr_instance
260-
.numerators
261-
.par_row_slices()
262-
.zip(gkr_instance.denominators.par_row_slices())
263-
.map(|(count_row, sigma_row)| {
264-
itertools::interleave(count_row, sigma_row)
265-
.zip(gamma_pows)
266-
.fold(EF::ZERO, |acc, (&val, &gamma_pow)| acc + gamma_pow * val)
267-
})
268-
.collect();
261+
// TODO: Precompute these per height.
262+
let eqs_at_r = hypercube_eq_over_y(r);
269263

270-
// TODO: Precompute these per height.
271-
let eqs_at_r = hypercube_eq_over_y(r);
264+
let mut partial_sum = EF::ZERO;
265+
let mut after_challenge_trace_data = Vec::with_capacity(eqs_at_r.len() * 3);
266+
for (eq_at_r, s_at_row) in eqs_at_r.iter().zip(s_at_rows.iter()) {
267+
after_challenge_trace_data.push(*eq_at_r);
268+
after_challenge_trace_data.push(*s_at_row);
269+
after_challenge_trace_data.push(partial_sum);
272270

273-
let mut partial_sum = EF::ZERO;
274-
let mut after_challenge_trace_data = Vec::with_capacity(eqs_at_r.len() * 3);
275-
for (eq_at_r, s_at_row) in eqs_at_r.iter().zip(s_at_rows.iter()) {
276-
after_challenge_trace_data.push(*eq_at_r);
277-
after_challenge_trace_data.push(*s_at_row);
278-
after_challenge_trace_data.push(partial_sum);
279-
280-
partial_sum += *eq_at_r * *s_at_row;
281-
}
282-
let after_challenge_trace = RowMajorMatrix::new(after_challenge_trace_data, 3);
283-
(after_challenge_trace, vec![partial_sum])
284-
});
285-
286-
(
287-
after_challenge_trace,
288-
exposed_values,
289-
numer_mle_claims,
290-
denom_mle_claims,
291-
)
271+
partial_sum += *eq_at_r * *s_at_row;
272+
}
273+
let after_challenge_trace = RowMajorMatrix::new(after_challenge_trace_data, 3);
274+
(after_challenge_trace, vec![partial_sum])
292275
}
293276
}

crates/stark-backend/src/poly/multi.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ impl<F: Field> Mle<F> {
4444
/// Evaluates the multilinear extension at `point`.
4545
pub fn eval(&self, point: &[F]) -> F {
4646
let mut buf = self.evals.clone();
47-
Self::eval_slice(&mut buf, point)
47+
Self::eval_slice_with_buf(&mut buf, point)
4848
}
4949

5050
/// Evaluates the multilinear extension given by hypercube evaluations `evals` at `point`.
@@ -54,7 +54,7 @@ impl<F: Field> Mle<F> {
5454
/// # Panics
5555
///
5656
/// Panics if `evals.len() != 2^point.len()`
57-
pub fn eval_slice(evals: &mut [F], point: &[F]) -> F {
57+
pub fn eval_slice_with_buf(evals: &mut [F], point: &[F]) -> F {
5858
let buf = evals;
5959
let n = point.len();
6060
let mut len = buf.len();

0 commit comments

Comments
 (0)