Skip to content

Commit 2548d06

Browse files
committed
clean up
1 parent 1d67be2 commit 2548d06

File tree

4 files changed

+135
-135
lines changed

4 files changed

+135
-135
lines changed

crates/stark-backend/src/interaction/gkr_log_up.rs

Lines changed: 59 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ pub struct GkrLogUpProvingKey;
4242
pub struct GkrLogUpPartialProof<T, Witness> {
4343
/// The rational sumcheck proof that can be verified via [gkr::partially_verify].
4444
pub gkr_proof: GkrBatchProof<T>,
45-
/// The purported evaluations of the count MLEs at `r`, per AIR, per interaction.
46-
pub count_mle_claims_per_instance: Vec<Vec<T>>,
47-
/// The purported evaluations of the sigma MLEs at `r`, per AIR, per interaction.
48-
pub sigma_mle_claims_per_instance: Vec<Vec<T>>,
45+
/// The purported evaluations of the numerator MLEs at `r`, per AIR.
46+
pub numer_mle_claims_per_instance: Vec<Vec<T>>,
47+
/// The purported evaluations of the denominator MLEs at `r`, per AIR.
48+
pub denom_mle_claims_per_instance: Vec<Vec<T>>,
4949
pub logup_pow_witness: Witness,
5050
}
5151

@@ -68,8 +68,8 @@ pub enum GkrLogUpError<F> {
6868
struct GkrAuxData<T> {
6969
after_challenge_trace_per_air: Vec<Option<DenseMatrix<T, Vec<T>>>>,
7070
exposed_values_per_air: Vec<Option<Vec<T>>>,
71-
count_mle_claims_per_instance: Vec<Vec<T>>,
72-
sigma_mle_claims_per_instance: Vec<Vec<T>>,
71+
numer_mle_claims_per_instance: Vec<Vec<T>>,
72+
denom_mle_claims_per_instance: Vec<Vec<T>>,
7373
}
7474

7575
impl<F, EF, Challenger> GkrLogUpPhase<F, EF, Challenger> {
@@ -119,17 +119,17 @@ where
119119

120120
let logup_pow_witness = challenger.grind(self.log_up_params.log_up_pow_bits);
121121

122-
let (alpha, beta) = self.generate_challenges(challenger);
122+
let (alpha, beta, gamma) = self.generate_challenges(challenger);
123123
let beta_pows = generate_betas(beta, &all_interactions);
124124

125125
// Build GKR instances.
126126
let gkr_instances: Vec<_> =
127-
Self::build_gkr_instances(trace_view_per_air, interactions_per_air, &beta_pows);
127+
Self::build_gkr_instances(trace_view_per_air, interactions_per_air, alpha, &beta_pows);
128128

129129
// Construct input layers and run GKR proof.
130130
let input_layers: Vec<_> = gkr_instances
131131
.par_iter()
132-
.map(|gkr_instance| gkr_instance.build_gkr_input_layer(alpha))
132+
.map(|gkr_instance| gkr_instance.build_gkr_input_layer())
133133
.collect();
134134

135135
let (gkr_proof, gkr_artifact) = metrics_span("gkr_prove_batch_ms", || {
@@ -140,25 +140,25 @@ where
140140
let GkrAuxData {
141141
after_challenge_trace_per_air,
142142
exposed_values_per_air,
143-
count_mle_claims_per_instance,
144-
sigma_mle_claims_per_instance,
143+
numer_mle_claims_per_instance,
144+
denom_mle_claims_per_instance,
145145
} = Self::generate_aux_per_air(
146146
challenger,
147147
interactions_per_air,
148148
trace_view_per_air,
149-
alpha,
149+
gamma,
150150
&gkr_instances,
151151
&gkr_artifact,
152152
);
153153

154-
let mut challenges = vec![beta, alpha];
154+
let mut challenges = vec![alpha, beta, gamma];
155155
challenges.extend_from_slice(&gkr_artifact.ood_point);
156156

157157
Some((
158158
GkrLogUpPartialProof {
159159
gkr_proof,
160-
count_mle_claims_per_instance,
161-
sigma_mle_claims_per_instance,
160+
numer_mle_claims_per_instance,
161+
denom_mle_claims_per_instance,
162162
logup_pow_witness,
163163
},
164164
RapPhaseProverData {
@@ -215,7 +215,7 @@ where
215215
return Err(GkrLogUpError::InvalidPowWitness);
216216
}
217217

218-
let (alpha, beta) = self.generate_challenges(challenger);
218+
let (alpha, beta, gamma) = self.generate_challenges(challenger);
219219

220220
let gkr_proof = &partial_proof.gkr_proof;
221221

@@ -234,13 +234,13 @@ where
234234
gkr::partially_verify_batch(vec![Gate::LogUp; n_instances], gkr_proof, challenger)
235235
.map_err(|e| Self::Error::GkrError(e))?;
236236

237-
for (count_mle_claims, sigma_mle_claims) in izip!(
238-
partial_proof.count_mle_claims_per_instance.iter(),
239-
partial_proof.sigma_mle_claims_per_instance.iter()
237+
for (numer_mle_claims, denom_mle_claims) in izip!(
238+
partial_proof.numer_mle_claims_per_instance.iter(),
239+
partial_proof.denom_mle_claims_per_instance.iter()
240240
) {
241-
for (count_mle_claim, sigma_mle_claim) in izip!(count_mle_claims, sigma_mle_claims) {
242-
challenger.observe_ext_element(*count_mle_claim);
243-
challenger.observe_ext_element(*sigma_mle_claim);
241+
for (denom_mle_claim, numer_mle_claim) in izip!(numer_mle_claims, denom_mle_claims) {
242+
challenger.observe_ext_element(*denom_mle_claim);
243+
challenger.observe_ext_element(*numer_mle_claim);
244244
}
245245
}
246246

@@ -254,8 +254,7 @@ where
254254
&bus_indices_per_air,
255255
exposed_values_per_air_per_phase,
256256
partial_proof,
257-
alpha,
258-
alpha,
257+
gamma,
259258
)?;
260259

261260
let mut j = 0;
@@ -276,10 +275,10 @@ where
276275
let numerator_claim = claims_to_verify[0];
277276
let denominator_claim = claims_to_verify[1];
278277

279-
let count_mle_claims = &partial_proof.count_mle_claims_per_instance[j];
280-
let sigma_mle_claims = &partial_proof.sigma_mle_claims_per_instance[j];
278+
let numer_mle_claims = &partial_proof.numer_mle_claims_per_instance[j];
279+
let denom_mle_claims = &partial_proof.denom_mle_claims_per_instance[j];
281280

282-
if count_mle_claims.len() != padded_len || sigma_mle_claims.len() != padded_len {
281+
if numer_mle_claims.len() != padded_len || denom_mle_claims.len() != padded_len {
283282
return Err(Self::Error::MalformedGkrLogUpProof);
284283
}
285284

@@ -300,11 +299,9 @@ where
300299
numerator_claim,
301300
denominator_claim,
302301
actual_sr,
303-
count_mle_claims,
304-
sigma_mle_claims,
305-
bus_indices,
306-
alpha,
307-
alpha, // using alpha as gamma—check this
302+
numer_mle_claims,
303+
denom_mle_claims,
304+
gamma,
308305
)?;
309306

310307
j += 1;
@@ -320,7 +317,7 @@ where
320317
return Err(Self::Error::NonZeroCumulativeSum);
321318
}
322319

323-
let mut challenges = vec![beta, alpha];
320+
let mut challenges = vec![alpha, beta, gamma];
324321
challenges.extend_from_slice(&gkr_artifact.ood_point);
325322

326323
Ok(RapPhaseVerifierData {
@@ -396,10 +393,11 @@ where
396393
EF: ExtensionField<F>,
397394
Challenger: FieldChallenger<F>,
398395
{
399-
fn generate_challenges(&self, challenger: &mut Challenger) -> (EF, EF) {
396+
fn generate_challenges(&self, challenger: &mut Challenger) -> (EF, EF, EF) {
400397
let alpha: EF = challenger.sample_ext_element();
401398
let beta: EF = challenger.sample_ext_element();
402-
(alpha, beta)
399+
let gamma: EF = challenger.sample_ext_element();
400+
(alpha, beta, gamma)
403401
}
404402
}
405403

@@ -446,26 +444,41 @@ pub fn eval_gkr_log_up_phase<AB>(builder: &mut AB)
446444
where
447445
AB: InteractionBuilder + PermutationAirBuilderWithExposedValues,
448446
{
449-
let &[beta, gamma] = builder.permutation_randomness() else {
450-
panic!("PermutationAirBuilderWithExposedValues requires 2 randomness elements");
447+
let &[alpha, beta, gamma] = builder.permutation_randomness() else {
448+
panic!("PermutationAirBuilderWithExposedValues requires 3 randomness elements");
451449
};
452450

453451
let all_interactions = builder.all_interactions();
452+
let num_padded_interactions = if all_interactions.len() == 1 {
453+
2
454+
} else {
455+
all_interactions.len().next_power_of_two()
456+
};
454457

455-
let mut gamma_pows = gamma.into().powers().take(2 * all_interactions.len());
458+
let mut gamma_pows = gamma.into().powers().take(2 * num_padded_interactions);
456459
let beta_pows = generate_betas(beta.into(), all_interactions);
457460

458461
let mut s_next = AB::ExprEF::ZERO;
459-
for interaction in all_interactions {
460-
s_next += gamma_pows.next().unwrap() * AB::ExprEF::from(interaction.count.clone());
462+
for i in 0..num_padded_interactions {
463+
let numerator = if i < all_interactions.len() {
464+
all_interactions[i].count.clone().into()
465+
} else {
466+
AB::ExprEF::ZERO
467+
};
468+
s_next += gamma_pows.next().unwrap() * numerator;
461469

462-
let b = AB::Expr::from_canonical_u32(interaction.bus_index as u32 + 1);
463-
let message = interaction.message.iter().chain(iter::once(&b));
470+
let denominator = if i < all_interactions.len() {
471+
let b = AB::Expr::from_canonical_u32(all_interactions[i].bus_index as u32 + 1);
472+
let message = all_interactions[i].message.iter().chain(iter::once(&b));
464473

465-
let sigma = zip(message, &beta_pows).fold(AB::ExprEF::ZERO, |acc, (field, beta)| {
466-
acc + beta.clone() * field.clone()
467-
});
468-
s_next += gamma_pows.next().unwrap() * sigma;
474+
let sigma = zip(message, &beta_pows).fold(AB::ExprEF::ZERO, |acc, (field, beta)| {
475+
acc + beta.clone() * field.clone()
476+
});
477+
alpha.into() + sigma
478+
} else {
479+
AB::ExprEF::ONE
480+
};
481+
s_next += gamma_pows.next().unwrap() * denominator;
469482
}
470483

471484
let exposed_values = builder.permutation_exposed_values();

0 commit comments

Comments
 (0)