Skip to content

Optimize computing batched opened values using fri_reduced_opening chip #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 50 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
6417088
Switch ceno reliance
yczhangsjtu Jul 11, 2025
4ac1628
Merge remote-tracking branch 'origin/main' into feat/simplify-basefol…
yczhangsjtu Jul 11, 2025
6069a53
Fix compilation errors due to out of date code
yczhangsjtu Jul 11, 2025
4e6d3e0
Update test query phase batch
yczhangsjtu Jul 14, 2025
07d2fba
Fix query opening proof
yczhangsjtu Jul 14, 2025
662396a
Implement basefold proof variable
yczhangsjtu Jul 15, 2025
e1c9ded
Update query phase verifier input
yczhangsjtu Jul 15, 2025
11fbced
Preparing test data for query phase with updated code
yczhangsjtu Jul 15, 2025
b6bdd1a
Implement basefold proof transform
yczhangsjtu Jul 15, 2025
a511a98
Prepare query phase verifier input
yczhangsjtu Jul 15, 2025
fc4e5ba
Prepare query phase verifier input
yczhangsjtu Jul 15, 2025
7ee316c
Fix final message access
yczhangsjtu Jul 16, 2025
0dc34f7
Switch ceno reliance to small field support
yczhangsjtu Jul 16, 2025
34f53ce
Create basefold verifier function
yczhangsjtu Jul 16, 2025
1dd84da
Check final message sizes
yczhangsjtu Jul 16, 2025
3744986
Fix final message size
yczhangsjtu Jul 16, 2025
ba814a5
Fix final message size
yczhangsjtu Jul 16, 2025
f130d21
Check query opening proof len
yczhangsjtu Jul 16, 2025
e883efb
Compute total number of polys
yczhangsjtu Jul 16, 2025
3629b44
Sample batch coeffs
yczhangsjtu Jul 16, 2025
12b81e3
Compute max_num_var
yczhangsjtu Jul 17, 2025
a68999b
Write sumcheck messages and commits to transcript
yczhangsjtu Jul 17, 2025
f9f244c
Write final message to transcript
yczhangsjtu Jul 17, 2025
89e5e7e
Complete the code for batch verifier
yczhangsjtu Jul 17, 2025
5ac05e7
Add verifier test
yczhangsjtu Jul 21, 2025
ac4cc44
Try to fix some compilation errors in e2e
yczhangsjtu Jul 21, 2025
33fba3d
Connecting pcs with e2e
yczhangsjtu Jul 21, 2025
389e2b2
Merge remote-tracking branch 'origin/cyte/fix-query-phase' into feat/…
yczhangsjtu Jul 21, 2025
b6c3942
Fix some issues after merge
yczhangsjtu Jul 21, 2025
d5861f2
Make compilation pass temporarily
yczhangsjtu Jul 21, 2025
785bb58
Make test pass before query phase
yczhangsjtu Jul 21, 2025
d3b6627
Merge remote
yczhangsjtu Jul 22, 2025
197ed35
Supply the permutation and make the random case pass
yczhangsjtu Jul 22, 2025
6824da4
Try fixing transcript inconsistency
yczhangsjtu Jul 22, 2025
db26b0f
Use bin to dec le
yczhangsjtu Jul 22, 2025
45a2c4c
Add pow witness
yczhangsjtu Jul 22, 2025
dcc4c5c
Basefold verifier passes for simple case
yczhangsjtu Jul 22, 2025
5aa023c
Update dependency
yczhangsjtu Jul 23, 2025
7e67692
Basefold verifier passes decreasing and random batches
yczhangsjtu Jul 23, 2025
8e0911a
Change the way for computing batch
yczhangsjtu Jul 24, 2025
ec51a95
Remove the unnecessary slice
yczhangsjtu Jul 24, 2025
156f5e6
Simplify code
yczhangsjtu Jul 24, 2025
86476a1
Fix bug: the first mmcs batch verify passes
yczhangsjtu Jul 24, 2025
6f91a61
Initialize all zeros
yczhangsjtu Jul 24, 2025
75b0270
Merge remote
yczhangsjtu Jul 24, 2025
989eb3e
Fix
yczhangsjtu Jul 24, 2025
9d02328
Fix one bug: alpha is batch coeffs 1 not 0
yczhangsjtu Jul 24, 2025
946612f
Handle case when batch size is 1
yczhangsjtu Jul 24, 2025
d16990c
Allocate all zeros only once
yczhangsjtu Jul 25, 2025
64f47e7
Construct test multiple rounds and test with fibonacci e2e data
yczhangsjtu Jul 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 135 additions & 50 deletions src/basefold_verifier/query_phase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ impl

#[derive(DslVariable, Clone)]
pub struct BatchOpeningVariable<C: Config> {
pub opened_values: Array<C, Array<C, Felt<C::F>>>,
pub opened_values: HintSlice<C>,
pub opening_proof: HintSlice<C>,
}

impl Hintable<InnerConfig> for BatchOpening {
type HintVariable = BatchOpeningVariable<InnerConfig>;

fn read(builder: &mut Builder<InnerConfig>) -> Self::HintVariable {
let opened_values = Vec::<Vec<F>>::read(builder);
let opened_values = read_hint_slice(builder);
let opening_proof = read_hint_slice(builder);

BatchOpeningVariable {
Expand All @@ -78,7 +78,14 @@ impl Hintable<InnerConfig> for BatchOpening {

fn write(&self) -> Vec<Vec<<InnerConfig as Config>::N>> {
let mut stream = Vec::new();
stream.extend(self.opened_values.write());
stream.extend(vec![
vec![F::from_canonical_usize(self.opened_values.len())],
self.opened_values
.iter()
.flatten()
.copied()
.collect::<Vec<_>>(),
]);
stream.extend(vec![
vec![F::from_canonical_usize(self.opening_proof.len())],
self.opening_proof
Expand Down Expand Up @@ -238,11 +245,13 @@ pub struct PointAndEvalsVariable<C: Config> {
pub struct QueryPhaseVerifierInput {
// pub t_inv_halves: Vec<Vec<<E as ExtensionField>::BaseField>>,
pub max_num_var: usize,
pub max_width: usize,
pub batch_coeffs: Vec<E>,
pub fold_challenges: Vec<E>,
pub indices: Vec<usize>,
pub proof: BasefoldProof,
pub rounds: Vec<Round>,
pub perms: Vec<Vec<usize>>,
}

impl Hintable<InnerConfig> for QueryPhaseVerifierInput {
Expand All @@ -251,15 +260,19 @@ impl Hintable<InnerConfig> for QueryPhaseVerifierInput {
fn read(builder: &mut Builder<InnerConfig>) -> Self::HintVariable {
// let t_inv_halves = Vec::<Vec<F>>::read(builder);
let max_num_var = Usize::Var(usize::read(builder));
let max_width = Usize::Var(usize::read(builder));
let batch_coeffs = Vec::<E>::read(builder);
let fold_challenges = Vec::<E>::read(builder);
let indices = Vec::<usize>::read(builder);
let proof = BasefoldProof::read(builder);
let rounds = Vec::<Round>::read(builder);
let perms: Array<InnerConfig, Array<InnerConfig, Var<F>>> =
Vec::<Vec<usize>>::read(builder);

QueryPhaseVerifierInputVariable {
// t_inv_halves,
max_num_var,
max_width,
batch_coeffs,
fold_challenges,
indices,
Expand All @@ -272,11 +285,13 @@ impl Hintable<InnerConfig> for QueryPhaseVerifierInput {
let mut stream = Vec::new();
// stream.extend(self.t_inv_halves.write());
stream.extend(<usize as Hintable<InnerConfig>>::write(&self.max_num_var));
stream.extend(<usize as Hintable<InnerConfig>>::write(&self.max_width));
stream.extend(self.batch_coeffs.write());
stream.extend(self.fold_challenges.write());
stream.extend(self.indices.write());
stream.extend(self.proof.write());
stream.extend(self.rounds.write());
stream.extend(self.perms.write());
stream
}
}
Expand All @@ -285,6 +300,7 @@ impl Hintable<InnerConfig> for QueryPhaseVerifierInput {
pub struct QueryPhaseVerifierInputVariable<C: Config> {
// pub t_inv_halves: Array<C, Array<C, Felt<C::F>>>,
pub max_num_var: Usize<C::N>,
pub max_width: Usize<C::N>,
pub batch_coeffs: Array<C, Ext<C::F, C::EF>>,
pub fold_challenges: Array<C, Ext<C::F, C::EF>>,
pub indices: Array<C, Var<C::N>>,
Expand Down Expand Up @@ -316,12 +332,20 @@ pub(crate) fn batch_verifier_query_phase<C: Config>(
let generator = builder.constant(C::F::from_canonical_usize(*val).inverse());
builder.set_value(&two_adic_generators_inverses, index, generator);
}
let zero: Ext<C::F, C::EF> = builder.constant(C::EF::ZERO);
let zero_flag = builder.constant(C::N::ZERO);
let two: Var<C::N> = builder.constant(C::N::from_canonical_usize(2));

// encode_small
let final_message = &input.proof.final_message;
let final_rmm_values_len = builder.get(final_message, 0).len();
let final_rmm_values = builder.dyn_array(final_rmm_values_len.clone());

let all_zeros = builder.dyn_array(input.max_width.clone());
iter_zip!(builder, all_zeros).for_each(|ptr_vec, builder| {
builder.set_value(&all_zeros, ptr_vec[0], zero.clone());
});

builder
.range(0, final_rmm_values_len.clone())
.for_each(|i_vec, builder| {
Expand All @@ -346,7 +370,13 @@ pub(crate) fn batch_verifier_query_phase<C: Config>(
let log2_max_codeword_size: Var<C::N> =
builder.eval(input.max_num_var.clone() + Usize::from(get_rate_log()));

let zero: Ext<C::F, C::EF> = builder.constant(C::EF::ZERO);
let alpha: Ext<C::F, C::EF> = builder.constant(C::EF::ONE);
builder
.if_ne(input.batch_coeffs.len(), C::N::ONE)
.then(|builder| {
let batch_coeff = builder.get(&input.batch_coeffs, 1);
builder.assign(&alpha, batch_coeff);
});

iter_zip!(builder, input.indices, input.proof.query_opening_proof).for_each(
|ptr_vec, builder| {
Expand Down Expand Up @@ -380,24 +410,96 @@ pub(crate) fn batch_verifier_query_phase<C: Config>(
let batch_opening = builder.iter_ptr_get(&query.input_proofs, ptr_vec[0]);
let round = builder.iter_ptr_get(&input.rounds, ptr_vec[1]);
let opened_values = batch_opening.opened_values;
let perm_opened_values = builder.dyn_array(opened_values.len());
let dimensions = builder.dyn_array(opened_values.len());
let perm_opened_values = builder.dyn_array(opened_values.length.clone());
let dimensions = builder.dyn_array(opened_values.length.clone());
let opening_proof = batch_opening.opening_proof;

let opened_values_buffer: Array<C, Array<C, Felt<C::F>>> =
builder.dyn_array(opened_values.length);

// TODO: optimize this procedure
iter_zip!(builder, opened_values_buffer, round.openings).for_each(
|ptr_vec, builder| {
let opening = builder.iter_ptr_get(&round.openings, ptr_vec[1]);
let log2_height: Var<C::N> =
builder.eval(opening.num_var + Usize::from(get_rate_log() - 1));
let width = opening.point_and_evals.evals.len();

let opened_value_len: Var<C::N> = builder.eval(width.clone() * two);
let opened_value_buffer = builder.dyn_array(opened_value_len);
builder.iter_ptr_set(
&opened_values_buffer,
ptr_vec[0],
opened_value_buffer.clone(),
);

let low_values = opened_value_buffer.slice(builder, 0, width.clone());
let high_values = opened_value_buffer.slice(
builder,
width.clone(),
opened_value_buffer.len(),
);

// The linear combination is by (alpha^offset, ..., alpha^(offset+width-1)), which is equal to
// alpha^offset * (1, ..., alpha^(width-1))
let alpha_offset =
builder.get(&input.batch_coeffs, batch_coeffs_offset.clone());
// Will need to negate the values of low and high
// because `fri_single_reduced_opening_eval` is
// computing \sum_i alpha^i (0 - opened_value[i]).
// We want \sum_i alpha^(i + offset) opened_value[i]
// Let's negate it here.
builder.assign(&alpha_offset, -alpha_offset);
let all_zeros_slice = all_zeros.slice(builder, 0, width.clone());

let low = builder.fri_single_reduced_opening_eval(
alpha,
opened_values.id.get_var(),
zero_flag,
&low_values,
&all_zeros_slice,
);
let high = builder.fri_single_reduced_opening_eval(
alpha,
opened_values.id.get_var(),
zero_flag,
&high_values,
&all_zeros_slice,
);
builder.assign(&low, low * alpha_offset);
builder.assign(&high, high * alpha_offset);

let codeword: PackedCodeword<C> = PackedCodeword { low, high };
let codeword_acc = builder.get(&reduced_codeword_by_height, log2_height);

// reduced_openings[log2_height] += codeword
builder.assign(&codeword_acc.low, codeword_acc.low + codeword.low);
builder.assign(&codeword_acc.high, codeword_acc.high + codeword.high);

builder.set_value(&reduced_codeword_by_height, log2_height, codeword_acc);
builder.assign(&batch_coeffs_offset, batch_coeffs_offset + width.clone());
},
);

// TODO: ensure that perm is indeed a permutation of 0, ..., opened_values.len()-1

// reorder (opened values, dimension) according to the permutation
builder
.range(0, opened_values.len())
.range(0, opened_values_buffer.len())
.for_each(|j_vec, builder| {
let j = j_vec[0];
let mat_j = builder.get(&opened_values, j);

let mat_j = builder.get(&opened_values_buffer, j);
let num_var_j = builder.get(&round.openings, j).num_var;
let height_j = builder.eval(num_var_j + Usize::from(get_rate_log() - 1));

let permuted_j = builder.get(&round.perm, j);
// let permuted_j = j;

builder.set_value(&perm_opened_values, permuted_j, mat_j);
builder.set_value(&dimensions, permuted_j, height_j);
});
// TODO: ensure that dimensions is indeed sorted decreasingly

// i >>= (log2_max_codeword_size - commit.log2_max_codeword_size);
let bits_shift: Var<C::N> = builder
Expand All @@ -414,48 +516,6 @@ pub(crate) fn batch_verifier_query_phase<C: Config>(
};

mmcs_verify_batch(builder, mmcs_verifier_input);

// TODO: optimize this procedure
iter_zip!(builder, opened_values, round.openings).for_each(|ptr_vec, builder| {
let opened_value = builder.iter_ptr_get(&opened_values, ptr_vec[0]);
let opening = builder.iter_ptr_get(&round.openings, ptr_vec[1]);
let log2_height: Var<C::N> =
builder.eval(opening.num_var + Usize::from(get_rate_log() - 1));
let width = opening.point_and_evals.evals.len();

let batch_coeffs_next_offset: Var<C::N> =
builder.eval(batch_coeffs_offset + width.clone());
let coeffs = input.batch_coeffs.slice(
builder,
batch_coeffs_offset.clone(),
batch_coeffs_next_offset.clone(),
);
let low_values = opened_value.slice(builder, 0, width.clone());
let high_values =
opened_value.slice(builder, width.clone(), opened_value.len());
let low: Ext<C::F, C::EF> = builder.constant(C::EF::ZERO);
let high: Ext<C::F, C::EF> = builder.constant(C::EF::ZERO);

iter_zip!(builder, coeffs, low_values, high_values).for_each(
|ptr_vec, builder| {
let coeff = builder.iter_ptr_get(&coeffs, ptr_vec[0]);
let low_value = builder.iter_ptr_get(&low_values, ptr_vec[1]);
let high_value = builder.iter_ptr_get(&high_values, ptr_vec[2]);

builder.assign(&low, low + coeff * low_value);
builder.assign(&high, high + coeff * high_value);
},
);
let codeword: PackedCodeword<C> = PackedCodeword { low, high };
let codeword_acc = builder.get(&reduced_codeword_by_height, log2_height);

// reduced_openings[log2_height] += codeword
builder.assign(&codeword_acc.low, codeword_acc.low + codeword.low);
builder.assign(&codeword_acc.high, codeword_acc.high + codeword.high);

builder.set_value(&reduced_codeword_by_height, log2_height, codeword_acc);
builder.assign(&batch_coeffs_offset, batch_coeffs_next_offset);
});
});

let opening_ext = query.commit_phase_openings;
Expand Down Expand Up @@ -718,6 +778,23 @@ pub mod tests {
let pp = PCS::setup(1 << 20, mpcs::SecurityLevel::Conjecture100bits).unwrap();
let (pp, vp) = pcs_trim::<E, PCS>(pp, 1 << 20).unwrap();

// Sort the dimensions decreasingly and compute the permutation array
let mut dimensions_with_index = dimensions.iter().enumerate().collect::<Vec<_>>();
dimensions_with_index.sort_by(|(_, (a, _)), (_, (b, _))| b.cmp(a));
// The perm array should satisfy that: sorted_dimensions[perm[i]] = dimensions[i]
// However, if we just pick the indices now, we get the inverse permutation:
// sorted_dimensions[i] = dimensions[perm[i]]
let perm = dimensions_with_index
.iter()
.map(|(i, _)| *i)
.collect::<Vec<_>>();
// So we need to invert the permutation
let mut inverted_perm = vec![0usize; perm.len()];
for (i, &j) in perm.iter().enumerate() {
inverted_perm[j] = i;
}
let perm = inverted_perm;

let mut num_total_polys = 0;
let (matrices, mles): (Vec<_>, Vec<_>) = dimensions
.into_iter()
Expand Down Expand Up @@ -770,6 +847,11 @@ pub mod tests {
.map(|(point, _)| point.len())
.max()
.unwrap();
let max_width = point_and_evals
.iter()
.map(|(_, evals)| evals.len())
.max()
.unwrap();
let num_rounds = max_num_var; // The final message is of length 1

// prepare folding challenges via sumcheck round msg + FRI commitment
Expand Down Expand Up @@ -801,9 +883,11 @@ pub mod tests {
<BasefoldRSParams as BasefoldSpec<E>>::get_number_queries(),
max_num_var + <BasefoldRSParams as BasefoldSpec<E>>::get_rate_log(),
);
let perms = vec![perm];

let query_input = QueryPhaseVerifierInput {
max_num_var,
max_width,
fold_challenges,
batch_coeffs,
indices: queries,
Expand All @@ -825,6 +909,7 @@ pub mod tests {
.collect(),
})
.collect(),
perms,
};
let (program, witness) = build_batch_verifier_query_phase(query_input);

Expand Down
Loading