diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index c7ff136..d0e52f6 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -59,7 +59,7 @@ impl #[derive(DslVariable, Clone)] pub struct BatchOpeningVariable { - pub opened_values: Array>>, + pub opened_values: HintSlice, pub opening_proof: HintSlice, } @@ -67,7 +67,7 @@ impl Hintable for BatchOpening { type HintVariable = BatchOpeningVariable; fn read(builder: &mut Builder) -> Self::HintVariable { - let opened_values = Vec::>::read(builder); + let opened_values = read_hint_slice(builder); let opening_proof = read_hint_slice(builder); BatchOpeningVariable { @@ -78,7 +78,14 @@ impl Hintable for BatchOpening { fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); - stream.extend(self.opened_values.write()); + stream.extend(vec![ + vec![F::from_canonical_usize(self.opened_values.len())], + self.opened_values + .iter() + .flatten() + .copied() + .collect::>(), + ]); stream.extend(vec![ vec![F::from_canonical_usize(self.opening_proof.len())], self.opening_proof @@ -238,11 +245,13 @@ pub struct PointAndEvalsVariable { pub struct QueryPhaseVerifierInput { // pub t_inv_halves: Vec::BaseField>>, pub max_num_var: usize, + pub max_width: usize, pub batch_coeffs: Vec, pub fold_challenges: Vec, pub indices: Vec, pub proof: BasefoldProof, pub rounds: Vec, + pub perms: Vec>, } impl Hintable for QueryPhaseVerifierInput { @@ -251,15 +260,19 @@ impl Hintable for QueryPhaseVerifierInput { fn read(builder: &mut Builder) -> Self::HintVariable { // let t_inv_halves = Vec::>::read(builder); let max_num_var = Usize::Var(usize::read(builder)); + let max_width = Usize::Var(usize::read(builder)); let batch_coeffs = Vec::::read(builder); let fold_challenges = Vec::::read(builder); let indices = Vec::::read(builder); let proof = BasefoldProof::read(builder); let rounds = Vec::::read(builder); + let perms: Array>> = + Vec::>::read(builder); QueryPhaseVerifierInputVariable { // t_inv_halves, max_num_var, + max_width, batch_coeffs, fold_challenges, indices, @@ -272,11 +285,13 @@ impl Hintable for QueryPhaseVerifierInput { let mut stream = Vec::new(); // stream.extend(self.t_inv_halves.write()); stream.extend(>::write(&self.max_num_var)); + stream.extend(>::write(&self.max_width)); stream.extend(self.batch_coeffs.write()); stream.extend(self.fold_challenges.write()); stream.extend(self.indices.write()); stream.extend(self.proof.write()); stream.extend(self.rounds.write()); + stream.extend(self.perms.write()); stream } } @@ -285,6 +300,7 @@ impl Hintable for QueryPhaseVerifierInput { pub struct QueryPhaseVerifierInputVariable { // pub t_inv_halves: Array>>, pub max_num_var: Usize, + pub max_width: Usize, pub batch_coeffs: Array>, pub fold_challenges: Array>, pub indices: Array>, @@ -316,12 +332,20 @@ pub(crate) fn batch_verifier_query_phase( let generator = builder.constant(C::F::from_canonical_usize(*val).inverse()); builder.set_value(&two_adic_generators_inverses, index, generator); } + let zero: Ext = builder.constant(C::EF::ZERO); + let zero_flag = builder.constant(C::N::ZERO); + let two: Var = builder.constant(C::N::from_canonical_usize(2)); // encode_small let final_message = &input.proof.final_message; let final_rmm_values_len = builder.get(final_message, 0).len(); let final_rmm_values = builder.dyn_array(final_rmm_values_len.clone()); + let all_zeros = builder.dyn_array(input.max_width.clone()); + iter_zip!(builder, all_zeros).for_each(|ptr_vec, builder| { + builder.set_value(&all_zeros, ptr_vec[0], zero.clone()); + }); + builder .range(0, final_rmm_values_len.clone()) .for_each(|i_vec, builder| { @@ -346,7 +370,13 @@ pub(crate) fn batch_verifier_query_phase( let log2_max_codeword_size: Var = builder.eval(input.max_num_var.clone() + Usize::from(get_rate_log())); - let zero: Ext = builder.constant(C::EF::ZERO); + let alpha: Ext = builder.constant(C::EF::ONE); + builder + .if_ne(input.batch_coeffs.len(), C::N::ONE) + .then(|builder| { + let batch_coeff = builder.get(&input.batch_coeffs, 1); + builder.assign(&alpha, batch_coeff); + }); iter_zip!(builder, input.indices, input.proof.query_opening_proof).for_each( |ptr_vec, builder| { @@ -380,24 +410,96 @@ pub(crate) fn batch_verifier_query_phase( let batch_opening = builder.iter_ptr_get(&query.input_proofs, ptr_vec[0]); let round = builder.iter_ptr_get(&input.rounds, ptr_vec[1]); let opened_values = batch_opening.opened_values; - let perm_opened_values = builder.dyn_array(opened_values.len()); - let dimensions = builder.dyn_array(opened_values.len()); + let perm_opened_values = builder.dyn_array(opened_values.length.clone()); + let dimensions = builder.dyn_array(opened_values.length.clone()); let opening_proof = batch_opening.opening_proof; + let opened_values_buffer: Array>> = + builder.dyn_array(opened_values.length); + + // TODO: optimize this procedure + iter_zip!(builder, opened_values_buffer, round.openings).for_each( + |ptr_vec, builder| { + let opening = builder.iter_ptr_get(&round.openings, ptr_vec[1]); + let log2_height: Var = + builder.eval(opening.num_var + Usize::from(get_rate_log() - 1)); + let width = opening.point_and_evals.evals.len(); + + let opened_value_len: Var = builder.eval(width.clone() * two); + let opened_value_buffer = builder.dyn_array(opened_value_len); + builder.iter_ptr_set( + &opened_values_buffer, + ptr_vec[0], + opened_value_buffer.clone(), + ); + + let low_values = opened_value_buffer.slice(builder, 0, width.clone()); + let high_values = opened_value_buffer.slice( + builder, + width.clone(), + opened_value_buffer.len(), + ); + + // The linear combination is by (alpha^offset, ..., alpha^(offset+width-1)), which is equal to + // alpha^offset * (1, ..., alpha^(width-1)) + let alpha_offset = + builder.get(&input.batch_coeffs, batch_coeffs_offset.clone()); + // Will need to negate the values of low and high + // because `fri_single_reduced_opening_eval` is + // computing \sum_i alpha^i (0 - opened_value[i]). + // We want \sum_i alpha^(i + offset) opened_value[i] + // Let's negate it here. + builder.assign(&alpha_offset, -alpha_offset); + let all_zeros_slice = all_zeros.slice(builder, 0, width.clone()); + + let low = builder.fri_single_reduced_opening_eval( + alpha, + opened_values.id.get_var(), + zero_flag, + &low_values, + &all_zeros_slice, + ); + let high = builder.fri_single_reduced_opening_eval( + alpha, + opened_values.id.get_var(), + zero_flag, + &high_values, + &all_zeros_slice, + ); + builder.assign(&low, low * alpha_offset); + builder.assign(&high, high * alpha_offset); + + let codeword: PackedCodeword = PackedCodeword { low, high }; + let codeword_acc = builder.get(&reduced_codeword_by_height, log2_height); + + // reduced_openings[log2_height] += codeword + builder.assign(&codeword_acc.low, codeword_acc.low + codeword.low); + builder.assign(&codeword_acc.high, codeword_acc.high + codeword.high); + + builder.set_value(&reduced_codeword_by_height, log2_height, codeword_acc); + builder.assign(&batch_coeffs_offset, batch_coeffs_offset + width.clone()); + }, + ); + + // TODO: ensure that perm is indeed a permutation of 0, ..., opened_values.len()-1 + // reorder (opened values, dimension) according to the permutation builder - .range(0, opened_values.len()) + .range(0, opened_values_buffer.len()) .for_each(|j_vec, builder| { let j = j_vec[0]; - let mat_j = builder.get(&opened_values, j); + + let mat_j = builder.get(&opened_values_buffer, j); let num_var_j = builder.get(&round.openings, j).num_var; let height_j = builder.eval(num_var_j + Usize::from(get_rate_log() - 1)); let permuted_j = builder.get(&round.perm, j); + // let permuted_j = j; builder.set_value(&perm_opened_values, permuted_j, mat_j); builder.set_value(&dimensions, permuted_j, height_j); }); + // TODO: ensure that dimensions is indeed sorted decreasingly // i >>= (log2_max_codeword_size - commit.log2_max_codeword_size); let bits_shift: Var = builder @@ -414,48 +516,6 @@ pub(crate) fn batch_verifier_query_phase( }; mmcs_verify_batch(builder, mmcs_verifier_input); - - // TODO: optimize this procedure - iter_zip!(builder, opened_values, round.openings).for_each(|ptr_vec, builder| { - let opened_value = builder.iter_ptr_get(&opened_values, ptr_vec[0]); - let opening = builder.iter_ptr_get(&round.openings, ptr_vec[1]); - let log2_height: Var = - builder.eval(opening.num_var + Usize::from(get_rate_log() - 1)); - let width = opening.point_and_evals.evals.len(); - - let batch_coeffs_next_offset: Var = - builder.eval(batch_coeffs_offset + width.clone()); - let coeffs = input.batch_coeffs.slice( - builder, - batch_coeffs_offset.clone(), - batch_coeffs_next_offset.clone(), - ); - let low_values = opened_value.slice(builder, 0, width.clone()); - let high_values = - opened_value.slice(builder, width.clone(), opened_value.len()); - let low: Ext = builder.constant(C::EF::ZERO); - let high: Ext = builder.constant(C::EF::ZERO); - - iter_zip!(builder, coeffs, low_values, high_values).for_each( - |ptr_vec, builder| { - let coeff = builder.iter_ptr_get(&coeffs, ptr_vec[0]); - let low_value = builder.iter_ptr_get(&low_values, ptr_vec[1]); - let high_value = builder.iter_ptr_get(&high_values, ptr_vec[2]); - - builder.assign(&low, low + coeff * low_value); - builder.assign(&high, high + coeff * high_value); - }, - ); - let codeword: PackedCodeword = PackedCodeword { low, high }; - let codeword_acc = builder.get(&reduced_codeword_by_height, log2_height); - - // reduced_openings[log2_height] += codeword - builder.assign(&codeword_acc.low, codeword_acc.low + codeword.low); - builder.assign(&codeword_acc.high, codeword_acc.high + codeword.high); - - builder.set_value(&reduced_codeword_by_height, log2_height, codeword_acc); - builder.assign(&batch_coeffs_offset, batch_coeffs_next_offset); - }); }); let opening_ext = query.commit_phase_openings; @@ -718,6 +778,23 @@ pub mod tests { let pp = PCS::setup(1 << 20, mpcs::SecurityLevel::Conjecture100bits).unwrap(); let (pp, vp) = pcs_trim::(pp, 1 << 20).unwrap(); + // Sort the dimensions decreasingly and compute the permutation array + let mut dimensions_with_index = dimensions.iter().enumerate().collect::>(); + dimensions_with_index.sort_by(|(_, (a, _)), (_, (b, _))| b.cmp(a)); + // The perm array should satisfy that: sorted_dimensions[perm[i]] = dimensions[i] + // However, if we just pick the indices now, we get the inverse permutation: + // sorted_dimensions[i] = dimensions[perm[i]] + let perm = dimensions_with_index + .iter() + .map(|(i, _)| *i) + .collect::>(); + // So we need to invert the permutation + let mut inverted_perm = vec![0usize; perm.len()]; + for (i, &j) in perm.iter().enumerate() { + inverted_perm[j] = i; + } + let perm = inverted_perm; + let mut num_total_polys = 0; let (matrices, mles): (Vec<_>, Vec<_>) = dimensions .into_iter() @@ -770,6 +847,11 @@ pub mod tests { .map(|(point, _)| point.len()) .max() .unwrap(); + let max_width = point_and_evals + .iter() + .map(|(_, evals)| evals.len()) + .max() + .unwrap(); let num_rounds = max_num_var; // The final message is of length 1 // prepare folding challenges via sumcheck round msg + FRI commitment @@ -801,9 +883,11 @@ pub mod tests { >::get_number_queries(), max_num_var + >::get_rate_log(), ); + let perms = vec![perm]; let query_input = QueryPhaseVerifierInput { max_num_var, + max_width, fold_challenges, batch_coeffs, indices: queries, @@ -825,6 +909,7 @@ pub mod tests { .collect(), }) .collect(), + perms, }; let (program, witness) = build_batch_verifier_query_phase(query_input); diff --git a/src/basefold_verifier/verifier.rs b/src/basefold_verifier/verifier.rs index 5eb39de..f7de271 100644 --- a/src/basefold_verifier/verifier.rs +++ b/src/basefold_verifier/verifier.rs @@ -25,6 +25,7 @@ pub type InnerConfig = AsmConfig; pub fn batch_verify( builder: &mut Builder, max_num_var: Var, + max_width: Var, rounds: Array>, proof: BasefoldProofVariable, challenger: &mut DuplexChallengerVariable, @@ -72,12 +73,13 @@ pub fn batch_verify( builder.assign(&running_coeff, running_coeff * batch_coeff); }); - // The max num var is provided by the prover and not guaranteed to be correct. + // The max num var and max width are provided by the prover and not guaranteed to be correct. // Check that - // 1. it is greater than or equal to every num var; + // 1. max_num_var is greater than or equal to every num var (same for width); // 2. it is equal to at least one of the num vars by multiplying all the differences - // together and assert the product is zero. - let diff_product: Var = builder.eval(Usize::from(1)); + // together and assert the product is zero (same for width). + let diff_product_num_var: Var = builder.eval(Usize::from(1)); + let diff_product_width: Var = builder.eval(Usize::from(1)); iter_zip!(builder, rounds).for_each(|ptr_vec, builder| { let round = builder.iter_ptr_get(&rounds, ptr_vec[0]); @@ -86,12 +88,19 @@ pub fn batch_verify( let diff: Var = builder.eval(max_num_var.clone() - opening.num_var); // num_var is always smaller than 32. builder.range_check_var(diff, 5); - builder.assign(&diff_product, diff_product * diff); + builder.assign(&diff_product_num_var, diff_product_num_var * diff); + + let diff: Var = + builder.eval(max_width.clone() - opening.point_and_evals.evals.len()); + // width is always smaller than 2^20. + builder.range_check_var(diff, 20); + builder.assign(&diff_product_width, diff_product_width * diff); }); }); // Check that at least one num_var is equal to max_num_var let zero: Var = builder.eval(C::N::ZERO); - builder.assert_eq::>(diff_product, zero); + builder.assert_eq::>(diff_product_num_var, zero); + builder.assert_eq::>(diff_product_width, zero); let num_rounds: Var = builder.eval(max_num_var - Usize::from(get_basecode_msg_size_log())); @@ -142,6 +151,7 @@ pub fn batch_verify( let input = QueryPhaseVerifierInputVariable { max_num_var: builder.eval(max_num_var), + max_width: builder.eval(max_width), batch_coeffs, fold_challenges, indices: queries, @@ -197,6 +207,7 @@ pub mod tests { #[derive(Deserialize)] pub struct VerifierInput { pub max_num_var: usize, + pub max_width: usize, pub proof: BasefoldProof, pub rounds: Vec, } @@ -206,11 +217,13 @@ pub mod tests { fn read(builder: &mut Builder) -> Self::HintVariable { let max_num_var = usize::read(builder); + let max_width = usize::read(builder); let proof = BasefoldProof::read(builder); let rounds = Vec::::read(builder); VerifierInputVariable { max_num_var, + max_width, proof, rounds, } @@ -219,6 +232,7 @@ pub mod tests { fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); stream.extend(>::write(&self.max_num_var)); + stream.extend(>::write(&self.max_width)); stream.extend(self.proof.write()); stream.extend(self.rounds.write()); stream @@ -228,6 +242,7 @@ pub mod tests { #[derive(DslVariable, Clone)] pub struct VerifierInputVariable { pub max_num_var: Var, + pub max_width: Var, pub proof: BasefoldProofVariable, pub rounds: Array>, } @@ -241,6 +256,7 @@ pub mod tests { batch_verify( &mut builder, verifier_input.max_num_var, + verifier_input.max_width, verifier_input.rounds, verifier_input.proof, &mut challenger, @@ -254,65 +270,96 @@ pub mod tests { (program, witness_stream) } - fn construct_test(dimensions: Vec<(usize, usize)>) { + fn construct_test(dimensions: Vec>) { let mut rng = thread_rng(); // setup PCS - let pp = PCS::setup(1 << 20, mpcs::SecurityLevel::Conjecture100bits).unwrap(); - let (pp, vp) = pcs_trim::(pp, 1 << 20).unwrap(); - - let mut num_total_polys = 0; - let (matrices, mles): (Vec<_>, Vec<_>) = dimensions - .into_iter() - .map(|(num_vars, width)| { - let m = ceno_witness::RowMajorMatrix::::rand(&mut rng, 1 << num_vars, width); - let mles = m.to_mles(); - num_total_polys += width; - - (m, mles) + let pp = PCS::setup(1 << 22, mpcs::SecurityLevel::Conjecture100bits).unwrap(); + let (pp, vp) = pcs_trim::(pp, 1 << 22).unwrap(); + + let rounds = dimensions + .iter() + .map(|dimensions| { + let mut num_total_polys = 0; + let (matrices, mles): (Vec<_>, Vec<_>) = dimensions + .into_iter() + .map(|(num_vars, width)| { + let m = ceno_witness::RowMajorMatrix::::rand( + &mut rng, + 1 << num_vars, + *width, + ); + let mles = m.to_mles(); + num_total_polys += width; + + (m, mles) + }) + .unzip(); + + // commit to matrices + let pcs_data = pcs_batch_commit::(&pp, matrices).unwrap(); + + let point_and_evals = mles + .iter() + .map(|mles| { + let point = E::random_vec(mles[0].num_vars(), &mut rng); + let evals = mles.iter().map(|mle| mle.evaluate(&point)).collect_vec(); + + (point, evals) + }) + .collect_vec(); + (pcs_data, point_and_evals.clone()) }) - .unzip(); + .collect_vec(); - // commit to matrices - let pcs_data = pcs_batch_commit::(&pp, matrices).unwrap(); - let comm = PCS::get_pure_commitment(&pcs_data); + let prover_rounds = rounds + .iter() + .map(|(comm, other)| (comm, other.clone())) + .collect_vec(); - let point_and_evals = mles + let max_num_var = rounds .iter() - .map(|mles| { - let point = E::random_vec(mles[0].num_vars(), &mut rng); - let evals = mles.iter().map(|mle| mle.evaluate(&point)).collect_vec(); + .map(|round| round.1.iter().map(|(point, _)| point.len()).max().unwrap()) + .max() + .unwrap(); + let max_width = rounds + .iter() + .map(|round| round.1.iter().map(|(_, evals)| evals.len()).max().unwrap()) + .max() + .unwrap(); - (point, evals) + let verifier_rounds = rounds + .iter() + .map(|round| { + ( + PCS::get_pure_commitment(&round.0), + round + .1 + .iter() + .map(|(point, evals)| (point.len(), (point.clone(), evals.clone()))) + .collect_vec(), + ) }) .collect_vec(); // batch open let mut transcript = BasicTranscript::::new(&[]); - let rounds = vec![(&pcs_data, point_and_evals.clone())]; - let opening_proof = PCS::batch_open(&pp, rounds, &mut transcript).unwrap(); + let opening_proof = PCS::batch_open(&pp, prover_rounds, &mut transcript).unwrap(); // batch verify let mut transcript = BasicTranscript::::new(&[]); - let rounds = vec![( - comm, - point_and_evals - .iter() - .map(|(point, evals)| (point.len(), (point.clone(), evals.clone()))) - .collect_vec(), - )]; - PCS::batch_verify(&vp, rounds.clone(), &opening_proof, &mut transcript) - .expect("Native verification failed"); - - let max_num_var = point_and_evals - .iter() - .map(|(point, _)| point.len()) - .max() - .unwrap(); + PCS::batch_verify( + &vp, + verifier_rounds.clone(), + &opening_proof, + &mut transcript, + ) + .expect("Native verification failed"); let verifier_input = VerifierInput { max_num_var, - rounds: rounds + max_width, + rounds: verifier_rounds .into_iter() .map(|(commit, openings)| Round { commit: commit.into(), @@ -321,9 +368,7 @@ pub mod tests { .map(|(num_var, (point, evals))| RoundOpening { num_var, point_and_evals: PointAndEvals { - point: Point { - fs: point, - }, + point: Point { fs: point }, evals, }, }) @@ -353,24 +398,80 @@ pub mod tests { #[test] fn test_simple_batch() { for num_var in 5..20 { - construct_test(vec![(num_var, 20)]); + construct_test(vec![vec![(num_var, 20)]]); } } #[test] fn test_decreasing_batch() { - construct_test(vec![ + construct_test(vec![vec![ (14, 20), (14, 40), (13, 30), (12, 30), (11, 10), (10, 15), - ]); + ]]); } #[test] fn test_random_batch() { - construct_test(vec![(10, 20), (12, 30), (11, 10), (12, 15)]); + construct_test(vec![vec![(10, 20), (12, 30), (11, 10), (12, 15)]]); + } + + // TODO: e2e generates two rounds, only using the witness part in the current test, need to add the fixed part later + #[test] + fn test_e2e_fibonacci_batch() { + construct_test(vec![ + vec![ + (22, 22), + (22, 18), + (1, 28), + (2, 24), + (3, 18), + (1, 21), + (4, 19), + (21, 18), + (1, 8), + (1, 11), + (4, 22), + (3, 27), + (5, 22), + (16, 1), + (16, 1), + (16, 1), + (5, 1), + (16, 1), + (1, 28), + (9, 1), + (3, 2), + (3, 1), + (5, 2), + (10, 2), + (6, 3), + (14, 1), + (16, 1), + (5, 1), + (8, 1), + (4, 29), + (1, 29), + (1, 18), + (1, 23), + (21, 20), + (21, 22), + (5, 22), + ], + vec![ + (16, 3), + (16, 3), + (16, 3), + (5, 3), + (16, 3), + (9, 6), + (3, 1), + (10, 2), + (6, 3), + ], + ]); } }