Skip to content

Commit 5db1a33

Browse files
committed
prof
1 parent 1b7bf93 commit 5db1a33

File tree

10 files changed

+122
-85
lines changed

10 files changed

+122
-85
lines changed

crates/stark-backend/src/air_builders/debug/check_constraints.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ pub fn check_constraints<R, SC>(
100100
});
101101
}
102102

103+
#[allow(dead_code)]
103104
fn check_gkr_log_up_adapter_constraints_for_row<SC: StarkGenericConfig>(
104105
after_challenge: &[RowMajorMatrixView<SC::Challenge>],
105106
challenges: &[Vec<SC::Challenge>],

crates/stark-backend/src/gkr/prover.rs

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,16 @@ use p3_field::{ExtensionField, Field};
1111
use p3_maybe_rayon::prelude::*;
1212
use thiserror::Error;
1313

14-
use crate::{gkr::types::{GkrArtifact, GkrBatchProof, GkrMask, Layer}, poly::{
15-
multi::{hypercube_eq, Mle, MultivariatePolyOracle},
16-
uni::{random_linear_combination, UnivariatePolynomial},
17-
}, sumcheck, sumcheck::SumcheckArtifacts};
18-
use crate::utils::metrics_span;
14+
use crate::utils::{metrics_span, metrics_span_increment};
15+
use crate::{
16+
gkr::types::{GkrArtifact, GkrBatchProof, GkrMask, Layer},
17+
poly::{
18+
multi::{Mle, MultivariatePolyOracle},
19+
uni::{random_linear_combination, UnivariatePolynomial},
20+
},
21+
sumcheck,
22+
sumcheck::SumcheckArtifacts,
23+
};
1924

2025
/// For a given `y`, stores evaluations of [hypercube_eq](x, y) on all 2^{n-1} boolean hypercube
2126
/// points of the form `x = (0, x_2, ..., x_n)`.
@@ -279,10 +284,12 @@ pub fn prove_batch<F: Field, EF: ExtensionField<F>>(
279284
let n_layers = *n_layers_by_instance.iter().max().unwrap();
280285

281286
// Evaluate all instance circuits and collect the layer values.
282-
let mut layers_by_instance: Vec<_> = input_layer_by_instance
283-
.into_par_iter()
284-
.map(|input_layer| gen_layers(input_layer).into_iter().rev())
285-
.collect();
287+
let mut layers_by_instance: Vec<_> = metrics_span("gkr_gen_layers_ms", || {
288+
input_layer_by_instance
289+
.into_par_iter()
290+
.map(|input_layer| gen_layers(input_layer).into_iter().rev())
291+
.collect()
292+
});
286293

287294
let mut output_claims_by_instance = vec![None; n_instances];
288295
let mut layer_masks_by_instance = vec![vec![]; n_instances];
@@ -297,6 +304,7 @@ pub fn prove_batch<F: Field, EF: ExtensionField<F>>(
297304
output_instances_by_layer[output_layer].push(instance);
298305
}
299306

307+
#[allow(clippy::needless_range_loop)]
300308
for layer in 0..n_layers {
301309
// Check all the instances for output layers.
302310
for &instance in &output_instances_by_layer[layer] {
@@ -344,7 +352,7 @@ pub fn prove_batch<F: Field, EF: ExtensionField<F>>(
344352
constant_poly_oracles,
345353
..
346354
},
347-
) = metrics_span("sumcheck_prove_batch_ms", || {
355+
) = metrics_span_increment("sumcheck_prove_batch_ms", || {
348356
sumcheck::prove_batch(
349357
sumcheck_claims,
350358
sumcheck_oracles,

crates/stark-backend/src/gkr/tests.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ use openvm_stark_sdk::{
55
config::baby_bear_blake3::default_engine, engine::StarkEngine, utils::create_seeded_rng,
66
};
77
use p3_baby_bear::BabyBear;
8-
use p3_field::FieldAlgebra;
98
use rand::Rng;
109

1110
use crate::{

crates/stark-backend/src/interaction/fri_log_up.rs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::{
1919
},
2020
parizip,
2121
rap::PermutationAirBuilderWithExposedValues,
22-
utils::metrics_span,
22+
utils::{metrics_span, metrics_span_increment},
2323
};
2424

2525
pub struct FriLogUpPhase<F, Challenge, Challenger> {
@@ -114,14 +114,15 @@ where
114114
let challenges: [Challenge; STARK_LU_NUM_CHALLENGES] =
115115
array::from_fn(|_| challenger.sample_ext_element::<Challenge>());
116116

117-
let after_challenge_trace_per_air = metrics_span("generate_perm_trace_time_ms", || {
118-
Self::generate_after_challenge_traces_per_air(
119-
&challenges,
120-
interactions_per_air,
121-
params_per_air,
122-
trace_view_per_air,
123-
)
124-
});
117+
let after_challenge_trace_per_air =
118+
metrics_span_increment("generate_perm_trace_time_ms", || {
119+
Self::generate_after_challenge_traces_per_air(
120+
&challenges,
121+
interactions_per_air,
122+
params_per_air,
123+
trace_view_per_air,
124+
)
125+
});
125126
let cumulative_sum_per_air = Self::extract_cumulative_sums(&after_challenge_trace_per_air);
126127

127128
// Challenger needs to observe what is exposed (cumulative_sums)
@@ -248,7 +249,7 @@ where
248249
parizip!(interactions_per_air, trace_view_per_air, params_per_air)
249250
.map(|(interactions, trace_view, params)| {
250251
Self::generate_after_challenge_trace(
251-
&interactions,
252+
interactions,
252253
trace_view,
253254
challenges,
254255
&params.interaction_partitions,

crates/stark-backend/src/interaction/gkr_log_up.rs

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ use p3_field::{
1313
use p3_matrix::{dense::DenseMatrix, Matrix};
1414
use p3_maybe_rayon::prelude::*;
1515
use p3_util::{log2_ceil_usize, log2_strict_usize};
16-
use rayon::iter::IntoParallelRefIterator;
1716
use serde::{Deserialize, Serialize};
1817
use thiserror::Error;
1918

2019
use crate::interaction::SymbolicInteraction;
20+
use crate::utils::metrics_span;
2121
use crate::{
2222
air_builders::symbolic::SymbolicConstraints,
2323
gkr,
@@ -28,7 +28,6 @@ use crate::{
2828
},
2929
rap::PermutationAirBuilderWithExposedValues,
3030
};
31-
use crate::utils::metrics_span;
3231

3332
pub struct GkrLogUpPhase<F, EF, Challenger> {
3433
// FIXME: USE THIS IN POW
@@ -113,7 +112,7 @@ where
113112
_params_per_air: &[&Self::PartialProvingKey],
114113
trace_view_per_air: &[PairTraceView<'_, F>],
115114
) -> Option<(Self::PartialProof, RapPhaseProverData<EF>)> {
116-
let all_interactions = interactions_per_air.iter().cloned().flatten().collect_vec();
115+
let all_interactions = interactions_per_air.iter().flatten().collect_vec();
117116
if all_interactions.is_empty() {
118117
return None;
119118
}
@@ -143,16 +142,14 @@ where
143142
exposed_values_per_air,
144143
count_mle_claims_per_instance,
145144
sigma_mle_claims_per_instance,
146-
} = metrics_span("generate_perm_trace_time_ms", || {
147-
Self::generate_aux_per_air(
148-
challenger,
149-
interactions_per_air,
150-
trace_view_per_air,
151-
alpha,
152-
&gkr_instances,
153-
&gkr_artifact,
154-
)
155-
});
145+
} = Self::generate_aux_per_air(
146+
challenger,
147+
interactions_per_air,
148+
trace_view_per_air,
149+
alpha,
150+
&gkr_instances,
151+
&gkr_artifact,
152+
);
156153

157154
let mut challenges = vec![beta, alpha];
158155
challenges.extend_from_slice(&gkr_artifact.ood_point);

crates/stark-backend/src/interaction/gkr_log_up/prove.rs

Lines changed: 60 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,18 @@ use p3_util::log2_strict_usize;
1212

1313
use crate::{
1414
air_builders::symbolic::symbolic_expression::{SymbolicEvaluator, SymbolicExpression},
15-
gkr::{GkrArtifact, Layer, Layer::LogUpGeneric},
15+
gkr::{
16+
GkrArtifact,
17+
Layer::{self, LogUpGeneric},
18+
},
1619
interaction::{
1720
gkr_log_up::{num_interaction_dimensions, GkrAuxData, GkrLogUpPhase},
1821
trace::Evaluator,
1922
PairTraceView, SymbolicInteraction,
2023
},
24+
parizip,
2125
poly::multi::{hypercube_eq_over_y, Mle},
26+
utils::metrics_span,
2227
};
2328

2429
/// A struct that holds the interaction-related data for a single GKR instance:
@@ -148,7 +153,7 @@ where
148153
} else {
149154
Some(GkrLogUpInstance::from_interactions(
150155
trace_view,
151-
&interactions,
156+
interactions,
152157
beta_pows,
153158
))
154159
}
@@ -181,29 +186,30 @@ where
181186
})
182187
.collect_vec();
183188

184-
let results: Vec<_> = trace_view_per_air_filtered
185-
.par_iter()
186-
.zip(gkr_instances.par_iter())
187-
.zip(gkr_artifact.n_variables_by_instance.par_iter())
188-
.map(|((trace_view, gkr_instance), &n_vars)| {
189-
let height = trace_view.partitioned_main[0].height();
190-
let log_height = log2_strict_usize(height);
191-
192-
let instance_ood = &ood_point[ood_point.len() - n_vars..];
193-
let (_, r) = instance_ood.split_at(n_vars - log_height);
194-
debug_assert_eq!(r.len(), log_height);
195-
196-
let (after_challenge_trace, exposed_values, count_mle_claims, sigma_mle_claims) =
197-
Self::generate_aux(gkr_instance, &gamma_pows, r);
198-
199-
(
200-
after_challenge_trace,
201-
exposed_values,
202-
count_mle_claims,
203-
sigma_mle_claims,
204-
)
205-
})
206-
.collect();
189+
let results: Vec<_> = parizip!(
190+
trace_view_per_air_filtered,
191+
gkr_instances,
192+
&gkr_artifact.n_variables_by_instance
193+
)
194+
.map(|(trace_view, gkr_instance, &n_vars)| {
195+
let height = trace_view.partitioned_main[0].height();
196+
let log_height = log2_strict_usize(height);
197+
198+
let instance_ood = &ood_point[ood_point.len() - n_vars..];
199+
let (_, r) = instance_ood.split_at(n_vars - log_height);
200+
debug_assert_eq!(r.len(), log_height);
201+
202+
let (after_challenge_trace, exposed_values, count_mle_claims, sigma_mle_claims) =
203+
Self::generate_aux(gkr_instance, &gamma_pows, r);
204+
205+
(
206+
after_challenge_trace,
207+
exposed_values,
208+
count_mle_claims,
209+
sigma_mle_claims,
210+
)
211+
})
212+
.collect();
207213

208214
let mut results_iter = results.into_iter();
209215

@@ -257,32 +263,38 @@ where
257263
.map(|row| Mle::eval_slice(row, r))
258264
.collect();
259265

260-
let s_at_rows: Vec<EF> = gkr_instance
261-
.counts
262-
.par_row_slices()
263-
.zip(gkr_instance.sigmas.par_row_slices())
264-
.map(|(count_row, sigma_row)| {
265-
itertools::interleave(count_row, sigma_row)
266-
.zip(gamma_pows)
267-
.fold(EF::ZERO, |acc, (&val, &gamma_pow)| acc + gamma_pow * val)
268-
})
269-
.collect();
270-
271-
let eqs_at_r = hypercube_eq_over_y(r);
272-
273-
let mut partial_sum = EF::ZERO;
274-
let mut after_challenge_trace_data = Vec::with_capacity(eqs_at_r.len() * 3);
275-
for (eq_at_r, s_at_row) in eqs_at_r.iter().zip(s_at_rows.iter()) {
276-
after_challenge_trace_data.push(*eq_at_r);
277-
after_challenge_trace_data.push(*s_at_row);
278-
after_challenge_trace_data.push(partial_sum);
266+
let (after_challenge_trace, exposed_values) =
267+
metrics_span("generate_perm_trace_time_ms", || {
268+
let s_at_rows: Vec<EF> = gkr_instance
269+
.counts
270+
.par_row_slices()
271+
.zip(gkr_instance.sigmas.par_row_slices())
272+
.map(|(count_row, sigma_row)| {
273+
itertools::interleave(count_row, sigma_row)
274+
.zip(gamma_pows)
275+
.fold(EF::ZERO, |acc, (&val, &gamma_pow)| acc + gamma_pow * val)
276+
})
277+
.collect();
278+
279+
// TODO: Precompute these per height.
280+
let eqs_at_r = hypercube_eq_over_y(r);
281+
282+
let mut partial_sum = EF::ZERO;
283+
let mut after_challenge_trace_data = Vec::with_capacity(eqs_at_r.len() * 3);
284+
for (eq_at_r, s_at_row) in eqs_at_r.iter().zip(s_at_rows.iter()) {
285+
after_challenge_trace_data.push(*eq_at_r);
286+
after_challenge_trace_data.push(*s_at_row);
287+
after_challenge_trace_data.push(partial_sum);
288+
289+
partial_sum += *eq_at_r * *s_at_row;
290+
}
291+
let after_challenge_trace = RowMajorMatrix::new(after_challenge_trace_data, 3);
292+
(after_challenge_trace, vec![partial_sum])
293+
});
279294

280-
partial_sum += *eq_at_r * *s_at_row;
281-
}
282-
let after_challenge_trace = RowMajorMatrix::new(after_challenge_trace_data, 3);
283295
(
284296
after_challenge_trace,
285-
vec![partial_sum],
297+
exposed_values,
286298
count_mle_claims,
287299
sigma_mle_claims,
288300
)

crates/stark-backend/src/interaction/utils.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
use p3_field::FieldAlgebra;
2+
use std::borrow::Borrow;
23

34
use super::Interaction;
45

56
/// Returns [beta^0, beta^1, ..., beta^{max_num_fields}]
67
/// where max_num_fields is the maximum length of `fields` in any interaction.
78
pub fn generate_betas<AF: FieldAlgebra, E>(
89
beta: AF,
9-
all_interactions: &[Interaction<E>],
10+
all_interactions: &[impl Borrow<Interaction<E>>],
1011
) -> Vec<AF> {
1112
let max_fields_len = all_interactions
1213
.iter()
13-
.map(|interaction| interaction.message.len())
14+
.map(|interaction| interaction.borrow().message.len())
1415
.max()
1516
.unwrap_or(0);
1617

crates/stark-backend/src/poly/uni.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ impl<F: Field> UnivariatePolynomial<F> {
3232
}
3333

3434
pub fn evaluate_at_zero(&self) -> F {
35-
self.coeffs.get(0).copied().unwrap_or(F::ZERO)
35+
*self.coeffs.first().unwrap_or(&F::ZERO)
3636
}
3737

3838
pub fn evaluate_at_one(&self) -> F {

crates/stark-backend/src/sumcheck.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ where
120120
.zip(this_round_polys.par_iter())
121121
.for_each(|(claim, round_poly)| *claim = round_poly.evaluate(challenge));
122122

123+
// TODO: This can be optimized if we keep track of the active polynomials.
123124
polys.par_iter_mut().for_each(|multivariate_poly| {
124125
if n_remaining_rounds == multivariate_poly.arity() {
125126
multivariate_poly.fix_first_in_place(challenge)

crates/stark-backend/src/utils.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,23 @@ pub fn metrics_span<R, F: FnOnce() -> R>(name: impl Into<Cow<'static, str>>, f:
5959
}
6060
}
6161

62+
/// A span that will run the given closure `f`,
63+
/// and emit a metric with the given `name` using [`gauge`](metrics::gauge)
64+
/// when the feature `"bench-metrics"` is enabled.
65+
#[allow(unused_variables)]
66+
pub fn metrics_span_increment<R, F: FnOnce() -> R>(name: impl Into<Cow<'static, str>>, f: F) -> R {
67+
cfg_if! {
68+
if #[cfg(feature = "bench-metrics")] {
69+
let start = std::time::Instant::now();
70+
let res = f();
71+
metrics::gauge!(name.into()).increment(start.elapsed().as_millis() as f64);
72+
res
73+
} else {
74+
f()
75+
}
76+
}
77+
}
78+
6279
#[macro_export]
6380
#[cfg(feature = "parallel")]
6481
macro_rules! parizip {

0 commit comments

Comments
 (0)