Skip to content

Commit 63df013

Browse files
committed
Interactions prove/verify via rational sumcheck
1 parent d2788c7 commit 63df013

31 files changed

+1892
-372
lines changed

crates/stark-backend/src/air_builders/debug/check_constraints.rs

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
1-
use itertools::izip;
1+
use itertools::{izip, Itertools};
22
use p3_air::BaseAir;
33
use p3_field::{Field, FieldAlgebra};
4-
use p3_matrix::{dense::RowMajorMatrixView, stack::VerticalPair, Matrix};
4+
use p3_matrix::{
5+
dense::{RowMajorMatrix, RowMajorMatrixView},
6+
stack::VerticalPair,
7+
Matrix,
8+
};
59
use p3_maybe_rayon::prelude::*;
10+
use p3_util::log2_strict_usize;
611

712
use crate::{
813
air_builders::debug::DebugConstraintBuilder,
914
config::{StarkGenericConfig, Val},
1015
interaction::{
1116
debug::{generate_logical_interactions, LogicalInteractions},
12-
InteractionType, RapPhaseSeqKind, SymbolicInteraction,
17+
gkr_log_up::fold_multilinear_lagrange_col_constraints,
18+
InteractionType, RapPhaseSeq, RapPhaseSeqKind, SymbolicInteraction,
1319
},
1420
rap::{PartitionedBaseAir, Rap},
1521
};
@@ -69,7 +75,7 @@ pub fn check_constraints<R, SC>(
6975
.iter()
7076
.map(|mat| (mat.row_slice(i), mat.row_slice(i_next)))
7177
.collect::<Vec<_>>();
72-
let after_challenge = after_challenge_row_pair
78+
let after_challenge_pairs = after_challenge_row_pair
7379
.iter()
7480
.map(|(local, next)| {
7581
VerticalPair::new(
@@ -87,7 +93,7 @@ pub fn check_constraints<R, SC>(
8793
RowMajorMatrixView::new_row(preprocessed_next.as_slice()),
8894
),
8995
partitioned_main,
90-
after_challenge,
96+
after_challenge: after_challenge_pairs,
9197
challenges,
9298
public_values,
9399
exposed_values_after_challenge,
@@ -106,9 +112,53 @@ pub fn check_constraints<R, SC>(
106112
}
107113

108114
rap.eval(&mut builder);
115+
116+
if matches!(SC::RapPhaseSeq::KIND, RapPhaseSeqKind::GkrLogUp) {
117+
check_gkr_log_up_adapter_constraints_for_row::<SC>(after_challenge, challenges, i);
118+
}
109119
});
110120
}
111121

122+
fn check_gkr_log_up_adapter_constraints_for_row<SC: StarkGenericConfig>(
123+
after_challenge: &[RowMajorMatrixView<SC::Challenge>],
124+
challenges: &[Vec<SC::Challenge>],
125+
i: usize,
126+
) {
127+
if after_challenge.is_empty() {
128+
return;
129+
}
130+
assert_eq!(after_challenge.len(), 1);
131+
assert_eq!(challenges.len(), 1);
132+
133+
let after_challenge = &after_challenge[0];
134+
let challenges = &challenges[0];
135+
let height = after_challenge.height();
136+
137+
let log_height = log2_strict_usize(height);
138+
let indices = std::iter::once(i).chain((0..log_height).map(|j| (i + (1 << j)) % height));
139+
let after_challenge_window = RowMajorMatrix::new(
140+
indices.flat_map(|i| after_challenge.row(i)).collect_vec(),
141+
after_challenge.width(),
142+
);
143+
144+
let r = &challenges[challenges.len() - log_height..];
145+
let mut accumulator = SC::Challenge::ZERO;
146+
let alpha = SC::Challenge::TWO;
147+
148+
let is_cyclic_row = (0..=log_height)
149+
.map(|k| SC::Challenge::from_bool(i & ((1 << (log_height - k)) - 1) == 0))
150+
.collect_vec();
151+
fold_multilinear_lagrange_col_constraints(
152+
&mut accumulator,
153+
alpha,
154+
&after_challenge_window,
155+
&is_cyclic_row,
156+
r,
157+
0,
158+
);
159+
assert_eq!(accumulator, SC::Challenge::ZERO);
160+
}
161+
112162
pub fn check_logup<F: Field>(
113163
air_names: &[String],
114164
interactions: &[&[SymbolicInteraction<F>]],

crates/stark-backend/src/air_builders/debug/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,8 @@ where
221221
impl<SC: StarkGenericConfig> InteractionPhaseAirBuilder for DebugConstraintBuilder<'_, SC> {
222222
fn finalize_interactions(&mut self) {}
223223

224-
fn interaction_chunk_size(&self) -> usize {
225-
0
224+
fn interaction_chunk_size(&self) -> Option<usize> {
225+
None
226226
}
227227

228228
fn rap_phase_seq_kind(&self) -> RapPhaseSeqKind {

crates/stark-backend/src/air_builders/symbolic/mod.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ pub struct SymbolicConstraints<F> {
4343
/// the prover for after challenge trace generation, and some partial
4444
/// information may be used by the verifier.
4545
///
46+
/// FIXME[zach]: False for GKR; there, some constraints depend on height of the trace and therefore cannot be precomputed.
4647
/// **However**, any contributions to the quotient polynomial from
4748
/// logup are already included in `constraints` and do not need to
4849
/// be separately calculated from `interactions`.
@@ -110,7 +111,7 @@ where
110111
num_challenges_to_sample,
111112
num_exposed_values_after_challenge,
112113
rap_phase_seq_kind,
113-
interaction_chunk_size,
114+
Some(interaction_chunk_size),
114115
);
115116
Rap::eval(rap, &mut builder);
116117
builder
@@ -127,7 +128,7 @@ pub struct SymbolicRapBuilder<F> {
127128
exposed_values_after_challenge: Vec<Vec<SymbolicVariable<F>>>,
128129
constraints: Vec<SymbolicExpression<F>>,
129130
interactions: Vec<SymbolicInteraction<F>>,
130-
interaction_chunk_size: usize,
131+
interaction_chunk_size: Option<usize>,
131132
rap_phase_seq_kind: RapPhaseSeqKind,
132133
trace_width: TraceWidth,
133134
}
@@ -142,7 +143,7 @@ impl<F: Field> SymbolicRapBuilder<F> {
142143
num_challenges_to_sample: &[usize],
143144
num_exposed_values_after_challenge: &[usize],
144145
rap_phase_seq_kind: RapPhaseSeqKind,
145-
interaction_chunk_size: usize,
146+
interaction_chunk_size: Option<usize>,
146147
) -> Self {
147148
let preprocessed_width = width.preprocessed.unwrap_or(0);
148149
let prep_values = [0, 1]
@@ -398,7 +399,15 @@ impl<F: Field> InteractionPhaseAirBuilder for SymbolicRapBuilder<F> {
398399
assert!(self.challenges.is_empty());
399400
assert!(self.exposed_values_after_challenge.is_empty());
400401

401-
let perm_width = num_interactions.div_ceil(self.interaction_chunk_size) + 1;
402+
let perm_width = match self.rap_phase_seq_kind {
403+
RapPhaseSeqKind::StarkLogUp => {
404+
let interaction_chunk_size = self
405+
.interaction_chunk_size()
406+
.expect("interaction chunk size should be set for StarkLogUp");
407+
num_interactions.div_ceil(interaction_chunk_size) + 1
408+
}
409+
RapPhaseSeqKind::GkrLogUp => 2,
410+
};
402411
self.after_challenge = Self::new_after_challenge(&[perm_width]);
403412

404413
let phases_shapes = self.rap_phase_seq_kind.shape();
@@ -410,7 +419,7 @@ impl<F: Field> InteractionPhaseAirBuilder for SymbolicRapBuilder<F> {
410419
}
411420
}
412421

413-
fn interaction_chunk_size(&self) -> usize {
422+
fn interaction_chunk_size(&self) -> Option<usize> {
414423
self.interaction_chunk_size
415424
}
416425

crates/stark-backend/src/config.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,15 @@ pub type PackedChallenge<SC> =
9191
#[derive(Debug)]
9292
pub struct StarkConfig<Pcs, RapPhaseSeq, Challenge, Challenger> {
9393
pcs: Pcs,
94-
rap_phase: RapPhaseSeq,
94+
rap_phase_seq: RapPhaseSeq,
9595
_phantom: PhantomData<(Challenge, Challenger)>,
9696
}
9797

9898
impl<Pcs, RapPhaseSeq, Challenge, Challenger> StarkConfig<Pcs, RapPhaseSeq, Challenge, Challenger> {
99-
pub const fn new(pcs: Pcs, rap_phase: RapPhaseSeq) -> Self {
99+
pub const fn new(pcs: Pcs, rap_phase_seq: RapPhaseSeq) -> Self {
100100
Self {
101101
pcs,
102-
rap_phase,
102+
rap_phase_seq,
103103
_phantom: PhantomData,
104104
}
105105
}
@@ -128,7 +128,7 @@ where
128128
&self.pcs
129129
}
130130
fn rap_phase_seq(&self) -> &Self::RapPhaseSeq {
131-
&self.rap_phase
131+
&self.rap_phase_seq
132132
}
133133
}
134134

crates/stark-backend/src/gkr/prover.rs

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::{
77

88
use itertools::Itertools;
99
use p3_challenger::FieldChallenger;
10-
use p3_field::Field;
10+
use p3_field::{ExtensionField, Field};
1111
use thiserror::Error;
1212

1313
use crate::{
@@ -327,10 +327,10 @@ pub struct NotConstantPolyError;
327327
///
328328
/// The input layers should be committed to the channel before calling this function.
329329
// GKR algorithm: <https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf> (page 64)
330-
pub fn prove_batch<F: Field>(
330+
pub fn prove_batch<F: Field, EF: ExtensionField<F>>(
331331
challenger: &mut impl FieldChallenger<F>,
332-
input_layer_by_instance: Vec<Layer<F>>,
333-
) -> (GkrBatchProof<F>, GkrArtifact<F>) {
332+
input_layer_by_instance: Vec<Layer<EF>>,
333+
) -> (GkrBatchProof<EF>, GkrArtifact<EF>) {
334334
let n_instances = input_layer_by_instance.len();
335335
let n_layers_by_instance = input_layer_by_instance
336336
.iter()
@@ -366,12 +366,14 @@ pub fn prove_batch<F: Field>(
366366

367367
// Seed the channel with layer claims.
368368
for claims_to_verify in claims_to_verify_by_instance.iter().flatten() {
369-
challenger.observe_slice(claims_to_verify);
369+
for claim in claims_to_verify {
370+
challenger.observe_ext_element(*claim);
371+
}
370372
}
371373

372374
let eq_evals = HypercubeEqEvals::eval(&ood_point);
373-
let sumcheck_alpha = challenger.sample();
374-
let instance_lambda = challenger.sample();
375+
let sumcheck_alpha: EF = challenger.sample_ext_element();
376+
let instance_lambda: EF = challenger.sample_ext_element();
375377

376378
let mut sumcheck_oracles = Vec::new();
377379
let mut sumcheck_claims = Vec::new();
@@ -385,7 +387,7 @@ pub fn prove_batch<F: Field>(
385387
sumcheck_oracles.push(GkrMultivariatePolyOracle {
386388
eq_evals: &eq_evals,
387389
input_layer: layer,
388-
eq_fixed_var_correction: F::ONE,
390+
eq_fixed_var_correction: EF::ONE,
389391
lambda: instance_lambda,
390392
});
391393
sumcheck_claims.push(random_linear_combination(claims_to_verify, instance_lambda));
@@ -417,12 +419,14 @@ pub fn prove_batch<F: Field>(
417419
// Seed the channel with the layer masks.
418420
for (&instance, mask) in zip(&sumcheck_instances, &masks) {
419421
for column in mask.columns() {
420-
challenger.observe_slice(column);
422+
for el in column {
423+
challenger.observe_ext_element(*el);
424+
}
421425
}
422426
layer_masks_by_instance[instance].push(mask.clone());
423427
}
424428

425-
let challenge = challenger.sample();
429+
let challenge: EF = challenger.sample_ext_element();
426430
ood_point = sumcheck_ood_point;
427431
ood_point.push(challenge);
428432

@@ -501,7 +505,7 @@ pub fn correct_sum_as_poly_in_first_variable<F: Field>(
501505
let r_at_2 = f_at_2 * hypercube_eq(&[F::TWO], &[y[n - k]]) * a_const;
502506

503507
// Interpolate.
504-
UnivariatePolynomial::from_interpolation(&[
508+
UnivariatePolynomial::from_points(&[
505509
(F::ZERO, r_at_0),
506510
(F::ONE, r_at_1),
507511
(F::TWO, r_at_2),

crates/stark-backend/src/gkr/tests.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ fn test_batch() -> Result<(), GkrError<BabyBear>> {
2020
let engine = default_engine();
2121
let mut rng = create_seeded_rng();
2222

23-
let col0 = Mle::new((0..1 << LOG_N).map(|_| rng.gen()).collect_vec());
24-
let col1 = Mle::new((0..1 << LOG_N).map(|_| rng.gen()).collect_vec());
23+
let col0 = Mle::from_vec((0..1 << LOG_N).map(|_| rng.gen()).collect_vec());
24+
let col1 = Mle::from_vec((0..1 << LOG_N).map(|_| rng.gen()).collect_vec());
2525

2626
let product0 = col0.iter().copied().product();
2727
let product1 = col1.iter().copied().product();
@@ -62,8 +62,8 @@ fn test_batch_with_different_sizes() -> Result<(), GkrError<BabyBear>> {
6262
const LOG_N0: usize = 5;
6363
const LOG_N1: usize = 7;
6464

65-
let col0 = Mle::new((0..1 << LOG_N0).map(|_| rng.gen()).collect());
66-
let col1 = Mle::new((0..1 << LOG_N1).map(|_| rng.gen()).collect());
65+
let col0 = Mle::from_vec((0..1 << LOG_N0).map(|_| rng.gen()).collect());
66+
let col1 = Mle::from_vec((0..1 << LOG_N1).map(|_| rng.gen()).collect());
6767

6868
let product0 = col0.iter().copied().product();
6969
let product1 = col1.iter().copied().product();
@@ -106,7 +106,7 @@ fn test_grand_product() -> Result<(), GkrError<BabyBear>> {
106106

107107
let values = (0..N).map(|_| rng.gen()).collect_vec();
108108
let product = values.iter().copied().product();
109-
let col = Mle::<BabyBear>::new(values);
109+
let col = Mle::<BabyBear>::from_vec(values);
110110
let input_layer = Layer::GrandProduct(col.clone());
111111
let (proof, _) = gkr::prove_batch(&mut engine.new_challenger(), vec![input_layer]);
112112

@@ -137,8 +137,8 @@ fn test_logup_with_generic_trace() -> Result<(), GkrError<BabyBear>> {
137137
let sum: Fraction<F> = zip(&numerator_values, &denominator_values)
138138
.map(|(&n, &d)| Fraction::new(n, d))
139139
.sum();
140-
let numerators = Mle::<F>::new(numerator_values);
141-
let denominators = Mle::<F>::new(denominator_values);
140+
let numerators = Mle::<F>::from_vec(numerator_values);
141+
let denominators = Mle::<F>::from_vec(denominator_values);
142142
let top_layer = Layer::LogUpGeneric {
143143
numerators: numerators.clone(),
144144
denominators: denominators.clone(),
@@ -178,7 +178,7 @@ fn test_logup_with_singles_trace() -> Result<(), GkrError<BabyBear>> {
178178
.iter()
179179
.map(|&d| Fraction::new(F::ONE, d))
180180
.sum();
181-
let denominators = Mle::new(denominator_values);
181+
let denominators = Mle::from_vec(denominator_values);
182182
let top_layer = Layer::LogUpSingles {
183183
denominators: denominators.clone(),
184184
};
@@ -214,8 +214,8 @@ fn test_logup_with_multiplicities_trace() -> Result<(), GkrError<BabyBear>> {
214214
let sum: Fraction<BabyBear> = zip(&numerator_values, &denominator_values)
215215
.map(|(&n, &d)| Fraction::new(n, d))
216216
.sum();
217-
let numerators = Mle::new(numerator_values);
218-
let denominators = Mle::new(denominator_values);
217+
let numerators = Mle::from_vec(numerator_values);
218+
let denominators = Mle::from_vec(denominator_values);
219219
let top_layer = Layer::LogUpMultiplicities {
220220
numerators: numerators.clone(),
221221
denominators: denominators.clone(),

crates/stark-backend/src/gkr/types.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::ops::Index;
22

33
use p3_field::Field;
4+
use serde::{Deserialize, Serialize};
45
use thiserror::Error;
56

67
use crate::{
@@ -12,6 +13,7 @@ use crate::{
1213
};
1314

1415
/// Batch GKR proof.
16+
#[derive(Clone, Serialize, Deserialize)]
1517
pub struct GkrBatchProof<F> {
1618
/// Sum-check proof for each layer.
1719
pub sumcheck_proofs: Vec<SumcheckProof<F>>,
@@ -22,6 +24,7 @@ pub struct GkrBatchProof<F> {
2224
}
2325

2426
/// Values of interest obtained from the execution of the GKR protocol.
27+
#[derive(Clone, Serialize, Deserialize)]
2528
pub struct GkrArtifact<F> {
2629
/// Out-of-domain (OOD) point for evaluating columns in the input layer.
2730
pub ood_point: Vec<F>,
@@ -32,7 +35,7 @@ pub struct GkrArtifact<F> {
3235
}
3336

3437
/// Stores two evaluations of each column in a GKR layer.
35-
#[derive(Debug, Clone)]
38+
#[derive(Debug, Clone, Serialize, Deserialize)]
3639
pub struct GkrMask<F> {
3740
columns: Vec<[F; 2]>,
3841
}
@@ -169,7 +172,7 @@ impl<F: Field> Layer<F> {
169172
.chunks_exact(2) // Process in chunks of 2 elements
170173
.map(|chunk| chunk[0] * chunk[1]) // Multiply each pair
171174
.collect();
172-
Layer::GrandProduct(Mle::new(res))
175+
Layer::GrandProduct(Mle::from_vec(res))
173176
}
174177

175178
fn next_logup_layer(numerators: MleExpr<'_, F>, denominators: &Mle<F>) -> Layer<F> {
@@ -186,8 +189,8 @@ impl<F: Field> Layer<F> {
186189
}
187190

188191
Layer::LogUpGeneric {
189-
numerators: Mle::new(next_numerators),
190-
denominators: Mle::new(next_denominators),
192+
numerators: Mle::from_vec(next_numerators),
193+
denominators: Mle::from_vec(next_denominators),
191194
}
192195
}
193196

0 commit comments

Comments
 (0)