Skip to content

Commit de86ea9

Browse files
committed
feat: interactions via rational sumcheck
1 parent 884f8e6 commit de86ea9

32 files changed

+2313
-349
lines changed

crates/stark-backend/src/air_builders/debug/check_constraints.rs

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
1-
use itertools::izip;
1+
use itertools::{izip, Itertools};
22
use p3_air::BaseAir;
33
use p3_field::{Field, FieldAlgebra};
4-
use p3_matrix::{dense::RowMajorMatrixView, stack::VerticalPair, Matrix};
4+
use p3_matrix::{
5+
dense::{RowMajorMatrix, RowMajorMatrixView},
6+
stack::VerticalPair,
7+
Matrix,
8+
};
59
use p3_maybe_rayon::prelude::*;
10+
use p3_util::log2_strict_usize;
611

712
use crate::{
813
air_builders::debug::DebugConstraintBuilder,
914
config::{StarkGenericConfig, Val},
1015
interaction::{
1116
debug::{generate_logical_interactions, LogicalInteractions},
17+
gkr_log_up::fold_multilinear_lagrange_col_constraints,
1218
RapPhaseSeqKind, SymbolicInteraction,
1319
},
1420
rap::{PartitionedBaseAir, Rap},
@@ -87,9 +93,53 @@ pub fn check_constraints<R, SC>(
8793
}
8894

8995
rap.eval(&mut builder);
96+
97+
// if matches!(SC::RapPhaseSeq::KIND, RapPhaseSeqKind::GkrLogUp) {
98+
// check_gkr_log_up_adapter_constraints_for_row::<SC>(after_challenge, challenges, i);
99+
// }
90100
});
91101
}
92102

103+
fn check_gkr_log_up_adapter_constraints_for_row<SC: StarkGenericConfig>(
104+
after_challenge: &[RowMajorMatrixView<SC::Challenge>],
105+
challenges: &[Vec<SC::Challenge>],
106+
i: usize,
107+
) {
108+
if after_challenge.is_empty() {
109+
return;
110+
}
111+
assert_eq!(after_challenge.len(), 1);
112+
assert_eq!(challenges.len(), 1);
113+
114+
let after_challenge = &after_challenge[0];
115+
let challenges = &challenges[0];
116+
let height = after_challenge.height();
117+
118+
let log_height = log2_strict_usize(height);
119+
let indices = std::iter::once(i).chain((0..log_height).map(|j| (i + (1 << j)) % height));
120+
let after_challenge_window = RowMajorMatrix::new(
121+
indices.flat_map(|i| after_challenge.row(i)).collect_vec(),
122+
after_challenge.width(),
123+
);
124+
125+
let r = &challenges[challenges.len() - log_height..];
126+
let mut accumulator = SC::Challenge::ZERO;
127+
let alpha = SC::Challenge::TWO;
128+
129+
let is_cyclic_row = (0..=log_height)
130+
.map(|k| SC::Challenge::from_bool(i & ((1 << (log_height - k)) - 1) == 0))
131+
.collect_vec();
132+
fold_multilinear_lagrange_col_constraints(
133+
&mut accumulator,
134+
alpha,
135+
&after_challenge_window,
136+
&is_cyclic_row,
137+
r,
138+
0,
139+
);
140+
assert_eq!(accumulator, SC::Challenge::ZERO);
141+
}
142+
93143
pub fn check_logup<F: Field>(
94144
air_names: &[String],
95145
interactions: &[Vec<SymbolicInteraction<F>>],
@@ -98,7 +148,7 @@ pub fn check_logup<F: Field>(
98148
public_values: &[Vec<F>],
99149
) {
100150
let mut logical_interactions = LogicalInteractions::<F>::default();
101-
for (air_idx, (interactions, preprocessed, partitioned_main, public_values)) in
151+
for (air_idx, (interactions, &preprocessed, partitioned_main, public_values)) in
102152
izip!(interactions, preprocessed, partitioned_main, public_values).enumerate()
103153
{
104154
generate_logical_interactions(

crates/stark-backend/src/air_builders/symbolic/mod.rs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ pub struct SymbolicConstraints<F> {
4141
/// the prover for after challenge trace generation, and some partial
4242
/// information may be used by the verifier.
4343
///
44+
/// FIXME[zach]: False for GKR; there, some constraints depend on height of the trace and therefore cannot be precomputed.
4445
/// **However**, any contributions to the quotient polynomial from
4546
/// logup are already included in `constraints` and do not need to
4647
/// be separately calculated from `interactions`.
@@ -399,15 +400,18 @@ impl<F: Field> InteractionPhaseAirBuilder for SymbolicRapBuilder<F> {
399400
assert!(self.challenges.is_empty());
400401
assert!(self.exposed_values_after_challenge.is_empty());
401402

402-
if self.rap_phase_seq_kind == RapPhaseSeqKind::FriLogUp {
403-
let interaction_partitions =
404-
find_interaction_chunks(&self.interactions, self.max_constraint_degree)
405-
.interaction_partitions();
406-
let num_chunks = interaction_partitions.len();
407-
self.interaction_partitions.replace(interaction_partitions);
408-
let perm_width = num_chunks + 1;
409-
self.after_challenge = Self::new_after_challenge(&[perm_width]);
410-
}
403+
let perm_width = match self.rap_phase_seq_kind {
404+
RapPhaseSeqKind::FriLogUp => {
405+
let interaction_partitions =
406+
find_interaction_chunks(&self.interactions, self.max_constraint_degree)
407+
.interaction_partitions();
408+
let num_chunks = interaction_partitions.len();
409+
self.interaction_partitions.replace(interaction_partitions);
410+
num_chunks + 1
411+
}
412+
RapPhaseSeqKind::GkrLogUp => 2,
413+
};
414+
self.after_challenge = Self::new_after_challenge(&[perm_width]);
411415

412416
let phases_shapes = self.rap_phase_seq_kind.shape();
413417
let phase_shape = phases_shapes.first().unwrap();

crates/stark-backend/src/config.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,15 @@ pub type PackedChallenge<SC> =
9393
#[derive(Debug)]
9494
pub struct StarkConfig<Pcs, RapPhaseSeq, Challenge, Challenger> {
9595
pcs: Pcs,
96-
rap_phase: RapPhaseSeq,
96+
rap_phase_seq: RapPhaseSeq,
9797
_phantom: PhantomData<(Challenge, Challenger)>,
9898
}
9999

100100
impl<Pcs, RapPhaseSeq, Challenge, Challenger> StarkConfig<Pcs, RapPhaseSeq, Challenge, Challenger> {
101-
pub const fn new(pcs: Pcs, rap_phase: RapPhaseSeq) -> Self {
101+
pub const fn new(pcs: Pcs, rap_phase_seq: RapPhaseSeq) -> Self {
102102
Self {
103103
pcs,
104-
rap_phase,
104+
rap_phase_seq,
105105
_phantom: PhantomData,
106106
}
107107
}
@@ -132,7 +132,7 @@ where
132132
&self.pcs
133133
}
134134
fn rap_phase_seq(&self) -> &Self::RapPhaseSeq {
135-
&self.rap_phase
135+
&self.rap_phase_seq
136136
}
137137
}
138138

crates/stark-backend/src/gkr/prover.rs

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::{
77

88
use itertools::Itertools;
99
use p3_challenger::FieldChallenger;
10-
use p3_field::Field;
10+
use p3_field::{ExtensionField, Field};
1111
use thiserror::Error;
1212

1313
use crate::{
@@ -327,10 +327,10 @@ pub struct NotConstantPolyError;
327327
///
328328
/// The input layers should be committed to the channel before calling this function.
329329
// GKR algorithm: <https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf> (page 64)
330-
pub fn prove_batch<F: Field>(
330+
pub fn prove_batch<F: Field, EF: ExtensionField<F>>(
331331
challenger: &mut impl FieldChallenger<F>,
332-
input_layer_by_instance: Vec<Layer<F>>,
333-
) -> (GkrBatchProof<F>, GkrArtifact<F>) {
332+
input_layer_by_instance: Vec<Layer<EF>>,
333+
) -> (GkrBatchProof<EF>, GkrArtifact<EF>) {
334334
let n_instances = input_layer_by_instance.len();
335335
let n_layers_by_instance = input_layer_by_instance
336336
.iter()
@@ -366,12 +366,14 @@ pub fn prove_batch<F: Field>(
366366

367367
// Seed the channel with layer claims.
368368
for claims_to_verify in claims_to_verify_by_instance.iter().flatten() {
369-
challenger.observe_slice(claims_to_verify);
369+
for claim in claims_to_verify {
370+
challenger.observe_ext_element(*claim);
371+
}
370372
}
371373

372374
let eq_evals = HypercubeEqEvals::eval(&ood_point);
373-
let sumcheck_alpha = challenger.sample();
374-
let instance_lambda = challenger.sample();
375+
let sumcheck_alpha: EF = challenger.sample_ext_element();
376+
let instance_lambda: EF = challenger.sample_ext_element();
375377

376378
let mut sumcheck_oracles = Vec::new();
377379
let mut sumcheck_claims = Vec::new();
@@ -385,7 +387,7 @@ pub fn prove_batch<F: Field>(
385387
sumcheck_oracles.push(GkrMultivariatePolyOracle {
386388
eq_evals: &eq_evals,
387389
input_layer: layer,
388-
eq_fixed_var_correction: F::ONE,
390+
eq_fixed_var_correction: EF::ONE,
389391
lambda: instance_lambda,
390392
});
391393
sumcheck_claims.push(random_linear_combination(claims_to_verify, instance_lambda));
@@ -417,12 +419,14 @@ pub fn prove_batch<F: Field>(
417419
// Seed the channel with the layer masks.
418420
for (&instance, mask) in zip(&sumcheck_instances, &masks) {
419421
for column in mask.columns() {
420-
challenger.observe_slice(column);
422+
for el in column {
423+
challenger.observe_ext_element(*el);
424+
}
421425
}
422426
layer_masks_by_instance[instance].push(mask.clone());
423427
}
424428

425-
let challenge = challenger.sample();
429+
let challenge: EF = challenger.sample_ext_element();
426430
ood_point = sumcheck_ood_point;
427431
ood_point.push(challenge);
428432

@@ -501,7 +505,7 @@ pub fn correct_sum_as_poly_in_first_variable<F: Field>(
501505
let r_at_2 = f_at_2 * hypercube_eq(&[F::TWO], &[y[n - k]]) * a_const;
502506

503507
// Interpolate.
504-
UnivariatePolynomial::from_interpolation(&[
508+
UnivariatePolynomial::from_points(&[
505509
(F::ZERO, r_at_0),
506510
(F::ONE, r_at_1),
507511
(F::TWO, r_at_2),

crates/stark-backend/src/gkr/tests.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ fn test_batch() -> Result<(), GkrError<BabyBear>> {
2020
let engine = default_engine();
2121
let mut rng = create_seeded_rng();
2222

23-
let col0 = Mle::new((0..1 << LOG_N).map(|_| rng.gen()).collect_vec());
24-
let col1 = Mle::new((0..1 << LOG_N).map(|_| rng.gen()).collect_vec());
23+
let col0 = Mle::from_vec((0..1 << LOG_N).map(|_| rng.gen()).collect_vec());
24+
let col1 = Mle::from_vec((0..1 << LOG_N).map(|_| rng.gen()).collect_vec());
2525

2626
let product0 = col0.iter().copied().product();
2727
let product1 = col1.iter().copied().product();
@@ -62,8 +62,8 @@ fn test_batch_with_different_sizes() -> Result<(), GkrError<BabyBear>> {
6262
const LOG_N0: usize = 5;
6363
const LOG_N1: usize = 7;
6464

65-
let col0 = Mle::new((0..1 << LOG_N0).map(|_| rng.gen()).collect());
66-
let col1 = Mle::new((0..1 << LOG_N1).map(|_| rng.gen()).collect());
65+
let col0 = Mle::from_vec((0..1 << LOG_N0).map(|_| rng.gen()).collect());
66+
let col1 = Mle::from_vec((0..1 << LOG_N1).map(|_| rng.gen()).collect());
6767

6868
let product0 = col0.iter().copied().product();
6969
let product1 = col1.iter().copied().product();
@@ -106,7 +106,7 @@ fn test_grand_product() -> Result<(), GkrError<BabyBear>> {
106106

107107
let values = (0..N).map(|_| rng.gen()).collect_vec();
108108
let product = values.iter().copied().product();
109-
let col = Mle::<BabyBear>::new(values);
109+
let col = Mle::<BabyBear>::from_vec(values);
110110
let input_layer = Layer::GrandProduct(col.clone());
111111
let (proof, _) = gkr::prove_batch(&mut engine.new_challenger(), vec![input_layer]);
112112

@@ -137,8 +137,8 @@ fn test_logup_with_generic_trace() -> Result<(), GkrError<BabyBear>> {
137137
let sum: Fraction<F> = zip(&numerator_values, &denominator_values)
138138
.map(|(&n, &d)| Fraction::new(n, d))
139139
.sum();
140-
let numerators = Mle::<F>::new(numerator_values);
141-
let denominators = Mle::<F>::new(denominator_values);
140+
let numerators = Mle::<F>::from_vec(numerator_values);
141+
let denominators = Mle::<F>::from_vec(denominator_values);
142142
let top_layer = Layer::LogUpGeneric {
143143
numerators: numerators.clone(),
144144
denominators: denominators.clone(),
@@ -178,7 +178,7 @@ fn test_logup_with_singles_trace() -> Result<(), GkrError<BabyBear>> {
178178
.iter()
179179
.map(|&d| Fraction::new(F::ONE, d))
180180
.sum();
181-
let denominators = Mle::new(denominator_values);
181+
let denominators = Mle::from_vec(denominator_values);
182182
let top_layer = Layer::LogUpSingles {
183183
denominators: denominators.clone(),
184184
};
@@ -214,8 +214,8 @@ fn test_logup_with_multiplicities_trace() -> Result<(), GkrError<BabyBear>> {
214214
let sum: Fraction<BabyBear> = zip(&numerator_values, &denominator_values)
215215
.map(|(&n, &d)| Fraction::new(n, d))
216216
.sum();
217-
let numerators = Mle::new(numerator_values);
218-
let denominators = Mle::new(denominator_values);
217+
let numerators = Mle::from_vec(numerator_values);
218+
let denominators = Mle::from_vec(denominator_values);
219219
let top_layer = Layer::LogUpMultiplicities {
220220
numerators: numerators.clone(),
221221
denominators: denominators.clone(),

crates/stark-backend/src/gkr/types.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::ops::Index;
22

33
use p3_field::Field;
4+
use serde::{Deserialize, Serialize};
45
use thiserror::Error;
56

67
use crate::{
@@ -12,6 +13,7 @@ use crate::{
1213
};
1314

1415
/// Batch GKR proof.
16+
#[derive(Clone, Serialize, Deserialize)]
1517
pub struct GkrBatchProof<F> {
1618
/// Sum-check proof for each layer.
1719
pub sumcheck_proofs: Vec<SumcheckProof<F>>,
@@ -22,6 +24,7 @@ pub struct GkrBatchProof<F> {
2224
}
2325

2426
/// Values of interest obtained from the execution of the GKR protocol.
27+
#[derive(Clone, Serialize, Deserialize)]
2528
pub struct GkrArtifact<F> {
2629
/// Out-of-domain (OOD) point for evaluating columns in the input layer.
2730
pub ood_point: Vec<F>,
@@ -32,7 +35,7 @@ pub struct GkrArtifact<F> {
3235
}
3336

3437
/// Stores two evaluations of each column in a GKR layer.
35-
#[derive(Debug, Clone)]
38+
#[derive(Debug, Clone, Serialize, Deserialize)]
3639
pub struct GkrMask<F> {
3740
columns: Vec<[F; 2]>,
3841
}
@@ -169,7 +172,7 @@ impl<F: Field> Layer<F> {
169172
.chunks_exact(2) // Process in chunks of 2 elements
170173
.map(|chunk| chunk[0] * chunk[1]) // Multiply each pair
171174
.collect();
172-
Layer::GrandProduct(Mle::new(res))
175+
Layer::GrandProduct(Mle::from_vec(res))
173176
}
174177

175178
fn next_logup_layer(numerators: MleExpr<'_, F>, denominators: &Mle<F>) -> Layer<F> {
@@ -186,8 +189,8 @@ impl<F: Field> Layer<F> {
186189
}
187190

188191
Layer::LogUpGeneric {
189-
numerators: Mle::new(next_numerators),
190-
denominators: Mle::new(next_denominators),
192+
numerators: Mle::from_vec(next_numerators),
193+
denominators: Mle::from_vec(next_denominators),
191194
}
192195
}
193196

crates/stark-backend/src/gkr/verifier.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::{
99
gate::Gate,
1010
types::{GkrArtifact, GkrBatchProof, GkrError},
1111
},
12+
p3_field::ExtensionField,
1213
poly::{multi::hypercube_eq, uni::random_linear_combination},
1314
sumcheck,
1415
};
@@ -18,11 +19,11 @@ use crate::{
1819
/// On successful verification the function returns a [`GkrArtifact`] which stores the out-of-domain
1920
/// point and claimed evaluations in the input layer columns for each instance at the OOD point.
2021
/// These claimed evaluations are not checked in this function - hence partial verification.
21-
pub fn partially_verify_batch<F: Field>(
22+
pub fn partially_verify_batch<F: Field, EF: ExtensionField<F>>(
2223
gate_by_instance: Vec<Gate>,
23-
proof: &GkrBatchProof<F>,
24+
proof: &GkrBatchProof<EF>,
2425
challenger: &mut impl FieldChallenger<F>,
25-
) -> Result<GkrArtifact<F>, GkrError<F>> {
26+
) -> Result<GkrArtifact<EF>, GkrError<EF>> {
2627
let GkrBatchProof {
2728
sumcheck_proofs,
2829
layer_masks_by_instance,
@@ -64,11 +65,13 @@ pub fn partially_verify_batch<F: Field>(
6465

6566
// Seed the channel with layer claims.
6667
for claims_to_verify in claims_to_verify_by_instance.iter().flatten() {
67-
challenger.observe_slice(claims_to_verify);
68+
for claim in claims_to_verify {
69+
challenger.observe_ext_element(*claim);
70+
}
6871
}
6972

70-
let sumcheck_alpha = challenger.sample();
71-
let instance_lambda = challenger.sample();
73+
let sumcheck_alpha = challenger.sample_ext_element();
74+
let instance_lambda = challenger.sample_ext_element();
7275

7376
let mut sumcheck_claims = Vec::new();
7477
let mut sumcheck_instances = Vec::new();
@@ -125,12 +128,14 @@ pub fn partially_verify_batch<F: Field>(
125128
let n_unused = n_layers - instance_n_layers(instance);
126129
let mask = &layer_masks_by_instance[instance][layer - n_unused];
127130
for column in mask.columns() {
128-
challenger.observe_slice(column);
131+
for elt in column {
132+
challenger.observe_ext_element(*elt);
133+
}
129134
}
130135
}
131136

132137
// Set the OOD evaluation point for layer above.
133-
let challenge = challenger.sample();
138+
let challenge = challenger.sample_ext_element();
134139
ood_point = sumcheck_ood_point;
135140
ood_point.push(challenge);
136141

0 commit comments

Comments
 (0)