From fcb8ed5fc7c6cab8ae8d0621b3425e42058cb8c1 Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Tue, 22 Apr 2025 11:35:37 -0400 Subject: [PATCH 01/70] Dense Matrix --- Cargo.lock | 12 +-- src/basefold_verifier/binding.rs | 0 src/basefold_verifier/mod.rs | 3 + src/basefold_verifier/program.rs | 1 + src/basefold_verifier/rs.rs | 162 +++++++++++++++++++++++++++++++ src/lib.rs | 1 + 6 files changed, 173 insertions(+), 6 deletions(-) create mode 100644 src/basefold_verifier/binding.rs create mode 100644 src/basefold_verifier/mod.rs create mode 100644 src/basefold_verifier/program.rs create mode 100644 src/basefold_verifier/rs.rs diff --git a/Cargo.lock b/Cargo.lock index 6174ffb..4d05e79 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -547,9 +547,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.8.1" +version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "389a099b34312839e16420d499a9cad9650541715937ffbdd40d36f49e77eeb3" +checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0" dependencies = [ "arrayref", "arrayvec", @@ -795,9 +795,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.36" +version = "4.5.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2df961d8c8a0d08aa9945718ccf584145eee3f3aa06cddbeac12933781102e04" +checksum = "eccb054f56cbd38340b380d4a8e69ef1f02f1af43db2f0cc817a4774d80ae071" dependencies = [ "clap_builder", "clap_derive", @@ -805,9 +805,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.36" +version = "4.5.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "132dbda40fb6753878316a489d5a1242a8ef2f0d9e47ba01c951ea8aa7d013a5" +checksum = "efd9466fac8543255d3b1fcad4762c5e116ffe808c8a3043d4263cd4fd4862a2" dependencies = [ "anstream", "anstyle", diff --git a/src/basefold_verifier/binding.rs b/src/basefold_verifier/binding.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/basefold_verifier/mod.rs b/src/basefold_verifier/mod.rs new file mode 100644 index 0000000..d2714ae --- /dev/null +++ b/src/basefold_verifier/mod.rs @@ -0,0 +1,3 @@ +pub(crate) mod binding; +pub(crate) mod program; +pub(crate) mod rs; \ No newline at end of file diff --git a/src/basefold_verifier/program.rs b/src/basefold_verifier/program.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/basefold_verifier/program.rs @@ -0,0 +1 @@ + diff --git a/src/basefold_verifier/rs.rs b/src/basefold_verifier/rs.rs new file mode 100644 index 0000000..192fa58 --- /dev/null +++ b/src/basefold_verifier/rs.rs @@ -0,0 +1,162 @@ +use openvm_native_compiler::{asm::AsmConfig, prelude::*}; +use openvm_native_recursion::hints::Hintable; +use openvm_stark_sdk::p3_baby_bear::BabyBear; +use p3_field::extension::BinomialExtensionField; + +pub type F = BabyBear; +pub type E = BinomialExtensionField; +pub type InnerConfig = AsmConfig; + +pub struct DenseMatrix { + pub values: Vec, + pub width: usize, +} + +impl Hintable for DenseMatrix { + type HintVariable = DenseMatrixVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let values = Vec::::read(builder); + let width = Usize::Var(usize::read(builder)); + + DenseMatrixVariable { + values, + width, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.values.write()); + stream.extend(>::write(&self.width)); + stream + } +} + +#[derive(DslVariable, Clone)] +pub struct DenseMatrixVariable { + pub values: Array>, + pub width: Usize, +} +pub type RowMajorMatrixVariable = DenseMatrixVariable; + +impl DenseMatrixVariable { + // XXX: Find better ways to handle this without cloning + pub fn pad_to_height( + &self, + builder: &mut Builder, + new_height: Usize, + fill: Ext, + ) -> DenseMatrixVariable { + // assert!(new_height >= self.height()); + let new_size = builder.eval_expr(self.width.clone() * new_height); + let evals: Array> = builder.dyn_array(new_size); + builder.range(0, self.values.len()).for_each(|i_vec, builder| { + let i = i_vec[0]; + let tmp: Ext = builder.get(&self.values, i); + builder.set(&evals, i, tmp); + }); + builder.range(self.values.len(), evals.len()).for_each(|i_vec, builder| { + let i = i_vec[0]; + builder.set(&evals, i, fill); + }); + DenseMatrixVariable:: { + values: evals, + width: self.width.clone(), + } + } +} + +/* +/// The DIT FFT algorithm. +#[derive(DslVariable, Clone)] +pub struct Radix2DitVariable { + /// Memoized twiddle factors for each length log_n. + /// Precise definition is a map from usize to E + pub twiddles: Array>, +} + +#[derive(DslVariable, Clone)] +pub struct RSCodeVerifierParametersVariable { + pub dft: Radix2DitVariable, + pub t_inv_halves: Array>, + pub full_message_size_log: Usize, +} + +pub(crate) fn encode_small( + builder: &mut Builder, + vp: RSCodeVerifierParametersVariable, + rmm: RowMajorMatrixVariable, +) -> RowMajorMatrixVariable { + let mut m = rmm; +} +*/ + +pub mod tests { + use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; + use openvm_native_circuit::{Native, NativeConfig}; + use openvm_native_compiler::asm::AsmBuilder; + use openvm_native_compiler::prelude::*; + use openvm_native_recursion::hints::Hintable; + use openvm_stark_backend::config::StarkGenericConfig; + use openvm_stark_sdk::{ + config::baby_bear_poseidon2::BabyBearPoseidon2Config, + p3_baby_bear::BabyBear, + }; + use p3_field::{extension::BinomialExtensionField, FieldAlgebra}; + type SC = BabyBearPoseidon2Config; + + type F = BabyBear; + type E = BinomialExtensionField; + type EF = ::Challenge; + use super::DenseMatrix; + + #[allow(dead_code)] + pub fn build_test_dense_matrix_pad() -> (Program, Vec>) { + // OpenVM DSL + let mut builder = AsmBuilder::::default(); + + // Witness inputs + let dense_matrix_variable = DenseMatrix::read(&mut builder); + let new_height = Usize::from(8); + let fill = Ext::new(0); + dense_matrix_variable.pad_to_height(&mut builder, new_height, fill); + builder.halt(); + + // Pass in witness stream + let mut witness_stream: Vec< + Vec>, + > = Vec::new(); + + let verifier_input = DenseMatrix { + values: vec![E::ONE; 25], + width: 5, + }; + witness_stream.extend(verifier_input.write()); + + let program: Program< + p3_monty_31::MontyField31, + > = builder.compile_isa(); + + (program, witness_stream) + } + + #[test] + fn test_dense_matrix_pad() { + let (program, witness) = build_test_dense_matrix_pad(); + + let system_config = SystemConfig::default() + .with_public_values(4) + .with_max_segment_len((1 << 25) - 100); + let config = NativeConfig::new(system_config, Native); + + let executor = VmExecutor::::new(config); + executor.execute(program, witness).unwrap(); + + // _debug + // let results = executor.execute_segments(program, witness).unwrap(); + // for seg in results { + // println!("=> cycle count: {:?}", seg.metrics.cycle_count); + // } + } +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 428deca..ca90a71 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ mod arithmetics; +mod basefold_verifier; pub mod constants; pub mod json; mod tests; From 66c77349c3bc0c2d6033da409b343e8724d8eac2 Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Tue, 22 Apr 2025 15:09:57 -0400 Subject: [PATCH 02/70] Added hints --- src/basefold_verifier/rs.rs | 78 +++++++++++++++++++++++++++++++------ 1 file changed, 67 insertions(+), 11 deletions(-) diff --git a/src/basefold_verifier/rs.rs b/src/basefold_verifier/rs.rs index 192fa58..0aa0910 100644 --- a/src/basefold_verifier/rs.rs +++ b/src/basefold_verifier/rs.rs @@ -1,3 +1,5 @@ +// Note: check all XXX comments! + use openvm_native_compiler::{asm::AsmConfig, prelude::*}; use openvm_native_recursion::hints::Hintable; use openvm_stark_sdk::p3_baby_bear::BabyBear; @@ -41,15 +43,33 @@ pub struct DenseMatrixVariable { pub type RowMajorMatrixVariable = DenseMatrixVariable; impl DenseMatrixVariable { + pub fn height( + &self, + builder: &mut Builder, + ) -> Var { + // Supply height as hint + let height = builder.hint_var(); + builder.if_eq(self.width.clone(), Usize::from(0)).then(|builder| { + builder.assert_usize_eq(height, Usize::from(0)); + }); + builder.if_ne(self.width.clone(), Usize::from(0)).then(|builder| { + // XXX: check that width * height is not a field multiplication + builder.assert_usize_eq(self.width.clone() * height, self.values.len()); + }); + height + } + // XXX: Find better ways to handle this without cloning pub fn pad_to_height( &self, builder: &mut Builder, - new_height: Usize, + new_height: RVar, fill: Ext, - ) -> DenseMatrixVariable { - // assert!(new_height >= self.height()); - let new_size = builder.eval_expr(self.width.clone() * new_height); + ) { + // XXX: Not necessary, only for testing purpose + let old_height = self.height(builder); + builder.assert_less_than_slow_small_rhs(old_height, new_height + RVar::from(1)); + let new_size = builder.eval_expr(self.width.clone() * new_height.clone()); let evals: Array> = builder.dyn_array(new_size); builder.range(0, self.values.len()).for_each(|i_vec, builder| { let i = i_vec[0]; @@ -60,10 +80,7 @@ impl DenseMatrixVariable { let i = i_vec[0]; builder.set(&evals, i, fill); }); - DenseMatrixVariable:: { - values: evals, - width: self.width.clone(), - } + builder.assign(&self.values, evals); } } @@ -76,6 +93,32 @@ pub struct Radix2DitVariable { pub twiddles: Array>, } +impl Radix2DitVariable { + fn dft_batch( + &self, + builder: &mut Builder, + mat: RowMajorMatrixVariable + ) -> RowMajorMatrixVariable { + let h = mat.height(builder); + // TODO: Verify correspondence between log_h and h + let log_h = builder.hint_var(); + + // TODO: support memoization + // Compute twiddle factors, or take memoized ones if already available. + let twiddles = { + let root = F::two_adic_generator(log_h); + root.powers().take(1 << log_h).collect() + }; + + // DIT butterfly + reverse_matrix_index_bits(&mut mat); + for layer in 0..log_h { + dit_layer(&mut mat.as_view_mut(), layer, twiddles); + } + mat + } +} + #[derive(DslVariable, Clone)] pub struct RSCodeVerifierParametersVariable { pub dft: Radix2DitVariable, @@ -83,12 +126,23 @@ pub struct RSCodeVerifierParametersVariable { pub full_message_size_log: Usize, } +fn get_rate_log() -> usize { + 1 +} + pub(crate) fn encode_small( builder: &mut Builder, vp: RSCodeVerifierParametersVariable, rmm: RowMajorMatrixVariable, ) -> RowMajorMatrixVariable { - let mut m = rmm; + let m = rmm; + // Add current setup this is unnecessary + let old_height = m.height(builder); + let new_height = builder.eval_expr( + old_height * Usize::from(1 << get_rate_log()) + ); + m.pad_to_height(builder, new_height, Ext::new(0)); + m } */ @@ -109,7 +163,7 @@ pub mod tests { type F = BabyBear; type E = BinomialExtensionField; type EF = ::Challenge; - use super::DenseMatrix; + use super::{DenseMatrix, InnerConfig}; #[allow(dead_code)] pub fn build_test_dense_matrix_pad() -> (Program, Vec>) { @@ -118,7 +172,7 @@ pub mod tests { // Witness inputs let dense_matrix_variable = DenseMatrix::read(&mut builder); - let new_height = Usize::from(8); + let new_height = RVar::from(8); let fill = Ext::new(0); dense_matrix_variable.pad_to_height(&mut builder, new_height, fill); builder.halt(); @@ -133,6 +187,8 @@ pub mod tests { width: 5, }; witness_stream.extend(verifier_input.write()); + // Hint for height + witness_stream.extend(>::write(&5)); let program: Program< p3_monty_31::MontyField31, From 6f8fe11280556dbf94cbc697a8468bc7aea11134 Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Wed, 23 Apr 2025 12:31:55 -0400 Subject: [PATCH 03/70] Finished get_base_codeword_dimensions --- src/basefold_verifier/binding.rs | 88 +++++++++++++++++ src/basefold_verifier/program.rs | 157 +++++++++++++++++++++++++++++++ src/basefold_verifier/rs.rs | 8 +- 3 files changed, 249 insertions(+), 4 deletions(-) diff --git a/src/basefold_verifier/binding.rs b/src/basefold_verifier/binding.rs index e69de29..0b1cdb5 100644 --- a/src/basefold_verifier/binding.rs +++ b/src/basefold_verifier/binding.rs @@ -0,0 +1,88 @@ +use openvm_native_compiler::{asm::AsmConfig, ir::*}; +use openvm_native_compiler_derive::DslVariable; +use openvm_native_recursion::hints::Hintable; +use openvm_stark_sdk::p3_baby_bear::BabyBear; +use p3_field::extension::BinomialExtensionField; + +pub type F = BabyBear; +pub type E = BinomialExtensionField; +pub type InnerConfig = AsmConfig; + +#[derive(DslVariable, Clone)] +pub struct CircuitIndexMetaVariable { + pub witin_num_vars: Usize, + pub witin_num_polys: Usize, + pub fixed_num_vars: Usize, + pub fixed_num_polys: Usize, +} + +pub struct CircuitIndexMeta { + pub witin_num_vars: usize, + pub witin_num_polys: usize, + pub fixed_num_vars: usize, + pub fixed_num_polys: usize, +} + +impl Hintable for CircuitIndexMeta { + type HintVariable = CircuitIndexMetaVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let witin_num_vars = Usize::Var(usize::read(builder)); + let witin_num_polys = Usize::Var(usize::read(builder)); + let fixed_num_vars = Usize::Var(usize::read(builder)); + let fixed_num_polys = Usize::Var(usize::read(builder)); + + CircuitIndexMetaVariable { + witin_num_vars, + witin_num_polys, + fixed_num_vars, + fixed_num_polys, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(>::write( + &self.witin_num_vars, + )); + stream.extend(>::write( + &self.witin_num_polys, + )); + stream.extend(>::write( + &self.fixed_num_vars, + )); + stream.extend(>::write( + &self.fixed_num_polys, + )); + stream + } +} + +#[derive(DslVariable, Clone)] +pub struct DimensionsVariable { + pub width: Var, + pub height: Var, +} + +pub struct Dimensions { + pub width: usize, + pub height: usize, +} + +impl Hintable for Dimensions { + type HintVariable = DimensionsVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let width = usize::read(builder); + let height = usize::read(builder); + + DimensionsVariable { width, height } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(>::write(&self.width)); + stream.extend(>::write(&self.height)); + stream + } +} diff --git a/src/basefold_verifier/program.rs b/src/basefold_verifier/program.rs index 8b13789..67f14ee 100644 --- a/src/basefold_verifier/program.rs +++ b/src/basefold_verifier/program.rs @@ -1 +1,158 @@ +use super::binding::*; +use openvm_native_compiler::ir::*; +use p3_field::FieldAlgebra; +use crate::basefold_verifier::rs::get_rate_log; +fn get_base_codeword_dimensions( + builder: &mut Builder, + circuit_meta_map: Array>, +) -> ( + Array>, + Array>, +) { + let dim_len = circuit_meta_map.len(); + let wit_dim: Array> = builder.dyn_array(dim_len.clone()); + let fixed_dim: Array> = builder.dyn_array(dim_len.clone()); + + builder.range(0, dim_len).for_each(|i_vec, builder| { + let i = i_vec[0]; + let tmp = builder.get(&circuit_meta_map, i); + let witin_num_vars = tmp.witin_num_vars; + let witin_num_polys = tmp.witin_num_polys; + let fixed_num_vars = tmp.fixed_num_vars; + let fixed_num_polys = tmp.fixed_num_polys; + // wit_dim + let width = builder.eval(witin_num_polys * Usize::from(2)); + let height_exp = builder.eval_expr(witin_num_vars + get_rate_log::() - Usize::from(1)); + // XXX: more efficient pow implementation + let height: Var = builder.constant(C::N::ONE); + let two: Var = builder.constant(C::N::from_canonical_usize(2)); + builder.range(0, height_exp).for_each(|_, builder| { + builder.assign(&height, height * two); + }); + let next_wit: DimensionsVariable = DimensionsVariable { + width, + height, + }; + builder.set_value(&wit_dim, i, next_wit); + + // fixed_dim + // XXX: since fixed_num_vars is usize, fixed_num_vars > 0 is equivalent to fixed_num_vars != 0 + builder.if_ne(fixed_num_vars.clone(), Usize::from(0)).then(|builder| { + let width = builder.eval(fixed_num_polys.clone() * Usize::from(2)); + let height_exp = builder.eval_expr(fixed_num_vars.clone() + get_rate_log::() - Usize::from(1)); + // XXX: more efficient pow implementation + let height: Var = builder.constant(C::N::ONE); + let two: Var = builder.constant(C::N::from_canonical_usize(2)); + builder.range(0, height_exp).for_each(|_, builder| { + builder.assign(&height, height * two); + }); + let next_fixed: DimensionsVariable = DimensionsVariable { + width, + height, + }; + builder.set_value(&fixed_dim, i, next_fixed); + }); + }); + (wit_dim, fixed_dim) +} + + +pub mod tests { + use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; + use openvm_native_circuit::{Native, NativeConfig}; + use openvm_native_compiler::asm::AsmBuilder; + use openvm_native_compiler::prelude::*; + use openvm_native_recursion::hints::Hintable; + use openvm_stark_backend::config::StarkGenericConfig; + use openvm_stark_sdk::{ + config::baby_bear_poseidon2::BabyBearPoseidon2Config, + p3_baby_bear::BabyBear, + }; + use p3_field::extension::BinomialExtensionField; + type SC = BabyBearPoseidon2Config; + + type F = BabyBear; + type E = BinomialExtensionField; + type EF = ::Challenge; + use crate::basefold_verifier::binding::*; + + use super::{get_base_codeword_dimensions, InnerConfig}; + + #[allow(dead_code)] + pub fn build_test_get_base_codeword_dimensions() -> (Program, Vec>) { + // OpenVM DSL + let mut builder = AsmBuilder::::default(); + + // Witness inputs + let map_len = Usize::Var(usize::read(&mut builder)); + let circuit_meta_map = builder.dyn_array(map_len.clone()); + builder.range(0, map_len.clone()).for_each(|i_vec, builder| { + let i = i_vec[0]; + let next_meta = CircuitIndexMeta::read(builder); + builder.set(&circuit_meta_map, i, next_meta); + }); + + let (wit_dim, fixed_dim) = get_base_codeword_dimensions(&mut builder, circuit_meta_map); + builder.range(0, map_len).for_each(|i_vec, builder| { + let i = i_vec[0]; + let wit = builder.get(&wit_dim, i); + let fixed = builder.get(&fixed_dim, i); + let i_val: Var<_> = builder.eval(i); + builder.print_v(i_val); + let ww_val: Var<_> = builder.eval(wit.width); + let wh_val: Var<_> = builder.eval(wit.height); + let fw_val: Var<_> = builder.eval(fixed.width); + let fh_val: Var<_> = builder.eval(fixed.height); + builder.print_v(ww_val); + builder.print_v(wh_val); + builder.print_v(fw_val); + builder.print_v(fh_val); + }); + builder.halt(); + + // Pass in witness stream + let mut witness_stream: Vec< + Vec>, + > = Vec::new(); + + // Map length + let map_len = 5; + witness_stream.extend(>::write(&map_len)); + for i in 0..map_len { + // Individual metas + let circuit_meta = CircuitIndexMeta { + witin_num_vars: i, + witin_num_polys: i, + fixed_num_vars: i, + fixed_num_polys: i, + }; + witness_stream.extend(circuit_meta.write()); + } + + let program: Program< + p3_monty_31::MontyField31, + > = builder.compile_isa(); + + (program, witness_stream) + } + + #[test] + fn test_dense_matrix_pad() { + let (program, witness) = build_test_get_base_codeword_dimensions(); + + let system_config = SystemConfig::default() + .with_public_values(4) + .with_max_segment_len((1 << 25) - 100); + let config = NativeConfig::new(system_config, Native); + + let executor = VmExecutor::::new(config); + executor.execute(program, witness).unwrap(); + + // _debug + // let results = executor.execute_segments(program, witness).unwrap(); + // for seg in results { + // println!("=> cycle count: {:?}", seg.metrics.cycle_count); + // } + } +} \ No newline at end of file diff --git a/src/basefold_verifier/rs.rs b/src/basefold_verifier/rs.rs index 0aa0910..1926f51 100644 --- a/src/basefold_verifier/rs.rs +++ b/src/basefold_verifier/rs.rs @@ -84,6 +84,10 @@ impl DenseMatrixVariable { } } +pub fn get_rate_log() -> Usize { + Usize::from(1) +} + /* /// The DIT FFT algorithm. #[derive(DslVariable, Clone)] @@ -126,10 +130,6 @@ pub struct RSCodeVerifierParametersVariable { pub full_message_size_log: Usize, } -fn get_rate_log() -> usize { - 1 -} - pub(crate) fn encode_small( builder: &mut Builder, vp: RSCodeVerifierParametersVariable, From 381307801d9b3529f68a6f9c025aa0ab3a4ec317 Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Wed, 23 Apr 2025 15:07:52 -0400 Subject: [PATCH 04/70] WIP mmcs --- src/basefold_verifier/binding.rs | 3 +- src/basefold_verifier/mmcs.rs | 182 +++++++++++++++++++++++++++++++ src/basefold_verifier/mod.rs | 4 +- src/basefold_verifier/program.rs | 10 +- src/basefold_verifier/utils.rs | 15 +++ 5 files changed, 205 insertions(+), 9 deletions(-) create mode 100644 src/basefold_verifier/mmcs.rs create mode 100644 src/basefold_verifier/utils.rs diff --git a/src/basefold_verifier/binding.rs b/src/basefold_verifier/binding.rs index 0b1cdb5..c0b419f 100644 --- a/src/basefold_verifier/binding.rs +++ b/src/basefold_verifier/binding.rs @@ -1,6 +1,6 @@ use openvm_native_compiler::{asm::AsmConfig, ir::*}; use openvm_native_compiler_derive::DslVariable; -use openvm_native_recursion::hints::Hintable; +use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; @@ -86,3 +86,4 @@ impl Hintable for Dimensions { stream } } +impl VecAutoHintable for Dimensions {} \ No newline at end of file diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs new file mode 100644 index 0000000..c20dbcd --- /dev/null +++ b/src/basefold_verifier/mmcs.rs @@ -0,0 +1,182 @@ +// Note: check all XXX comments! + +use openvm_native_compiler::{asm::AsmConfig, prelude::*}; +use openvm_native_recursion::hints::Hintable; +use openvm_stark_sdk::p3_baby_bear::BabyBear; +use p3_field::extension::BinomialExtensionField; +use p3_field::FieldAlgebra; + +use super::{binding::*, utils::*}; + +pub type F = BabyBear; +pub type E = BinomialExtensionField; +pub type InnerConfig = AsmConfig; + +const DIGEST_ELEMS: usize = 4; + +pub struct Hash { + pub value: [F; DIGEST_ELEMS], +} + +impl Hintable for Hash { + type HintVariable = HashVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let value = builder.uninit_fixed_array(DIGEST_ELEMS); + for i in 0..DIGEST_ELEMS { + let tmp = F::read(builder); + builder.set(&value, i, tmp); + } + + HashVariable { + value, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + // Write out each entries + for i in 0..DIGEST_ELEMS { + stream.extend(self.value[i].write()); + } + stream + } +} + +#[derive(DslVariable, Clone)] +pub struct HashVariable { + pub value: Array>, +} + +type Commitment = Hash; +type Proof = Vec<[F; DIGEST_ELEMS]>; +pub struct MmcsVerifierInput { + pub commit: Commitment, + pub dimensions: Vec, + pub index: usize, + pub opened_values: Vec>, + pub proof: Proof, +} + +impl Hintable for MmcsVerifierInput { + type HintVariable = MmcsVerifierInputVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let commit = Commitment::read(builder); + let dimensions = Vec::::read(builder); + let index = usize::read(builder); + let opened_values = Vec::>::read(builder); + let proof = Vec::>::read(builder); + + MmcsVerifierInputVariable { + commit, + dimensions, + index, + opened_values, + proof, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.commit.write()); + stream.extend(self.dimensions.write()); + stream.extend(>::write(&self.index)); + stream.extend(self.opened_values.write()); + stream.extend(self.proof.iter().map(|p| p.to_vec()).collect::>().write()); + stream + } +} + +type CommitmentVariable = HashVariable; +type ProofVariable = Array::F>>>; +#[derive(DslVariable, Clone)] +pub struct MmcsVerifierInputVariable { + pub commit: CommitmentVariable, + pub dimensions: Array>, + pub index: Var, + pub opened_values: Array>>, + pub proof: ProofVariable, +} + +pub(crate) fn verify_batch( + builder: &mut Builder, + input: MmcsVerifierInputVariable, +) { + // Check that the openings have the correct shape. + builder.assert_eq(input.dimensions.len(), input.opened_values.len()); + + // TODO: Disabled for now since TwoAdicFriPcs and CirclePcs currently pass 0 for width. + + // TODO: Disabled for now, CirclePcs sometimes passes a height that's off by 1 bit. + // Nondeterministically supply max_height and check + let max_height = builder.hint_var(); + builder.range(0, input.dimensions.len()).for_each(|i_vec, builder| { + let i = i_vec[0]; + let next_height = builder.get(&input.dimensions, i).height; + let max_height_plus_one = builder.eval(max_height + Usize::from(1)); + builder.assert_less_than_slow_bit_decomp(next_height, max_height_plus_one); + }); + + // Verify correspondence between log_h and h + let log_max_height = builder.hint_var(); + let two: Var = builder.constant(C::N::from_canonical_usize(2)); + let purported_max_height = pow(builder, two, log_max_height); + builder.assert_eq(purported_max_height, max_height); + builder.assert_eq(input.proof.len(), log_max_height); + + + let mut heights_tallest_first = dimensions + .iter() + .enumerate() + .sorted_by_key(|(_, dims)| Reverse(dims.height)) + .peekable(); + + let Some(mut curr_height_padded) = heights_tallest_first + .peek() + .map(|x| x.1.height.next_power_of_two()) + else { + // dimensions is empty + return Err(EmptyBatch); + }; + + let mut root = self.hash.hash_iter_slices( + heights_tallest_first + .peeking_take_while(|(_, dims)| { + dims.height.next_power_of_two() == curr_height_padded + }) + .map(|(i, _)| opened_values[i].as_slice()), + ); + + for &sibling in proof { + let (left, right) = if index & 1 == 0 { + (root, sibling) + } else { + (sibling, root) + }; + + root = self.compress.compress([left, right]); + index >>= 1; + curr_height_padded >>= 1; + + let next_height = heights_tallest_first + .peek() + .map(|(_, dims)| dims.height) + .filter(|h| h.next_power_of_two() == curr_height_padded); + if let Some(next_height) = next_height { + let next_height_openings_digest = self.hash.hash_iter_slices( + heights_tallest_first + .peeking_take_while(|(_, dims)| dims.height == next_height) + .map(|(i, _)| opened_values[i].as_slice()), + ); + + root = self.compress.compress([root, next_height_openings_digest]); + } + } + + if commit == &root { + Ok(()) + } else { + Err(RootMismatch) + } +} diff --git a/src/basefold_verifier/mod.rs b/src/basefold_verifier/mod.rs index d2714ae..c02d4ee 100644 --- a/src/basefold_verifier/mod.rs +++ b/src/basefold_verifier/mod.rs @@ -1,3 +1,5 @@ pub(crate) mod binding; pub(crate) mod program; -pub(crate) mod rs; \ No newline at end of file +pub(crate) mod rs; +pub(crate) mod mmcs; +pub(crate) mod utils; \ No newline at end of file diff --git a/src/basefold_verifier/program.rs b/src/basefold_verifier/program.rs index 67f14ee..d9b6c1a 100644 --- a/src/basefold_verifier/program.rs +++ b/src/basefold_verifier/program.rs @@ -1,4 +1,4 @@ -use super::binding::*; +use super::{binding::*, utils::*}; use openvm_native_compiler::ir::*; use p3_field::FieldAlgebra; use crate::basefold_verifier::rs::get_rate_log; @@ -23,13 +23,9 @@ fn get_base_codeword_dimensions( let fixed_num_polys = tmp.fixed_num_polys; // wit_dim let width = builder.eval(witin_num_polys * Usize::from(2)); - let height_exp = builder.eval_expr(witin_num_vars + get_rate_log::() - Usize::from(1)); - // XXX: more efficient pow implementation - let height: Var = builder.constant(C::N::ONE); + let height_exp = builder.eval(witin_num_vars + get_rate_log::() - Usize::from(1)); let two: Var = builder.constant(C::N::from_canonical_usize(2)); - builder.range(0, height_exp).for_each(|_, builder| { - builder.assign(&height, height * two); - }); + let height = pow(builder, two, height_exp); let next_wit: DimensionsVariable = DimensionsVariable { width, height, diff --git a/src/basefold_verifier/utils.rs b/src/basefold_verifier/utils.rs new file mode 100644 index 0000000..40bf9b8 --- /dev/null +++ b/src/basefold_verifier/utils.rs @@ -0,0 +1,15 @@ +use openvm_native_compiler::ir::*; +use p3_field::FieldAlgebra; + +// XXX: more efficient pow implementation +pub fn pow( + builder: &mut Builder, + base: Var, + exponent: Var, +) -> Var { + let value: Var = builder.constant(C::N::ONE); + builder.range(0, exponent).for_each(|_, builder| { + builder.assign(&value, value * base); + }); + value +} \ No newline at end of file From b10539ba704f3ef9cfee95abce97e4b26f337206 Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Thu, 24 Apr 2025 14:39:19 -0400 Subject: [PATCH 05/70] WIP mmcs --- src/basefold_verifier/hash.rs | 53 ++++++++++++ src/basefold_verifier/mmcs.rs | 142 +++++++++++++++++-------------- src/basefold_verifier/mod.rs | 1 + src/basefold_verifier/program.rs | 12 +-- src/basefold_verifier/utils.rs | 25 ++++++ 5 files changed, 161 insertions(+), 72 deletions(-) create mode 100644 src/basefold_verifier/hash.rs diff --git a/src/basefold_verifier/hash.rs b/src/basefold_verifier/hash.rs new file mode 100644 index 0000000..d07ad43 --- /dev/null +++ b/src/basefold_verifier/hash.rs @@ -0,0 +1,53 @@ +use openvm_native_compiler::{asm::AsmConfig, prelude::*}; +use openvm_native_recursion::hints::Hintable; +use openvm_stark_sdk::p3_baby_bear::BabyBear; +use p3_field::extension::BinomialExtensionField; + +pub type F = BabyBear; +pub type E = BinomialExtensionField; +pub type InnerConfig = AsmConfig; + +pub const DIGEST_ELEMS: usize = 4; + +pub struct Hash { + pub value: [F; DIGEST_ELEMS], +} + +impl Hintable for Hash { + type HintVariable = HashVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let value = builder.uninit_fixed_array(DIGEST_ELEMS); + for i in 0..DIGEST_ELEMS { + let tmp = F::read(builder); + builder.set(&value, i, tmp); + } + + HashVariable { + value, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + // Write out each entries + for i in 0..DIGEST_ELEMS { + stream.extend(self.value[i].write()); + } + stream + } +} + +#[derive(DslVariable, Clone)] +pub struct HashVariable { + pub value: Array>, +} + +pub fn hash_iter_slices( + builder: &mut Builder, + // _hash: HashVariable, + _values: Array>>, +) -> Array> { + // XXX: verify hash + builder.hint_felts_fixed(DIGEST_ELEMS) +} \ No newline at end of file diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index c20dbcd..b939950 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -4,50 +4,13 @@ use openvm_native_compiler::{asm::AsmConfig, prelude::*}; use openvm_native_recursion::hints::Hintable; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; -use p3_field::FieldAlgebra; -use super::{binding::*, utils::*}; +use super::{binding::*, utils::*, hash::*}; pub type F = BabyBear; pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; -const DIGEST_ELEMS: usize = 4; - -pub struct Hash { - pub value: [F; DIGEST_ELEMS], -} - -impl Hintable for Hash { - type HintVariable = HashVariable; - - fn read(builder: &mut Builder) -> Self::HintVariable { - let value = builder.uninit_fixed_array(DIGEST_ELEMS); - for i in 0..DIGEST_ELEMS { - let tmp = F::read(builder); - builder.set(&value, i, tmp); - } - - HashVariable { - value, - } - } - - fn write(&self) -> Vec::N>> { - let mut stream = Vec::new(); - // Write out each entries - for i in 0..DIGEST_ELEMS { - stream.extend(self.value[i].write()); - } - stream - } -} - -#[derive(DslVariable, Clone)] -pub struct HashVariable { - pub value: Array>, -} - type Commitment = Hash; type Proof = Vec<[F; DIGEST_ELEMS]>; pub struct MmcsVerifierInput { @@ -104,12 +67,15 @@ pub(crate) fn verify_batch( input: MmcsVerifierInputVariable, ) { // Check that the openings have the correct shape. - builder.assert_eq(input.dimensions.len(), input.opened_values.len()); + let num_dims = input.dimensions.len(); + // Assert dimensions is not empty + builder.assert_nonzero(&num_dims); + builder.assert_eq(num_dims.clone(), input.opened_values.len()); // TODO: Disabled for now since TwoAdicFriPcs and CirclePcs currently pass 0 for width. // TODO: Disabled for now, CirclePcs sometimes passes a height that's off by 1 bit. - // Nondeterministically supply max_height and check + // Nondeterministically supplies max_height let max_height = builder.hint_var(); builder.range(0, input.dimensions.len()).for_each(|i_vec, builder| { let i = i_vec[0]; @@ -120,33 +86,83 @@ pub(crate) fn verify_batch( // Verify correspondence between log_h and h let log_max_height = builder.hint_var(); - let two: Var = builder.constant(C::N::from_canonical_usize(2)); - let purported_max_height = pow(builder, two, log_max_height); + let purported_max_height = pow_2(builder, log_max_height); builder.assert_eq(purported_max_height, max_height); builder.assert_eq(input.proof.len(), log_max_height); + // Nondeterministically supplies: + // 1. num_unique_height: number of different heights + // 2. unique_height_count: for each unique height, number of dimensions of that height + // 3. height_order: after sorting by decreasing height, the original index of each entry + // 4. height_diff: whether the height of the sorted entry differ from the next. 0 - no diff; 1 - diff + let num_unique_height = builder.hint_var(); + let unique_height_count = builder.dyn_array(num_unique_height); + builder.range(0, num_unique_height).for_each(|i_vec, builder| { + let i = i_vec[0]; + let mut next_count = builder.hint_var(); + builder.set_value(&unique_height_count, i, next_count); + }); + + let height_order = builder.dyn_array(num_dims); + let height_diff = builder.dyn_array(num_dims); + let mut last_order = builder.hint_var(); + let mut last_diff = builder.hint_var(); + builder.set_value(&height_order, 0, last_order); + builder.set_value(&height_diff, 0, last_diff); + let mut last_height = builder.get(&input.dimensions, last_order).height; + + let curr_height_padded = next_power_of_two(builder, last_height); + + let last_unique_height_index: Var = builder.eval(Usize::from(0)); + let last_unique_height_count: Var = builder.eval(Usize::from(1)); + builder.range(1, num_dims).for_each(|i_vec, builder| { + let i = i_vec[0]; + let mut next_order = builder.hint_var(); + let mut next_diff = builder.hint_var(); + let next_height = builder.get(&input.dimensions, next_order).height; + + builder.if_eq(last_diff, Usize::from(0)).then(|builder| { + // last_diff == 0 ==> next_height == last_height + builder.assert_eq(last_height, next_height); + builder.assign(&last_unique_height_count, last_unique_height_count + Usize::from(1)); + }); + builder.if_ne(last_diff, Usize::from(0)).then(|builder| { + // last_diff != 0 ==> next_height < last_height + builder.assert_less_than_slow_small_rhs(next_height, last_height); + + // Verify correctness of unique_height_count + let purported_unique_height_count = builder.get(&unique_height_count, last_unique_height_index); + builder.assert_eq(purported_unique_height_count, last_unique_height_count); + builder.assign(&last_unique_height_index, last_unique_height_index + Usize::from(1)); + builder.assign(&last_unique_height_count, Usize::from(1)); + }); + + last_order = next_order; + last_diff = next_diff; + builder.set_value(&height_order, i, last_order); + builder.set_value(&height_diff, i, last_diff); + last_height = next_height; + }); + // Final check on num_unique_height and unique_height_count + let purported_unique_height_count = builder.get(&unique_height_count, last_unique_height_index); + builder.assert_eq(purported_unique_height_count, last_unique_height_count); + builder.assign(&last_unique_height_index, last_unique_height_index + Usize::from(1)); + builder.assert_eq(last_unique_height_index, num_unique_height); + + // Construct root through hashing + let root_dims_count = builder.get(&unique_height_count, 0); + let root_values = builder.dyn_array(root_dims_count); + builder.range(0, root_dims_count).for_each(|i_vec, builder| { + let i = i_vec[0]; + let tmp = builder.get(&input.opened_values, i); + builder.set_value(&root_values, i, tmp); + }); + let root = hash_iter_slices(builder, root_values); - let mut heights_tallest_first = dimensions - .iter() - .enumerate() - .sorted_by_key(|(_, dims)| Reverse(dims.height)) - .peekable(); - - let Some(mut curr_height_padded) = heights_tallest_first - .peek() - .map(|x| x.1.height.next_power_of_two()) - else { - // dimensions is empty - return Err(EmptyBatch); - }; - - let mut root = self.hash.hash_iter_slices( - heights_tallest_first - .peeking_take_while(|(_, dims)| { - dims.height.next_power_of_two() == curr_height_padded - }) - .map(|(i, _)| opened_values[i].as_slice()), - ); + builder.range(0, input.proof.len()).for_each(|i_vec, builder| { + let i = i_vec[0]; + + }); for &sibling in proof { let (left, right) = if index & 1 == 0 { diff --git a/src/basefold_verifier/mod.rs b/src/basefold_verifier/mod.rs index c02d4ee..825ab09 100644 --- a/src/basefold_verifier/mod.rs +++ b/src/basefold_verifier/mod.rs @@ -2,4 +2,5 @@ pub(crate) mod binding; pub(crate) mod program; pub(crate) mod rs; pub(crate) mod mmcs; +pub(crate) mod hash; pub(crate) mod utils; \ No newline at end of file diff --git a/src/basefold_verifier/program.rs b/src/basefold_verifier/program.rs index d9b6c1a..419ef70 100644 --- a/src/basefold_verifier/program.rs +++ b/src/basefold_verifier/program.rs @@ -1,6 +1,5 @@ use super::{binding::*, utils::*}; use openvm_native_compiler::ir::*; -use p3_field::FieldAlgebra; use crate::basefold_verifier::rs::get_rate_log; fn get_base_codeword_dimensions( @@ -24,8 +23,7 @@ fn get_base_codeword_dimensions( // wit_dim let width = builder.eval(witin_num_polys * Usize::from(2)); let height_exp = builder.eval(witin_num_vars + get_rate_log::() - Usize::from(1)); - let two: Var = builder.constant(C::N::from_canonical_usize(2)); - let height = pow(builder, two, height_exp); + let height = pow_2(builder, height_exp); let next_wit: DimensionsVariable = DimensionsVariable { width, height, @@ -36,13 +34,9 @@ fn get_base_codeword_dimensions( // XXX: since fixed_num_vars is usize, fixed_num_vars > 0 is equivalent to fixed_num_vars != 0 builder.if_ne(fixed_num_vars.clone(), Usize::from(0)).then(|builder| { let width = builder.eval(fixed_num_polys.clone() * Usize::from(2)); - let height_exp = builder.eval_expr(fixed_num_vars.clone() + get_rate_log::() - Usize::from(1)); + let height_exp = builder.eval(fixed_num_vars.clone() + get_rate_log::() - Usize::from(1)); // XXX: more efficient pow implementation - let height: Var = builder.constant(C::N::ONE); - let two: Var = builder.constant(C::N::from_canonical_usize(2)); - builder.range(0, height_exp).for_each(|_, builder| { - builder.assign(&height, height * two); - }); + let height = pow_2(builder, height_exp); let next_fixed: DimensionsVariable = DimensionsVariable { width, height, diff --git a/src/basefold_verifier/utils.rs b/src/basefold_verifier/utils.rs index 40bf9b8..ceb77f9 100644 --- a/src/basefold_verifier/utils.rs +++ b/src/basefold_verifier/utils.rs @@ -12,4 +12,29 @@ pub fn pow( builder.assign(&value, value * base); }); value +} + +pub fn pow_2( + builder: &mut Builder, + exponent: Var, +) -> Var { + let two: Var = builder.constant(C::N::from_canonical_usize(2)); + pow(builder, two, exponent) +} + +// XXX: Equally outrageously inefficient +pub fn next_power_of_two( + builder: &mut Builder, + value: Var, +) -> Var { + // Non-deterministically supply the exponent n such that + // 2^n < v <= 2^{n+1} + let n: Var = builder.hint_var(); + let ret = pow_2(builder, n); + builder.assert_less_than_slow_bit_decomp(ret, value); + let two: Var = builder.constant(C::N::from_canonical_usize(2)); + builder.assign(&ret, ret * two); + let ret_plus_one = builder.eval(ret.clone() + Usize::from(1)); + builder.assert_less_than_slow_bit_decomp(value, ret_plus_one); + ret } \ No newline at end of file From 8c35009d2ee342e85db563414deb3c88541d63f1 Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Fri, 25 Apr 2025 12:29:02 -0400 Subject: [PATCH 06/70] Update mmcs --- src/basefold_verifier/hash.rs | 9 +++ src/basefold_verifier/mmcs.rs | 132 ++++++++++++++++++++++------------ 2 files changed, 95 insertions(+), 46 deletions(-) diff --git a/src/basefold_verifier/hash.rs b/src/basefold_verifier/hash.rs index d07ad43..4a3497d 100644 --- a/src/basefold_verifier/hash.rs +++ b/src/basefold_verifier/hash.rs @@ -50,4 +50,13 @@ pub fn hash_iter_slices( ) -> Array> { // XXX: verify hash builder.hint_felts_fixed(DIGEST_ELEMS) +} + +pub fn compress( + builder: &mut Builder, + // _hash: HashVariable, + _values: Array>>, +) -> Array> { + // XXX: verify hash + builder.hint_felts_fixed(DIGEST_ELEMS) } \ No newline at end of file diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index b939950..fd1ca19 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -70,7 +70,7 @@ pub(crate) fn verify_batch( let num_dims = input.dimensions.len(); // Assert dimensions is not empty builder.assert_nonzero(&num_dims); - builder.assert_eq(num_dims.clone(), input.opened_values.len()); + builder.assert_usize_eq(num_dims.clone(), input.opened_values.len()); // TODO: Disabled for now since TwoAdicFriPcs and CirclePcs currently pass 0 for width. @@ -87,8 +87,8 @@ pub(crate) fn verify_batch( // Verify correspondence between log_h and h let log_max_height = builder.hint_var(); let purported_max_height = pow_2(builder, log_max_height); - builder.assert_eq(purported_max_height, max_height); - builder.assert_eq(input.proof.len(), log_max_height); + builder.assert_var_eq(purported_max_height, max_height); + builder.assert_usize_eq(input.proof.len(), log_max_height); // Nondeterministically supplies: // 1. num_unique_height: number of different heights @@ -99,12 +99,12 @@ pub(crate) fn verify_batch( let unique_height_count = builder.dyn_array(num_unique_height); builder.range(0, num_unique_height).for_each(|i_vec, builder| { let i = i_vec[0]; - let mut next_count = builder.hint_var(); + let next_count = builder.hint_var(); builder.set_value(&unique_height_count, i, next_count); }); - let height_order = builder.dyn_array(num_dims); - let height_diff = builder.dyn_array(num_dims); + let height_order = builder.dyn_array(num_dims.clone()); + let height_diff = builder.dyn_array(num_dims.clone()); let mut last_order = builder.hint_var(); let mut last_diff = builder.hint_var(); builder.set_value(&height_order, 0, last_order); @@ -117,13 +117,13 @@ pub(crate) fn verify_batch( let last_unique_height_count: Var = builder.eval(Usize::from(1)); builder.range(1, num_dims).for_each(|i_vec, builder| { let i = i_vec[0]; - let mut next_order = builder.hint_var(); - let mut next_diff = builder.hint_var(); + let next_order = builder.hint_var(); + let next_diff = builder.hint_var(); let next_height = builder.get(&input.dimensions, next_order).height; builder.if_eq(last_diff, Usize::from(0)).then(|builder| { // last_diff == 0 ==> next_height == last_height - builder.assert_eq(last_height, next_height); + builder.assert_var_eq(last_height, next_height); builder.assign(&last_unique_height_count, last_unique_height_count + Usize::from(1)); }); builder.if_ne(last_diff, Usize::from(0)).then(|builder| { @@ -132,7 +132,7 @@ pub(crate) fn verify_batch( // Verify correctness of unique_height_count let purported_unique_height_count = builder.get(&unique_height_count, last_unique_height_index); - builder.assert_eq(purported_unique_height_count, last_unique_height_count); + builder.assert_var_eq(purported_unique_height_count, last_unique_height_count); builder.assign(&last_unique_height_index, last_unique_height_index + Usize::from(1)); builder.assign(&last_unique_height_count, Usize::from(1)); }); @@ -145,54 +145,94 @@ pub(crate) fn verify_batch( }); // Final check on num_unique_height and unique_height_count let purported_unique_height_count = builder.get(&unique_height_count, last_unique_height_index); - builder.assert_eq(purported_unique_height_count, last_unique_height_count); + builder.assert_var_eq(purported_unique_height_count, last_unique_height_count); builder.assign(&last_unique_height_index, last_unique_height_index + Usize::from(1)); - builder.assert_eq(last_unique_height_index, num_unique_height); + builder.assert_var_eq(last_unique_height_index, num_unique_height); // Construct root through hashing let root_dims_count = builder.get(&unique_height_count, 0); let root_values = builder.dyn_array(root_dims_count); builder.range(0, root_dims_count).for_each(|i_vec, builder| { let i = i_vec[0]; - let tmp = builder.get(&input.opened_values, i); + let index = builder.get(&height_order, i); + let tmp = builder.get(&input.opened_values, index); builder.set_value(&root_values, i, tmp); }); let root = hash_iter_slices(builder, root_values); + // Index_pow and reassembled_index for bit split + let index_pow: Var = builder.eval(Usize::from(1)); + let reassembled_index: Var = builder.eval(Usize::from(0)); + // next_height is the height of the next dim to be incorporated into root + let next_unique_height_index: Var = builder.eval(Usize::from(1)); + let next_unique_height_count: Var = builder.eval(root_dims_count); + let next_height = builder.get(&input.dimensions, next_unique_height_count).height; + let next_height_padded = next_power_of_two(builder, next_height); builder.range(0, input.proof.len()).for_each(|i_vec, builder| { let i = i_vec[0]; - - }); - - for &sibling in proof { - let (left, right) = if index & 1 == 0 { - (root, sibling) - } else { - (sibling, root) - }; - - root = self.compress.compress([left, right]); - index >>= 1; - curr_height_padded >>= 1; - - let next_height = heights_tallest_first - .peek() - .map(|(_, dims)| dims.height) - .filter(|h| h.next_power_of_two() == curr_height_padded); - if let Some(next_height) = next_height { - let next_height_openings_digest = self.hash.hash_iter_slices( - heights_tallest_first - .peeking_take_while(|(_, dims)| dims.height == next_height) - .map(|(i, _)| opened_values[i].as_slice()), - ); - - root = self.compress.compress([root, next_height_openings_digest]); - } - } + let sibling = builder.get(&input.proof, i); + let two_var: Var = builder.eval(Usize::from(2)); // XXX: is there a better way to do this? + // Supply the next index bit as hint, assert that it is a bit + let next_index_bit = builder.hint_var(); + builder.assert_var_eq(next_index_bit, next_index_bit * next_index_bit); + builder.assign(&reassembled_index, reassembled_index + index_pow * next_index_bit); + builder.assign(&index_pow, index_pow * two_var); + + // left, right + let compress_elem = builder.dyn_array(2); + builder.if_eq(next_index_bit, Usize::from(0)).then(|builder| { + // root, sibling + builder.set_value(&compress_elem, 0, root.clone()); + builder.set_value(&compress_elem, 0, sibling.clone()); + }); + builder.if_ne(next_index_bit, Usize::from(0)).then(|builder| { + // sibling, root + builder.set_value(&compress_elem, 0, sibling.clone()); + builder.set_value(&compress_elem, 0, root.clone()); + }); + let new_root = compress(builder, compress_elem); + builder.assign(&root, new_root); + + // curr_height_padded >>= 1 given curr_height_padded is a power of two + // Nondeterministically supply next_curr_height_padded + let next_curr_height_padded = builder.hint_var(); + builder.assert_var_eq(next_curr_height_padded * two_var, curr_height_padded); + builder.assign(&curr_height_padded, next_curr_height_padded); + + // determine whether next_height matches curr_height + builder.if_eq(curr_height_padded, next_height_padded).then(|builder| { + // hash opened_values of all dims of next_height to root + let root_dims_count = builder.get(&unique_height_count, next_unique_height_index); + let root_size: Var = builder.eval(root_dims_count + Usize::from(1)); + let root_values = builder.dyn_array(root_size); + builder.set_value(&root_values, 0, root.clone()); + builder.range(0, root_dims_count).for_each(|i_vec, builder| { + let i = i_vec[0]; + let index = builder.get(&height_order, i); + let tmp = builder.get(&input.opened_values, index); + let j = builder.eval_expr(i + RVar::from(1)); + builder.set_value(&root_values, j, tmp); + }); + let new_root = hash_iter_slices(builder, root_values); + builder.assign(&root, new_root); + + // Update parameters + builder.assign(&next_unique_height_count, next_unique_height_count + root_dims_count); + builder.assign(&next_unique_height_index, next_unique_height_index + Usize::from(1)); + builder.if_eq(next_unique_height_index, num_unique_height).then(|builder| { + builder.assign(&next_height_padded, Usize::from(0)); + }); + let next_height = builder.get(&input.dimensions, next_unique_height_count).height; + let next_tmp_height_padded = next_power_of_two(builder, next_height); + builder.assign(&next_height_padded, next_tmp_height_padded); + }); - if commit == &root { - Ok(()) - } else { - Err(RootMismatch) - } + }); + builder.assert_var_eq(reassembled_index, input.index); + builder.range(0, DIGEST_ELEMS).for_each(|i_vec, builder| { + let i = i_vec[0]; + let next_input = builder.get(&input.commit.value, i); + let next_root = builder.get(&root, i); + builder.assert_felt_eq(next_input, next_root); + }); } From 3afc01b9e37ca372447099307e954e531717f5a0 Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Fri, 25 Apr 2025 20:46:34 -0400 Subject: [PATCH 07/70] WIP mmcs --- src/basefold_verifier/mmcs.rs | 104 ++++++++++++++++++++++++++++++++-- 1 file changed, 99 insertions(+), 5 deletions(-) diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index fd1ca19..c40456f 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -12,13 +12,13 @@ pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; type Commitment = Hash; -type Proof = Vec<[F; DIGEST_ELEMS]>; +type MmcsProof = Vec<[F; DIGEST_ELEMS]>; pub struct MmcsVerifierInput { pub commit: Commitment, pub dimensions: Vec, pub index: usize, pub opened_values: Vec>, - pub proof: Proof, + pub proof: MmcsProof, } impl Hintable for MmcsVerifierInput { @@ -52,17 +52,17 @@ impl Hintable for MmcsVerifierInput { } type CommitmentVariable = HashVariable; -type ProofVariable = Array::F>>>; +type MmcsProofVariable = Array::F>>>; #[derive(DslVariable, Clone)] pub struct MmcsVerifierInputVariable { pub commit: CommitmentVariable, pub dimensions: Array>, pub index: Var, pub opened_values: Array>>, - pub proof: ProofVariable, + pub proof: MmcsProofVariable, } -pub(crate) fn verify_batch( +pub(crate) fn mmcs_verify_batch( builder: &mut Builder, input: MmcsVerifierInputVariable, ) { @@ -236,3 +236,97 @@ pub(crate) fn verify_batch( builder.assert_felt_eq(next_input, next_root); }); } + +pub mod tests { + use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; + use openvm_native_circuit::{Native, NativeConfig}; + use openvm_native_compiler::asm::AsmBuilder; + use openvm_native_recursion::hints::Hintable; + use openvm_stark_backend::config::StarkGenericConfig; + use openvm_stark_sdk::{ + config::baby_bear_poseidon2::BabyBearPoseidon2Config, + p3_baby_bear::BabyBear, + }; + use p3_field::{extension::BinomialExtensionField, FieldAlgebra}; + type SC = BabyBearPoseidon2Config; + + type F = BabyBear; + type E = BinomialExtensionField; + type EF = ::Challenge; + use crate::basefold_verifier::{binding::Dimensions, hash::DIGEST_ELEMS}; + + use super::{mmcs_verify_batch, Commitment, InnerConfig, MmcsVerifierInput}; + + #[allow(dead_code)] + pub fn build_mmcs_verify_batch() -> (Program, Vec>) { + // OpenVM DSL + let mut builder = AsmBuilder::::default(); + + // Witness inputs + let mmcs_input = MmcsVerifierInput::read(&mut builder); + mmcs_verify_batch(&mut builder, mmcs_input); + builder.halt(); + + // Pass in witness stream + let f = |n: usize| F::from_canonical_usize(n); + let mut witness_stream: Vec< + Vec>, + > = Vec::new(); + let commit = Commitment { + value: [f(0); DIGEST_ELEMS] + }; + let dimensions = vec![ + Dimensions { width: 9, height: 9 }, + Dimensions { width: 7, height: 7 }, + Dimensions { width: 5, height: 5 }, + Dimensions { width: 3, height: 3 }, + ]; + let opened_values = vec![ + vec![f(47), f(22), f(14), f(6)], + vec![f(35), f(3), f(1)], + vec![f(29), f(11), f(2)], + vec![f(14), f(4)], + ]; + let proof = vec![ + [f(47); 4], + [f(35); 4], + [f(29); 4], + [f(14); 4], + ]; + let mmcs_input = MmcsVerifierInput { + commit, + dimensions, + index: 7, + opened_values, + proof, + }; + witness_stream.extend(mmcs_input.write()); + // Hints + witness_stream.extend(>::write(&5)); + + let program: Program< + p3_monty_31::MontyField31, + > = builder.compile_isa(); + + (program, witness_stream) + } + + #[test] + fn test_mmcs_verify_batch() { + let (program, witness) = build_mmcs_verify_batch(); + + let system_config = SystemConfig::default() + .with_public_values(4) + .with_max_segment_len((1 << 25) - 100); + let config = NativeConfig::new(system_config, Native); + + let executor = VmExecutor::::new(config); + executor.execute(program, witness).unwrap(); + + // _debug + // let results = executor.execute_segments(program, witness).unwrap(); + // for seg in results { + // println!("=> cycle count: {:?}", seg.metrics.cycle_count); + // } + } +} \ No newline at end of file From 1156331f1746dfd31059cbc5aabff8f023d8f073 Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Mon, 28 Apr 2025 14:27:21 -0400 Subject: [PATCH 08/70] Finished MMCS --- src/basefold_verifier/hash.rs | 4 +- src/basefold_verifier/mmcs.rs | 295 ++++++++++++++++++++++++++++----- src/basefold_verifier/utils.rs | 16 +- 3 files changed, 264 insertions(+), 51 deletions(-) diff --git a/src/basefold_verifier/hash.rs b/src/basefold_verifier/hash.rs index 4a3497d..64c0c8a 100644 --- a/src/basefold_verifier/hash.rs +++ b/src/basefold_verifier/hash.rs @@ -3,12 +3,12 @@ use openvm_native_recursion::hints::Hintable; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; +pub const DIGEST_ELEMS: usize = 8; + pub type F = BabyBear; pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; -pub const DIGEST_ELEMS: usize = 4; - pub struct Hash { pub value: [F; DIGEST_ELEMS], } diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index c40456f..1ec9114 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -80,14 +80,17 @@ pub(crate) fn mmcs_verify_batch( builder.range(0, input.dimensions.len()).for_each(|i_vec, builder| { let i = i_vec[0]; let next_height = builder.get(&input.dimensions, i).height; - let max_height_plus_one = builder.eval(max_height + Usize::from(1)); - builder.assert_less_than_slow_bit_decomp(next_height, max_height_plus_one); + let max_height_plus_one: Var = builder.eval(max_height + Usize::from(1)); + builder.assert_less_than_slow_small_rhs(next_height, max_height_plus_one); }); // Verify correspondence between log_h and h let log_max_height = builder.hint_var(); - let purported_max_height = pow_2(builder, log_max_height); - builder.assert_var_eq(purported_max_height, max_height); + let log_max_height_minus_1: Var = builder.eval(log_max_height - Usize::from(1)); + let purported_max_height_lower_bound = pow_2(builder, log_max_height_minus_1); + let purported_max_height_upper_bound = pow_2(builder, log_max_height); + builder.assert_less_than_slow_small_rhs(purported_max_height_lower_bound, max_height); + builder.assert_less_than_slow_small_rhs(max_height, purported_max_height_upper_bound); builder.assert_usize_eq(input.proof.len(), log_max_height); // Nondeterministically supplies: @@ -104,51 +107,48 @@ pub(crate) fn mmcs_verify_batch( }); let height_order = builder.dyn_array(num_dims.clone()); - let height_diff = builder.dyn_array(num_dims.clone()); - let mut last_order = builder.hint_var(); - let mut last_diff = builder.hint_var(); + let last_order = builder.hint_var(); builder.set_value(&height_order, 0, last_order); - builder.set_value(&height_diff, 0, last_diff); - let mut last_height = builder.get(&input.dimensions, last_order).height; - - let curr_height_padded = next_power_of_two(builder, last_height); + let last_height = builder.get(&input.dimensions, last_order).height; let last_unique_height_index: Var = builder.eval(Usize::from(0)); let last_unique_height_count: Var = builder.eval(Usize::from(1)); builder.range(1, num_dims).for_each(|i_vec, builder| { let i = i_vec[0]; let next_order = builder.hint_var(); - let next_diff = builder.hint_var(); let next_height = builder.get(&input.dimensions, next_order).height; - builder.if_eq(last_diff, Usize::from(0)).then(|builder| { - // last_diff == 0 ==> next_height == last_height - builder.assert_var_eq(last_height, next_height); + builder.if_eq(last_height, next_height).then(|builder| { + // next_height == last_height builder.assign(&last_unique_height_count, last_unique_height_count + Usize::from(1)); }); - builder.if_ne(last_diff, Usize::from(0)).then(|builder| { - // last_diff != 0 ==> next_height < last_height + builder.if_ne(last_height, next_height).then(|builder| { + // next_height < last_height builder.assert_less_than_slow_small_rhs(next_height, last_height); // Verify correctness of unique_height_count let purported_unique_height_count = builder.get(&unique_height_count, last_unique_height_index); builder.assert_var_eq(purported_unique_height_count, last_unique_height_count); + builder.assign(&last_height, next_height); builder.assign(&last_unique_height_index, last_unique_height_index + Usize::from(1)); builder.assign(&last_unique_height_count, Usize::from(1)); }); - last_order = next_order; - last_diff = next_diff; + builder.assign(&last_order, next_order); builder.set_value(&height_order, i, last_order); - builder.set_value(&height_diff, i, last_diff); - last_height = next_height; }); + // Final check on num_unique_height and unique_height_count let purported_unique_height_count = builder.get(&unique_height_count, last_unique_height_index); builder.assert_var_eq(purported_unique_height_count, last_unique_height_count); builder.assign(&last_unique_height_index, last_unique_height_index + Usize::from(1)); builder.assert_var_eq(last_unique_height_index, num_unique_height); + // First padded_height + let first_order = builder.get(&height_order, 0); + let first_height = builder.get(&input.dimensions, first_order).height; + let curr_height_padded = next_power_of_two(builder, first_height); + // Construct root through hashing let root_dims_count = builder.get(&unique_height_count, 0); let root_values = builder.dyn_array(root_dims_count); @@ -166,8 +166,12 @@ pub(crate) fn mmcs_verify_batch( // next_height is the height of the next dim to be incorporated into root let next_unique_height_index: Var = builder.eval(Usize::from(1)); let next_unique_height_count: Var = builder.eval(root_dims_count); - let next_height = builder.get(&input.dimensions, next_unique_height_count).height; - let next_height_padded = next_power_of_two(builder, next_height); + let next_height_padded: Var = builder.eval(Usize::from(0)); + builder.if_ne(num_unique_height, Usize::from(1)).then(|builder| { + let next_height = builder.get(&input.dimensions, next_unique_height_count).height; + let tmp_next_height_padded = next_power_of_two(builder, next_height); + builder.assign(&next_height_padded, tmp_next_height_padded); + }); builder.range(0, input.proof.len()).for_each(|i_vec, builder| { let i = i_vec[0]; let sibling = builder.get(&input.proof, i); @@ -222,11 +226,12 @@ pub(crate) fn mmcs_verify_batch( builder.if_eq(next_unique_height_index, num_unique_height).then(|builder| { builder.assign(&next_height_padded, Usize::from(0)); }); - let next_height = builder.get(&input.dimensions, next_unique_height_count).height; - let next_tmp_height_padded = next_power_of_two(builder, next_height); - builder.assign(&next_height_padded, next_tmp_height_padded); + builder.if_ne(next_unique_height_index, num_unique_height).then(|builder| { + let next_height = builder.get(&input.dimensions, next_unique_height_count).height; + let next_tmp_height_padded = next_power_of_two(builder, next_height); + builder.assign(&next_height_padded, next_tmp_height_padded); + }); }); - }); builder.assert_var_eq(reassembled_index, input.index); builder.range(0, DIGEST_ELEMS).for_each(|i_vec, builder| { @@ -273,37 +278,239 @@ pub mod tests { Vec>, > = Vec::new(); let commit = Commitment { - value: [f(0); DIGEST_ELEMS] + value: [ + f(778527199), + f(28726932), + f(1315347420), + f(1824757698), + f(154429821), + f(1391932058), + f(826833161), + f(1793433773), + ] }; let dimensions = vec![ - Dimensions { width: 9, height: 9 }, - Dimensions { width: 7, height: 7 }, - Dimensions { width: 5, height: 5 }, - Dimensions { width: 3, height: 3 }, + Dimensions { width: 8, height: 1 }, + Dimensions { width: 8, height: 1 }, + Dimensions { width: 8, height: 70 }, ]; + let index = 6; let opened_values = vec![ - vec![f(47), f(22), f(14), f(6)], - vec![f(35), f(3), f(1)], - vec![f(29), f(11), f(2)], - vec![f(14), f(4)], + vec![f(1105434748), f(689726213), f(688169105), f(1988100049), f(1580478319), f(1706067197), f(513975191), f(1741109149)], + vec![f(1522482301), f(479042531), f(1086100811), f(734531439), f(705797008), f(1234295284), f(937641372), f(553060608)], + vec![f(744749480), f(1063269152), f(300382655), f(1107270768), f(1172794741), f(274350305), f(1359913694), f(179073086)], ]; let proof = vec![ - [f(47); 4], - [f(35); 4], - [f(29); 4], - [f(14); 4], + [f(1073443193), f(894272286), f(588425464), f(1974315438), f(376335434), f(1149692201), f(543618925), f(1485228078)], + [f(1196372702), f(867462678), f(871921129), f(1745802269), f(1878325218), f(1200890208), f(955410895), f(588843483)], + [f(348296419), f(1531857785), f(1922560959), f(1197467594), f(1441649143), f(914359927), f(1924320269), f(1056370810)], + [f(1581777890), f(1925056505), f(1645298574), f(515725387), f(1060947616), f(1614093762), f(967068928), f(968302842)], + [f(961265251), f(1008373514), f(72654335), f(16568774), f(1778075526), f(1938499582), f(23748437), f(30462657)], + [f(1638730933), f(698689687), f(116457371), f(1466997263), f(993891206), f(1568724141), f(1402556463), f(1903080766)], + [f(1451476441), f(480987775), f(1782294403), f(709729703), f(500945265), f(1280038868), f(1762204994), f(240464)], ]; let mmcs_input = MmcsVerifierInput { commit, dimensions, - index: 7, + index, opened_values, proof, }; witness_stream.extend(mmcs_input.write()); - // Hints - witness_stream.extend(>::write(&5)); - + // max_height + witness_stream.extend(>::write(&70)); + // log_max_height + witness_stream.extend(>::write(&7)); + // num_unique_height + witness_stream.extend(>::write(&2)); + // unique_height_count + witness_stream.extend(>::write(&1)); + // unique_height_count + witness_stream.extend(>::write(&2)); + // height_order + witness_stream.extend(>::write(&2)); + // height_order + witness_stream.extend(>::write(&0)); + // height_order + witness_stream.extend(>::write(&1)); + // curr_height_log + witness_stream.extend(>::write(&6)); + // root + witness_stream.extend(>::write(&F::from_canonical_usize(410616511))); + // root + witness_stream.extend(>::write(&F::from_canonical_usize(1016155415))); + // root + witness_stream.extend(>::write(&F::from_canonical_usize(1214189198))); + // root + witness_stream.extend(>::write(&F::from_canonical_usize(227596423))); + // root + witness_stream.extend(>::write(&F::from_canonical_usize(638999723))); + // root + witness_stream.extend(>::write(&F::from_canonical_usize(1793520096))); + // root + witness_stream.extend(>::write(&F::from_canonical_usize(1497010699))); + // root + witness_stream.extend(>::write(&F::from_canonical_usize(307833588))); + // next_height_log + witness_stream.extend(>::write(&0)); + // next_bit + witness_stream.extend(>::write(&0)); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(847632309))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(597844957))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(471673299))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1929998464))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1285517017))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(383750469))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1336144331))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(89856465))); + // next_curr_height_padded + witness_stream.extend(>::write(&64)); + // next_bit + witness_stream.extend(>::write(&1)); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1661952137))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(675247342))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(358879322))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(328576074))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(45664218))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1026458030))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(670890979))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1522300104))); + // next_curr_height_padded + witness_stream.extend(>::write(&32)); + // next_bit + witness_stream.extend(>::write(&1)); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1134267269))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(621171717))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(231890617))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(500108260))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1862498334))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(168633872))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(399123277))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1301607042))); + // next_curr_height_padded + witness_stream.extend(>::write(&16)); + // next_bit + witness_stream.extend(>::write(&0)); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1081303431))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1607649945))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1290504702))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(149378))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1025603059))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1980340366))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(172368574))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1449539534))); + // next_curr_height_padded + witness_stream.extend(>::write(&8)); + // next_bit + witness_stream.extend(>::write(&0)); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1779401002))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1329892692))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1551737751))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1315686077))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1218609253))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1532387083))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(80357312))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1697204536))); + // next_curr_height_padded + witness_stream.extend(>::write(&4)); + // next_bit + witness_stream.extend(>::write(&0)); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(922874076))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1357099772))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(91993648))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1335971015))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(295319780))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(790352918))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1988018190))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1079914414))); + // next_curr_height_padded + witness_stream.extend(>::write(&2)); + // next_bit + witness_stream.extend(>::write(&0)); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(590430057))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1802104709))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1602739834))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(578735974))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1828105722))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(279136942))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(120317613))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(849588480))); + // next_curr_height_padded + witness_stream.extend(>::write(&1)); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(778527199))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(28726932))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1315347420))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1824757698))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(154429821))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1391932058))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(826833161))); + // new_root + witness_stream.extend(>::write(&F::from_canonical_usize(1793433773))); + + // PROGRAM let program: Program< p3_monty_31::MontyField31, > = builder.compile_isa(); diff --git a/src/basefold_verifier/utils.rs b/src/basefold_verifier/utils.rs index ceb77f9..f1bc4c7 100644 --- a/src/basefold_verifier/utils.rs +++ b/src/basefold_verifier/utils.rs @@ -29,12 +29,18 @@ pub fn next_power_of_two( ) -> Var { // Non-deterministically supply the exponent n such that // 2^n < v <= 2^{n+1} + // Ignore if v == 1 let n: Var = builder.hint_var(); let ret = pow_2(builder, n); - builder.assert_less_than_slow_bit_decomp(ret, value); - let two: Var = builder.constant(C::N::from_canonical_usize(2)); - builder.assign(&ret, ret * two); - let ret_plus_one = builder.eval(ret.clone() + Usize::from(1)); - builder.assert_less_than_slow_bit_decomp(value, ret_plus_one); + builder.if_eq(value, Usize::from(1)).then(|builder| { + builder.assign(&ret, Usize::from(1)); + }); + builder.if_ne(value, Usize::from(1)).then(|builder| { + builder.assert_less_than_slow_bit_decomp(ret, value); + let two: Var = builder.constant(C::N::from_canonical_usize(2)); + builder.assign(&ret, ret * two); + let ret_plus_one = builder.eval(ret.clone() + Usize::from(1)); + builder.assert_less_than_slow_bit_decomp(value, ret_plus_one); + }); ret } \ No newline at end of file From 18aef52c3c56f2b8ac0f38521341eb8c80afabac Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Mon, 28 Apr 2025 22:32:14 -0400 Subject: [PATCH 09/70] Added dot_product --- src/basefold_verifier/field.rs | 54 ++++++++++++++++++++++++++++ src/basefold_verifier/mmcs.rs | 2 +- src/basefold_verifier/mod.rs | 2 ++ src/basefold_verifier/query_phase.rs | 0 src/basefold_verifier/rs.rs | 10 +++++- src/basefold_verifier/utils.rs | 19 ++++++++++ 6 files changed, 85 insertions(+), 2 deletions(-) create mode 100644 src/basefold_verifier/field.rs create mode 100644 src/basefold_verifier/query_phase.rs diff --git a/src/basefold_verifier/field.rs b/src/basefold_verifier/field.rs new file mode 100644 index 0000000..64eea03 --- /dev/null +++ b/src/basefold_verifier/field.rs @@ -0,0 +1,54 @@ +const TWO_ADICITY: usize = 32; +const TWO_ADIC_GENERATORS: [usize; 33] = [ + 0x0000000000000001, + 0xffffffff00000000, + 0x0001000000000000, + 0xfffffffeff000001, + 0xefffffff00000001, + 0x00003fffffffc000, + 0x0000008000000000, + 0xf80007ff08000001, + 0xbf79143ce60ca966, + 0x1905d02a5c411f4e, + 0x9d8f2ad78bfed972, + 0x0653b4801da1c8cf, + 0xf2c35199959dfcb6, + 0x1544ef2335d17997, + 0xe0ee099310bba1e2, + 0xf6b2cffe2306baac, + 0x54df9630bf79450e, + 0xabd0a6e8aa3d8a0e, + 0x81281a7b05f9beac, + 0xfbd41c6b8caa3302, + 0x30ba2ecd5e93e76d, + 0xf502aef532322654, + 0x4b2a18ade67246b5, + 0xea9d5a1336fbc98b, + 0x86cdcc31c307e171, + 0x4bbaf5976ecfefd8, + 0xed41d05b78d6e286, + 0x10d78dd8915a171d, + 0x59049500004a4485, + 0xdfa8c93ba46d2666, + 0x7e9bd009b86a0845, + 0x400a7f755588e659, + 0x185629dcda58878c, +]; + +use openvm_native_compiler::prelude::*; +use p3_field::FieldAlgebra; + +fn two_adic_generator( + builder: &mut Builder, + bits: Var, +) -> Var { + let bits_limit = builder.eval(Usize::from(TWO_ADICITY) + Usize::from(1)); + builder.assert_less_than_slow_small_rhs(bits, bits_limit); + + let two_adic_generator: Array::F>> = builder.dyn_array(TWO_ADICITY + 1); + builder.range(0, TWO_ADICITY + 1).for_each(|i_vec, builder| { + let i = i_vec[0]; + builder.set_value(&two_adic_generator, i, C::F::from_canonical_usize(TWO_ADIC_GENERATORS[i.value()])); + }); + builder.get(&two_adic_generator, bits) +} \ No newline at end of file diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index 1ec9114..b9aa19f 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -258,7 +258,7 @@ pub mod tests { type F = BabyBear; type E = BinomialExtensionField; type EF = ::Challenge; - use crate::basefold_verifier::{binding::Dimensions, hash::DIGEST_ELEMS}; + use crate::basefold_verifier::binding::Dimensions; use super::{mmcs_verify_batch, Commitment, InnerConfig, MmcsVerifierInput}; diff --git a/src/basefold_verifier/mod.rs b/src/basefold_verifier/mod.rs index 825ab09..ac25318 100644 --- a/src/basefold_verifier/mod.rs +++ b/src/basefold_verifier/mod.rs @@ -1,6 +1,8 @@ pub(crate) mod binding; pub(crate) mod program; +pub(crate) mod query_phase; pub(crate) mod rs; pub(crate) mod mmcs; pub(crate) mod hash; +// pub(crate) mod field; pub(crate) mod utils; \ No newline at end of file diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/basefold_verifier/rs.rs b/src/basefold_verifier/rs.rs index 1926f51..b73b8f2 100644 --- a/src/basefold_verifier/rs.rs +++ b/src/basefold_verifier/rs.rs @@ -5,6 +5,8 @@ use openvm_native_recursion::hints::Hintable; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; +use super::utils::*; + pub type F = BabyBear; pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; @@ -104,8 +106,12 @@ impl Radix2DitVariable { mat: RowMajorMatrixVariable ) -> RowMajorMatrixVariable { let h = mat.height(builder); - // TODO: Verify correspondence between log_h and h let log_h = builder.hint_var(); + let log_h_minus_1: Var = builder.eval(log_h - Usize::from(1)); + let purported_h_lower_bound = pow_2(builder, log_h_minus_1); + let purported_h_upper_bound = pow_2(builder, log_h); + builder.assert_less_than_slow_small_rhs(purported_h_lower_bound, h); + builder.assert_less_than_slow_small_rhs(h, purported_h_upper_bound); // TODO: support memoization // Compute twiddle factors, or take memoized ones if already available. @@ -122,7 +128,9 @@ impl Radix2DitVariable { mat } } +*/ +/* #[derive(DslVariable, Clone)] pub struct RSCodeVerifierParametersVariable { pub dft: Radix2DitVariable, diff --git a/src/basefold_verifier/utils.rs b/src/basefold_verifier/utils.rs index f1bc4c7..72b3ee0 100644 --- a/src/basefold_verifier/utils.rs +++ b/src/basefold_verifier/utils.rs @@ -43,4 +43,23 @@ pub fn next_power_of_two( builder.assert_less_than_slow_bit_decomp(value, ret_plus_one); }); ret +} + +// Generic dot product +pub fn dot_product( + builder: &mut Builder, + li: Array>, + ri: Array>, +) -> Ext { + let ret: Ext = builder.constant(C::EF::ZERO); + builder.assert_eq::>(li.len(), ri.len()); + let len = li.len(); + + builder.range(0, len).for_each(|i_vec, builder| { + let i = i_vec[0]; + let l = builder.get(&li, i); + let r = builder.get(&ri, i); + builder.assign(&ret, ret + l * r); + }); + ret } \ No newline at end of file From 31cd29ccbe78c0d82148cbfb0048d0adb3bc3438 Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Tue, 29 Apr 2025 16:17:34 -0400 Subject: [PATCH 10/70] query_phase input --- src/basefold_verifier/basefold.rs | 49 ++++ src/basefold_verifier/binding.rs | 89 ------ src/basefold_verifier/hash.rs | 38 ++- src/basefold_verifier/mmcs.rs | 22 +- src/basefold_verifier/mod.rs | 4 +- src/basefold_verifier/query_phase.rs | 264 ++++++++++++++++++ src/basefold_verifier/rs.rs | 88 +++++- .../{program.rs => structs.rs} | 98 ++++++- src/tower_verifier/binding.rs | 26 ++ 9 files changed, 548 insertions(+), 130 deletions(-) create mode 100644 src/basefold_verifier/basefold.rs delete mode 100644 src/basefold_verifier/binding.rs rename src/basefold_verifier/{program.rs => structs.rs} (65%) diff --git a/src/basefold_verifier/basefold.rs b/src/basefold_verifier/basefold.rs new file mode 100644 index 0000000..799abf7 --- /dev/null +++ b/src/basefold_verifier/basefold.rs @@ -0,0 +1,49 @@ +use openvm_native_compiler::{asm::AsmConfig, prelude::*}; +use openvm_native_recursion::hints::Hintable; +use openvm_stark_sdk::p3_baby_bear::BabyBear; +use p3_field::extension::BinomialExtensionField; + +use super::mmcs::*; + +pub type F = BabyBear; +pub type E = BinomialExtensionField; +pub type InnerConfig = AsmConfig; + +pub type HashDigest = MmcsCommitment; +pub struct BasefoldCommitment { + pub commit: HashDigest, + pub log2_max_codeword_size: usize, + pub trivial_commits: Vec, +} + +impl Hintable for BasefoldCommitment { + type HintVariable = BasefoldCommitmentVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let commit = HashDigest::read(builder); + let log2_max_codeword_size = Usize::Var(usize::read(builder)); + let trivial_commits = Vec::::read(builder); + + BasefoldCommitmentVariable { + commit, + log2_max_codeword_size, + trivial_commits, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.commit.write()); + stream.extend(>::write(&self.log2_max_codeword_size)); + stream.extend(self.trivial_commits.write()); + stream + } +} + +pub type HashDigestVariable = MmcsCommitmentVariable; +#[derive(DslVariable, Clone)] +pub struct BasefoldCommitmentVariable { + pub commit: HashDigestVariable, + pub log2_max_codeword_size: Usize, + pub trivial_commits: Array>, +} diff --git a/src/basefold_verifier/binding.rs b/src/basefold_verifier/binding.rs deleted file mode 100644 index c0b419f..0000000 --- a/src/basefold_verifier/binding.rs +++ /dev/null @@ -1,89 +0,0 @@ -use openvm_native_compiler::{asm::AsmConfig, ir::*}; -use openvm_native_compiler_derive::DslVariable; -use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; -use openvm_stark_sdk::p3_baby_bear::BabyBear; -use p3_field::extension::BinomialExtensionField; - -pub type F = BabyBear; -pub type E = BinomialExtensionField; -pub type InnerConfig = AsmConfig; - -#[derive(DslVariable, Clone)] -pub struct CircuitIndexMetaVariable { - pub witin_num_vars: Usize, - pub witin_num_polys: Usize, - pub fixed_num_vars: Usize, - pub fixed_num_polys: Usize, -} - -pub struct CircuitIndexMeta { - pub witin_num_vars: usize, - pub witin_num_polys: usize, - pub fixed_num_vars: usize, - pub fixed_num_polys: usize, -} - -impl Hintable for CircuitIndexMeta { - type HintVariable = CircuitIndexMetaVariable; - - fn read(builder: &mut Builder) -> Self::HintVariable { - let witin_num_vars = Usize::Var(usize::read(builder)); - let witin_num_polys = Usize::Var(usize::read(builder)); - let fixed_num_vars = Usize::Var(usize::read(builder)); - let fixed_num_polys = Usize::Var(usize::read(builder)); - - CircuitIndexMetaVariable { - witin_num_vars, - witin_num_polys, - fixed_num_vars, - fixed_num_polys, - } - } - - fn write(&self) -> Vec::N>> { - let mut stream = Vec::new(); - stream.extend(>::write( - &self.witin_num_vars, - )); - stream.extend(>::write( - &self.witin_num_polys, - )); - stream.extend(>::write( - &self.fixed_num_vars, - )); - stream.extend(>::write( - &self.fixed_num_polys, - )); - stream - } -} - -#[derive(DslVariable, Clone)] -pub struct DimensionsVariable { - pub width: Var, - pub height: Var, -} - -pub struct Dimensions { - pub width: usize, - pub height: usize, -} - -impl Hintable for Dimensions { - type HintVariable = DimensionsVariable; - - fn read(builder: &mut Builder) -> Self::HintVariable { - let width = usize::read(builder); - let height = usize::read(builder); - - DimensionsVariable { width, height } - } - - fn write(&self) -> Vec::N>> { - let mut stream = Vec::new(); - stream.extend(>::write(&self.width)); - stream.extend(>::write(&self.height)); - stream - } -} -impl VecAutoHintable for Dimensions {} \ No newline at end of file diff --git a/src/basefold_verifier/hash.rs b/src/basefold_verifier/hash.rs index 64c0c8a..af23a2e 100644 --- a/src/basefold_verifier/hash.rs +++ b/src/basefold_verifier/hash.rs @@ -1,7 +1,8 @@ use openvm_native_compiler::{asm::AsmConfig, prelude::*}; -use openvm_native_recursion::hints::Hintable; +use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; +use p3_field::FieldAlgebra; pub const DIGEST_ELEMS: usize = 8; @@ -13,6 +14,14 @@ pub struct Hash { pub value: [F; DIGEST_ELEMS], } +impl Default for Hash { + fn default() -> Self { + Hash { + value: [F::ZERO; DIGEST_ELEMS], + } + } +} + impl Hintable for Hash { type HintVariable = HashVariable; @@ -23,9 +32,7 @@ impl Hintable for Hash { builder.set(&value, i, tmp); } - HashVariable { - value, - } + HashVariable { value } } fn write(&self) -> Vec::N>> { @@ -37,6 +44,7 @@ impl Hintable for Hash { stream } } +impl VecAutoHintable for Hash {} #[derive(DslVariable, Clone)] pub struct HashVariable { @@ -44,19 +52,19 @@ pub struct HashVariable { } pub fn hash_iter_slices( - builder: &mut Builder, - // _hash: HashVariable, - _values: Array>>, + builder: &mut Builder, + // _hash: HashVariable, + _values: Array>>, ) -> Array> { - // XXX: verify hash - builder.hint_felts_fixed(DIGEST_ELEMS) + // XXX: verify hash + builder.hint_felts_fixed(DIGEST_ELEMS) } pub fn compress( - builder: &mut Builder, - // _hash: HashVariable, - _values: Array>>, + builder: &mut Builder, + // _hash: HashVariable, + _values: Array>>, ) -> Array> { - // XXX: verify hash - builder.hint_felts_fixed(DIGEST_ELEMS) -} \ No newline at end of file + // XXX: verify hash + builder.hint_felts_fixed(DIGEST_ELEMS) +} diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index b9aa19f..ab6eb57 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -5,16 +5,16 @@ use openvm_native_recursion::hints::Hintable; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; -use super::{binding::*, utils::*, hash::*}; +use super::{structs::*, utils::*, hash::*}; pub type F = BabyBear; pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; -type Commitment = Hash; -type MmcsProof = Vec<[F; DIGEST_ELEMS]>; +pub type MmcsCommitment = Hash; +pub type MmcsProof = Vec<[F; DIGEST_ELEMS]>; pub struct MmcsVerifierInput { - pub commit: Commitment, + pub commit: MmcsCommitment, pub dimensions: Vec, pub index: usize, pub opened_values: Vec>, @@ -25,7 +25,7 @@ impl Hintable for MmcsVerifierInput { type HintVariable = MmcsVerifierInputVariable; fn read(builder: &mut Builder) -> Self::HintVariable { - let commit = Commitment::read(builder); + let commit = MmcsCommitment::read(builder); let dimensions = Vec::::read(builder); let index = usize::read(builder); let opened_values = Vec::>::read(builder); @@ -51,11 +51,11 @@ impl Hintable for MmcsVerifierInput { } } -type CommitmentVariable = HashVariable; -type MmcsProofVariable = Array::F>>>; +pub type MmcsCommitmentVariable = HashVariable; +pub type MmcsProofVariable = Array::F>>>; #[derive(DslVariable, Clone)] pub struct MmcsVerifierInputVariable { - pub commit: CommitmentVariable, + pub commit: MmcsCommitmentVariable, pub dimensions: Array>, pub index: Var, pub opened_values: Array>>, @@ -258,9 +258,9 @@ pub mod tests { type F = BabyBear; type E = BinomialExtensionField; type EF = ::Challenge; - use crate::basefold_verifier::binding::Dimensions; + use crate::basefold_verifier::structs::Dimensions; - use super::{mmcs_verify_batch, Commitment, InnerConfig, MmcsVerifierInput}; + use super::{mmcs_verify_batch, MmcsCommitment, InnerConfig, MmcsVerifierInput}; #[allow(dead_code)] pub fn build_mmcs_verify_batch() -> (Program, Vec>) { @@ -277,7 +277,7 @@ pub mod tests { let mut witness_stream: Vec< Vec>, > = Vec::new(); - let commit = Commitment { + let commit = MmcsCommitment { value: [ f(778527199), f(28726932), diff --git a/src/basefold_verifier/mod.rs b/src/basefold_verifier/mod.rs index ac25318..d80609c 100644 --- a/src/basefold_verifier/mod.rs +++ b/src/basefold_verifier/mod.rs @@ -1,5 +1,5 @@ -pub(crate) mod binding; -pub(crate) mod program; +pub(crate) mod structs; +pub(crate) mod basefold; pub(crate) mod query_phase; pub(crate) mod rs; pub(crate) mod mmcs; diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index e69de29..43dcdc0 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -0,0 +1,264 @@ +// Note: check all XXX comments! + +use openvm_native_compiler::{asm::AsmConfig, prelude::*}; +use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; +use openvm_stark_sdk::p3_baby_bear::BabyBear; +use p3_field::extension::BinomialExtensionField; + +use crate::tower_verifier::binding::*; +use super::{basefold::*, mmcs::*, rs::*, structs::*}; + +pub type F = BabyBear; +pub type E = BinomialExtensionField; +pub type InnerConfig = AsmConfig; + +pub struct BatchOpening { + pub opened_values: Vec>, + pub opening_proof: MmcsProof, +} + +impl Hintable for BatchOpening { + type HintVariable = BatchOpeningVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let opened_values = Vec::>::read(builder); + let opening_proof = Vec::>::read(builder); + BatchOpeningVariable { + opened_values, + opening_proof, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.opened_values.write()); + stream.extend(self.opening_proof.iter().map(|p| p.to_vec()).collect::>().write()); + stream + } +} + +#[derive(DslVariable, Clone)] +pub struct BatchOpeningVariable { + pub opened_values: Array>>, + pub opening_proof: MmcsProofVariable, +} + +pub struct CommitPhaseProofStep { + pub sibling_value: F, + pub opening_proof: MmcsProof, +} + +impl Hintable for CommitPhaseProofStep { + type HintVariable = CommitPhaseProofStepVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let sibling_value = F::read(builder); + let opening_proof = Vec::>::read(builder); + CommitPhaseProofStepVariable { + sibling_value, + opening_proof, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.sibling_value.write()); + stream.extend(self.opening_proof.iter().map(|p| p.to_vec()).collect::>().write()); + stream + } + } +impl VecAutoHintable for CommitPhaseProofStep {} + +#[derive(DslVariable, Clone)] +pub struct CommitPhaseProofStepVariable { + pub sibling_value: Felt, + pub opening_proof: MmcsProofVariable, +} + +pub struct QueryOpeningProof { + pub witin_base_proof: BatchOpening, + pub fixed_base_proof: Option, + pub commit_phase_openings: Vec, +} +type QueryOpeningProofs = Vec; + +impl Hintable for QueryOpeningProof { + type HintVariable = QueryOpeningProofVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let witin_base_proof = BatchOpening::read(builder); + let fixed_is_some = Usize::Var(usize::read(builder)); + let fixed_base_proof = BatchOpening::read(builder); + let commit_phase_openings = Vec::::read(builder); + QueryOpeningProofVariable { + witin_base_proof, + fixed_is_some, + fixed_base_proof, + commit_phase_openings, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.witin_base_proof.write()); + if let Some(fixed_base_proof) = &self.fixed_base_proof { + stream.extend(>::write(&1)); + stream.extend(fixed_base_proof.write()); + } else { + stream.extend(>::write(&0)); + let tmp_proof = BatchOpening { + opened_values: Vec::new(), + opening_proof: Vec::new(), + }; + stream.extend(tmp_proof.write()); + } + stream.extend(self.commit_phase_openings.write()); + stream + } + } +impl VecAutoHintable for QueryOpeningProof {} + +#[derive(DslVariable, Clone)] +pub struct QueryOpeningProofVariable { + pub witin_base_proof: BatchOpeningVariable, + pub fixed_is_some: Usize, // 0 <==> false + pub fixed_base_proof: BatchOpeningVariable, + pub commit_phase_openings: Array>, +} +type QueryOpeningProofsVariable = Array>; + + +// NOTE: Different from PointAndEval in tower_verifier! +pub struct PointAndEvals { + pub point: Point, + pub evals: Vec, +} +impl Hintable for PointAndEvals { + type HintVariable = PointAndEvalsVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let point = Point::read(builder); + let evals = Vec::::read(builder); + PointAndEvalsVariable { + point, + evals, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.point.write()); + stream.extend(self.evals.write()); + stream + } +} +impl VecAutoHintable for PointAndEvals {} + +#[derive(DslVariable, Clone)] +pub struct PointAndEvalsVariable { + pub point: PointVariable, + pub evals: Array>, +} + +pub struct QueryPhaseVerifierInput { + pub max_num_var: usize, + pub indices: Vec, + pub vp: RSCodeVerifierParameters, + pub final_message: Vec>, + pub batch_coeffs: Vec, + pub queries: QueryOpeningProofs, + pub fixed_comm: Option, + pub witin_comm: BasefoldCommitment, + pub circuit_meta: Vec, + pub commits: Vec, + pub fold_challenges: Vec, + pub sumcheck_messages: Vec, + pub point_evals: Vec<(Point, Vec)>, +} + +impl Hintable for QueryPhaseVerifierInput { + type HintVariable = QueryPhaseVerifierInputVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let max_num_var = Usize::Var(usize::read(builder)); + let indices = Vec::::read(builder); + let vp = RSCodeVerifierParameters::read(builder); + let final_message = Vec::>::read(builder); + let batch_coeffs = Vec::::read(builder); + let queries = QueryOpeningProofs::read(builder); + let fixed_is_some = Usize::Var(usize::read(builder)); + let fixed_comm = BasefoldCommitment::read(builder); + let witin_comm = BasefoldCommitment::read(builder); + let circuit_meta = Vec::::read(builder); + let commits = Vec::::read(builder); + let fold_challenges = Vec::::read(builder); + let sumcheck_messages = Vec::::read(builder); + let point_evals = Vec::::read(builder); + + QueryPhaseVerifierInputVariable { + max_num_var, + indices, + vp, + final_message, + batch_coeffs, + queries, + fixed_is_some, + fixed_comm, + witin_comm, + circuit_meta, + commits, + fold_challenges, + sumcheck_messages, + point_evals, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(>::write(&self.max_num_var)); + stream.extend(self.indices.write()); + stream.extend(self.vp.write()); + stream.extend(self.final_message.write()); + stream.extend(self.batch_coeffs.write()); + stream.extend(self.queries.write()); + if let Some(fixed_comm) = &self.fixed_comm { + stream.extend(>::write(&1)); + stream.extend(fixed_comm.write()); + } else { + stream.extend(>::write(&0)); + let tmp_comm = BasefoldCommitment { + commit: Default::default(), + log2_max_codeword_size: 0, + trivial_commits: Vec::new(), + }; + stream.extend(tmp_comm.write()); + } + stream.extend(self.witin_comm.write()); + stream.extend(self.circuit_meta.write()); + stream.extend(self.commits.write()); + stream.extend(self.fold_challenges.write()); + stream.extend(self.sumcheck_messages.write()); + stream.extend(self.point_evals.iter().map(|(p, e)| + PointAndEvals { point: p.clone(), evals: e.clone() } + ).collect::>().write()); + stream + } + } + +#[derive(DslVariable, Clone)] +pub struct QueryPhaseVerifierInputVariable { + pub max_num_var: Usize, + pub indices: Array>, + pub vp: RSCodeVerifierParametersVariable, + pub final_message: Array>>, + pub batch_coeffs: Array>, + pub queries: QueryOpeningProofsVariable, + pub fixed_is_some: Usize, // 0 <==> false + pub fixed_comm: BasefoldCommitmentVariable, + pub witin_comm: BasefoldCommitmentVariable, + pub circuit_meta: Array>, + pub commits: Array>, + pub fold_challenges: Array>, + pub sumcheck_messages: Array>, + pub point_evals: Array>, +} \ No newline at end of file diff --git a/src/basefold_verifier/rs.rs b/src/basefold_verifier/rs.rs index b73b8f2..a1402fb 100644 --- a/src/basefold_verifier/rs.rs +++ b/src/basefold_verifier/rs.rs @@ -5,8 +5,6 @@ use openvm_native_recursion::hints::Hintable; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; -use super::utils::*; - pub type F = BabyBear; pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; @@ -21,7 +19,7 @@ impl Hintable for DenseMatrix { fn read(builder: &mut Builder) -> Self::HintVariable { let values = Vec::::read(builder); - let width = Usize::Var(usize::read(builder)); + let width = usize::read(builder); DenseMatrixVariable { values, @@ -40,7 +38,7 @@ impl Hintable for DenseMatrix { #[derive(DslVariable, Clone)] pub struct DenseMatrixVariable { pub values: Array>, - pub width: Usize, + pub width: Var, } pub type RowMajorMatrixVariable = DenseMatrixVariable; @@ -65,7 +63,7 @@ impl DenseMatrixVariable { pub fn pad_to_height( &self, builder: &mut Builder, - new_height: RVar, + new_height: Var, fill: Ext, ) { // XXX: Not necessary, only for testing purpose @@ -90,8 +88,29 @@ pub fn get_rate_log() -> Usize { Usize::from(1) } -/* /// The DIT FFT algorithm. +pub struct Radix2Dit { + pub twiddles: Vec, +} + +impl Hintable for Radix2Dit { + type HintVariable = Radix2DitVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let twiddles = Vec::::read(builder); + + Radix2DitVariable { + twiddles, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.twiddles.write()); + stream + } +} + #[derive(DslVariable, Clone)] pub struct Radix2DitVariable { /// Memoized twiddle factors for each length log_n. @@ -99,6 +118,7 @@ pub struct Radix2DitVariable { pub twiddles: Array>, } +/* impl Radix2DitVariable { fn dft_batch( &self, @@ -130,14 +150,44 @@ impl Radix2DitVariable { } */ -/* +pub struct RSCodeVerifierParameters { + pub dft: Radix2Dit, + pub t_inv_halves: Vec>, + pub full_message_size_log: usize, +} + +impl Hintable for RSCodeVerifierParameters { + type HintVariable = RSCodeVerifierParametersVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let dft = Radix2Dit::read(builder); + let t_inv_halves = Vec::>::read(builder); + let full_message_size_log = Usize::Var(usize::read(builder)); + + RSCodeVerifierParametersVariable { + dft, + t_inv_halves, + full_message_size_log, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.dft.write()); + stream.extend(self.t_inv_halves.write()); + stream.extend(>::write(&self.full_message_size_log)); + stream + } +} + #[derive(DslVariable, Clone)] pub struct RSCodeVerifierParametersVariable { pub dft: Radix2DitVariable, - pub t_inv_halves: Array>, + pub t_inv_halves: Array>>, pub full_message_size_log: Usize, } +/* pub(crate) fn encode_small( builder: &mut Builder, vp: RSCodeVerifierParametersVariable, @@ -154,6 +204,26 @@ pub(crate) fn encode_small( } */ +pub(crate) fn encode_small( + builder: &mut Builder, + _vp: RSCodeVerifierParametersVariable, + _rmm: RowMajorMatrixVariable, +) -> RowMajorMatrixVariable { + // XXX: nondeterministically supply the results for now + let len = builder.hint_var(); + let values = builder.dyn_array(len); + builder.range(0, len).for_each(|i_vec, builder| { + let i = i_vec[0]; + let next_input = builder.hint_ext(); + builder.set_value(&values, i, next_input); + }); + let width = builder.hint_var(); + DenseMatrixVariable { + values, + width, + } +} + pub mod tests { use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; use openvm_native_circuit::{Native, NativeConfig}; @@ -180,7 +250,7 @@ pub mod tests { // Witness inputs let dense_matrix_variable = DenseMatrix::read(&mut builder); - let new_height = RVar::from(8); + let new_height = builder.eval(Usize::from(8)); let fill = Ext::new(0); dense_matrix_variable.pad_to_height(&mut builder, new_height, fill); builder.halt(); diff --git a/src/basefold_verifier/program.rs b/src/basefold_verifier/structs.rs similarity index 65% rename from src/basefold_verifier/program.rs rename to src/basefold_verifier/structs.rs index 419ef70..ccdfb7c 100644 --- a/src/basefold_verifier/program.rs +++ b/src/basefold_verifier/structs.rs @@ -1,6 +1,96 @@ -use super::{binding::*, utils::*}; -use openvm_native_compiler::ir::*; -use crate::basefold_verifier::rs::get_rate_log; +use openvm_native_compiler::{asm::AsmConfig, ir::*}; +use openvm_native_compiler_derive::DslVariable; +use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; +use openvm_stark_sdk::p3_baby_bear::BabyBear; +use p3_field::extension::BinomialExtensionField; + +pub type F = BabyBear; +pub type E = BinomialExtensionField; +pub type InnerConfig = AsmConfig; + +use super::rs::get_rate_log; +use super::utils::pow_2; + +#[derive(DslVariable, Clone)] +pub struct CircuitIndexMetaVariable { + pub witin_num_vars: Usize, + pub witin_num_polys: Usize, + pub fixed_num_vars: Usize, + pub fixed_num_polys: Usize, +} + +pub struct CircuitIndexMeta { + pub witin_num_vars: usize, + pub witin_num_polys: usize, + pub fixed_num_vars: usize, + pub fixed_num_polys: usize, +} + +impl Hintable for CircuitIndexMeta { + type HintVariable = CircuitIndexMetaVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let witin_num_vars = Usize::Var(usize::read(builder)); + let witin_num_polys = Usize::Var(usize::read(builder)); + let fixed_num_vars = Usize::Var(usize::read(builder)); + let fixed_num_polys = Usize::Var(usize::read(builder)); + + CircuitIndexMetaVariable { + witin_num_vars, + witin_num_polys, + fixed_num_vars, + fixed_num_polys, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(>::write( + &self.witin_num_vars, + )); + stream.extend(>::write( + &self.witin_num_polys, + )); + stream.extend(>::write( + &self.fixed_num_vars, + )); + stream.extend(>::write( + &self.fixed_num_polys, + )); + stream + } +} +impl VecAutoHintable for CircuitIndexMeta {} + +#[derive(DslVariable, Clone)] +pub struct DimensionsVariable { + pub width: Var, + pub height: Var, +} + +pub struct Dimensions { + pub width: usize, + pub height: usize, +} + +impl Hintable for Dimensions { + type HintVariable = DimensionsVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let width = usize::read(builder); + let height = usize::read(builder); + + DimensionsVariable { width, height } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(>::write(&self.width)); + stream.extend(>::write(&self.height)); + stream + } +} +impl VecAutoHintable for Dimensions {} fn get_base_codeword_dimensions( builder: &mut Builder, @@ -65,7 +155,7 @@ pub mod tests { type F = BabyBear; type E = BinomialExtensionField; type EF = ::Challenge; - use crate::basefold_verifier::binding::*; + use crate::basefold_verifier::structs::*; use super::{get_base_codeword_dimensions, InnerConfig}; diff --git a/src/tower_verifier/binding.rs b/src/tower_verifier/binding.rs index 8736067..686632d 100644 --- a/src/tower_verifier/binding.rs +++ b/src/tower_verifier/binding.rs @@ -52,6 +52,7 @@ pub struct TowerVerifierInputVariable { pub logup_specs_eval: Array>>>, } +#[derive(Clone)] pub struct Point { pub fs: Vec, } @@ -72,6 +73,31 @@ impl Hintable for Point { } impl VecAutoHintable for Point {} +pub struct PointAndEval { + pub point: Point, + pub eval: E, +} +impl Hintable for PointAndEval { + type HintVariable = PointAndEvalVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let point = Point::read(builder); + let eval = E::read(builder); + PointAndEvalVariable { + point, + eval, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.point.write()); + stream.extend(self.eval.write()); + stream + } +} +impl VecAutoHintable for PointAndEval {} + #[derive(Debug)] pub struct IOPProverMessage { pub evaluations: Vec, From 85f00f2939a8a479f5a737b91fd380fba4f289fe Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Tue, 29 Apr 2025 16:30:26 -0400 Subject: [PATCH 11/70] WIP query_phase --- src/basefold_verifier/query_phase.rs | 313 ++++++++++++++++++++++++++- 1 file changed, 312 insertions(+), 1 deletion(-) diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 43dcdc0..93d4b32 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -4,6 +4,7 @@ use openvm_native_compiler::{asm::AsmConfig, prelude::*}; use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; +use p3_field::FieldAlgebra; use crate::tower_verifier::binding::*; use super::{basefold::*, mmcs::*, rs::*, structs::*}; @@ -261,4 +262,314 @@ pub struct QueryPhaseVerifierInputVariable { pub fold_challenges: Array>, pub sumcheck_messages: Array>, pub point_evals: Array>, -} \ No newline at end of file +} + +pub(crate) fn batch_verifier_query_phase( + builder: &mut Builder, + input: QueryPhaseVerifierInputVariable, +) { + // Nondeterministically supply inv_2 + let inv_2 = builder.hint_felt(); + builder.assert_eq::>(inv_2 * C::F::from_canonical_usize(2), C::F::from_canonical_usize(1)); + + // encode_small + let final_rmm_values_len = builder.get(&input.final_message, 0).len(); + let final_rmm_values = builder.dyn_array(final_rmm_values_len); + builder.range(0, final_rmm_values_len).for_each(|i_vec, builder| { + let i = i_vec[0]; + let row = builder.get(&input.final_message, i); + let sum = builder.constant(C::EF::ZERO); + builder.range(0, row.len()).for_each(|j_vec, builder| { + let j = j_vec[0]; + let row_j = builder.get(&row, j); + builder.assign(&sum, sum + row_j); + }); + builder.set_value(&final_rmm_values, i, sum); + }); + let final_rmm = RowMajorMatrixVariable { + values: final_rmm_values, + width: builder.eval(Usize::from(1)), + }; + let final_codeword = encode_small( + builder, + input.vp.clone(), + final_rmm, + ); + + let mmcs_ext = ExtensionMmcs::::new(poseidon2_merkle_tree::()); + let mmcs = poseidon2_merkle_tree::(); + let check_queries_span = entered_span!("check_queries"); + // can't use witin_comm.log2_max_codeword_size since it's untrusted + let log2_witin_max_codeword_size = + max_num_var + >::get_rate_log(); + + // an vector with same length as circuit_meta, which is sorted by num_var in descending order and keep its index + // for reverse lookup when retrieving next base codeword to involve into batching + let folding_sorted_order = circuit_meta + .iter() + .enumerate() + .sorted_by_key(|(_, CircuitIndexMeta { witin_num_vars, .. })| Reverse(witin_num_vars)) + .map(|(index, CircuitIndexMeta { witin_num_vars, .. })| (witin_num_vars, index)) + .collect_vec(); + + indices.iter().zip_eq(queries).for_each( + |( + idx, + QueryOpeningProof { + witin_base_proof: + BatchOpening { + opened_values: witin_opened_values, + opening_proof: witin_opening_proof, + }, + fixed_base_proof: fixed_commit_option, + commit_phase_openings: opening_ext, + }, + )| { + // verify base oracle query proof + // refer to prover documentation for the reason of right shift by 1 + let mut idx = idx >> 1; + + let (witin_dimentions, fixed_dimentions) = + get_base_codeword_dimentions::(circuit_meta); + // verify witness + mmcs.verify_batch( + &witin_comm.commit, + &witin_dimentions, + idx, + witin_opened_values, + witin_opening_proof, + ) + .expect("verify witin commit batch failed"); + + // verify fixed + let fixed_commit_leafs = if let Some(fixed_comm) = fixed_comm { + let BatchOpening { + opened_values: fixed_opened_values, + opening_proof: fixed_opening_proof, + } = &fixed_commit_option.as_ref().unwrap(); + + + mmcs.verify_batch( + &fixed_comm.commit, + &fixed_dimentions, + { + let idx_shift = log2_witin_max_codeword_size as i32 + - fixed_comm.log2_max_codeword_size as i32; + if idx_shift > 0 { + idx >> idx_shift + } else { + idx << -idx_shift + } + }, + fixed_opened_values, + fixed_opening_proof, + ) + .expect("verify fixed commit batch failed"); + fixed_opened_values + } else { + &vec![] + }; + + let mut fixed_commit_leafs_iter = fixed_commit_leafs.iter(); + let mut batch_coeffs_iter = batch_coeffs.iter(); + + let base_codeword_lo_hi = circuit_meta + .iter() + .zip_eq(witin_opened_values) + .map( + |( + CircuitIndexMeta { + witin_num_polys, + fixed_num_vars, + fixed_num_polys, + .. + }, + witin_leafs, + )| { + let (lo, hi) = std::iter::once((witin_leafs, *witin_num_polys)) + .chain((*fixed_num_vars > 0).then(|| { + (fixed_commit_leafs_iter.next().unwrap(), *fixed_num_polys) + })) + .map(|(leafs, num_polys)| { + let batch_coeffs = batch_coeffs_iter + .by_ref() + .take(num_polys) + .copied() + .collect_vec(); + let (lo, hi): (&[E::BaseField], &[E::BaseField]) = + leafs.split_at(leafs.len() / 2); + ( + dot_product::( + batch_coeffs.iter().copied(), + lo.iter().copied(), + ), + dot_product::( + batch_coeffs.iter().copied(), + hi.iter().copied(), + ), + ) + }) + // fold witin/fixed lo, hi together because they share the same num_vars + .reduce(|(lo_wit, hi_wit), (lo_fixed, hi_fixed)| { + (lo_wit + lo_fixed, hi_wit + hi_fixed) + }) + .expect("unreachable"); + (lo, hi) + }, + ) + .collect_vec(); + debug_assert_eq!(folding_sorted_order.len(), base_codeword_lo_hi.len()); + debug_assert!(fixed_commit_leafs_iter.next().is_none()); + debug_assert!(batch_coeffs_iter.next().is_none()); + + // fold and query + let mut cur_num_var = max_num_var; + // -1 because for there are only #max_num_var-1 openings proof + let rounds = cur_num_var + - >::get_basecode_msg_size_log() + - 1; + let n_d_next = 1 + << (cur_num_var + >::get_rate_log() - 1); + debug_assert_eq!(rounds, fold_challenges.len() - 1); + debug_assert_eq!(rounds, commits.len(),); + debug_assert_eq!(rounds, opening_ext.len(),); + + // first folding challenge + let r = fold_challenges.first().unwrap(); + + let mut folding_sorted_order_iter = folding_sorted_order.iter(); + // take first batch which num_vars match max_num_var to initial fold value + let mut folded = folding_sorted_order_iter + .by_ref() + .peeking_take_while(|(num_vars, _)| **num_vars == cur_num_var) + .map(|(_, index)| { + let (lo, hi) = &base_codeword_lo_hi[*index]; + let coeff = + >::verifier_folding_coeffs_level( + vp, + cur_num_var + + >::get_rate_log() + - 1, + )[idx]; + codeword_fold_with_challenge(&[*lo, *hi], *r, coeff, inv_2) + }) + .sum::(); + + let mut n_d_i = n_d_next; + for ( + (pi_comm, r), + CommitPhaseProofStep { + sibling_value: leaf, + opening_proof: proof, + }, + ) in commits + .iter() + .zip_eq(fold_challenges.iter().skip(1)) + .zip_eq(opening_ext) + { + cur_num_var -= 1; + + let is_interpolate_to_right_index = (idx & 1) == 1; + let new_involved_codewords = folding_sorted_order_iter + .by_ref() + .peeking_take_while(|(num_vars, _)| **num_vars == cur_num_var) + .map(|(_, index)| { + let (lo, hi) = &base_codeword_lo_hi[*index]; + if is_interpolate_to_right_index { + *hi + } else { + *lo + } + }) + .sum::(); + + let mut leafs = vec![*leaf; 2]; + leafs[is_interpolate_to_right_index as usize] = folded + new_involved_codewords; + idx >>= 1; + mmcs_ext + .verify_batch( + pi_comm, + &[Dimensions { + width: 2, + // width is 2, thus height divide by 2 via right shift + height: n_d_i >> 1, + }], + idx, + slice::from_ref(&leafs), + proof, + ) + .expect("verify failed"); + let coeff = + >::verifier_folding_coeffs_level( + vp, + log2_strict_usize(n_d_i) - 1, + )[idx]; + debug_assert_eq!( + >::verifier_folding_coeffs_level( + vp, + log2_strict_usize(n_d_i) - 1, + ) + .len(), + n_d_i >> 1 + ); + folded = codeword_fold_with_challenge(&[leafs[0], leafs[1]], *r, coeff, inv_2); + n_d_i >>= 1; + } + debug_assert!(folding_sorted_order_iter.next().is_none()); + assert!( + final_codeword.values[idx] == folded, + "final_codeword.values[idx] value {:?} != folded {:?}", + final_codeword.values[idx], + folded + ); + }, + ); + exit_span!(check_queries_span); + + // 1. check initial claim match with first round sumcheck value + assert_eq!( + // we need to scale up with scalar for witin_num_vars < max_num_var + dot_product::( + batch_coeffs.iter().copied(), + point_evals.iter().zip_eq(circuit_meta.iter()).flat_map( + |((_, evals), CircuitIndexMeta { witin_num_vars, .. })| { + evals.iter().copied().map(move |eval| { + eval * E::from_u64(1 << (max_num_var - witin_num_vars) as u64) + }) + } + ) + ), + { sumcheck_messages[0].evaluations[0] + sumcheck_messages[0].evaluations[1] } + ); + // 2. check every round of sumcheck match with prev claims + for i in 0..fold_challenges.len() - 1 { + assert_eq!( + interpolate_uni_poly(&sumcheck_messages[i].evaluations, fold_challenges[i]), + { sumcheck_messages[i + 1].evaluations[0] + sumcheck_messages[i + 1].evaluations[1] } + ); + } + // 3. check final evaluation are correct + assert_eq!( + interpolate_uni_poly( + &sumcheck_messages[fold_challenges.len() - 1].evaluations, + fold_challenges[fold_challenges.len() - 1] + ), + izip!(final_message, point_evals.iter().map(|(point, _)| point)) + .map(|(final_message, point)| { + // coeff is the eq polynomial evaluated at the first challenge.len() variables + let num_vars_evaluated = point.len() + - >::get_basecode_msg_size_log(); + let coeff = eq_eval( + &point[..num_vars_evaluated], + &fold_challenges[fold_challenges.len() - num_vars_evaluated..], + ); + // Compute eq as the partially evaluated eq polynomial + let eq = build_eq_x_r_vec(&point[num_vars_evaluated..]); + dot_product( + final_message.iter().copied(), + eq.into_iter().map(|e| e * coeff), + ) + }) + .sum() + ); +} From ef34a424bee4001baa0d4938d221142e4e6c08d0 Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Tue, 29 Apr 2025 23:03:44 -0400 Subject: [PATCH 12/70] Fix bug in sorting --- src/basefold_verifier/basefold.rs | 2 +- src/basefold_verifier/mmcs.rs | 242 +++++++++++++++------------ src/basefold_verifier/query_phase.rs | 26 ++- src/tower_verifier/program.rs | 5 +- 4 files changed, 154 insertions(+), 121 deletions(-) diff --git a/src/basefold_verifier/basefold.rs b/src/basefold_verifier/basefold.rs index 799abf7..453410b 100644 --- a/src/basefold_verifier/basefold.rs +++ b/src/basefold_verifier/basefold.rs @@ -40,7 +40,7 @@ impl Hintable for BasefoldCommitment { } } -pub type HashDigestVariable = MmcsCommitmentVariable; +pub type HashDigestVariable = MmcsCommitmentVariable; #[derive(DslVariable, Clone)] pub struct BasefoldCommitmentVariable { pub commit: HashDigestVariable, diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index ab6eb57..8863c4a 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -1,9 +1,12 @@ // Note: check all XXX comments! +use std::marker::PhantomData; + use openvm_native_compiler::{asm::AsmConfig, prelude::*}; use openvm_native_recursion::hints::Hintable; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; +use p3_field::FieldAlgebra; use super::{structs::*, utils::*, hash::*}; @@ -11,6 +14,19 @@ pub type F = BabyBear; pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; +// XXX: Fill in MerkleTreeMmcs +pub struct MerkleTreeMmcs { + pub hash: (), + pub compress: (), +} + +#[derive(Default)] +pub struct MerkleTreeMmcsVariables { + pub hash: (), + pub compress: (), + _phantom: PhantomData, +} + pub type MmcsCommitment = Hash; pub type MmcsProof = Vec<[F; DIGEST_ELEMS]>; pub struct MmcsVerifierInput { @@ -64,6 +80,7 @@ pub struct MmcsVerifierInputVariable { pub(crate) fn mmcs_verify_batch( builder: &mut Builder, + _mmcs: MerkleTreeMmcsVariables, // self input: MmcsVerifierInputVariable, ) { // Check that the openings have the correct shape. @@ -95,19 +112,29 @@ pub(crate) fn mmcs_verify_batch( // Nondeterministically supplies: // 1. num_unique_height: number of different heights - // 2. unique_height_count: for each unique height, number of dimensions of that height - // 3. height_order: after sorting by decreasing height, the original index of each entry - // 4. height_diff: whether the height of the sorted entry differ from the next. 0 - no diff; 1 - diff + // 2. height_order: after sorting by decreasing height, the original index of each entry + // To ensure that height_order represents sorted index, assert that + // 1. It has the same length as input.dimensions (checked by requesting num_dims hints) + // 2. It does not contain the same index twice (checked via a correspondence array) + // 3. Indexed heights are sorted in decreasing order + // While checking, record: + // 1. unique_height_count: for each unique height, number of dimensions of that height let num_unique_height = builder.hint_var(); let unique_height_count = builder.dyn_array(num_unique_height); - builder.range(0, num_unique_height).for_each(|i_vec, builder| { + let zero: Ext = builder.constant(C::EF::ZERO); + let one: Ext = builder.constant(C::EF::ONE); + let height_order_surjection_check: Array> = builder.dyn_array(num_dims.clone()); + builder.range(0, num_dims.clone()).for_each(|i_vec, builder| { let i = i_vec[0]; - let next_count = builder.hint_var(); - builder.set_value(&unique_height_count, i, next_count); + builder.set(&height_order_surjection_check, i, zero.clone()); }); let height_order = builder.dyn_array(num_dims.clone()); let last_order = builder.hint_var(); + // Check surjection + let surjection_check = builder.get(&height_order_surjection_check, last_order); + builder.assert_ext_eq(surjection_check, zero.clone()); + builder.set(&height_order_surjection_check, last_order, one.clone()); builder.set_value(&height_order, 0, last_order); let last_height = builder.get(&input.dimensions, last_order).height; @@ -115,9 +142,13 @@ pub(crate) fn mmcs_verify_batch( let last_unique_height_count: Var = builder.eval(Usize::from(1)); builder.range(1, num_dims).for_each(|i_vec, builder| { let i = i_vec[0]; + // Check surjection let next_order = builder.hint_var(); + let surjection_check = builder.get(&height_order_surjection_check, next_order); + builder.assert_ext_eq(surjection_check, zero.clone()); + builder.set(&height_order_surjection_check, next_order, one.clone()); + // Check height let next_height = builder.get(&input.dimensions, next_order).height; - builder.if_eq(last_height, next_height).then(|builder| { // next_height == last_height builder.assign(&last_unique_height_count, last_unique_height_count + Usize::from(1)); @@ -126,9 +157,8 @@ pub(crate) fn mmcs_verify_batch( // next_height < last_height builder.assert_less_than_slow_small_rhs(next_height, last_height); - // Verify correctness of unique_height_count - let purported_unique_height_count = builder.get(&unique_height_count, last_unique_height_index); - builder.assert_var_eq(purported_unique_height_count, last_unique_height_count); + // Update unique_height_count + builder.set(&unique_height_count, last_unique_height_index, last_unique_height_count); builder.assign(&last_height, next_height); builder.assign(&last_unique_height_index, last_unique_height_index + Usize::from(1)); builder.assign(&last_unique_height_count, Usize::from(1)); @@ -139,8 +169,7 @@ pub(crate) fn mmcs_verify_batch( }); // Final check on num_unique_height and unique_height_count - let purported_unique_height_count = builder.get(&unique_height_count, last_unique_height_index); - builder.assert_var_eq(purported_unique_height_count, last_unique_height_count); + builder.set(&unique_height_count, last_unique_height_index, last_unique_height_count); builder.assign(&last_unique_height_index, last_unique_height_index + Usize::from(1)); builder.assert_var_eq(last_unique_height_index, num_unique_height); @@ -150,7 +179,7 @@ pub(crate) fn mmcs_verify_batch( let curr_height_padded = next_power_of_two(builder, first_height); // Construct root through hashing - let root_dims_count = builder.get(&unique_height_count, 0); + let root_dims_count: Var = builder.get(&unique_height_count, 0); let root_values = builder.dyn_array(root_dims_count); builder.range(0, root_dims_count).for_each(|i_vec, builder| { let i = i_vec[0]; @@ -268,8 +297,9 @@ pub mod tests { let mut builder = AsmBuilder::::default(); // Witness inputs + let mmcs_self = Default::default(); let mmcs_input = MmcsVerifierInput::read(&mut builder); - mmcs_verify_batch(&mut builder, mmcs_input); + mmcs_verify_batch(&mut builder, mmcs_self, mmcs_input); builder.halt(); // Pass in witness stream @@ -279,14 +309,14 @@ pub mod tests { > = Vec::new(); let commit = MmcsCommitment { value: [ - f(778527199), - f(28726932), - f(1315347420), - f(1824757698), - f(154429821), - f(1391932058), - f(826833161), - f(1793433773), + f(1715944678), + f(1204294900), + f(59582177), + f(320945505), + f(1470843790), + f(1773915204), + f(380281369), + f(383365269), ] }; let dimensions = vec![ @@ -296,18 +326,18 @@ pub mod tests { ]; let index = 6; let opened_values = vec![ - vec![f(1105434748), f(689726213), f(688169105), f(1988100049), f(1580478319), f(1706067197), f(513975191), f(1741109149)], - vec![f(1522482301), f(479042531), f(1086100811), f(734531439), f(705797008), f(1234295284), f(937641372), f(553060608)], - vec![f(744749480), f(1063269152), f(300382655), f(1107270768), f(1172794741), f(274350305), f(1359913694), f(179073086)], + vec![f(774319227), f(1631186743), f(254325873), f(504149682), f(239740532), f(1126519109), f(1044404585), f(1274764277)], + vec![f(1486505160), f(631183960), f(329388712), f(1934479253), f(115532954), f(1978455077), f(66346996), f(821157541)], + vec![f(149196326), f(1186650877), f(1970038391), f(1893286029), f(1249658956), f(1618951617), f(419030634), f(1967997848)], ]; let proof = vec![ - [f(1073443193), f(894272286), f(588425464), f(1974315438), f(376335434), f(1149692201), f(543618925), f(1485228078)], - [f(1196372702), f(867462678), f(871921129), f(1745802269), f(1878325218), f(1200890208), f(955410895), f(588843483)], - [f(348296419), f(1531857785), f(1922560959), f(1197467594), f(1441649143), f(914359927), f(1924320269), f(1056370810)], - [f(1581777890), f(1925056505), f(1645298574), f(515725387), f(1060947616), f(1614093762), f(967068928), f(968302842)], - [f(961265251), f(1008373514), f(72654335), f(16568774), f(1778075526), f(1938499582), f(23748437), f(30462657)], - [f(1638730933), f(698689687), f(116457371), f(1466997263), f(993891206), f(1568724141), f(1402556463), f(1903080766)], - [f(1451476441), f(480987775), f(1782294403), f(709729703), f(500945265), f(1280038868), f(1762204994), f(240464)], + [f(845920358), f(1201648213), f(1087654550), f(264553580), f(633209321), f(877945079), f(1674449089), f(1062812099)], + [f(5498027), f(1901489519), f(179361222), f(41261871), f(1546446894), f(266690586), f(1882928070), f(844710372)], + [f(721245096), f(388358486), f(1443363461), f(1349470697), f(253624794), f(1359455861), f(237485093), f(1955099141)], + [f(1816731864), f(402719753), f(1972161922), f(693018227), f(1617207065), f(1848150948), f(360933015), f(669793414)], + [f(1746479395), f(457185725), f(1263857148), f(328668702), f(1743038915), f(582282833), f(927410326), f(376217274)], + [f(1146845382), f(1117439420), f(1622226137), f(1449227765), f(138752938), f(1251889563), f(1266915653), f(267248408)], + [f(1992750195), f(1604624754), f(1748646393), f(1777984113), f(861317745), f(564150089), f(1371546358), f(460033967)], ]; let mmcs_input = MmcsVerifierInput { commit, @@ -323,10 +353,6 @@ pub mod tests { witness_stream.extend(>::write(&7)); // num_unique_height witness_stream.extend(>::write(&2)); - // unique_height_count - witness_stream.extend(>::write(&1)); - // unique_height_count - witness_stream.extend(>::write(&2)); // height_order witness_stream.extend(>::write(&2)); // height_order @@ -336,179 +362,179 @@ pub mod tests { // curr_height_log witness_stream.extend(>::write(&6)); // root - witness_stream.extend(>::write(&F::from_canonical_usize(410616511))); + witness_stream.extend(>::write(&F::from_canonical_usize(1782972889))); // root - witness_stream.extend(>::write(&F::from_canonical_usize(1016155415))); + witness_stream.extend(>::write(&F::from_canonical_usize(279434715))); // root - witness_stream.extend(>::write(&F::from_canonical_usize(1214189198))); + witness_stream.extend(>::write(&F::from_canonical_usize(1209301918))); // root - witness_stream.extend(>::write(&F::from_canonical_usize(227596423))); + witness_stream.extend(>::write(&F::from_canonical_usize(1853868602))); // root - witness_stream.extend(>::write(&F::from_canonical_usize(638999723))); + witness_stream.extend(>::write(&F::from_canonical_usize(883945353))); // root - witness_stream.extend(>::write(&F::from_canonical_usize(1793520096))); + witness_stream.extend(>::write(&F::from_canonical_usize(368353728))); // root - witness_stream.extend(>::write(&F::from_canonical_usize(1497010699))); + witness_stream.extend(>::write(&F::from_canonical_usize(1699837443))); // root - witness_stream.extend(>::write(&F::from_canonical_usize(307833588))); + witness_stream.extend(>::write(&F::from_canonical_usize(908962698))); // next_height_log witness_stream.extend(>::write(&0)); // next_bit witness_stream.extend(>::write(&0)); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(847632309))); + witness_stream.extend(>::write(&F::from_canonical_usize(271352274))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(597844957))); + witness_stream.extend(>::write(&F::from_canonical_usize(1918158485))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(471673299))); + witness_stream.extend(>::write(&F::from_canonical_usize(1538604111))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1929998464))); + witness_stream.extend(>::write(&F::from_canonical_usize(1122013445))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1285517017))); + witness_stream.extend(>::write(&F::from_canonical_usize(1844193149))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(383750469))); + witness_stream.extend(>::write(&F::from_canonical_usize(501326061))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1336144331))); + witness_stream.extend(>::write(&F::from_canonical_usize(1508959271))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(89856465))); + witness_stream.extend(>::write(&F::from_canonical_usize(1549189152))); // next_curr_height_padded witness_stream.extend(>::write(&64)); // next_bit witness_stream.extend(>::write(&1)); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1661952137))); + witness_stream.extend(>::write(&F::from_canonical_usize(222162520))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(675247342))); + witness_stream.extend(>::write(&F::from_canonical_usize(785634830))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(358879322))); + witness_stream.extend(>::write(&F::from_canonical_usize(1461778378))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(328576074))); + witness_stream.extend(>::write(&F::from_canonical_usize(836284568))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(45664218))); + witness_stream.extend(>::write(&F::from_canonical_usize(1141654637))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1026458030))); + witness_stream.extend(>::write(&F::from_canonical_usize(1339589042))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(670890979))); + witness_stream.extend(>::write(&F::from_canonical_usize(1081824021))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1522300104))); + witness_stream.extend(>::write(&F::from_canonical_usize(698316542))); // next_curr_height_padded witness_stream.extend(>::write(&32)); // next_bit witness_stream.extend(>::write(&1)); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1134267269))); + witness_stream.extend(>::write(&F::from_canonical_usize(567517164))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(621171717))); + witness_stream.extend(>::write(&F::from_canonical_usize(915833994))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(231890617))); + witness_stream.extend(>::write(&F::from_canonical_usize(621327606))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(500108260))); + witness_stream.extend(>::write(&F::from_canonical_usize(476128789))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1862498334))); + witness_stream.extend(>::write(&F::from_canonical_usize(1976747536))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(168633872))); + witness_stream.extend(>::write(&F::from_canonical_usize(1385950652))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(399123277))); + witness_stream.extend(>::write(&F::from_canonical_usize(1416073024))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1301607042))); + witness_stream.extend(>::write(&F::from_canonical_usize(862764478))); // next_curr_height_padded witness_stream.extend(>::write(&16)); // next_bit witness_stream.extend(>::write(&0)); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1081303431))); + witness_stream.extend(>::write(&F::from_canonical_usize(822965313))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1607649945))); + witness_stream.extend(>::write(&F::from_canonical_usize(1036402058))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1290504702))); + witness_stream.extend(>::write(&F::from_canonical_usize(117603799))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(149378))); + witness_stream.extend(>::write(&F::from_canonical_usize(1087591966))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1025603059))); + witness_stream.extend(>::write(&F::from_canonical_usize(443405499))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1980340366))); + witness_stream.extend(>::write(&F::from_canonical_usize(1334745091))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(172368574))); + witness_stream.extend(>::write(&F::from_canonical_usize(901165815))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1449539534))); + witness_stream.extend(>::write(&F::from_canonical_usize(1187124281))); // next_curr_height_padded witness_stream.extend(>::write(&8)); // next_bit witness_stream.extend(>::write(&0)); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1779401002))); + witness_stream.extend(>::write(&F::from_canonical_usize(875508647))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1329892692))); + witness_stream.extend(>::write(&F::from_canonical_usize(1313410483))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1551737751))); + witness_stream.extend(>::write(&F::from_canonical_usize(355713834))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1315686077))); + witness_stream.extend(>::write(&F::from_canonical_usize(1976667383))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1218609253))); + witness_stream.extend(>::write(&F::from_canonical_usize(1804021525))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1532387083))); + witness_stream.extend(>::write(&F::from_canonical_usize(294385081))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(80357312))); + witness_stream.extend(>::write(&F::from_canonical_usize(669164730))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1697204536))); + witness_stream.extend(>::write(&F::from_canonical_usize(1187763617))); // next_curr_height_padded witness_stream.extend(>::write(&4)); // next_bit witness_stream.extend(>::write(&0)); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(922874076))); + witness_stream.extend(>::write(&F::from_canonical_usize(1992024140))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1357099772))); + witness_stream.extend(>::write(&F::from_canonical_usize(439080849))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(91993648))); + witness_stream.extend(>::write(&F::from_canonical_usize(1032272714))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1335971015))); + witness_stream.extend(>::write(&F::from_canonical_usize(1304584689))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(295319780))); + witness_stream.extend(>::write(&F::from_canonical_usize(1795447062))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(790352918))); + witness_stream.extend(>::write(&F::from_canonical_usize(859522945))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1988018190))); + witness_stream.extend(>::write(&F::from_canonical_usize(1661892383))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1079914414))); + witness_stream.extend(>::write(&F::from_canonical_usize(1980559722))); // next_curr_height_padded witness_stream.extend(>::write(&2)); // next_bit witness_stream.extend(>::write(&0)); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(590430057))); + witness_stream.extend(>::write(&F::from_canonical_usize(1121119596))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1802104709))); + witness_stream.extend(>::write(&F::from_canonical_usize(369487248))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1602739834))); + witness_stream.extend(>::write(&F::from_canonical_usize(834451573))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(578735974))); + witness_stream.extend(>::write(&F::from_canonical_usize(1120744826))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1828105722))); + witness_stream.extend(>::write(&F::from_canonical_usize(758930984))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(279136942))); + witness_stream.extend(>::write(&F::from_canonical_usize(632316631))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(120317613))); + witness_stream.extend(>::write(&F::from_canonical_usize(1593276657))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(849588480))); + witness_stream.extend(>::write(&F::from_canonical_usize(507031465))); // next_curr_height_padded witness_stream.extend(>::write(&1)); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(778527199))); + witness_stream.extend(>::write(&F::from_canonical_usize(1715944678))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(28726932))); + witness_stream.extend(>::write(&F::from_canonical_usize(1204294900))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1315347420))); + witness_stream.extend(>::write(&F::from_canonical_usize(59582177))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1824757698))); + witness_stream.extend(>::write(&F::from_canonical_usize(320945505))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(154429821))); + witness_stream.extend(>::write(&F::from_canonical_usize(1470843790))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1391932058))); + witness_stream.extend(>::write(&F::from_canonical_usize(1773915204))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(826833161))); + witness_stream.extend(>::write(&F::from_canonical_usize(380281369))); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1793433773))); + witness_stream.extend(>::write(&F::from_canonical_usize(383365269))); // PROGRAM let program: Program< diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 93d4b32..246d4d9 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -274,8 +274,8 @@ pub(crate) fn batch_verifier_query_phase( // encode_small let final_rmm_values_len = builder.get(&input.final_message, 0).len(); - let final_rmm_values = builder.dyn_array(final_rmm_values_len); - builder.range(0, final_rmm_values_len).for_each(|i_vec, builder| { + let final_rmm_values = builder.dyn_array(final_rmm_values_len.clone()); + builder.range(0, final_rmm_values_len.clone()).for_each(|i_vec, builder| { let i = i_vec[0]; let row = builder.get(&input.final_message, i); let sum = builder.constant(C::EF::ZERO); @@ -295,14 +295,21 @@ pub(crate) fn batch_verifier_query_phase( input.vp.clone(), final_rmm, ); - - let mmcs_ext = ExtensionMmcs::::new(poseidon2_merkle_tree::()); - let mmcs = poseidon2_merkle_tree::(); - let check_queries_span = entered_span!("check_queries"); + // XXX: we might need to add generics to MMCS to account for different field types + let mmcs_ext: MerkleTreeMmcsVariables = Default::default(); + let mmcs: MerkleTreeMmcsVariables = Default::default(); // can't use witin_comm.log2_max_codeword_size since it's untrusted - let log2_witin_max_codeword_size = - max_num_var + >::get_rate_log(); - + let log2_witin_max_codeword_size: Var = builder.eval(input.max_num_var + get_rate_log::()); + // Nondeterministically supply the index folding_sorted_order + // Check that: + // 1. It has the same length as input.circuit_meta (checked by requesting folding_len hints) + // 2. It does not contain the same index twice (checked via a correspondence array) + // 3. Indexed witin_num_vars are sorted in decreasing order + // Infer witin_num_vars through index + let folding_len = input.circuit_meta.len(); + // let folding_sorted_order_index = builder.dyn_array(folding_len); + + /* // an vector with same length as circuit_meta, which is sorted by num_var in descending order and keep its index // for reverse lookup when retrieving next base codeword to involve into batching let folding_sorted_order = circuit_meta @@ -572,4 +579,5 @@ pub(crate) fn batch_verifier_query_phase( }) .sum() ); + */ } diff --git a/src/tower_verifier/program.rs b/src/tower_verifier/program.rs index 3d0061b..0d14fd9 100644 --- a/src/tower_verifier/program.rs +++ b/src/tower_verifier/program.rs @@ -390,11 +390,10 @@ pub fn verify_tower_proof( let initial_claim: Ext = builder.constant(C::EF::ZERO); builder.assign(&initial_claim, prod_sub_sum + logup_sub_sum); - let mut curr_pt = initial_rt.clone(); - let mut curr_eval = initial_claim.clone(); + let curr_pt = initial_rt.clone(); + let curr_eval = initial_claim.clone(); let op_range = builder.eval_expr(tower_verifier_input.max_num_variables - RVar::from(1)); let round: Felt = builder.constant(C::F::ZERO); - let one: Ext<::F, ::EF> = builder.constant(C::EF::ONE); let mut next_rt = PointAndEvalVariable { point: PointVariable { From b63abd8247be942bb8f326cbba0530288c602b72 Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Thu, 1 May 2025 11:06:10 -0400 Subject: [PATCH 13/70] WIP query phase --- src/basefold_verifier/mmcs.rs | 32 +++++++-------- src/basefold_verifier/query_phase.rs | 58 ++++++++++++++++++++++++---- src/basefold_verifier/utils.rs | 10 +++++ 3 files changed, 76 insertions(+), 24 deletions(-) diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index 8863c4a..c451ca3 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -104,8 +104,9 @@ pub(crate) fn mmcs_verify_batch( // Verify correspondence between log_h and h let log_max_height = builder.hint_var(); let log_max_height_minus_1: Var = builder.eval(log_max_height - Usize::from(1)); - let purported_max_height_lower_bound = pow_2(builder, log_max_height_minus_1); - let purported_max_height_upper_bound = pow_2(builder, log_max_height); + let purported_max_height_lower_bound: Var = pow_2(builder, log_max_height_minus_1); + let two: Var = builder.constant(C::N::TWO); + let purported_max_height_upper_bound: Var = builder.eval(purported_max_height_lower_bound * two); builder.assert_less_than_slow_small_rhs(purported_max_height_lower_bound, max_height); builder.assert_less_than_slow_small_rhs(max_height, purported_max_height_upper_bound); builder.assert_usize_eq(input.proof.len(), log_max_height); @@ -123,30 +124,30 @@ pub(crate) fn mmcs_verify_batch( let unique_height_count = builder.dyn_array(num_unique_height); let zero: Ext = builder.constant(C::EF::ZERO); let one: Ext = builder.constant(C::EF::ONE); - let height_order_surjection_check: Array> = builder.dyn_array(num_dims.clone()); + let height_sort_surjective: Array> = builder.dyn_array(num_dims.clone()); builder.range(0, num_dims.clone()).for_each(|i_vec, builder| { let i = i_vec[0]; - builder.set(&height_order_surjection_check, i, zero.clone()); + builder.set(&height_sort_surjective, i, zero.clone()); }); let height_order = builder.dyn_array(num_dims.clone()); - let last_order = builder.hint_var(); + let next_order = builder.hint_var(); // Check surjection - let surjection_check = builder.get(&height_order_surjection_check, last_order); - builder.assert_ext_eq(surjection_check, zero.clone()); - builder.set(&height_order_surjection_check, last_order, one.clone()); - builder.set_value(&height_order, 0, last_order); - let last_height = builder.get(&input.dimensions, last_order).height; + let surjective = builder.get(&height_sort_surjective, next_order); + builder.assert_ext_eq(surjective, zero.clone()); + builder.set(&height_sort_surjective, next_order, one.clone()); + builder.set_value(&height_order, 0, next_order); + let last_height = builder.get(&input.dimensions, next_order).height; let last_unique_height_index: Var = builder.eval(Usize::from(0)); let last_unique_height_count: Var = builder.eval(Usize::from(1)); builder.range(1, num_dims).for_each(|i_vec, builder| { let i = i_vec[0]; - // Check surjection let next_order = builder.hint_var(); - let surjection_check = builder.get(&height_order_surjection_check, next_order); - builder.assert_ext_eq(surjection_check, zero.clone()); - builder.set(&height_order_surjection_check, next_order, one.clone()); + // Check surjection + let surjective = builder.get(&height_sort_surjective, next_order); + builder.assert_ext_eq(surjective, zero.clone()); + builder.set(&height_sort_surjective, next_order, one.clone()); // Check height let next_height = builder.get(&input.dimensions, next_order).height; builder.if_eq(last_height, next_height).then(|builder| { @@ -164,8 +165,7 @@ pub(crate) fn mmcs_verify_batch( builder.assign(&last_unique_height_count, Usize::from(1)); }); - builder.assign(&last_order, next_order); - builder.set_value(&height_order, i, last_order); + builder.set_value(&height_order, i, next_order); }); // Final check on num_unique_height and unique_height_count diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 246d4d9..dbe3e7e 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -300,6 +300,7 @@ pub(crate) fn batch_verifier_query_phase( let mmcs: MerkleTreeMmcsVariables = Default::default(); // can't use witin_comm.log2_max_codeword_size since it's untrusted let log2_witin_max_codeword_size: Var = builder.eval(input.max_num_var + get_rate_log::()); + // Nondeterministically supply the index folding_sorted_order // Check that: // 1. It has the same length as input.circuit_meta (checked by requesting folding_len hints) @@ -307,18 +308,59 @@ pub(crate) fn batch_verifier_query_phase( // 3. Indexed witin_num_vars are sorted in decreasing order // Infer witin_num_vars through index let folding_len = input.circuit_meta.len(); - // let folding_sorted_order_index = builder.dyn_array(folding_len); + let zero: Ext = builder.constant(C::EF::ZERO); + let one: Ext = builder.constant(C::EF::ONE); + let folding_sort_surjective: Array> = builder.dyn_array(folding_len.clone()); + builder.range(0, folding_len.clone()).for_each(|i_vec, builder| { + let i = i_vec[0]; + builder.set(&folding_sort_surjective, i, zero.clone()); + }); - /* // an vector with same length as circuit_meta, which is sorted by num_var in descending order and keep its index // for reverse lookup when retrieving next base codeword to involve into batching - let folding_sorted_order = circuit_meta - .iter() - .enumerate() - .sorted_by_key(|(_, CircuitIndexMeta { witin_num_vars, .. })| Reverse(witin_num_vars)) - .map(|(index, CircuitIndexMeta { witin_num_vars, .. })| (witin_num_vars, index)) - .collect_vec(); + let folding_sorted_order_witin_num_vars: Array> = builder.dyn_array(folding_len.clone()); + let folding_sorted_order_index: Array> = builder.dyn_array(folding_len.clone()); + let next_order = builder.hint_var(); + // Check surjection + let surjective = builder.get(&folding_sort_surjective, next_order); + builder.assert_ext_eq(surjective, zero.clone()); + builder.set(&folding_sort_surjective, next_order, one.clone()); + // Assignment + let next_witin_num_vars = builder.get(&input.circuit_meta, next_order).witin_num_vars; + builder.set_value(&folding_sorted_order_witin_num_vars, 0, next_witin_num_vars.clone()); + builder.set_value(&folding_sorted_order_index, 0, Usize::Var(next_order)); + let last_witin_num_vars_plus_one: Var = builder.eval(next_witin_num_vars + Usize::from(1)); + builder.range(1, folding_len.clone()).for_each(|i_vec, builder| { + let i = i_vec[0]; + let next_order = builder.hint_var(); + // Check surjection + let surjective = builder.get(&folding_sort_surjective, next_order); + builder.assert_ext_eq(surjective, zero.clone()); + builder.set(&folding_sort_surjective, next_order, one.clone()); + // Check witin_num_vars, next_witin_num_vars < last_witin_num_vars_plus_one + let next_witin_num_vars = builder.get(&input.circuit_meta, next_order).witin_num_vars; + builder.assert_less_than_slow_small_rhs(next_witin_num_vars.clone(), last_witin_num_vars_plus_one); + builder.assign(&last_witin_num_vars_plus_one, next_witin_num_vars.clone() + Usize::from(1)); + // Assignment + builder.set_value(&folding_sorted_order_witin_num_vars, i, next_witin_num_vars); + builder.set_value(&folding_sorted_order_index, i, Usize::Var(next_order)); + }); + builder.range(0, input.indices.len()).for_each(|i_vec, builder| { + let i = i_vec[0]; + let idx = builder.get(&input.indices, i); + let query = builder.get(&input.queries, i); + let witin_opened_values = query.witin_base_proof.opened_values; + let witin_opened_proof = query.witin_base_proof.opening_proof; + let fixed_is_some = query.fixed_is_some; + let fixed_commit = query.fixed_base_proof; + let opening_ext = query.commit_phase_openings; + + // verify base oracle query proof + // refer to prover documentation for the reason of right shift by 1 + let mut idx = idx >> 1; + }); + /* indices.iter().zip_eq(queries).for_each( |( idx, diff --git a/src/basefold_verifier/utils.rs b/src/basefold_verifier/utils.rs index 72b3ee0..44827a6 100644 --- a/src/basefold_verifier/utils.rs +++ b/src/basefold_verifier/utils.rs @@ -62,4 +62,14 @@ pub fn dot_product( builder.assign(&ret, ret + l * r); }); ret +} + +// Right shift +// Note: we try to avoid this as much as possible. This is unnecessary in the case where Var is a pow of 2. +pub fn right_shift( + builder: &mut Builder, + base: Var, + exp: Var, +) -> Var { + } \ No newline at end of file From 2a5f3daa2013e3cc6f64d2b2be214260d19a4604 Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Thu, 1 May 2025 13:45:54 -0400 Subject: [PATCH 14/70] WIP query phase --- src/basefold_verifier/mmcs.rs | 2 +- src/basefold_verifier/query_phase.rs | 42 ++++++++++++++++++++++++++-- src/basefold_verifier/structs.rs | 2 +- src/basefold_verifier/utils.rs | 20 +++++++++---- 4 files changed, 55 insertions(+), 11 deletions(-) diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index c451ca3..d897a66 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -20,7 +20,7 @@ pub struct MerkleTreeMmcs { pub compress: (), } -#[derive(Default)] +#[derive(Default, Clone)] pub struct MerkleTreeMmcsVariables { pub hash: (), pub compress: (), diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index dbe3e7e..904644e 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -7,7 +7,7 @@ use p3_field::extension::BinomialExtensionField; use p3_field::FieldAlgebra; use crate::tower_verifier::binding::*; -use super::{basefold::*, mmcs::*, rs::*, structs::*}; +use super::{basefold::*, mmcs::*, rs::*, structs::*, utils::*}; pub type F = BabyBear; pub type E = BinomialExtensionField; @@ -351,14 +351,50 @@ pub(crate) fn batch_verifier_query_phase( let idx = builder.get(&input.indices, i); let query = builder.get(&input.queries, i); let witin_opened_values = query.witin_base_proof.opened_values; - let witin_opened_proof = query.witin_base_proof.opening_proof; + let witin_opening_proof = query.witin_base_proof.opening_proof; let fixed_is_some = query.fixed_is_some; let fixed_commit = query.fixed_base_proof; let opening_ext = query.commit_phase_openings; // verify base oracle query proof // refer to prover documentation for the reason of right shift by 1 - let mut idx = idx >> 1; + // Nondeterministically supply the bits of idx in BIG ENDIAN + // These are not only used by the right shift here but also later on idx_shift + let idx_len = builder.hint_var(); + let idx_bits: Array> = builder.dyn_array(idx_len); + builder.range(0, idx_len).for_each(|j_vec, builder| { + let j = j_vec[0]; + let next_bit = builder.hint_var(); + // Assert that it is a bit + builder.assert_eq::>(next_bit * next_bit, next_bit); + builder.set_value(&idx_bits, j, next_bit); + }); + // Right shift + let idx_len_minus_one: Var = builder.eval(idx_len - Usize::from(1)); + builder.assign(&idx_len, idx_len_minus_one); + let new_idx = bin_to_dec(builder, &idx_bits, idx_len); + let last_bit = builder.get(&idx_bits, idx_len); + builder.assert_eq::>(new_idx + last_bit, idx); + builder.assign(&idx, new_idx); + + let (witin_dimensions, fixed_dimensions) = + get_base_codeword_dimensions(builder, input.circuit_meta.clone()); + // verify witness + let mmcs_verifier_input = MmcsVerifierInputVariable { + commit: input.witin_comm.commit.clone(), + dimensions: witin_dimensions, + index: idx, + opened_values: witin_opened_values, + proof: witin_opening_proof, + }; + mmcs_verify_batch(builder, mmcs.clone(), mmcs_verifier_input); + + // verify fixed + builder.if_eq(fixed_is_some, Usize::from(1)).then(|builder| { + // idx_shift and idx + // let idx_shift = log2_witin_max_codeword_size - input.fixed_comm.log2_max_codeword_size.clone(); + + }); }); /* indices.iter().zip_eq(queries).for_each( diff --git a/src/basefold_verifier/structs.rs b/src/basefold_verifier/structs.rs index ccdfb7c..f45e061 100644 --- a/src/basefold_verifier/structs.rs +++ b/src/basefold_verifier/structs.rs @@ -92,7 +92,7 @@ impl Hintable for Dimensions { } impl VecAutoHintable for Dimensions {} -fn get_base_codeword_dimensions( +pub fn get_base_codeword_dimensions( builder: &mut Builder, circuit_meta_map: Array>, ) -> ( diff --git a/src/basefold_verifier/utils.rs b/src/basefold_verifier/utils.rs index 44827a6..30f82e2 100644 --- a/src/basefold_verifier/utils.rs +++ b/src/basefold_verifier/utils.rs @@ -64,12 +64,20 @@ pub fn dot_product( ret } -// Right shift -// Note: we try to avoid this as much as possible. This is unnecessary in the case where Var is a pow of 2. -pub fn right_shift( +// Convert the first len entries of binary to decimal +// BIN is in big endian +pub fn bin_to_dec( builder: &mut Builder, - base: Var, - exp: Var, + bin: &Array>, + len: Var, ) -> Var { - + let value: Var = builder.constant(C::N::ZERO); + let two: Var = builder.constant(C::N::TWO); + builder.range(0, len).for_each(|i_vec, builder| { + let i = i_vec[0]; + builder.assign(&value, value * two); + let next_bit = builder.get(bin, i); + builder.assign(&value, value + next_bit); + }); + value } \ No newline at end of file From f4c21af915f25b7a94d8bb947aec38295cd5b4fd Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Thu, 1 May 2025 19:28:14 -0400 Subject: [PATCH 15/70] WIP query phase --- src/basefold_verifier/query_phase.rs | 96 +++++++++++++--------------- 1 file changed, 45 insertions(+), 51 deletions(-) diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 904644e..5c088e3 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -390,10 +390,52 @@ pub(crate) fn batch_verifier_query_phase( mmcs_verify_batch(builder, mmcs.clone(), mmcs_verifier_input); // verify fixed + let fixed_commit_leafs = builder.dyn_array(0); builder.if_eq(fixed_is_some, Usize::from(1)).then(|builder| { - // idx_shift and idx - // let idx_shift = log2_witin_max_codeword_size - input.fixed_comm.log2_max_codeword_size.clone(); - + let fixed_opened_values = fixed_commit.opened_values.clone(); + let fixed_opening_proof = fixed_commit.opening_proof.clone(); + // new_idx used by mmcs proof + let new_idx: Var = builder.eval(idx); + // Nondeterministically supply a hint: + // 0: input.fixed_comm.log2_max_codeword_size < log2_witin_max_codeword_size + // 1: >= + let branch_le = builder.hint_var(); + builder.if_eq(branch_le, Usize::from(0)).then(|builder| { + // input.fixed_comm.log2_max_codeword_size < log2_witin_max_codeword_size + builder.assert_less_than_slow_small_rhs(input.fixed_comm.log2_max_codeword_size.clone(), log2_witin_max_codeword_size); + // idx >> idx_shift + let idx_shift_remain: Var = builder.eval(idx_len - (log2_witin_max_codeword_size - input.fixed_comm.log2_max_codeword_size.clone())); + let tmp_idx = bin_to_dec(builder, &idx_bits, idx_shift_remain); + builder.assign(&new_idx, tmp_idx); + }); + builder.if_ne(branch_le, Usize::from(0)).then(|builder| { + // input.fixed_comm.log2_max_codeword_size >= log2_witin_max_codeword_size + let input_codeword_size_plus_one: Var = builder.eval(input.fixed_comm.log2_max_codeword_size.clone() + Usize::from(1)); + builder.assert_less_than_slow_small_rhs(log2_witin_max_codeword_size, input_codeword_size_plus_one); + // idx << -idx_shift + let idx_shift = builder.eval(input.fixed_comm.log2_max_codeword_size.clone() - log2_witin_max_codeword_size); + let idx_factor = pow_2(builder, idx_shift); + builder.assign(&new_idx, new_idx * idx_factor); + }); + // verify witness + let mmcs_verifier_input = MmcsVerifierInputVariable { + commit: input.fixed_comm.commit.clone(), + dimensions: fixed_dimensions.clone(), + index: new_idx, + opened_values: fixed_opened_values.clone(), + proof: fixed_opening_proof, + }; + mmcs_verify_batch(builder, mmcs.clone(), mmcs_verifier_input); + builder.assign(&fixed_commit_leafs, fixed_opened_values); + }); + + builder.range(0, folding_len.clone()).for_each(|j_vec, builder| { + let j = j_vec[0]; + let circuit_meta = builder.get(&input.circuit_meta, j); + let witin_num_polys = circuit_meta.witin_num_polys; + let fixed_num_vars = circuit_meta.fixed_num_vars; + let fixed_num_polys = circuit_meta.fixed_num_polys; + let witin_leafs = builder.get(&witin_opened_values, j); }); }); /* @@ -410,54 +452,6 @@ pub(crate) fn batch_verifier_query_phase( commit_phase_openings: opening_ext, }, )| { - // verify base oracle query proof - // refer to prover documentation for the reason of right shift by 1 - let mut idx = idx >> 1; - - let (witin_dimentions, fixed_dimentions) = - get_base_codeword_dimentions::(circuit_meta); - // verify witness - mmcs.verify_batch( - &witin_comm.commit, - &witin_dimentions, - idx, - witin_opened_values, - witin_opening_proof, - ) - .expect("verify witin commit batch failed"); - - // verify fixed - let fixed_commit_leafs = if let Some(fixed_comm) = fixed_comm { - let BatchOpening { - opened_values: fixed_opened_values, - opening_proof: fixed_opening_proof, - } = &fixed_commit_option.as_ref().unwrap(); - - - mmcs.verify_batch( - &fixed_comm.commit, - &fixed_dimentions, - { - let idx_shift = log2_witin_max_codeword_size as i32 - - fixed_comm.log2_max_codeword_size as i32; - if idx_shift > 0 { - idx >> idx_shift - } else { - idx << -idx_shift - } - }, - fixed_opened_values, - fixed_opening_proof, - ) - .expect("verify fixed commit batch failed"); - fixed_opened_values - } else { - &vec![] - }; - - let mut fixed_commit_leafs_iter = fixed_commit_leafs.iter(); - let mut batch_coeffs_iter = batch_coeffs.iter(); - let base_codeword_lo_hi = circuit_meta .iter() .zip_eq(witin_opened_values) From 746f04feedbb37671f54042e8d12d61c54576b99 Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Thu, 1 May 2025 22:08:08 -0400 Subject: [PATCH 16/70] WIP query phase --- src/basefold_verifier/query_phase.rs | 133 +++++++++++++-------------- src/basefold_verifier/rs.rs | 4 + src/basefold_verifier/utils.rs | 17 ++-- 3 files changed, 80 insertions(+), 74 deletions(-) diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 5c088e3..418af88 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -299,7 +299,7 @@ pub(crate) fn batch_verifier_query_phase( let mmcs_ext: MerkleTreeMmcsVariables = Default::default(); let mmcs: MerkleTreeMmcsVariables = Default::default(); // can't use witin_comm.log2_max_codeword_size since it's untrusted - let log2_witin_max_codeword_size: Var = builder.eval(input.max_num_var + get_rate_log::()); + let log2_witin_max_codeword_size: Var = builder.eval(input.max_num_var.clone() + get_rate_log::()); // Nondeterministically supply the index folding_sorted_order // Check that: @@ -384,14 +384,14 @@ pub(crate) fn batch_verifier_query_phase( commit: input.witin_comm.commit.clone(), dimensions: witin_dimensions, index: idx, - opened_values: witin_opened_values, + opened_values: witin_opened_values.clone(), proof: witin_opening_proof, }; mmcs_verify_batch(builder, mmcs.clone(), mmcs_verifier_input); // verify fixed let fixed_commit_leafs = builder.dyn_array(0); - builder.if_eq(fixed_is_some, Usize::from(1)).then(|builder| { + builder.if_eq(fixed_is_some.clone(), Usize::from(1)).then(|builder| { let fixed_opened_values = fixed_commit.opened_values.clone(); let fixed_opening_proof = fixed_commit.opening_proof.clone(); // new_idx used by mmcs proof @@ -429,6 +429,9 @@ pub(crate) fn batch_verifier_query_phase( builder.assign(&fixed_commit_leafs, fixed_opened_values); }); + // base_codeword_lo_hi + let base_codeword_lo = builder.dyn_array(folding_len.clone()); + let base_codeword_hi = builder.dyn_array(folding_len.clone()); builder.range(0, folding_len.clone()).for_each(|j_vec, builder| { let j = j_vec[0]; let circuit_meta = builder.get(&input.circuit_meta, j); @@ -436,8 +439,68 @@ pub(crate) fn batch_verifier_query_phase( let fixed_num_vars = circuit_meta.fixed_num_vars; let fixed_num_polys = circuit_meta.fixed_num_polys; let witin_leafs = builder.get(&witin_opened_values, j); + // lo_wit, hi_wit + let leafs_len_div_2 = builder.hint_var(); + let two: Var = builder.eval(Usize::from(2)); + builder.assert_eq::>(leafs_len_div_2 * two, witin_leafs.len()); // Can we assume that leafs.len() is even? + // Actual dot_product only computes the first num_polys terms (can we assume leafs_len_div_2 == num_polys?) + let lo_wit = dot_product(builder, + &input.batch_coeffs, + &witin_leafs, + Usize::from(0), + Usize::from(0), + witin_num_polys.clone(), + ); + let hi_wit = dot_product(builder, + &input.batch_coeffs, + &witin_leafs, + Usize::from(0), + Usize::Var(leafs_len_div_2), + witin_num_polys.clone(), + ); + // lo_fixed, hi_fixed + let lo_fixed: Ext = builder.constant(C::EF::from_canonical_usize(0)); + let hi_fixed: Ext = builder.constant(C::EF::from_canonical_usize(0)); + builder.if_ne(fixed_num_vars, Usize::from(0)).then(|builder| { + let fixed_leafs = builder.get(&fixed_commit_leafs, j); + let leafs_len_div_2 = builder.hint_var(); + let two: Var = builder.eval(Usize::from(2)); + builder.assert_eq::>(leafs_len_div_2 * two, fixed_leafs.len()); // Can we assume that leafs.len() is even? + // Actual dot_product only computes the first num_polys terms (can we assume leafs_len_div_2 == num_polys?) + let tmp_lo_fixed = dot_product(builder, + &input.batch_coeffs, + &fixed_leafs, + Usize::from(0), + Usize::from(0), + fixed_num_polys.clone(), + ); + let tmp_hi_fixed = dot_product(builder, + &input.batch_coeffs, + &fixed_leafs, + Usize::from(0), + Usize::Var(leafs_len_div_2), + fixed_num_polys.clone(), + ); + builder.assign(&lo_fixed, tmp_lo_fixed); + builder.assign(&hi_fixed, tmp_hi_fixed); + }); + let lo: Ext = builder.eval(lo_wit + lo_fixed); + let hi: Ext = builder.eval(hi_wit + hi_fixed); + builder.set_value(&base_codeword_lo, j, lo); + builder.set_value(&base_codeword_hi, j, hi); }); + + // fold and query + let cur_num_var: Var = builder.eval(input.max_num_var.clone()); + let rounds: Var = builder.eval(cur_num_var - get_basecode_msg_size_log::() - Usize::from(1)); + let n_d_next_log: Var = builder.eval(cur_num_var - get_rate_log::() - Usize::from(1)); + let n_d_next = pow_2(builder, n_d_next_log); + + // first folding challenge + let r = builder.get(&input.fold_challenges, 0); }); + + /* indices.iter().zip_eq(queries).for_each( |( @@ -452,70 +515,6 @@ pub(crate) fn batch_verifier_query_phase( commit_phase_openings: opening_ext, }, )| { - let base_codeword_lo_hi = circuit_meta - .iter() - .zip_eq(witin_opened_values) - .map( - |( - CircuitIndexMeta { - witin_num_polys, - fixed_num_vars, - fixed_num_polys, - .. - }, - witin_leafs, - )| { - let (lo, hi) = std::iter::once((witin_leafs, *witin_num_polys)) - .chain((*fixed_num_vars > 0).then(|| { - (fixed_commit_leafs_iter.next().unwrap(), *fixed_num_polys) - })) - .map(|(leafs, num_polys)| { - let batch_coeffs = batch_coeffs_iter - .by_ref() - .take(num_polys) - .copied() - .collect_vec(); - let (lo, hi): (&[E::BaseField], &[E::BaseField]) = - leafs.split_at(leafs.len() / 2); - ( - dot_product::( - batch_coeffs.iter().copied(), - lo.iter().copied(), - ), - dot_product::( - batch_coeffs.iter().copied(), - hi.iter().copied(), - ), - ) - }) - // fold witin/fixed lo, hi together because they share the same num_vars - .reduce(|(lo_wit, hi_wit), (lo_fixed, hi_fixed)| { - (lo_wit + lo_fixed, hi_wit + hi_fixed) - }) - .expect("unreachable"); - (lo, hi) - }, - ) - .collect_vec(); - debug_assert_eq!(folding_sorted_order.len(), base_codeword_lo_hi.len()); - debug_assert!(fixed_commit_leafs_iter.next().is_none()); - debug_assert!(batch_coeffs_iter.next().is_none()); - - // fold and query - let mut cur_num_var = max_num_var; - // -1 because for there are only #max_num_var-1 openings proof - let rounds = cur_num_var - - >::get_basecode_msg_size_log() - - 1; - let n_d_next = 1 - << (cur_num_var + >::get_rate_log() - 1); - debug_assert_eq!(rounds, fold_challenges.len() - 1); - debug_assert_eq!(rounds, commits.len(),); - debug_assert_eq!(rounds, opening_ext.len(),); - - // first folding challenge - let r = fold_challenges.first().unwrap(); - let mut folding_sorted_order_iter = folding_sorted_order.iter(); // take first batch which num_vars match max_num_var to initial fold value let mut folded = folding_sorted_order_iter diff --git a/src/basefold_verifier/rs.rs b/src/basefold_verifier/rs.rs index a1402fb..b20a665 100644 --- a/src/basefold_verifier/rs.rs +++ b/src/basefold_verifier/rs.rs @@ -88,6 +88,10 @@ pub fn get_rate_log() -> Usize { Usize::from(1) } +pub fn get_basecode_msg_size_log() -> Usize { + Usize::from(7) +} + /// The DIT FFT algorithm. pub struct Radix2Dit { pub twiddles: Vec, diff --git a/src/basefold_verifier/utils.rs b/src/basefold_verifier/utils.rs index 30f82e2..e21133d 100644 --- a/src/basefold_verifier/utils.rs +++ b/src/basefold_verifier/utils.rs @@ -45,20 +45,23 @@ pub fn next_power_of_two( ret } -// Generic dot product +// Generic dot product of li[llo..llo+len] * ri[rlo..rlo+len] pub fn dot_product( builder: &mut Builder, - li: Array>, - ri: Array>, + li: &Array>, + ri: &Array>, + llo: Usize, + rlo: Usize, + len: Usize, ) -> Ext { let ret: Ext = builder.constant(C::EF::ZERO); - builder.assert_eq::>(li.len(), ri.len()); - let len = li.len(); builder.range(0, len).for_each(|i_vec, builder| { let i = i_vec[0]; - let l = builder.get(&li, i); - let r = builder.get(&ri, i); + let lidx: Var = builder.eval(llo.clone() + i); + let ridx: Var = builder.eval(rlo.clone() + i); + let l = builder.get(li, lidx); + let r = builder.get(ri, ridx); builder.assign(&ret, ret + l * r); }); ret From 6f98a1790c119f66e6e71efc52a0e264fa5c3ed1 Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Sun, 4 May 2025 15:11:45 -0700 Subject: [PATCH 17/70] New Sorting Impl --- src/basefold_verifier/mmcs.rs | 86 ++++++---------------------- src/basefold_verifier/query_phase.rs | 41 +++++-------- src/basefold_verifier/utils.rs | 78 +++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 94 deletions(-) diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index d897a66..a6bf276 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -111,67 +111,19 @@ pub(crate) fn mmcs_verify_batch( builder.assert_less_than_slow_small_rhs(max_height, purported_max_height_upper_bound); builder.assert_usize_eq(input.proof.len(), log_max_height); - // Nondeterministically supplies: - // 1. num_unique_height: number of different heights - // 2. height_order: after sorting by decreasing height, the original index of each entry - // To ensure that height_order represents sorted index, assert that - // 1. It has the same length as input.dimensions (checked by requesting num_dims hints) - // 2. It does not contain the same index twice (checked via a correspondence array) - // 3. Indexed heights are sorted in decreasing order - // While checking, record: - // 1. unique_height_count: for each unique height, number of dimensions of that height - let num_unique_height = builder.hint_var(); - let unique_height_count = builder.dyn_array(num_unique_height); - let zero: Ext = builder.constant(C::EF::ZERO); - let one: Ext = builder.constant(C::EF::ONE); - let height_sort_surjective: Array> = builder.dyn_array(num_dims.clone()); - builder.range(0, num_dims.clone()).for_each(|i_vec, builder| { - let i = i_vec[0]; - builder.set(&height_sort_surjective, i, zero.clone()); - }); - - let height_order = builder.dyn_array(num_dims.clone()); - let next_order = builder.hint_var(); - // Check surjection - let surjective = builder.get(&height_sort_surjective, next_order); - builder.assert_ext_eq(surjective, zero.clone()); - builder.set(&height_sort_surjective, next_order, one.clone()); - builder.set_value(&height_order, 0, next_order); - let last_height = builder.get(&input.dimensions, next_order).height; - - let last_unique_height_index: Var = builder.eval(Usize::from(0)); - let last_unique_height_count: Var = builder.eval(Usize::from(1)); - builder.range(1, num_dims).for_each(|i_vec, builder| { - let i = i_vec[0]; - let next_order = builder.hint_var(); - // Check surjection - let surjective = builder.get(&height_sort_surjective, next_order); - builder.assert_ext_eq(surjective, zero.clone()); - builder.set(&height_sort_surjective, next_order, one.clone()); - // Check height - let next_height = builder.get(&input.dimensions, next_order).height; - builder.if_eq(last_height, next_height).then(|builder| { - // next_height == last_height - builder.assign(&last_unique_height_count, last_unique_height_count + Usize::from(1)); - }); - builder.if_ne(last_height, next_height).then(|builder| { - // next_height < last_height - builder.assert_less_than_slow_small_rhs(next_height, last_height); - - // Update unique_height_count - builder.set(&unique_height_count, last_unique_height_index, last_unique_height_count); - builder.assign(&last_height, next_height); - builder.assign(&last_unique_height_index, last_unique_height_index + Usize::from(1)); - builder.assign(&last_unique_height_count, Usize::from(1)); - }); - - builder.set_value(&height_order, i, next_order); - }); - - // Final check on num_unique_height and unique_height_count - builder.set(&unique_height_count, last_unique_height_index, last_unique_height_count); - builder.assign(&last_unique_height_index, last_unique_height_index + Usize::from(1)); - builder.assert_var_eq(last_unique_height_index, num_unique_height); + // Sort input.dimensions by height, returns + // 1. height_order: after sorting by decreasing height, the original index of each entry + // 2. num_unique_height: number of different heights + // 3. count_per_unique_height: for each unique height, number of dimensions of that height + let ( + height_order, + num_unique_height, + count_per_unique_height + ) = sort_with_count( + builder, + &input.dimensions, + |d: DimensionsVariable| d.height, + ); // First padded_height let first_order = builder.get(&height_order, 0); @@ -179,7 +131,7 @@ pub(crate) fn mmcs_verify_batch( let curr_height_padded = next_power_of_two(builder, first_height); // Construct root through hashing - let root_dims_count: Var = builder.get(&unique_height_count, 0); + let root_dims_count: Var = builder.get(&count_per_unique_height, 0); let root_values = builder.dyn_array(root_dims_count); builder.range(0, root_dims_count).for_each(|i_vec, builder| { let i = i_vec[0]; @@ -194,10 +146,10 @@ pub(crate) fn mmcs_verify_batch( let reassembled_index: Var = builder.eval(Usize::from(0)); // next_height is the height of the next dim to be incorporated into root let next_unique_height_index: Var = builder.eval(Usize::from(1)); - let next_unique_height_count: Var = builder.eval(root_dims_count); + let next_count_per_unique_height: Var = builder.eval(root_dims_count); let next_height_padded: Var = builder.eval(Usize::from(0)); builder.if_ne(num_unique_height, Usize::from(1)).then(|builder| { - let next_height = builder.get(&input.dimensions, next_unique_height_count).height; + let next_height = builder.get(&input.dimensions, next_count_per_unique_height).height; let tmp_next_height_padded = next_power_of_two(builder, next_height); builder.assign(&next_height_padded, tmp_next_height_padded); }); @@ -235,7 +187,7 @@ pub(crate) fn mmcs_verify_batch( // determine whether next_height matches curr_height builder.if_eq(curr_height_padded, next_height_padded).then(|builder| { // hash opened_values of all dims of next_height to root - let root_dims_count = builder.get(&unique_height_count, next_unique_height_index); + let root_dims_count = builder.get(&count_per_unique_height, next_unique_height_index); let root_size: Var = builder.eval(root_dims_count + Usize::from(1)); let root_values = builder.dyn_array(root_size); builder.set_value(&root_values, 0, root.clone()); @@ -250,13 +202,13 @@ pub(crate) fn mmcs_verify_batch( builder.assign(&root, new_root); // Update parameters - builder.assign(&next_unique_height_count, next_unique_height_count + root_dims_count); + builder.assign(&next_count_per_unique_height, next_count_per_unique_height + root_dims_count); builder.assign(&next_unique_height_index, next_unique_height_index + Usize::from(1)); builder.if_eq(next_unique_height_index, num_unique_height).then(|builder| { builder.assign(&next_height_padded, Usize::from(0)); }); builder.if_ne(next_unique_height_index, num_unique_height).then(|builder| { - let next_height = builder.get(&input.dimensions, next_unique_height_count).height; + let next_height = builder.get(&input.dimensions, next_count_per_unique_height).height; let next_tmp_height_padded = next_power_of_two(builder, next_height); builder.assign(&next_height_padded, next_tmp_height_padded); }); diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 418af88..0f4addc 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -318,33 +318,19 @@ pub(crate) fn batch_verifier_query_phase( // an vector with same length as circuit_meta, which is sorted by num_var in descending order and keep its index // for reverse lookup when retrieving next base codeword to involve into batching - let folding_sorted_order_witin_num_vars: Array> = builder.dyn_array(folding_len.clone()); - let folding_sorted_order_index: Array> = builder.dyn_array(folding_len.clone()); - let next_order = builder.hint_var(); - // Check surjection - let surjective = builder.get(&folding_sort_surjective, next_order); - builder.assert_ext_eq(surjective, zero.clone()); - builder.set(&folding_sort_surjective, next_order, one.clone()); - // Assignment - let next_witin_num_vars = builder.get(&input.circuit_meta, next_order).witin_num_vars; - builder.set_value(&folding_sorted_order_witin_num_vars, 0, next_witin_num_vars.clone()); - builder.set_value(&folding_sorted_order_index, 0, Usize::Var(next_order)); - let last_witin_num_vars_plus_one: Var = builder.eval(next_witin_num_vars + Usize::from(1)); - builder.range(1, folding_len.clone()).for_each(|i_vec, builder| { - let i = i_vec[0]; - let next_order = builder.hint_var(); - // Check surjection - let surjective = builder.get(&folding_sort_surjective, next_order); - builder.assert_ext_eq(surjective, zero.clone()); - builder.set(&folding_sort_surjective, next_order, one.clone()); - // Check witin_num_vars, next_witin_num_vars < last_witin_num_vars_plus_one - let next_witin_num_vars = builder.get(&input.circuit_meta, next_order).witin_num_vars; - builder.assert_less_than_slow_small_rhs(next_witin_num_vars.clone(), last_witin_num_vars_plus_one); - builder.assign(&last_witin_num_vars_plus_one, next_witin_num_vars.clone() + Usize::from(1)); - // Assignment - builder.set_value(&folding_sorted_order_witin_num_vars, i, next_witin_num_vars); - builder.set_value(&folding_sorted_order_index, i, Usize::Var(next_order)); - }); + // Sort input.dimensions by height, returns + // 1. height_order: after sorting by decreasing height, the original index of each entry + // 2. num_unique_height: number of different heights + // 3. count_per_unique_height: for each unique height, number of dimensions of that height + let ( + folding_sorted_order_index, + num_unique_num_vars, + count_per_unique_num_var + ) = sort_with_count( + builder, + &input.circuit_meta, + |m: CircuitIndexMetaVariable| m.witin_num_vars, + ); builder.range(0, input.indices.len()).for_each(|i_vec, builder| { let i = i_vec[0]; @@ -498,6 +484,7 @@ pub(crate) fn batch_verifier_query_phase( // first folding challenge let r = builder.get(&input.fold_challenges, 0); + }); diff --git a/src/basefold_verifier/utils.rs b/src/basefold_verifier/utils.rs index e21133d..944f002 100644 --- a/src/basefold_verifier/utils.rs +++ b/src/basefold_verifier/utils.rs @@ -83,4 +83,82 @@ pub fn bin_to_dec( builder.assign(&value, value + next_bit); }); value +} + +// Sort a list in decreasing order, returns: +// 1. The original index of each sorted entry +// 2. Number of unique entries +// 3. Number of counts of each unique entry +pub fn sort_with_count( + builder: &mut Builder, + list: &Array, + ind: Ind, // Convert loaded out entries into comparable ones +) -> (Array>, Var, Array>) + where E: openvm_native_compiler::ir::MemVariable, + N: Into::N>> + openvm_native_compiler::ir::Variable, + Ind: Fn(E) -> N { + let len = list.len(); + // Nondeterministically supplies: + // 1. num_unique_entries: number of different entries + // 2. entry_order: after sorting by decreasing order, the original index of each entry + // To ensure that entry_order represents sorted index, assert that + // 1. It has the same length as list (checked by requesting list.len() hints) + // 2. It does not contain the same index twice (checked via a correspondence array) + // 3. Sorted entries are in decreasing order + // While checking, record: + // 1. count_per_unique_entry: for each unique entry value, count of entries of that value + let num_unique_entries = builder.hint_var(); + let count_per_unique_entry = builder.dyn_array(num_unique_entries); + let zero: Ext = builder.constant(C::EF::ZERO); + let one: Ext = builder.constant(C::EF::ONE); + let entries_sort_surjective: Array> = builder.dyn_array(len.clone()); + builder.range(0, len.clone()).for_each(|i_vec, builder| { + let i = i_vec[0]; + builder.set(&entries_sort_surjective, i, zero.clone()); + }); + + let entries_order = builder.dyn_array(len.clone()); + let next_order = builder.hint_var(); + // Check surjection + let surjective = builder.get(&entries_sort_surjective, next_order); + builder.assert_ext_eq(surjective, zero.clone()); + builder.set(&entries_sort_surjective, next_order, one.clone()); + builder.set_value(&entries_order, 0, next_order); + let last_entry = ind(builder.get(&list, next_order)); + + let last_unique_entry_index: Var = builder.eval(Usize::from(0)); + let last_count_per_unique_entry: Var = builder.eval(Usize::from(1)); + builder.range(1, len).for_each(|i_vec, builder| { + let i = i_vec[0]; + let next_order = builder.hint_var(); + // Check surjection + let surjective = builder.get(&entries_sort_surjective, next_order); + builder.assert_ext_eq(surjective, zero.clone()); + builder.set(&entries_sort_surjective, next_order, one.clone()); + // Check entries + let next_entry = ind(builder.get(&list, next_order)); + builder.if_eq(last_entry.clone(), next_entry.clone()).then(|builder| { + // next_entry == last_entry + builder.assign(&last_count_per_unique_entry, last_count_per_unique_entry + Usize::from(1)); + }); + builder.if_ne(last_entry.clone(), next_entry.clone()).then(|builder| { + // next_entry < last_entry + builder.assert_less_than_slow_small_rhs(next_entry.clone(), last_entry.clone()); + + // Update count_per_unique_entry + builder.set(&count_per_unique_entry, last_unique_entry_index, last_count_per_unique_entry); + builder.assign(&last_entry, next_entry.clone()); + builder.assign(&last_unique_entry_index, last_unique_entry_index + Usize::from(1)); + builder.assign(&last_count_per_unique_entry, Usize::from(1)); + }); + + builder.set_value(&entries_order, i, next_order); + }); + + // Final check on num_unique_entries and count_per_unique_entry + builder.set(&count_per_unique_entry, last_unique_entry_index, last_count_per_unique_entry); + builder.assign(&last_unique_entry_index, last_unique_entry_index + Usize::from(1)); + builder.assert_var_eq(last_unique_entry_index, num_unique_entries); + + (entries_order, num_unique_entries, count_per_unique_entry) } \ No newline at end of file From 8e3312bf40e5b567cf693ceffa742d53dde29e67 Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Sun, 4 May 2025 23:04:03 -0700 Subject: [PATCH 18/70] WIP query phase --- src/basefold_verifier/mmcs.rs | 8 +-- src/basefold_verifier/query_phase.rs | 97 +++++++++++++++++----------- src/basefold_verifier/rs.rs | 8 +++ src/basefold_verifier/utils.rs | 20 +++++- 4 files changed, 91 insertions(+), 42 deletions(-) diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index a6bf276..619b881 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -146,10 +146,10 @@ pub(crate) fn mmcs_verify_batch( let reassembled_index: Var = builder.eval(Usize::from(0)); // next_height is the height of the next dim to be incorporated into root let next_unique_height_index: Var = builder.eval(Usize::from(1)); - let next_count_per_unique_height: Var = builder.eval(root_dims_count); + let cumul_dims_count: Var = builder.eval(root_dims_count); let next_height_padded: Var = builder.eval(Usize::from(0)); builder.if_ne(num_unique_height, Usize::from(1)).then(|builder| { - let next_height = builder.get(&input.dimensions, next_count_per_unique_height).height; + let next_height = builder.get(&input.dimensions, cumul_dims_count).height; let tmp_next_height_padded = next_power_of_two(builder, next_height); builder.assign(&next_height_padded, tmp_next_height_padded); }); @@ -202,13 +202,13 @@ pub(crate) fn mmcs_verify_batch( builder.assign(&root, new_root); // Update parameters - builder.assign(&next_count_per_unique_height, next_count_per_unique_height + root_dims_count); + builder.assign(&cumul_dims_count, cumul_dims_count + root_dims_count); builder.assign(&next_unique_height_index, next_unique_height_index + Usize::from(1)); builder.if_eq(next_unique_height_index, num_unique_height).then(|builder| { builder.assign(&next_height_padded, Usize::from(0)); }); builder.if_ne(next_unique_height_index, num_unique_height).then(|builder| { - let next_height = builder.get(&input.dimensions, next_count_per_unique_height).height; + let next_height = builder.get(&input.dimensions, cumul_dims_count).height; let next_tmp_height_padded = next_power_of_two(builder, next_height); builder.assign(&next_height_padded, next_tmp_height_padded); }); diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 0f4addc..69a51ee 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -360,7 +360,7 @@ pub(crate) fn batch_verifier_query_phase( builder.assign(&idx_len, idx_len_minus_one); let new_idx = bin_to_dec(builder, &idx_bits, idx_len); let last_bit = builder.get(&idx_bits, idx_len); - builder.assert_eq::>(new_idx + last_bit, idx); + builder.assert_eq::>(Usize::from(2) * new_idx + last_bit, idx); builder.assign(&idx, new_idx); let (witin_dimensions, fixed_dimensions) = @@ -484,7 +484,65 @@ pub(crate) fn batch_verifier_query_phase( // first folding challenge let r = builder.get(&input.fold_challenges, 0); - + let next_unique_num_vars_count: Var = builder.get(&count_per_unique_num_var, 0); + let folded: Ext = builder.constant(C::EF::ZERO); + builder.range(0, next_unique_num_vars_count).for_each(|j_vec, builder| { + let j = j_vec[0]; + let index = builder.get(&folding_sorted_order_index, j); + let lo = builder.get(&base_codeword_lo, index.clone()); + let hi = builder.get(&base_codeword_hi, index.clone()); + let level: Var = builder.eval(cur_num_var + get_rate_log::() - Usize::from(1)); + let coeffs = verifier_folding_coeffs_level(builder, &input.vp, level); + let coeff = builder.get(&coeffs, idx); + let fold = codeword_fold_with_challenge::(builder, lo, hi, r, coeff, inv_2); + builder.assign(&folded, folded + fold); + }); + let next_unique_num_vars_index: Var = builder.eval(Usize::from(1)); + let cumul_num_vars_count: Var = builder.eval(next_unique_num_vars_count); + let n_d_i: Var = builder.eval(n_d_next); + // zip_eq + builder.assert_eq::>(input.commits.len() + Usize::from(1), input.fold_challenges.len()); + builder.assert_eq::>(input.commits.len(), opening_ext.len()); + builder.range(0, input.commits.len()).for_each(|j_vec, builder| { + let j = j_vec[0]; + let pi_comm = builder.get(&input.commits, j); + let j_plus_one = builder.eval_expr(j + RVar::from(1)); + let r = builder.get(&input.fold_challenges, j_plus_one); + let leaf = builder.get(&opening_ext, j).sibling_value; + let proof = builder.get(&opening_ext, j).opening_proof; + builder.assign(&cur_num_var, cur_num_var - Usize::from(1)); + + // next folding challenges + let idx_len_minus_one: Var = builder.eval(idx_len - Usize::from(1)); + let is_interpolate_to_right_index = builder.get(&idx_bits, idx_len_minus_one); + let new_involved_codewords: Ext = builder.constant(C::EF::ZERO); + let next_unique_num_vars_count: Var = builder.get(&count_per_unique_num_var, next_unique_num_vars_index); + builder.range(0, next_unique_num_vars_count).for_each(|k_vec, builder| { + let k = builder.eval_expr(k_vec[0] + cumul_num_vars_count); + let index = builder.get(&folding_sorted_order_index, k); + let lo = builder.get(&base_codeword_lo, index.clone()); + let hi = builder.get(&base_codeword_hi, index.clone()); + builder.if_eq(is_interpolate_to_right_index, Usize::from(1)).then(|builder| { + builder.assign(&new_involved_codewords, new_involved_codewords + hi); + }); + builder.if_ne(is_interpolate_to_right_index, Usize::from(1)).then(|builder| { + builder.assign(&new_involved_codewords, new_involved_codewords + lo); + }); + }); + builder.assign(&cumul_num_vars_count, cumul_num_vars_count + next_unique_num_vars_count); + builder.assign(&next_unique_num_vars_index, next_unique_num_vars_index + Usize::from(1)); + + // leafs + let leafs = builder.dyn_array(2); + builder.if_eq(is_interpolate_to_right_index, Usize::from(1)).then(|builder| { + builder.set_value(&leafs, 0, leaf); + builder.set_value(&leafs, 1, folded + new_involved_codewords); + }); + builder.if_ne(is_interpolate_to_right_index, Usize::from(1)).then(|builder| { + builder.set_value(&leafs, 0, folded + new_involved_codewords); + builder.set_value(&leafs, 1, leaf); + }); + }); }); @@ -502,25 +560,6 @@ pub(crate) fn batch_verifier_query_phase( commit_phase_openings: opening_ext, }, )| { - let mut folding_sorted_order_iter = folding_sorted_order.iter(); - // take first batch which num_vars match max_num_var to initial fold value - let mut folded = folding_sorted_order_iter - .by_ref() - .peeking_take_while(|(num_vars, _)| **num_vars == cur_num_var) - .map(|(_, index)| { - let (lo, hi) = &base_codeword_lo_hi[*index]; - let coeff = - >::verifier_folding_coeffs_level( - vp, - cur_num_var - + >::get_rate_log() - - 1, - )[idx]; - codeword_fold_with_challenge(&[*lo, *hi], *r, coeff, inv_2) - }) - .sum::(); - - let mut n_d_i = n_d_next; for ( (pi_comm, r), CommitPhaseProofStep { @@ -532,22 +571,6 @@ pub(crate) fn batch_verifier_query_phase( .zip_eq(fold_challenges.iter().skip(1)) .zip_eq(opening_ext) { - cur_num_var -= 1; - - let is_interpolate_to_right_index = (idx & 1) == 1; - let new_involved_codewords = folding_sorted_order_iter - .by_ref() - .peeking_take_while(|(num_vars, _)| **num_vars == cur_num_var) - .map(|(_, index)| { - let (lo, hi) = &base_codeword_lo_hi[*index]; - if is_interpolate_to_right_index { - *hi - } else { - *lo - } - }) - .sum::(); - let mut leafs = vec![*leaf; 2]; leafs[is_interpolate_to_right_index as usize] = folded + new_involved_codewords; idx >>= 1; diff --git a/src/basefold_verifier/rs.rs b/src/basefold_verifier/rs.rs index b20a665..892327a 100644 --- a/src/basefold_verifier/rs.rs +++ b/src/basefold_verifier/rs.rs @@ -92,6 +92,14 @@ pub fn get_basecode_msg_size_log() -> Usize { Usize::from(7) } +pub fn verifier_folding_coeffs_level( + builder: &mut Builder, + pp: &RSCodeVerifierParametersVariable, + level: Var, +) -> Array> { + builder.get(&pp.t_inv_halves, level) +} + /// The DIT FFT algorithm. pub struct Radix2Dit { pub twiddles: Vec, diff --git a/src/basefold_verifier/utils.rs b/src/basefold_verifier/utils.rs index 944f002..a61b8a2 100644 --- a/src/basefold_verifier/utils.rs +++ b/src/basefold_verifier/utils.rs @@ -96,7 +96,8 @@ pub fn sort_with_count( ) -> (Array>, Var, Array>) where E: openvm_native_compiler::ir::MemVariable, N: Into::N>> + openvm_native_compiler::ir::Variable, - Ind: Fn(E) -> N { + Ind: Fn(E) -> N +{ let len = list.len(); // Nondeterministically supplies: // 1. num_unique_entries: number of different entries @@ -161,4 +162,21 @@ pub fn sort_with_count( builder.assert_var_eq(last_unique_entry_index, num_unique_entries); (entries_order, num_unique_entries, count_per_unique_entry) +} + +pub fn codeword_fold_with_challenge( + builder: &mut Builder, + left: Ext, + right: Ext, + challenge: Ext, + coeff: Felt, + inv_2: Felt, +) -> Ext { + // original (left, right) = (lo + hi*x, lo - hi*x), lo, hi are codeword, but after times x it's not codeword + // recover left & right codeword via (lo, hi) = ((left + right) / 2, (left - right) / 2x) + let lo: Ext = builder.eval((left + right) * inv_2); + let hi: Ext = builder.eval((left - right) * coeff); // e.g. coeff = (2 * dit_butterfly)^(-1) in rs code + // we do fold on (lo, hi) to get folded = (1-r) * lo + r * hi (with lo, hi are two codewords), as it match perfectly with raw message in lagrange domain fixed variable + let ret: Ext = builder.eval(lo + challenge * (hi - lo)); + ret } \ No newline at end of file From dda2fff93755b46d58c1ab7dd1f05fbf4992c21b Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Mon, 5 May 2025 20:38:13 -0700 Subject: [PATCH 19/70] WIP query phase --- src/basefold_verifier/basefold.rs | 4 +- src/basefold_verifier/extension_mmcs.rs | 89 +++++++++++++++++++++++++ src/basefold_verifier/hash.rs | 4 +- src/basefold_verifier/mmcs.rs | 6 +- src/basefold_verifier/mod.rs | 1 + src/basefold_verifier/query_phase.rs | 31 +++++---- src/basefold_verifier/rs.rs | 4 +- src/basefold_verifier/structs.rs | 4 +- 8 files changed, 123 insertions(+), 20 deletions(-) create mode 100644 src/basefold_verifier/extension_mmcs.rs diff --git a/src/basefold_verifier/basefold.rs b/src/basefold_verifier/basefold.rs index 453410b..c705ee8 100644 --- a/src/basefold_verifier/basefold.rs +++ b/src/basefold_verifier/basefold.rs @@ -3,10 +3,10 @@ use openvm_native_recursion::hints::Hintable; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; -use super::mmcs::*; +use super::{mmcs::*, structs::DIMENSIONS}; pub type F = BabyBear; -pub type E = BinomialExtensionField; +pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; pub type HashDigest = MmcsCommitment; diff --git a/src/basefold_verifier/extension_mmcs.rs b/src/basefold_verifier/extension_mmcs.rs new file mode 100644 index 0000000..4447bea --- /dev/null +++ b/src/basefold_verifier/extension_mmcs.rs @@ -0,0 +1,89 @@ +use openvm_native_compiler::{asm::AsmConfig, prelude::*}; +use openvm_native_recursion::hints::Hintable; +use openvm_stark_sdk::p3_baby_bear::BabyBear; +use p3_field::extension::BinomialExtensionField; +use p3_field::FieldExtensionAlgebra; + +use super::{mmcs::*, structs::*}; + +pub type F = BabyBear; +pub type E = BinomialExtensionField; +pub type InnerConfig = AsmConfig; + +pub struct ExtensionMmcs { + pub inner: MerkleTreeMmcs, +} + +#[derive(Default, Clone)] +pub struct ExtensionMmcsVariable { + pub inner: MerkleTreeMmcsVariable, +} + +pub struct ExtMmcsVerifierInput { + pub commit: MmcsCommitment, + pub dimensions: Vec, + pub index: usize, + pub opened_values: Vec>, + pub proof: MmcsProof, +} + +impl Hintable for ExtMmcsVerifierInput { + type HintVariable = ExtMmcsVerifierInputVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let commit = MmcsCommitment::read(builder); + let dimensions = Vec::::read(builder); + let index = usize::read(builder); + let opened_values = Vec::>::read(builder); + let proof = Vec::>::read(builder); + + ExtMmcsVerifierInputVariable { + commit, + dimensions, + index, + opened_values, + proof, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.commit.write()); + stream.extend(self.dimensions.write()); + stream.extend(>::write(&self.index)); + stream.extend(self.opened_values.write()); + stream.extend(self.proof.iter().map(|p| p.to_vec()).collect::>().write()); + stream + } +} + +#[derive(DslVariable, Clone)] +pub struct ExtMmcsVerifierInputVariable { + pub commit: MmcsCommitmentVariable, + pub dimensions: Array>, + pub index: Var, + pub opened_values: Array>>, + pub proof: MmcsProofVariable, +} + +pub(crate) fn ext_mmcs_verify_batch( + builder: &mut Builder, + _mmcs: ExtensionMmcsVariable, // self + input: ExtMmcsVerifierInputVariable, +) { + let dim_factor: Var = builder.eval(Usize::from(C::EF::D)); + let opened_base_values = builder.dyn_array(input.opened_values.len()); + let next_base_index: Var = builder.eval(Usize::from(0)); + builder.range(0, input.opened_values.len()).for_each(|i_vec, builder| { + let i = i_vec[0]; + let next_opened_values = builder.get(&input.opened_values, i); + let next_opened_base_values_len: Var = builder.eval(next_opened_values.len() * dim_factor); + let next_opened_base_values = builder.dyn_array(next_opened_base_values_len); + builder.range(0, next_opened_values.len()).for_each(|j_vec, builder| { + let j = j_vec[0]; + let next_opened_value = builder.get(&next_opened_values, j); + + }); + builder.set_value(&opened_base_values, i, next_opened_base_values); + }); +} \ No newline at end of file diff --git a/src/basefold_verifier/hash.rs b/src/basefold_verifier/hash.rs index af23a2e..4ed395b 100644 --- a/src/basefold_verifier/hash.rs +++ b/src/basefold_verifier/hash.rs @@ -4,10 +4,12 @@ use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; use p3_field::FieldAlgebra; +use super::structs::DIMENSIONS; + pub const DIGEST_ELEMS: usize = 8; pub type F = BabyBear; -pub type E = BinomialExtensionField; +pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; pub struct Hash { diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index 619b881..c7e7efa 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -11,7 +11,7 @@ use p3_field::FieldAlgebra; use super::{structs::*, utils::*, hash::*}; pub type F = BabyBear; -pub type E = BinomialExtensionField; +pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; // XXX: Fill in MerkleTreeMmcs @@ -21,7 +21,7 @@ pub struct MerkleTreeMmcs { } #[derive(Default, Clone)] -pub struct MerkleTreeMmcsVariables { +pub struct MerkleTreeMmcsVariable { pub hash: (), pub compress: (), _phantom: PhantomData, @@ -80,7 +80,7 @@ pub struct MmcsVerifierInputVariable { pub(crate) fn mmcs_verify_batch( builder: &mut Builder, - _mmcs: MerkleTreeMmcsVariables, // self + _mmcs: MerkleTreeMmcsVariable, // self input: MmcsVerifierInputVariable, ) { // Check that the openings have the correct shape. diff --git a/src/basefold_verifier/mod.rs b/src/basefold_verifier/mod.rs index d80609c..01ea914 100644 --- a/src/basefold_verifier/mod.rs +++ b/src/basefold_verifier/mod.rs @@ -2,6 +2,7 @@ pub(crate) mod structs; pub(crate) mod basefold; pub(crate) mod query_phase; pub(crate) mod rs; +pub(crate) mod extension_mmcs; pub(crate) mod mmcs; pub(crate) mod hash; // pub(crate) mod field; diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 69a51ee..20f5de1 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -7,10 +7,10 @@ use p3_field::extension::BinomialExtensionField; use p3_field::FieldAlgebra; use crate::tower_verifier::binding::*; -use super::{basefold::*, mmcs::*, rs::*, structs::*, utils::*}; +use super::{basefold::*, extension_mmcs::ExtensionMmcsVariable, mmcs::*, rs::*, structs::*, utils::*}; pub type F = BabyBear; -pub type E = BinomialExtensionField; +pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; pub struct BatchOpening { @@ -45,7 +45,7 @@ pub struct BatchOpeningVariable { } pub struct CommitPhaseProofStep { - pub sibling_value: F, + pub sibling_value: E, pub opening_proof: MmcsProof, } @@ -53,7 +53,7 @@ impl Hintable for CommitPhaseProofStep { type HintVariable = CommitPhaseProofStepVariable; fn read(builder: &mut Builder) -> Self::HintVariable { - let sibling_value = F::read(builder); + let sibling_value = E::read(builder); let opening_proof = Vec::>::read(builder); CommitPhaseProofStepVariable { sibling_value, @@ -72,7 +72,7 @@ impl VecAutoHintable for CommitPhaseProofStep {} #[derive(DslVariable, Clone)] pub struct CommitPhaseProofStepVariable { - pub sibling_value: Felt, + pub sibling_value: Ext, pub opening_proof: MmcsProofVariable, } @@ -296,8 +296,8 @@ pub(crate) fn batch_verifier_query_phase( final_rmm, ); // XXX: we might need to add generics to MMCS to account for different field types - let mmcs_ext: MerkleTreeMmcsVariables = Default::default(); - let mmcs: MerkleTreeMmcsVariables = Default::default(); + let mmcs_ext: ExtensionMmcsVariable = Default::default(); + let mmcs: MerkleTreeMmcsVariable = Default::default(); // can't use witin_comm.log2_max_codeword_size since it's untrusted let log2_witin_max_codeword_size: Var = builder.eval(input.max_num_var.clone() + get_rate_log::()); @@ -534,14 +534,24 @@ pub(crate) fn batch_verifier_query_phase( // leafs let leafs = builder.dyn_array(2); + let new_leaf = builder.eval(folded + new_involved_codewords); builder.if_eq(is_interpolate_to_right_index, Usize::from(1)).then(|builder| { builder.set_value(&leafs, 0, leaf); - builder.set_value(&leafs, 1, folded + new_involved_codewords); + builder.set_value(&leafs, 1, new_leaf); }); builder.if_ne(is_interpolate_to_right_index, Usize::from(1)).then(|builder| { - builder.set_value(&leafs, 0, folded + new_involved_codewords); + builder.set_value(&leafs, 0, new_leaf); builder.set_value(&leafs, 1, leaf); }); + // idx >>= 1 + let idx_len_minus_one: Var = builder.eval(idx_len - Usize::from(1)); + builder.assign(&idx_len, idx_len_minus_one); + let new_idx = bin_to_dec(builder, &idx_bits, idx_len); + let last_bit = builder.get(&idx_bits, idx_len); + builder.assert_eq::>(Usize::from(2) * new_idx + last_bit, idx); + builder.assign(&idx, new_idx); + // mmcs_ext.verify_batch + }); }); @@ -571,9 +581,6 @@ pub(crate) fn batch_verifier_query_phase( .zip_eq(fold_challenges.iter().skip(1)) .zip_eq(opening_ext) { - let mut leafs = vec![*leaf; 2]; - leafs[is_interpolate_to_right_index as usize] = folded + new_involved_codewords; - idx >>= 1; mmcs_ext .verify_batch( pi_comm, diff --git a/src/basefold_verifier/rs.rs b/src/basefold_verifier/rs.rs index 892327a..a176d34 100644 --- a/src/basefold_verifier/rs.rs +++ b/src/basefold_verifier/rs.rs @@ -5,8 +5,10 @@ use openvm_native_recursion::hints::Hintable; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; +use super::structs::*; + pub type F = BabyBear; -pub type E = BinomialExtensionField; +pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; pub struct DenseMatrix { diff --git a/src/basefold_verifier/structs.rs b/src/basefold_verifier/structs.rs index f45e061..0d2cd55 100644 --- a/src/basefold_verifier/structs.rs +++ b/src/basefold_verifier/structs.rs @@ -4,8 +4,10 @@ use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; +pub const DIMENSIONS: usize = 4; + pub type F = BabyBear; -pub type E = BinomialExtensionField; +pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; use super::rs::get_rate_log; From 39443efaccba5a105b371d12bbe8d361671ddc05 Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Mon, 5 May 2025 21:36:06 -0700 Subject: [PATCH 20/70] WIP query phase --- src/basefold_verifier/extension_mmcs.rs | 23 ++++- src/basefold_verifier/query_phase.rs | 112 +++++++++--------------- 2 files changed, 61 insertions(+), 74 deletions(-) diff --git a/src/basefold_verifier/extension_mmcs.rs b/src/basefold_verifier/extension_mmcs.rs index 4447bea..f78d40d 100644 --- a/src/basefold_verifier/extension_mmcs.rs +++ b/src/basefold_verifier/extension_mmcs.rs @@ -68,22 +68,39 @@ pub struct ExtMmcsVerifierInputVariable { pub(crate) fn ext_mmcs_verify_batch( builder: &mut Builder, - _mmcs: ExtensionMmcsVariable, // self + mmcs: ExtensionMmcsVariable, // self input: ExtMmcsVerifierInputVariable, ) { let dim_factor: Var = builder.eval(Usize::from(C::EF::D)); let opened_base_values = builder.dyn_array(input.opened_values.len()); - let next_base_index: Var = builder.eval(Usize::from(0)); + let base_dimensions = builder.dyn_array(input.dimensions.len()); builder.range(0, input.opened_values.len()).for_each(|i_vec, builder| { let i = i_vec[0]; + // opened_values let next_opened_values = builder.get(&input.opened_values, i); let next_opened_base_values_len: Var = builder.eval(next_opened_values.len() * dim_factor); let next_opened_base_values = builder.dyn_array(next_opened_base_values_len); builder.range(0, next_opened_values.len()).for_each(|j_vec, builder| { let j = j_vec[0]; let next_opened_value = builder.get(&next_opened_values, j); - + // XXX: how to convert Ext to [Felt]? }); builder.set_value(&opened_base_values, i, next_opened_base_values); + + // dimensions + let next_dimension = builder.get(&input.dimensions, i); + let next_base_dimension = DimensionsVariable { + width: builder.eval(next_dimension.width.clone() * dim_factor), + height: next_dimension.height.clone(), + }; + builder.set_value(&base_dimensions, i, next_base_dimension); }); + let input = MmcsVerifierInputVariable { + commit: input.commit, + dimensions: base_dimensions, + index: input.index, + opened_values: opened_base_values, + proof: input.proof, + }; + mmcs_verify_batch(builder, mmcs.inner, input); } \ No newline at end of file diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 20f5de1..771b982 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -7,7 +7,7 @@ use p3_field::extension::BinomialExtensionField; use p3_field::FieldAlgebra; use crate::tower_verifier::binding::*; -use super::{basefold::*, extension_mmcs::ExtensionMmcsVariable, mmcs::*, rs::*, structs::*, utils::*}; +use super::{basefold::*, extension_mmcs::*, mmcs::*, rs::*, structs::*, utils::*}; pub type F = BabyBear; pub type E = BinomialExtensionField; @@ -309,7 +309,6 @@ pub(crate) fn batch_verifier_query_phase( // Infer witin_num_vars through index let folding_len = input.circuit_meta.len(); let zero: Ext = builder.constant(C::EF::ZERO); - let one: Ext = builder.constant(C::EF::ONE); let folding_sort_surjective: Array> = builder.dyn_array(folding_len.clone()); builder.range(0, folding_len.clone()).for_each(|i_vec, builder| { let i = i_vec[0]; @@ -324,7 +323,7 @@ pub(crate) fn batch_verifier_query_phase( // 3. count_per_unique_height: for each unique height, number of dimensions of that height let ( folding_sorted_order_index, - num_unique_num_vars, + _num_unique_num_vars, count_per_unique_num_var ) = sort_with_count( builder, @@ -478,9 +477,9 @@ pub(crate) fn batch_verifier_query_phase( // fold and query let cur_num_var: Var = builder.eval(input.max_num_var.clone()); - let rounds: Var = builder.eval(cur_num_var - get_basecode_msg_size_log::() - Usize::from(1)); + // let rounds: Var = builder.eval(cur_num_var - get_basecode_msg_size_log::() - Usize::from(1)); let n_d_next_log: Var = builder.eval(cur_num_var - get_rate_log::() - Usize::from(1)); - let n_d_next = pow_2(builder, n_d_next_log); + // let n_d_next = pow_2(builder, n_d_next_log); // first folding challenge let r = builder.get(&input.fold_challenges, 0); @@ -499,7 +498,8 @@ pub(crate) fn batch_verifier_query_phase( }); let next_unique_num_vars_index: Var = builder.eval(Usize::from(1)); let cumul_num_vars_count: Var = builder.eval(next_unique_num_vars_count); - let n_d_i: Var = builder.eval(n_d_next); + let n_d_i_log: Var = builder.eval(n_d_next_log); + // let n_d_i: Var = builder.eval(n_d_next); // zip_eq builder.assert_eq::>(input.commits.len() + Usize::from(1), input.fold_challenges.len()); builder.assert_eq::>(input.commits.len(), opening_ext.len()); @@ -550,77 +550,47 @@ pub(crate) fn batch_verifier_query_phase( let last_bit = builder.get(&idx_bits, idx_len); builder.assert_eq::>(Usize::from(2) * new_idx + last_bit, idx); builder.assign(&idx, new_idx); + // n_d_i >> 1 + builder.assign(&n_d_i_log, n_d_i_log - Usize::from(1)); + let n_d_i = pow_2(builder, n_d_i_log); // mmcs_ext.verify_batch - + let dimensions = builder.uninit_fixed_array(1); + let two = builder.eval(Usize::from(2)); + builder.set_value(&dimensions, 0, DimensionsVariable { + width: two, + height: n_d_i.clone(), + }); + let opened_values = builder.uninit_fixed_array(1); + builder.set_value(&opened_values, 0, leafs.clone()); + let ext_mmcs_verifier_input = ExtMmcsVerifierInputVariable { + commit: pi_comm.clone(), + dimensions, + index: idx.clone(), + opened_values, + proof, + }; + ext_mmcs_verify_batch::(builder, mmcs_ext.clone(), ext_mmcs_verifier_input); + + let coeffs = verifier_folding_coeffs_level(builder, &input.vp, n_d_i_log.clone()); + let coeff = builder.get(&coeffs, idx.clone()); + let left = builder.get(&leafs, 0); + let right = builder.get(&leafs, 1); + let new_folded = codeword_fold_with_challenge( + builder, + left, + right, + r.clone(), + coeff, + inv_2 + ); + builder.assign(&folded, new_folded); }); + let final_value = builder.get(&final_codeword.values, idx.clone()); + builder.assert_eq::>(final_value, folded); }); /* - indices.iter().zip_eq(queries).for_each( - |( - idx, - QueryOpeningProof { - witin_base_proof: - BatchOpening { - opened_values: witin_opened_values, - opening_proof: witin_opening_proof, - }, - fixed_base_proof: fixed_commit_option, - commit_phase_openings: opening_ext, - }, - )| { - for ( - (pi_comm, r), - CommitPhaseProofStep { - sibling_value: leaf, - opening_proof: proof, - }, - ) in commits - .iter() - .zip_eq(fold_challenges.iter().skip(1)) - .zip_eq(opening_ext) - { - mmcs_ext - .verify_batch( - pi_comm, - &[Dimensions { - width: 2, - // width is 2, thus height divide by 2 via right shift - height: n_d_i >> 1, - }], - idx, - slice::from_ref(&leafs), - proof, - ) - .expect("verify failed"); - let coeff = - >::verifier_folding_coeffs_level( - vp, - log2_strict_usize(n_d_i) - 1, - )[idx]; - debug_assert_eq!( - >::verifier_folding_coeffs_level( - vp, - log2_strict_usize(n_d_i) - 1, - ) - .len(), - n_d_i >> 1 - ); - folded = codeword_fold_with_challenge(&[leafs[0], leafs[1]], *r, coeff, inv_2); - n_d_i >>= 1; - } - debug_assert!(folding_sorted_order_iter.next().is_none()); - assert!( - final_codeword.values[idx] == folded, - "final_codeword.values[idx] value {:?} != folded {:?}", - final_codeword.values[idx], - folded - ); - }, - ); - exit_span!(check_queries_span); - // 1. check initial claim match with first round sumcheck value assert_eq!( // we need to scale up with scalar for witin_num_vars < max_num_var From 520da7ec45388066933f8a5cf86971ceb52e836e Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Tue, 6 May 2025 14:07:27 -0700 Subject: [PATCH 21/70] Finished query_phase encoding --- src/arithmetics/mod.rs | 60 +++++++++-- src/basefold_verifier/extension_mmcs.rs | 8 ++ src/basefold_verifier/query_phase.rs | 129 ++++++++++++++---------- src/basefold_verifier/utils.rs | 26 ++++- 4 files changed, 160 insertions(+), 63 deletions(-) diff --git a/src/arithmetics/mod.rs b/src/arithmetics/mod.rs index fcec8a3..296ffa6 100644 --- a/src/arithmetics/mod.rs +++ b/src/arithmetics/mod.rs @@ -177,19 +177,38 @@ pub fn reverse_idx_arr( res } -// Evaluate eq polynomial. pub fn eq_eval( builder: &mut Builder, x: &Array>, y: &Array>, +) -> Ext { + eq_eval_with_index::( + builder, + x, + y, + Usize::from(0), + Usize::from(0), + x.len(), + ) +} + +// Evaluate eq polynomial. +pub fn eq_eval_with_index( + builder: &mut Builder, + x: &Array>, + y: &Array>, + xlo: Usize, + ylo: Usize, + len: Usize, ) -> Ext { let acc: Ext = builder.constant(C::EF::ONE); - iter_zip!(builder, x, y).for_each(|idx_vec, builder| { - let ptr_x = idx_vec[0]; - let ptr_y = idx_vec[1]; - let v_x = builder.iter_ptr_get(&x, ptr_x); - let v_y = builder.iter_ptr_get(&y, ptr_y); + builder.range(0, len).for_each(|i_vec, builder| { + let i = i_vec[0]; + let ptr_x: Var = builder.eval(xlo.clone() + i); + let ptr_y: Var = builder.eval(ylo.clone() + i); + let v_x = builder.get(&x, ptr_x); + let v_y = builder.get(&y, ptr_y); let xi_yi: Ext = builder.eval(v_x * v_y); let one: Ext = builder.constant(C::EF::ONE); let new_acc: Ext = builder.eval(acc * (xi_yi + xi_yi - v_x - v_y + one)); @@ -391,6 +410,35 @@ pub fn build_eq_x_r_vec_sequential( evals } +pub fn build_eq_x_r_vec_sequential_with_offset( + builder: &mut Builder, + r: &Array>, + offset: Usize, +) -> Array> { + // we build eq(x,r) from its evaluations + // we want to evaluate eq(x,r) over x \in {0, 1}^num_vars + // for example, with num_vars = 4, x is a binary vector of 4, then + // 0 0 0 0 -> (1-r0) * (1-r1) * (1-r2) * (1-r3) + // 1 0 0 0 -> r0 * (1-r1) * (1-r2) * (1-r3) + // 0 1 0 0 -> (1-r0) * r1 * (1-r2) * (1-r3) + // 1 1 0 0 -> r0 * r1 * (1-r2) * (1-r3) + // .... + // 1 1 1 1 -> r0 * r1 * r2 * r3 + // we will need 2^num_var evaluations + + let r_len: Var = builder.eval(r.len() - offset); + let evals_len: Felt = builder.constant(C::F::ONE); + let evals_len = builder.exp_power_of_2_v::>(evals_len, r_len); + let evals_len = builder.cast_felt_to_var(evals_len); + + let evals: Array> = builder.dyn_array(evals_len); + + // _debug + // build_eq_x_r_helper_sequential_offset(r, &mut evals, E::ONE); + // unsafe { std::mem::transmute(evals) } + evals +} + // _debug // /// A helper function to build eq(x, r)*init via dynamic programing tricks. // /// This function takes 2^num_var iterations, and per iteration with 1 multiplication. diff --git a/src/basefold_verifier/extension_mmcs.rs b/src/basefold_verifier/extension_mmcs.rs index f78d40d..af65cc0 100644 --- a/src/basefold_verifier/extension_mmcs.rs +++ b/src/basefold_verifier/extension_mmcs.rs @@ -80,10 +80,18 @@ pub(crate) fn ext_mmcs_verify_batch( let next_opened_values = builder.get(&input.opened_values, i); let next_opened_base_values_len: Var = builder.eval(next_opened_values.len() * dim_factor); let next_opened_base_values = builder.dyn_array(next_opened_base_values_len); + let next_opened_base_index: Var = builder.eval(Usize::from(0)); builder.range(0, next_opened_values.len()).for_each(|j_vec, builder| { let j = j_vec[0]; let next_opened_value = builder.get(&next_opened_values, j); // XXX: how to convert Ext to [Felt]? + let next_opened_value_felt = builder.ext2felt(next_opened_value); + builder.range(0, next_opened_value_felt.len()).for_each(|k_vec, builder| { + let k = k_vec[0]; + let next_felt = builder.get(&next_opened_value_felt, k); + builder.set_value(&next_opened_base_values, next_opened_base_index, next_felt); + builder.assign(&next_opened_base_index, next_opened_base_index + Usize::from(1)); + }); }); builder.set_value(&opened_base_values, i, next_opened_base_values); diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 771b982..0c87670 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -6,7 +6,7 @@ use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; use p3_field::FieldAlgebra; -use crate::tower_verifier::binding::*; +use crate::{arithmetics::{build_eq_x_r_vec_sequential_with_offset, eq_eval_with_index}, tower_verifier::{binding::*, program::interpolate_uni_poly}}; use super::{basefold::*, extension_mmcs::*, mmcs::*, rs::*, structs::*, utils::*}; pub type F = BabyBear; @@ -431,12 +431,9 @@ pub(crate) fn batch_verifier_query_phase( // Actual dot_product only computes the first num_polys terms (can we assume leafs_len_div_2 == num_polys?) let lo_wit = dot_product(builder, &input.batch_coeffs, - &witin_leafs, - Usize::from(0), - Usize::from(0), - witin_num_polys.clone(), + &witin_leafs, ); - let hi_wit = dot_product(builder, + let hi_wit = dot_product_with_index(builder, &input.batch_coeffs, &witin_leafs, Usize::from(0), @@ -455,11 +452,8 @@ pub(crate) fn batch_verifier_query_phase( let tmp_lo_fixed = dot_product(builder, &input.batch_coeffs, &fixed_leafs, - Usize::from(0), - Usize::from(0), - fixed_num_polys.clone(), ); - let tmp_hi_fixed = dot_product(builder, + let tmp_hi_fixed = dot_product_with_index(builder, &input.batch_coeffs, &fixed_leafs, Usize::from(0), @@ -589,53 +583,80 @@ pub(crate) fn batch_verifier_query_phase( builder.assert_eq::>(final_value, folded); }); - - /* // 1. check initial claim match with first round sumcheck value - assert_eq!( + let points = builder.dyn_array(input.batch_coeffs.len()); + let next_point_index: Var = builder.eval(Usize::from(0)); + builder.range(0, input.point_evals.len()).for_each(|i_vec, builder| { + let i = i_vec[0]; + let evals = builder.get(&input.point_evals, i).evals; + let witin_num_vars = builder.get(&input.circuit_meta, i).witin_num_vars; // we need to scale up with scalar for witin_num_vars < max_num_var - dot_product::( - batch_coeffs.iter().copied(), - point_evals.iter().zip_eq(circuit_meta.iter()).flat_map( - |((_, evals), CircuitIndexMeta { witin_num_vars, .. })| { - evals.iter().copied().map(move |eval| { - eval * E::from_u64(1 << (max_num_var - witin_num_vars) as u64) - }) - } - ) - ), - { sumcheck_messages[0].evaluations[0] + sumcheck_messages[0].evaluations[1] } + let scale_log = builder.eval(input.max_num_var.clone() - witin_num_vars); + let scale = pow_2(builder, scale_log); + builder.range(0, evals.len()).for_each(|j_vec, builder| { + let j = j_vec[0]; + let eval = builder.get(&evals, j); + let scaled_eval: Ext = builder.eval(eval * scale); + builder.set_value(&points, next_point_index, scaled_eval); + builder.assign(&next_point_index, next_point_index + Usize::from(1)); + }); + }); + let left = dot_product( + builder, + &input.batch_coeffs, + &points, ); + let next_sumcheck_evals = builder.get(&input.sumcheck_messages, 0).evaluations; + let eval0 = builder.get(&next_sumcheck_evals, 0); + let eval1 = builder.get(&next_sumcheck_evals, 1); + let right: Ext = builder.eval(eval0 + eval1); + builder.assert_eq::>(left, right); + // 2. check every round of sumcheck match with prev claims - for i in 0..fold_challenges.len() - 1 { - assert_eq!( - interpolate_uni_poly(&sumcheck_messages[i].evaluations, fold_challenges[i]), - { sumcheck_messages[i + 1].evaluations[0] + sumcheck_messages[i + 1].evaluations[1] } - ); - } + let fold_len_minus_one: Var = builder.eval(input.fold_challenges.len() - Usize::from(1)); + builder.range(0, fold_len_minus_one).for_each(|i_vec, builder| { + let i = i_vec[0]; + let evals = builder.get(&input.sumcheck_messages, i).evaluations; + let challenge = builder.get(&input.fold_challenges, i); + let left = interpolate_uni_poly(builder, evals, challenge); + let i_plus_one = builder.eval_expr(i + Usize::from(1)); + let next_evals = builder.get(&input.sumcheck_messages, i_plus_one).evaluations; + let eval0 = builder.get(&next_evals, 0); + let eval1 = builder.get(&next_evals, 1); + let right: Ext = builder.eval(eval0 + eval1); + builder.assert_eq::>(left, right); + }); + // 3. check final evaluation are correct - assert_eq!( - interpolate_uni_poly( - &sumcheck_messages[fold_challenges.len() - 1].evaluations, - fold_challenges[fold_challenges.len() - 1] - ), - izip!(final_message, point_evals.iter().map(|(point, _)| point)) - .map(|(final_message, point)| { - // coeff is the eq polynomial evaluated at the first challenge.len() variables - let num_vars_evaluated = point.len() - - >::get_basecode_msg_size_log(); - let coeff = eq_eval( - &point[..num_vars_evaluated], - &fold_challenges[fold_challenges.len() - num_vars_evaluated..], - ); - // Compute eq as the partially evaluated eq polynomial - let eq = build_eq_x_r_vec(&point[num_vars_evaluated..]); - dot_product( - final_message.iter().copied(), - eq.into_iter().map(|e| e * coeff), - ) - }) - .sum() - ); - */ + let final_evals = builder.get(&input.sumcheck_messages, fold_len_minus_one.clone()).evaluations; + let final_challenge = builder.get(&input.fold_challenges, fold_len_minus_one.clone()); + let left = interpolate_uni_poly(builder, final_evals, final_challenge); + let right: Ext = builder.constant(C::EF::ZERO); + builder.range(0, input.final_message.len()).for_each(|i_vec, builder| { + let i = i_vec[0]; + let final_message = builder.get(&input.final_message, i); + let point = builder.get(&input.point_evals, i).point; + // coeff is the eq polynomial evaluated at the first challenge.len() variables + let num_vars_evaluated: Var = builder.eval(point.fs.len() - get_basecode_msg_size_log::()); + let ylo = builder.eval(input.fold_challenges.len() - num_vars_evaluated); + let coeff = eq_eval_with_index( + builder, + &point.fs, + &input.fold_challenges, + Usize::from(0), + Usize::Var(ylo), + Usize::Var(num_vars_evaluated), + ); + let eq = build_eq_x_r_vec_sequential_with_offset::(builder, &point.fs, Usize::Var(num_vars_evaluated)); + let eq_coeff = builder.dyn_array(eq.len()); + builder.range(0, eq.len()).for_each(|j_vec, builder| { + let j = j_vec[0]; + let next_eq = builder.get(&eq, j); + let next_eq_coeff: Ext = builder.eval(next_eq * coeff); + builder.set_value(&eq_coeff, j, next_eq_coeff); + }); + let dot_prod = dot_product(builder, &final_message, &eq_coeff); + builder.assign(&right, right + dot_prod); + }); + builder.assert_eq::>(left, right); } diff --git a/src/basefold_verifier/utils.rs b/src/basefold_verifier/utils.rs index a61b8a2..a390b6d 100644 --- a/src/basefold_verifier/utils.rs +++ b/src/basefold_verifier/utils.rs @@ -45,15 +45,35 @@ pub fn next_power_of_two( ret } +// Dot product: li * ri +pub fn dot_product( + builder: &mut Builder, + li: &Array>, + ri: &Array, +) -> Ext +where F: openvm_native_compiler::ir::MemVariable + 'static +{ + dot_product_with_index::( + builder, + li, + ri, + Usize::from(0), + Usize::from(0), + li.len(), + ) +} + // Generic dot product of li[llo..llo+len] * ri[rlo..rlo+len] -pub fn dot_product( +pub fn dot_product_with_index( builder: &mut Builder, li: &Array>, - ri: &Array>, + ri: &Array, llo: Usize, rlo: Usize, len: Usize, -) -> Ext { +) -> Ext + where F: openvm_native_compiler::ir::MemVariable + 'static +{ let ret: Ext = builder.constant(C::EF::ZERO); builder.range(0, len).for_each(|i_vec, builder| { From 282892ff54468ce9fa57e9ca9165419182899f6b Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Sat, 10 May 2025 16:21:18 -0700 Subject: [PATCH 22/70] Support serialized input --- Cargo.lock | 1 + Cargo.toml | 1 + src/basefold_verifier/basefold.rs | 2 + src/basefold_verifier/hash.rs | 10 ++-- src/basefold_verifier/mmcs.rs | 2 +- src/basefold_verifier/query_phase.rs | 77 ++++++++++++++++++++++++++++ src/basefold_verifier/rs.rs | 11 +++- src/basefold_verifier/structs.rs | 2 + src/tower_verifier/binding.rs | 5 +- 9 files changed, 102 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4d05e79..99d32d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -650,6 +650,7 @@ dependencies = [ "ark-poly", "ark-serialize 0.5.0", "ark-std 0.5.0", + "bincode", "ceno_emul", "ceno_zkvm", "ff_ext", diff --git a/Cargo.toml b/Cargo.toml index ffed690..3685879 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", ta rand = { version = "0.8.5", default-features = false } itertools = { version = "0.13.0", default-features = false } +bincode = "1" # Plonky3 p3-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } diff --git a/src/basefold_verifier/basefold.rs b/src/basefold_verifier/basefold.rs index c705ee8..b083880 100644 --- a/src/basefold_verifier/basefold.rs +++ b/src/basefold_verifier/basefold.rs @@ -2,6 +2,7 @@ use openvm_native_compiler::{asm::AsmConfig, prelude::*}; use openvm_native_recursion::hints::Hintable; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; +use serde::Deserialize; use super::{mmcs::*, structs::DIMENSIONS}; @@ -10,6 +11,7 @@ pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; pub type HashDigest = MmcsCommitment; +#[derive(Deserialize)] pub struct BasefoldCommitment { pub commit: HashDigest, pub log2_max_codeword_size: usize, diff --git a/src/basefold_verifier/hash.rs b/src/basefold_verifier/hash.rs index 4ed395b..dd901d2 100644 --- a/src/basefold_verifier/hash.rs +++ b/src/basefold_verifier/hash.rs @@ -3,6 +3,7 @@ use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; use p3_field::FieldAlgebra; +use serde::Deserialize; use super::structs::DIMENSIONS; @@ -12,11 +13,12 @@ pub type F = BabyBear; pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; -pub struct Hash { +#[derive(Deserialize)] +pub struct Hash { pub value: [F; DIGEST_ELEMS], } -impl Default for Hash { +impl Default for Hash { fn default() -> Self { Hash { value: [F::ZERO; DIGEST_ELEMS], @@ -24,7 +26,7 @@ impl Default for Hash { } } -impl Hintable for Hash { +impl Hintable for Hash { type HintVariable = HashVariable; fn read(builder: &mut Builder) -> Self::HintVariable { @@ -46,7 +48,7 @@ impl Hintable for Hash { stream } } -impl VecAutoHintable for Hash {} +impl VecAutoHintable for Hash {} #[derive(DslVariable, Clone)] pub struct HashVariable { diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index c7e7efa..e052ffa 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -27,7 +27,7 @@ pub struct MerkleTreeMmcsVariable { _phantom: PhantomData, } -pub type MmcsCommitment = Hash; +pub type MmcsCommitment = Hash; pub type MmcsProof = Vec<[F; DIGEST_ELEMS]>; pub struct MmcsVerifierInput { pub commit: MmcsCommitment, diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 0c87670..7a77ef5 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -5,6 +5,7 @@ use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; use p3_field::FieldAlgebra; +use serde::Deserialize; use crate::{arithmetics::{build_eq_x_r_vec_sequential_with_offset, eq_eval_with_index}, tower_verifier::{binding::*, program::interpolate_uni_poly}}; use super::{basefold::*, extension_mmcs::*, mmcs::*, rs::*, structs::*, utils::*}; @@ -13,6 +14,7 @@ pub type F = BabyBear; pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; +#[derive(Deserialize)] pub struct BatchOpening { pub opened_values: Vec>, pub opening_proof: MmcsProof, @@ -44,6 +46,7 @@ pub struct BatchOpeningVariable { pub opening_proof: MmcsProofVariable, } +#[derive(Deserialize)] pub struct CommitPhaseProofStep { pub sibling_value: E, pub opening_proof: MmcsProof, @@ -76,6 +79,7 @@ pub struct CommitPhaseProofStepVariable { pub opening_proof: MmcsProofVariable, } +#[derive(Deserialize)] pub struct QueryOpeningProof { pub witin_base_proof: BatchOpening, pub fixed_base_proof: Option, @@ -161,6 +165,7 @@ pub struct PointAndEvalsVariable { pub evals: Array>, } +#[derive(Deserialize)] pub struct QueryPhaseVerifierInput { pub max_num_var: usize, pub indices: Vec, @@ -660,3 +665,75 @@ pub(crate) fn batch_verifier_query_phase( }); builder.assert_eq::>(left, right); } + +pub mod tests { + use std::{fs::File, io::Read}; + + use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; + use openvm_native_circuit::{Native, NativeConfig}; + use openvm_native_compiler::asm::AsmBuilder; + use openvm_native_recursion::hints::Hintable; + use openvm_stark_backend::config::StarkGenericConfig; + use openvm_stark_sdk::{ + config::baby_bear_poseidon2::BabyBearPoseidon2Config, + p3_baby_bear::BabyBear, + }; + use p3_field::{extension::BinomialExtensionField, FieldAlgebra, FieldExtensionAlgebra}; + type SC = BabyBearPoseidon2Config; + + type F = BabyBear; + type E = BinomialExtensionField; + type EF = ::Challenge; + + use super::{batch_verifier_query_phase, QueryPhaseVerifierInput}; + + #[allow(dead_code)] + pub fn build_batch_verifier_query_phase() -> (Program, Vec>) { + // OpenVM DSL + let mut builder = AsmBuilder::::default(); + + // Witness inputs + let query_phase_input = QueryPhaseVerifierInput::read(&mut builder); + batch_verifier_query_phase(&mut builder, query_phase_input); + builder.halt(); + + // Pass in witness stream + let f = |n: usize| F::from_canonical_usize(n); + let mut witness_stream: Vec< + Vec>, + > = Vec::new(); + + // INPUT + let mut f = File::open("input.bin".to_string()).unwrap(); + let mut content: Vec = Vec::new(); + f.read_to_end(&mut content).unwrap(); + let input: QueryPhaseVerifierInput = bincode::deserialize(&content).unwrap(); + witness_stream.extend(input.write()); + + // PROGRAM + let program: Program< + p3_monty_31::MontyField31, + > = builder.compile_isa(); + + (program, witness_stream) + } + + #[test] + fn test_mmcs_verify_batch() { + let (program, witness) = build_batch_verifier_query_phase(); + + let system_config = SystemConfig::default() + .with_public_values(4) + .with_max_segment_len((1 << 25) - 100); + let config = NativeConfig::new(system_config, Native); + + let executor = VmExecutor::::new(config); + executor.execute(program, witness).unwrap(); + + // _debug + // let results = executor.execute_segments(program, witness).unwrap(); + // for seg in results { + // println!("=> cycle count: {:?}", seg.metrics.cycle_count); + // } + } +} \ No newline at end of file diff --git a/src/basefold_verifier/rs.rs b/src/basefold_verifier/rs.rs index a176d34..49188d9 100644 --- a/src/basefold_verifier/rs.rs +++ b/src/basefold_verifier/rs.rs @@ -1,9 +1,12 @@ // Note: check all XXX comments! +use std::{cell::RefCell, collections::BTreeMap}; + use openvm_native_compiler::{asm::AsmConfig, prelude::*}; use openvm_native_recursion::hints::Hintable; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; +use serde::Deserialize; use super::structs::*; @@ -103,8 +106,9 @@ pub fn verifier_folding_coeffs_level( } /// The DIT FFT algorithm. +#[derive(Deserialize)] pub struct Radix2Dit { - pub twiddles: Vec, + pub twiddles: RefCell>>, } impl Hintable for Radix2Dit { @@ -120,7 +124,9 @@ impl Hintable for Radix2Dit { fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); - stream.extend(self.twiddles.write()); + // XXX: process BTreeMap + let twiddles_vec: Vec = Vec::new(); + stream.extend(twiddles_vec.write()); stream } } @@ -164,6 +170,7 @@ impl Radix2DitVariable { } */ +#[derive(Deserialize)] pub struct RSCodeVerifierParameters { pub dft: Radix2Dit, pub t_inv_halves: Vec>, diff --git a/src/basefold_verifier/structs.rs b/src/basefold_verifier/structs.rs index 0d2cd55..5d2c324 100644 --- a/src/basefold_verifier/structs.rs +++ b/src/basefold_verifier/structs.rs @@ -3,6 +3,7 @@ use openvm_native_compiler_derive::DslVariable; use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; +use serde::Deserialize; pub const DIMENSIONS: usize = 4; @@ -21,6 +22,7 @@ pub struct CircuitIndexMetaVariable { pub fixed_num_polys: Usize, } +#[derive(Deserialize)] pub struct CircuitIndexMeta { pub witin_num_vars: usize, pub witin_num_polys: usize, diff --git a/src/tower_verifier/binding.rs b/src/tower_verifier/binding.rs index 686632d..e2af25a 100644 --- a/src/tower_verifier/binding.rs +++ b/src/tower_verifier/binding.rs @@ -17,6 +17,7 @@ use openvm_stark_sdk::{ p3_baby_bear::{BabyBear, Poseidon2BabyBear}, }; use p3_field::extension::BinomialExtensionField; +use serde::Deserialize; #[derive(DslVariable, Clone)] pub struct PointVariable { @@ -52,7 +53,7 @@ pub struct TowerVerifierInputVariable { pub logup_specs_eval: Array>>>, } -#[derive(Clone)] +#[derive(Clone, Deserialize)] pub struct Point { pub fs: Vec, } @@ -98,7 +99,7 @@ impl Hintable for PointAndEval { } impl VecAutoHintable for PointAndEval {} -#[derive(Debug)] +#[derive(Debug, Deserialize)] pub struct IOPProverMessage { pub evaluations: Vec, } From 0dc7b5642036655f6bfebae591d356bf577bb6d2 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Thu, 15 May 2025 17:07:25 +0800 Subject: [PATCH 23/70] Implement the naive encode small method --- src/basefold_verifier/rs.rs | 90 +++++++++++++++++++------------------ 1 file changed, 47 insertions(+), 43 deletions(-) diff --git a/src/basefold_verifier/rs.rs b/src/basefold_verifier/rs.rs index 49188d9..d180942 100644 --- a/src/basefold_verifier/rs.rs +++ b/src/basefold_verifier/rs.rs @@ -26,10 +26,7 @@ impl Hintable for DenseMatrix { let values = Vec::::read(builder); let width = usize::read(builder); - DenseMatrixVariable { - values, - width, - } + DenseMatrixVariable { values, width } } fn write(&self) -> Vec::N>> { @@ -48,19 +45,20 @@ pub struct DenseMatrixVariable { pub type RowMajorMatrixVariable = DenseMatrixVariable; impl DenseMatrixVariable { - pub fn height( - &self, - builder: &mut Builder, - ) -> Var { + pub fn height(&self, builder: &mut Builder) -> Var { // Supply height as hint let height = builder.hint_var(); - builder.if_eq(self.width.clone(), Usize::from(0)).then(|builder| { - builder.assert_usize_eq(height, Usize::from(0)); - }); - builder.if_ne(self.width.clone(), Usize::from(0)).then(|builder| { - // XXX: check that width * height is not a field multiplication - builder.assert_usize_eq(self.width.clone() * height, self.values.len()); - }); + builder + .if_eq(self.width.clone(), Usize::from(0)) + .then(|builder| { + builder.assert_usize_eq(height, Usize::from(0)); + }); + builder + .if_ne(self.width.clone(), Usize::from(0)) + .then(|builder| { + // XXX: check that width * height is not a field multiplication + builder.assert_usize_eq(self.width.clone() * height, self.values.len()); + }); height } @@ -76,15 +74,19 @@ impl DenseMatrixVariable { builder.assert_less_than_slow_small_rhs(old_height, new_height + RVar::from(1)); let new_size = builder.eval_expr(self.width.clone() * new_height.clone()); let evals: Array> = builder.dyn_array(new_size); - builder.range(0, self.values.len()).for_each(|i_vec, builder| { - let i = i_vec[0]; - let tmp: Ext = builder.get(&self.values, i); - builder.set(&evals, i, tmp); - }); - builder.range(self.values.len(), evals.len()).for_each(|i_vec, builder| { - let i = i_vec[0]; - builder.set(&evals, i, fill); - }); + builder + .range(0, self.values.len()) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + let tmp: Ext = builder.get(&self.values, i); + builder.set(&evals, i, tmp); + }); + builder + .range(self.values.len(), evals.len()) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + builder.set(&evals, i, fill); + }); builder.assign(&self.values, evals); } } @@ -117,9 +119,7 @@ impl Hintable for Radix2Dit { fn read(builder: &mut Builder) -> Self::HintVariable { let twiddles = Vec::::read(builder); - Radix2DitVariable { - twiddles, - } + Radix2DitVariable { twiddles } } fn write(&self) -> Vec::N>> { @@ -141,7 +141,7 @@ pub struct Radix2DitVariable { /* impl Radix2DitVariable { fn dft_batch( - &self, + &self, builder: &mut Builder, mat: RowMajorMatrixVariable ) -> RowMajorMatrixVariable { @@ -182,7 +182,7 @@ impl Hintable for RSCodeVerifierParameters { fn read(builder: &mut Builder) -> Self::HintVariable { let dft = Radix2Dit::read(builder); - let t_inv_halves = Vec::>::read(builder); + let t_inv_halves = Vec::>::read(builder); let full_message_size_log = Usize::Var(usize::read(builder)); RSCodeVerifierParametersVariable { @@ -196,7 +196,9 @@ impl Hintable for RSCodeVerifierParameters { let mut stream = Vec::new(); stream.extend(self.dft.write()); stream.extend(self.t_inv_halves.write()); - stream.extend(>::write(&self.full_message_size_log)); + stream.extend(>::write( + &self.full_message_size_log, + )); stream } } @@ -225,23 +227,26 @@ pub(crate) fn encode_small( } */ +/// Encode the last message sent from the prover to the verifier +/// in the commit phase. Currently, for simplicity, we drop the +/// early stopping strategy so the last message has just one +/// element, and the encoding is simply repeating this element +/// by the expansion rate. pub(crate) fn encode_small( builder: &mut Builder, _vp: RSCodeVerifierParametersVariable, - _rmm: RowMajorMatrixVariable, + rmm: RowMajorMatrixVariable, // Assumed to have only one row and one column ) -> RowMajorMatrixVariable { // XXX: nondeterministically supply the results for now - let len = builder.hint_var(); - let values = builder.dyn_array(len); - builder.range(0, len).for_each(|i_vec, builder| { + let result = builder.array(2); // Assume the expansion rate is fixed to 2 by now + let value = builder.get(&rmm.values, 0); + builder.range(0, 2).for_each(|i_vec, builder| { let i = i_vec[0]; - let next_input = builder.hint_ext(); - builder.set_value(&values, i, next_input); + builder.set_value(&result, i, value); }); - let width = builder.hint_var(); - DenseMatrixVariable { - values, - width, + DenseMatrixVariable { + values: result, + width: builder.eval(Usize::from(1)), } } @@ -253,8 +258,7 @@ pub mod tests { use openvm_native_recursion::hints::Hintable; use openvm_stark_backend::config::StarkGenericConfig; use openvm_stark_sdk::{ - config::baby_bear_poseidon2::BabyBearPoseidon2Config, - p3_baby_bear::BabyBear, + config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, }; use p3_field::{extension::BinomialExtensionField, FieldAlgebra}; type SC = BabyBearPoseidon2Config; @@ -314,4 +318,4 @@ pub mod tests { // println!("=> cycle count: {:?}", seg.metrics.cycle_count); // } } -} \ No newline at end of file +} From c06f3c9315fb5470201215901fee5845a34e8445 Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Thu, 15 May 2025 18:43:58 -0700 Subject: [PATCH 24/70] Bug workaround --- src/basefold_verifier/basefold.rs | 10 +++++----- src/basefold_verifier/query_phase.rs | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/basefold_verifier/basefold.rs b/src/basefold_verifier/basefold.rs index b083880..b8ed73c 100644 --- a/src/basefold_verifier/basefold.rs +++ b/src/basefold_verifier/basefold.rs @@ -15,7 +15,7 @@ pub type HashDigest = MmcsCommitment; pub struct BasefoldCommitment { pub commit: HashDigest, pub log2_max_codeword_size: usize, - pub trivial_commits: Vec, + // pub trivial_commits: Vec, } impl Hintable for BasefoldCommitment { @@ -24,12 +24,12 @@ impl Hintable for BasefoldCommitment { fn read(builder: &mut Builder) -> Self::HintVariable { let commit = HashDigest::read(builder); let log2_max_codeword_size = Usize::Var(usize::read(builder)); - let trivial_commits = Vec::::read(builder); + // let trivial_commits = Vec::::read(builder); BasefoldCommitmentVariable { commit, log2_max_codeword_size, - trivial_commits, + // trivial_commits, } } @@ -37,7 +37,7 @@ impl Hintable for BasefoldCommitment { let mut stream = Vec::new(); stream.extend(self.commit.write()); stream.extend(>::write(&self.log2_max_codeword_size)); - stream.extend(self.trivial_commits.write()); + // stream.extend(self.trivial_commits.write()); stream } } @@ -47,5 +47,5 @@ pub type HashDigestVariable = MmcsCommitmentVariable; pub struct BasefoldCommitmentVariable { pub commit: HashDigestVariable, pub log2_max_codeword_size: Usize, - pub trivial_commits: Array>, + // pub trivial_commits: Array>, } diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 7a77ef5..4ac3c11 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -235,7 +235,7 @@ impl Hintable for QueryPhaseVerifierInput { let tmp_comm = BasefoldCommitment { commit: Default::default(), log2_max_codeword_size: 0, - trivial_commits: Vec::new(), + // trivial_commits: Vec::new(), }; stream.extend(tmp_comm.write()); } From 4553fc3ae38adfc7bad99d340c90b29b5ac41f66 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Fri, 30 May 2025 10:01:51 +0800 Subject: [PATCH 25/70] Temp store: starting to use openvm mmcs instruction --- src/basefold_verifier/extension_mmcs.rs | 99 ++- src/basefold_verifier/mmcs.rs | 642 ++++++++++++------ src/basefold_verifier/query_phase.rs | 823 +++++++++++++----------- 3 files changed, 983 insertions(+), 581 deletions(-) diff --git a/src/basefold_verifier/extension_mmcs.rs b/src/basefold_verifier/extension_mmcs.rs index af65cc0..23d711f 100644 --- a/src/basefold_verifier/extension_mmcs.rs +++ b/src/basefold_verifier/extension_mmcs.rs @@ -34,8 +34,8 @@ impl Hintable for ExtMmcsVerifierInput { let commit = MmcsCommitment::read(builder); let dimensions = Vec::::read(builder); let index = usize::read(builder); - let opened_values = Vec::>::read(builder); - let proof = Vec::>::read(builder); + let opened_values = Vec::>::read(builder); + let proof = Vec::>::read(builder); ExtMmcsVerifierInputVariable { commit, @@ -52,7 +52,13 @@ impl Hintable for ExtMmcsVerifierInput { stream.extend(self.dimensions.write()); stream.extend(>::write(&self.index)); stream.extend(self.opened_values.write()); - stream.extend(self.proof.iter().map(|p| p.to_vec()).collect::>().write()); + stream.extend( + self.proof + .iter() + .map(|p| p.to_vec()) + .collect::>() + .write(), + ); stream } } @@ -74,35 +80,64 @@ pub(crate) fn ext_mmcs_verify_batch( let dim_factor: Var = builder.eval(Usize::from(C::EF::D)); let opened_base_values = builder.dyn_array(input.opened_values.len()); let base_dimensions = builder.dyn_array(input.dimensions.len()); - builder.range(0, input.opened_values.len()).for_each(|i_vec, builder| { - let i = i_vec[0]; - // opened_values - let next_opened_values = builder.get(&input.opened_values, i); - let next_opened_base_values_len: Var = builder.eval(next_opened_values.len() * dim_factor); - let next_opened_base_values = builder.dyn_array(next_opened_base_values_len); - let next_opened_base_index: Var = builder.eval(Usize::from(0)); - builder.range(0, next_opened_values.len()).for_each(|j_vec, builder| { - let j = j_vec[0]; - let next_opened_value = builder.get(&next_opened_values, j); - // XXX: how to convert Ext to [Felt]? - let next_opened_value_felt = builder.ext2felt(next_opened_value); - builder.range(0, next_opened_value_felt.len()).for_each(|k_vec, builder| { - let k = k_vec[0]; - let next_felt = builder.get(&next_opened_value_felt, k); - builder.set_value(&next_opened_base_values, next_opened_base_index, next_felt); - builder.assign(&next_opened_base_index, next_opened_base_index + Usize::from(1)); - }); + + builder + .range(0, input.opened_values.len()) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + // opened_values + let next_opened_values = builder.get(&input.opened_values, i); + let next_opened_base_values_len: Var = + builder.eval(next_opened_values.len() * dim_factor); + let next_opened_base_values = builder.dyn_array(next_opened_base_values_len); + let next_opened_base_index: Var = builder.eval(Usize::from(0)); + builder + .range(0, next_opened_values.len()) + .for_each(|j_vec, builder| { + let j = j_vec[0]; + let next_opened_value = builder.get(&next_opened_values, j); + // XXX: how to convert Ext to [Felt]? + let next_opened_value_felt = builder.ext2felt(next_opened_value); + builder + .range(0, next_opened_value_felt.len()) + .for_each(|k_vec, builder| { + let k = k_vec[0]; + let next_felt = builder.get(&next_opened_value_felt, k); + builder.set_value( + &next_opened_base_values, + next_opened_base_index, + next_felt, + ); + builder.assign( + &next_opened_base_index, + next_opened_base_index + Usize::from(1), + ); + }); + }); + builder.set_value(&opened_base_values, i, next_opened_base_values); + + // dimensions + let next_dimension = builder.get(&input.dimensions, i); + let next_base_dimension = DimensionsVariable { + width: builder.eval(next_dimension.width.clone() * dim_factor), + height: next_dimension.height.clone(), + }; + builder.set_value(&base_dimensions, i, next_base_dimension); }); - builder.set_value(&opened_base_values, i, next_opened_base_values); - - // dimensions - let next_dimension = builder.get(&input.dimensions, i); - let next_base_dimension = DimensionsVariable { - width: builder.eval(next_dimension.width.clone() * dim_factor), - height: next_dimension.height.clone(), - }; - builder.set_value(&base_dimensions, i, next_base_dimension); - }); + + let dimensions = match input.dimensions { + Array::Dyn(ptr, len) => Array::Dyn(ptr, len.clone()), + _ => panic!("Expected a dynamic array of felts"), + }; + + builder.verify_batch_ext( + &dimensions, + &input.opened_values, + &input.proof_id, + &input.index_bits, + &input.commit.value, + ); + let input = MmcsVerifierInputVariable { commit: input.commit, dimensions: base_dimensions, @@ -111,4 +146,4 @@ pub(crate) fn ext_mmcs_verify_batch( proof: input.proof, }; mmcs_verify_batch(builder, mmcs.inner, input); -} \ No newline at end of file +} diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index e052ffa..7499430 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -8,7 +8,7 @@ use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; use p3_field::FieldAlgebra; -use super::{structs::*, utils::*, hash::*}; +use super::{hash::*, structs::*, utils::*}; pub type F = BabyBear; pub type E = BinomialExtensionField; @@ -44,8 +44,8 @@ impl Hintable for MmcsVerifierInput { let commit = MmcsCommitment::read(builder); let dimensions = Vec::::read(builder); let index = usize::read(builder); - let opened_values = Vec::>::read(builder); - let proof = Vec::>::read(builder); + let opened_values = Vec::>::read(builder); + let proof = Vec::>::read(builder); MmcsVerifierInputVariable { commit, @@ -62,7 +62,13 @@ impl Hintable for MmcsVerifierInput { stream.extend(self.dimensions.write()); stream.extend(>::write(&self.index)); stream.extend(self.opened_values.write()); - stream.extend(self.proof.iter().map(|p| p.to_vec()).collect::>().write()); + stream.extend( + self.proof + .iter() + .map(|p| p.to_vec()) + .collect::>() + .write(), + ); stream } } @@ -72,7 +78,7 @@ pub type MmcsProofVariable = Array::F>>>; #[derive(DslVariable, Clone)] pub struct MmcsVerifierInputVariable { pub commit: MmcsCommitmentVariable, - pub dimensions: Array>, + pub dimensions: Array>, pub index: Var, pub opened_values: Array>>, pub proof: MmcsProofVariable, @@ -83,6 +89,18 @@ pub(crate) fn mmcs_verify_batch( _mmcs: MerkleTreeMmcsVariable, // self input: MmcsVerifierInputVariable, ) { + let dimensions = match input.dimensions { + Array::Dyn(ptr, len) => Array::Dyn(ptr, len.clone()), + _ => panic!("Expected a dynamic array of felts"), + }; + builder.verify_batch_felt( + &dimensions, + &input.opened_values, + input.proof_id, + input.index_bits, + &input.commit.value, + ); + // Check that the openings have the correct shape. let num_dims = input.dimensions.len(); // Assert dimensions is not empty @@ -94,19 +112,22 @@ pub(crate) fn mmcs_verify_batch( // TODO: Disabled for now, CirclePcs sometimes passes a height that's off by 1 bit. // Nondeterministically supplies max_height let max_height = builder.hint_var(); - builder.range(0, input.dimensions.len()).for_each(|i_vec, builder| { - let i = i_vec[0]; - let next_height = builder.get(&input.dimensions, i).height; - let max_height_plus_one: Var = builder.eval(max_height + Usize::from(1)); - builder.assert_less_than_slow_small_rhs(next_height, max_height_plus_one); - }); + builder + .range(0, input.dimensions.len()) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + let next_height = builder.get(&input.dimensions, i).height; + let max_height_plus_one: Var = builder.eval(max_height + Usize::from(1)); + builder.assert_less_than_slow_small_rhs(next_height, max_height_plus_one); + }); // Verify correspondence between log_h and h let log_max_height = builder.hint_var(); let log_max_height_minus_1: Var = builder.eval(log_max_height - Usize::from(1)); let purported_max_height_lower_bound: Var = pow_2(builder, log_max_height_minus_1); let two: Var = builder.constant(C::N::TWO); - let purported_max_height_upper_bound: Var = builder.eval(purported_max_height_lower_bound * two); + let purported_max_height_upper_bound: Var = + builder.eval(purported_max_height_lower_bound * two); builder.assert_less_than_slow_small_rhs(purported_max_height_lower_bound, max_height); builder.assert_less_than_slow_small_rhs(max_height, purported_max_height_upper_bound); builder.assert_usize_eq(input.proof.len(), log_max_height); @@ -115,15 +136,10 @@ pub(crate) fn mmcs_verify_batch( // 1. height_order: after sorting by decreasing height, the original index of each entry // 2. num_unique_height: number of different heights // 3. count_per_unique_height: for each unique height, number of dimensions of that height - let ( - height_order, - num_unique_height, - count_per_unique_height - ) = sort_with_count( - builder, - &input.dimensions, - |d: DimensionsVariable| d.height, - ); + let (height_order, num_unique_height, count_per_unique_height) = + sort_with_count(builder, &input.dimensions, |d: DimensionsVariable| { + d.height + }); // First padded_height let first_order = builder.get(&height_order, 0); @@ -133,12 +149,14 @@ pub(crate) fn mmcs_verify_batch( // Construct root through hashing let root_dims_count: Var = builder.get(&count_per_unique_height, 0); let root_values = builder.dyn_array(root_dims_count); - builder.range(0, root_dims_count).for_each(|i_vec, builder| { - let i = i_vec[0]; - let index = builder.get(&height_order, i); - let tmp = builder.get(&input.opened_values, index); - builder.set_value(&root_values, i, tmp); - }); + builder + .range(0, root_dims_count) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + let index = builder.get(&height_order, i); + let tmp = builder.get(&input.opened_values, index); + builder.set_value(&root_values, i, tmp); + }); let root = hash_iter_slices(builder, root_values); // Index_pow and reassembled_index for bit split @@ -148,72 +166,96 @@ pub(crate) fn mmcs_verify_batch( let next_unique_height_index: Var = builder.eval(Usize::from(1)); let cumul_dims_count: Var = builder.eval(root_dims_count); let next_height_padded: Var = builder.eval(Usize::from(0)); - builder.if_ne(num_unique_height, Usize::from(1)).then(|builder| { - let next_height = builder.get(&input.dimensions, cumul_dims_count).height; - let tmp_next_height_padded = next_power_of_two(builder, next_height); - builder.assign(&next_height_padded, tmp_next_height_padded); - }); - builder.range(0, input.proof.len()).for_each(|i_vec, builder| { - let i = i_vec[0]; - let sibling = builder.get(&input.proof, i); - let two_var: Var = builder.eval(Usize::from(2)); // XXX: is there a better way to do this? - // Supply the next index bit as hint, assert that it is a bit - let next_index_bit = builder.hint_var(); - builder.assert_var_eq(next_index_bit, next_index_bit * next_index_bit); - builder.assign(&reassembled_index, reassembled_index + index_pow * next_index_bit); - builder.assign(&index_pow, index_pow * two_var); - - // left, right - let compress_elem = builder.dyn_array(2); - builder.if_eq(next_index_bit, Usize::from(0)).then(|builder| { - // root, sibling - builder.set_value(&compress_elem, 0, root.clone()); - builder.set_value(&compress_elem, 0, sibling.clone()); - }); - builder.if_ne(next_index_bit, Usize::from(0)).then(|builder| { - // sibling, root - builder.set_value(&compress_elem, 0, sibling.clone()); - builder.set_value(&compress_elem, 0, root.clone()); + builder + .if_ne(num_unique_height, Usize::from(1)) + .then(|builder| { + let next_height = builder.get(&input.dimensions, cumul_dims_count).height; + let tmp_next_height_padded = next_power_of_two(builder, next_height); + builder.assign(&next_height_padded, tmp_next_height_padded); }); - let new_root = compress(builder, compress_elem); - builder.assign(&root, new_root); - - // curr_height_padded >>= 1 given curr_height_padded is a power of two - // Nondeterministically supply next_curr_height_padded - let next_curr_height_padded = builder.hint_var(); - builder.assert_var_eq(next_curr_height_padded * two_var, curr_height_padded); - builder.assign(&curr_height_padded, next_curr_height_padded); - - // determine whether next_height matches curr_height - builder.if_eq(curr_height_padded, next_height_padded).then(|builder| { - // hash opened_values of all dims of next_height to root - let root_dims_count = builder.get(&count_per_unique_height, next_unique_height_index); - let root_size: Var = builder.eval(root_dims_count + Usize::from(1)); - let root_values = builder.dyn_array(root_size); - builder.set_value(&root_values, 0, root.clone()); - builder.range(0, root_dims_count).for_each(|i_vec, builder| { - let i = i_vec[0]; - let index = builder.get(&height_order, i); - let tmp = builder.get(&input.opened_values, index); - let j = builder.eval_expr(i + RVar::from(1)); - builder.set_value(&root_values, j, tmp); - }); - let new_root = hash_iter_slices(builder, root_values); + builder + .range(0, input.proof.len()) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + let sibling = builder.get(&input.proof, i); + let two_var: Var = builder.eval(Usize::from(2)); // XXX: is there a better way to do this? + // Supply the next index bit as hint, assert that it is a bit + let next_index_bit = builder.hint_var(); + builder.assert_var_eq(next_index_bit, next_index_bit * next_index_bit); + builder.assign( + &reassembled_index, + reassembled_index + index_pow * next_index_bit, + ); + builder.assign(&index_pow, index_pow * two_var); + + // left, right + let compress_elem = builder.dyn_array(2); + builder + .if_eq(next_index_bit, Usize::from(0)) + .then(|builder| { + // root, sibling + builder.set_value(&compress_elem, 0, root.clone()); + builder.set_value(&compress_elem, 0, sibling.clone()); + }); + builder + .if_ne(next_index_bit, Usize::from(0)) + .then(|builder| { + // sibling, root + builder.set_value(&compress_elem, 0, sibling.clone()); + builder.set_value(&compress_elem, 0, root.clone()); + }); + let new_root = compress(builder, compress_elem); builder.assign(&root, new_root); - // Update parameters - builder.assign(&cumul_dims_count, cumul_dims_count + root_dims_count); - builder.assign(&next_unique_height_index, next_unique_height_index + Usize::from(1)); - builder.if_eq(next_unique_height_index, num_unique_height).then(|builder| { - builder.assign(&next_height_padded, Usize::from(0)); - }); - builder.if_ne(next_unique_height_index, num_unique_height).then(|builder| { - let next_height = builder.get(&input.dimensions, cumul_dims_count).height; - let next_tmp_height_padded = next_power_of_two(builder, next_height); - builder.assign(&next_height_padded, next_tmp_height_padded); - }); + // curr_height_padded >>= 1 given curr_height_padded is a power of two + // Nondeterministically supply next_curr_height_padded + let next_curr_height_padded = builder.hint_var(); + builder.assert_var_eq(next_curr_height_padded * two_var, curr_height_padded); + builder.assign(&curr_height_padded, next_curr_height_padded); + + // determine whether next_height matches curr_height + builder + .if_eq(curr_height_padded, next_height_padded) + .then(|builder| { + // hash opened_values of all dims of next_height to root + let root_dims_count = + builder.get(&count_per_unique_height, next_unique_height_index); + let root_size: Var = builder.eval(root_dims_count + Usize::from(1)); + let root_values = builder.dyn_array(root_size); + builder.set_value(&root_values, 0, root.clone()); + builder + .range(0, root_dims_count) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + let index = builder.get(&height_order, i); + let tmp = builder.get(&input.opened_values, index); + let j = builder.eval_expr(i + RVar::from(1)); + builder.set_value(&root_values, j, tmp); + }); + let new_root = hash_iter_slices(builder, root_values); + builder.assign(&root, new_root); + + // Update parameters + builder.assign(&cumul_dims_count, cumul_dims_count + root_dims_count); + builder.assign( + &next_unique_height_index, + next_unique_height_index + Usize::from(1), + ); + builder + .if_eq(next_unique_height_index, num_unique_height) + .then(|builder| { + builder.assign(&next_height_padded, Usize::from(0)); + }); + builder + .if_ne(next_unique_height_index, num_unique_height) + .then(|builder| { + let next_height = + builder.get(&input.dimensions, cumul_dims_count).height; + let next_tmp_height_padded = next_power_of_two(builder, next_height); + builder.assign(&next_height_padded, next_tmp_height_padded); + }); + }); }); - }); builder.assert_var_eq(reassembled_index, input.index); builder.range(0, DIGEST_ELEMS).for_each(|i_vec, builder| { let i = i_vec[0]; @@ -230,8 +272,7 @@ pub mod tests { use openvm_native_recursion::hints::Hintable; use openvm_stark_backend::config::StarkGenericConfig; use openvm_stark_sdk::{ - config::baby_bear_poseidon2::BabyBearPoseidon2Config, - p3_baby_bear::BabyBear, + config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, }; use p3_field::{extension::BinomialExtensionField, FieldAlgebra}; type SC = BabyBearPoseidon2Config; @@ -241,7 +282,7 @@ pub mod tests { type EF = ::Challenge; use crate::basefold_verifier::structs::Dimensions; - use super::{mmcs_verify_batch, MmcsCommitment, InnerConfig, MmcsVerifierInput}; + use super::{mmcs_verify_batch, InnerConfig, MmcsCommitment, MmcsVerifierInput}; #[allow(dead_code)] pub fn build_mmcs_verify_batch() -> (Program, Vec>) { @@ -269,27 +310,126 @@ pub mod tests { f(1773915204), f(380281369), f(383365269), - ] + ], }; let dimensions = vec![ - Dimensions { width: 8, height: 1 }, - Dimensions { width: 8, height: 1 }, - Dimensions { width: 8, height: 70 }, + Dimensions { + width: 8, + height: 1, + }, + Dimensions { + width: 8, + height: 1, + }, + Dimensions { + width: 8, + height: 70, + }, ]; let index = 6; let opened_values = vec![ - vec![f(774319227), f(1631186743), f(254325873), f(504149682), f(239740532), f(1126519109), f(1044404585), f(1274764277)], - vec![f(1486505160), f(631183960), f(329388712), f(1934479253), f(115532954), f(1978455077), f(66346996), f(821157541)], - vec![f(149196326), f(1186650877), f(1970038391), f(1893286029), f(1249658956), f(1618951617), f(419030634), f(1967997848)], + vec![ + f(774319227), + f(1631186743), + f(254325873), + f(504149682), + f(239740532), + f(1126519109), + f(1044404585), + f(1274764277), + ], + vec![ + f(1486505160), + f(631183960), + f(329388712), + f(1934479253), + f(115532954), + f(1978455077), + f(66346996), + f(821157541), + ], + vec![ + f(149196326), + f(1186650877), + f(1970038391), + f(1893286029), + f(1249658956), + f(1618951617), + f(419030634), + f(1967997848), + ], ]; let proof = vec![ - [f(845920358), f(1201648213), f(1087654550), f(264553580), f(633209321), f(877945079), f(1674449089), f(1062812099)], - [f(5498027), f(1901489519), f(179361222), f(41261871), f(1546446894), f(266690586), f(1882928070), f(844710372)], - [f(721245096), f(388358486), f(1443363461), f(1349470697), f(253624794), f(1359455861), f(237485093), f(1955099141)], - [f(1816731864), f(402719753), f(1972161922), f(693018227), f(1617207065), f(1848150948), f(360933015), f(669793414)], - [f(1746479395), f(457185725), f(1263857148), f(328668702), f(1743038915), f(582282833), f(927410326), f(376217274)], - [f(1146845382), f(1117439420), f(1622226137), f(1449227765), f(138752938), f(1251889563), f(1266915653), f(267248408)], - [f(1992750195), f(1604624754), f(1748646393), f(1777984113), f(861317745), f(564150089), f(1371546358), f(460033967)], + [ + f(845920358), + f(1201648213), + f(1087654550), + f(264553580), + f(633209321), + f(877945079), + f(1674449089), + f(1062812099), + ], + [ + f(5498027), + f(1901489519), + f(179361222), + f(41261871), + f(1546446894), + f(266690586), + f(1882928070), + f(844710372), + ], + [ + f(721245096), + f(388358486), + f(1443363461), + f(1349470697), + f(253624794), + f(1359455861), + f(237485093), + f(1955099141), + ], + [ + f(1816731864), + f(402719753), + f(1972161922), + f(693018227), + f(1617207065), + f(1848150948), + f(360933015), + f(669793414), + ], + [ + f(1746479395), + f(457185725), + f(1263857148), + f(328668702), + f(1743038915), + f(582282833), + f(927410326), + f(376217274), + ], + [ + f(1146845382), + f(1117439420), + f(1622226137), + f(1449227765), + f(138752938), + f(1251889563), + f(1266915653), + f(267248408), + ], + [ + f(1992750195), + f(1604624754), + f(1748646393), + f(1777984113), + f(861317745), + f(564150089), + f(1371546358), + f(460033967), + ], ]; let mmcs_input = MmcsVerifierInput { commit, @@ -314,179 +454,323 @@ pub mod tests { // curr_height_log witness_stream.extend(>::write(&6)); // root - witness_stream.extend(>::write(&F::from_canonical_usize(1782972889))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1782972889), + )); // root - witness_stream.extend(>::write(&F::from_canonical_usize(279434715))); + witness_stream.extend(>::write( + &F::from_canonical_usize(279434715), + )); // root - witness_stream.extend(>::write(&F::from_canonical_usize(1209301918))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1209301918), + )); // root - witness_stream.extend(>::write(&F::from_canonical_usize(1853868602))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1853868602), + )); // root - witness_stream.extend(>::write(&F::from_canonical_usize(883945353))); + witness_stream.extend(>::write( + &F::from_canonical_usize(883945353), + )); // root - witness_stream.extend(>::write(&F::from_canonical_usize(368353728))); + witness_stream.extend(>::write( + &F::from_canonical_usize(368353728), + )); // root - witness_stream.extend(>::write(&F::from_canonical_usize(1699837443))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1699837443), + )); // root - witness_stream.extend(>::write(&F::from_canonical_usize(908962698))); + witness_stream.extend(>::write( + &F::from_canonical_usize(908962698), + )); // next_height_log witness_stream.extend(>::write(&0)); // next_bit witness_stream.extend(>::write(&0)); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(271352274))); + witness_stream.extend(>::write( + &F::from_canonical_usize(271352274), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1918158485))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1918158485), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1538604111))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1538604111), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1122013445))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1122013445), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1844193149))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1844193149), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(501326061))); + witness_stream.extend(>::write( + &F::from_canonical_usize(501326061), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1508959271))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1508959271), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1549189152))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1549189152), + )); // next_curr_height_padded witness_stream.extend(>::write(&64)); // next_bit witness_stream.extend(>::write(&1)); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(222162520))); + witness_stream.extend(>::write( + &F::from_canonical_usize(222162520), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(785634830))); + witness_stream.extend(>::write( + &F::from_canonical_usize(785634830), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1461778378))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1461778378), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(836284568))); + witness_stream.extend(>::write( + &F::from_canonical_usize(836284568), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1141654637))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1141654637), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1339589042))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1339589042), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1081824021))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1081824021), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(698316542))); + witness_stream.extend(>::write( + &F::from_canonical_usize(698316542), + )); // next_curr_height_padded witness_stream.extend(>::write(&32)); // next_bit witness_stream.extend(>::write(&1)); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(567517164))); + witness_stream.extend(>::write( + &F::from_canonical_usize(567517164), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(915833994))); + witness_stream.extend(>::write( + &F::from_canonical_usize(915833994), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(621327606))); + witness_stream.extend(>::write( + &F::from_canonical_usize(621327606), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(476128789))); + witness_stream.extend(>::write( + &F::from_canonical_usize(476128789), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1976747536))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1976747536), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1385950652))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1385950652), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1416073024))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1416073024), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(862764478))); + witness_stream.extend(>::write( + &F::from_canonical_usize(862764478), + )); // next_curr_height_padded witness_stream.extend(>::write(&16)); // next_bit witness_stream.extend(>::write(&0)); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(822965313))); + witness_stream.extend(>::write( + &F::from_canonical_usize(822965313), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1036402058))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1036402058), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(117603799))); + witness_stream.extend(>::write( + &F::from_canonical_usize(117603799), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1087591966))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1087591966), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(443405499))); + witness_stream.extend(>::write( + &F::from_canonical_usize(443405499), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1334745091))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1334745091), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(901165815))); + witness_stream.extend(>::write( + &F::from_canonical_usize(901165815), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1187124281))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1187124281), + )); // next_curr_height_padded witness_stream.extend(>::write(&8)); // next_bit witness_stream.extend(>::write(&0)); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(875508647))); + witness_stream.extend(>::write( + &F::from_canonical_usize(875508647), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1313410483))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1313410483), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(355713834))); + witness_stream.extend(>::write( + &F::from_canonical_usize(355713834), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1976667383))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1976667383), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1804021525))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1804021525), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(294385081))); + witness_stream.extend(>::write( + &F::from_canonical_usize(294385081), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(669164730))); + witness_stream.extend(>::write( + &F::from_canonical_usize(669164730), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1187763617))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1187763617), + )); // next_curr_height_padded witness_stream.extend(>::write(&4)); // next_bit witness_stream.extend(>::write(&0)); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1992024140))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1992024140), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(439080849))); + witness_stream.extend(>::write( + &F::from_canonical_usize(439080849), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1032272714))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1032272714), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1304584689))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1304584689), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1795447062))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1795447062), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(859522945))); + witness_stream.extend(>::write( + &F::from_canonical_usize(859522945), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1661892383))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1661892383), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1980559722))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1980559722), + )); // next_curr_height_padded witness_stream.extend(>::write(&2)); // next_bit witness_stream.extend(>::write(&0)); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1121119596))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1121119596), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(369487248))); + witness_stream.extend(>::write( + &F::from_canonical_usize(369487248), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(834451573))); + witness_stream.extend(>::write( + &F::from_canonical_usize(834451573), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1120744826))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1120744826), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(758930984))); + witness_stream.extend(>::write( + &F::from_canonical_usize(758930984), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(632316631))); + witness_stream.extend(>::write( + &F::from_canonical_usize(632316631), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1593276657))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1593276657), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(507031465))); + witness_stream.extend(>::write( + &F::from_canonical_usize(507031465), + )); // next_curr_height_padded witness_stream.extend(>::write(&1)); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1715944678))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1715944678), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1204294900))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1204294900), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(59582177))); + witness_stream.extend(>::write( + &F::from_canonical_usize(59582177), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(320945505))); + witness_stream.extend(>::write( + &F::from_canonical_usize(320945505), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1470843790))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1470843790), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(1773915204))); + witness_stream.extend(>::write( + &F::from_canonical_usize(1773915204), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(380281369))); + witness_stream.extend(>::write( + &F::from_canonical_usize(380281369), + )); // new_root - witness_stream.extend(>::write(&F::from_canonical_usize(383365269))); + witness_stream.extend(>::write( + &F::from_canonical_usize(383365269), + )); // PROGRAM let program: Program< @@ -514,4 +798,4 @@ pub mod tests { // println!("=> cycle count: {:?}", seg.metrics.cycle_count); // } } -} \ No newline at end of file +} diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 4ac3c11..e39e8cd 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -7,8 +7,11 @@ use p3_field::extension::BinomialExtensionField; use p3_field::FieldAlgebra; use serde::Deserialize; -use crate::{arithmetics::{build_eq_x_r_vec_sequential_with_offset, eq_eval_with_index}, tower_verifier::{binding::*, program::interpolate_uni_poly}}; use super::{basefold::*, extension_mmcs::*, mmcs::*, rs::*, structs::*, utils::*}; +use crate::{ + arithmetics::{build_eq_x_r_vec_sequential_with_offset, eq_eval_with_index}, + tower_verifier::{binding::*, program::interpolate_uni_poly}, +}; pub type F = BabyBear; pub type E = BinomialExtensionField; @@ -21,23 +24,29 @@ pub struct BatchOpening { } impl Hintable for BatchOpening { - type HintVariable = BatchOpeningVariable; - - fn read(builder: &mut Builder) -> Self::HintVariable { - let opened_values = Vec::>::read(builder); - let opening_proof = Vec::>::read(builder); - BatchOpeningVariable { - opened_values, - opening_proof, - } - } - - fn write(&self) -> Vec::N>> { - let mut stream = Vec::new(); - stream.extend(self.opened_values.write()); - stream.extend(self.opening_proof.iter().map(|p| p.to_vec()).collect::>().write()); - stream - } + type HintVariable = BatchOpeningVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let opened_values = Vec::>::read(builder); + let opening_proof = Vec::>::read(builder); + BatchOpeningVariable { + opened_values, + opening_proof, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.opened_values.write()); + stream.extend( + self.opening_proof + .iter() + .map(|p| p.to_vec()) + .collect::>() + .write(), + ); + stream + } } #[derive(DslVariable, Clone)] @@ -54,23 +63,29 @@ pub struct CommitPhaseProofStep { impl Hintable for CommitPhaseProofStep { type HintVariable = CommitPhaseProofStepVariable; - + fn read(builder: &mut Builder) -> Self::HintVariable { let sibling_value = E::read(builder); - let opening_proof = Vec::>::read(builder); + let opening_proof = Vec::>::read(builder); CommitPhaseProofStepVariable { sibling_value, opening_proof, } } - + fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); stream.extend(self.sibling_value.write()); - stream.extend(self.opening_proof.iter().map(|p| p.to_vec()).collect::>().write()); + stream.extend( + self.opening_proof + .iter() + .map(|p| p.to_vec()) + .collect::>() + .write(), + ); stream } - } +} impl VecAutoHintable for CommitPhaseProofStep {} #[derive(DslVariable, Clone)] @@ -89,7 +104,7 @@ type QueryOpeningProofs = Vec; impl Hintable for QueryOpeningProof { type HintVariable = QueryOpeningProofVariable; - + fn read(builder: &mut Builder) -> Self::HintVariable { let witin_base_proof = BatchOpening::read(builder); let fixed_is_some = Usize::Var(usize::read(builder)); @@ -102,7 +117,7 @@ impl Hintable for QueryOpeningProof { commit_phase_openings, } } - + fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); stream.extend(self.witin_base_proof.write()); @@ -120,7 +135,7 @@ impl Hintable for QueryOpeningProof { stream.extend(self.commit_phase_openings.write()); stream } - } +} impl VecAutoHintable for QueryOpeningProof {} #[derive(DslVariable, Clone)] @@ -132,7 +147,6 @@ pub struct QueryOpeningProofVariable { } type QueryOpeningProofsVariable = Array>; - // NOTE: Different from PointAndEval in tower_verifier! pub struct PointAndEvals { pub point: Point, @@ -144,10 +158,7 @@ impl Hintable for PointAndEvals { fn read(builder: &mut Builder) -> Self::HintVariable { let point = Point::read(builder); let evals = Vec::::read(builder); - PointAndEvalsVariable { - point, - evals, - } + PointAndEvalsVariable { point, evals } } fn write(&self) -> Vec::N>> { @@ -184,12 +195,12 @@ pub struct QueryPhaseVerifierInput { impl Hintable for QueryPhaseVerifierInput { type HintVariable = QueryPhaseVerifierInputVariable; - + fn read(builder: &mut Builder) -> Self::HintVariable { let max_num_var = Usize::Var(usize::read(builder)); let indices = Vec::::read(builder); let vp = RSCodeVerifierParameters::read(builder); - let final_message = Vec::>::read(builder); + let final_message = Vec::>::read(builder); let batch_coeffs = Vec::::read(builder); let queries = QueryOpeningProofs::read(builder); let fixed_is_some = Usize::Var(usize::read(builder)); @@ -218,7 +229,7 @@ impl Hintable for QueryPhaseVerifierInput { point_evals, } } - + fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); stream.extend(>::write(&self.max_num_var)); @@ -244,12 +255,19 @@ impl Hintable for QueryPhaseVerifierInput { stream.extend(self.commits.write()); stream.extend(self.fold_challenges.write()); stream.extend(self.sumcheck_messages.write()); - stream.extend(self.point_evals.iter().map(|(p, e)| - PointAndEvals { point: p.clone(), evals: e.clone() } - ).collect::>().write()); + stream.extend( + self.point_evals + .iter() + .map(|(p, e)| PointAndEvals { + point: p.clone(), + evals: e.clone(), + }) + .collect::>() + .write(), + ); stream } - } +} #[derive(DslVariable, Clone)] pub struct QueryPhaseVerifierInputVariable { @@ -275,36 +293,38 @@ pub(crate) fn batch_verifier_query_phase( ) { // Nondeterministically supply inv_2 let inv_2 = builder.hint_felt(); - builder.assert_eq::>(inv_2 * C::F::from_canonical_usize(2), C::F::from_canonical_usize(1)); + builder.assert_eq::>( + inv_2 * C::F::from_canonical_usize(2), + C::F::from_canonical_usize(1), + ); // encode_small let final_rmm_values_len = builder.get(&input.final_message, 0).len(); let final_rmm_values = builder.dyn_array(final_rmm_values_len.clone()); - builder.range(0, final_rmm_values_len.clone()).for_each(|i_vec, builder| { - let i = i_vec[0]; - let row = builder.get(&input.final_message, i); - let sum = builder.constant(C::EF::ZERO); - builder.range(0, row.len()).for_each(|j_vec, builder| { - let j = j_vec[0]; - let row_j = builder.get(&row, j); - builder.assign(&sum, sum + row_j); + builder + .range(0, final_rmm_values_len.clone()) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + let row = builder.get(&input.final_message, i); + let sum = builder.constant(C::EF::ZERO); + builder.range(0, row.len()).for_each(|j_vec, builder| { + let j = j_vec[0]; + let row_j = builder.get(&row, j); + builder.assign(&sum, sum + row_j); + }); + builder.set_value(&final_rmm_values, i, sum); }); - builder.set_value(&final_rmm_values, i, sum); - }); let final_rmm = RowMajorMatrixVariable { values: final_rmm_values, width: builder.eval(Usize::from(1)), }; - let final_codeword = encode_small( - builder, - input.vp.clone(), - final_rmm, - ); + let final_codeword = encode_small(builder, input.vp.clone(), final_rmm); // XXX: we might need to add generics to MMCS to account for different field types let mmcs_ext: ExtensionMmcsVariable = Default::default(); let mmcs: MerkleTreeMmcsVariable = Default::default(); // can't use witin_comm.log2_max_codeword_size since it's untrusted - let log2_witin_max_codeword_size: Var = builder.eval(input.max_num_var.clone() + get_rate_log::()); + let log2_witin_max_codeword_size: Var = + builder.eval(input.max_num_var.clone() + get_rate_log::()); // Nondeterministically supply the index folding_sorted_order // Check that: @@ -314,11 +334,14 @@ pub(crate) fn batch_verifier_query_phase( // Infer witin_num_vars through index let folding_len = input.circuit_meta.len(); let zero: Ext = builder.constant(C::EF::ZERO); - let folding_sort_surjective: Array> = builder.dyn_array(folding_len.clone()); - builder.range(0, folding_len.clone()).for_each(|i_vec, builder| { - let i = i_vec[0]; - builder.set(&folding_sort_surjective, i, zero.clone()); - }); + let folding_sort_surjective: Array> = + builder.dyn_array(folding_len.clone()); + builder + .range(0, folding_len.clone()) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + builder.set(&folding_sort_surjective, i, zero.clone()); + }); // an vector with same length as circuit_meta, which is sorted by num_var in descending order and keep its index // for reverse lookup when retrieving next base codeword to involve into batching @@ -326,291 +349,339 @@ pub(crate) fn batch_verifier_query_phase( // 1. height_order: after sorting by decreasing height, the original index of each entry // 2. num_unique_height: number of different heights // 3. count_per_unique_height: for each unique height, number of dimensions of that height - let ( - folding_sorted_order_index, - _num_unique_num_vars, - count_per_unique_num_var - ) = sort_with_count( - builder, - &input.circuit_meta, - |m: CircuitIndexMetaVariable| m.witin_num_vars, - ); + let (folding_sorted_order_index, _num_unique_num_vars, count_per_unique_num_var) = + sort_with_count( + builder, + &input.circuit_meta, + |m: CircuitIndexMetaVariable| m.witin_num_vars, + ); - builder.range(0, input.indices.len()).for_each(|i_vec, builder| { - let i = i_vec[0]; - let idx = builder.get(&input.indices, i); - let query = builder.get(&input.queries, i); - let witin_opened_values = query.witin_base_proof.opened_values; - let witin_opening_proof = query.witin_base_proof.opening_proof; - let fixed_is_some = query.fixed_is_some; - let fixed_commit = query.fixed_base_proof; - let opening_ext = query.commit_phase_openings; - - // verify base oracle query proof - // refer to prover documentation for the reason of right shift by 1 - // Nondeterministically supply the bits of idx in BIG ENDIAN - // These are not only used by the right shift here but also later on idx_shift - let idx_len = builder.hint_var(); - let idx_bits: Array> = builder.dyn_array(idx_len); - builder.range(0, idx_len).for_each(|j_vec, builder| { - let j = j_vec[0]; - let next_bit = builder.hint_var(); - // Assert that it is a bit - builder.assert_eq::>(next_bit * next_bit, next_bit); - builder.set_value(&idx_bits, j, next_bit); - }); - // Right shift - let idx_len_minus_one: Var = builder.eval(idx_len - Usize::from(1)); - builder.assign(&idx_len, idx_len_minus_one); - let new_idx = bin_to_dec(builder, &idx_bits, idx_len); - let last_bit = builder.get(&idx_bits, idx_len); - builder.assert_eq::>(Usize::from(2) * new_idx + last_bit, idx); - builder.assign(&idx, new_idx); - - let (witin_dimensions, fixed_dimensions) = - get_base_codeword_dimensions(builder, input.circuit_meta.clone()); - // verify witness - let mmcs_verifier_input = MmcsVerifierInputVariable { - commit: input.witin_comm.commit.clone(), - dimensions: witin_dimensions, - index: idx, - opened_values: witin_opened_values.clone(), - proof: witin_opening_proof, - }; - mmcs_verify_batch(builder, mmcs.clone(), mmcs_verifier_input); - - // verify fixed - let fixed_commit_leafs = builder.dyn_array(0); - builder.if_eq(fixed_is_some.clone(), Usize::from(1)).then(|builder| { - let fixed_opened_values = fixed_commit.opened_values.clone(); - let fixed_opening_proof = fixed_commit.opening_proof.clone(); - // new_idx used by mmcs proof - let new_idx: Var = builder.eval(idx); - // Nondeterministically supply a hint: - // 0: input.fixed_comm.log2_max_codeword_size < log2_witin_max_codeword_size - // 1: >= - let branch_le = builder.hint_var(); - builder.if_eq(branch_le, Usize::from(0)).then(|builder| { - // input.fixed_comm.log2_max_codeword_size < log2_witin_max_codeword_size - builder.assert_less_than_slow_small_rhs(input.fixed_comm.log2_max_codeword_size.clone(), log2_witin_max_codeword_size); - // idx >> idx_shift - let idx_shift_remain: Var = builder.eval(idx_len - (log2_witin_max_codeword_size - input.fixed_comm.log2_max_codeword_size.clone())); - let tmp_idx = bin_to_dec(builder, &idx_bits, idx_shift_remain); - builder.assign(&new_idx, tmp_idx); - }); - builder.if_ne(branch_le, Usize::from(0)).then(|builder| { - // input.fixed_comm.log2_max_codeword_size >= log2_witin_max_codeword_size - let input_codeword_size_plus_one: Var = builder.eval(input.fixed_comm.log2_max_codeword_size.clone() + Usize::from(1)); - builder.assert_less_than_slow_small_rhs(log2_witin_max_codeword_size, input_codeword_size_plus_one); - // idx << -idx_shift - let idx_shift = builder.eval(input.fixed_comm.log2_max_codeword_size.clone() - log2_witin_max_codeword_size); - let idx_factor = pow_2(builder, idx_shift); - builder.assign(&new_idx, new_idx * idx_factor); + builder + .range(0, input.indices.len()) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + let idx = builder.get(&input.indices, i); + let query = builder.get(&input.queries, i); + let witin_opened_values = query.witin_base_proof.opened_values; + let witin_opening_proof = query.witin_base_proof.opening_proof; + let fixed_is_some = query.fixed_is_some; + let fixed_commit = query.fixed_base_proof; + let opening_ext = query.commit_phase_openings; + + // verify base oracle query proof + // refer to prover documentation for the reason of right shift by 1 + // Nondeterministically supply the bits of idx in BIG ENDIAN + // These are not only used by the right shift here but also later on idx_shift + let idx_len = builder.hint_var(); + let idx_bits: Array> = builder.dyn_array(idx_len); + builder.range(0, idx_len).for_each(|j_vec, builder| { + let j = j_vec[0]; + let next_bit = builder.hint_var(); + // Assert that it is a bit + builder.assert_eq::>(next_bit * next_bit, next_bit); + builder.set_value(&idx_bits, j, next_bit); }); + // Right shift + let idx_len_minus_one: Var = builder.eval(idx_len - Usize::from(1)); + builder.assign(&idx_len, idx_len_minus_one); + let new_idx = bin_to_dec(builder, &idx_bits, idx_len); + let last_bit = builder.get(&idx_bits, idx_len); + builder.assert_eq::>(Usize::from(2) * new_idx + last_bit, idx); + builder.assign(&idx, new_idx); + + let (witin_dimensions, fixed_dimensions) = + get_base_codeword_dimensions(builder, input.circuit_meta.clone()); // verify witness let mmcs_verifier_input = MmcsVerifierInputVariable { - commit: input.fixed_comm.commit.clone(), - dimensions: fixed_dimensions.clone(), - index: new_idx, - opened_values: fixed_opened_values.clone(), - proof: fixed_opening_proof, + commit: input.witin_comm.commit.clone(), + dimensions: witin_dimensions, + index: idx, + opened_values: witin_opened_values.clone(), + proof: witin_opening_proof, }; mmcs_verify_batch(builder, mmcs.clone(), mmcs_verifier_input); - builder.assign(&fixed_commit_leafs, fixed_opened_values); - }); - // base_codeword_lo_hi - let base_codeword_lo = builder.dyn_array(folding_len.clone()); - let base_codeword_hi = builder.dyn_array(folding_len.clone()); - builder.range(0, folding_len.clone()).for_each(|j_vec, builder| { - let j = j_vec[0]; - let circuit_meta = builder.get(&input.circuit_meta, j); - let witin_num_polys = circuit_meta.witin_num_polys; - let fixed_num_vars = circuit_meta.fixed_num_vars; - let fixed_num_polys = circuit_meta.fixed_num_polys; - let witin_leafs = builder.get(&witin_opened_values, j); - // lo_wit, hi_wit - let leafs_len_div_2 = builder.hint_var(); - let two: Var = builder.eval(Usize::from(2)); - builder.assert_eq::>(leafs_len_div_2 * two, witin_leafs.len()); // Can we assume that leafs.len() is even? - // Actual dot_product only computes the first num_polys terms (can we assume leafs_len_div_2 == num_polys?) - let lo_wit = dot_product(builder, - &input.batch_coeffs, - &witin_leafs, - ); - let hi_wit = dot_product_with_index(builder, - &input.batch_coeffs, - &witin_leafs, - Usize::from(0), - Usize::Var(leafs_len_div_2), - witin_num_polys.clone(), - ); - // lo_fixed, hi_fixed - let lo_fixed: Ext = builder.constant(C::EF::from_canonical_usize(0)); - let hi_fixed: Ext = builder.constant(C::EF::from_canonical_usize(0)); - builder.if_ne(fixed_num_vars, Usize::from(0)).then(|builder| { - let fixed_leafs = builder.get(&fixed_commit_leafs, j); - let leafs_len_div_2 = builder.hint_var(); - let two: Var = builder.eval(Usize::from(2)); - builder.assert_eq::>(leafs_len_div_2 * two, fixed_leafs.len()); // Can we assume that leafs.len() is even? - // Actual dot_product only computes the first num_polys terms (can we assume leafs_len_div_2 == num_polys?) - let tmp_lo_fixed = dot_product(builder, - &input.batch_coeffs, - &fixed_leafs, - ); - let tmp_hi_fixed = dot_product_with_index(builder, - &input.batch_coeffs, - &fixed_leafs, - Usize::from(0), - Usize::Var(leafs_len_div_2), - fixed_num_polys.clone(), - ); - builder.assign(&lo_fixed, tmp_lo_fixed); - builder.assign(&hi_fixed, tmp_hi_fixed); - }); - let lo: Ext = builder.eval(lo_wit + lo_fixed); - let hi: Ext = builder.eval(hi_wit + hi_fixed); - builder.set_value(&base_codeword_lo, j, lo); - builder.set_value(&base_codeword_hi, j, hi); - }); + // verify fixed + let fixed_commit_leafs = builder.dyn_array(0); + builder + .if_eq(fixed_is_some.clone(), Usize::from(1)) + .then(|builder| { + let fixed_opened_values = fixed_commit.opened_values.clone(); + let fixed_opening_proof = fixed_commit.opening_proof.clone(); + // new_idx used by mmcs proof + let new_idx: Var = builder.eval(idx); + // Nondeterministically supply a hint: + // 0: input.fixed_comm.log2_max_codeword_size < log2_witin_max_codeword_size + // 1: >= + let branch_le = builder.hint_var(); + builder.if_eq(branch_le, Usize::from(0)).then(|builder| { + // input.fixed_comm.log2_max_codeword_size < log2_witin_max_codeword_size + builder.assert_less_than_slow_small_rhs( + input.fixed_comm.log2_max_codeword_size.clone(), + log2_witin_max_codeword_size, + ); + // idx >> idx_shift + let idx_shift_remain: Var = builder.eval( + idx_len + - (log2_witin_max_codeword_size + - input.fixed_comm.log2_max_codeword_size.clone()), + ); + let tmp_idx = bin_to_dec(builder, &idx_bits, idx_shift_remain); + builder.assign(&new_idx, tmp_idx); + }); + builder.if_ne(branch_le, Usize::from(0)).then(|builder| { + // input.fixed_comm.log2_max_codeword_size >= log2_witin_max_codeword_size + let input_codeword_size_plus_one: Var = builder + .eval(input.fixed_comm.log2_max_codeword_size.clone() + Usize::from(1)); + builder.assert_less_than_slow_small_rhs( + log2_witin_max_codeword_size, + input_codeword_size_plus_one, + ); + // idx << -idx_shift + let idx_shift = builder.eval( + input.fixed_comm.log2_max_codeword_size.clone() + - log2_witin_max_codeword_size, + ); + let idx_factor = pow_2(builder, idx_shift); + builder.assign(&new_idx, new_idx * idx_factor); + }); + // verify witness + let mmcs_verifier_input = MmcsVerifierInputVariable { + commit: input.fixed_comm.commit.clone(), + dimensions: fixed_dimensions.clone(), + index: new_idx, + opened_values: fixed_opened_values.clone(), + proof: fixed_opening_proof, + }; + mmcs_verify_batch(builder, mmcs.clone(), mmcs_verifier_input); + builder.assign(&fixed_commit_leafs, fixed_opened_values); + }); - // fold and query - let cur_num_var: Var = builder.eval(input.max_num_var.clone()); - // let rounds: Var = builder.eval(cur_num_var - get_basecode_msg_size_log::() - Usize::from(1)); - let n_d_next_log: Var = builder.eval(cur_num_var - get_rate_log::() - Usize::from(1)); - // let n_d_next = pow_2(builder, n_d_next_log); - - // first folding challenge - let r = builder.get(&input.fold_challenges, 0); - let next_unique_num_vars_count: Var = builder.get(&count_per_unique_num_var, 0); - let folded: Ext = builder.constant(C::EF::ZERO); - builder.range(0, next_unique_num_vars_count).for_each(|j_vec, builder| { - let j = j_vec[0]; - let index = builder.get(&folding_sorted_order_index, j); - let lo = builder.get(&base_codeword_lo, index.clone()); - let hi = builder.get(&base_codeword_hi, index.clone()); - let level: Var = builder.eval(cur_num_var + get_rate_log::() - Usize::from(1)); - let coeffs = verifier_folding_coeffs_level(builder, &input.vp, level); - let coeff = builder.get(&coeffs, idx); - let fold = codeword_fold_with_challenge::(builder, lo, hi, r, coeff, inv_2); - builder.assign(&folded, folded + fold); - }); - let next_unique_num_vars_index: Var = builder.eval(Usize::from(1)); - let cumul_num_vars_count: Var = builder.eval(next_unique_num_vars_count); - let n_d_i_log: Var = builder.eval(n_d_next_log); - // let n_d_i: Var = builder.eval(n_d_next); - // zip_eq - builder.assert_eq::>(input.commits.len() + Usize::from(1), input.fold_challenges.len()); - builder.assert_eq::>(input.commits.len(), opening_ext.len()); - builder.range(0, input.commits.len()).for_each(|j_vec, builder| { - let j = j_vec[0]; - let pi_comm = builder.get(&input.commits, j); - let j_plus_one = builder.eval_expr(j + RVar::from(1)); - let r = builder.get(&input.fold_challenges, j_plus_one); - let leaf = builder.get(&opening_ext, j).sibling_value; - let proof = builder.get(&opening_ext, j).opening_proof; - builder.assign(&cur_num_var, cur_num_var - Usize::from(1)); - - // next folding challenges - let idx_len_minus_one: Var = builder.eval(idx_len - Usize::from(1)); - let is_interpolate_to_right_index = builder.get(&idx_bits, idx_len_minus_one); - let new_involved_codewords: Ext = builder.constant(C::EF::ZERO); - let next_unique_num_vars_count: Var = builder.get(&count_per_unique_num_var, next_unique_num_vars_index); - builder.range(0, next_unique_num_vars_count).for_each(|k_vec, builder| { - let k = builder.eval_expr(k_vec[0] + cumul_num_vars_count); - let index = builder.get(&folding_sorted_order_index, k); - let lo = builder.get(&base_codeword_lo, index.clone()); - let hi = builder.get(&base_codeword_hi, index.clone()); - builder.if_eq(is_interpolate_to_right_index, Usize::from(1)).then(|builder| { - builder.assign(&new_involved_codewords, new_involved_codewords + hi); + // base_codeword_lo_hi + let base_codeword_lo = builder.dyn_array(folding_len.clone()); + let base_codeword_hi = builder.dyn_array(folding_len.clone()); + builder + .range(0, folding_len.clone()) + .for_each(|j_vec, builder| { + let j = j_vec[0]; + let circuit_meta = builder.get(&input.circuit_meta, j); + let witin_num_polys = circuit_meta.witin_num_polys; + let fixed_num_vars = circuit_meta.fixed_num_vars; + let fixed_num_polys = circuit_meta.fixed_num_polys; + let witin_leafs = builder.get(&witin_opened_values, j); + // lo_wit, hi_wit + let leafs_len_div_2 = builder.hint_var(); + let two: Var = builder.eval(Usize::from(2)); + builder.assert_eq::>(leafs_len_div_2 * two, witin_leafs.len()); // Can we assume that leafs.len() is even? + // Actual dot_product only computes the first num_polys terms (can we assume leafs_len_div_2 == num_polys?) + let lo_wit = dot_product(builder, &input.batch_coeffs, &witin_leafs); + let hi_wit = dot_product_with_index( + builder, + &input.batch_coeffs, + &witin_leafs, + Usize::from(0), + Usize::Var(leafs_len_div_2), + witin_num_polys.clone(), + ); + // lo_fixed, hi_fixed + let lo_fixed: Ext = + builder.constant(C::EF::from_canonical_usize(0)); + let hi_fixed: Ext = + builder.constant(C::EF::from_canonical_usize(0)); + builder + .if_ne(fixed_num_vars, Usize::from(0)) + .then(|builder| { + let fixed_leafs = builder.get(&fixed_commit_leafs, j); + let leafs_len_div_2 = builder.hint_var(); + let two: Var = builder.eval(Usize::from(2)); + builder + .assert_eq::>(leafs_len_div_2 * two, fixed_leafs.len()); // Can we assume that leafs.len() is even? + // Actual dot_product only computes the first num_polys terms (can we assume leafs_len_div_2 == num_polys?) + let tmp_lo_fixed = + dot_product(builder, &input.batch_coeffs, &fixed_leafs); + let tmp_hi_fixed = dot_product_with_index( + builder, + &input.batch_coeffs, + &fixed_leafs, + Usize::from(0), + Usize::Var(leafs_len_div_2), + fixed_num_polys.clone(), + ); + builder.assign(&lo_fixed, tmp_lo_fixed); + builder.assign(&hi_fixed, tmp_hi_fixed); + }); + let lo: Ext = builder.eval(lo_wit + lo_fixed); + let hi: Ext = builder.eval(hi_wit + hi_fixed); + builder.set_value(&base_codeword_lo, j, lo); + builder.set_value(&base_codeword_hi, j, hi); }); - builder.if_ne(is_interpolate_to_right_index, Usize::from(1)).then(|builder| { - builder.assign(&new_involved_codewords, new_involved_codewords + lo); + + // fold and query + let cur_num_var: Var = builder.eval(input.max_num_var.clone()); + // let rounds: Var = builder.eval(cur_num_var - get_basecode_msg_size_log::() - Usize::from(1)); + let n_d_next_log: Var = + builder.eval(cur_num_var - get_rate_log::() - Usize::from(1)); + // let n_d_next = pow_2(builder, n_d_next_log); + + // first folding challenge + let r = builder.get(&input.fold_challenges, 0); + let next_unique_num_vars_count: Var = builder.get(&count_per_unique_num_var, 0); + let folded: Ext = builder.constant(C::EF::ZERO); + builder + .range(0, next_unique_num_vars_count) + .for_each(|j_vec, builder| { + let j = j_vec[0]; + let index = builder.get(&folding_sorted_order_index, j); + let lo = builder.get(&base_codeword_lo, index.clone()); + let hi = builder.get(&base_codeword_hi, index.clone()); + let level: Var = + builder.eval(cur_num_var + get_rate_log::() - Usize::from(1)); + let coeffs = verifier_folding_coeffs_level(builder, &input.vp, level); + let coeff = builder.get(&coeffs, idx); + let fold = codeword_fold_with_challenge::(builder, lo, hi, r, coeff, inv_2); + builder.assign(&folded, folded + fold); }); - }); - builder.assign(&cumul_num_vars_count, cumul_num_vars_count + next_unique_num_vars_count); - builder.assign(&next_unique_num_vars_index, next_unique_num_vars_index + Usize::from(1)); - - // leafs - let leafs = builder.dyn_array(2); - let new_leaf = builder.eval(folded + new_involved_codewords); - builder.if_eq(is_interpolate_to_right_index, Usize::from(1)).then(|builder| { - builder.set_value(&leafs, 0, leaf); - builder.set_value(&leafs, 1, new_leaf); - }); - builder.if_ne(is_interpolate_to_right_index, Usize::from(1)).then(|builder| { - builder.set_value(&leafs, 0, new_leaf); - builder.set_value(&leafs, 1, leaf); - }); - // idx >>= 1 - let idx_len_minus_one: Var = builder.eval(idx_len - Usize::from(1)); - builder.assign(&idx_len, idx_len_minus_one); - let new_idx = bin_to_dec(builder, &idx_bits, idx_len); - let last_bit = builder.get(&idx_bits, idx_len); - builder.assert_eq::>(Usize::from(2) * new_idx + last_bit, idx); - builder.assign(&idx, new_idx); - // n_d_i >> 1 - builder.assign(&n_d_i_log, n_d_i_log - Usize::from(1)); - let n_d_i = pow_2(builder, n_d_i_log); - // mmcs_ext.verify_batch - let dimensions = builder.uninit_fixed_array(1); - let two = builder.eval(Usize::from(2)); - builder.set_value(&dimensions, 0, DimensionsVariable { - width: two, - height: n_d_i.clone(), - }); - let opened_values = builder.uninit_fixed_array(1); - builder.set_value(&opened_values, 0, leafs.clone()); - let ext_mmcs_verifier_input = ExtMmcsVerifierInputVariable { - commit: pi_comm.clone(), - dimensions, - index: idx.clone(), - opened_values, - proof, - }; - ext_mmcs_verify_batch::(builder, mmcs_ext.clone(), ext_mmcs_verifier_input); - - let coeffs = verifier_folding_coeffs_level(builder, &input.vp, n_d_i_log.clone()); - let coeff = builder.get(&coeffs, idx.clone()); - let left = builder.get(&leafs, 0); - let right = builder.get(&leafs, 1); - let new_folded = codeword_fold_with_challenge( - builder, - left, - right, - r.clone(), - coeff, - inv_2 + let next_unique_num_vars_index: Var = builder.eval(Usize::from(1)); + let cumul_num_vars_count: Var = builder.eval(next_unique_num_vars_count); + let n_d_i_log: Var = builder.eval(n_d_next_log); + // let n_d_i: Var = builder.eval(n_d_next); + // zip_eq + builder.assert_eq::>( + input.commits.len() + Usize::from(1), + input.fold_challenges.len(), ); - builder.assign(&folded, new_folded); + builder.assert_eq::>(input.commits.len(), opening_ext.len()); + builder + .range(0, input.commits.len()) + .for_each(|j_vec, builder| { + let j = j_vec[0]; + let pi_comm = builder.get(&input.commits, j); + let j_plus_one = builder.eval_expr(j + RVar::from(1)); + let r = builder.get(&input.fold_challenges, j_plus_one); + let leaf = builder.get(&opening_ext, j).sibling_value; + let proof = builder.get(&opening_ext, j).opening_proof; + builder.assign(&cur_num_var, cur_num_var - Usize::from(1)); + + // next folding challenges + let idx_len_minus_one: Var = builder.eval(idx_len - Usize::from(1)); + let is_interpolate_to_right_index = builder.get(&idx_bits, idx_len_minus_one); + let new_involved_codewords: Ext = builder.constant(C::EF::ZERO); + let next_unique_num_vars_count: Var = + builder.get(&count_per_unique_num_var, next_unique_num_vars_index); + builder + .range(0, next_unique_num_vars_count) + .for_each(|k_vec, builder| { + let k = builder.eval_expr(k_vec[0] + cumul_num_vars_count); + let index = builder.get(&folding_sorted_order_index, k); + let lo = builder.get(&base_codeword_lo, index.clone()); + let hi = builder.get(&base_codeword_hi, index.clone()); + builder + .if_eq(is_interpolate_to_right_index, Usize::from(1)) + .then(|builder| { + builder.assign( + &new_involved_codewords, + new_involved_codewords + hi, + ); + }); + builder + .if_ne(is_interpolate_to_right_index, Usize::from(1)) + .then(|builder| { + builder.assign( + &new_involved_codewords, + new_involved_codewords + lo, + ); + }); + }); + builder.assign( + &cumul_num_vars_count, + cumul_num_vars_count + next_unique_num_vars_count, + ); + builder.assign( + &next_unique_num_vars_index, + next_unique_num_vars_index + Usize::from(1), + ); + + // leafs + let leafs = builder.dyn_array(2); + let new_leaf = builder.eval(folded + new_involved_codewords); + builder + .if_eq(is_interpolate_to_right_index, Usize::from(1)) + .then(|builder| { + builder.set_value(&leafs, 0, leaf); + builder.set_value(&leafs, 1, new_leaf); + }); + builder + .if_ne(is_interpolate_to_right_index, Usize::from(1)) + .then(|builder| { + builder.set_value(&leafs, 0, new_leaf); + builder.set_value(&leafs, 1, leaf); + }); + // idx >>= 1 + let idx_len_minus_one: Var = builder.eval(idx_len - Usize::from(1)); + builder.assign(&idx_len, idx_len_minus_one); + let new_idx = bin_to_dec(builder, &idx_bits, idx_len); + let last_bit = builder.get(&idx_bits, idx_len); + builder.assert_eq::>(Usize::from(2) * new_idx + last_bit, idx); + builder.assign(&idx, new_idx); + // n_d_i >> 1 + builder.assign(&n_d_i_log, n_d_i_log - Usize::from(1)); + let n_d_i = pow_2(builder, n_d_i_log); + // mmcs_ext.verify_batch + let dimensions = builder.uninit_fixed_array(1); + let two = builder.eval(Usize::from(2)); + builder.set_value( + &dimensions, + 0, + DimensionsVariable { + width: two, + height: n_d_i.clone(), + }, + ); + let opened_values = builder.uninit_fixed_array(1); + builder.set_value(&opened_values, 0, leafs.clone()); + let ext_mmcs_verifier_input = ExtMmcsVerifierInputVariable { + commit: pi_comm.clone(), + dimensions, + index: idx.clone(), + opened_values, + proof, + }; + ext_mmcs_verify_batch::(builder, mmcs_ext.clone(), ext_mmcs_verifier_input); + + let coeffs = + verifier_folding_coeffs_level(builder, &input.vp, n_d_i_log.clone()); + let coeff = builder.get(&coeffs, idx.clone()); + let left = builder.get(&leafs, 0); + let right = builder.get(&leafs, 1); + let new_folded = + codeword_fold_with_challenge(builder, left, right, r.clone(), coeff, inv_2); + builder.assign(&folded, new_folded); + }); + let final_value = builder.get(&final_codeword.values, idx.clone()); + builder.assert_eq::>(final_value, folded); }); - let final_value = builder.get(&final_codeword.values, idx.clone()); - builder.assert_eq::>(final_value, folded); - }); // 1. check initial claim match with first round sumcheck value let points = builder.dyn_array(input.batch_coeffs.len()); let next_point_index: Var = builder.eval(Usize::from(0)); - builder.range(0, input.point_evals.len()).for_each(|i_vec, builder| { - let i = i_vec[0]; - let evals = builder.get(&input.point_evals, i).evals; - let witin_num_vars = builder.get(&input.circuit_meta, i).witin_num_vars; - // we need to scale up with scalar for witin_num_vars < max_num_var - let scale_log = builder.eval(input.max_num_var.clone() - witin_num_vars); - let scale = pow_2(builder, scale_log); - builder.range(0, evals.len()).for_each(|j_vec, builder| { - let j = j_vec[0]; - let eval = builder.get(&evals, j); - let scaled_eval: Ext = builder.eval(eval * scale); - builder.set_value(&points, next_point_index, scaled_eval); - builder.assign(&next_point_index, next_point_index + Usize::from(1)); + builder + .range(0, input.point_evals.len()) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + let evals = builder.get(&input.point_evals, i).evals; + let witin_num_vars = builder.get(&input.circuit_meta, i).witin_num_vars; + // we need to scale up with scalar for witin_num_vars < max_num_var + let scale_log = builder.eval(input.max_num_var.clone() - witin_num_vars); + let scale = pow_2(builder, scale_log); + builder.range(0, evals.len()).for_each(|j_vec, builder| { + let j = j_vec[0]; + let eval = builder.get(&evals, j); + let scaled_eval: Ext = builder.eval(eval * scale); + builder.set_value(&points, next_point_index, scaled_eval); + builder.assign(&next_point_index, next_point_index + Usize::from(1)); + }); }); - }); - let left = dot_product( - builder, - &input.batch_coeffs, - &points, - ); + let left = dot_product(builder, &input.batch_coeffs, &points); let next_sumcheck_evals = builder.get(&input.sumcheck_messages, 0).evaluations; let eval0 = builder.get(&next_sumcheck_evals, 0); let eval1 = builder.get(&next_sumcheck_evals, 1); @@ -619,50 +690,63 @@ pub(crate) fn batch_verifier_query_phase( // 2. check every round of sumcheck match with prev claims let fold_len_minus_one: Var = builder.eval(input.fold_challenges.len() - Usize::from(1)); - builder.range(0, fold_len_minus_one).for_each(|i_vec, builder| { - let i = i_vec[0]; - let evals = builder.get(&input.sumcheck_messages, i).evaluations; - let challenge = builder.get(&input.fold_challenges, i); - let left = interpolate_uni_poly(builder, evals, challenge); - let i_plus_one = builder.eval_expr(i + Usize::from(1)); - let next_evals = builder.get(&input.sumcheck_messages, i_plus_one).evaluations; - let eval0 = builder.get(&next_evals, 0); - let eval1 = builder.get(&next_evals, 1); - let right: Ext = builder.eval(eval0 + eval1); - builder.assert_eq::>(left, right); - }); + builder + .range(0, fold_len_minus_one) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + let evals = builder.get(&input.sumcheck_messages, i).evaluations; + let challenge = builder.get(&input.fold_challenges, i); + let left = interpolate_uni_poly(builder, evals, challenge); + let i_plus_one = builder.eval_expr(i + Usize::from(1)); + let next_evals = builder + .get(&input.sumcheck_messages, i_plus_one) + .evaluations; + let eval0 = builder.get(&next_evals, 0); + let eval1 = builder.get(&next_evals, 1); + let right: Ext = builder.eval(eval0 + eval1); + builder.assert_eq::>(left, right); + }); // 3. check final evaluation are correct - let final_evals = builder.get(&input.sumcheck_messages, fold_len_minus_one.clone()).evaluations; + let final_evals = builder + .get(&input.sumcheck_messages, fold_len_minus_one.clone()) + .evaluations; let final_challenge = builder.get(&input.fold_challenges, fold_len_minus_one.clone()); let left = interpolate_uni_poly(builder, final_evals, final_challenge); let right: Ext = builder.constant(C::EF::ZERO); - builder.range(0, input.final_message.len()).for_each(|i_vec, builder| { - let i = i_vec[0]; - let final_message = builder.get(&input.final_message, i); - let point = builder.get(&input.point_evals, i).point; - // coeff is the eq polynomial evaluated at the first challenge.len() variables - let num_vars_evaluated: Var = builder.eval(point.fs.len() - get_basecode_msg_size_log::()); - let ylo = builder.eval(input.fold_challenges.len() - num_vars_evaluated); - let coeff = eq_eval_with_index( - builder, - &point.fs, - &input.fold_challenges, - Usize::from(0), - Usize::Var(ylo), - Usize::Var(num_vars_evaluated), - ); - let eq = build_eq_x_r_vec_sequential_with_offset::(builder, &point.fs, Usize::Var(num_vars_evaluated)); - let eq_coeff = builder.dyn_array(eq.len()); - builder.range(0, eq.len()).for_each(|j_vec, builder| { - let j = j_vec[0]; - let next_eq = builder.get(&eq, j); - let next_eq_coeff: Ext = builder.eval(next_eq * coeff); - builder.set_value(&eq_coeff, j, next_eq_coeff); + builder + .range(0, input.final_message.len()) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + let final_message = builder.get(&input.final_message, i); + let point = builder.get(&input.point_evals, i).point; + // coeff is the eq polynomial evaluated at the first challenge.len() variables + let num_vars_evaluated: Var = + builder.eval(point.fs.len() - get_basecode_msg_size_log::()); + let ylo = builder.eval(input.fold_challenges.len() - num_vars_evaluated); + let coeff = eq_eval_with_index( + builder, + &point.fs, + &input.fold_challenges, + Usize::from(0), + Usize::Var(ylo), + Usize::Var(num_vars_evaluated), + ); + let eq = build_eq_x_r_vec_sequential_with_offset::( + builder, + &point.fs, + Usize::Var(num_vars_evaluated), + ); + let eq_coeff = builder.dyn_array(eq.len()); + builder.range(0, eq.len()).for_each(|j_vec, builder| { + let j = j_vec[0]; + let next_eq = builder.get(&eq, j); + let next_eq_coeff: Ext = builder.eval(next_eq * coeff); + builder.set_value(&eq_coeff, j, next_eq_coeff); + }); + let dot_prod = dot_product(builder, &final_message, &eq_coeff); + builder.assign(&right, right + dot_prod); }); - let dot_prod = dot_product(builder, &final_message, &eq_coeff); - builder.assign(&right, right + dot_prod); - }); builder.assert_eq::>(left, right); } @@ -675,8 +759,7 @@ pub mod tests { use openvm_native_recursion::hints::Hintable; use openvm_stark_backend::config::StarkGenericConfig; use openvm_stark_sdk::{ - config::baby_bear_poseidon2::BabyBearPoseidon2Config, - p3_baby_bear::BabyBear, + config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, }; use p3_field::{extension::BinomialExtensionField, FieldAlgebra, FieldExtensionAlgebra}; type SC = BabyBearPoseidon2Config; @@ -702,7 +785,7 @@ pub mod tests { let mut witness_stream: Vec< Vec>, > = Vec::new(); - + // INPUT let mut f = File::open("input.bin".to_string()).unwrap(); let mut content: Vec = Vec::new(); @@ -736,4 +819,4 @@ pub mod tests { // println!("=> cycle count: {:?}", seg.metrics.cycle_count); // } } -} \ No newline at end of file +} From 39fa15ce6612fe7618bb89947248e61cf1ba1961 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Fri, 30 May 2025 10:09:17 +0800 Subject: [PATCH 26/70] Temp store: index to bits --- src/basefold_verifier/extension_mmcs.rs | 6 +++--- src/basefold_verifier/mmcs.rs | 10 +++++----- src/basefold_verifier/query_phase.rs | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/basefold_verifier/extension_mmcs.rs b/src/basefold_verifier/extension_mmcs.rs index 23d711f..7025e3e 100644 --- a/src/basefold_verifier/extension_mmcs.rs +++ b/src/basefold_verifier/extension_mmcs.rs @@ -40,7 +40,7 @@ impl Hintable for ExtMmcsVerifierInput { ExtMmcsVerifierInputVariable { commit, dimensions, - index, + index_bits: index, opened_values, proof, } @@ -67,7 +67,7 @@ impl Hintable for ExtMmcsVerifierInput { pub struct ExtMmcsVerifierInputVariable { pub commit: MmcsCommitmentVariable, pub dimensions: Array>, - pub index: Var, + pub index_bits: Array>, pub opened_values: Array>>, pub proof: MmcsProofVariable, } @@ -141,7 +141,7 @@ pub(crate) fn ext_mmcs_verify_batch( let input = MmcsVerifierInputVariable { commit: input.commit, dimensions: base_dimensions, - index: input.index, + index_bits: input.index_bits, opened_values: opened_base_values, proof: input.proof, }; diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index 7499430..01f8d0c 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -50,7 +50,7 @@ impl Hintable for MmcsVerifierInput { MmcsVerifierInputVariable { commit, dimensions, - index, + index_bits: index, opened_values, proof, } @@ -79,7 +79,7 @@ pub type MmcsProofVariable = Array::F>>>; pub struct MmcsVerifierInputVariable { pub commit: MmcsCommitmentVariable, pub dimensions: Array>, - pub index: Var, + pub index_bits: Array>, pub opened_values: Array>>, pub proof: MmcsProofVariable, } @@ -96,8 +96,8 @@ pub(crate) fn mmcs_verify_batch( builder.verify_batch_felt( &dimensions, &input.opened_values, - input.proof_id, - input.index_bits, + &input.proof_id, + &input.index_bits, &input.commit.value, ); @@ -256,7 +256,7 @@ pub(crate) fn mmcs_verify_batch( }); }); }); - builder.assert_var_eq(reassembled_index, input.index); + builder.assert_var_eq(reassembled_index, input.index_bits); builder.range(0, DIGEST_ELEMS).for_each(|i_vec, builder| { let i = i_vec[0]; let next_input = builder.get(&input.commit.value, i); diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index e39e8cd..a694796 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -272,7 +272,7 @@ impl Hintable for QueryPhaseVerifierInput { #[derive(DslVariable, Clone)] pub struct QueryPhaseVerifierInputVariable { pub max_num_var: Usize, - pub indices: Array>, + pub indices: Array>>, pub vp: RSCodeVerifierParametersVariable, pub final_message: Array>>, pub batch_coeffs: Array>, @@ -395,7 +395,7 @@ pub(crate) fn batch_verifier_query_phase( let mmcs_verifier_input = MmcsVerifierInputVariable { commit: input.witin_comm.commit.clone(), dimensions: witin_dimensions, - index: idx, + index_bits: idx, opened_values: witin_opened_values.clone(), proof: witin_opening_proof, }; @@ -449,7 +449,7 @@ pub(crate) fn batch_verifier_query_phase( let mmcs_verifier_input = MmcsVerifierInputVariable { commit: input.fixed_comm.commit.clone(), dimensions: fixed_dimensions.clone(), - index: new_idx, + index_bits: new_idx, opened_values: fixed_opened_values.clone(), proof: fixed_opening_proof, }; @@ -642,7 +642,7 @@ pub(crate) fn batch_verifier_query_phase( let ext_mmcs_verifier_input = ExtMmcsVerifierInputVariable { commit: pi_comm.clone(), dimensions, - index: idx.clone(), + index_bits: idx.clone(), opened_values, proof, }; From 2acf9eb2c8540e4d6b49cc2dabc300afcd3cf4c3 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Fri, 30 May 2025 10:19:29 +0800 Subject: [PATCH 27/70] Temp store: clean up mmcs --- src/basefold_verifier/extension_mmcs.rs | 6 +++--- src/basefold_verifier/mmcs.rs | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/basefold_verifier/extension_mmcs.rs b/src/basefold_verifier/extension_mmcs.rs index 7025e3e..89ca08f 100644 --- a/src/basefold_verifier/extension_mmcs.rs +++ b/src/basefold_verifier/extension_mmcs.rs @@ -1,5 +1,5 @@ use openvm_native_compiler::{asm::AsmConfig, prelude::*}; -use openvm_native_recursion::hints::Hintable; +use openvm_native_recursion::{hints::Hintable, vars::HintSlice}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; use p3_field::FieldExtensionAlgebra; @@ -69,7 +69,7 @@ pub struct ExtMmcsVerifierInputVariable { pub dimensions: Array>, pub index_bits: Array>, pub opened_values: Array>>, - pub proof: MmcsProofVariable, + pub proof: HintSlice, } pub(crate) fn ext_mmcs_verify_batch( @@ -133,7 +133,7 @@ pub(crate) fn ext_mmcs_verify_batch( builder.verify_batch_ext( &dimensions, &input.opened_values, - &input.proof_id, + &input.proof.id.get_var(), &input.index_bits, &input.commit.value, ); diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index 01f8d0c..419d4a4 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use openvm_native_compiler::{asm::AsmConfig, prelude::*}; -use openvm_native_recursion::hints::Hintable; +use openvm_native_recursion::{hints::Hintable, vars::HintSlice}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; use p3_field::FieldAlgebra; @@ -81,7 +81,7 @@ pub struct MmcsVerifierInputVariable { pub dimensions: Array>, pub index_bits: Array>, pub opened_values: Array>>, - pub proof: MmcsProofVariable, + pub proof: HintSlice, } pub(crate) fn mmcs_verify_batch( @@ -96,7 +96,7 @@ pub(crate) fn mmcs_verify_batch( builder.verify_batch_felt( &dimensions, &input.opened_values, - &input.proof_id, + input.proof.id.get_var(), &input.index_bits, &input.commit.value, ); From 4ffe36162be784ad785f607fb8d669c3f1966772 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Fri, 30 May 2025 17:22:24 +0800 Subject: [PATCH 28/70] Temp store: clear compilation errors --- src/basefold_verifier/extension_mmcs.rs | 81 ++--------- src/basefold_verifier/mmcs.rs | 182 ++---------------------- src/basefold_verifier/query_phase.rs | 44 +++--- src/basefold_verifier/structs.rs | 80 +++++------ 4 files changed, 85 insertions(+), 302 deletions(-) diff --git a/src/basefold_verifier/extension_mmcs.rs b/src/basefold_verifier/extension_mmcs.rs index 89ca08f..3525cc5 100644 --- a/src/basefold_verifier/extension_mmcs.rs +++ b/src/basefold_verifier/extension_mmcs.rs @@ -2,7 +2,6 @@ use openvm_native_compiler::{asm::AsmConfig, prelude::*}; use openvm_native_recursion::{hints::Hintable, vars::HintSlice}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; -use p3_field::FieldExtensionAlgebra; use super::{mmcs::*, structs::*}; @@ -32,15 +31,17 @@ impl Hintable for ExtMmcsVerifierInput { fn read(builder: &mut Builder) -> Self::HintVariable { let commit = MmcsCommitment::read(builder); - let dimensions = Vec::::read(builder); - let index = usize::read(builder); + let dimensions = Vec::::read(builder); + let index_bits = Vec::::read(builder); let opened_values = Vec::>::read(builder); - let proof = Vec::>::read(builder); + let length = Usize::from(builder.hint_var()); + let id = Usize::from(builder.hint_load()); + let proof = HintSlice { length, id }; ExtMmcsVerifierInputVariable { commit, dimensions, - index_bits: index, + index_bits, opened_values, proof, } @@ -50,7 +51,13 @@ impl Hintable for ExtMmcsVerifierInput { let mut stream = Vec::new(); stream.extend(self.commit.write()); stream.extend(self.dimensions.write()); - stream.extend(>::write(&self.index)); + let mut index_bits = Vec::new(); + let mut index = self.index; + while index > 0 { + index_bits.push(index % 2); + index /= 2; + } + stream.extend( as Hintable>::write(&index_bits)); stream.extend(self.opened_values.write()); stream.extend( self.proof @@ -66,7 +73,7 @@ impl Hintable for ExtMmcsVerifierInput { #[derive(DslVariable, Clone)] pub struct ExtMmcsVerifierInputVariable { pub commit: MmcsCommitmentVariable, - pub dimensions: Array>, + pub dimensions: Array>, pub index_bits: Array>, pub opened_values: Array>>, pub proof: HintSlice, @@ -74,57 +81,8 @@ pub struct ExtMmcsVerifierInputVariable { pub(crate) fn ext_mmcs_verify_batch( builder: &mut Builder, - mmcs: ExtensionMmcsVariable, // self input: ExtMmcsVerifierInputVariable, ) { - let dim_factor: Var = builder.eval(Usize::from(C::EF::D)); - let opened_base_values = builder.dyn_array(input.opened_values.len()); - let base_dimensions = builder.dyn_array(input.dimensions.len()); - - builder - .range(0, input.opened_values.len()) - .for_each(|i_vec, builder| { - let i = i_vec[0]; - // opened_values - let next_opened_values = builder.get(&input.opened_values, i); - let next_opened_base_values_len: Var = - builder.eval(next_opened_values.len() * dim_factor); - let next_opened_base_values = builder.dyn_array(next_opened_base_values_len); - let next_opened_base_index: Var = builder.eval(Usize::from(0)); - builder - .range(0, next_opened_values.len()) - .for_each(|j_vec, builder| { - let j = j_vec[0]; - let next_opened_value = builder.get(&next_opened_values, j); - // XXX: how to convert Ext to [Felt]? - let next_opened_value_felt = builder.ext2felt(next_opened_value); - builder - .range(0, next_opened_value_felt.len()) - .for_each(|k_vec, builder| { - let k = k_vec[0]; - let next_felt = builder.get(&next_opened_value_felt, k); - builder.set_value( - &next_opened_base_values, - next_opened_base_index, - next_felt, - ); - builder.assign( - &next_opened_base_index, - next_opened_base_index + Usize::from(1), - ); - }); - }); - builder.set_value(&opened_base_values, i, next_opened_base_values); - - // dimensions - let next_dimension = builder.get(&input.dimensions, i); - let next_base_dimension = DimensionsVariable { - width: builder.eval(next_dimension.width.clone() * dim_factor), - height: next_dimension.height.clone(), - }; - builder.set_value(&base_dimensions, i, next_base_dimension); - }); - let dimensions = match input.dimensions { Array::Dyn(ptr, len) => Array::Dyn(ptr, len.clone()), _ => panic!("Expected a dynamic array of felts"), @@ -133,17 +91,8 @@ pub(crate) fn ext_mmcs_verify_batch( builder.verify_batch_ext( &dimensions, &input.opened_values, - &input.proof.id.get_var(), + input.proof.id.get_var(), &input.index_bits, &input.commit.value, ); - - let input = MmcsVerifierInputVariable { - commit: input.commit, - dimensions: base_dimensions, - index_bits: input.index_bits, - opened_values: opened_base_values, - proof: input.proof, - }; - mmcs_verify_batch(builder, mmcs.inner, input); } diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index 419d4a4..ff0e76a 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -6,9 +6,8 @@ use openvm_native_compiler::{asm::AsmConfig, prelude::*}; use openvm_native_recursion::{hints::Hintable, vars::HintSlice}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; -use p3_field::FieldAlgebra; -use super::{hash::*, structs::*, utils::*}; +use super::{hash::*, structs::*}; pub type F = BabyBear; pub type E = BinomialExtensionField; @@ -42,15 +41,17 @@ impl Hintable for MmcsVerifierInput { fn read(builder: &mut Builder) -> Self::HintVariable { let commit = MmcsCommitment::read(builder); - let dimensions = Vec::::read(builder); - let index = usize::read(builder); + let dimensions = Vec::::read(builder); + let index_bits = Vec::::read(builder); let opened_values = Vec::>::read(builder); - let proof = Vec::>::read(builder); + let length = Usize::from(builder.hint_var()); + let id = Usize::from(builder.hint_load()); + let proof = HintSlice { length, id }; MmcsVerifierInputVariable { commit, dimensions, - index_bits: index, + index_bits, opened_values, proof, } @@ -78,7 +79,7 @@ pub type MmcsProofVariable = Array::F>>>; #[derive(DslVariable, Clone)] pub struct MmcsVerifierInputVariable { pub commit: MmcsCommitmentVariable, - pub dimensions: Array>, + pub dimensions: Array>, pub index_bits: Array>, pub opened_values: Array>>, pub proof: HintSlice, @@ -86,7 +87,6 @@ pub struct MmcsVerifierInputVariable { pub(crate) fn mmcs_verify_batch( builder: &mut Builder, - _mmcs: MerkleTreeMmcsVariable, // self input: MmcsVerifierInputVariable, ) { let dimensions = match input.dimensions { @@ -100,169 +100,6 @@ pub(crate) fn mmcs_verify_batch( &input.index_bits, &input.commit.value, ); - - // Check that the openings have the correct shape. - let num_dims = input.dimensions.len(); - // Assert dimensions is not empty - builder.assert_nonzero(&num_dims); - builder.assert_usize_eq(num_dims.clone(), input.opened_values.len()); - - // TODO: Disabled for now since TwoAdicFriPcs and CirclePcs currently pass 0 for width. - - // TODO: Disabled for now, CirclePcs sometimes passes a height that's off by 1 bit. - // Nondeterministically supplies max_height - let max_height = builder.hint_var(); - builder - .range(0, input.dimensions.len()) - .for_each(|i_vec, builder| { - let i = i_vec[0]; - let next_height = builder.get(&input.dimensions, i).height; - let max_height_plus_one: Var = builder.eval(max_height + Usize::from(1)); - builder.assert_less_than_slow_small_rhs(next_height, max_height_plus_one); - }); - - // Verify correspondence between log_h and h - let log_max_height = builder.hint_var(); - let log_max_height_minus_1: Var = builder.eval(log_max_height - Usize::from(1)); - let purported_max_height_lower_bound: Var = pow_2(builder, log_max_height_minus_1); - let two: Var = builder.constant(C::N::TWO); - let purported_max_height_upper_bound: Var = - builder.eval(purported_max_height_lower_bound * two); - builder.assert_less_than_slow_small_rhs(purported_max_height_lower_bound, max_height); - builder.assert_less_than_slow_small_rhs(max_height, purported_max_height_upper_bound); - builder.assert_usize_eq(input.proof.len(), log_max_height); - - // Sort input.dimensions by height, returns - // 1. height_order: after sorting by decreasing height, the original index of each entry - // 2. num_unique_height: number of different heights - // 3. count_per_unique_height: for each unique height, number of dimensions of that height - let (height_order, num_unique_height, count_per_unique_height) = - sort_with_count(builder, &input.dimensions, |d: DimensionsVariable| { - d.height - }); - - // First padded_height - let first_order = builder.get(&height_order, 0); - let first_height = builder.get(&input.dimensions, first_order).height; - let curr_height_padded = next_power_of_two(builder, first_height); - - // Construct root through hashing - let root_dims_count: Var = builder.get(&count_per_unique_height, 0); - let root_values = builder.dyn_array(root_dims_count); - builder - .range(0, root_dims_count) - .for_each(|i_vec, builder| { - let i = i_vec[0]; - let index = builder.get(&height_order, i); - let tmp = builder.get(&input.opened_values, index); - builder.set_value(&root_values, i, tmp); - }); - let root = hash_iter_slices(builder, root_values); - - // Index_pow and reassembled_index for bit split - let index_pow: Var = builder.eval(Usize::from(1)); - let reassembled_index: Var = builder.eval(Usize::from(0)); - // next_height is the height of the next dim to be incorporated into root - let next_unique_height_index: Var = builder.eval(Usize::from(1)); - let cumul_dims_count: Var = builder.eval(root_dims_count); - let next_height_padded: Var = builder.eval(Usize::from(0)); - builder - .if_ne(num_unique_height, Usize::from(1)) - .then(|builder| { - let next_height = builder.get(&input.dimensions, cumul_dims_count).height; - let tmp_next_height_padded = next_power_of_two(builder, next_height); - builder.assign(&next_height_padded, tmp_next_height_padded); - }); - builder - .range(0, input.proof.len()) - .for_each(|i_vec, builder| { - let i = i_vec[0]; - let sibling = builder.get(&input.proof, i); - let two_var: Var = builder.eval(Usize::from(2)); // XXX: is there a better way to do this? - // Supply the next index bit as hint, assert that it is a bit - let next_index_bit = builder.hint_var(); - builder.assert_var_eq(next_index_bit, next_index_bit * next_index_bit); - builder.assign( - &reassembled_index, - reassembled_index + index_pow * next_index_bit, - ); - builder.assign(&index_pow, index_pow * two_var); - - // left, right - let compress_elem = builder.dyn_array(2); - builder - .if_eq(next_index_bit, Usize::from(0)) - .then(|builder| { - // root, sibling - builder.set_value(&compress_elem, 0, root.clone()); - builder.set_value(&compress_elem, 0, sibling.clone()); - }); - builder - .if_ne(next_index_bit, Usize::from(0)) - .then(|builder| { - // sibling, root - builder.set_value(&compress_elem, 0, sibling.clone()); - builder.set_value(&compress_elem, 0, root.clone()); - }); - let new_root = compress(builder, compress_elem); - builder.assign(&root, new_root); - - // curr_height_padded >>= 1 given curr_height_padded is a power of two - // Nondeterministically supply next_curr_height_padded - let next_curr_height_padded = builder.hint_var(); - builder.assert_var_eq(next_curr_height_padded * two_var, curr_height_padded); - builder.assign(&curr_height_padded, next_curr_height_padded); - - // determine whether next_height matches curr_height - builder - .if_eq(curr_height_padded, next_height_padded) - .then(|builder| { - // hash opened_values of all dims of next_height to root - let root_dims_count = - builder.get(&count_per_unique_height, next_unique_height_index); - let root_size: Var = builder.eval(root_dims_count + Usize::from(1)); - let root_values = builder.dyn_array(root_size); - builder.set_value(&root_values, 0, root.clone()); - builder - .range(0, root_dims_count) - .for_each(|i_vec, builder| { - let i = i_vec[0]; - let index = builder.get(&height_order, i); - let tmp = builder.get(&input.opened_values, index); - let j = builder.eval_expr(i + RVar::from(1)); - builder.set_value(&root_values, j, tmp); - }); - let new_root = hash_iter_slices(builder, root_values); - builder.assign(&root, new_root); - - // Update parameters - builder.assign(&cumul_dims_count, cumul_dims_count + root_dims_count); - builder.assign( - &next_unique_height_index, - next_unique_height_index + Usize::from(1), - ); - builder - .if_eq(next_unique_height_index, num_unique_height) - .then(|builder| { - builder.assign(&next_height_padded, Usize::from(0)); - }); - builder - .if_ne(next_unique_height_index, num_unique_height) - .then(|builder| { - let next_height = - builder.get(&input.dimensions, cumul_dims_count).height; - let next_tmp_height_padded = next_power_of_two(builder, next_height); - builder.assign(&next_height_padded, next_tmp_height_padded); - }); - }); - }); - builder.assert_var_eq(reassembled_index, input.index_bits); - builder.range(0, DIGEST_ELEMS).for_each(|i_vec, builder| { - let i = i_vec[0]; - let next_input = builder.get(&input.commit.value, i); - let next_root = builder.get(&root, i); - builder.assert_felt_eq(next_input, next_root); - }); } pub mod tests { @@ -290,9 +127,8 @@ pub mod tests { let mut builder = AsmBuilder::::default(); // Witness inputs - let mmcs_self = Default::default(); let mmcs_input = MmcsVerifierInput::read(&mut builder); - mmcs_verify_batch(&mut builder, mmcs_self, mmcs_input); + mmcs_verify_batch(&mut builder, mmcs_input); builder.halt(); // Pass in witness stream diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index a694796..761b754 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -1,7 +1,10 @@ // Note: check all XXX comments! use openvm_native_compiler::{asm::AsmConfig, prelude::*}; -use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; +use openvm_native_recursion::{ + hints::{Hintable, VecAutoHintable}, + vars::HintSlice, +}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; use p3_field::FieldAlgebra; @@ -28,7 +31,10 @@ impl Hintable for BatchOpening { fn read(builder: &mut Builder) -> Self::HintVariable { let opened_values = Vec::>::read(builder); - let opening_proof = Vec::>::read(builder); + let length = Usize::from(builder.hint_var()); + let id = Usize::from(builder.hint_load()); + let opening_proof = HintSlice { length, id }; + BatchOpeningVariable { opened_values, opening_proof, @@ -52,7 +58,7 @@ impl Hintable for BatchOpening { #[derive(DslVariable, Clone)] pub struct BatchOpeningVariable { pub opened_values: Array>>, - pub opening_proof: MmcsProofVariable, + pub opening_proof: HintSlice, } #[derive(Deserialize)] @@ -66,7 +72,10 @@ impl Hintable for CommitPhaseProofStep { fn read(builder: &mut Builder) -> Self::HintVariable { let sibling_value = E::read(builder); - let opening_proof = Vec::>::read(builder); + let length = Usize::from(builder.hint_var()); + let id = Usize::from(builder.hint_load()); + let opening_proof = HintSlice { length, id }; + CommitPhaseProofStepVariable { sibling_value, opening_proof, @@ -91,7 +100,7 @@ impl VecAutoHintable for CommitPhaseProofStep {} #[derive(DslVariable, Clone)] pub struct CommitPhaseProofStepVariable { pub sibling_value: Ext, - pub opening_proof: MmcsProofVariable, + pub opening_proof: HintSlice, } #[derive(Deserialize)] @@ -272,7 +281,7 @@ impl Hintable for QueryPhaseVerifierInput { #[derive(DslVariable, Clone)] pub struct QueryPhaseVerifierInputVariable { pub max_num_var: Usize, - pub indices: Array>>, + pub indices: Array>, pub vp: RSCodeVerifierParametersVariable, pub final_message: Array>>, pub batch_coeffs: Array>, @@ -395,11 +404,11 @@ pub(crate) fn batch_verifier_query_phase( let mmcs_verifier_input = MmcsVerifierInputVariable { commit: input.witin_comm.commit.clone(), dimensions: witin_dimensions, - index_bits: idx, + index_bits: idx_bits.clone(), // TODO: double check, should be new idx bits here opened_values: witin_opened_values.clone(), proof: witin_opening_proof, }; - mmcs_verify_batch(builder, mmcs.clone(), mmcs_verifier_input); + mmcs_verify_batch(builder, mmcs_verifier_input); // verify fixed let fixed_commit_leafs = builder.dyn_array(0); @@ -449,11 +458,11 @@ pub(crate) fn batch_verifier_query_phase( let mmcs_verifier_input = MmcsVerifierInputVariable { commit: input.fixed_comm.commit.clone(), dimensions: fixed_dimensions.clone(), - index_bits: new_idx, + index_bits: idx_bits.clone(), // TODO: should be new idx_bits opened_values: fixed_opened_values.clone(), proof: fixed_opening_proof, }; - mmcs_verify_batch(builder, mmcs.clone(), mmcs_verifier_input); + mmcs_verify_batch(builder, mmcs_verifier_input); builder.assign(&fixed_commit_leafs, fixed_opened_values); }); @@ -628,25 +637,18 @@ pub(crate) fn batch_verifier_query_phase( let n_d_i = pow_2(builder, n_d_i_log); // mmcs_ext.verify_batch let dimensions = builder.uninit_fixed_array(1); - let two = builder.eval(Usize::from(2)); - builder.set_value( - &dimensions, - 0, - DimensionsVariable { - width: two, - height: n_d_i.clone(), - }, - ); + // let two: Var<_> = builder.eval(Usize::from(2)); + builder.set_value(&dimensions, 0, n_d_i.clone()); let opened_values = builder.uninit_fixed_array(1); builder.set_value(&opened_values, 0, leafs.clone()); let ext_mmcs_verifier_input = ExtMmcsVerifierInputVariable { commit: pi_comm.clone(), dimensions, - index_bits: idx.clone(), + index_bits: idx_bits.clone(), // TODO: new idx bits? opened_values, proof, }; - ext_mmcs_verify_batch::(builder, mmcs_ext.clone(), ext_mmcs_verifier_input); + ext_mmcs_verify_batch::(builder, ext_mmcs_verifier_input); let coeffs = verifier_folding_coeffs_level(builder, &input.vp, n_d_i_log.clone()); diff --git a/src/basefold_verifier/structs.rs b/src/basefold_verifier/structs.rs index 5d2c324..366bda0 100644 --- a/src/basefold_verifier/structs.rs +++ b/src/basefold_verifier/structs.rs @@ -99,13 +99,10 @@ impl VecAutoHintable for Dimensions {} pub fn get_base_codeword_dimensions( builder: &mut Builder, circuit_meta_map: Array>, -) -> ( - Array>, - Array>, -) { +) -> (Array>, Array>) { let dim_len = circuit_meta_map.len(); - let wit_dim: Array> = builder.dyn_array(dim_len.clone()); - let fixed_dim: Array> = builder.dyn_array(dim_len.clone()); + let wit_dim: Array> = builder.dyn_array(dim_len.clone()); + let fixed_dim: Array> = builder.dyn_array(dim_len.clone()); builder.range(0, dim_len).for_each(|i_vec, builder| { let i = i_vec[0]; @@ -115,33 +112,31 @@ pub fn get_base_codeword_dimensions( let fixed_num_vars = tmp.fixed_num_vars; let fixed_num_polys = tmp.fixed_num_polys; // wit_dim - let width = builder.eval(witin_num_polys * Usize::from(2)); + // let width = builder.eval(witin_num_polys * Usize::from(2)); let height_exp = builder.eval(witin_num_vars + get_rate_log::() - Usize::from(1)); let height = pow_2(builder, height_exp); - let next_wit: DimensionsVariable = DimensionsVariable { - width, - height, - }; - builder.set_value(&wit_dim, i, next_wit); - + // let next_wit: DimensionsVariable = DimensionsVariable { width, height }; + // Only keep the height because the width is not needed in the mmcs batch + // verify instruction + builder.set_value(&wit_dim, i, height); + // fixed_dim // XXX: since fixed_num_vars is usize, fixed_num_vars > 0 is equivalent to fixed_num_vars != 0 - builder.if_ne(fixed_num_vars.clone(), Usize::from(0)).then(|builder| { - let width = builder.eval(fixed_num_polys.clone() * Usize::from(2)); - let height_exp = builder.eval(fixed_num_vars.clone() + get_rate_log::() - Usize::from(1)); - // XXX: more efficient pow implementation - let height = pow_2(builder, height_exp); - let next_fixed: DimensionsVariable = DimensionsVariable { - width, - height, - }; - builder.set_value(&fixed_dim, i, next_fixed); - }); + builder + .if_ne(fixed_num_vars.clone(), Usize::from(0)) + .then(|builder| { + // let width = builder.eval(fixed_num_polys.clone() * Usize::from(2)); + let height_exp = + builder.eval(fixed_num_vars.clone() + get_rate_log::() - Usize::from(1)); + // XXX: more efficient pow implementation + let height = pow_2(builder, height_exp); + // let next_fixed: DimensionsVariable = DimensionsVariable { width, height }; + builder.set_value(&fixed_dim, i, height); + }); }); (wit_dim, fixed_dim) } - pub mod tests { use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; use openvm_native_circuit::{Native, NativeConfig}; @@ -150,8 +145,7 @@ pub mod tests { use openvm_native_recursion::hints::Hintable; use openvm_stark_backend::config::StarkGenericConfig; use openvm_stark_sdk::{ - config::baby_bear_poseidon2::BabyBearPoseidon2Config, - p3_baby_bear::BabyBear, + config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, }; use p3_field::extension::BinomialExtensionField; type SC = BabyBearPoseidon2Config; @@ -171,12 +165,14 @@ pub mod tests { // Witness inputs let map_len = Usize::Var(usize::read(&mut builder)); let circuit_meta_map = builder.dyn_array(map_len.clone()); - builder.range(0, map_len.clone()).for_each(|i_vec, builder| { - let i = i_vec[0]; - let next_meta = CircuitIndexMeta::read(builder); - builder.set(&circuit_meta_map, i, next_meta); - }); - + builder + .range(0, map_len.clone()) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + let next_meta = CircuitIndexMeta::read(builder); + builder.set(&circuit_meta_map, i, next_meta); + }); + let (wit_dim, fixed_dim) = get_base_codeword_dimensions(&mut builder, circuit_meta_map); builder.range(0, map_len).for_each(|i_vec, builder| { let i = i_vec[0]; @@ -184,14 +180,14 @@ pub mod tests { let fixed = builder.get(&fixed_dim, i); let i_val: Var<_> = builder.eval(i); builder.print_v(i_val); - let ww_val: Var<_> = builder.eval(wit.width); - let wh_val: Var<_> = builder.eval(wit.height); - let fw_val: Var<_> = builder.eval(fixed.width); - let fh_val: Var<_> = builder.eval(fixed.height); - builder.print_v(ww_val); - builder.print_v(wh_val); - builder.print_v(fw_val); - builder.print_v(fh_val); + // let ww_val: Var<_> = builder.eval(wit.width); + // let wh_val: Var<_> = builder.eval(wit.height); + // let fw_val: Var<_> = builder.eval(fixed.width); + // let fh_val: Var<_> = builder.eval(fixed.height); + // builder.print_v(ww_val); + builder.print_v(wit); + // builder.print_v(fw_val); + builder.print_v(fixed); }); builder.halt(); @@ -239,4 +235,4 @@ pub mod tests { // println!("=> cycle count: {:?}", seg.metrics.cycle_count); // } } -} \ No newline at end of file +} From 629bb23bbccd3f476eb47c1a6fb379a121004fde Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Tue, 3 Jun 2025 09:02:36 +0800 Subject: [PATCH 29/70] Fix compilation error --- src/arithmetics/mod.rs | 28 ++++++++++++++++++++++++++++ src/basefold_verifier/query_phase.rs | 8 +++++--- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/src/arithmetics/mod.rs b/src/arithmetics/mod.rs index b9516a9..e6fb83d 100644 --- a/src/arithmetics/mod.rs +++ b/src/arithmetics/mod.rs @@ -373,6 +373,34 @@ pub fn build_eq_x_r_vec_sequential( evals } +pub fn build_eq_x_r_vec_sequential_with_offset( + builder: &mut Builder, + r: &Array>, + offset: Usize, +) -> Array> { + // we build eq(x,r) from its evaluations + // we want to evaluate eq(x,r) over x \in {0, 1}^num_vars + // for example, with num_vars = 4, x is a binary vector of 4, then + // 0 0 0 0 -> (1-r0) * (1-r1) * (1-r2) * (1-r3) + // 1 0 0 0 -> r0 * (1-r1) * (1-r2) * (1-r3) + // 0 1 0 0 -> (1-r0) * r1 * (1-r2) * (1-r3) + // 1 1 0 0 -> r0 * r1 * (1-r2) * (1-r3) + // .... + // 1 1 1 1 -> r0 * r1 * r2 * r3 + // we will need 2^num_var evaluations + + let r_len: Var = builder.eval(r.len() - offset); + let evals_len: Felt = builder.constant(C::F::ONE); + let evals_len = builder.exp_power_of_2_v::>(evals_len, r_len); + let evals_len = builder.cast_felt_to_var(evals_len); + + let evals: Array> = builder.dyn_array(evals_len); + // _debug + // build_eq_x_r_helper_sequential_offset(r, &mut evals, E::ONE); + // unsafe { std::mem::transmute(evals) } + evals +} + pub fn ceil_log2(x: usize) -> usize { assert!(x > 0, "ceil_log2: x must be positive"); // Calculate the number of bits in usize diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 30816e5..c936c73 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -12,7 +12,9 @@ use serde::Deserialize; use super::{basefold::*, extension_mmcs::*, mmcs::*, rs::*, structs::*, utils::*}; use crate::{ - arithmetics::{build_eq_x_r_vec_sequential_with_offset, eq_eval_with_index}, + arithmetics::{ + build_eq_x_r_vec_sequential, build_eq_x_r_vec_sequential_with_offset, eq_eval_with_index, + }, tower_verifier::{binding::*, program::interpolate_uni_poly}, }; @@ -698,7 +700,7 @@ pub(crate) fn batch_verifier_query_phase( let i = i_vec[0]; let evals = builder.get(&input.sumcheck_messages, i).evaluations; let challenge = builder.get(&input.fold_challenges, i); - let left = interpolate_uni_poly(builder, evals, challenge); + let left = interpolate_uni_poly(builder, &evals, challenge); let i_plus_one = builder.eval_expr(i + Usize::from(1)); let next_evals = builder .get(&input.sumcheck_messages, i_plus_one) @@ -714,7 +716,7 @@ pub(crate) fn batch_verifier_query_phase( .get(&input.sumcheck_messages, fold_len_minus_one.clone()) .evaluations; let final_challenge = builder.get(&input.fold_challenges, fold_len_minus_one.clone()); - let left = interpolate_uni_poly(builder, final_evals, final_challenge); + let left = interpolate_uni_poly(builder, &final_evals, final_challenge); let right: Ext = builder.constant(C::EF::ZERO); builder .range(0, input.final_message.len()) From 2516cf55b71e091f598c209de61f23652b17a906 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Tue, 3 Jun 2025 09:36:18 +0800 Subject: [PATCH 30/70] Fix hash variable reading bug --- src/basefold_verifier/hash.rs | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/src/basefold_verifier/hash.rs b/src/basefold_verifier/hash.rs index dd901d2..ca76387 100644 --- a/src/basefold_verifier/hash.rs +++ b/src/basefold_verifier/hash.rs @@ -30,7 +30,7 @@ impl Hintable for Hash { type HintVariable = HashVariable; fn read(builder: &mut Builder) -> Self::HintVariable { - let value = builder.uninit_fixed_array(DIGEST_ELEMS); + let value = builder.dyn_array(DIGEST_ELEMS); for i in 0..DIGEST_ELEMS { let tmp = F::read(builder); builder.set(&value, i, tmp); @@ -72,3 +72,30 @@ pub fn compress( // XXX: verify hash builder.hint_felts_fixed(DIGEST_ELEMS) } + +#[cfg(test)] +mod tests { + use openvm_native_compiler::asm::AsmBuilder; + use openvm_native_compiler_derive::iter_zip; + use openvm_stark_backend::config::StarkGenericConfig; + use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config; + type SC = BabyBearPoseidon2Config; + type F = BabyBear; + type E = BinomialExtensionField; + type EF = ::Challenge; + + use crate::basefold_verifier::basefold::HashDigest; + + use super::*; + #[test] + fn test_read_to_hash_variable() { + let mut builder = AsmBuilder::::default(); + + let hint = HashDigest::read(&mut builder); + println!(" hint: {:?}", hint.value); + let dst: HashVariable<_> = builder.uninit(); + println!(" dst: {:?}", dst.value); + // builder.set(&arr, 0, hint); + builder.assign(&dst, hint); + } +} From ba93b0b687c6a40bd4a9e9dc8e289b971c415533 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Tue, 3 Jun 2025 09:39:15 +0800 Subject: [PATCH 31/70] Use dyn array for dimensions --- src/basefold_verifier/query_phase.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index c936c73..b9ff730 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -638,7 +638,7 @@ pub(crate) fn batch_verifier_query_phase( builder.assign(&n_d_i_log, n_d_i_log - Usize::from(1)); let n_d_i = pow_2(builder, n_d_i_log); // mmcs_ext.verify_batch - let dimensions = builder.uninit_fixed_array(1); + let dimensions = builder.dyn_array(1); // let two: Var<_> = builder.eval(Usize::from(2)); builder.set_value(&dimensions, 0, n_d_i.clone()); let opened_values = builder.uninit_fixed_array(1); From ec5e218f8517d5ebad577992cbb1a16e31706109 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Tue, 3 Jun 2025 10:15:41 +0800 Subject: [PATCH 32/70] Fix multiplication between var and ext --- src/basefold_verifier/hash.rs | 3 - src/basefold_verifier/query_phase.rs | 2 + src/basefold_verifier/utils.rs | 102 +++++++++++++++------------ 3 files changed, 57 insertions(+), 50 deletions(-) diff --git a/src/basefold_verifier/hash.rs b/src/basefold_verifier/hash.rs index ca76387..d7b7813 100644 --- a/src/basefold_verifier/hash.rs +++ b/src/basefold_verifier/hash.rs @@ -92,10 +92,7 @@ mod tests { let mut builder = AsmBuilder::::default(); let hint = HashDigest::read(&mut builder); - println!(" hint: {:?}", hint.value); let dst: HashVariable<_> = builder.uninit(); - println!(" dst: {:?}", dst.value); - // builder.set(&arr, 0, hint); builder.assign(&dst, hint); } } diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index b9ff730..7243905 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -677,6 +677,8 @@ pub(crate) fn batch_verifier_query_phase( // we need to scale up with scalar for witin_num_vars < max_num_var let scale_log = builder.eval(input.max_num_var.clone() - witin_num_vars); let scale = pow_2(builder, scale_log); + // Transform scale into a field element + let scale = builder.unsafe_cast_var_to_felt(scale); builder.range(0, evals.len()).for_each(|j_vec, builder| { let j = j_vec[0]; let eval = builder.get(&evals, j); diff --git a/src/basefold_verifier/utils.rs b/src/basefold_verifier/utils.rs index a390b6d..8ca0b23 100644 --- a/src/basefold_verifier/utils.rs +++ b/src/basefold_verifier/utils.rs @@ -2,11 +2,7 @@ use openvm_native_compiler::ir::*; use p3_field::FieldAlgebra; // XXX: more efficient pow implementation -pub fn pow( - builder: &mut Builder, - base: Var, - exponent: Var, -) -> Var { +pub fn pow(builder: &mut Builder, base: Var, exponent: Var) -> Var { let value: Var = builder.constant(C::N::ONE); builder.range(0, exponent).for_each(|_, builder| { builder.assign(&value, value * base); @@ -14,19 +10,13 @@ pub fn pow( value } -pub fn pow_2( - builder: &mut Builder, - exponent: Var, -) -> Var { +pub fn pow_2(builder: &mut Builder, exponent: Var) -> Var { let two: Var = builder.constant(C::N::from_canonical_usize(2)); pow(builder, two, exponent) } // XXX: Equally outrageously inefficient -pub fn next_power_of_two( - builder: &mut Builder, - value: Var, -) -> Var { +pub fn next_power_of_two(builder: &mut Builder, value: Var) -> Var { // Non-deterministically supply the exponent n such that // 2^n < v <= 2^{n+1} // Ignore if v == 1 @@ -50,17 +40,11 @@ pub fn dot_product( builder: &mut Builder, li: &Array>, ri: &Array, -) -> Ext -where F: openvm_native_compiler::ir::MemVariable + 'static +) -> Ext +where + F: openvm_native_compiler::ir::MemVariable + 'static, { - dot_product_with_index::( - builder, - li, - ri, - Usize::from(0), - Usize::from(0), - li.len(), - ) + dot_product_with_index::(builder, li, ri, Usize::from(0), Usize::from(0), li.len()) } // Generic dot product of li[llo..llo+len] * ri[rlo..rlo+len] @@ -71,8 +55,9 @@ pub fn dot_product_with_index( llo: Usize, rlo: Usize, len: Usize, -) -> Ext - where F: openvm_native_compiler::ir::MemVariable + 'static +) -> Ext +where + F: openvm_native_compiler::ir::MemVariable + 'static, { let ret: Ext = builder.constant(C::EF::ZERO); @@ -114,9 +99,11 @@ pub fn sort_with_count( list: &Array, ind: Ind, // Convert loaded out entries into comparable ones ) -> (Array>, Var, Array>) - where E: openvm_native_compiler::ir::MemVariable, - N: Into::N>> + openvm_native_compiler::ir::Variable, - Ind: Fn(E) -> N +where + E: openvm_native_compiler::ir::MemVariable, + N: Into::N>> + + openvm_native_compiler::ir::Variable, + Ind: Fn(E) -> N, { let len = list.len(); // Nondeterministically supplies: @@ -146,7 +133,7 @@ pub fn sort_with_count( builder.set(&entries_sort_surjective, next_order, one.clone()); builder.set_value(&entries_order, 0, next_order); let last_entry = ind(builder.get(&list, next_order)); - + let last_unique_entry_index: Var = builder.eval(Usize::from(0)); let last_count_per_unique_entry: Var = builder.eval(Usize::from(1)); builder.range(1, len).for_each(|i_vec, builder| { @@ -158,27 +145,48 @@ pub fn sort_with_count( builder.set(&entries_sort_surjective, next_order, one.clone()); // Check entries let next_entry = ind(builder.get(&list, next_order)); - builder.if_eq(last_entry.clone(), next_entry.clone()).then(|builder| { - // next_entry == last_entry - builder.assign(&last_count_per_unique_entry, last_count_per_unique_entry + Usize::from(1)); - }); - builder.if_ne(last_entry.clone(), next_entry.clone()).then(|builder| { - // next_entry < last_entry - builder.assert_less_than_slow_small_rhs(next_entry.clone(), last_entry.clone()); - - // Update count_per_unique_entry - builder.set(&count_per_unique_entry, last_unique_entry_index, last_count_per_unique_entry); - builder.assign(&last_entry, next_entry.clone()); - builder.assign(&last_unique_entry_index, last_unique_entry_index + Usize::from(1)); - builder.assign(&last_count_per_unique_entry, Usize::from(1)); - }); + builder + .if_eq(last_entry.clone(), next_entry.clone()) + .then(|builder| { + // next_entry == last_entry + builder.assign( + &last_count_per_unique_entry, + last_count_per_unique_entry + Usize::from(1), + ); + }); + builder + .if_ne(last_entry.clone(), next_entry.clone()) + .then(|builder| { + // next_entry < last_entry + builder.assert_less_than_slow_small_rhs(next_entry.clone(), last_entry.clone()); + + // Update count_per_unique_entry + builder.set( + &count_per_unique_entry, + last_unique_entry_index, + last_count_per_unique_entry, + ); + builder.assign(&last_entry, next_entry.clone()); + builder.assign( + &last_unique_entry_index, + last_unique_entry_index + Usize::from(1), + ); + builder.assign(&last_count_per_unique_entry, Usize::from(1)); + }); builder.set_value(&entries_order, i, next_order); }); // Final check on num_unique_entries and count_per_unique_entry - builder.set(&count_per_unique_entry, last_unique_entry_index, last_count_per_unique_entry); - builder.assign(&last_unique_entry_index, last_unique_entry_index + Usize::from(1)); + builder.set( + &count_per_unique_entry, + last_unique_entry_index, + last_count_per_unique_entry, + ); + builder.assign( + &last_unique_entry_index, + last_unique_entry_index + Usize::from(1), + ); builder.assert_var_eq(last_unique_entry_index, num_unique_entries); (entries_order, num_unique_entries, count_per_unique_entry) @@ -196,7 +204,7 @@ pub fn codeword_fold_with_challenge( // recover left & right codeword via (lo, hi) = ((left + right) / 2, (left - right) / 2x) let lo: Ext = builder.eval((left + right) * inv_2); let hi: Ext = builder.eval((left - right) * coeff); // e.g. coeff = (2 * dit_butterfly)^(-1) in rs code - // we do fold on (lo, hi) to get folded = (1-r) * lo + r * hi (with lo, hi are two codewords), as it match perfectly with raw message in lagrange domain fixed variable + // we do fold on (lo, hi) to get folded = (1-r) * lo + r * hi (with lo, hi are two codewords), as it match perfectly with raw message in lagrange domain fixed variable let ret: Ext = builder.eval(lo + challenge * (hi - lo)); ret -} \ No newline at end of file +} From 586bc0da7ae919b9d1c45bebba56de4dde9a0ce1 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Fri, 6 Jun 2025 14:37:35 +0800 Subject: [PATCH 33/70] Fix mmcs reading --- src/basefold_verifier/extension_mmcs.rs | 11 +---- src/basefold_verifier/mmcs.rs | 55 +++++++++++-------------- src/basefold_verifier/query_phase.rs | 3 -- 3 files changed, 26 insertions(+), 43 deletions(-) diff --git a/src/basefold_verifier/extension_mmcs.rs b/src/basefold_verifier/extension_mmcs.rs index 3525cc5..cf4f7e0 100644 --- a/src/basefold_verifier/extension_mmcs.rs +++ b/src/basefold_verifier/extension_mmcs.rs @@ -9,18 +9,9 @@ pub type F = BabyBear; pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; -pub struct ExtensionMmcs { - pub inner: MerkleTreeMmcs, -} - -#[derive(Default, Clone)] -pub struct ExtensionMmcsVariable { - pub inner: MerkleTreeMmcsVariable, -} - pub struct ExtMmcsVerifierInput { pub commit: MmcsCommitment, - pub dimensions: Vec, + pub dimensions: Vec, pub index: usize, pub opened_values: Vec>, pub proof: MmcsProof, diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index ff0e76a..84d9a98 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -13,24 +13,11 @@ pub type F = BabyBear; pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; -// XXX: Fill in MerkleTreeMmcs -pub struct MerkleTreeMmcs { - pub hash: (), - pub compress: (), -} - -#[derive(Default, Clone)] -pub struct MerkleTreeMmcsVariable { - pub hash: (), - pub compress: (), - _phantom: PhantomData, -} - pub type MmcsCommitment = Hash; pub type MmcsProof = Vec<[F; DIGEST_ELEMS]>; pub struct MmcsVerifierInput { pub commit: MmcsCommitment, - pub dimensions: Vec, + pub dimensions: Vec, pub index: usize, pub opened_values: Vec>, pub proof: MmcsProof, @@ -59,9 +46,16 @@ impl Hintable for MmcsVerifierInput { fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); + // Split index into bits stream.extend(self.commit.write()); stream.extend(self.dimensions.write()); - stream.extend(>::write(&self.index)); + let mut index_bits = Vec::new(); + let mut index = self.index; + while index > 0 { + index_bits.push(index % 2); + index /= 2; + } + stream.extend( as Hintable>::write(&index_bits)); stream.extend(self.opened_values.write()); stream.extend( self.proof @@ -75,7 +69,7 @@ impl Hintable for MmcsVerifierInput { } pub type MmcsCommitmentVariable = HashVariable; -pub type MmcsProofVariable = Array::F>>>; + #[derive(DslVariable, Clone)] pub struct MmcsVerifierInputVariable { pub commit: MmcsCommitmentVariable, @@ -148,20 +142,21 @@ pub mod tests { f(383365269), ], }; - let dimensions = vec![ - Dimensions { - width: 8, - height: 1, - }, - Dimensions { - width: 8, - height: 1, - }, - Dimensions { - width: 8, - height: 70, - }, - ]; + // let dimensions = vec![ + // Dimensions { + // width: 8, + // height: 1, + // }, + // Dimensions { + // width: 8, + // height: 1, + // }, + // Dimensions { + // width: 8, + // height: 70, + // }, + // ]; + let dimensions = vec![1, 1, 70]; let index = 6; let opened_values = vec![ vec![ diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 7243905..f736a71 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -330,9 +330,6 @@ pub(crate) fn batch_verifier_query_phase( width: builder.eval(Usize::from(1)), }; let final_codeword = encode_small(builder, input.vp.clone(), final_rmm); - // XXX: we might need to add generics to MMCS to account for different field types - let mmcs_ext: ExtensionMmcsVariable = Default::default(); - let mmcs: MerkleTreeMmcsVariable = Default::default(); // can't use witin_comm.log2_max_codeword_size since it's untrusted let log2_witin_max_codeword_size: Var = builder.eval(input.max_num_var.clone() + get_rate_log::()); From b35aa5d512829e871817add162918a5b1d393d12 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Tue, 10 Jun 2025 14:06:23 +0800 Subject: [PATCH 34/70] Remove unnecessary witness stream --- src/basefold_verifier/mmcs.rs | 332 ---------------------------------- 1 file changed, 332 deletions(-) diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index 84d9a98..994c1d1 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -270,338 +270,6 @@ pub mod tests { proof, }; witness_stream.extend(mmcs_input.write()); - // max_height - witness_stream.extend(>::write(&70)); - // log_max_height - witness_stream.extend(>::write(&7)); - // num_unique_height - witness_stream.extend(>::write(&2)); - // height_order - witness_stream.extend(>::write(&2)); - // height_order - witness_stream.extend(>::write(&0)); - // height_order - witness_stream.extend(>::write(&1)); - // curr_height_log - witness_stream.extend(>::write(&6)); - // root - witness_stream.extend(>::write( - &F::from_canonical_usize(1782972889), - )); - // root - witness_stream.extend(>::write( - &F::from_canonical_usize(279434715), - )); - // root - witness_stream.extend(>::write( - &F::from_canonical_usize(1209301918), - )); - // root - witness_stream.extend(>::write( - &F::from_canonical_usize(1853868602), - )); - // root - witness_stream.extend(>::write( - &F::from_canonical_usize(883945353), - )); - // root - witness_stream.extend(>::write( - &F::from_canonical_usize(368353728), - )); - // root - witness_stream.extend(>::write( - &F::from_canonical_usize(1699837443), - )); - // root - witness_stream.extend(>::write( - &F::from_canonical_usize(908962698), - )); - // next_height_log - witness_stream.extend(>::write(&0)); - // next_bit - witness_stream.extend(>::write(&0)); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(271352274), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1918158485), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1538604111), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1122013445), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1844193149), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(501326061), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1508959271), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1549189152), - )); - // next_curr_height_padded - witness_stream.extend(>::write(&64)); - // next_bit - witness_stream.extend(>::write(&1)); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(222162520), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(785634830), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1461778378), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(836284568), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1141654637), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1339589042), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1081824021), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(698316542), - )); - // next_curr_height_padded - witness_stream.extend(>::write(&32)); - // next_bit - witness_stream.extend(>::write(&1)); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(567517164), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(915833994), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(621327606), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(476128789), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1976747536), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1385950652), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1416073024), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(862764478), - )); - // next_curr_height_padded - witness_stream.extend(>::write(&16)); - // next_bit - witness_stream.extend(>::write(&0)); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(822965313), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1036402058), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(117603799), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1087591966), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(443405499), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1334745091), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(901165815), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1187124281), - )); - // next_curr_height_padded - witness_stream.extend(>::write(&8)); - // next_bit - witness_stream.extend(>::write(&0)); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(875508647), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1313410483), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(355713834), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1976667383), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1804021525), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(294385081), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(669164730), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1187763617), - )); - // next_curr_height_padded - witness_stream.extend(>::write(&4)); - // next_bit - witness_stream.extend(>::write(&0)); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1992024140), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(439080849), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1032272714), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1304584689), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1795447062), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(859522945), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1661892383), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1980559722), - )); - // next_curr_height_padded - witness_stream.extend(>::write(&2)); - // next_bit - witness_stream.extend(>::write(&0)); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1121119596), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(369487248), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(834451573), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1120744826), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(758930984), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(632316631), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1593276657), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(507031465), - )); - // next_curr_height_padded - witness_stream.extend(>::write(&1)); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1715944678), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1204294900), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(59582177), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(320945505), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1470843790), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(1773915204), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(380281369), - )); - // new_root - witness_stream.extend(>::write( - &F::from_canonical_usize(383365269), - )); // PROGRAM let program: Program< From 0b8d75182578749ae1faee2392ddd95a24d6b2fb Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Tue, 10 Jun 2025 17:45:48 +0800 Subject: [PATCH 35/70] Add doc for generating mmcs test data --- src/basefold_verifier/mmcs.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index 994c1d1..4df1a52 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -115,6 +115,9 @@ pub mod tests { use super::{mmcs_verify_batch, InnerConfig, MmcsCommitment, MmcsVerifierInput}; + /// The witness in this test is produced by: + /// https://github.com/Jiangkm3/Plonky3 + /// cargo test --package p3-merkle-tree --lib -- mmcs::tests::size_gaps --exact --show-output #[allow(dead_code)] pub fn build_mmcs_verify_batch() -> (Program, Vec>) { // OpenVM DSL From ac1cbfd4cf2bcb01753df4988c8099d57c7b25d7 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Wed, 11 Jun 2025 09:19:47 +0800 Subject: [PATCH 36/70] Try fixing mmcs --- src/basefold_verifier/mmcs.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index 4df1a52..e27cd38 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -57,10 +57,13 @@ impl Hintable for MmcsVerifierInput { } stream.extend( as Hintable>::write(&index_bits)); stream.extend(self.opened_values.write()); + stream.extend(>::write( + &(self.proof.len() * 8), + )); stream.extend( self.proof .iter() - .map(|p| p.to_vec()) + .flat_map(|p| p.iter().copied()) .collect::>() .write(), ); From 4f22895b917442f737b01e25e8c5a664024ac050 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Wed, 11 Jun 2025 10:02:57 +0800 Subject: [PATCH 37/70] Try fixing mmcs --- src/basefold_verifier/mmcs.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index e27cd38..f37ffba 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -51,22 +51,21 @@ impl Hintable for MmcsVerifierInput { stream.extend(self.dimensions.write()); let mut index_bits = Vec::new(); let mut index = self.index; - while index > 0 { + for _ in 0..self.proof.len() { index_bits.push(index % 2); index /= 2; } + index_bits.reverse(); // Index bits is big endian ? stream.extend( as Hintable>::write(&index_bits)); stream.extend(self.opened_values.write()); - stream.extend(>::write( - &(self.proof.len() * 8), - )); + stream.extend(>::write(&(self.proof.len()))); // According to openvm extensions/native/recursion/src/fri/hints.rs stream.extend( self.proof .iter() .flat_map(|p| p.iter().copied()) .collect::>() .write(), - ); + ); // According to openvm extensions/native/recursion/src/fri/hints.rs stream } } From 0b6680130eef6248942c2efd1022d34f5ab0fd5a Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Wed, 11 Jun 2025 11:04:20 +0800 Subject: [PATCH 38/70] Try fixing mmcs --- src/basefold_verifier/mmcs.rs | 194 ++++++++++++++++------------------ 1 file changed, 90 insertions(+), 104 deletions(-) diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index f37ffba..93c8e60 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -55,7 +55,7 @@ impl Hintable for MmcsVerifierInput { index_bits.push(index % 2); index /= 2; } - index_bits.reverse(); // Index bits is big endian ? + // index_bits.reverse(); // Index bits is big endian ? stream.extend( as Hintable>::write(&index_bits)); stream.extend(self.opened_values.write()); stream.extend(>::write(&(self.proof.len()))); // According to openvm extensions/native/recursion/src/fri/hints.rs @@ -137,134 +137,120 @@ pub mod tests { > = Vec::new(); let commit = MmcsCommitment { value: [ - f(1715944678), - f(1204294900), - f(59582177), - f(320945505), - f(1470843790), - f(1773915204), - f(380281369), - f(383365269), + f(1680031158), + f(1530464150), + f(442938890), + f(1915006716), + f(1586505947), + f(567492512), + f(78546285), + f(122995307), ], }; - // let dimensions = vec![ - // Dimensions { - // width: 8, - // height: 1, - // }, - // Dimensions { - // width: 8, - // height: 1, - // }, - // Dimensions { - // width: 8, - // height: 70, - // }, - // ]; - let dimensions = vec![1, 1, 70]; + let dimensions = vec![7, 1, 1]; let index = 6; let opened_values = vec![ vec![ - f(774319227), - f(1631186743), - f(254325873), - f(504149682), - f(239740532), - f(1126519109), - f(1044404585), - f(1274764277), + f(960601660), + f(1192659670), + f(1578824022), + f(144975148), + f(1177686049), + f(1685481888), + f(743505857), + f(279845322), ], vec![ - f(1486505160), - f(631183960), - f(329388712), - f(1934479253), - f(115532954), - f(1978455077), - f(66346996), - f(821157541), + f(1097397493), + f(887027944), + f(980566941), + f(1572544252), + f(597464337), + f(396275662), + f(819983943), + f(1414101776), ], vec![ - f(149196326), - f(1186650877), - f(1970038391), - f(1893286029), - f(1249658956), - f(1618951617), - f(419030634), - f(1967997848), + f(1198674230), + f(1468910507), + f(453723651), + f(1663663454), + f(1329515200), + f(85748328), + f(660749682), + f(2010576218), ], ]; let proof = vec![ [ - f(845920358), - f(1201648213), - f(1087654550), - f(264553580), - f(633209321), - f(877945079), - f(1674449089), - f(1062812099), + f(1334234643), + f(588743138), + f(1420323154), + f(735905746), + f(495445129), + f(1544297066), + f(1062502165), + f(1322613112), ], [ - f(5498027), - f(1901489519), - f(179361222), - f(41261871), - f(1546446894), - f(266690586), - f(1882928070), - f(844710372), + f(1882651949), + f(572080113), + f(152683464), + f(14829179), + f(2006886314), + f(133167211), + f(745961821), + f(1681442214), ], [ - f(721245096), - f(388358486), - f(1443363461), - f(1349470697), - f(253624794), - f(1359455861), - f(237485093), - f(1955099141), + f(1157629442), + f(1439934290), + f(1996877031), + f(124179660), + f(1785268039), + f(1531335481), + f(172600848), + f(717903005), ], [ - f(1816731864), - f(402719753), - f(1972161922), - f(693018227), - f(1617207065), - f(1848150948), - f(360933015), - f(669793414), + f(1686363855), + f(364530059), + f(127515555), + f(1313410702), + f(1401384952), + f(1701278059), + f(1934144441), + f(120278217), ], [ - f(1746479395), - f(457185725), - f(1263857148), - f(328668702), - f(1743038915), - f(582282833), - f(927410326), - f(376217274), + f(1819568176), + f(1745841261), + f(211079785), + f(941471227), + f(981333411), + f(1989076935), + f(1836318175), + f(421578048), ], [ - f(1146845382), - f(1117439420), - f(1622226137), - f(1449227765), - f(138752938), - f(1251889563), - f(1266915653), - f(267248408), + f(1722987777), + f(675846798), + f(583189961), + f(1322278935), + f(575957852), + f(238416543), + f(382123109), + f(1551859129), ], [ - f(1992750195), - f(1604624754), - f(1748646393), - f(1777984113), - f(861317745), - f(564150089), - f(1371546358), - f(460033967), + f(1129037404), + f(615143781), + f(1557998657), + f(978670363), + f(325351741), + f(1598221158), + f(60344644), + f(1792544175), ], ]; let mmcs_input = MmcsVerifierInput { From 25c21c01292e2adf672ed53095c6f24dfb9b4621 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Wed, 11 Jun 2025 11:04:47 +0800 Subject: [PATCH 39/70] Add comment --- src/basefold_verifier/mmcs.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index 93c8e60..1f5e88e 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -147,7 +147,7 @@ pub mod tests { f(122995307), ], }; - let dimensions = vec![7, 1, 1]; + let dimensions = vec![7, 1, 1]; // The dimensions are logarithmic of the heights let index = 6; let opened_values = vec![ vec![ From edc4848783181beb578311d65a678c9363b61310 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Wed, 11 Jun 2025 13:52:07 +0800 Subject: [PATCH 40/70] Use the same poseidon2 constants as the test data --- src/basefold_verifier/mmcs.rs | 176 +++++++++++++++++----------------- 1 file changed, 88 insertions(+), 88 deletions(-) diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index 1f5e88e..4ac0c49 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -137,120 +137,120 @@ pub mod tests { > = Vec::new(); let commit = MmcsCommitment { value: [ - f(1680031158), - f(1530464150), - f(442938890), - f(1915006716), - f(1586505947), - f(567492512), - f(78546285), - f(122995307), + f(1093866670), + f(574851776), + f(1247953105), + f(1203327509), + f(1533188113), + f(685187947), + f(1288273829), + f(1115279794), ], }; let dimensions = vec![7, 1, 1]; // The dimensions are logarithmic of the heights let index = 6; let opened_values = vec![ vec![ - f(960601660), - f(1192659670), - f(1578824022), - f(144975148), - f(1177686049), - f(1685481888), - f(743505857), - f(279845322), + f(1900166735), + f(489844155), + f(447015069), + f(1088472428), + f(741990823), + f(629716069), + f(1813856950), + f(993673429), ], vec![ - f(1097397493), - f(887027944), - f(980566941), - f(1572544252), - f(597464337), - f(396275662), - f(819983943), - f(1414101776), + f(461805816), + f(38690267), + f(628409367), + f(326210486), + f(1399484986), + f(1106048341), + f(1653752726), + f(1508026260), ], vec![ - f(1198674230), - f(1468910507), - f(453723651), - f(1663663454), - f(1329515200), - f(85748328), - f(660749682), - f(2010576218), + f(1179435248), + f(589758130), + f(102692717), + f(1240806078), + f(1326867049), + f(1843793614), + f(1140390710), + f(1590488665), ], ]; let proof = vec![ [ - f(1334234643), - f(588743138), - f(1420323154), - f(735905746), - f(495445129), - f(1544297066), - f(1062502165), - f(1322613112), + f(296596654), + f(678058943), + f(1998719115), + f(927782063), + f(2012932188), + f(651079256), + f(106721600), + f(1671237590), ], [ - f(1882651949), - f(572080113), - f(152683464), - f(14829179), - f(2006886314), - f(133167211), - f(745961821), - f(1681442214), + f(1631182650), + f(1639600768), + f(451941478), + f(204140132), + f(471048369), + f(277644394), + f(1867343362), + f(592993761), ], [ - f(1157629442), - f(1439934290), - f(1996877031), - f(124179660), - f(1785268039), - f(1531335481), - f(172600848), - f(717903005), + f(138270113), + f(983240556), + f(868154296), + f(1436014073), + f(1333074616), + f(74821565), + f(220358401), + f(494000015), ], [ - f(1686363855), - f(364530059), - f(127515555), - f(1313410702), - f(1401384952), - f(1701278059), - f(1934144441), - f(120278217), + f(14047213), + f(1523499359), + f(1105004739), + f(222898207), + f(696072743), + f(913719856), + f(411499939), + f(250843350), ], [ - f(1819568176), - f(1745841261), - f(211079785), - f(941471227), - f(981333411), - f(1989076935), - f(1836318175), - f(421578048), + f(1093746185), + f(368171740), + f(1405456697), + f(103797304), + f(1561352958), + f(90154716), + f(154291788), + f(1719437900), ], [ - f(1722987777), - f(675846798), - f(583189961), - f(1322278935), - f(575957852), - f(238416543), - f(382123109), - f(1551859129), + f(1988708734), + f(580748340), + f(1935380041), + f(236593751), + f(230821177), + f(232517197), + f(451633153), + f(423978451), ], [ - f(1129037404), - f(615143781), - f(1557998657), - f(978670363), - f(325351741), - f(1598221158), - f(60344644), - f(1792544175), + f(319678818), + f(1076607925), + f(615756225), + f(508582464), + f(991750834), + f(1175188849), + f(1143234948), + f(1588353893), ], ]; let mmcs_input = MmcsVerifierInput { From c04e4b7872524f840566dfaec2d1fb9faf374b3a Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Wed, 11 Jun 2025 14:09:08 +0800 Subject: [PATCH 41/70] Specify branch for test data gen --- src/basefold_verifier/mmcs.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index 4ac0c49..3ae039b 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -118,7 +118,7 @@ pub mod tests { use super::{mmcs_verify_batch, InnerConfig, MmcsCommitment, MmcsVerifierInput}; /// The witness in this test is produced by: - /// https://github.com/Jiangkm3/Plonky3 + /// https://github.com/Jiangkm3/Plonky3 branch cyte/mmcs-poseidon2-constants /// cargo test --package p3-merkle-tree --lib -- mmcs::tests::size_gaps --exact --show-output #[allow(dead_code)] pub fn build_mmcs_verify_batch() -> (Program, Vec>) { From d5c8baabe2108f2bb37b5f0426990beb44672ca4 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Wed, 11 Jun 2025 16:13:43 +0800 Subject: [PATCH 42/70] MMCS test passes --- src/basefold_verifier/mmcs.rs | 178 +++++++++++++++++----------------- 1 file changed, 89 insertions(+), 89 deletions(-) diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index 3ae039b..d419e46 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -137,120 +137,120 @@ pub mod tests { > = Vec::new(); let commit = MmcsCommitment { value: [ - f(1093866670), - f(574851776), - f(1247953105), - f(1203327509), - f(1533188113), - f(685187947), - f(1288273829), - f(1115279794), + f(414821839), + f(366064801), + f(76927727), + f(1054874897), + f(522043147), + f(638338172), + f(1583746438), + f(941156703), ], }; - let dimensions = vec![7, 1, 1]; // The dimensions are logarithmic of the heights + let dimensions = vec![7, 0, 0]; let index = 6; let opened_values = vec![ vec![ - f(1900166735), - f(489844155), - f(447015069), - f(1088472428), - f(741990823), - f(629716069), - f(1813856950), - f(993673429), + f(783379538), + f(1083745632), + f(1297755122), + f(739705382), + f(1249630435), + f(1794480926), + f(706129135), + f(51286871), ], vec![ - f(461805816), - f(38690267), - f(628409367), - f(326210486), - f(1399484986), - f(1106048341), - f(1653752726), - f(1508026260), + f(1782820525), + f(487690259), + f(1939320991), + f(1236615939), + f(1149125220), + f(1681169264), + f(418636771), + f(1198975790), ], vec![ - f(1179435248), - f(589758130), - f(102692717), - f(1240806078), - f(1326867049), - f(1843793614), - f(1140390710), - f(1590488665), + f(1782820525), + f(487690259), + f(1939320991), + f(1236615939), + f(1149125220), + f(1681169264), + f(418636771), + f(1198975790), ], ]; let proof = vec![ [ - f(296596654), - f(678058943), - f(1998719115), - f(927782063), - f(2012932188), - f(651079256), - f(106721600), - f(1671237590), + f(709175359), + f(862600965), + f(21724453), + f(1644204827), + f(1122851899), + f(902491334), + f(187250228), + f(766400688), ], [ - f(1631182650), - f(1639600768), - f(451941478), - f(204140132), - f(471048369), - f(277644394), - f(1867343362), - f(592993761), + f(1500388444), + f(788589576), + f(699109303), + f(1804289606), + f(295155621), + f(328080503), + f(198482491), + f(1942550078), ], [ - f(138270113), - f(983240556), - f(868154296), - f(1436014073), - f(1333074616), - f(74821565), - f(220358401), - f(494000015), + f(132120813), + f(362247724), + f(635527855), + f(709381234), + f(1331884835), + f(1016275827), + f(962247980), + f(1772849136), ], [ - f(14047213), - f(1523499359), - f(1105004739), - f(222898207), - f(696072743), - f(913719856), - f(411499939), - f(250843350), + f(1707124288), + f(1917010688), + f(261076785), + f(346295418), + f(1637246858), + f(1607442625), + f(777235843), + f(194556598), ], [ - f(1093746185), - f(368171740), - f(1405456697), - f(103797304), - f(1561352958), - f(90154716), - f(154291788), - f(1719437900), + f(1410853257), + f(1598063795), + f(1111574219), + f(1465562989), + f(1102456901), + f(1433687377), + f(1376477958), + f(1087266135), ], [ - f(1988708734), - f(580748340), - f(1935380041), - f(236593751), - f(230821177), - f(232517197), - f(451633153), - f(423978451), + f(278709284), + f(1823086849), + f(1969802325), + f(633552560), + f(1780238760), + f(297873878), + f(421105965), + f(1357131680), ], [ - f(319678818), - f(1076607925), - f(615756225), - f(508582464), - f(991750834), - f(1175188849), - f(1143234948), - f(1588353893), + f(883611536), + f(685305811), + f(56966874), + f(170904280), + f(1353579462), + f(1357636937), + f(1565241058), + f(209109553), ], ]; let mmcs_input = MmcsVerifierInput { From 8029daa429b9cb19183530ea6a4967257789d7bc Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Fri, 13 Jun 2025 09:09:09 +0800 Subject: [PATCH 43/70] Rewrite fold coeff according to current basefold code --- src/basefold_verifier/query_phase.rs | 36 ++++++++++++++++++++----- src/basefold_verifier/rs.rs | 23 +++++++++------- src/basefold_verifier/utils.rs | 40 ++++++++++++++++++++++++++-- 3 files changed, 81 insertions(+), 18 deletions(-) diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index f736a71..837a0d2 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -308,6 +308,19 @@ pub(crate) fn batch_verifier_query_phase( inv_2 * C::F::from_canonical_usize(2), C::F::from_canonical_usize(1), ); + let two_adic_generators: Array> = builder.uninit_fixed_array(28); + for (index, val) in [ + 0x1usize, 0x78000000, 0x67055c21, 0x5ee99486, 0xbb4c4e4, 0x2d4cc4da, 0x669d6090, + 0x17b56c64, 0x67456167, 0x688442f9, 0x145e952d, 0x4fe61226, 0x4c734715, 0x11c33e2a, + 0x62c3d2b1, 0x77cad399, 0x54c131f4, 0x4cabd6a6, 0x5cf5713f, 0x3e9430e8, 0xba067a3, + 0x18adc27d, 0x21fd55bc, 0x4b859b3d, 0x3bd57996, 0x4483d85a, 0x3a26eef8, 0x1a427a41, + ] + .iter() + .enumerate() + { + let generator = builder.constant(C::F::from_canonical_usize(*val)); + builder.set_value(&two_adic_generators, index, generator); + } // encode_small let final_rmm_values_len = builder.get(&input.final_message, 0).len(); @@ -329,7 +342,7 @@ pub(crate) fn batch_verifier_query_phase( values: final_rmm_values, width: builder.eval(Usize::from(1)), }; - let final_codeword = encode_small(builder, input.vp.clone(), final_rmm); + let final_codeword = encode_small(builder, final_rmm); // can't use witin_comm.log2_max_codeword_size since it's untrusted let log2_witin_max_codeword_size: Var = builder.eval(input.max_num_var.clone() + get_rate_log::()); @@ -403,7 +416,7 @@ pub(crate) fn batch_verifier_query_phase( let mmcs_verifier_input = MmcsVerifierInputVariable { commit: input.witin_comm.commit.clone(), dimensions: witin_dimensions, - index_bits: idx_bits.clone(), // TODO: double check, should be new idx bits here + index_bits: idx_bits.clone(), // TODO: double check, should be new idx bits here ? opened_values: witin_opened_values.clone(), proof: witin_opening_proof, }; @@ -544,8 +557,13 @@ pub(crate) fn batch_verifier_query_phase( let hi = builder.get(&base_codeword_hi, index.clone()); let level: Var = builder.eval(cur_num_var + get_rate_log::() - Usize::from(1)); - let coeffs = verifier_folding_coeffs_level(builder, &input.vp, level); - let coeff = builder.get(&coeffs, idx); + let coeff = verifier_folding_coeffs_level( + builder, + &two_adic_generators, + level, + &idx_bits, + inv_2, + ); let fold = codeword_fold_with_challenge::(builder, lo, hi, r, coeff, inv_2); builder.assign(&folded, folded + fold); }); @@ -649,9 +667,13 @@ pub(crate) fn batch_verifier_query_phase( }; ext_mmcs_verify_batch::(builder, ext_mmcs_verifier_input); - let coeffs = - verifier_folding_coeffs_level(builder, &input.vp, n_d_i_log.clone()); - let coeff = builder.get(&coeffs, idx.clone()); + let coeff = verifier_folding_coeffs_level( + builder, + &two_adic_generators, + n_d_i_log.clone(), + &idx_bits, + inv_2, + ); let left = builder.get(&leafs, 0); let right = builder.get(&leafs, 1); let new_folded = diff --git a/src/basefold_verifier/rs.rs b/src/basefold_verifier/rs.rs index d180942..53b4e9d 100644 --- a/src/basefold_verifier/rs.rs +++ b/src/basefold_verifier/rs.rs @@ -6,9 +6,11 @@ use openvm_native_compiler::{asm::AsmConfig, prelude::*}; use openvm_native_recursion::hints::Hintable; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; +use p3_field::FieldAlgebra; use serde::Deserialize; use super::structs::*; +use super::utils::{pow_felt, pow_felt_bits}; pub type F = BabyBear; pub type E = BinomialExtensionField; @@ -101,10 +103,19 @@ pub fn get_basecode_msg_size_log() -> Usize { pub fn verifier_folding_coeffs_level( builder: &mut Builder, - pp: &RSCodeVerifierParametersVariable, + two_adic_generators: &Array>, level: Var, -) -> Array> { - builder.get(&pp.t_inv_halves, level) + index_bits: &Array>, // In big endian + two_inv: Felt, +) -> Felt { + let level_plus_one = builder.eval::, _>(level + C::N::ONE); + let g = builder.get(two_adic_generators, level_plus_one); + let g_inv = builder.hint_felt(); + builder.assert_eq::>(g_inv * g, C::F::from_canonical_usize(1)); + + let g_inv_index = pow_felt_bits(builder, g_inv, index_bits, level.into()); + + builder.eval(g_inv_index * two_inv) } /// The DIT FFT algorithm. @@ -172,7 +183,6 @@ impl Radix2DitVariable { #[derive(Deserialize)] pub struct RSCodeVerifierParameters { - pub dft: Radix2Dit, pub t_inv_halves: Vec>, pub full_message_size_log: usize, } @@ -181,12 +191,10 @@ impl Hintable for RSCodeVerifierParameters { type HintVariable = RSCodeVerifierParametersVariable; fn read(builder: &mut Builder) -> Self::HintVariable { - let dft = Radix2Dit::read(builder); let t_inv_halves = Vec::>::read(builder); let full_message_size_log = Usize::Var(usize::read(builder)); RSCodeVerifierParametersVariable { - dft, t_inv_halves, full_message_size_log, } @@ -194,7 +202,6 @@ impl Hintable for RSCodeVerifierParameters { fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); - stream.extend(self.dft.write()); stream.extend(self.t_inv_halves.write()); stream.extend(>::write( &self.full_message_size_log, @@ -205,7 +212,6 @@ impl Hintable for RSCodeVerifierParameters { #[derive(DslVariable, Clone)] pub struct RSCodeVerifierParametersVariable { - pub dft: Radix2DitVariable, pub t_inv_halves: Array>>, pub full_message_size_log: Usize, } @@ -234,7 +240,6 @@ pub(crate) fn encode_small( /// by the expansion rate. pub(crate) fn encode_small( builder: &mut Builder, - _vp: RSCodeVerifierParametersVariable, rmm: RowMajorMatrixVariable, // Assumed to have only one row and one column ) -> RowMajorMatrixVariable { // XXX: nondeterministically supply the results for now diff --git a/src/basefold_verifier/utils.rs b/src/basefold_verifier/utils.rs index 8ca0b23..49dc5a2 100644 --- a/src/basefold_verifier/utils.rs +++ b/src/basefold_verifier/utils.rs @@ -10,6 +10,39 @@ pub fn pow(builder: &mut Builder, base: Var, exponent: Var( + builder: &mut Builder, + base: Felt, + exponent: Var, +) -> Felt { + let value: Felt = builder.constant(C::F::ONE); + builder.range(0, exponent).for_each(|_, builder| { + builder.assign(&value, value * base); + }); + value +} + +// XXX: more efficient pow implementation +pub fn pow_felt_bits( + builder: &mut Builder, + base: Felt, + exponent_bits: &Array>, // In small endian + exponent_len: Usize, +) -> Felt { + let value: Felt = builder.constant(C::F::ONE); + let repeated_squared: Felt = base; + builder.range(0, exponent_len).for_each(|ptr, builder| { + let ptr = ptr[0]; + let bit = builder.get(exponent_bits, ptr); + builder.if_eq(bit, C::N::ONE).then(|builder| { + builder.assign(&value, value * repeated_squared); + }); + builder.assign(&repeated_squared, repeated_squared * repeated_squared); + }); + value +} + pub fn pow_2(builder: &mut Builder, exponent: Var) -> Var { let two: Var = builder.constant(C::N::from_canonical_usize(2)); pow(builder, two, exponent) @@ -203,8 +236,11 @@ pub fn codeword_fold_with_challenge( // original (left, right) = (lo + hi*x, lo - hi*x), lo, hi are codeword, but after times x it's not codeword // recover left & right codeword via (lo, hi) = ((left + right) / 2, (left - right) / 2x) let lo: Ext = builder.eval((left + right) * inv_2); - let hi: Ext = builder.eval((left - right) * coeff); // e.g. coeff = (2 * dit_butterfly)^(-1) in rs code - // we do fold on (lo, hi) to get folded = (1-r) * lo + r * hi (with lo, hi are two codewords), as it match perfectly with raw message in lagrange domain fixed variable + let hi: Ext = builder.eval((left - right) * coeff); + // e.g. coeff = (2 * dit_butterfly)^(-1) in rs code + // we do fold on (lo, hi) to get folded = (1-r) * lo + r * hi + // (with lo, hi are two codewords), as it match perfectly with raw message in + // lagrange domain fixed variable let ret: Ext = builder.eval(lo + challenge * (hi - lo)); ret } From c1e92b006d1301e7837fba1ea66d6b5c36f404b8 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Thu, 19 Jun 2025 11:22:23 +0800 Subject: [PATCH 44/70] Fix --- src/basefold_verifier/query_phase.rs | 9 ++------- src/basefold_verifier/rs.rs | 25 ------------------------- 2 files changed, 2 insertions(+), 32 deletions(-) diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 837a0d2..3cec547 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -191,7 +191,6 @@ pub struct PointAndEvalsVariable { pub struct QueryPhaseVerifierInput { pub max_num_var: usize, pub indices: Vec, - pub vp: RSCodeVerifierParameters, pub final_message: Vec>, pub batch_coeffs: Vec, pub queries: QueryOpeningProofs, @@ -210,7 +209,6 @@ impl Hintable for QueryPhaseVerifierInput { fn read(builder: &mut Builder) -> Self::HintVariable { let max_num_var = Usize::Var(usize::read(builder)); let indices = Vec::::read(builder); - let vp = RSCodeVerifierParameters::read(builder); let final_message = Vec::>::read(builder); let batch_coeffs = Vec::::read(builder); let queries = QueryOpeningProofs::read(builder); @@ -226,7 +224,6 @@ impl Hintable for QueryPhaseVerifierInput { QueryPhaseVerifierInputVariable { max_num_var, indices, - vp, final_message, batch_coeffs, queries, @@ -245,7 +242,6 @@ impl Hintable for QueryPhaseVerifierInput { let mut stream = Vec::new(); stream.extend(>::write(&self.max_num_var)); stream.extend(self.indices.write()); - stream.extend(self.vp.write()); stream.extend(self.final_message.write()); stream.extend(self.batch_coeffs.write()); stream.extend(self.queries.write()); @@ -284,7 +280,6 @@ impl Hintable for QueryPhaseVerifierInput { pub struct QueryPhaseVerifierInputVariable { pub max_num_var: Usize, pub indices: Array>, - pub vp: RSCodeVerifierParametersVariable, pub final_message: Array>>, pub batch_coeffs: Array>, pub queries: QueryOpeningProofsVariable, @@ -308,7 +303,7 @@ pub(crate) fn batch_verifier_query_phase( inv_2 * C::F::from_canonical_usize(2), C::F::from_canonical_usize(1), ); - let two_adic_generators: Array> = builder.uninit_fixed_array(28); + let two_adic_generators: Array> = builder.dyn_array(28); for (index, val) in [ 0x1usize, 0x78000000, 0x67055c21, 0x5ee99486, 0xbb4c4e4, 0x2d4cc4da, 0x669d6090, 0x17b56c64, 0x67456167, 0x688442f9, 0x145e952d, 0x4fe61226, 0x4c734715, 0x11c33e2a, @@ -812,7 +807,7 @@ pub mod tests { > = Vec::new(); // INPUT - let mut f = File::open("input.bin".to_string()).unwrap(); + let mut f = File::open("query_phase_verifier_input.bin".to_string()).unwrap(); let mut content: Vec = Vec::new(); f.read_to_end(&mut content).unwrap(); let input: QueryPhaseVerifierInput = bincode::deserialize(&content).unwrap(); diff --git a/src/basefold_verifier/rs.rs b/src/basefold_verifier/rs.rs index 53b4e9d..93e676a 100644 --- a/src/basefold_verifier/rs.rs +++ b/src/basefold_verifier/rs.rs @@ -183,36 +183,11 @@ impl Radix2DitVariable { #[derive(Deserialize)] pub struct RSCodeVerifierParameters { - pub t_inv_halves: Vec>, pub full_message_size_log: usize, } -impl Hintable for RSCodeVerifierParameters { - type HintVariable = RSCodeVerifierParametersVariable; - - fn read(builder: &mut Builder) -> Self::HintVariable { - let t_inv_halves = Vec::>::read(builder); - let full_message_size_log = Usize::Var(usize::read(builder)); - - RSCodeVerifierParametersVariable { - t_inv_halves, - full_message_size_log, - } - } - - fn write(&self) -> Vec::N>> { - let mut stream = Vec::new(); - stream.extend(self.t_inv_halves.write()); - stream.extend(>::write( - &self.full_message_size_log, - )); - stream - } -} - #[derive(DslVariable, Clone)] pub struct RSCodeVerifierParametersVariable { - pub t_inv_halves: Array>>, pub full_message_size_log: Usize, } From e28b4818ece61e3a41d3aeba8e1a28b04a853a95 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Mon, 23 Jun 2025 09:45:50 +0800 Subject: [PATCH 45/70] Merge e2e modification --- src/arithmetics/mod.rs | 122 ++++++++++++++++++++++++++++------------- 1 file changed, 85 insertions(+), 37 deletions(-) diff --git a/src/arithmetics/mod.rs b/src/arithmetics/mod.rs index 4e5ea70..6888d55 100644 --- a/src/arithmetics/mod.rs +++ b/src/arithmetics/mod.rs @@ -7,10 +7,10 @@ use ff_ext::{BabyBearExt4, SmallField}; use openvm_native_compiler::prelude::*; use openvm_native_compiler_derive::iter_zip; use openvm_native_recursion::challenger::ChallengerVariable; -use p3_field::{FieldAlgebra, FieldExtensionAlgebra}; use openvm_native_recursion::challenger::{ duplex::DuplexChallengerVariable, CanObserveVariable, FeltChallenger, }; +use p3_field::{FieldAlgebra, FieldExtensionAlgebra}; type E = BabyBearExt4; const HASH_RATE: usize = 8; @@ -35,7 +35,10 @@ pub fn _print_usize_arr(builder: &mut Builder, arr: &Array(builder: &mut Builder, exts: &Array>) -> Array> { +pub unsafe fn exts_to_felts( + builder: &mut Builder, + exts: &Array>, +) -> Array> { let f_len: Usize = builder.eval(exts.len() * Usize::from(C::EF::D)); let f_arr: Array> = Array::Dyn(exts.ptr(), f_len); f_arr @@ -44,15 +47,22 @@ pub unsafe fn exts_to_felts(builder: &mut Builder, exts: &Array( builder: &mut Builder, challenger: &mut DuplexChallengerVariable, - arr: &Array> + arr: &Array>, ) { - let next_input_ptr = builder.poseidon2_multi_observe(&challenger.sponge_state, challenger.input_ptr, &arr); - builder.assign(&challenger.input_ptr, challenger.io_empty_ptr + next_input_ptr.clone()); - builder.if_ne(next_input_ptr, Usize::from(0)).then_or_else(|builder| { - builder.assign(&challenger.output_ptr, challenger.io_empty_ptr); - }, |builder| { - builder.assign(&challenger.output_ptr, challenger.io_full_ptr); - }); + let next_input_ptr = + builder.poseidon2_multi_observe(&challenger.sponge_state, challenger.input_ptr, &arr); + builder.assign( + &challenger.input_ptr, + challenger.io_empty_ptr + next_input_ptr.clone(), + ); + builder.if_ne(next_input_ptr, Usize::from(0)).then_or_else( + |builder| { + builder.assign(&challenger.output_ptr, challenger.io_empty_ptr); + }, + |builder| { + builder.assign(&challenger.output_ptr, challenger.io_full_ptr); + }, + ); } pub fn is_smaller_than( @@ -731,7 +741,7 @@ pub fn max_usize_arr( } pub struct UniPolyExtrapolator { - constants: [Ext; 12], // 0, 1, 2, 3, 4, -1, 1/2, -1/2, 1/6, -1/6, 1/4, 1/24 + constants: [Ext; 12], // 0, 1, 2, 3, 4, -1, 1/2, -1/2, 1/6, -1/6, 1/4, 1/24 } impl UniPolyExtrapolator { @@ -765,39 +775,62 @@ impl UniPolyExtrapolator { neg_six_inverse, four_inverse, twenty_four_inverse, - ] + ], } } - pub fn extrapolate_uni_poly(&mut self, builder: &mut Builder, p_i: &Array>, eval_at: Ext) -> Ext { + pub fn extrapolate_uni_poly( + &mut self, + builder: &mut Builder, + p_i: &Array>, + eval_at: Ext, + ) -> Ext { let res: Ext = builder.constant(C::EF::ZERO); - builder.if_eq(p_i.len(), Usize::from(4)).then_or_else(|builder| { - let ext = self.extrapolate_uni_poly_deg_3(builder, p_i, eval_at); - builder.assign(&res, ext); - }, |builder| { - builder.if_eq(p_i.len(), Usize::from(3)).then_or_else(|builder| { - let ext = self.extrapolate_uni_poly_deg_2(builder, p_i, eval_at); + builder.if_eq(p_i.len(), Usize::from(4)).then_or_else( + |builder| { + let ext = self.extrapolate_uni_poly_deg_3(builder, p_i, eval_at); builder.assign(&res, ext); - }, |builder| { - builder.if_eq(p_i.len(), Usize::from(2)).then_or_else(|builder| { - let ext = self.extrapolate_uni_poly_deg_1(builder, p_i, eval_at); - builder.assign(&res, ext); - }, |builder| { - builder.if_eq(p_i.len(), Usize::from(5)).then_or_else(|builder| { - let ext = self.extrapolate_uni_poly_deg_4(builder, p_i, eval_at); + }, + |builder| { + builder.if_eq(p_i.len(), Usize::from(3)).then_or_else( + |builder| { + let ext = self.extrapolate_uni_poly_deg_2(builder, p_i, eval_at); builder.assign(&res, ext); - }, |builder| { - builder.error(); - }); - }); - }); - }); + }, + |builder| { + builder.if_eq(p_i.len(), Usize::from(2)).then_or_else( + |builder| { + let ext = self.extrapolate_uni_poly_deg_1(builder, p_i, eval_at); + builder.assign(&res, ext); + }, + |builder| { + builder.if_eq(p_i.len(), Usize::from(5)).then_or_else( + |builder| { + let ext = + self.extrapolate_uni_poly_deg_4(builder, p_i, eval_at); + builder.assign(&res, ext); + }, + |builder| { + builder.error(); + }, + ); + }, + ); + }, + ); + }, + ); res } - fn extrapolate_uni_poly_deg_1(&self, builder: &mut Builder, p_i: &Array>, eval_at: Ext) -> Ext { + fn extrapolate_uni_poly_deg_1( + &self, + builder: &mut Builder, + p_i: &Array>, + eval_at: Ext, + ) -> Ext { // w0 = 1 / (0−1) = -1 // w1 = 1 / (1−0) = 1 let d0: Ext = builder.eval(eval_at - self.constants[0]); @@ -813,7 +846,12 @@ impl UniPolyExtrapolator { builder.eval(l * (t0 + t1)) } - fn extrapolate_uni_poly_deg_2(&self, builder: &mut Builder, p_i: &Array>, eval_at: Ext) -> Ext { + fn extrapolate_uni_poly_deg_2( + &self, + builder: &mut Builder, + p_i: &Array>, + eval_at: Ext, + ) -> Ext { // w0 = 1 / ((0−1)(0−2)) = 1/2 // w1 = 1 / ((1−0)(1−2)) = -1 // w2 = 1 / ((2−0)(2−1)) = 1/2 @@ -834,7 +872,12 @@ impl UniPolyExtrapolator { builder.eval(l * (t0 + t1 + t2)) } - fn extrapolate_uni_poly_deg_3(&self, builder: &mut Builder, p_i: &Array>, eval_at: Ext) -> Ext { + fn extrapolate_uni_poly_deg_3( + &self, + builder: &mut Builder, + p_i: &Array>, + eval_at: Ext, + ) -> Ext { // w0 = 1 / ((0−1)(0−2)(0−3)) = -1/6 // w1 = 1 / ((1−0)(1−2)(1−3)) = 1/2 // w2 = 1 / ((2−0)(2−1)(2−3)) = -1/2 @@ -859,7 +902,12 @@ impl UniPolyExtrapolator { builder.eval(l * (t0 + t1 + t2 + t3)) } - fn extrapolate_uni_poly_deg_4(&self, builder: &mut Builder, p_i: &Array>, eval_at: Ext) -> Ext { + fn extrapolate_uni_poly_deg_4( + &self, + builder: &mut Builder, + p_i: &Array>, + eval_at: Ext, + ) -> Ext { // w0 = 1 / ((0−1)(0−2)(0−3)(0−4)) = 1/24 // w1 = 1 / ((1−0)(1−2)(1−3)(1−4)) = -1/6 // w2 = 1 / ((2−0)(2−1)(2−3)(2−4)) = 1/4 @@ -887,4 +935,4 @@ impl UniPolyExtrapolator { builder.eval(l * (t0 + t1 + t2 + t3 + t4)) } -} \ No newline at end of file +} From dd6d6bd2eebe34cd201d60c7df064fe5b8765e88 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Mon, 23 Jun 2025 09:49:54 +0800 Subject: [PATCH 46/70] Fix compilation errors from merge --- src/arithmetics/mod.rs | 18 ++++++++- src/tower_verifier/program.rs | 71 +++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/src/arithmetics/mod.rs b/src/arithmetics/mod.rs index 6888d55..aa1a396 100644 --- a/src/arithmetics/mod.rs +++ b/src/arithmetics/mod.rs @@ -180,12 +180,28 @@ pub fn gen_idx_arr(builder: &mut Builder, len: Usize) -> Arr res } +// Evaluate eq polynomial. pub fn eq_eval( builder: &mut Builder, x: &Array>, y: &Array>, + one: Ext, + zero: Ext, ) -> Ext { - eq_eval_with_index::(builder, x, y, Usize::from(0), Usize::from(0), x.len()) + let acc: Ext = builder.eval(zero + one); // simple trick to use AddE + + iter_zip!(builder, x, y).for_each(|idx_vec, builder| { + let ptr_x = idx_vec[0]; + let ptr_y = idx_vec[1]; + let v_x = builder.iter_ptr_get(&x, ptr_x); + let v_y = builder.iter_ptr_get(&y, ptr_y); + // x*y + (1-x)*(1-y) + let xi_yi: Ext = builder.eval(v_x * v_y); + let new_acc: Ext = builder.eval(acc * (xi_yi + xi_yi - v_x - v_y + one)); + builder.assign(&acc, new_acc); + }); + + acc } // Evaluate eq polynomial. diff --git a/src/tower_verifier/program.rs b/src/tower_verifier/program.rs index 815decb..f8e6e3c 100644 --- a/src/tower_verifier/program.rs +++ b/src/tower_verifier/program.rs @@ -14,6 +14,77 @@ use openvm_native_recursion::challenger::{ }; use p3_field::FieldAlgebra; +pub(crate) fn interpolate_uni_poly( + builder: &mut Builder, + p_i: &Array>, + eval_at: Ext, +) -> Ext { + let len = p_i.len(); + let evals: Array> = builder.dyn_array(len.clone()); + let prod: Ext = builder.eval(eval_at); + + builder.set(&evals, 0, eval_at); + + // `prod = \prod_{j} (eval_at - j)` + let e: Ext = builder.constant(C::EF::ONE); + let one: Ext = builder.constant(C::EF::ONE); + builder.range(1, len.clone()).for_each(|i_vec, builder| { + let i = i_vec[0]; + let tmp: Ext = builder.constant(C::EF::ONE); + builder.assign(&tmp, eval_at - e); + builder.set(&evals, i, tmp); + builder.assign(&prod, prod * tmp); + builder.assign(&e, e + one); + }); + + let denom_up: Ext = builder.constant(C::EF::ONE); + let i: Ext = builder.constant(C::EF::ONE); + builder.assign(&i, i + one); + builder.range(2, len.clone()).for_each(|_i_vec, builder| { + builder.assign(&denom_up, denom_up * i); + builder.assign(&i, i + one); + }); + let denom_down: Ext = builder.constant(C::EF::ONE); + + let idx_vec_len: RVar = builder.eval_expr(len.clone() - RVar::from(1)); + let idx_vec: Array> = builder.dyn_array(idx_vec_len); + let idx_val: Ext = builder.constant(C::EF::ONE); + builder.range(0, idx_vec.len()).for_each(|i_vec, builder| { + builder.set(&idx_vec, i_vec[0], idx_val); + builder.assign(&idx_val, idx_val + one); + }); + let idx_rev = reverse(builder, &idx_vec); + let res = builder.constant(C::EF::ZERO); + + let len_f = idx_val.clone(); + let neg_one: Ext = builder.constant(C::EF::NEG_ONE); + let evals_rev = reverse(builder, &evals); + let p_i_rev = reverse(builder, &p_i); + + let mut idx_pos: RVar = builder.eval_expr(len.clone() - RVar::from(1)); + iter_zip!(builder, idx_rev, evals_rev, p_i_rev).for_each(|ptr_vec, builder| { + let idx = builder.iter_ptr_get(&idx_rev, ptr_vec[0]); + let eval = builder.iter_ptr_get(&evals_rev, ptr_vec[1]); + let up_eval_inv: Ext = builder.eval(denom_up * eval); + builder.assign(&up_eval_inv, up_eval_inv.inverse()); + let p = builder.iter_ptr_get(&p_i_rev, ptr_vec[2]); + + builder.assign(&res, res + p * prod * denom_down * up_eval_inv); + builder.assign(&denom_up, denom_up * (len_f - idx) * neg_one); + builder.assign(&denom_down, denom_down * idx); + + idx_pos = builder.eval_expr(idx_pos - RVar::from(1)); + }); + + let p_i_0 = builder.get(&p_i, 0); + let eval_0 = builder.get(&evals, 0); + let up_eval_inv: Ext = builder.eval(denom_up * eval_0); + builder.assign(&up_eval_inv, up_eval_inv.inverse()); + builder.assign(&res, res + p_i_0 * prod * denom_down * up_eval_inv); + + res +} + // Interpolate a uni-variate degree-`p_i.len()-1` polynomial and evaluate this // polynomial at `eval_at`: // From 6ba1fea37667235a9024715e153b7f5ca6e0dd43 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Mon, 23 Jun 2025 09:50:47 +0800 Subject: [PATCH 47/70] Fix compilation errors from merge --- src/arithmetics/mod.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/arithmetics/mod.rs b/src/arithmetics/mod.rs index aa1a396..b786a20 100644 --- a/src/arithmetics/mod.rs +++ b/src/arithmetics/mod.rs @@ -952,3 +952,21 @@ impl UniPolyExtrapolator { builder.eval(l * (t0 + t1 + t2 + t3 + t4)) } } + +pub fn extend( + builder: &mut Builder, + arr: &Array>, + elem: &Ext, +) -> Array> { + let new_len: Var = builder.eval(arr.len() + C::N::ONE); + let out = builder.dyn_array(new_len); + + builder.range(0, arr.len()).for_each(|i_vec, builder| { + let i = i_vec[0]; + let val = builder.get(arr, i); + builder.set_value(&out, i, val); + }); + builder.set_value(&out, arr.len(), elem.clone()); + + out +} From 5c2b6b50b58a0120f24aae0e7ba0eb87a1be8cc8 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Mon, 23 Jun 2025 09:52:47 +0800 Subject: [PATCH 48/70] Fix compilation errors from merge --- src/arithmetics/mod.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/arithmetics/mod.rs b/src/arithmetics/mod.rs index b786a20..0d63942 100644 --- a/src/arithmetics/mod.rs +++ b/src/arithmetics/mod.rs @@ -97,6 +97,23 @@ pub fn evaluate_at_point( builder.eval(r * (right - left) + left) } +pub fn fixed_dot_product( + builder: &mut Builder, + a: &[Ext], + b: &Array>, + zero: Ext, +) -> Ext<::F, ::EF> { + // simple trick to prefer AddE(1 cycle) than AddEI(4 cycles) + let acc: Ext = builder.eval(zero + zero); + + for (i, va) in a.iter().enumerate() { + let vb = builder.get(b, i); + builder.assign(&acc, acc + *va * vb); + } + + acc +} + pub fn dot_product( builder: &mut Builder, a: &Array>, From 8cb2ad748c0f09909020c4afe2796297259768a4 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Mon, 23 Jun 2025 10:05:32 +0800 Subject: [PATCH 49/70] (WIP) Connecting e2e with query phase --- src/e2e/mod.rs | 16 ++++++++++++++++ src/zkvm_verifier/binding.rs | 11 +++++++++++ 2 files changed, 27 insertions(+) diff --git a/src/e2e/mod.rs b/src/e2e/mod.rs index 101eacc..a8eb418 100644 --- a/src/e2e/mod.rs +++ b/src/e2e/mod.rs @@ -400,6 +400,21 @@ pub fn parse_zkvm_proof_import( serde_json::from_value(serde_json::to_value(zkvm_proof.witin_commit).unwrap()).unwrap(); let fixed_commit = verifier.vk.fixed_commit.clone(); + let query_phase_verifier_input = QueryPhaseVerifierInput { + max_num_var, + indices, + final_message, + batch_coeffs, + queries, + fixed_comm, + witin_comm, + circuit_meta, + commits, + fold_challenges, + sumcheck_messages, + point_evals, + }; + ( ZKVMProofInput { raw_pi, @@ -409,6 +424,7 @@ pub fn parse_zkvm_proof_import( witin_commit, fixed_commit, num_instances: zkvm_proof.num_instances.clone(), + query_phase_verifier_input, }, proving_sequence, ) diff --git a/src/zkvm_verifier/binding.rs b/src/zkvm_verifier/binding.rs index e0e3725..532ab3c 100644 --- a/src/zkvm_verifier/binding.rs +++ b/src/zkvm_verifier/binding.rs @@ -1,4 +1,7 @@ use crate::arithmetics::next_pow2_instance_padding; +use crate::basefold_verifier::query_phase::{ + QueryPhaseVerifierInput, QueryPhaseVerifierInputVariable, +}; use crate::{ arithmetics::ceil_log2, tower_verifier::binding::{IOPProverMessage, IOPProverMessageVariable}, @@ -37,6 +40,8 @@ pub struct ZKVMProofInputVariable { pub fixed_commit_trivial_commits: Array>>, pub fixed_commit_log2_max_codeword_size: Felt, pub num_instances: Array>>, + + pub query_phase_verifier_input: QueryPhaseVerifierInputVariable, } #[derive(DslVariable, Clone)] @@ -106,6 +111,7 @@ pub(crate) struct ZKVMProofInput { pub witin_commit: BasefoldCommitment, pub fixed_commit: Option>, pub num_instances: Vec<(usize, usize)>, + pub query_phase_verifier_input: QueryPhaseVerifierInput, } impl Hintable for ZKVMProofInput { type HintVariable = ZKVMProofInputVariable; @@ -128,6 +134,8 @@ impl Hintable for ZKVMProofInput { let num_instances = Vec::>::read(builder); + let query_phase_verifier_input = QueryPhaseVerifierInput::read(builder); + ZKVMProofInputVariable { raw_pi, raw_pi_num_variables, @@ -142,6 +150,7 @@ impl Hintable for ZKVMProofInput { fixed_commit_trivial_commits, fixed_commit_log2_max_codeword_size, num_instances, + query_phase_verifier_input, } } @@ -226,6 +235,8 @@ impl Hintable for ZKVMProofInput { } stream.extend(num_instances_vec.write()); + stream.extend(self.query_phase_verifier_input.write()); + stream } } From 5361f3f86b766ed61ee046eeff8db3bd1ff9d236 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Mon, 23 Jun 2025 14:24:09 +0800 Subject: [PATCH 50/70] WIP --- src/e2e/mod.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/e2e/mod.rs b/src/e2e/mod.rs index a8eb418..ca97d4a 100644 --- a/src/e2e/mod.rs +++ b/src/e2e/mod.rs @@ -400,6 +400,8 @@ pub fn parse_zkvm_proof_import( serde_json::from_value(serde_json::to_value(zkvm_proof.witin_commit).unwrap()).unwrap(); let fixed_commit = verifier.vk.fixed_commit.clone(); + let pcs_proof = zkvm_proof.fixed_witin_opening_proof; + let query_phase_verifier_input = QueryPhaseVerifierInput { max_num_var, indices, @@ -409,9 +411,9 @@ pub fn parse_zkvm_proof_import( fixed_comm, witin_comm, circuit_meta, - commits, + commits: pcs_proof.commits, fold_challenges, - sumcheck_messages, + sumcheck_messages: pcs_proof.sumcheck_proof.unwrap(), point_evals, }; From 37fb5f824cc189ee441445a9cc5d5b271edd8116 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Tue, 24 Jun 2025 19:45:12 +0800 Subject: [PATCH 51/70] (WIP) transform ceno query phase verifier input to current --- Cargo.lock | 322 ++++++++++++++++----------- Cargo.toml | 14 +- src/basefold_verifier/query_phase.rs | 24 +- 3 files changed, 227 insertions(+), 133 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4606d00..1fef29a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,9 +24,9 @@ dependencies = [ [[package]] name = "adler2" -version = "2.0.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" [[package]] name = "aes" @@ -117,7 +117,7 @@ version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c8bdeb6047d8983be085bab0ba1472e6dc604e7041dbf6fcd5e71523014fae9" dependencies = [ - "windows-sys", + "windows-sys 0.59.0", ] [[package]] @@ -128,7 +128,7 @@ checksum = "403f75924867bb1033c59fbf0797484329750cfbe3c4325cd33127941fabc882" dependencies = [ "anstyle", "once_cell_polyfill", - "windows-sys", + "windows-sys 0.59.0", ] [[package]] @@ -215,7 +215,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62945a2f7e6de02a31fe400aa489f0e0f5b2502e69f95f853adb82a96c7a6b60" dependencies = [ "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -241,7 +241,7 @@ dependencies = [ "num-traits", "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -291,7 +291,7 @@ checksum = "213888f660fddcca0d257e88e54ac05bca01885f258ccdf695bafd77031bb69d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -328,9 +328,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "autocfg" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "backtrace" @@ -345,7 +345,7 @@ dependencies = [ "object", "rustc-demangle", "serde", - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -390,7 +390,7 @@ checksum = "42b6b4cb608b8282dc3b53d0f4c9ab404655d562674c682db7e6c0458cc83c23" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -508,14 +508,14 @@ checksum = "efb7846e0cb180355c2dec69e721edafa36919850f1a9f52ffba4ebc0393cb71" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] name = "bytemuck" -version = "1.23.0" +version = "1.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9134a6ef01ce4b366b50689c94f82c14bc72bc5d0386829828a2e2752ef7958c" +checksum = "5c76a5792e44e4abe34d3abf15636779261d45a7450612059293d1d2cfc63422" [[package]] name = "byteorder" @@ -537,9 +537,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.26" +version = "1.2.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "956a5e21988b87f372569b66183b78babf23ebc2e744b733e4350a752c4dafac" +checksum = "d487aa071b5f64da6f19a3e848e3578944b726ee5a4854b82172f02aa876bfdc" dependencies = [ "shlex", ] @@ -547,7 +547,7 @@ dependencies = [ [[package]] name = "ceno-examples" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#f2346a942b630d9af4cc4f88e875fefaf5e9f6d0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" dependencies = [ "glob", ] @@ -596,7 +596,7 @@ dependencies = [ [[package]] name = "ceno_emul" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#f2346a942b630d9af4cc4f88e875fefaf5e9f6d0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" dependencies = [ "anyhow", "ceno_rt", @@ -617,7 +617,7 @@ dependencies = [ [[package]] name = "ceno_host" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#f2346a942b630d9af4cc4f88e875fefaf5e9f6d0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" dependencies = [ "anyhow", "ceno_emul", @@ -630,7 +630,7 @@ dependencies = [ [[package]] name = "ceno_rt" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#f2346a942b630d9af4cc4f88e875fefaf5e9f6d0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" dependencies = [ "getrandom 0.2.16", "rkyv", @@ -639,7 +639,7 @@ dependencies = [ [[package]] name = "ceno_zkvm" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#f2346a942b630d9af4cc4f88e875fefaf5e9f6d0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" dependencies = [ "base64", "bincode", @@ -680,9 +680,9 @@ dependencies = [ [[package]] name = "cfg-if" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" [[package]] name = "ciborium" @@ -723,9 +723,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.39" +version = "4.5.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd60e63e9be68e5fb56422e397cf9baddded06dae1d2e523401542383bc72a9f" +checksum = "40b6887a1d8685cebccf115538db5c0efe625ccac9696ad45c409d96566e910f" dependencies = [ "clap_builder", "clap_derive", @@ -733,9 +733,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.39" +version = "4.5.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89cc6392a1f72bbeb820d71f32108f61fdaf18bc526e1d23954168a67759ef51" +checksum = "e0c66c08ce9f0c698cbce5c0279d0bb6ac936d8674174fe48f736533b964f59e" dependencies = [ "anstream", "anstyle", @@ -745,21 +745,21 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.32" +version = "4.5.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09176aae279615badda0765c0c0b3f6ed53f4709118af73cf4655d85d1530cd7" +checksum = "d2c7947ae4cc3d851207c1adb5b5e260ff0cca11446b1d6d1423788e442257ce" dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] name = "clap_lex" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" +checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" [[package]] name = "colorchoice" @@ -923,7 +923,7 @@ checksum = "d150dea618e920167e5973d70ae6ece4385b7164e0d799fe7c122dd0a5d912ad" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -934,7 +934,7 @@ checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -947,7 +947,7 @@ dependencies = [ "proc-macro2", "quote", "rustc_version", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -967,7 +967,7 @@ checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", "unicode-xid", ] @@ -1018,7 +1018,7 @@ dependencies = [ "enum-ordinalize", "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -1062,7 +1062,7 @@ checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -1074,7 +1074,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -1085,12 +1085,12 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "errno" -version = "0.3.12" +version = "0.3.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cea14ef9355e3beab063703aa9dab15afd25f0667c341310c1e5274bb1d0da18" +checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.60.2", ] [[package]] @@ -1151,7 +1151,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#f2346a942b630d9af4cc4f88e875fefaf5e9f6d0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" dependencies = [ "p3", "rand_core", @@ -1204,7 +1204,7 @@ checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi 0.11.1+wasi-snapshot-preview1", ] [[package]] @@ -1221,14 +1221,14 @@ dependencies = [ [[package]] name = "getset" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3586f256131df87204eb733da72e3d3eb4f343c639f4b7be279ac7c48baeafe" +checksum = "9cf0fc11e47561d47397154977bc219f4cf809b2974facc3ccb3b89e2436f912" dependencies = [ "proc-macro-error2", "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -1239,9 +1239,9 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "glam" -version = "0.30.3" +version = "0.30.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b46b9ca4690308844c644e7c634d68792467260e051c8543e0c7871662b3ba7" +checksum = "50a99dbe56b72736564cfa4b85bf9a33079f16ae8b74983ab06af3b1a3696b11" [[package]] name = "glob" @@ -1407,9 +1407,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "hermit-abi" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f154ce46856750ed433c8649605bf7ed2de3bc35fd9d2a9f30cddd873c80cb08" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" [[package]] name = "hex" @@ -1462,7 +1462,7 @@ checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" dependencies = [ "hermit-abi", "libc", - "windows-sys", + "windows-sys 0.59.0", ] [[package]] @@ -1557,9 +1557,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.172" +version = "0.2.174" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" +checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" [[package]] name = "libm" @@ -1606,9 +1606,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.7.4" +version = "2.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" [[package]] name = "memuse" @@ -1664,9 +1664,9 @@ dependencies = [ [[package]] name = "miniz_oxide" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" dependencies = [ "adler2", ] @@ -1674,7 +1674,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#f2346a942b630d9af4cc4f88e875fefaf5e9f6d0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" dependencies = [ "aes", "bincode", @@ -1704,7 +1704,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#f2346a942b630d9af4cc4f88e875fefaf5e9f6d0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -1717,22 +1717,22 @@ dependencies = [ [[package]] name = "munge" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e22e7961c873e8b305b176d2a4e1d41ce7ba31bc1c52d2a107a89568ec74c55" +checksum = "9cce144fab80fbb74ec5b89d1ca9d41ddf6b644ab7e986f7d3ed0aab31625cb1" dependencies = [ "munge_macro", ] [[package]] name = "munge_macro" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ac7d860b767c6398e88fe93db73ce53eb496057aa6895ffa4d60cb02e1d1c6b" +checksum = "574af9cd5b9971cbfdf535d6a8d533778481b241c447826d976101e0149392a1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -1813,7 +1813,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -1955,7 +1955,7 @@ source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_mul dependencies = [ "itertools 0.14.0", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -1980,7 +1980,7 @@ source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_mul dependencies = [ "itertools 0.14.0", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -1990,7 +1990,7 @@ source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_mul dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -2016,7 +2016,7 @@ version = "1.1.0" source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" dependencies = [ "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -2074,7 +2074,7 @@ version = "1.1.0" source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" dependencies = [ "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -2269,7 +2269,7 @@ checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#f2346a942b630d9af4cc4f88e875fefaf5e9f6d0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" dependencies = [ "p3-baby-bear", "p3-challenger", @@ -2704,7 +2704,7 @@ checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#f2346a942b630d9af4cc4f88e875fefaf5e9f6d0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" dependencies = [ "criterion", "ff_ext", @@ -2765,7 +2765,7 @@ dependencies = [ "proc-macro-error-attr2", "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -2794,20 +2794,20 @@ checksum = "ca414edb151b4c8d125c12566ab0d74dc9cdba36fb80eb7b848c15f495fd32d1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] name = "quanta" -version = "0.12.5" +version = "0.12.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bd1fe6824cea6538803de3ff1bc0cf3949024db3d43c9643024bfb33a807c0e" +checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" dependencies = [ "crossbeam-utils", "libc", "once_cell", "raw-cpuid", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi 0.11.1+wasi-snapshot-preview1", "web-sys", "winapi", ] @@ -2823,9 +2823,9 @@ dependencies = [ [[package]] name = "r-efi" -version = "5.2.0" +version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] name = "radium" @@ -3003,7 +3003,7 @@ checksum = "246b40ac189af6c675d124b802e8ef6d5246c53e17367ce9501f8f66a81abb7a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -3029,9 +3029,9 @@ dependencies = [ [[package]] name = "rustc-demangle" -version = "0.1.24" +version = "0.1.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +checksum = "989e6739f80c4ad5b13e0fd7fe89531180375b18520cc8c82080e4dc4035b84f" [[package]] name = "rustc-hash" @@ -3064,7 +3064,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.59.0", ] [[package]] @@ -3161,7 +3161,7 @@ checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -3282,7 +3282,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -3307,7 +3307,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#f2346a942b630d9af4cc4f88e875fefaf5e9f6d0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" dependencies = [ "crossbeam-channel", "ff_ext", @@ -3324,14 +3324,14 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#f2346a942b630d9af4cc4f88e875fefaf5e9f6d0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" dependencies = [ "itertools 0.13.0", "p3", "proc-macro2", "quote", "rand", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -3347,9 +3347,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.101" +version = "2.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce2b7fc941b3a24138a0a7cf8e858bfc6a992e7978a068a5c760deb0ed43caf" +checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" dependencies = [ "proc-macro2", "quote", @@ -3372,7 +3372,7 @@ dependencies = [ "getrandom 0.3.3", "once_cell", "rustix", - "windows-sys", + "windows-sys 0.59.0", ] [[package]] @@ -3403,17 +3403,16 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] name = "thread_local" -version = "1.1.8" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" dependencies = [ "cfg-if", - "once_cell", ] [[package]] @@ -3513,7 +3512,7 @@ dependencies = [ "serde_spanned", "toml_datetime", "toml_write", - "winnow 0.7.10", + "winnow 0.7.11", ] [[package]] @@ -3535,13 +3534,13 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.29" +version = "0.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b1ffbcf9c6f6b99d386e7444eb608ba646ae452a36b39737deb9663b610f662" +checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -3599,7 +3598,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#f2346a942b630d9af4cc4f88e875fefaf5e9f6d0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" dependencies = [ "crossbeam-channel", "ff_ext", @@ -3693,9 +3692,9 @@ dependencies = [ [[package]] name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" +version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasi" @@ -3728,7 +3727,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", "wasm-bindgen-shared", ] @@ -3750,7 +3749,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3777,7 +3776,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#f2346a942b630d9af4cc4f88e875fefaf5e9f6d0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" dependencies = [ "bincode", "blake2", @@ -3830,7 +3829,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys", + "windows-sys 0.59.0", ] [[package]] @@ -3845,7 +3844,16 @@ version = "0.59.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.2", ] [[package]] @@ -3854,14 +3862,30 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c66f69fcc9ce11da9966ddb31a40968cad001c5bedeb5c2b82ede4253ab48aef" +dependencies = [ + "windows_aarch64_gnullvm 0.53.0", + "windows_aarch64_msvc 0.53.0", + "windows_i686_gnu 0.53.0", + "windows_i686_gnullvm 0.53.0", + "windows_i686_msvc 0.53.0", + "windows_x86_64_gnu 0.53.0", + "windows_x86_64_gnullvm 0.53.0", + "windows_x86_64_msvc 0.53.0", ] [[package]] @@ -3870,48 +3894,96 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" + [[package]] name = "windows_aarch64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" + [[package]] name = "windows_i686_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" + [[package]] name = "windows_i686_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_i686_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" + [[package]] name = "windows_x86_64_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" + [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" + [[package]] name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" + [[package]] name = "winnow" version = "0.5.40" @@ -3923,9 +3995,9 @@ dependencies = [ [[package]] name = "winnow" -version = "0.7.10" +version = "0.7.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06928c8748d81b05c9be96aad92e1b6ff01833332f281e8cfca3be4b35fc9ec" +checksum = "74c7b26e3480b707944fc872477815d29a8e429d2f93a1ce000f5fa84a15cbcd" dependencies = [ "memchr", ] @@ -3942,7 +4014,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#f2346a942b630d9af4cc4f88e875fefaf5e9f6d0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" dependencies = [ "ff_ext", "multilinear_extensions", @@ -3963,22 +4035,22 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.25" +version = "0.8.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" +checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.25" +version = "0.8.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" +checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -3998,7 +4070,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 1ed81e5..935a9a8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,13 +38,13 @@ ark-poly = "0.5" ark-serialize = "0.5" # Ceno -ceno_mle = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext", package = "multilinear_extensions" } -ceno_sumcheck = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext", package = "sumcheck" } -ceno_transcript = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext", package = "transcript" } -ceno_zkvm = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext" } -ceno_emul = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext" } -mpcs = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext" } -ff_ext = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext" } +ceno_mle = { git = "https://github.com/scroll-tech/ceno.git", branch = "cyte/generate-basefold-verifier-query-phase-test-data", package = "multilinear_extensions" } +ceno_sumcheck = { git = "https://github.com/scroll-tech/ceno.git", branch = "cyte/generate-basefold-verifier-query-phase-test-data", package = "sumcheck" } +ceno_transcript = { git = "https://github.com/scroll-tech/ceno.git", branch = "cyte/generate-basefold-verifier-query-phase-test-data", package = "transcript" } +ceno_zkvm = { git = "https://github.com/scroll-tech/ceno.git", branch = "cyte/generate-basefold-verifier-query-phase-test-data" } +ceno_emul = { git = "https://github.com/scroll-tech/ceno.git", branch = "cyte/generate-basefold-verifier-query-phase-test-data" } +mpcs = { git = "https://github.com/scroll-tech/ceno.git", branch = "cyte/generate-basefold-verifier-query-phase-test-data" } +ff_ext = { git = "https://github.com/scroll-tech/ceno.git", branch = "cyte/generate-basefold-verifier-query-phase-test-data" } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 3cec547..8cb92b3 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -1,5 +1,6 @@ // Note: check all XXX comments! +use mpcs::QueryPhaseVerifierInput as InnerQueryPhaseVerifierInput; use openvm_native_compiler::{asm::AsmConfig, prelude::*}; use openvm_native_recursion::{ hints::{Hintable, VecAutoHintable}, @@ -203,6 +204,25 @@ pub struct QueryPhaseVerifierInput { pub point_evals: Vec<(Point, Vec)>, } +impl From> for QueryPhaseVerifierInput { + fn from(input: InnerQueryPhaseVerifierInput) -> Self { + QueryPhaseVerifierInput { + max_num_var: input.max_num_var, + indices: input.indices, + final_message: input.final_message, + batch_coeffs: input.batch_coeffs, + queries: input.queries, + fixed_comm: input.fixed_comm, + witin_comm: input.witin_comm, + circuit_meta: input.circuit_meta, + commits: input.commits, + fold_challenges: input.fold_challenges, + sumcheck_messages: input.sumcheck_messages, + point_evals: input.point_evals, + } + } +} + impl Hintable for QueryPhaseVerifierInput { type HintVariable = QueryPhaseVerifierInputVariable; @@ -773,6 +793,7 @@ pub(crate) fn batch_verifier_query_phase( pub mod tests { use std::{fs::File, io::Read}; + use mpcs::QueryPhaseVerifierInput as InnerQueryPhaseVerifierInput; use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; use openvm_native_circuit::{Native, NativeConfig}; use openvm_native_compiler::asm::AsmBuilder; @@ -810,7 +831,8 @@ pub mod tests { let mut f = File::open("query_phase_verifier_input.bin".to_string()).unwrap(); let mut content: Vec = Vec::new(); f.read_to_end(&mut content).unwrap(); - let input: QueryPhaseVerifierInput = bincode::deserialize(&content).unwrap(); + let input: InnerQueryPhaseVerifierInput = bincode::deserialize(&content).unwrap(); + let input: QueryPhaseVerifierInput = input.into(); witness_stream.extend(input.write()); // PROGRAM From 1feb92c41e1191a7e22be179578c023524406d9e Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Tue, 24 Jun 2025 19:57:40 +0800 Subject: [PATCH 52/70] (WIP) Fix query phase transform --- Cargo.lock | 31 +++++++------ Cargo.toml | 1 + src/basefold_verifier/query_phase.rs | 68 +++++++++++++++++++++++++--- 3 files changed, 78 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1fef29a..510a5cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -547,7 +547,7 @@ dependencies = [ [[package]] name = "ceno-examples" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" dependencies = [ "glob", ] @@ -580,6 +580,7 @@ dependencies = [ "p3-challenger", "p3-commit", "p3-field", + "p3-fri", "p3-goldilocks", "p3-matrix", "p3-monty-31", @@ -596,7 +597,7 @@ dependencies = [ [[package]] name = "ceno_emul" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" dependencies = [ "anyhow", "ceno_rt", @@ -617,7 +618,7 @@ dependencies = [ [[package]] name = "ceno_host" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" dependencies = [ "anyhow", "ceno_emul", @@ -630,7 +631,7 @@ dependencies = [ [[package]] name = "ceno_rt" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" dependencies = [ "getrandom 0.2.16", "rkyv", @@ -639,7 +640,7 @@ dependencies = [ [[package]] name = "ceno_zkvm" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" dependencies = [ "base64", "bincode", @@ -1151,7 +1152,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" dependencies = [ "p3", "rand_core", @@ -1674,7 +1675,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" dependencies = [ "aes", "bincode", @@ -1704,7 +1705,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -2269,7 +2270,7 @@ checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" dependencies = [ "p3-baby-bear", "p3-challenger", @@ -2704,7 +2705,7 @@ checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" dependencies = [ "criterion", "ff_ext", @@ -3307,7 +3308,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" dependencies = [ "crossbeam-channel", "ff_ext", @@ -3324,7 +3325,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" dependencies = [ "itertools 0.13.0", "p3", @@ -3598,7 +3599,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" dependencies = [ "crossbeam-channel", "ff_ext", @@ -3776,7 +3777,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" dependencies = [ "bincode", "blake2", @@ -4014,7 +4015,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#16ac8d30410d19481b4a77ae780706fe4a43ead5" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/Cargo.toml b/Cargo.toml index 935a9a8..42e186d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ bincode = "1" tracing = "0.1.40" # Plonky3 +p3-fri = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } p3-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 8cb92b3..86ef55f 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -1,5 +1,6 @@ // Note: check all XXX comments! +use ff_ext::{ExtensionField, PoseidonField}; use mpcs::QueryPhaseVerifierInput as InnerQueryPhaseVerifierInput; use openvm_native_compiler::{asm::AsmConfig, prelude::*}; use openvm_native_recursion::{ @@ -7,6 +8,7 @@ use openvm_native_recursion::{ vars::HintSlice, }; use openvm_stark_sdk::p3_baby_bear::BabyBear; +use p3_commit::ExtensionMmcs; use p3_field::extension::BinomialExtensionField; use p3_field::FieldAlgebra; use serde::Deserialize; @@ -23,12 +25,34 @@ pub type F = BabyBear; pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; +use p3_fri::BatchOpening as InnerBatchOpening; #[derive(Deserialize)] pub struct BatchOpening { pub opened_values: Vec>, pub opening_proof: MmcsProof, } +impl + From< + InnerBatchOpening< + ::BaseField, + <::BaseField as PoseidonField>::MMCS, + >, + > for BatchOpening +{ + fn from( + inner: InnerBatchOpening< + ::BaseField, + <::BaseField as PoseidonField>::MMCS, + >, + ) -> Self { + Self { + opened_values: inner.opened_values, + opening_proof: inner.opening_proof.into(), + } + } +} + impl Hintable for BatchOpening { type HintVariable = BatchOpeningVariable; @@ -64,12 +88,27 @@ pub struct BatchOpeningVariable { pub opening_proof: HintSlice, } +use p3_fri::CommitPhaseProofStep as InnerCommitPhaseProofStep; #[derive(Deserialize)] pub struct CommitPhaseProofStep { pub sibling_value: E, pub opening_proof: MmcsProof, } +pub type ExtMmcs = ExtensionMmcs< + ::BaseField, + E, + <::BaseField as PoseidonField>::MMCS, +>; +impl From>> for CommitPhaseProofStep { + fn from(inner: InnerCommitPhaseProofStep>) -> Self { + Self { + sibling_value: inner.sibling_value, + opening_proof: inner.opening_proof.into(), + } + } +} + impl Hintable for CommitPhaseProofStep { type HintVariable = CommitPhaseProofStepVariable; @@ -114,6 +153,21 @@ pub struct QueryOpeningProof { } type QueryOpeningProofs = Vec; +use mpcs::QueryOpeningProof as InnerQueryOpeningProof; +impl From> for QueryOpeningProof { + fn from(proof: InnerQueryOpeningProof) -> Self { + QueryOpeningProof { + witin_base_proof: proof.witin_base_proof.into(), + fixed_base_proof: proof.fixed_base_proof.map(|p| p.into()), + commit_phase_openings: proof + .commit_phase_openings + .into_iter() + .map(|p| p.into()) + .collect(), + } + } +} + impl Hintable for QueryOpeningProof { type HintVariable = QueryOpeningProofVariable; @@ -211,14 +265,14 @@ impl From> for QueryPhaseVerifierInput { indices: input.indices, final_message: input.final_message, batch_coeffs: input.batch_coeffs, - queries: input.queries, - fixed_comm: input.fixed_comm, - witin_comm: input.witin_comm, - circuit_meta: input.circuit_meta, - commits: input.commits, + queries: input.queries.into_iter().map(|q| q.into()).collect(), + fixed_comm: input.fixed_comm.into(), + witin_comm: input.witin_comm.into(), + circuit_meta: input.circuit_meta.into(), + commits: input.commits.into(), fold_challenges: input.fold_challenges, - sumcheck_messages: input.sumcheck_messages, - point_evals: input.point_evals, + sumcheck_messages: input.sumcheck_messages.into(), + point_evals: input.point_evals.into(), } } } From 410e99660ce1b72bbe8ff76700e89e656b3ef20b Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Wed, 25 Jun 2025 07:47:26 +0800 Subject: [PATCH 53/70] (WIP) Fix query phase transform --- src/basefold_verifier/basefold.rs | 19 ++++++++++++++++++- src/basefold_verifier/query_phase.rs | 2 +- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/basefold_verifier/basefold.rs b/src/basefold_verifier/basefold.rs index b8ed73c..4bb8d77 100644 --- a/src/basefold_verifier/basefold.rs +++ b/src/basefold_verifier/basefold.rs @@ -4,6 +4,8 @@ use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; use serde::Deserialize; +use crate::basefold_verifier::hash::Hash; + use super::{mmcs::*, structs::DIMENSIONS}; pub type F = BabyBear; @@ -18,6 +20,19 @@ pub struct BasefoldCommitment { // pub trivial_commits: Vec, } +use mpcs::BasefoldCommitment as InnerBasefoldCommitment; + +impl From> for BasefoldCommitment { + fn from(value: InnerBasefoldCommitment) -> Self { + Self { + commit: Hash { + value: value.commit().into(), + }, + log2_max_codeword_size: value.log2_max_codeword_size, + } + } +} + impl Hintable for BasefoldCommitment { type HintVariable = BasefoldCommitmentVariable; @@ -36,7 +51,9 @@ impl Hintable for BasefoldCommitment { fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); stream.extend(self.commit.write()); - stream.extend(>::write(&self.log2_max_codeword_size)); + stream.extend(>::write( + &self.log2_max_codeword_size, + )); // stream.extend(self.trivial_commits.write()); stream } diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 86ef55f..2d5482f 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -266,7 +266,7 @@ impl From> for QueryPhaseVerifierInput { final_message: input.final_message, batch_coeffs: input.batch_coeffs, queries: input.queries.into_iter().map(|q| q.into()).collect(), - fixed_comm: input.fixed_comm.into(), + fixed_comm: input.fixed_comm.map(|comm| comm.into()), witin_comm: input.witin_comm.into(), circuit_meta: input.circuit_meta.into(), commits: input.commits.into(), From fbd5f1ff585fbfc05f94c08dafcf59fa672ab65a Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Wed, 25 Jun 2025 07:50:40 +0800 Subject: [PATCH 54/70] (WIP) Fix query phase transform --- Cargo.lock | 34 ++++++++++++++-------------- src/basefold_verifier/query_phase.rs | 2 +- src/basefold_verifier/structs.rs | 12 ++++++++++ 3 files changed, 30 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 510a5cd..bee2eb1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -484,9 +484,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.18.1" +version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793db76d6187cd04dff33004d8e6c9cc4e05cd330500379d2394209271b4aeee" +checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" [[package]] name = "bytecheck" @@ -547,7 +547,7 @@ dependencies = [ [[package]] name = "ceno-examples" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" dependencies = [ "glob", ] @@ -597,7 +597,7 @@ dependencies = [ [[package]] name = "ceno_emul" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" dependencies = [ "anyhow", "ceno_rt", @@ -618,7 +618,7 @@ dependencies = [ [[package]] name = "ceno_host" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" dependencies = [ "anyhow", "ceno_emul", @@ -631,7 +631,7 @@ dependencies = [ [[package]] name = "ceno_rt" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" dependencies = [ "getrandom 0.2.16", "rkyv", @@ -640,7 +640,7 @@ dependencies = [ [[package]] name = "ceno_zkvm" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" dependencies = [ "base64", "bincode", @@ -1152,7 +1152,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" dependencies = [ "p3", "rand_core", @@ -1675,7 +1675,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" dependencies = [ "aes", "bincode", @@ -1705,7 +1705,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -2270,7 +2270,7 @@ checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" dependencies = [ "p3-baby-bear", "p3-challenger", @@ -2705,7 +2705,7 @@ checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" dependencies = [ "criterion", "ff_ext", @@ -3308,7 +3308,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" dependencies = [ "crossbeam-channel", "ff_ext", @@ -3325,7 +3325,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" dependencies = [ "itertools 0.13.0", "p3", @@ -3599,7 +3599,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" dependencies = [ "crossbeam-channel", "ff_ext", @@ -3777,7 +3777,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" dependencies = [ "bincode", "blake2", @@ -4015,7 +4015,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#a9f074b85711df47b32116d056469a282febfdf8" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 2d5482f..64fb2ed 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -268,7 +268,7 @@ impl From> for QueryPhaseVerifierInput { queries: input.queries.into_iter().map(|q| q.into()).collect(), fixed_comm: input.fixed_comm.map(|comm| comm.into()), witin_comm: input.witin_comm.into(), - circuit_meta: input.circuit_meta.into(), + circuit_meta: input.circuit_meta.into_iter().map(|q| q.into()).collect(), commits: input.commits.into(), fold_challenges: input.fold_challenges, sumcheck_messages: input.sumcheck_messages.into(), diff --git a/src/basefold_verifier/structs.rs b/src/basefold_verifier/structs.rs index 366bda0..a933c5c 100644 --- a/src/basefold_verifier/structs.rs +++ b/src/basefold_verifier/structs.rs @@ -30,6 +30,18 @@ pub struct CircuitIndexMeta { pub fixed_num_polys: usize, } +use mpcs::CircuitIndexMeta as InnerCircuitIndexMeta; +impl From for CircuitIndexMeta { + fn from(inner: InnerCircuitIndexMeta) -> Self { + Self { + witin_num_vars: inner.witin_num_vars, + witin_num_polys: inner.witin_num_polys, + fixed_num_vars: inner.fixed_num_vars, + fixed_num_polys: inner.fixed_num_polys, + } + } +} + impl Hintable for CircuitIndexMeta { type HintVariable = CircuitIndexMetaVariable; From dd7951ddd055a28b0eb2c2f9b36b94e5aea0aabc Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Wed, 25 Jun 2025 07:54:39 +0800 Subject: [PATCH 55/70] (WIP) Fix query phase transform --- src/basefold_verifier/query_phase.rs | 12 ++++++++++-- src/tower_verifier/binding.rs | 15 +++++++++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 64fb2ed..20290dd 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -269,9 +269,17 @@ impl From> for QueryPhaseVerifierInput { fixed_comm: input.fixed_comm.map(|comm| comm.into()), witin_comm: input.witin_comm.into(), circuit_meta: input.circuit_meta.into_iter().map(|q| q.into()).collect(), - commits: input.commits.into(), + commits: input + .commits + .into_iter() + .map(|q| super::hash::Hash { value: q.into() }) + .collect(), fold_challenges: input.fold_challenges, - sumcheck_messages: input.sumcheck_messages.into(), + sumcheck_messages: input + .sumcheck_messages + .into_iter() + .map(|q| q.into()) + .collect(), point_evals: input.point_evals.into(), } } diff --git a/src/tower_verifier/binding.rs b/src/tower_verifier/binding.rs index 85807c2..6228524 100644 --- a/src/tower_verifier/binding.rs +++ b/src/tower_verifier/binding.rs @@ -78,10 +78,7 @@ impl Hintable for PointAndEval { fn read(builder: &mut Builder) -> Self::HintVariable { let point = Point::read(builder); let eval = E::read(builder); - PointAndEvalVariable { - point, - eval, - } + PointAndEvalVariable { point, eval } } fn write(&self) -> Vec::N>> { @@ -97,6 +94,16 @@ impl VecAutoHintable for PointAndEval {} pub struct IOPProverMessage { pub evaluations: Vec, } + +use ceno_sumcheck::structs::IOPProverMessage as InnerIOPProverMessage; +impl From> for IOPProverMessage { + fn from(value: InnerIOPProverMessage) -> Self { + IOPProverMessage { + evaluations: value.evaluations, + } + } +} + impl Hintable for IOPProverMessage { type HintVariable = IOPProverMessageVariable; From 6f35d7ba5963545e6612ea2f0304301da0870ec6 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Wed, 25 Jun 2025 07:57:13 +0800 Subject: [PATCH 56/70] Fix query phase transform --- src/basefold_verifier/query_phase.rs | 6 +++++- src/tower_verifier/binding.rs | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 20290dd..bf107a9 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -280,7 +280,11 @@ impl From> for QueryPhaseVerifierInput { .into_iter() .map(|q| q.into()) .collect(), - point_evals: input.point_evals.into(), + point_evals: input + .point_evals + .into_iter() + .map(|q| (Point { fs: q.0 }, q.1)) + .collect(), } } } diff --git a/src/tower_verifier/binding.rs b/src/tower_verifier/binding.rs index 6228524..5fd52bd 100644 --- a/src/tower_verifier/binding.rs +++ b/src/tower_verifier/binding.rs @@ -49,7 +49,7 @@ pub struct TowerVerifierInputVariable { #[derive(Clone, Deserialize)] pub struct Point { - pub fs: Vec, + pub fs: Vec, } impl Hintable for Point { type HintVariable = PointVariable; From 478ef692780d93aba995f41a7697d6f83b40dd5f Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Wed, 25 Jun 2025 08:03:12 +0800 Subject: [PATCH 57/70] Comment out connecting code temporarily --- src/e2e/mod.rs | 30 +++++++++++++++--------------- src/zkvm_verifier/binding.rs | 4 ++-- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/e2e/mod.rs b/src/e2e/mod.rs index ca97d4a..505c217 100644 --- a/src/e2e/mod.rs +++ b/src/e2e/mod.rs @@ -402,20 +402,20 @@ pub fn parse_zkvm_proof_import( let pcs_proof = zkvm_proof.fixed_witin_opening_proof; - let query_phase_verifier_input = QueryPhaseVerifierInput { - max_num_var, - indices, - final_message, - batch_coeffs, - queries, - fixed_comm, - witin_comm, - circuit_meta, - commits: pcs_proof.commits, - fold_challenges, - sumcheck_messages: pcs_proof.sumcheck_proof.unwrap(), - point_evals, - }; + // let query_phase_verifier_input = QueryPhaseVerifierInput { + // max_num_var, + // indices, + // final_message, + // batch_coeffs, + // queries, + // fixed_comm, + // witin_comm, + // circuit_meta, + // commits: pcs_proof.commits, + // fold_challenges, + // sumcheck_messages: pcs_proof.sumcheck_proof.unwrap(), + // point_evals, + // }; ( ZKVMProofInput { @@ -426,7 +426,7 @@ pub fn parse_zkvm_proof_import( witin_commit, fixed_commit, num_instances: zkvm_proof.num_instances.clone(), - query_phase_verifier_input, + // query_phase_verifier_input, }, proving_sequence, ) diff --git a/src/zkvm_verifier/binding.rs b/src/zkvm_verifier/binding.rs index 532ab3c..21cdee2 100644 --- a/src/zkvm_verifier/binding.rs +++ b/src/zkvm_verifier/binding.rs @@ -111,7 +111,7 @@ pub(crate) struct ZKVMProofInput { pub witin_commit: BasefoldCommitment, pub fixed_commit: Option>, pub num_instances: Vec<(usize, usize)>, - pub query_phase_verifier_input: QueryPhaseVerifierInput, + // pub query_phase_verifier_input: QueryPhaseVerifierInput, } impl Hintable for ZKVMProofInput { type HintVariable = ZKVMProofInputVariable; @@ -235,7 +235,7 @@ impl Hintable for ZKVMProofInput { } stream.extend(num_instances_vec.write()); - stream.extend(self.query_phase_verifier_input.write()); + // stream.extend(self.query_phase_verifier_input.write()); stream } From 9689d1d7f775720bb8cafcde42d59802b9c7747b Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Wed, 25 Jun 2025 09:19:29 +0800 Subject: [PATCH 58/70] Query phase compile successful --- src/basefold_verifier/query_phase.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index bf107a9..9816475 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -1,5 +1,7 @@ // Note: check all XXX comments! +use std::fmt::Debug; + use ff_ext::{ExtensionField, PoseidonField}; use mpcs::QueryPhaseVerifierInput as InnerQueryPhaseVerifierInput; use openvm_native_compiler::{asm::AsmConfig, prelude::*}; @@ -379,7 +381,7 @@ pub struct QueryPhaseVerifierInputVariable { pub point_evals: Array>, } -pub(crate) fn batch_verifier_query_phase( +pub(crate) fn batch_verifier_query_phase( builder: &mut Builder, input: QueryPhaseVerifierInputVariable, ) { @@ -465,6 +467,7 @@ pub(crate) fn batch_verifier_query_phase( let idx = builder.get(&input.indices, i); let query = builder.get(&input.queries, i); let witin_opened_values = query.witin_base_proof.opened_values; + let witin_opening_proof = query.witin_base_proof.opening_proof; let fixed_is_some = query.fixed_is_some; let fixed_commit = query.fixed_base_proof; @@ -509,6 +512,7 @@ pub(crate) fn batch_verifier_query_phase( .if_eq(fixed_is_some.clone(), Usize::from(1)) .then(|builder| { let fixed_opened_values = fixed_commit.opened_values.clone(); + let fixed_opening_proof = fixed_commit.opening_proof.clone(); // new_idx used by mmcs proof let new_idx: Var = builder.eval(idx); @@ -737,7 +741,7 @@ pub(crate) fn batch_verifier_query_phase( let dimensions = builder.dyn_array(1); // let two: Var<_> = builder.eval(Usize::from(2)); builder.set_value(&dimensions, 0, n_d_i.clone()); - let opened_values = builder.uninit_fixed_array(1); + let opened_values = builder.dyn_array(1); builder.set_value(&opened_values, 0, leafs.clone()); let ext_mmcs_verifier_input = ExtMmcsVerifierInputVariable { commit: pi_comm.clone(), From 94d4c132e139fe88f4864b51fc64390994d71c47 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Wed, 25 Jun 2025 10:46:26 +0800 Subject: [PATCH 59/70] (WIP) Debugging query phase --- src/basefold_verifier/query_phase.rs | 35 ++++++++++++++++------------ 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 9816475..32b077c 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -73,13 +73,12 @@ impl Hintable for BatchOpening { fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); stream.extend(self.opened_values.write()); - stream.extend( - self.opening_proof - .iter() - .map(|p| p.to_vec()) - .collect::>() - .write(), - ); + stream.extend(vec![ + vec![::N::from_canonical_usize( + self.opening_proof.len(), + )], + self.opening_proof.iter().flatten().copied().collect(), + ]); stream } } @@ -129,13 +128,12 @@ impl Hintable for CommitPhaseProofStep { fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); stream.extend(self.sibling_value.write()); - stream.extend( - self.opening_proof - .iter() - .map(|p| p.to_vec()) - .collect::>() - .write(), - ); + stream.extend(vec![ + vec![::N::from_canonical_usize( + self.opening_proof.len(), + )], + self.opening_proof.iter().flatten().copied().collect(), + ]); stream } } @@ -872,7 +870,7 @@ pub mod tests { use openvm_stark_sdk::{ config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, }; - use p3_field::{extension::BinomialExtensionField, FieldAlgebra, FieldExtensionAlgebra}; + use p3_field::{extension::BinomialExtensionField, Field, FieldAlgebra}; type SC = BabyBearPoseidon2Config; type F = BabyBear; @@ -903,8 +901,15 @@ pub mod tests { f.read_to_end(&mut content).unwrap(); let input: InnerQueryPhaseVerifierInput = bincode::deserialize(&content).unwrap(); let input: QueryPhaseVerifierInput = input.into(); + witness_stream.extend(input.write()); + // TODO: the builder reads some additional hints after reading the query + // phase verifier input. Need to feed them into the stream + + // inv_2 + witness_stream.push(vec![F::TWO.try_inverse().unwrap()]); + // PROGRAM let program: Program< p3_monty_31::MontyField31, From d0a45ca797f188a2c68b633c539c0f5b6a173377 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Thu, 26 Jun 2025 08:08:34 +0800 Subject: [PATCH 60/70] Fix a bug in verifier query phase --- src/basefold_verifier/query_phase.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 32b077c..e0ea0e4 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -406,19 +406,22 @@ pub(crate) fn batch_verifier_query_phase( // encode_small let final_rmm_values_len = builder.get(&input.final_message, 0).len(); let final_rmm_values = builder.dyn_array(final_rmm_values_len.clone()); + builder .range(0, final_rmm_values_len.clone()) .for_each(|i_vec, builder| { let i = i_vec[0]; - let row = builder.get(&input.final_message, i); + let row_len = input.final_message.len(); let sum = builder.constant(C::EF::ZERO); - builder.range(0, row.len()).for_each(|j_vec, builder| { + builder.range(0, row_len).for_each(|j_vec, builder| { let j = j_vec[0]; - let row_j = builder.get(&row, j); + let row = builder.get(&input.final_message, j); + let row_j = builder.get(&row, i); builder.assign(&sum, sum + row_j); }); builder.set_value(&final_rmm_values, i, sum); }); + let final_rmm = RowMajorMatrixVariable { values: final_rmm_values, width: builder.eval(Usize::from(1)), From 8a61142c9eb2b0774e90991952919eb499006b82 Mon Sep 17 00:00:00 2001 From: Yuncong Zhang Date: Thu, 26 Jun 2025 08:56:18 +0800 Subject: [PATCH 61/70] Read additional hints from binary file --- Cargo.lock | 38 ++++++++++++++-------------- src/basefold_verifier/query_phase.rs | 22 ++++++++++++---- 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bee2eb1..bf573cb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -547,7 +547,7 @@ dependencies = [ [[package]] name = "ceno-examples" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" dependencies = [ "glob", ] @@ -597,7 +597,7 @@ dependencies = [ [[package]] name = "ceno_emul" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" dependencies = [ "anyhow", "ceno_rt", @@ -618,7 +618,7 @@ dependencies = [ [[package]] name = "ceno_host" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" dependencies = [ "anyhow", "ceno_emul", @@ -631,7 +631,7 @@ dependencies = [ [[package]] name = "ceno_rt" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" dependencies = [ "getrandom 0.2.16", "rkyv", @@ -640,7 +640,7 @@ dependencies = [ [[package]] name = "ceno_zkvm" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" dependencies = [ "base64", "bincode", @@ -861,9 +861,9 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crunchy" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" [[package]] name = "crypto-common" @@ -1152,7 +1152,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" dependencies = [ "p3", "rand_core", @@ -1570,9 +1570,9 @@ checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" [[package]] name = "libredox" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +checksum = "1580801010e535496706ba011c15f8532df6b42297d2e471fec38ceadd8c0638" dependencies = [ "bitflags", "libc", @@ -1675,7 +1675,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" dependencies = [ "aes", "bincode", @@ -1705,7 +1705,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -2270,7 +2270,7 @@ checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" dependencies = [ "p3-baby-bear", "p3-challenger", @@ -2705,7 +2705,7 @@ checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" dependencies = [ "criterion", "ff_ext", @@ -3308,7 +3308,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" dependencies = [ "crossbeam-channel", "ff_ext", @@ -3325,7 +3325,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" dependencies = [ "itertools 0.13.0", "p3", @@ -3599,7 +3599,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" dependencies = [ "crossbeam-channel", "ff_ext", @@ -3777,7 +3777,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" dependencies = [ "bincode", "blake2", @@ -4015,7 +4015,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#7aa2a55ebcbf087e16e4a629020cc072691a059e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index e0ea0e4..e533bb5 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -454,6 +454,7 @@ pub(crate) fn batch_verifier_query_phase( // 1. height_order: after sorting by decreasing height, the original index of each entry // 2. num_unique_height: number of different heights // 3. count_per_unique_height: for each unique height, number of dimensions of that height + // builder.assert_nonzero(&Usize::from(0)); let (folding_sorted_order_index, _num_unique_num_vars, count_per_unique_num_var) = sort_with_count( builder, @@ -864,7 +865,7 @@ pub(crate) fn batch_verifier_query_phase( pub mod tests { use std::{fs::File, io::Read}; - use mpcs::QueryPhaseVerifierInput as InnerQueryPhaseVerifierInput; + use mpcs::{QueryPhaseAdditionalHint, QueryPhaseVerifierInput as InnerQueryPhaseVerifierInput}; use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; use openvm_native_circuit::{Native, NativeConfig}; use openvm_native_compiler::asm::AsmBuilder; @@ -907,11 +908,22 @@ pub mod tests { witness_stream.extend(input.write()); - // TODO: the builder reads some additional hints after reading the query + // the builder reads some additional hints after reading the query // phase verifier input. Need to feed them into the stream - - // inv_2 - witness_stream.push(vec![F::TWO.try_inverse().unwrap()]); + let mut f = File::open("query_phase_additional_hint.bin".to_string()).unwrap(); + let mut content: Vec = Vec::new(); + f.read_to_end(&mut content).unwrap(); + let input: QueryPhaseAdditionalHint = bincode::deserialize(&content).unwrap(); + + witness_stream.extend(vec![vec![input.two_inv]]); + witness_stream.extend(vec![vec![F::from_canonical_usize( + input.num_unique_entries, + )]]); + witness_stream.extend(vec![input + .sorting_orders + .iter() + .map(|x| F::from_canonical_usize(*x)) + .collect()]); // PROGRAM let program: Program< From f01ae481b87920d7960f8bbd4e241de79713d599 Mon Sep 17 00:00:00 2001 From: xkx Date: Tue, 8 Jul 2025 09:48:57 +0800 Subject: [PATCH 62/70] Fix: batch opening (#28) * comment out * wip * hash read/write unit test passed * wip2 * add inv_2 to input stream * Fix memory out of bound problem * update Cargo.lock * Avoid providing two-adic generators inverses by hint * Replace idx_bits by num2bits_f * Replace idx_len by max_num_vars + rate log * Change index bits to small endian * Try fixing new index check * Fix new index check * Sub one from index len * Identified the cause * Fix index out of bound error * Change comment * Add native verify test * Fix evals shape error * Fix mmcs verify failure * Remove some print lines * fmt * Supply all hints * Fix new index compute * Fix ext mmcs verify dimension * Slice idx bits in ext mmcs verify * Some small fixes * right shift by hint * Fix verifier_folding_coeffs_level * Successfully run to first checkpoint * Fails at last line * Identified unimplemented function build_eq_x_r_vec_sequential_with_offset * batch verifier query phase test passes * Print the cycle count --------- Co-authored-by: Yuncong Zhang --- Cargo.lock | 35 +- Cargo.toml | 27 +- src/arithmetics/mod.rs | 1 + src/basefold_verifier/basefold.rs | 7 +- src/basefold_verifier/extension_mmcs.rs | 18 +- src/basefold_verifier/hash.rs | 66 ++- src/basefold_verifier/mmcs.rs | 1 + src/basefold_verifier/mod.rs | 10 +- src/basefold_verifier/query_phase.rs | 622 +++++++++++++++--------- src/basefold_verifier/rs.rs | 10 +- src/basefold_verifier/structs.rs | 22 +- src/basefold_verifier/utils.rs | 73 ++- src/e2e/mod.rs | 1 + 13 files changed, 561 insertions(+), 332 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bf573cb..79fc6b2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -547,7 +547,7 @@ dependencies = [ [[package]] name = "ceno-examples" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "glob", ] @@ -592,12 +592,13 @@ dependencies = [ "sumcheck", "tracing", "transcript", + "witness", ] [[package]] name = "ceno_emul" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "anyhow", "ceno_rt", @@ -618,7 +619,7 @@ dependencies = [ [[package]] name = "ceno_host" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "anyhow", "ceno_emul", @@ -631,7 +632,7 @@ dependencies = [ [[package]] name = "ceno_rt" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "getrandom 0.2.16", "rkyv", @@ -640,7 +641,7 @@ dependencies = [ [[package]] name = "ceno_zkvm" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "base64", "bincode", @@ -1152,7 +1153,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "p3", "rand_core", @@ -1438,9 +1439,9 @@ checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" [[package]] name = "indexmap" -version = "2.9.0" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" +checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" dependencies = [ "equivalent", "hashbrown 0.15.4", @@ -1675,7 +1676,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "aes", "bincode", @@ -1705,7 +1706,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -2270,7 +2271,7 @@ checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "p3-baby-bear", "p3-challenger", @@ -2705,7 +2706,7 @@ checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "criterion", "ff_ext", @@ -3308,7 +3309,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "crossbeam-channel", "ff_ext", @@ -3325,7 +3326,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "itertools 0.13.0", "p3", @@ -3599,7 +3600,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "crossbeam-channel", "ff_ext", @@ -3777,7 +3778,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "bincode", "blake2", @@ -4015,7 +4016,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=cyte%2Fgenerate-basefold-verifier-query-phase-test-data#e9d59e268818cd2d26668a2fac014fde25cb07dd" +source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/Cargo.toml b/Cargo.toml index 42e186d..97088ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,6 @@ bincode = "1" tracing = "0.1.40" # Plonky3 -p3-fri = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } p3-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } @@ -30,6 +29,7 @@ p3-util = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } p3-challenger = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } p3-monty-31 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } +p3-fri = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } p3-goldilocks = { git = "https://github.com/Plonky3/plonky3", rev = "1ba4e5c" } # WHIR @@ -39,15 +39,26 @@ ark-poly = "0.5" ark-serialize = "0.5" # Ceno -ceno_mle = { git = "https://github.com/scroll-tech/ceno.git", branch = "cyte/generate-basefold-verifier-query-phase-test-data", package = "multilinear_extensions" } -ceno_sumcheck = { git = "https://github.com/scroll-tech/ceno.git", branch = "cyte/generate-basefold-verifier-query-phase-test-data", package = "sumcheck" } -ceno_transcript = { git = "https://github.com/scroll-tech/ceno.git", branch = "cyte/generate-basefold-verifier-query-phase-test-data", package = "transcript" } -ceno_zkvm = { git = "https://github.com/scroll-tech/ceno.git", branch = "cyte/generate-basefold-verifier-query-phase-test-data" } -ceno_emul = { git = "https://github.com/scroll-tech/ceno.git", branch = "cyte/generate-basefold-verifier-query-phase-test-data" } -mpcs = { git = "https://github.com/scroll-tech/ceno.git", branch = "cyte/generate-basefold-verifier-query-phase-test-data" } -ff_ext = { git = "https://github.com/scroll-tech/ceno.git", branch = "cyte/generate-basefold-verifier-query-phase-test-data" } +ceno_mle = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext", package = "multilinear_extensions" } +ceno_sumcheck = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext", package = "sumcheck" } +ceno_transcript = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext", package = "transcript" } +ceno_witness = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext", package = "witness" } +ceno_zkvm = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext" } +ceno_emul = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext" } +mpcs = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext" } +ff_ext = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext" } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" [features] bench-metrics = ["openvm-circuit/bench-metrics"] + +# [patch."https://github.com/scroll-tech/ceno.git"] +# ceno_mle = { path = "../ceno/multilinear_extensions", package = "multilinear_extensions" } +# ceno_sumcheck = { path = "../ceno/sumcheck", package = "sumcheck" } +# ceno_transcript = { path = "../ceno/transcript", package = "transcript" } +# ceno_witness = { path = "../ceno/witness", package = "witness" } +# ceno_zkvm = { path = "../ceno/ceno_zkvm" } +# ceno_emul = { path = "../ceno/ceno_emul" } +# mpcs = { path = "../ceno/mpcs" } +# ff_ext = { path = "../ceno/ff_ext" } diff --git a/src/arithmetics/mod.rs b/src/arithmetics/mod.rs index 0d63942..64bd151 100644 --- a/src/arithmetics/mod.rs +++ b/src/arithmetics/mod.rs @@ -465,6 +465,7 @@ pub fn build_eq_x_r_vec_sequential_with_offset( // _debug // build_eq_x_r_helper_sequential_offset(r, &mut evals, E::ONE); // unsafe { std::mem::transmute(evals) } + // FIXME: this function is not implemented yet evals } diff --git a/src/basefold_verifier/basefold.rs b/src/basefold_verifier/basefold.rs index 4bb8d77..e025145 100644 --- a/src/basefold_verifier/basefold.rs +++ b/src/basefold_verifier/basefold.rs @@ -17,7 +17,7 @@ pub type HashDigest = MmcsCommitment; pub struct BasefoldCommitment { pub commit: HashDigest, pub log2_max_codeword_size: usize, - // pub trivial_commits: Vec, + pub trivial_commits: Vec, } use mpcs::BasefoldCommitment as InnerBasefoldCommitment; @@ -29,6 +29,11 @@ impl From> for BasefoldCommitment { value: value.commit().into(), }, log2_max_codeword_size: value.log2_max_codeword_size, + trivial_commits: value + .trivial_commits + .into_iter() + .map(|c| c.into()) + .collect(), } } } diff --git a/src/basefold_verifier/extension_mmcs.rs b/src/basefold_verifier/extension_mmcs.rs index cf4f7e0..b60794d 100644 --- a/src/basefold_verifier/extension_mmcs.rs +++ b/src/basefold_verifier/extension_mmcs.rs @@ -17,6 +17,15 @@ pub struct ExtMmcsVerifierInput { pub proof: MmcsProof, } +#[derive(DslVariable, Clone)] +pub struct ExtMmcsVerifierInputVariable { + pub commit: MmcsCommitmentVariable, + pub dimensions: Array>, + pub index_bits: Array>, + pub opened_values: Array>>, + pub proof: HintSlice, +} + impl Hintable for ExtMmcsVerifierInput { type HintVariable = ExtMmcsVerifierInputVariable; @@ -61,15 +70,6 @@ impl Hintable for ExtMmcsVerifierInput { } } -#[derive(DslVariable, Clone)] -pub struct ExtMmcsVerifierInputVariable { - pub commit: MmcsCommitmentVariable, - pub dimensions: Array>, - pub index_bits: Array>, - pub opened_values: Array>>, - pub proof: HintSlice, -} - pub(crate) fn ext_mmcs_verify_batch( builder: &mut Builder, input: ExtMmcsVerifierInputVariable, diff --git a/src/basefold_verifier/hash.rs b/src/basefold_verifier/hash.rs index d7b7813..9aee223 100644 --- a/src/basefold_verifier/hash.rs +++ b/src/basefold_verifier/hash.rs @@ -26,6 +26,19 @@ impl Default for Hash { } } +impl From> for Hash { + fn from(hash: p3_symmetric::Hash) -> Self { + Hash { value: hash.into() } + } +} + +#[derive(DslVariable, Clone)] +pub struct HashVariable { + pub value: Array>, +} + +impl VecAutoHintable for Hash {} + impl Hintable for Hash { type HintVariable = HashVariable; @@ -48,51 +61,36 @@ impl Hintable for Hash { stream } } -impl VecAutoHintable for Hash {} - -#[derive(DslVariable, Clone)] -pub struct HashVariable { - pub value: Array>, -} - -pub fn hash_iter_slices( - builder: &mut Builder, - // _hash: HashVariable, - _values: Array>>, -) -> Array> { - // XXX: verify hash - builder.hint_felts_fixed(DIGEST_ELEMS) -} - -pub fn compress( - builder: &mut Builder, - // _hash: HashVariable, - _values: Array>>, -) -> Array> { - // XXX: verify hash - builder.hint_felts_fixed(DIGEST_ELEMS) -} #[cfg(test)] mod tests { + use openvm_circuit::arch::{SystemConfig, VmExecutor}; + use openvm_native_circuit::{Native, NativeConfig}; use openvm_native_compiler::asm::AsmBuilder; - use openvm_native_compiler_derive::iter_zip; - use openvm_stark_backend::config::StarkGenericConfig; - use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config; - type SC = BabyBearPoseidon2Config; type F = BabyBear; type E = BinomialExtensionField; - type EF = ::Challenge; use crate::basefold_verifier::basefold::HashDigest; use super::*; #[test] fn test_read_to_hash_variable() { - let mut builder = AsmBuilder::::default(); - - let hint = HashDigest::read(&mut builder); - let dst: HashVariable<_> = builder.uninit(); - builder.assign(&dst, hint); + // simple test program + let mut builder = AsmBuilder::::default(); + let _digest = HashDigest::read(&mut builder); + builder.halt(); + + // configure the VM executor + let system_config = SystemConfig::default().with_max_segment_len(1 << 20); + let config = NativeConfig::new(system_config, Native); + let executor = VmExecutor::new(config); + + // prepare input + let mut input = Vec::new(); + input.extend(Hash::default().write()); + + // execute the program + let program = builder.compile_isa(); + executor.execute(program, input).unwrap(); } } diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index d419e46..5093a9e 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -15,6 +15,7 @@ pub type InnerConfig = AsmConfig; pub type MmcsCommitment = Hash; pub type MmcsProof = Vec<[F; DIGEST_ELEMS]>; + pub struct MmcsVerifierInput { pub commit: MmcsCommitment, pub dimensions: Vec, diff --git a/src/basefold_verifier/mod.rs b/src/basefold_verifier/mod.rs index 01ea914..0353225 100644 --- a/src/basefold_verifier/mod.rs +++ b/src/basefold_verifier/mod.rs @@ -1,9 +1,9 @@ -pub(crate) mod structs; pub(crate) mod basefold; -pub(crate) mod query_phase; -pub(crate) mod rs; pub(crate) mod extension_mmcs; -pub(crate) mod mmcs; pub(crate) mod hash; +pub(crate) mod mmcs; +pub(crate) mod query_phase; +pub(crate) mod rs; +pub(crate) mod structs; // pub(crate) mod field; -pub(crate) mod utils; \ No newline at end of file +pub(crate) mod utils; diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index e533bb5..a35f5bc 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -1,9 +1,6 @@ // Note: check all XXX comments! -use std::fmt::Debug; - -use ff_ext::{ExtensionField, PoseidonField}; -use mpcs::QueryPhaseVerifierInput as InnerQueryPhaseVerifierInput; +use ff_ext::{BabyBearExt4, ExtensionField, PoseidonField}; use openvm_native_compiler::{asm::AsmConfig, prelude::*}; use openvm_native_recursion::{ hints::{Hintable, VecAutoHintable}, @@ -11,23 +8,25 @@ use openvm_native_recursion::{ }; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_commit::ExtensionMmcs; -use p3_field::extension::BinomialExtensionField; -use p3_field::FieldAlgebra; +use p3_field::{Field, FieldAlgebra}; use serde::Deserialize; +use std::fmt::Debug; use super::{basefold::*, extension_mmcs::*, mmcs::*, rs::*, structs::*, utils::*}; use crate::{ - arithmetics::{ - build_eq_x_r_vec_sequential, build_eq_x_r_vec_sequential_with_offset, eq_eval_with_index, - }, + arithmetics::{build_eq_x_r_vec_sequential_with_offset, eq_eval_with_index}, tower_verifier::{binding::*, program::interpolate_uni_poly}, }; pub type F = BabyBear; -pub type E = BinomialExtensionField; +pub type E = BabyBearExt4; pub type InnerConfig = AsmConfig; use p3_fri::BatchOpening as InnerBatchOpening; +use p3_fri::CommitPhaseProofStep as InnerCommitPhaseProofStep; + +/// We have to define a struct similar to p3_fri::BatchOpening as +/// the trait `Hintable` is defined in another crate inside OpenVM #[derive(Deserialize)] pub struct BatchOpening { pub opened_values: Vec>, @@ -55,14 +54,18 @@ impl } } +#[derive(DslVariable, Clone)] +pub struct BatchOpeningVariable { + pub opened_values: Array>>, + pub opening_proof: HintSlice, +} + impl Hintable for BatchOpening { type HintVariable = BatchOpeningVariable; fn read(builder: &mut Builder) -> Self::HintVariable { let opened_values = Vec::>::read(builder); - let length = Usize::from(builder.hint_var()); - let id = Usize::from(builder.hint_load()); - let opening_proof = HintSlice { length, id }; + let opening_proof = read_hint_slice(builder); BatchOpeningVariable { opened_values, @@ -74,22 +77,18 @@ impl Hintable for BatchOpening { let mut stream = Vec::new(); stream.extend(self.opened_values.write()); stream.extend(vec![ - vec![::N::from_canonical_usize( - self.opening_proof.len(), - )], - self.opening_proof.iter().flatten().copied().collect(), + vec![F::from_canonical_usize(self.opening_proof.len())], + self.opening_proof + .iter() + .flatten() + .copied() + .collect::>(), ]); stream } } -#[derive(DslVariable, Clone)] -pub struct BatchOpeningVariable { - pub opened_values: Array>>, - pub opening_proof: HintSlice, -} - -use p3_fri::CommitPhaseProofStep as InnerCommitPhaseProofStep; +/// TODO: use `openvm_native_recursion::fri::types::FriCommitPhaseProofStepVariable` instead #[derive(Deserialize)] pub struct CommitPhaseProofStep { pub sibling_value: E, @@ -110,14 +109,20 @@ impl From>> for CommitPhaseProofStep { } } +#[derive(DslVariable, Clone)] +pub struct CommitPhaseProofStepVariable { + pub sibling_value: Ext, + pub opening_proof: HintSlice, +} + +impl VecAutoHintable for CommitPhaseProofStep {} + impl Hintable for CommitPhaseProofStep { type HintVariable = CommitPhaseProofStepVariable; fn read(builder: &mut Builder) -> Self::HintVariable { let sibling_value = E::read(builder); - let length = Usize::from(builder.hint_var()); - let id = Usize::from(builder.hint_load()); - let opening_proof = HintSlice { length, id }; + let opening_proof = read_hint_slice(builder); CommitPhaseProofStepVariable { sibling_value, @@ -129,21 +134,16 @@ impl Hintable for CommitPhaseProofStep { let mut stream = Vec::new(); stream.extend(self.sibling_value.write()); stream.extend(vec![ - vec![::N::from_canonical_usize( - self.opening_proof.len(), - )], - self.opening_proof.iter().flatten().copied().collect(), + vec![F::from_canonical_usize(self.opening_proof.len())], + self.opening_proof + .iter() + .flatten() + .copied() + .collect::>(), ]); stream } } -impl VecAutoHintable for CommitPhaseProofStep {} - -#[derive(DslVariable, Clone)] -pub struct CommitPhaseProofStepVariable { - pub sibling_value: Ext, - pub opening_proof: HintSlice, -} #[derive(Deserialize)] pub struct QueryOpeningProof { @@ -151,23 +151,20 @@ pub struct QueryOpeningProof { pub fixed_base_proof: Option, pub commit_phase_openings: Vec, } -type QueryOpeningProofs = Vec; -use mpcs::QueryOpeningProof as InnerQueryOpeningProof; -impl From> for QueryOpeningProof { - fn from(proof: InnerQueryOpeningProof) -> Self { - QueryOpeningProof { - witin_base_proof: proof.witin_base_proof.into(), - fixed_base_proof: proof.fixed_base_proof.map(|p| p.into()), - commit_phase_openings: proof - .commit_phase_openings - .into_iter() - .map(|p| p.into()) - .collect(), - } - } +#[derive(DslVariable, Clone)] +pub struct QueryOpeningProofVariable { + pub witin_base_proof: BatchOpeningVariable, + pub fixed_is_some: Usize, // 0 <==> false + pub fixed_base_proof: BatchOpeningVariable, + pub commit_phase_openings: Array>, } +type QueryOpeningProofs = Vec; +type QueryOpeningProofsVariable = Array>; + +impl VecAutoHintable for QueryOpeningProof {} + impl Hintable for QueryOpeningProof { type HintVariable = QueryOpeningProofVariable; @@ -202,16 +199,6 @@ impl Hintable for QueryOpeningProof { stream } } -impl VecAutoHintable for QueryOpeningProof {} - -#[derive(DslVariable, Clone)] -pub struct QueryOpeningProofVariable { - pub witin_base_proof: BatchOpeningVariable, - pub fixed_is_some: Usize, // 0 <==> false - pub fixed_base_proof: BatchOpeningVariable, - pub commit_phase_openings: Array>, -} -type QueryOpeningProofsVariable = Array>; // NOTE: Different from PointAndEval in tower_verifier! pub struct PointAndEvals { @@ -244,6 +231,7 @@ pub struct PointAndEvalsVariable { #[derive(Deserialize)] pub struct QueryPhaseVerifierInput { + // pub t_inv_halves: Vec::BaseField>>, pub max_num_var: usize, pub indices: Vec, pub final_message: Vec>, @@ -258,41 +246,11 @@ pub struct QueryPhaseVerifierInput { pub point_evals: Vec<(Point, Vec)>, } -impl From> for QueryPhaseVerifierInput { - fn from(input: InnerQueryPhaseVerifierInput) -> Self { - QueryPhaseVerifierInput { - max_num_var: input.max_num_var, - indices: input.indices, - final_message: input.final_message, - batch_coeffs: input.batch_coeffs, - queries: input.queries.into_iter().map(|q| q.into()).collect(), - fixed_comm: input.fixed_comm.map(|comm| comm.into()), - witin_comm: input.witin_comm.into(), - circuit_meta: input.circuit_meta.into_iter().map(|q| q.into()).collect(), - commits: input - .commits - .into_iter() - .map(|q| super::hash::Hash { value: q.into() }) - .collect(), - fold_challenges: input.fold_challenges, - sumcheck_messages: input - .sumcheck_messages - .into_iter() - .map(|q| q.into()) - .collect(), - point_evals: input - .point_evals - .into_iter() - .map(|q| (Point { fs: q.0 }, q.1)) - .collect(), - } - } -} - impl Hintable for QueryPhaseVerifierInput { type HintVariable = QueryPhaseVerifierInputVariable; fn read(builder: &mut Builder) -> Self::HintVariable { + // let t_inv_halves = Vec::>::read(builder); let max_num_var = Usize::Var(usize::read(builder)); let indices = Vec::::read(builder); let final_message = Vec::>::read(builder); @@ -308,6 +266,7 @@ impl Hintable for QueryPhaseVerifierInput { let point_evals = Vec::::read(builder); QueryPhaseVerifierInputVariable { + // t_inv_halves, max_num_var, indices, final_message, @@ -326,6 +285,7 @@ impl Hintable for QueryPhaseVerifierInput { fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); + // stream.extend(self.t_inv_halves.write()); stream.extend(>::write(&self.max_num_var)); stream.extend(self.indices.write()); stream.extend(self.final_message.write()); @@ -339,7 +299,7 @@ impl Hintable for QueryPhaseVerifierInput { let tmp_comm = BasefoldCommitment { commit: Default::default(), log2_max_codeword_size: 0, - // trivial_commits: Vec::new(), + trivial_commits: vec![], }; stream.extend(tmp_comm.write()); } @@ -364,6 +324,7 @@ impl Hintable for QueryPhaseVerifierInput { #[derive(DslVariable, Clone)] pub struct QueryPhaseVerifierInputVariable { + // pub t_inv_halves: Array>>, pub max_num_var: Usize, pub indices: Array>, pub final_message: Array>>, @@ -389,7 +350,7 @@ pub(crate) fn batch_verifier_query_phase( inv_2 * C::F::from_canonical_usize(2), C::F::from_canonical_usize(1), ); - let two_adic_generators: Array> = builder.dyn_array(28); + let two_adic_generators_inverses: Array> = builder.dyn_array(28); for (index, val) in [ 0x1usize, 0x78000000, 0x67055c21, 0x5ee99486, 0xbb4c4e4, 0x2d4cc4da, 0x669d6090, 0x17b56c64, 0x67456167, 0x688442f9, 0x145e952d, 0x4fe61226, 0x4c734715, 0x11c33e2a, @@ -399,8 +360,8 @@ pub(crate) fn batch_verifier_query_phase( .iter() .enumerate() { - let generator = builder.constant(C::F::from_canonical_usize(*val)); - builder.set_value(&two_adic_generators, index, generator); + let generator = builder.constant(C::F::from_canonical_usize(*val).inverse()); + builder.set_value(&two_adic_generators_inverses, index, generator); } // encode_small @@ -455,12 +416,16 @@ pub(crate) fn batch_verifier_query_phase( // 2. num_unique_height: number of different heights // 3. count_per_unique_height: for each unique height, number of dimensions of that height // builder.assert_nonzero(&Usize::from(0)); - let (folding_sorted_order_index, _num_unique_num_vars, count_per_unique_num_var) = - sort_with_count( - builder, - &input.circuit_meta, - |m: CircuitIndexMetaVariable| m.witin_num_vars, - ); + let ( + folding_sorted_order_index, + num_unique_num_vars, + count_per_unique_num_var, + sorted_unique_num_vars, + ) = sort_with_count( + builder, + &input.circuit_meta, + |m: CircuitIndexMetaVariable| m.witin_num_vars, + ); builder .range(0, input.indices.len()) @@ -477,24 +442,25 @@ pub(crate) fn batch_verifier_query_phase( // verify base oracle query proof // refer to prover documentation for the reason of right shift by 1 - // Nondeterministically supply the bits of idx in BIG ENDIAN - // These are not only used by the right shift here but also later on idx_shift - let idx_len = builder.hint_var(); - let idx_bits: Array> = builder.dyn_array(idx_len); - builder.range(0, idx_len).for_each(|j_vec, builder| { - let j = j_vec[0]; - let next_bit = builder.hint_var(); - // Assert that it is a bit - builder.assert_eq::>(next_bit * next_bit, next_bit); - builder.set_value(&idx_bits, j, next_bit); - }); + // The index length is the logarithm of the maximal codeword size. + let idx_len: Var = builder.eval(input.max_num_var.clone() + get_rate_log::()); + let idx_felt = builder.unsafe_cast_var_to_felt(idx); + let idx_bits = builder.num2bits_f(idx_felt, C::N::bits() as u32); + builder + .range(idx_len, idx_bits.len()) + .for_each(|i_vec, builder| { + let bit = builder.get(&idx_bits, i_vec[0]); + builder.assert_eq::>(bit, Usize::from(0)); + }); + // Right shift let idx_len_minus_one: Var = builder.eval(idx_len - Usize::from(1)); + let idx_half = builder.hint_var(); + let lsb = builder.get(&idx_bits, 0); + builder.assert_var_eq(Usize::from(2) * idx_half + lsb, idx); + builder.assign(&idx_len, idx_len_minus_one); - let new_idx = bin_to_dec(builder, &idx_bits, idx_len); - let last_bit = builder.get(&idx_bits, idx_len); - builder.assert_eq::>(Usize::from(2) * new_idx + last_bit, idx); - builder.assign(&idx, new_idx); + builder.assign(&idx, idx_half); let (witin_dimensions, fixed_dimensions) = get_base_codeword_dimensions(builder, input.circuit_meta.clone()); @@ -502,7 +468,7 @@ pub(crate) fn batch_verifier_query_phase( let mmcs_verifier_input = MmcsVerifierInputVariable { commit: input.witin_comm.commit.clone(), dimensions: witin_dimensions, - index_bits: idx_bits.clone(), // TODO: double check, should be new idx bits here ? + index_bits: idx_bits.clone().slice(builder, 1, idx_len), // Remove the first bit because two entries are grouped into one leaf in the Merkle tree opened_values: witin_opened_values.clone(), proof: witin_opening_proof, }; @@ -557,7 +523,7 @@ pub(crate) fn batch_verifier_query_phase( let mmcs_verifier_input = MmcsVerifierInputVariable { commit: input.fixed_comm.commit.clone(), dimensions: fixed_dimensions.clone(), - index_bits: idx_bits.clone(), // TODO: should be new idx_bits + index_bits: idx_bits.clone().slice(builder, 1, idx_len), opened_values: fixed_opened_values.clone(), proof: fixed_opening_proof, }; @@ -600,7 +566,7 @@ pub(crate) fn batch_verifier_query_phase( .if_ne(fixed_num_vars, Usize::from(0)) .then(|builder| { let fixed_leafs = builder.get(&fixed_commit_leafs, j); - let leafs_len_div_2 = builder.hint_var(); + let leafs_len_div_2: Var<::N> = builder.hint_var(); let two: Var = builder.eval(Usize::from(2)); builder .assert_eq::>(leafs_len_div_2 * two, fixed_leafs.len()); // Can we assume that leafs.len() is even? @@ -628,7 +594,7 @@ pub(crate) fn batch_verifier_query_phase( let cur_num_var: Var = builder.eval(input.max_num_var.clone()); // let rounds: Var = builder.eval(cur_num_var - get_basecode_msg_size_log::() - Usize::from(1)); let n_d_next_log: Var = - builder.eval(cur_num_var - get_rate_log::() - Usize::from(1)); + builder.eval(cur_num_var + get_rate_log::() - Usize::from(1)); // let n_d_next = pow_2(builder, n_d_next_log); // first folding challenge @@ -644,13 +610,18 @@ pub(crate) fn batch_verifier_query_phase( let hi = builder.get(&base_codeword_hi, index.clone()); let level: Var = builder.eval(cur_num_var + get_rate_log::() - Usize::from(1)); + let sliced_bits = idx_bits.clone().slice(builder, 1, idx_len); + // let expected_coeffs = builder.get(&input.t_inv_halves, level); + // let expected_coeff = builder.get(&expected_coeffs, idx); // TODO: remove this and directly use the result from verifier_folding_coeffs_level function let coeff = verifier_folding_coeffs_level( builder, - &two_adic_generators, + &two_adic_generators_inverses, level, - &idx_bits, + &sliced_bits, inv_2, ); + // builder.assert_eq::>(coeff, expected_coeff); + // builder.halt(); let fold = codeword_fold_with_challenge::(builder, lo, hi, r, coeff, inv_2); builder.assign(&folded, folded + fold); }); @@ -670,49 +641,66 @@ pub(crate) fn batch_verifier_query_phase( let j = j_vec[0]; let pi_comm = builder.get(&input.commits, j); let j_plus_one = builder.eval_expr(j + RVar::from(1)); + let j_plus_two = builder.eval(j + RVar::from(2)); let r = builder.get(&input.fold_challenges, j_plus_one); let leaf = builder.get(&opening_ext, j).sibling_value; let proof = builder.get(&opening_ext, j).opening_proof; builder.assign(&cur_num_var, cur_num_var - Usize::from(1)); // next folding challenges - let idx_len_minus_one: Var = builder.eval(idx_len - Usize::from(1)); - let is_interpolate_to_right_index = builder.get(&idx_bits, idx_len_minus_one); + let is_interpolate_to_right_index = builder.get(&idx_bits, j_plus_one); let new_involved_codewords: Ext = builder.constant(C::EF::ZERO); - let next_unique_num_vars_count: Var = - builder.get(&count_per_unique_num_var, next_unique_num_vars_index); builder - .range(0, next_unique_num_vars_count) - .for_each(|k_vec, builder| { - let k = builder.eval_expr(k_vec[0] + cumul_num_vars_count); - let index = builder.get(&folding_sorted_order_index, k); - let lo = builder.get(&base_codeword_lo, index.clone()); - let hi = builder.get(&base_codeword_hi, index.clone()); + .if_ne(next_unique_num_vars_index, num_unique_num_vars) + .then(|builder| { + let next_unique_num_vars: Var = + builder.get(&sorted_unique_num_vars, next_unique_num_vars_index); builder - .if_eq(is_interpolate_to_right_index, Usize::from(1)) + .if_eq(next_unique_num_vars, cur_num_var) .then(|builder| { + let next_unique_num_vars_count: Var = builder + .get(&count_per_unique_num_var, next_unique_num_vars_index); + builder.range(0, next_unique_num_vars_count).for_each( + |k_vec, builder| { + let k = + builder.eval_expr(k_vec[0] + cumul_num_vars_count); + let index = builder.get(&folding_sorted_order_index, k); + let lo = builder.get(&base_codeword_lo, index.clone()); + let hi = builder.get(&base_codeword_hi, index.clone()); + builder + .if_eq( + is_interpolate_to_right_index, + Usize::from(1), + ) + .then(|builder| { + builder.assign( + &new_involved_codewords, + new_involved_codewords + hi, + ); + }); + builder + .if_ne( + is_interpolate_to_right_index, + Usize::from(1), + ) + .then(|builder| { + builder.assign( + &new_involved_codewords, + new_involved_codewords + lo, + ); + }); + }, + ); builder.assign( - &new_involved_codewords, - new_involved_codewords + hi, + &cumul_num_vars_count, + cumul_num_vars_count + next_unique_num_vars_count, ); - }); - builder - .if_ne(is_interpolate_to_right_index, Usize::from(1)) - .then(|builder| { builder.assign( - &new_involved_codewords, - new_involved_codewords + lo, + &next_unique_num_vars_index, + next_unique_num_vars_index + Usize::from(1), ); }); }); - builder.assign( - &cumul_num_vars_count, - cumul_num_vars_count + next_unique_num_vars_count, - ); - builder.assign( - &next_unique_num_vars_index, - next_unique_num_vars_index + Usize::from(1), - ); // leafs let leafs = builder.dyn_array(2); @@ -732,33 +720,34 @@ pub(crate) fn batch_verifier_query_phase( // idx >>= 1 let idx_len_minus_one: Var = builder.eval(idx_len - Usize::from(1)); builder.assign(&idx_len, idx_len_minus_one); - let new_idx = bin_to_dec(builder, &idx_bits, idx_len); - let last_bit = builder.get(&idx_bits, idx_len); - builder.assert_eq::>(Usize::from(2) * new_idx + last_bit, idx); + let idx_end = builder.eval(input.max_num_var.clone() + get_rate_log::()); + let new_idx = bin_to_dec_le(builder, &idx_bits, j_plus_two, idx_end); + let first_bit = builder.get(&idx_bits, j_plus_one); + builder.assert_eq::>(Usize::from(2) * new_idx + first_bit, idx); builder.assign(&idx, new_idx); // n_d_i >> 1 builder.assign(&n_d_i_log, n_d_i_log - Usize::from(1)); - let n_d_i = pow_2(builder, n_d_i_log); // mmcs_ext.verify_batch let dimensions = builder.dyn_array(1); // let two: Var<_> = builder.eval(Usize::from(2)); - builder.set_value(&dimensions, 0, n_d_i.clone()); + builder.set_value(&dimensions, 0, n_d_i_log.clone()); let opened_values = builder.dyn_array(1); builder.set_value(&opened_values, 0, leafs.clone()); let ext_mmcs_verifier_input = ExtMmcsVerifierInputVariable { commit: pi_comm.clone(), dimensions, - index_bits: idx_bits.clone(), // TODO: new idx bits? + index_bits: idx_bits.clone().slice(builder, j_plus_two, idx_end), opened_values, proof, }; - ext_mmcs_verify_batch::(builder, ext_mmcs_verifier_input); + ext_mmcs_verify_batch::(builder, ext_mmcs_verifier_input); // FIXME: the Merkle roots do not match + let sliced_bits = idx_bits.clone().slice(builder, j_plus_two, idx_len); let coeff = verifier_folding_coeffs_level( builder, - &two_adic_generators, + &two_adic_generators_inverses, n_d_i_log.clone(), - &idx_bits, + &sliced_bits, inv_2, ); let left = builder.get(&leafs, 0); @@ -844,98 +833,271 @@ pub(crate) fn batch_verifier_query_phase( Usize::Var(ylo), Usize::Var(num_vars_evaluated), ); - let eq = build_eq_x_r_vec_sequential_with_offset::( - builder, - &point.fs, - Usize::Var(num_vars_evaluated), - ); - let eq_coeff = builder.dyn_array(eq.len()); - builder.range(0, eq.len()).for_each(|j_vec, builder| { - let j = j_vec[0]; - let next_eq = builder.get(&eq, j); - let next_eq_coeff: Ext = builder.eval(next_eq * coeff); - builder.set_value(&eq_coeff, j, next_eq_coeff); - }); - let dot_prod = dot_product(builder, &final_message, &eq_coeff); + // We assume that the final message is of size 1, so the eq poly is just + // vec![one]. + // let eq = build_eq_x_r_vec_sequential_with_offset::( + // builder, + // &point.fs, + // Usize::Var(num_vars_evaluated), + // ); + // eq_coeff = eq * coeff + // let eq_coeff = builder.dyn_array(eq.len()); + // builder.range(0, eq.len()).for_each(|j_vec, builder| { + // let j = j_vec[0]; + // let next_eq = builder.get(&eq, j); + // let next_eq_coeff: Ext = builder.eval(next_eq * coeff); + // builder.set_value(&eq_coeff, j, next_eq_coeff); + // }); + // let dot_prod = dot_product(builder, &final_message, &eq_coeff); + + // Again assuming final message is a single element + let final_message = builder.get(&final_message, 0); + // Again, eq polynomial is just one + let dot_prod: Ext = builder.eval(final_message * coeff); builder.assign(&right, right + dot_prod); }); builder.assert_eq::>(left, right); } +#[cfg(test)] pub mod tests { - use std::{fs::File, io::Read}; - - use mpcs::{QueryPhaseAdditionalHint, QueryPhaseVerifierInput as InnerQueryPhaseVerifierInput}; + use std::{cmp::Reverse, collections::BTreeMap, iter::once}; + + use ceno_mle::mle::MultilinearExtension; + use ceno_transcript::{BasicTranscript, Transcript}; + use ff_ext::{BabyBearExt4, FromUniformBytes}; + use itertools::Itertools; + use mpcs::pcs_batch_verify; + use mpcs::{ + pcs_batch_commit, pcs_batch_open, pcs_setup, pcs_trim, + util::hash::write_digest_to_transcript, BasefoldDefault, PolynomialCommitmentScheme, + }; use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; use openvm_native_circuit::{Native, NativeConfig}; use openvm_native_compiler::asm::AsmBuilder; use openvm_native_recursion::hints::Hintable; - use openvm_stark_backend::config::StarkGenericConfig; - use openvm_stark_sdk::{ - config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, - }; - use p3_field::{extension::BinomialExtensionField, Field, FieldAlgebra}; - type SC = BabyBearPoseidon2Config; + use openvm_stark_sdk::p3_baby_bear::BabyBear; + use p3_field::Field; + use p3_field::FieldAlgebra; + use rand::thread_rng; type F = BabyBear; - type E = BinomialExtensionField; - type EF = ::Challenge; + type E = BabyBearExt4; + type PCS = BasefoldDefault; + + use crate::{ + basefold_verifier::{ + basefold::BasefoldCommitment, + query_phase::{BatchOpening, CommitPhaseProofStep, QueryOpeningProof}, + structs::CircuitIndexMeta, + }, + tower_verifier::binding::{Point, PointAndEval}, + }; use super::{batch_verifier_query_phase, QueryPhaseVerifierInput}; #[allow(dead_code)] - pub fn build_batch_verifier_query_phase() -> (Program, Vec>) { - // OpenVM DSL - let mut builder = AsmBuilder::::default(); - - // Witness inputs + pub fn build_batch_verifier_query_phase( + input: QueryPhaseVerifierInput, + ) -> (Program, Vec>) { + // build test program + let mut builder = AsmBuilder::::default(); let query_phase_input = QueryPhaseVerifierInput::read(&mut builder); batch_verifier_query_phase(&mut builder, query_phase_input); builder.halt(); + let program = builder.compile_isa(); - // Pass in witness stream - let f = |n: usize| F::from_canonical_usize(n); - let mut witness_stream: Vec< - Vec>, - > = Vec::new(); - - // INPUT - let mut f = File::open("query_phase_verifier_input.bin".to_string()).unwrap(); - let mut content: Vec = Vec::new(); - f.read_to_end(&mut content).unwrap(); - let input: InnerQueryPhaseVerifierInput = bincode::deserialize(&content).unwrap(); - let input: QueryPhaseVerifierInput = input.into(); - + // prepare input + let mut witness_stream: Vec> = Vec::new(); witness_stream.extend(input.write()); - - // the builder reads some additional hints after reading the query - // phase verifier input. Need to feed them into the stream - let mut f = File::open("query_phase_additional_hint.bin".to_string()).unwrap(); - let mut content: Vec = Vec::new(); - f.read_to_end(&mut content).unwrap(); - let input: QueryPhaseAdditionalHint = bincode::deserialize(&content).unwrap(); - - witness_stream.extend(vec![vec![input.two_inv]]); - witness_stream.extend(vec![vec![F::from_canonical_usize( - input.num_unique_entries, - )]]); - witness_stream.extend(vec![input - .sorting_orders - .iter() - .map(|x| F::from_canonical_usize(*x)) - .collect()]); - - // PROGRAM - let program: Program< - p3_monty_31::MontyField31, - > = builder.compile_isa(); + witness_stream.push(vec![F::from_canonical_u32(2).inverse()]); + witness_stream.push(vec![F::from_canonical_usize( + input + .circuit_meta + .iter() + .unique_by(|x| x.witin_num_vars) + .count(), + )]); + witness_stream.push( + input + .circuit_meta + .iter() + .enumerate() + .sorted_by_key(|(_, CircuitIndexMeta { witin_num_vars, .. })| { + Reverse(witin_num_vars) + }) + .map(|(index, _)| F::from_canonical_usize(index)) + .collect_vec(), + ); + for (query, idx) in input.queries.iter().zip(input.indices.iter()) { + witness_stream.push(vec![F::from_canonical_usize(idx / 2)]); + if let Some(fixed_comm) = &input.fixed_comm { + let log2_witin_max_codeword_size = input.max_num_var + 1; + if log2_witin_max_codeword_size > fixed_comm.log2_max_codeword_size { + witness_stream.push(vec![F::ZERO]) + } else { + witness_stream.push(vec![F::ONE]) + } + } + for i in 0..input.circuit_meta.len() { + witness_stream.push(vec![F::from_canonical_usize( + query.witin_base_proof.opened_values[i].len() / 2, + )]); + if input.circuit_meta[i].fixed_num_vars > 0 { + witness_stream.push(vec![F::from_canonical_usize( + if let Some(fixed_base_proof) = &query.fixed_base_proof { + fixed_base_proof.opened_values[i].len() / 2 + } else { + 0 + }, + )]); + } + } + } (program, witness_stream) } #[test] fn test_verify_query_phase_batch() { - let (program, witness) = build_batch_verifier_query_phase(); + let mut rng = thread_rng(); + let m1 = ceno_witness::RowMajorMatrix::::rand(&mut rng, 1 << 10, 10); + let mles_1 = m1.to_mles(); + let matrices = BTreeMap::from_iter(once((0, m1))); + + let pp = pcs_setup::(1 << 20).unwrap(); + let (pp, vp) = pcs_trim::(pp, 1 << 20).unwrap(); + let pcs_data = pcs_batch_commit::(&pp, matrices).unwrap(); + let witin_comm = PCS::get_pure_commitment(&pcs_data); + + let points = vec![E::random_vec(10, &mut rng)]; + let evals = points + .iter() + .map(|p| mles_1.iter().map(|mle| mle.evaluate(p)).collect_vec()) + .collect::>(); + // let evals = mles_1 + // .iter() + // .map(|mle| points.iter().map(|p| mle.evaluate(&p)).collect_vec()) + // .collect::>(); + let mut transcript = BasicTranscript::::new(&[]); + let opening_proof = pcs_batch_open::( + &pp, + &[(0, 1 << 10)], + None, + &pcs_data, + &points, + &evals, + &[(10, 0)], + &mut transcript, + ) + .unwrap(); + + let mut transcript = BasicTranscript::::new(&[]); + pcs_batch_verify::( + &vp, + &[(0, 1 << 10)], + &points, + None, + &witin_comm, + &evals, + &opening_proof, + &[(10, 0)], + &mut transcript, + ) + .expect("Native verification failed"); + + let mut transcript = BasicTranscript::::new(&[]); + let batch_coeffs = transcript.sample_and_append_challenge_pows(10, b"batch coeffs"); + + let max_num_var = 10; + let num_rounds = max_num_var; // The final message is of length 1 + + // prepare folding challenges via sumcheck round msg + FRI commitment + let mut fold_challenges: Vec = Vec::with_capacity(10); + let commits = &opening_proof.commits; + + let sumcheck_messages = opening_proof.sumcheck_proof.as_ref().unwrap(); + for i in 0..num_rounds { + transcript.append_field_element_exts(sumcheck_messages[i].evaluations.as_slice()); + fold_challenges.push( + transcript + .sample_and_append_challenge(b"commit round") + .elements, + ); + if i < num_rounds - 1 { + write_digest_to_transcript(&commits[i], &mut transcript); + } + } + transcript.append_field_element_exts_iter(opening_proof.final_message.iter().flatten()); + + let queries = opening_proof + .query_opening_proof + .iter() + .map(|query| QueryOpeningProof { + witin_base_proof: BatchOpening { + opened_values: query.witin_base_proof.opened_values.clone(), + opening_proof: query.witin_base_proof.opening_proof.clone(), + }, + fixed_base_proof: None, + commit_phase_openings: query + .commit_phase_openings + .iter() + .map(|step| CommitPhaseProofStep { + sibling_value: step.sibling_value.clone(), + opening_proof: step.opening_proof.clone(), + }) + .collect(), + }) + .collect(); + + let query_input = QueryPhaseVerifierInput { + // t_inv_halves: vp.encoding_params.t_inv_halves, + max_num_var: 10, + indices: opening_proof.query_indices, + final_message: opening_proof.final_message, + batch_coeffs, + queries, + fixed_comm: None, + witin_comm: BasefoldCommitment { + commit: witin_comm.commit().into(), + trivial_commits: witin_comm + .trivial_commits + .iter() + .copied() + .map(|c| c.into()) + .collect(), + log2_max_codeword_size: 20, + // This is a dummy value, should be set according to the actual codeword size + }, + circuit_meta: vec![CircuitIndexMeta { + witin_num_vars: 10, + fixed_num_vars: 0, + witin_num_polys: 10, + fixed_num_polys: 0, + }], + commits: opening_proof + .commits + .iter() + .copied() + .map(|c| c.into()) + .collect(), + fold_challenges, + sumcheck_messages: opening_proof + .sumcheck_proof + .as_ref() + .unwrap() + .clone() + .into_iter() + .map(|msg| msg.into()) + .collect(), + point_evals: vec![( + Point { + fs: points[0].clone(), + }, + evals[0].clone(), + )], + }; + let (program, witness) = build_batch_verifier_query_phase(query_input); let system_config = SystemConfig::default() .with_public_values(4) @@ -943,12 +1105,12 @@ pub mod tests { let config = NativeConfig::new(system_config, Native); let executor = VmExecutor::::new(config); - executor.execute(program, witness).unwrap(); + executor.execute(program.clone(), witness.clone()).unwrap(); // _debug - // let results = executor.execute_segments(program, witness).unwrap(); - // for seg in results { - // println!("=> cycle count: {:?}", seg.metrics.cycle_count); - // } + let results = executor.execute_segments(program, witness).unwrap(); + for seg in results { + println!("=> cycle count: {:?}", seg.metrics.cycle_count); + } } } diff --git a/src/basefold_verifier/rs.rs b/src/basefold_verifier/rs.rs index 93e676a..a533eaa 100644 --- a/src/basefold_verifier/rs.rs +++ b/src/basefold_verifier/rs.rs @@ -98,20 +98,18 @@ pub fn get_rate_log() -> Usize { } pub fn get_basecode_msg_size_log() -> Usize { - Usize::from(7) + Usize::from(0) } pub fn verifier_folding_coeffs_level( builder: &mut Builder, - two_adic_generators: &Array>, + two_adic_generators_inverses: &Array>, level: Var, - index_bits: &Array>, // In big endian + index_bits: &Array>, two_inv: Felt, ) -> Felt { let level_plus_one = builder.eval::, _>(level + C::N::ONE); - let g = builder.get(two_adic_generators, level_plus_one); - let g_inv = builder.hint_felt(); - builder.assert_eq::>(g_inv * g, C::F::from_canonical_usize(1)); + let g_inv = builder.get(two_adic_generators_inverses, level_plus_one); let g_inv_index = pow_felt_bits(builder, g_inv, index_bits, level.into()); diff --git a/src/basefold_verifier/structs.rs b/src/basefold_verifier/structs.rs index a933c5c..1c31594 100644 --- a/src/basefold_verifier/structs.rs +++ b/src/basefold_verifier/structs.rs @@ -30,18 +30,6 @@ pub struct CircuitIndexMeta { pub fixed_num_polys: usize, } -use mpcs::CircuitIndexMeta as InnerCircuitIndexMeta; -impl From for CircuitIndexMeta { - fn from(inner: InnerCircuitIndexMeta) -> Self { - Self { - witin_num_vars: inner.witin_num_vars, - witin_num_polys: inner.witin_num_polys, - fixed_num_vars: inner.fixed_num_vars, - fixed_num_polys: inner.fixed_num_polys, - } - } -} - impl Hintable for CircuitIndexMeta { type HintVariable = CircuitIndexMetaVariable; @@ -126,11 +114,13 @@ pub fn get_base_codeword_dimensions( // wit_dim // let width = builder.eval(witin_num_polys * Usize::from(2)); let height_exp = builder.eval(witin_num_vars + get_rate_log::() - Usize::from(1)); - let height = pow_2(builder, height_exp); + // let height = pow_2(builder, height_exp); // let next_wit: DimensionsVariable = DimensionsVariable { width, height }; // Only keep the height because the width is not needed in the mmcs batch // verify instruction - builder.set_value(&wit_dim, i, height); + // The dimension passed to the mmcs verifier batch is log of the height, not + // the height itself + builder.set_value(&wit_dim, i, height_exp); // fixed_dim // XXX: since fixed_num_vars is usize, fixed_num_vars > 0 is equivalent to fixed_num_vars != 0 @@ -141,9 +131,9 @@ pub fn get_base_codeword_dimensions( let height_exp = builder.eval(fixed_num_vars.clone() + get_rate_log::() - Usize::from(1)); // XXX: more efficient pow implementation - let height = pow_2(builder, height_exp); + // let height = pow_2(builder, height_exp); // let next_fixed: DimensionsVariable = DimensionsVariable { width, height }; - builder.set_value(&fixed_dim, i, height); + builder.set_value(&fixed_dim, i, height_exp); }); }); (wit_dim, fixed_dim) diff --git a/src/basefold_verifier/utils.rs b/src/basefold_verifier/utils.rs index 49dc5a2..6b6bf92 100644 --- a/src/basefold_verifier/utils.rs +++ b/src/basefold_verifier/utils.rs @@ -1,6 +1,9 @@ use openvm_native_compiler::ir::*; +use openvm_native_recursion::vars::HintSlice; use p3_field::FieldAlgebra; +use crate::basefold_verifier::mmcs::MmcsProof; + // XXX: more efficient pow implementation pub fn pow(builder: &mut Builder, base: Var, exponent: Var) -> Var { let value: Var = builder.constant(C::N::ONE); @@ -27,18 +30,30 @@ pub fn pow_felt( pub fn pow_felt_bits( builder: &mut Builder, base: Felt, - exponent_bits: &Array>, // In small endian + exponent_bits: &Array>, // FIXME: Should be big endian? There is a bit_reverse_rows() in Ceno native code exponent_len: Usize, ) -> Felt { let value: Felt = builder.constant(C::F::ONE); - let repeated_squared: Felt = base; + + // Little endian + // let repeated_squared: Felt = base; + // builder.range(0, exponent_len).for_each(|ptr, builder| { + // let ptr = ptr[0]; + // let bit = builder.get(exponent_bits, ptr); + // builder.if_eq(bit, C::N::ONE).then(|builder| { + // builder.assign(&value, value * repeated_squared); + // }); + // builder.assign(&repeated_squared, repeated_squared * repeated_squared); + // }); + + // Big endian builder.range(0, exponent_len).for_each(|ptr, builder| { let ptr = ptr[0]; + builder.assign(&value, value * value); let bit = builder.get(exponent_bits, ptr); builder.if_eq(bit, C::N::ONE).then(|builder| { - builder.assign(&value, value * repeated_squared); + builder.assign(&value, value * base); }); - builder.assign(&repeated_squared, repeated_squared * repeated_squared); }); value } @@ -123,6 +138,25 @@ pub fn bin_to_dec( value } +// Convert start to end entries of binary to decimal in little endian +pub fn bin_to_dec_le( + builder: &mut Builder, + bin: &Array>, + start: Var, + end: Var, +) -> Var { + let value: Var = builder.constant(C::N::ZERO); + let two: Var = builder.constant(C::N::TWO); + let power_of_two: Var = builder.constant(C::N::ONE); + builder.range(start, end).for_each(|i_vec, builder| { + let i = i_vec[0]; + let next_bit = builder.get(bin, i); + builder.assign(&value, value + power_of_two * next_bit); + builder.assign(&power_of_two, power_of_two * two); + }); + value +} + // Sort a list in decreasing order, returns: // 1. The original index of each sorted entry // 2. Number of unique entries @@ -131,7 +165,12 @@ pub fn sort_with_count( builder: &mut Builder, list: &Array, ind: Ind, // Convert loaded out entries into comparable ones -) -> (Array>, Var, Array>) +) -> ( + Array>, + Var, + Array>, + Array>, +) where E: openvm_native_compiler::ir::MemVariable, N: Into::N>> @@ -150,6 +189,7 @@ where // 1. count_per_unique_entry: for each unique entry value, count of entries of that value let num_unique_entries = builder.hint_var(); let count_per_unique_entry = builder.dyn_array(num_unique_entries); + let sorted_unique_num_vars = builder.dyn_array(num_unique_entries); let zero: Ext = builder.constant(C::EF::ZERO); let one: Ext = builder.constant(C::EF::ONE); let entries_sort_surjective: Array> = builder.dyn_array(len.clone()); @@ -199,6 +239,11 @@ where last_unique_entry_index, last_count_per_unique_entry, ); + builder.set( + &sorted_unique_num_vars, + last_unique_entry_index, + last_entry.clone(), + ); builder.assign(&last_entry, next_entry.clone()); builder.assign( &last_unique_entry_index, @@ -216,13 +261,23 @@ where last_unique_entry_index, last_count_per_unique_entry, ); + builder.set( + &sorted_unique_num_vars, + last_unique_entry_index, + last_entry.clone(), + ); builder.assign( &last_unique_entry_index, last_unique_entry_index + Usize::from(1), ); builder.assert_var_eq(last_unique_entry_index, num_unique_entries); - (entries_order, num_unique_entries, count_per_unique_entry) + ( + entries_order, + num_unique_entries, + count_per_unique_entry, + sorted_unique_num_vars, + ) } pub fn codeword_fold_with_challenge( @@ -244,3 +299,9 @@ pub fn codeword_fold_with_challenge( let ret: Ext = builder.eval(lo + challenge * (hi - lo)); ret } + +pub(crate) fn read_hint_slice(builder: &mut Builder) -> HintSlice { + let length = Usize::from(builder.hint_var()); + let id = Usize::from(builder.hint_load()); + HintSlice { length, id } +} diff --git a/src/e2e/mod.rs b/src/e2e/mod.rs index 505c217..4afd4b6 100644 --- a/src/e2e/mod.rs +++ b/src/e2e/mod.rs @@ -1,3 +1,4 @@ +use crate::basefold_verifier::query_phase::QueryPhaseVerifierInput; use crate::tower_verifier::binding::IOPProverMessage; use crate::zkvm_verifier::binding::ZKVMProofInput; use crate::zkvm_verifier::binding::{ From 21c4d411bf33cf124638361525e950092ce9aafe Mon Sep 17 00:00:00 2001 From: Cyte Zhang Date: Mon, 21 Jul 2025 18:46:28 +0800 Subject: [PATCH 63/70] Simplify BaseFold verifier (#34) * [Upgrade] ZKVMProof Verifier Update (#31) * Remove index reversal * Add a cycle tracker * Delete a loop * Better casting * Change verifier logic * Finish opcdoe proof verification debugging * Finish debugging table proof verification * Debug verifier * Finish debugging updated verifier * Remove unnecessary table proof fields * Remove unnecessary parsing * Update Plonky3 * Migrate away from temporary build branch * Switch ceno reliance * Fix compilation errors due to out of date code * Update test query phase batch * Fix query opening proof * Implement basefold proof variable * Update query phase verifier input * Preparing test data for query phase with updated code * Implement basefold proof transform * Prepare query phase verifier input * Prepare query phase verifier input * Fix final message access * Switch ceno reliance to small field support * basefold verifier for one matrix (#35) * wip * wip2 * wip3 * fix test * fix * fmt * fri part of verifying basefold proof for 1 matrix passed * sumcheck part 1 * sumcheck part 2 * sumcheck part 3 * cleanup * more cleanups --------- Co-authored-by: Ray Gao Co-authored-by: xkx --- Cargo.lock | 225 +++--- Cargo.toml | 37 +- rust-toolchain.toml | 2 +- src/arithmetics/mod.rs | 187 +++-- src/basefold_verifier/basefold.rs | 143 +++- src/basefold_verifier/query_phase.rs | 990 +++++++++------------------ src/e2e/mod.rs | 275 ++++---- src/extensions/mod.rs | 14 +- src/tower_verifier/binding.rs | 108 +-- src/tower_verifier/program.rs | 154 ++--- src/zkvm_verifier/binding.rs | 207 +++--- src/zkvm_verifier/verifier.rs | 644 +++++++---------- 12 files changed, 1291 insertions(+), 1695 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 79fc6b2..7c479b7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -137,27 +137,6 @@ version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" -[[package]] -name = "ark-ec" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43d68f2d516162846c1238e755a7c4d131b892b70cc70c471a8e3ca3ed818fce" -dependencies = [ - "ahash", - "ark-ff 0.5.0", - "ark-poly", - "ark-serialize 0.5.0", - "ark-std 0.5.0", - "educe", - "fnv", - "hashbrown 0.15.4", - "itertools 0.13.0", - "num-bigint 0.4.6", - "num-integer", - "num-traits", - "zeroize", -] - [[package]] name = "ark-ff" version = "0.4.2" @@ -537,9 +516,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.27" +version = "1.2.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d487aa071b5f64da6f19a3e848e3578944b726ee5a4854b82172f02aa876bfdc" +checksum = "5c1599538de2394445747c8cf7935946e3cc27e9625f889d979bfb2aaf569362" dependencies = [ "shlex", ] @@ -547,7 +526,6 @@ dependencies = [ [[package]] name = "ceno-examples" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "glob", ] @@ -598,12 +576,13 @@ dependencies = [ [[package]] name = "ceno_emul" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "anyhow", "ceno_rt", "elf", + "ff_ext", "itertools 0.13.0", + "multilinear_extensions", "num-bigint 0.4.6", "num-derive", "num-traits", @@ -619,7 +598,6 @@ dependencies = [ [[package]] name = "ceno_host" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "anyhow", "ceno_emul", @@ -632,7 +610,6 @@ dependencies = [ [[package]] name = "ceno_rt" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "getrandom 0.2.16", "rkyv", @@ -641,7 +618,6 @@ dependencies = [ [[package]] name = "ceno_zkvm" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "base64", "bincode", @@ -649,8 +625,10 @@ dependencies = [ "ceno_emul", "ceno_host", "clap", + "either", "ff_ext", "generic_static", + "gkr_iop", "glob", "itertools 0.13.0", "mpcs", @@ -725,9 +703,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.40" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40b6887a1d8685cebccf115538db5c0efe625ccac9696ad45c409d96566e910f" +checksum = "be92d32e80243a54711e5d7ce823c35c41c9d929dc4ab58e1276f625841aadf9" dependencies = [ "clap_builder", "clap_derive", @@ -735,9 +713,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.40" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0c66c08ce9f0c698cbce5c0279d0bb6ac936d8674174fe48f736533b964f59e" +checksum = "707eab41e9622f9139419d573eca0900137718000c517d47da73045f54331c3d" dependencies = [ "anstream", "anstyle", @@ -747,9 +725,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.40" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2c7947ae4cc3d851207c1adb5b5e260ff0cca11446b1d6d1423788e442257ce" +checksum = "ef4f52386a59ca4c860f7393bcf8abd8dfd91ecccc0f774635ff68e92eeef491" dependencies = [ "heck", "proc-macro2", @@ -1028,6 +1006,9 @@ name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +dependencies = [ + "serde", +] [[package]] name = "elf" @@ -1153,8 +1134,8 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ + "once_cell", "p3", "rand_core", "serde", @@ -1239,6 +1220,40 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "gkr_iop" +version = "0.1.0" +dependencies = [ + "ark-std 0.5.0", + "bincode", + "ceno_emul", + "clap", + "either", + "ff_ext", + "itertools 0.13.0", + "mpcs", + "multilinear_extensions", + "ndarray", + "p3", + "p3-field", + "p3-goldilocks", + "p3-util", + "rand", + "rayon", + "serde", + "strum", + "strum_macros", + "sumcheck", + "thiserror", + "thread_local", + "tiny-keccak", + "tracing", + "tracing-forest", + "tracing-subscriber", + "transcript", + "witness", +] + [[package]] name = "glam" version = "0.30.4" @@ -1606,6 +1621,16 @@ dependencies = [ "regex-automata 0.1.10", ] +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "memchr" version = "2.7.5" @@ -1676,11 +1701,11 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "aes", "bincode", "bitvec", + "clap", "ctr", "ff_ext", "generic-array", @@ -1706,8 +1731,8 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ + "either", "ff_ext", "itertools 0.13.0", "p3", @@ -1738,41 +1763,27 @@ dependencies = [ ] [[package]] -name = "nibble_vec" -version = "0.1.0" +name = "ndarray" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" -dependencies = [ - "smallvec", -] - -[[package]] -name = "nimue" -version = "0.2.0" -source = "git+https://github.com/arkworks-rs/nimue?rev=3a83250#3a83250d9e30046be464753901430710232561fc" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" dependencies = [ - "ark-ec", - "ark-ff 0.5.0", - "ark-serialize 0.5.0", - "digest", - "hex", - "keccak", - "log", - "rand", - "zeroize", + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", ] [[package]] -name = "nimue-pow" +name = "nibble_vec" version = "0.1.0" -source = "git+https://github.com/arkworks-rs/nimue?rev=3a83250#3a83250d9e30046be464753901430710232561fc" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" dependencies = [ - "blake3", - "bytemuck", - "keccak", - "nimue", - "rand", - "rayon", + "smallvec", ] [[package]] @@ -1807,6 +1818,15 @@ dependencies = [ "rand", ] +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + [[package]] name = "num-derive" version = "0.4.2" @@ -1909,7 +1929,7 @@ checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" [[package]] name = "openvm" version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" dependencies = [ "bytemuck", "num-bigint 0.4.6", @@ -1922,7 +1942,7 @@ dependencies = [ [[package]] name = "openvm-circuit" version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" dependencies = [ "backtrace", "cfg-if", @@ -1953,7 +1973,7 @@ dependencies = [ [[package]] name = "openvm-circuit-derive" version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" dependencies = [ "itertools 0.14.0", "quote", @@ -1963,7 +1983,7 @@ dependencies = [ [[package]] name = "openvm-circuit-primitives" version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" dependencies = [ "derive-new 0.6.0", "itertools 0.14.0", @@ -1978,7 +1998,7 @@ dependencies = [ [[package]] name = "openvm-circuit-primitives-derive" version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" dependencies = [ "itertools 0.14.0", "quote", @@ -1988,7 +2008,7 @@ dependencies = [ [[package]] name = "openvm-custom-insn" version = "0.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" dependencies = [ "proc-macro2", "quote", @@ -1998,7 +2018,7 @@ dependencies = [ [[package]] name = "openvm-instructions" version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" dependencies = [ "backtrace", "derive-new 0.6.0", @@ -2015,7 +2035,7 @@ dependencies = [ [[package]] name = "openvm-instructions-derive" version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" dependencies = [ "quote", "syn 2.0.104", @@ -2024,7 +2044,7 @@ dependencies = [ [[package]] name = "openvm-native-circuit" version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", @@ -2051,7 +2071,7 @@ dependencies = [ [[package]] name = "openvm-native-compiler" version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" dependencies = [ "backtrace", "itertools 0.14.0", @@ -2073,7 +2093,7 @@ dependencies = [ [[package]] name = "openvm-native-compiler-derive" version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" dependencies = [ "quote", "syn 2.0.104", @@ -2082,7 +2102,7 @@ dependencies = [ [[package]] name = "openvm-native-recursion" version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" dependencies = [ "cfg-if", "itertools 0.14.0", @@ -2106,7 +2126,7 @@ dependencies = [ [[package]] name = "openvm-platform" version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" dependencies = [ "getrandom 0.2.16", "libm", @@ -2117,7 +2137,7 @@ dependencies = [ [[package]] name = "openvm-poseidon2-air" version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" dependencies = [ "derivative", "lazy_static", @@ -2134,7 +2154,7 @@ dependencies = [ [[package]] name = "openvm-rv32im-circuit" version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", @@ -2157,7 +2177,7 @@ dependencies = [ [[package]] name = "openvm-rv32im-guest" version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" dependencies = [ "openvm-custom-insn", "strum_macros", @@ -2166,7 +2186,7 @@ dependencies = [ [[package]] name = "openvm-rv32im-transpiler" version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" dependencies = [ "openvm-instructions", "openvm-instructions-derive", @@ -2242,7 +2262,7 @@ dependencies = [ [[package]] name = "openvm-transpiler" version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" dependencies = [ "elf", "eyre", @@ -2271,7 +2291,6 @@ checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "p3-baby-bear", "p3-challenger", @@ -2703,10 +2722,18 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "criterion", "ff_ext", @@ -2894,6 +2921,12 @@ dependencies = [ "bitflags", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.10.0" @@ -3058,15 +3091,15 @@ dependencies = [ [[package]] name = "rustix" -version = "1.0.7" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" +checksum = "11181fbabf243db407ef8df94a6ce0b2f9a733bd8be4ad02b4eda9602296cac8" dependencies = [ "bitflags", "errno", "libc", "linux-raw-sys", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] @@ -3309,9 +3342,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "crossbeam-channel", + "either", "ff_ext", "itertools 0.13.0", "multilinear_extensions", @@ -3319,6 +3352,7 @@ dependencies = [ "rayon", "serde", "sumcheck_macro", + "thiserror", "tracing", "transcript", ] @@ -3326,7 +3360,6 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "itertools 0.13.0", "p3", @@ -3514,7 +3547,7 @@ dependencies = [ "serde_spanned", "toml_datetime", "toml_write", - "winnow 0.7.11", + "winnow 0.7.12", ] [[package]] @@ -3600,7 +3633,6 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "crossbeam-channel", "ff_ext", @@ -3778,7 +3810,6 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "bincode", "blake2", @@ -3791,8 +3822,6 @@ dependencies = [ "itertools 0.14.0", "lazy_static", "multilinear_extensions", - "nimue", - "nimue-pow", "p3", "poseidon", "rand", @@ -3997,9 +4026,9 @@ dependencies = [ [[package]] name = "winnow" -version = "0.7.11" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74c7b26e3480b707944fc872477815d29a8e429d2f93a1ce000f5fa84a15cbcd" +checksum = "f3edebf492c8125044983378ecb5766203ad3b4c2f7a922bd7dd207f6d443e95" dependencies = [ "memchr", ] @@ -4016,7 +4045,6 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fexport_ff_ext#c2e55fe1a25872cea1ddd916c8ae023d70f29c0c" dependencies = [ "ff_ext", "multilinear_extensions", @@ -4024,6 +4052,7 @@ dependencies = [ "rand", "rayon", "serde", + "tracing", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 97088ca..eea354f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,13 +10,12 @@ openvm-native-circuit = { git = "https://github.com/scroll-tech/openvm.git", bra openvm-native-compiler = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/native_multi_observe", default-features = false } openvm-native-compiler-derive = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/native_multi_observe", default-features = false } openvm-native-recursion = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/native_multi_observe", default-features = false } - openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.0.0", default-features = false } openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.0.0", default-features = false } rand = { version = "0.8.5", default-features = false } itertools = { version = "0.13.0", default-features = false } -bincode = "1" +bincode = "1.3.3" tracing = "0.1.40" # Plonky3 @@ -39,26 +38,26 @@ ark-poly = "0.5" ark-serialize = "0.5" # Ceno -ceno_mle = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext", package = "multilinear_extensions" } -ceno_sumcheck = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext", package = "sumcheck" } -ceno_transcript = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext", package = "transcript" } -ceno_witness = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext", package = "witness" } -ceno_zkvm = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext" } -ceno_emul = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext" } -mpcs = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext" } -ff_ext = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext" } +ceno_mle = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support", package = "multilinear_extensions" } +ceno_sumcheck = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support", package = "sumcheck" } +ceno_transcript = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support", package = "transcript" } +ceno_witness = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support", package = "witness" } +ceno_zkvm = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" } +ceno_emul = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" } +mpcs = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" } +ff_ext = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" [features] bench-metrics = ["openvm-circuit/bench-metrics"] -# [patch."https://github.com/scroll-tech/ceno.git"] -# ceno_mle = { path = "../ceno/multilinear_extensions", package = "multilinear_extensions" } -# ceno_sumcheck = { path = "../ceno/sumcheck", package = "sumcheck" } -# ceno_transcript = { path = "../ceno/transcript", package = "transcript" } -# ceno_witness = { path = "../ceno/witness", package = "witness" } -# ceno_zkvm = { path = "../ceno/ceno_zkvm" } -# ceno_emul = { path = "../ceno/ceno_emul" } -# mpcs = { path = "../ceno/mpcs" } -# ff_ext = { path = "../ceno/ff_ext" } +[patch."https://github.com/scroll-tech/ceno.git"] +ceno_mle = { path = "../ceno/multilinear_extensions", package = "multilinear_extensions" } +ceno_sumcheck = { path = "../ceno/sumcheck", package = "sumcheck" } +ceno_transcript = { path = "../ceno/transcript", package = "transcript" } +ceno_witness = { path = "../ceno/witness", package = "witness" } +ceno_zkvm = { path = "../ceno/ceno_zkvm" } +ceno_emul = { path = "../ceno/ceno_emul" } +mpcs = { path = "../ceno/mpcs" } +ff_ext = { path = "../ceno/ff_ext" } diff --git a/rust-toolchain.toml b/rust-toolchain.toml index da3da89..9661b2d 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,4 +1,4 @@ [toolchain] -channel = "nightly-2025-01-06" +channel = "nightly-2025-03-25" targets = ["riscv32im-unknown-none-elf"] components = ["clippy", "rustfmt", "rust-src"] \ No newline at end of file diff --git a/src/arithmetics/mod.rs b/src/arithmetics/mod.rs index 64bd151..b8ac2bc 100644 --- a/src/arithmetics/mod.rs +++ b/src/arithmetics/mod.rs @@ -1,9 +1,10 @@ use crate::tower_verifier::binding::PointAndEvalVariable; use crate::zkvm_verifier::binding::ZKVMOpcodeProofInputVariable; -use ceno_zkvm::expression::{Expression, Fixed, Instance}; +use ceno_mle::{Expression, Fixed, Instance}; use ceno_zkvm::structs::{ChallengeId, WitnessId}; use ff_ext::ExtensionField; use ff_ext::{BabyBearExt4, SmallField}; +use itertools::Either; use openvm_native_compiler::prelude::*; use openvm_native_compiler_derive::iter_zip; use openvm_native_recursion::challenger::ChallengerVariable; @@ -13,8 +14,9 @@ use openvm_native_recursion::challenger::{ use p3_field::{FieldAlgebra, FieldExtensionAlgebra}; type E = BabyBearExt4; const HASH_RATE: usize = 8; +const MAX_NUM_VARS: usize = 25; -pub fn _print_ext_arr(builder: &mut Builder, arr: &Array>) { +pub fn print_ext_arr(builder: &mut Builder, arr: &Array>) { iter_zip!(builder, arr).for_each(|ptr_vec, builder| { let e = builder.iter_ptr_get(arr, ptr_vec[0]); builder.print_e(e); @@ -28,7 +30,7 @@ pub fn print_felt_arr(builder: &mut Builder, arr: &Array(builder: &mut Builder, arr: &Array>) { +pub fn print_usize_arr(builder: &mut Builder, arr: &Array>) { iter_zip!(builder, arr).for_each(|ptr_vec, builder| { let n = builder.iter_ptr_get(arr, ptr_vec[0]); builder.print_v(n.get_var()); @@ -83,13 +85,11 @@ pub fn is_smaller_than( RVar::from(v) } -pub fn evaluate_at_point( +pub fn evaluate_at_point_degree_1( builder: &mut Builder, evals: &Array>, point: &Array>, ) -> Ext { - // TODO: Dynamic length - // TODO: Sanity checks let left = builder.get(&evals, 0); let right = builder.get(&evals, 1); let r = builder.get(point, 0); @@ -114,6 +114,80 @@ pub fn fixed_dot_product( acc } +pub struct PolyEvaluator { + powers_of_2: Array>, +} + +impl PolyEvaluator { + pub fn new(builder: &mut Builder) -> Self { + let powers_of_2: Array> = builder.dyn_array(MAX_NUM_VARS); + builder.set(&powers_of_2, 0, Usize::from(16777216)); + builder.set(&powers_of_2, 1, Usize::from(8388608)); + builder.set(&powers_of_2, 2, Usize::from(4194304)); + builder.set(&powers_of_2, 3, Usize::from(1048576)); + builder.set(&powers_of_2, 4, Usize::from(2097152)); + builder.set(&powers_of_2, 5, Usize::from(524288)); + builder.set(&powers_of_2, 6, Usize::from(262144)); + builder.set(&powers_of_2, 7, Usize::from(131072)); + builder.set(&powers_of_2, 8, Usize::from(65536)); + builder.set(&powers_of_2, 9, Usize::from(32768)); + builder.set(&powers_of_2, 10, Usize::from(16384)); + builder.set(&powers_of_2, 11, Usize::from(8192)); + builder.set(&powers_of_2, 12, Usize::from(4096)); + builder.set(&powers_of_2, 13, Usize::from(2048)); + builder.set(&powers_of_2, 14, Usize::from(1024)); + builder.set(&powers_of_2, 15, Usize::from(512)); + builder.set(&powers_of_2, 16, Usize::from(256)); + builder.set(&powers_of_2, 17, Usize::from(128)); + builder.set(&powers_of_2, 18, Usize::from(64)); + builder.set(&powers_of_2, 19, Usize::from(32)); + builder.set(&powers_of_2, 20, Usize::from(16)); + builder.set(&powers_of_2, 21, Usize::from(8)); + builder.set(&powers_of_2, 22, Usize::from(4)); + builder.set(&powers_of_2, 23, Usize::from(2)); + builder.set(&powers_of_2, 24, Usize::from(1)); + + Self { powers_of_2 } + } + + pub fn evaluate_base_poly_at_point( + &self, + builder: &mut Builder, + evals: &Array>, + point: &Array>, + ) -> Ext { + let num_var = point.len(); + + let evals_ext: Array> = builder.dyn_array(evals.len()); + iter_zip!(builder, evals, evals_ext).for_each(|ptr_vec, builder| { + let f = builder.iter_ptr_get(&evals, ptr_vec[0]); + let e = builder.ext_from_base_slice(&[f]); + builder.iter_ptr_set(&evals_ext, ptr_vec[1], e); + }); + + let pwr_slice_idx: Usize = builder.eval(Usize::from(25) - num_var); + let pwrs = self.powers_of_2.slice(builder, pwr_slice_idx, MAX_NUM_VARS); + + iter_zip!(builder, point, pwrs).for_each(|ptr_vec, builder| { + let pt = builder.iter_ptr_get(&point, ptr_vec[0]); + let idx_bound = builder.iter_ptr_get(&pwrs, ptr_vec[1]); + + builder.range(0, idx_bound).for_each(|idx_vec, builder| { + let left_idx: Usize = builder.eval(idx_vec[0] * Usize::from(2)); + let right_idx: Usize = + builder.eval(idx_vec[0] * Usize::from(2) + Usize::from(1)); + let left = builder.get(&evals_ext, left_idx); + let right = builder.get(&evals_ext, right_idx); + + let e: Ext = builder.eval(pt * (right - left) + left); + builder.set(&evals_ext, idx_vec[0], e); + }); + }); + + builder.get(&evals_ext, 0) + } +} + pub fn dot_product( builder: &mut Builder, a: &Array>, @@ -261,6 +335,24 @@ pub fn product( acc } +// Multiply all elements in a nested Array +pub fn nested_product( + builder: &mut Builder, + arr: &Array>>, +) -> Ext { + let acc = builder.constant(C::EF::ONE); + iter_zip!(builder, arr).for_each(|ptr_vec, builder| { + let inner_arr = builder.iter_ptr_get(arr, ptr_vec[0]); + + iter_zip!(builder, inner_arr).for_each(|ptr_vec, builder| { + let el = builder.iter_ptr_get(&inner_arr, ptr_vec[0]); + builder.assign(&acc, acc * el); + }); + }); + + acc +} + // Add all elements in the Array pub fn sum( builder: &mut Builder, @@ -334,12 +426,13 @@ pub fn eq_eval_less_or_equal_than( a: &Array>, b: &Array>, ) -> Ext { + builder.cycle_tracker_start("Compute eq_eval_less_or_equal_than"); let eq_bit_decomp: Array> = opcode_proof .num_instances_minus_one_bit_decomposition .slice(builder, 0, b.len()); let one_ext: Ext = builder.constant(C::EF::ONE); - let rp_len = builder.eval_expr(RVar::from(b.len()) + RVar::from(1)); + let rp_len = builder.eval_expr(b.len() + C::N::ONE); let running_product: Array> = builder.dyn_array(rp_len); builder.set(&running_product, 0, one_ext); @@ -353,49 +446,33 @@ pub fn eq_eval_less_or_equal_than( builder.set(&running_product, next_idx, next_v); }); - let running_product2: Array> = builder.dyn_array(rp_len); - builder.set(&running_product2, b.len(), one_ext); - - let eq_bit_decomp_rev = reverse(builder, &eq_bit_decomp); - let idx_arr = gen_idx_arr(builder, b.len()); - let idx_arr_rev = reverse(builder, &idx_arr); - builder.assert_usize_eq(eq_bit_decomp_rev.len(), idx_arr_rev.len()); - - iter_zip!(builder, idx_arr_rev, eq_bit_decomp_rev).for_each(|ptr_vec, builder| { - let i = builder.iter_ptr_get(&idx_arr_rev, ptr_vec[0]); - let bit = builder.iter_ptr_get(&eq_bit_decomp_rev, ptr_vec[1]); - let bit_ext = builder.ext_from_base_slice(&[bit]); - let last_idx = builder.eval_expr(i.clone() + RVar::from(1)); - - let v = builder.get(&running_product2, last_idx); - let a_i = builder.get(a, i.clone()); - let b_i = builder.get(b, i.clone()); - - let next_v: Ext = builder.eval( - v * (a_i * b_i * bit_ext + (one_ext - a_i) * (one_ext - b_i) * (one_ext - bit_ext)), - ); - builder.set(&running_product2, i, next_v); - }); - - // Here is an example of how this works: - // Suppose max_idx = (110101)_2 - // Then ans = eq(a, b) - // - eq(11011, a[1..6], b[1..6])eq(a[0..1], b[0..1]) - // - eq(111, a[3..6], b[3..6])eq(a[0..3], b[0..3]) let ans = builder.get(&running_product, b.len()); - builder.range(0, b.len()).for_each(|idx_vec, builder| { - let bit = builder.get(&eq_bit_decomp, idx_vec[0]); + let running_product2: Ext = builder.constant(C::EF::ONE); + let idx: Var = builder.uninit(); + builder.assign(&idx, b.len() - C::N::ONE); + builder.range(0, b.len()).for_each(|_, builder| { + let bit = builder.get(&eq_bit_decomp, idx); let bit_rvar = RVar::from(builder.cast_felt_to_var(bit)); + let bit_ext: Ext = builder.eval(bit * SymbolicExt::from_f(C::EF::ONE)); - builder.if_ne(bit_rvar, RVar::from(1)).then(|builder| { - let next_idx = builder.eval_expr(idx_vec[0] + RVar::from(1)); - let v1 = builder.get(&running_product, idx_vec[0]); - let v2 = builder.get(&running_product2, next_idx); - let a_i = builder.get(a, idx_vec[0]); - let b_i = builder.get(b, idx_vec[0]); + let a_i = builder.get(a, idx); + let b_i = builder.get(b, idx); - builder.assign(&ans, ans - v1 * v2 * a_i * b_i); + // Suppose max_idx = (110101)_2 + // Then ans = eq(a, b) + // - eq(11011, a[1..6], b[1..6])eq(a[0..1], b[0..1]) + // - eq(111, a[3..6], b[3..6])eq(a[0..3], b[0..3]) + builder.if_ne(bit_rvar, RVar::from(1)).then(|builder| { + let v1 = builder.get(&running_product, idx); + builder.assign(&ans, ans - v1 * running_product2 * a_i * b_i); }); + + builder.assign( + &running_product2, + running_product2 + * (a_i * b_i * bit_ext + (one_ext - a_i) * (one_ext - b_i) * (one_ext - bit_ext)), + ); + builder.assign(&idx, idx - C::N::ONE); }); let a_remainder_arr: Array> = a.slice(builder, b.len(), a.len()); @@ -404,6 +481,8 @@ pub fn eq_eval_less_or_equal_than( builder.assign(&ans, ans * (one_ext - a)); }); + builder.cycle_tracker_end("Compute eq_eval_less_or_equal_than"); + ans } @@ -537,9 +616,14 @@ pub fn eval_ceno_expr_with_instance( res }, &|builder, scalar| { - let res: Ext = - builder.constant(C::EF::from_canonical_u32(scalar.to_canonical_u64() as u32)); - res + let scalar_base_slice = scalar + .as_bases() + .iter() + .map(|b| C::F::from_canonical_u64(b.to_canonical_u64())) + .collect::>(); + let scalar_ext: Ext = + builder.constant(C::EF::from_base_slice(&scalar_base_slice)); + scalar_ext }, &|builder, challenge_id, pow, scalar, offset| { let challenge = builder.get(&challenges, challenge_id as usize); @@ -587,7 +671,7 @@ pub fn evaluate_ceno_expr( wit_in: &impl Fn(&mut Builder, WitnessId) -> T, // witin id structural_wit_in: &impl Fn(&mut Builder, WitnessId, usize, u32, usize) -> T, instance: &impl Fn(&mut Builder, Instance) -> T, - constant: &impl Fn(&mut Builder, ::BaseField) -> T, + constant: &impl Fn(&mut Builder, E) -> T, challenge: &impl Fn(&mut Builder, ChallengeId, usize, E, E) -> T, sum: &impl Fn(&mut Builder, T, T) -> T, product: &impl Fn(&mut Builder, T, T) -> T, @@ -600,7 +684,10 @@ pub fn evaluate_ceno_expr( structural_wit_in(builder, *witness_id, *max_len, *offset, *multi_factor) } Expression::Instance(i) => instance(builder, *i), - Expression::Constant(scalar) => constant(builder, *scalar), + Expression::Constant(scalar) => match scalar { + Either::Left(s) => constant(builder, E::from_base(*s)), + Either::Right(s) => constant(builder, *s), + }, Expression::Sum(a, b) => { let a = evaluate_ceno_expr( builder, diff --git a/src/basefold_verifier/basefold.rs b/src/basefold_verifier/basefold.rs index e025145..12735ef 100644 --- a/src/basefold_verifier/basefold.rs +++ b/src/basefold_verifier/basefold.rs @@ -1,10 +1,19 @@ +use mpcs::basefold::BasefoldProof as InnerBasefoldProof; use openvm_native_compiler::{asm::AsmConfig, prelude::*}; -use openvm_native_recursion::hints::Hintable; +use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; use serde::Deserialize; -use crate::basefold_verifier::hash::Hash; +use crate::{ + basefold_verifier::{ + hash::{Hash, HashVariable}, + query_phase::{ + PointAndEvals, PointAndEvalsVariable, QueryOpeningProofs, QueryOpeningProofsVariable, + }, + }, + tower_verifier::binding::{IOPProverMessage, IOPProverMessageVariable}, +}; use super::{mmcs::*, structs::DIMENSIONS}; @@ -17,7 +26,6 @@ pub type HashDigest = MmcsCommitment; pub struct BasefoldCommitment { pub commit: HashDigest, pub log2_max_codeword_size: usize, - pub trivial_commits: Vec, } use mpcs::BasefoldCommitment as InnerBasefoldCommitment; @@ -29,11 +37,6 @@ impl From> for BasefoldCommitment { value: value.commit().into(), }, log2_max_codeword_size: value.log2_max_codeword_size, - trivial_commits: value - .trivial_commits - .into_iter() - .map(|c| c.into()) - .collect(), } } } @@ -44,12 +47,10 @@ impl Hintable for BasefoldCommitment { fn read(builder: &mut Builder) -> Self::HintVariable { let commit = HashDigest::read(builder); let log2_max_codeword_size = Usize::Var(usize::read(builder)); - // let trivial_commits = Vec::::read(builder); BasefoldCommitmentVariable { commit, log2_max_codeword_size, - // trivial_commits, } } @@ -59,7 +60,6 @@ impl Hintable for BasefoldCommitment { stream.extend(>::write( &self.log2_max_codeword_size, )); - // stream.extend(self.trivial_commits.write()); stream } } @@ -71,3 +71,124 @@ pub struct BasefoldCommitmentVariable { pub log2_max_codeword_size: Usize, // pub trivial_commits: Array>, } + +#[derive(Deserialize)] +pub struct BasefoldProof { + pub commits: Vec, + pub final_message: Vec>, + pub query_opening_proof: QueryOpeningProofs, + pub sumcheck_proof: Vec, +} + +#[derive(DslVariable, Clone)] +pub struct BasefoldProofVariable { + pub commits: Array>, + pub final_message: Array>>, + pub query_opening_proof: QueryOpeningProofsVariable, + pub sumcheck_proof: Array>, +} + +impl Hintable for BasefoldProof { + type HintVariable = BasefoldProofVariable; + fn read(builder: &mut Builder) -> Self::HintVariable { + let commits = Vec::::read(builder); + let final_message = Vec::>::read(builder); + let query_opening_proof = QueryOpeningProofs::read(builder); + let sumcheck_proof = Vec::::read(builder); + BasefoldProofVariable { + commits, + final_message, + query_opening_proof, + sumcheck_proof, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.commits.write()); + stream.extend(self.final_message.write()); + stream.extend(self.query_opening_proof.write()); + stream.extend(self.sumcheck_proof.write()); + stream + } +} + +impl From> for BasefoldProof { + fn from(proof: InnerBasefoldProof) -> Self { + BasefoldProof { + commits: proof.commits.iter().map(|c| c.clone().into()).collect(), + final_message: proof.final_message.into(), + query_opening_proof: proof + .query_opening_proof + .iter() + .map(|proof| proof.clone().into()) + .collect(), + sumcheck_proof: proof.sumcheck_proof.map_or(vec![], |proof| { + proof.into_iter().map(|proof| proof.into()).collect() + }), + } + } +} + +#[derive(Deserialize)] +pub struct RoundOpening { + pub num_var: usize, + pub point_and_evals: PointAndEvals, +} + +#[derive(DslVariable, Clone)] +pub struct RoundOpeningVariable { + pub num_var: Var, + pub point_and_evals: PointAndEvalsVariable, +} + +impl Hintable for RoundOpening { + type HintVariable = RoundOpeningVariable; + fn read(builder: &mut Builder) -> Self::HintVariable { + let num_var = usize::read(builder); + let point_and_evals = PointAndEvals::read(builder); + RoundOpeningVariable { + num_var, + point_and_evals, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = vec![]; + stream.extend(>::write(&self.num_var)); + stream.extend(self.point_and_evals.write()); + stream + } +} + +impl VecAutoHintable for RoundOpening {} + +#[derive(Deserialize)] +pub struct Round { + pub commit: BasefoldCommitment, + pub openings: Vec, +} + +#[derive(DslVariable, Clone)] +pub struct RoundVariable { + pub commit: BasefoldCommitmentVariable, + pub openings: Array>, +} + +impl Hintable for Round { + type HintVariable = RoundVariable; + fn read(builder: &mut Builder) -> Self::HintVariable { + let commit = BasefoldCommitment::read(builder); + let openings = Vec::::read(builder); + RoundVariable { commit, openings } + } + + fn write(&self) -> Vec::N>> { + let mut stream = vec![]; + stream.extend(self.commit.write()); + stream.extend(self.openings.write()); + stream + } +} + +impl VecAutoHintable for Round {} diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index a35f5bc..9eb9a57 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -1,7 +1,10 @@ // Note: check all XXX comments! +use ark_std::log2; use ff_ext::{BabyBearExt4, ExtensionField, PoseidonField}; +use mpcs::basefold::QueryOpeningProof as InnerQueryOpeningProof; use openvm_native_compiler::{asm::AsmConfig, prelude::*}; +use openvm_native_compiler_derive::iter_zip; use openvm_native_recursion::{ hints::{Hintable, VecAutoHintable}, vars::HintSlice, @@ -88,6 +91,8 @@ impl Hintable for BatchOpening { } } +impl VecAutoHintable for BatchOpening {} + /// TODO: use `openvm_native_recursion::fri::types::FriCommitPhaseProofStepVariable` instead #[derive(Deserialize)] pub struct CommitPhaseProofStep { @@ -147,21 +152,35 @@ impl Hintable for CommitPhaseProofStep { #[derive(Deserialize)] pub struct QueryOpeningProof { - pub witin_base_proof: BatchOpening, - pub fixed_base_proof: Option, + pub input_proofs: Vec, pub commit_phase_openings: Vec, } +impl From> for QueryOpeningProof { + fn from(proof: InnerQueryOpeningProof) -> Self { + Self { + input_proofs: proof + .input_proofs + .into_iter() + .map(|proof| proof.into()) + .collect(), + commit_phase_openings: proof + .commit_phase_openings + .into_iter() + .map(|proof| proof.into()) + .collect(), + } + } +} + #[derive(DslVariable, Clone)] pub struct QueryOpeningProofVariable { - pub witin_base_proof: BatchOpeningVariable, - pub fixed_is_some: Usize, // 0 <==> false - pub fixed_base_proof: BatchOpeningVariable, + pub input_proofs: Array>, pub commit_phase_openings: Array>, } -type QueryOpeningProofs = Vec; -type QueryOpeningProofsVariable = Array>; +pub(crate) type QueryOpeningProofs = Vec; +pub(crate) type QueryOpeningProofsVariable = Array>; impl VecAutoHintable for QueryOpeningProof {} @@ -169,37 +188,23 @@ impl Hintable for QueryOpeningProof { type HintVariable = QueryOpeningProofVariable; fn read(builder: &mut Builder) -> Self::HintVariable { - let witin_base_proof = BatchOpening::read(builder); - let fixed_is_some = Usize::Var(usize::read(builder)); - let fixed_base_proof = BatchOpening::read(builder); + let input_proofs = Vec::::read(builder); let commit_phase_openings = Vec::::read(builder); QueryOpeningProofVariable { - witin_base_proof, - fixed_is_some, - fixed_base_proof, + input_proofs, commit_phase_openings, } } fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); - stream.extend(self.witin_base_proof.write()); - if let Some(fixed_base_proof) = &self.fixed_base_proof { - stream.extend(>::write(&1)); - stream.extend(fixed_base_proof.write()); - } else { - stream.extend(>::write(&0)); - let tmp_proof = BatchOpening { - opened_values: Vec::new(), - opening_proof: Vec::new(), - }; - stream.extend(tmp_proof.write()); - } + stream.extend(self.input_proofs.write()); stream.extend(self.commit_phase_openings.write()); stream } } +#[derive(Deserialize)] // NOTE: Different from PointAndEval in tower_verifier! pub struct PointAndEvals { pub point: Point, @@ -233,17 +238,11 @@ pub struct PointAndEvalsVariable { pub struct QueryPhaseVerifierInput { // pub t_inv_halves: Vec::BaseField>>, pub max_num_var: usize, - pub indices: Vec, - pub final_message: Vec>, pub batch_coeffs: Vec, - pub queries: QueryOpeningProofs, - pub fixed_comm: Option, - pub witin_comm: BasefoldCommitment, - pub circuit_meta: Vec, - pub commits: Vec, pub fold_challenges: Vec, - pub sumcheck_messages: Vec, - pub point_evals: Vec<(Point, Vec)>, + pub indices: Vec, + pub proof: BasefoldProof, + pub rounds: Vec, } impl Hintable for QueryPhaseVerifierInput { @@ -252,34 +251,20 @@ impl Hintable for QueryPhaseVerifierInput { fn read(builder: &mut Builder) -> Self::HintVariable { // let t_inv_halves = Vec::>::read(builder); let max_num_var = Usize::Var(usize::read(builder)); - let indices = Vec::::read(builder); - let final_message = Vec::>::read(builder); let batch_coeffs = Vec::::read(builder); - let queries = QueryOpeningProofs::read(builder); - let fixed_is_some = Usize::Var(usize::read(builder)); - let fixed_comm = BasefoldCommitment::read(builder); - let witin_comm = BasefoldCommitment::read(builder); - let circuit_meta = Vec::::read(builder); - let commits = Vec::::read(builder); let fold_challenges = Vec::::read(builder); - let sumcheck_messages = Vec::::read(builder); - let point_evals = Vec::::read(builder); + let indices = Vec::::read(builder); + let proof = BasefoldProof::read(builder); + let rounds = Vec::::read(builder); QueryPhaseVerifierInputVariable { // t_inv_halves, max_num_var, - indices, - final_message, batch_coeffs, - queries, - fixed_is_some, - fixed_comm, - witin_comm, - circuit_meta, - commits, fold_challenges, - sumcheck_messages, - point_evals, + indices, + proof, + rounds, } } @@ -287,37 +272,11 @@ impl Hintable for QueryPhaseVerifierInput { let mut stream = Vec::new(); // stream.extend(self.t_inv_halves.write()); stream.extend(>::write(&self.max_num_var)); - stream.extend(self.indices.write()); - stream.extend(self.final_message.write()); stream.extend(self.batch_coeffs.write()); - stream.extend(self.queries.write()); - if let Some(fixed_comm) = &self.fixed_comm { - stream.extend(>::write(&1)); - stream.extend(fixed_comm.write()); - } else { - stream.extend(>::write(&0)); - let tmp_comm = BasefoldCommitment { - commit: Default::default(), - log2_max_codeword_size: 0, - trivial_commits: vec![], - }; - stream.extend(tmp_comm.write()); - } - stream.extend(self.witin_comm.write()); - stream.extend(self.circuit_meta.write()); - stream.extend(self.commits.write()); stream.extend(self.fold_challenges.write()); - stream.extend(self.sumcheck_messages.write()); - stream.extend( - self.point_evals - .iter() - .map(|(p, e)| PointAndEvals { - point: p.clone(), - evals: e.clone(), - }) - .collect::>() - .write(), - ); + stream.extend(self.indices.write()); + stream.extend(self.proof.write()); + stream.extend(self.rounds.write()); stream } } @@ -326,30 +285,24 @@ impl Hintable for QueryPhaseVerifierInput { pub struct QueryPhaseVerifierInputVariable { // pub t_inv_halves: Array>>, pub max_num_var: Usize, - pub indices: Array>, - pub final_message: Array>>, pub batch_coeffs: Array>, - pub queries: QueryOpeningProofsVariable, - pub fixed_is_some: Usize, // 0 <==> false - pub fixed_comm: BasefoldCommitmentVariable, - pub witin_comm: BasefoldCommitmentVariable, - pub circuit_meta: Array>, - pub commits: Array>, pub fold_challenges: Array>, - pub sumcheck_messages: Array>, - pub point_evals: Array>, + pub indices: Array>, + pub proof: BasefoldProofVariable, + pub rounds: Array>, +} + +#[derive(DslVariable, Clone)] +pub struct PackedCodeword { + pub low: Ext, + pub high: Ext, } pub(crate) fn batch_verifier_query_phase( builder: &mut Builder, input: QueryPhaseVerifierInputVariable, ) { - // Nondeterministically supply inv_2 - let inv_2 = builder.hint_felt(); - builder.assert_eq::>( - inv_2 * C::F::from_canonical_usize(2), - C::F::from_canonical_usize(1), - ); + let inv_2 = builder.constant(C::F::from_canonical_u32(0x3c000001)); let two_adic_generators_inverses: Array> = builder.dyn_array(28); for (index, val) in [ 0x1usize, 0x78000000, 0x67055c21, 0x5ee99486, 0xbb4c4e4, 0x2d4cc4da, 0x669d6090, @@ -365,18 +318,19 @@ pub(crate) fn batch_verifier_query_phase( } // encode_small - let final_rmm_values_len = builder.get(&input.final_message, 0).len(); + let final_message = &input.proof.final_message; + let final_rmm_values_len = builder.get(final_message, 0).len(); let final_rmm_values = builder.dyn_array(final_rmm_values_len.clone()); builder .range(0, final_rmm_values_len.clone()) .for_each(|i_vec, builder| { let i = i_vec[0]; - let row_len = input.final_message.len(); + let row_len = final_message.len(); let sum = builder.constant(C::EF::ZERO); builder.range(0, row_len).for_each(|j_vec, builder| { let j = j_vec[0]; - let row = builder.get(&input.final_message, j); + let row = builder.get(final_message, j); let row_j = builder.get(&row, i); builder.assign(&sum, sum + row_j); }); @@ -388,406 +342,241 @@ pub(crate) fn batch_verifier_query_phase( width: builder.eval(Usize::from(1)), }; let final_codeword = encode_small(builder, final_rmm); - // can't use witin_comm.log2_max_codeword_size since it's untrusted - let log2_witin_max_codeword_size: Var = - builder.eval(input.max_num_var.clone() + get_rate_log::()); - - // Nondeterministically supply the index folding_sorted_order - // Check that: - // 1. It has the same length as input.circuit_meta (checked by requesting folding_len hints) - // 2. It does not contain the same index twice (checked via a correspondence array) - // 3. Indexed witin_num_vars are sorted in decreasing order - // Infer witin_num_vars through index - let folding_len = input.circuit_meta.len(); - let zero: Ext = builder.constant(C::EF::ZERO); - let folding_sort_surjective: Array> = - builder.dyn_array(folding_len.clone()); - builder - .range(0, folding_len.clone()) - .for_each(|i_vec, builder| { - let i = i_vec[0]; - builder.set(&folding_sort_surjective, i, zero.clone()); - }); - - // an vector with same length as circuit_meta, which is sorted by num_var in descending order and keep its index - // for reverse lookup when retrieving next base codeword to involve into batching - // Sort input.dimensions by height, returns - // 1. height_order: after sorting by decreasing height, the original index of each entry - // 2. num_unique_height: number of different heights - // 3. count_per_unique_height: for each unique height, number of dimensions of that height - // builder.assert_nonzero(&Usize::from(0)); - let ( - folding_sorted_order_index, - num_unique_num_vars, - count_per_unique_num_var, - sorted_unique_num_vars, - ) = sort_with_count( - builder, - &input.circuit_meta, - |m: CircuitIndexMetaVariable| m.witin_num_vars, - ); - - builder - .range(0, input.indices.len()) - .for_each(|i_vec, builder| { - let i = i_vec[0]; - let idx = builder.get(&input.indices, i); - let query = builder.get(&input.queries, i); - let witin_opened_values = query.witin_base_proof.opened_values; - let witin_opening_proof = query.witin_base_proof.opening_proof; - let fixed_is_some = query.fixed_is_some; - let fixed_commit = query.fixed_base_proof; - let opening_ext = query.commit_phase_openings; + let log2_max_codeword_size: Var = + builder.eval(input.max_num_var.clone() + get_rate_log::()); - // verify base oracle query proof - // refer to prover documentation for the reason of right shift by 1 - // The index length is the logarithm of the maximal codeword size. - let idx_len: Var = builder.eval(input.max_num_var.clone() + get_rate_log::()); - let idx_felt = builder.unsafe_cast_var_to_felt(idx); - let idx_bits = builder.num2bits_f(idx_felt, C::N::bits() as u32); + iter_zip!(builder, input.indices, input.proof.query_opening_proof).for_each( + |ptr_vec, builder| { + // TODO: change type of input.indices to be `Array>>` + let idx = builder.iter_ptr_get(&input.indices, ptr_vec[0]); + let idx = builder.unsafe_cast_var_to_felt(idx); + let idx_bits = builder.num2bits_f(idx, C::N::bits() as u32); + // assert idx_bits[log2_max_codeword_size..] == 0 builder - .range(idx_len, idx_bits.len()) + .range(log2_max_codeword_size, idx_bits.len()) .for_each(|i_vec, builder| { let bit = builder.get(&idx_bits, i_vec[0]); builder.assert_eq::>(bit, Usize::from(0)); }); - - // Right shift - let idx_len_minus_one: Var = builder.eval(idx_len - Usize::from(1)); - let idx_half = builder.hint_var(); - let lsb = builder.get(&idx_bits, 0); - builder.assert_var_eq(Usize::from(2) * idx_half + lsb, idx); - - builder.assign(&idx_len, idx_len_minus_one); - builder.assign(&idx, idx_half); - - let (witin_dimensions, fixed_dimensions) = - get_base_codeword_dimensions(builder, input.circuit_meta.clone()); - // verify witness - let mmcs_verifier_input = MmcsVerifierInputVariable { - commit: input.witin_comm.commit.clone(), - dimensions: witin_dimensions, - index_bits: idx_bits.clone().slice(builder, 1, idx_len), // Remove the first bit because two entries are grouped into one leaf in the Merkle tree - opened_values: witin_opened_values.clone(), - proof: witin_opening_proof, - }; - mmcs_verify_batch(builder, mmcs_verifier_input); - - // verify fixed - let fixed_commit_leafs = builder.dyn_array(0); - builder - .if_eq(fixed_is_some.clone(), Usize::from(1)) - .then(|builder| { - let fixed_opened_values = fixed_commit.opened_values.clone(); - - let fixed_opening_proof = fixed_commit.opening_proof.clone(); - // new_idx used by mmcs proof - let new_idx: Var = builder.eval(idx); - // Nondeterministically supply a hint: - // 0: input.fixed_comm.log2_max_codeword_size < log2_witin_max_codeword_size - // 1: >= - let branch_le = builder.hint_var(); - builder.if_eq(branch_le, Usize::from(0)).then(|builder| { - // input.fixed_comm.log2_max_codeword_size < log2_witin_max_codeword_size - builder.assert_less_than_slow_small_rhs( - input.fixed_comm.log2_max_codeword_size.clone(), - log2_witin_max_codeword_size, - ); - // idx >> idx_shift - let idx_shift_remain: Var = builder.eval( - idx_len - - (log2_witin_max_codeword_size - - input.fixed_comm.log2_max_codeword_size.clone()), - ); - let tmp_idx = bin_to_dec(builder, &idx_bits, idx_shift_remain); - builder.assign(&new_idx, tmp_idx); - }); - builder.if_ne(branch_le, Usize::from(0)).then(|builder| { - // input.fixed_comm.log2_max_codeword_size >= log2_witin_max_codeword_size - let input_codeword_size_plus_one: Var = builder - .eval(input.fixed_comm.log2_max_codeword_size.clone() + Usize::from(1)); - builder.assert_less_than_slow_small_rhs( - log2_witin_max_codeword_size, - input_codeword_size_plus_one, - ); - // idx << -idx_shift - let idx_shift = builder.eval( - input.fixed_comm.log2_max_codeword_size.clone() - - log2_witin_max_codeword_size, - ); - let idx_factor = pow_2(builder, idx_shift); - builder.assign(&new_idx, new_idx * idx_factor); + let idx_bits = idx_bits.slice(builder, 1, log2_max_codeword_size.clone()); + + let reduced_codeword_by_height: Array> = + builder.dyn_array(log2_max_codeword_size.clone()); + + let query = builder.iter_ptr_get(&input.proof.query_opening_proof, ptr_vec[1]); + let batch_coeffs_offset: Var = builder.constant(C::N::ZERO); + + iter_zip!(builder, query.input_proofs, input.rounds).for_each(|ptr_vec, builder| { + let batch_opening = builder.iter_ptr_get(&query.input_proofs, ptr_vec[0]); + let round = builder.iter_ptr_get(&input.rounds, ptr_vec[1]); + let opened_values = batch_opening.opened_values; + let perm_opened_values = builder.dyn_array(opened_values.len()); + let dimensions = builder.dyn_array(opened_values.len()); + let opening_proof = batch_opening.opening_proof; + + // reorder (opened values, dimension) according to the permutation + builder + .range(0, opened_values.len()) + .for_each(|j_vec, builder| { + let j = j_vec[0]; + let mat_j = builder.get(&opened_values, j); + let num_var_j = builder.get(&round.openings, j).num_var; + let height_j = + builder.eval(num_var_j + get_rate_log::() - Usize::from(1)); + + // TODO: use permutation to get the index + // let permuted_index = builder.get(&perm, j); + let permuted_j = j; + + builder.set_value(&perm_opened_values, permuted_j, mat_j); + builder.set_value(&dimensions, permuted_j, height_j); }); - // verify witness - let mmcs_verifier_input = MmcsVerifierInputVariable { - commit: input.fixed_comm.commit.clone(), - dimensions: fixed_dimensions.clone(), - index_bits: idx_bits.clone().slice(builder, 1, idx_len), - opened_values: fixed_opened_values.clone(), - proof: fixed_opening_proof, - }; - mmcs_verify_batch(builder, mmcs_verifier_input); - builder.assign(&fixed_commit_leafs, fixed_opened_values); - }); - // base_codeword_lo_hi - let base_codeword_lo = builder.dyn_array(folding_len.clone()); - let base_codeword_hi = builder.dyn_array(folding_len.clone()); - builder - .range(0, folding_len.clone()) - .for_each(|j_vec, builder| { - let j = j_vec[0]; - let circuit_meta = builder.get(&input.circuit_meta, j); - let witin_num_polys = circuit_meta.witin_num_polys; - let fixed_num_vars = circuit_meta.fixed_num_vars; - let fixed_num_polys = circuit_meta.fixed_num_polys; - let witin_leafs = builder.get(&witin_opened_values, j); - // lo_wit, hi_wit - let leafs_len_div_2 = builder.hint_var(); - let two: Var = builder.eval(Usize::from(2)); - builder.assert_eq::>(leafs_len_div_2 * two, witin_leafs.len()); // Can we assume that leafs.len() is even? - // Actual dot_product only computes the first num_polys terms (can we assume leafs_len_div_2 == num_polys?) - let lo_wit = dot_product(builder, &input.batch_coeffs, &witin_leafs); - let hi_wit = dot_product_with_index( + // i >>= (log2_max_codeword_size - commit.log2_max_codeword_size); + let bits_shift: Var = builder + .eval(log2_max_codeword_size.clone() - round.commit.log2_max_codeword_size); + let reduced_idx_bits = idx_bits.slice(builder, bits_shift, idx_bits.len()); + + // verify input mmcs + let mmcs_verifier_input = MmcsVerifierInputVariable { + commit: round.commit.commit.clone(), + dimensions: dimensions, + index_bits: reduced_idx_bits, + opened_values: perm_opened_values, + proof: opening_proof, + }; + + mmcs_verify_batch(builder, mmcs_verifier_input); + + // TODO: optimize this procedure + iter_zip!(builder, opened_values, round.openings).for_each(|ptr_vec, builder| { + let opened_value = builder.iter_ptr_get(&opened_values, ptr_vec[0]); + let opening = builder.iter_ptr_get(&round.openings, ptr_vec[1]); + let log2_height: Var = + builder.eval(opening.num_var + get_rate_log::() - Usize::from(1)); + let width = opening.point_and_evals.evals.len(); + + let batch_coeffs_next_offset: Var = + builder.eval(batch_coeffs_offset + width.clone()); + let coeffs = input.batch_coeffs.slice( builder, - &input.batch_coeffs, - &witin_leafs, - Usize::from(0), - Usize::Var(leafs_len_div_2), - witin_num_polys.clone(), + batch_coeffs_offset.clone(), + batch_coeffs_next_offset.clone(), + ); + let low_values = opened_value.slice(builder, 0, width.clone()); + let high_values = + opened_value.slice(builder, width.clone(), opened_value.len()); + let low: Ext = builder.constant(C::EF::ZERO); + let high: Ext = builder.constant(C::EF::ZERO); + + iter_zip!(builder, coeffs, low_values, high_values).for_each( + |ptr_vec, builder| { + let coeff = builder.iter_ptr_get(&coeffs, ptr_vec[0]); + let low_value = builder.iter_ptr_get(&low_values, ptr_vec[1]); + let high_value = builder.iter_ptr_get(&high_values, ptr_vec[2]); + + builder.assign(&low, low + coeff * low_value); + builder.assign(&high, high + coeff * high_value); + }, ); - // lo_fixed, hi_fixed - let lo_fixed: Ext = - builder.constant(C::EF::from_canonical_usize(0)); - let hi_fixed: Ext = - builder.constant(C::EF::from_canonical_usize(0)); - builder - .if_ne(fixed_num_vars, Usize::from(0)) - .then(|builder| { - let fixed_leafs = builder.get(&fixed_commit_leafs, j); - let leafs_len_div_2: Var<::N> = builder.hint_var(); - let two: Var = builder.eval(Usize::from(2)); - builder - .assert_eq::>(leafs_len_div_2 * two, fixed_leafs.len()); // Can we assume that leafs.len() is even? - // Actual dot_product only computes the first num_polys terms (can we assume leafs_len_div_2 == num_polys?) - let tmp_lo_fixed = - dot_product(builder, &input.batch_coeffs, &fixed_leafs); - let tmp_hi_fixed = dot_product_with_index( - builder, - &input.batch_coeffs, - &fixed_leafs, - Usize::from(0), - Usize::Var(leafs_len_div_2), - fixed_num_polys.clone(), - ); - builder.assign(&lo_fixed, tmp_lo_fixed); - builder.assign(&hi_fixed, tmp_hi_fixed); - }); - let lo: Ext = builder.eval(lo_wit + lo_fixed); - let hi: Ext = builder.eval(hi_wit + hi_fixed); - builder.set_value(&base_codeword_lo, j, lo); - builder.set_value(&base_codeword_hi, j, hi); + let codeword = PackedCodeword { low, high }; + builder.set_value(&reduced_codeword_by_height, log2_height, codeword); + builder.assign(&batch_coeffs_offset, batch_coeffs_next_offset); }); + }); + + let opening_ext = query.commit_phase_openings; - // fold and query + // fold 1st codeword let cur_num_var: Var = builder.eval(input.max_num_var.clone()); - // let rounds: Var = builder.eval(cur_num_var - get_basecode_msg_size_log::() - Usize::from(1)); - let n_d_next_log: Var = + let log2_height: Var = builder.eval(cur_num_var + get_rate_log::() - Usize::from(1)); - // let n_d_next = pow_2(builder, n_d_next_log); - // first folding challenge let r = builder.get(&input.fold_challenges, 0); - let next_unique_num_vars_count: Var = builder.get(&count_per_unique_num_var, 0); - let folded: Ext = builder.constant(C::EF::ZERO); - builder - .range(0, next_unique_num_vars_count) - .for_each(|j_vec, builder| { - let j = j_vec[0]; - let index = builder.get(&folding_sorted_order_index, j); - let lo = builder.get(&base_codeword_lo, index.clone()); - let hi = builder.get(&base_codeword_hi, index.clone()); - let level: Var = - builder.eval(cur_num_var + get_rate_log::() - Usize::from(1)); - let sliced_bits = idx_bits.clone().slice(builder, 1, idx_len); - // let expected_coeffs = builder.get(&input.t_inv_halves, level); - // let expected_coeff = builder.get(&expected_coeffs, idx); // TODO: remove this and directly use the result from verifier_folding_coeffs_level function - let coeff = verifier_folding_coeffs_level( - builder, - &two_adic_generators_inverses, - level, - &sliced_bits, - inv_2, - ); - // builder.assert_eq::>(coeff, expected_coeff); - // builder.halt(); - let fold = codeword_fold_with_challenge::(builder, lo, hi, r, coeff, inv_2); - builder.assign(&folded, folded + fold); - }); - let next_unique_num_vars_index: Var = builder.eval(Usize::from(1)); - let cumul_num_vars_count: Var = builder.eval(next_unique_num_vars_count); - let n_d_i_log: Var = builder.eval(n_d_next_log); - // let n_d_i: Var = builder.eval(n_d_next); - // zip_eq + let codeword = builder.get(&reduced_codeword_by_height, log2_height); + let coeff = verifier_folding_coeffs_level( + builder, + &two_adic_generators_inverses, + log2_height, + &idx_bits, + inv_2, + ); + let folded = codeword_fold_with_challenge::( + builder, + codeword.low, + codeword.high, + r, + coeff, + inv_2, + ); + + // check commit phases + let commits = &input.proof.commits; builder.assert_eq::>( - input.commits.len() + Usize::from(1), + commits.len() + Usize::from(1), input.fold_challenges.len(), ); - builder.assert_eq::>(input.commits.len(), opening_ext.len()); + builder.assert_eq::>(commits.len(), opening_ext.len()); + builder.range(0, commits.len()).for_each(|i_vec, builder| { + let i = i_vec[0]; + let commit = builder.get(&commits, i); + let commit_phase_step = builder.get(&opening_ext, i); + let i_plus_one = builder.eval_expr(i + Usize::from(1)); + + let sibling_value = commit_phase_step.sibling_value; + let proof = commit_phase_step.opening_proof; + + builder.assign(&cur_num_var, cur_num_var - Usize::from(1)); + builder.assign(&log2_height, log2_height - Usize::from(1)); + + let folded_idx = builder.get(&idx_bits, i); + // TODO: absorb smaller codeword + let new_involved_codeword: Ext = builder.constant(C::EF::ZERO); + + // leafs + let leafs = builder.dyn_array(2); + let sibling_idx = builder.eval_expr(RVar::from(1) - folded_idx); + builder.assign(&folded, folded + new_involved_codeword); + builder.set_value(&leafs, folded_idx, folded); + builder.set_value(&leafs, sibling_idx, sibling_value); + + // idx >>= 1 + let idx_pair = idx_bits.slice(builder, i_plus_one, idx_bits.len()); + + // mmcs_ext.verify_batch + let dimensions = builder.dyn_array(1); + let opened_values = builder.dyn_array(1); + builder.set_value(&opened_values, 0, leafs.clone()); + builder.set_value(&dimensions, 0, log2_height.clone()); + let ext_mmcs_verifier_input = ExtMmcsVerifierInputVariable { + commit: commit.clone(), + dimensions, + index_bits: idx_pair.clone(), + opened_values, + proof, + }; + ext_mmcs_verify_batch::(builder, ext_mmcs_verifier_input); + + let r = builder.get(&input.fold_challenges, i_plus_one); + let coeff = verifier_folding_coeffs_level( + builder, + &two_adic_generators_inverses, + log2_height, + &idx_pair, + inv_2, + ); + let left = builder.get(&leafs, 0); + let right = builder.get(&leafs, 1); + let new_folded = + codeword_fold_with_challenge(builder, left, right, r, coeff, inv_2); + builder.assign(&folded, new_folded); + }); + + // assert that final_value[i] = folded + let final_idx: Var = builder.constant(C::N::ZERO); builder - .range(0, input.commits.len()) - .for_each(|j_vec, builder| { - let j = j_vec[0]; - let pi_comm = builder.get(&input.commits, j); - let j_plus_one = builder.eval_expr(j + RVar::from(1)); - let j_plus_two = builder.eval(j + RVar::from(2)); - let r = builder.get(&input.fold_challenges, j_plus_one); - let leaf = builder.get(&opening_ext, j).sibling_value; - let proof = builder.get(&opening_ext, j).opening_proof; - builder.assign(&cur_num_var, cur_num_var - Usize::from(1)); - - // next folding challenges - let is_interpolate_to_right_index = builder.get(&idx_bits, j_plus_one); - let new_involved_codewords: Ext = builder.constant(C::EF::ZERO); - builder - .if_ne(next_unique_num_vars_index, num_unique_num_vars) - .then(|builder| { - let next_unique_num_vars: Var = - builder.get(&sorted_unique_num_vars, next_unique_num_vars_index); - builder - .if_eq(next_unique_num_vars, cur_num_var) - .then(|builder| { - let next_unique_num_vars_count: Var = builder - .get(&count_per_unique_num_var, next_unique_num_vars_index); - builder.range(0, next_unique_num_vars_count).for_each( - |k_vec, builder| { - let k = - builder.eval_expr(k_vec[0] + cumul_num_vars_count); - let index = builder.get(&folding_sorted_order_index, k); - let lo = builder.get(&base_codeword_lo, index.clone()); - let hi = builder.get(&base_codeword_hi, index.clone()); - builder - .if_eq( - is_interpolate_to_right_index, - Usize::from(1), - ) - .then(|builder| { - builder.assign( - &new_involved_codewords, - new_involved_codewords + hi, - ); - }); - builder - .if_ne( - is_interpolate_to_right_index, - Usize::from(1), - ) - .then(|builder| { - builder.assign( - &new_involved_codewords, - new_involved_codewords + lo, - ); - }); - }, - ); - builder.assign( - &cumul_num_vars_count, - cumul_num_vars_count + next_unique_num_vars_count, - ); - builder.assign( - &next_unique_num_vars_index, - next_unique_num_vars_index + Usize::from(1), - ); - }); - }); - - // leafs - let leafs = builder.dyn_array(2); - let new_leaf = builder.eval(folded + new_involved_codewords); - builder - .if_eq(is_interpolate_to_right_index, Usize::from(1)) - .then(|builder| { - builder.set_value(&leafs, 0, leaf); - builder.set_value(&leafs, 1, new_leaf); - }); - builder - .if_ne(is_interpolate_to_right_index, Usize::from(1)) - .then(|builder| { - builder.set_value(&leafs, 0, new_leaf); - builder.set_value(&leafs, 1, leaf); - }); - // idx >>= 1 - let idx_len_minus_one: Var = builder.eval(idx_len - Usize::from(1)); - builder.assign(&idx_len, idx_len_minus_one); - let idx_end = builder.eval(input.max_num_var.clone() + get_rate_log::()); - let new_idx = bin_to_dec_le(builder, &idx_bits, j_plus_two, idx_end); - let first_bit = builder.get(&idx_bits, j_plus_one); - builder.assert_eq::>(Usize::from(2) * new_idx + first_bit, idx); - builder.assign(&idx, new_idx); - // n_d_i >> 1 - builder.assign(&n_d_i_log, n_d_i_log - Usize::from(1)); - // mmcs_ext.verify_batch - let dimensions = builder.dyn_array(1); - // let two: Var<_> = builder.eval(Usize::from(2)); - builder.set_value(&dimensions, 0, n_d_i_log.clone()); - let opened_values = builder.dyn_array(1); - builder.set_value(&opened_values, 0, leafs.clone()); - let ext_mmcs_verifier_input = ExtMmcsVerifierInputVariable { - commit: pi_comm.clone(), - dimensions, - index_bits: idx_bits.clone().slice(builder, j_plus_two, idx_end), - opened_values, - proof, - }; - ext_mmcs_verify_batch::(builder, ext_mmcs_verifier_input); // FIXME: the Merkle roots do not match - - let sliced_bits = idx_bits.clone().slice(builder, j_plus_two, idx_len); - let coeff = verifier_folding_coeffs_level( - builder, - &two_adic_generators_inverses, - n_d_i_log.clone(), - &sliced_bits, - inv_2, + .range(commits.len(), idx_bits.len()) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + let bit = builder.get(&idx_bits, i); + builder.assign( + &final_idx, + final_idx * SymbolicVar::from(C::N::from_canonical_u16(2)) + bit, ); - let left = builder.get(&leafs, 0); - let right = builder.get(&leafs, 1); - let new_folded = - codeword_fold_with_challenge(builder, left, right, r.clone(), coeff, inv_2); - builder.assign(&folded, new_folded); }); - let final_value = builder.get(&final_codeword.values, idx.clone()); + let final_value = builder.get(&final_codeword.values, final_idx); builder.assert_eq::>(final_value, folded); - }); + }, + ); // 1. check initial claim match with first round sumcheck value - let points = builder.dyn_array(input.batch_coeffs.len()); - let next_point_index: Var = builder.eval(Usize::from(0)); - builder - .range(0, input.point_evals.len()) - .for_each(|i_vec, builder| { - let i = i_vec[0]; - let evals = builder.get(&input.point_evals, i).evals; - let witin_num_vars = builder.get(&input.circuit_meta, i).witin_num_vars; - // we need to scale up with scalar for witin_num_vars < max_num_var - let scale_log = builder.eval(input.max_num_var.clone() - witin_num_vars); - let scale = pow_2(builder, scale_log); - // Transform scale into a field element - let scale = builder.unsafe_cast_var_to_felt(scale); - builder.range(0, evals.len()).for_each(|j_vec, builder| { - let j = j_vec[0]; - let eval = builder.get(&evals, j); - let scaled_eval: Ext = builder.eval(eval * scale); - builder.set_value(&points, next_point_index, scaled_eval); - builder.assign(&next_point_index, next_point_index + Usize::from(1)); + let batch_coeffs_offset: Var = builder.constant(C::N::ZERO); + let expected_sum: Ext = builder.constant(C::EF::ZERO); + iter_zip!(builder, input.rounds).for_each(|ptr_vec, builder| { + let round = builder.iter_ptr_get(&input.rounds, ptr_vec[0]); + iter_zip!(builder, round.openings).for_each(|ptr_vec, builder| { + let opening = builder.iter_ptr_get(&round.openings, ptr_vec[0]); + // TODO: filter out openings with num_var >= get_basecode_msg_size_log::() + let var_diff: Var = builder.eval(input.max_num_var.get_var() - opening.num_var); + let scalar_var = pow_2(builder, var_diff); + let scalar = builder.unsafe_cast_var_to_felt(scalar_var); + iter_zip!(builder, opening.point_and_evals.evals).for_each(|ptr_vec, builder| { + let eval = builder.iter_ptr_get(&opening.point_and_evals.evals, ptr_vec[0]); + let coeff = builder.get(&input.batch_coeffs, batch_coeffs_offset); + let val: Ext = builder.eval(eval * coeff * scalar); + builder.assign(&expected_sum, expected_sum + val); + builder.assign(&batch_coeffs_offset, batch_coeffs_offset + Usize::from(1)); }); }); - let left = dot_product(builder, &input.batch_coeffs, &points); - let next_sumcheck_evals = builder.get(&input.sumcheck_messages, 0).evaluations; - let eval0 = builder.get(&next_sumcheck_evals, 0); - let eval1 = builder.get(&next_sumcheck_evals, 1); - let right: Ext = builder.eval(eval0 + eval1); - builder.assert_eq::>(left, right); + }); + let sum: Ext = { + let sumcheck_evals = builder.get(&input.proof.sumcheck_proof, 0).evaluations; + let eval0 = builder.get(&sumcheck_evals, 0); + let eval1 = builder.get(&sumcheck_evals, 1); + builder.eval(eval0 + eval1) + }; + builder.assert_eq::>(expected_sum, sum); // 2. check every round of sumcheck match with prev claims let fold_len_minus_one: Var = builder.eval(input.fold_challenges.len() - Usize::from(1)); @@ -795,12 +584,12 @@ pub(crate) fn batch_verifier_query_phase( .range(0, fold_len_minus_one) .for_each(|i_vec, builder| { let i = i_vec[0]; - let evals = builder.get(&input.sumcheck_messages, i).evaluations; + let evals = builder.get(&input.proof.sumcheck_proof, i).evaluations; let challenge = builder.get(&input.fold_challenges, i); let left = interpolate_uni_poly(builder, &evals, challenge); let i_plus_one = builder.eval_expr(i + Usize::from(1)); let next_evals = builder - .get(&input.sumcheck_messages, i_plus_one) + .get(&input.proof.sumcheck_proof, i_plus_one) .evaluations; let eval0 = builder.get(&next_evals, 0); let eval1 = builder.get(&next_evals, 1); @@ -810,20 +599,27 @@ pub(crate) fn batch_verifier_query_phase( // 3. check final evaluation are correct let final_evals = builder - .get(&input.sumcheck_messages, fold_len_minus_one.clone()) + .get(&input.proof.sumcheck_proof, fold_len_minus_one.clone()) .evaluations; let final_challenge = builder.get(&input.fold_challenges, fold_len_minus_one.clone()); let left = interpolate_uni_poly(builder, &final_evals, final_challenge); let right: Ext = builder.constant(C::EF::ZERO); - builder - .range(0, input.final_message.len()) - .for_each(|i_vec, builder| { - let i = i_vec[0]; - let final_message = builder.get(&input.final_message, i); - let point = builder.get(&input.point_evals, i).point; - // coeff is the eq polynomial evaluated at the first challenge.len() variables + let one: Var = builder.constant(C::N::ONE); + let j: Var = builder.constant(C::N::ZERO); + // \sum_i eq(p, [r,i]) * f(r,i) + iter_zip!(builder, input.rounds,).for_each(|ptr_vec, builder| { + let round = builder.iter_ptr_get(&input.rounds, ptr_vec[0]); + // TODO: filter out openings with num_var >= get_basecode_msg_size_log::() + iter_zip!(builder, round.openings).for_each(|ptr_vec, builder| { + let opening = builder.iter_ptr_get(&round.openings, ptr_vec[0]); + let point_and_evals = &opening.point_and_evals; + let point = &point_and_evals.point; + let num_vars_evaluated: Var = builder.eval(point.fs.len() - get_basecode_msg_size_log::()); + let final_message = builder.get(&input.proof.final_message, j); + + // coeff is the eq polynomial evaluated at the first challenge.len() variables let ylo = builder.eval(input.fold_challenges.len() - num_vars_evaluated); let coeff = eq_eval_with_index( builder, @@ -833,49 +629,38 @@ pub(crate) fn batch_verifier_query_phase( Usize::Var(ylo), Usize::Var(num_vars_evaluated), ); - // We assume that the final message is of size 1, so the eq poly is just - // vec![one]. - // let eq = build_eq_x_r_vec_sequential_with_offset::( - // builder, - // &point.fs, - // Usize::Var(num_vars_evaluated), - // ); - // eq_coeff = eq * coeff - // let eq_coeff = builder.dyn_array(eq.len()); - // builder.range(0, eq.len()).for_each(|j_vec, builder| { - // let j = j_vec[0]; - // let next_eq = builder.get(&eq, j); - // let next_eq_coeff: Ext = builder.eval(next_eq * coeff); - // builder.set_value(&eq_coeff, j, next_eq_coeff); - // }); - // let dot_prod = dot_product(builder, &final_message, &eq_coeff); - - // Again assuming final message is a single element + + // compute \sum_i eq(p[..num_vars_evaluated], r) * eq(p[num_vars_evaluated..], i) * f(r,i) + // + // We always assume that num_vars_evaluated is equal to p.len() + // so that the above sum only has one item and the final evaluation vector has only one element. + builder.assert_eq::>(final_message.len(), one); let final_message = builder.get(&final_message, 0); - // Again, eq polynomial is just one let dot_prod: Ext = builder.eval(final_message * coeff); builder.assign(&right, right + dot_prod); + + builder.assign(&j, j + Usize::from(1)); }); + }); + builder.assert_eq::>(j, input.proof.final_message.len()); builder.assert_eq::>(left, right); } #[cfg(test)] pub mod tests { - use std::{cmp::Reverse, collections::BTreeMap, iter::once}; - - use ceno_mle::mle::MultilinearExtension; use ceno_transcript::{BasicTranscript, Transcript}; use ff_ext::{BabyBearExt4, FromUniformBytes}; use itertools::Itertools; - use mpcs::pcs_batch_verify; use mpcs::{ - pcs_batch_commit, pcs_batch_open, pcs_setup, pcs_trim, - util::hash::write_digest_to_transcript, BasefoldDefault, PolynomialCommitmentScheme, + pcs_batch_commit, pcs_setup, pcs_trim, util::hash::write_digest_to_transcript, + BasefoldDefault, PolynomialCommitmentScheme, }; + use mpcs::{BasefoldRSParams, BasefoldSpec, PCSFriParam}; use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; use openvm_native_circuit::{Native, NativeConfig}; use openvm_native_compiler::asm::AsmBuilder; use openvm_native_recursion::hints::Hintable; + use openvm_stark_backend::p3_challenger::GrindingChallenger; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::Field; use p3_field::FieldAlgebra; @@ -885,14 +670,9 @@ pub mod tests { type E = BabyBearExt4; type PCS = BasefoldDefault; - use crate::{ - basefold_verifier::{ - basefold::BasefoldCommitment, - query_phase::{BatchOpening, CommitPhaseProofStep, QueryOpeningProof}, - structs::CircuitIndexMeta, - }, - tower_verifier::binding::{Point, PointAndEval}, - }; + use crate::basefold_verifier::basefold::{Round, RoundOpening}; + use crate::basefold_verifier::query_phase::PointAndEvals; + use crate::tower_verifier::binding::{Point, PointAndEval}; use super::{batch_verifier_query_phase, QueryPhaseVerifierInput}; @@ -910,50 +690,6 @@ pub mod tests { // prepare input let mut witness_stream: Vec> = Vec::new(); witness_stream.extend(input.write()); - witness_stream.push(vec![F::from_canonical_u32(2).inverse()]); - witness_stream.push(vec![F::from_canonical_usize( - input - .circuit_meta - .iter() - .unique_by(|x| x.witin_num_vars) - .count(), - )]); - witness_stream.push( - input - .circuit_meta - .iter() - .enumerate() - .sorted_by_key(|(_, CircuitIndexMeta { witin_num_vars, .. })| { - Reverse(witin_num_vars) - }) - .map(|(index, _)| F::from_canonical_usize(index)) - .collect_vec(), - ); - for (query, idx) in input.queries.iter().zip(input.indices.iter()) { - witness_stream.push(vec![F::from_canonical_usize(idx / 2)]); - if let Some(fixed_comm) = &input.fixed_comm { - let log2_witin_max_codeword_size = input.max_num_var + 1; - if log2_witin_max_codeword_size > fixed_comm.log2_max_codeword_size { - witness_stream.push(vec![F::ZERO]) - } else { - witness_stream.push(vec![F::ONE]) - } - } - for i in 0..input.circuit_meta.len() { - witness_stream.push(vec![F::from_canonical_usize( - query.witin_base_proof.opened_values[i].len() / 2, - )]); - if input.circuit_meta[i].fixed_num_vars > 0 { - witness_stream.push(vec![F::from_canonical_usize( - if let Some(fixed_base_proof) = &query.fixed_base_proof { - fixed_base_proof.opened_values[i].len() / 2 - } else { - 0 - }, - )]); - } - } - } (program, witness_stream) } @@ -963,48 +699,28 @@ pub mod tests { let mut rng = thread_rng(); let m1 = ceno_witness::RowMajorMatrix::::rand(&mut rng, 1 << 10, 10); let mles_1 = m1.to_mles(); - let matrices = BTreeMap::from_iter(once((0, m1))); + let matrices = vec![m1]; - let pp = pcs_setup::(1 << 20).unwrap(); + let pp = PCS::setup(1 << 20, mpcs::SecurityLevel::Conjecture100bits).unwrap(); let (pp, vp) = pcs_trim::(pp, 1 << 20).unwrap(); let pcs_data = pcs_batch_commit::(&pp, matrices).unwrap(); - let witin_comm = PCS::get_pure_commitment(&pcs_data); + let comm = PCS::get_pure_commitment(&pcs_data); + + let point = E::random_vec(10, &mut rng); + let evals = mles_1.iter().map(|mle| mle.evaluate(&point)).collect_vec(); - let points = vec![E::random_vec(10, &mut rng)]; - let evals = points - .iter() - .map(|p| mles_1.iter().map(|mle| mle.evaluate(p)).collect_vec()) - .collect::>(); // let evals = mles_1 // .iter() // .map(|mle| points.iter().map(|p| mle.evaluate(&p)).collect_vec()) // .collect::>(); let mut transcript = BasicTranscript::::new(&[]); - let opening_proof = pcs_batch_open::( - &pp, - &[(0, 1 << 10)], - None, - &pcs_data, - &points, - &evals, - &[(10, 0)], - &mut transcript, - ) - .unwrap(); + let rounds = vec![(&pcs_data, vec![(point.clone(), evals.clone())])]; + let opening_proof = PCS::batch_open(&pp, rounds, &mut transcript).unwrap(); let mut transcript = BasicTranscript::::new(&[]); - pcs_batch_verify::( - &vp, - &[(0, 1 << 10)], - &points, - None, - &witin_comm, - &evals, - &opening_proof, - &[(10, 0)], - &mut transcript, - ) - .expect("Native verification failed"); + let rounds = vec![(comm, vec![(point.len(), (point, evals.clone()))])]; + PCS::batch_verify(&vp, rounds.clone(), &opening_proof, &mut transcript) + .expect("Native verification failed"); let mut transcript = BasicTranscript::::new(&[]); let batch_coeffs = transcript.sample_and_append_challenge_pows(10, b"batch coeffs"); @@ -1030,72 +746,44 @@ pub mod tests { } transcript.append_field_element_exts_iter(opening_proof.final_message.iter().flatten()); - let queries = opening_proof - .query_opening_proof - .iter() - .map(|query| QueryOpeningProof { - witin_base_proof: BatchOpening { - opened_values: query.witin_base_proof.opened_values.clone(), - opening_proof: query.witin_base_proof.opening_proof.clone(), - }, - fixed_base_proof: None, - commit_phase_openings: query - .commit_phase_openings - .iter() - .map(|step| CommitPhaseProofStep { - sibling_value: step.sibling_value.clone(), - opening_proof: step.opening_proof.clone(), - }) - .collect(), - }) - .collect(); + // check pow + let pow_bits = vp.get_pow_bits_by_level(mpcs::PowStrategy::FriPow); + if pow_bits > 0 { + assert!(transcript.check_witness(pow_bits, opening_proof.pow_witness)); + } + + let queries: Vec<_> = transcript.sample_bits_and_append_vec( + b"query indices", + >::get_number_queries(), + max_num_var + >::get_rate_log(), + ); let query_input = QueryPhaseVerifierInput { // t_inv_halves: vp.encoding_params.t_inv_halves, max_num_var: 10, - indices: opening_proof.query_indices, - final_message: opening_proof.final_message, + fold_challenges, batch_coeffs, - queries, - fixed_comm: None, - witin_comm: BasefoldCommitment { - commit: witin_comm.commit().into(), - trivial_commits: witin_comm - .trivial_commits - .iter() - .copied() - .map(|c| c.into()) - .collect(), - log2_max_codeword_size: 20, - // This is a dummy value, should be set according to the actual codeword size - }, - circuit_meta: vec![CircuitIndexMeta { - witin_num_vars: 10, - fixed_num_vars: 0, - witin_num_polys: 10, - fixed_num_polys: 0, - }], - commits: opening_proof - .commits + indices: queries, + proof: opening_proof.into(), + rounds: rounds .iter() - .copied() - .map(|c| c.into()) - .collect(), - fold_challenges, - sumcheck_messages: opening_proof - .sumcheck_proof - .as_ref() - .unwrap() - .clone() - .into_iter() - .map(|msg| msg.into()) + .map(|round| Round { + commit: round.0.clone().into(), + openings: round + .1 + .iter() + .map(|opening| RoundOpening { + num_var: opening.0, + point_and_evals: PointAndEvals { + point: Point { + fs: opening.1.clone().0, + }, + evals: opening.1.clone().1, + }, + }) + .collect(), + }) .collect(), - point_evals: vec![( - Point { - fs: points[0].clone(), - }, - evals[0].clone(), - )], }; let (program, witness) = build_batch_verifier_query_phase(query_input); diff --git a/src/e2e/mod.rs b/src/e2e/mod.rs index 4afd4b6..5d5aead 100644 --- a/src/e2e/mod.rs +++ b/src/e2e/mod.rs @@ -5,6 +5,7 @@ use crate::zkvm_verifier::binding::{ TowerProofInput, ZKVMOpcodeProofInput, ZKVMTableProofInput, E, F, }; use crate::zkvm_verifier::verifier::verify_zkvm_proof; +use ceno_mle::util::ceil_log2; use ff_ext::BabyBearExt4; use itertools::Itertools; use mpcs::BasefoldCommitment; @@ -24,6 +25,7 @@ use openvm_stark_sdk::{ }; use std::collections::HashMap; use std::fs::File; +use std::thread; type SC = BabyBearPoseidon2Config; type EF = ::Challenge; @@ -33,7 +35,7 @@ use ceno_zkvm::{ structs::ZKVMVerifyingKey, }; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct SubcircuitParams { pub id: usize, pub order_idx: usize, @@ -49,50 +51,10 @@ pub fn parse_zkvm_proof_import( ) -> (ZKVMProofInput, Vec) { let subcircuit_names = verifier.vk.circuit_vks.keys().collect_vec(); - let mut opcode_num_instances_lookup: HashMap = HashMap::new(); - let mut table_num_instances_lookup: HashMap = HashMap::new(); - for (index, num_instances) in &zkvm_proof.num_instances { - if let Some(_opcode_proof) = zkvm_proof.opcode_proofs.get(index) { - opcode_num_instances_lookup.insert(index.clone(), num_instances.clone()); - } else if let Some(_table_proof) = zkvm_proof.table_proofs.get(index) { - table_num_instances_lookup.insert(index.clone(), num_instances.clone()); - } else { - unreachable!("respective proof of index {} should exist", index) - } - } - let mut order_idx: usize = 0; let mut opcode_order_idx: usize = 0; let mut table_order_idx: usize = 0; let mut proving_sequence: Vec = vec![]; - for (index, _) in &zkvm_proof.num_instances { - let name = subcircuit_names[*index].clone(); - if zkvm_proof.opcode_proofs.get(index).is_some() { - proving_sequence.push(SubcircuitParams { - id: *index, - order_idx: order_idx.clone(), - type_order_idx: opcode_order_idx.clone(), - name: name.clone(), - num_instances: opcode_num_instances_lookup.get(index).unwrap().clone(), - is_opcode: true, - }); - opcode_order_idx += 1; - } else if zkvm_proof.table_proofs.get(index).is_some() { - proving_sequence.push(SubcircuitParams { - id: *index, - order_idx: order_idx.clone(), - type_order_idx: table_order_idx.clone(), - name: name.clone(), - num_instances: table_num_instances_lookup.get(index).unwrap().clone(), - is_opcode: false, - }); - table_order_idx += 1; - } else { - unreachable!("respective proof of index {} should exist", index) - } - - order_idx += 1; - } let raw_pi = zkvm_proof .raw_pi @@ -119,31 +81,42 @@ pub fn parse_zkvm_proof_import( .collect::>(); let mut opcode_proofs_vec: Vec = vec![]; - for (opcode_id, opcode_proof) in &zkvm_proof.opcode_proofs { - let mut record_r_out_evals: Vec = vec![]; - let mut record_w_out_evals: Vec = vec![]; - for v in &opcode_proof.record_r_out_evals { - let v_e: E = serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - record_r_out_evals.push(v_e); + /* + for (opcode_id, opcode_proof) in &zkvm_proof.chip_proofs { + let mut record_r_out_evals: Vec> = vec![]; + let mut record_w_out_evals: Vec> = vec![]; + let mut record_lk_out_evals: Vec> = vec![]; + + let record_r_out_evals_len: usize = opcode_proof.r_out_evals.len(); + for v_vec in &opcode_proof.r_out_evals { + let mut arr: Vec = vec![]; + for v in v_vec { + let v_e: E = + serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); + arr.push(v_e); + } + record_r_out_evals.push(arr); } - for v in &opcode_proof.record_w_out_evals { - let v_e: E = serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - record_w_out_evals.push(v_e); + let record_w_out_evals_len: usize = opcode_proof.w_out_evals.len(); + for v_vec in &opcode_proof.w_out_evals { + let mut arr: Vec = vec![]; + for v in v_vec { + let v_e: E = + serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); + arr.push(v_e); + } + record_w_out_evals.push(arr); + } + let record_lk_out_evals_len: usize = opcode_proof.lk_out_evals.len(); + for v_vec in &opcode_proof.lk_out_evals { + let mut arr: Vec = vec![]; + for v in v_vec { + let v_e: E = + serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); + arr.push(v_e); + } + record_lk_out_evals.push(arr); } - - // logup sum at layer 1 - let lk_p1_out_eval: E = - serde_json::from_value(serde_json::to_value(opcode_proof.lk_p1_out_eval).unwrap()) - .unwrap(); - let lk_p2_out_eval: E = - serde_json::from_value(serde_json::to_value(opcode_proof.lk_p2_out_eval).unwrap()) - .unwrap(); - let lk_q1_out_eval: E = - serde_json::from_value(serde_json::to_value(opcode_proof.lk_q1_out_eval).unwrap()) - .unwrap(); - let lk_q2_out_eval: E = - serde_json::from_value(serde_json::to_value(opcode_proof.lk_q2_out_eval).unwrap()) - .unwrap(); // Tower proof let mut tower_proof = TowerProofInput::default(); @@ -205,106 +178,84 @@ pub fn parse_zkvm_proof_import( tower_proof.logup_specs_eval = logup_specs_eval; // main constraint and select sumcheck proof - let mut main_sel_sumcheck_proofs: Vec = vec![]; - for m in &opcode_proof.main_sel_sumcheck_proofs { - let mut evaluations_vec: Vec = vec![]; - for v in &m.evaluations { - let v_e: E = - serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - evaluations_vec.push(v_e); + let mut main_sumcheck_proofs: Vec = vec![]; + if opcode_proof.main_sumcheck_proofs.is_some() { + for m in opcode_proof.main_sumcheck_proofs.as_ref().unwrap() { + let mut evaluations_vec: Vec = vec![]; + for v in &m.evaluations { + let v_e: E = + serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); + evaluations_vec.push(v_e); + } + main_sumcheck_proofs.push(IOPProverMessage { + evaluations: evaluations_vec, + }); } - main_sel_sumcheck_proofs.push(IOPProverMessage { - evaluations: evaluations_vec, - }); - } - let mut r_records_in_evals: Vec = vec![]; - for v in &opcode_proof.r_records_in_evals { - let v_e: E = serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - r_records_in_evals.push(v_e); - } - let mut w_records_in_evals: Vec = vec![]; - for v in &opcode_proof.w_records_in_evals { - let v_e: E = serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - w_records_in_evals.push(v_e); - } - let mut lk_records_in_evals: Vec = vec![]; - for v in &opcode_proof.lk_records_in_evals { - let v_e: E = serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - lk_records_in_evals.push(v_e); } + let mut wits_in_evals: Vec = vec![]; for v in &opcode_proof.wits_in_evals { let v_e: E = serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); wits_in_evals.push(v_e); } + let mut fixed_in_evals: Vec = vec![]; + for v in &opcode_proof.fixed_in_evals { + let v_e: E = serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); + fixed_in_evals.push(v_e); + } + opcode_proofs_vec.push(ZKVMOpcodeProofInput { idx: opcode_id.clone(), num_instances: opcode_num_instances_lookup.get(opcode_id).unwrap().clone(), + record_r_out_evals_len, + record_w_out_evals_len, + record_lk_out_evals_len, record_r_out_evals, record_w_out_evals, - lk_p1_out_eval, - lk_p2_out_eval, - lk_q1_out_eval, - lk_q2_out_eval, + record_lk_out_evals, tower_proof, - main_sel_sumcheck_proofs, - r_records_in_evals, - w_records_in_evals, - lk_records_in_evals, + main_sumcheck_proofs, wits_in_evals, + fixed_in_evals, }); } let mut table_proofs_vec: Vec = vec![]; for (table_id, table_proof) in &zkvm_proof.table_proofs { - let mut r_out_evals: Vec = vec![]; - let mut w_out_evals: Vec = vec![]; - let mut lk_out_evals: Vec = vec![]; - - for v in &table_proof.r_out_evals { - r_out_evals.push(serde_json::from_value(serde_json::to_value(v[0]).unwrap()).unwrap()); - r_out_evals.push(serde_json::from_value(serde_json::to_value(v[1]).unwrap()).unwrap()); - } - for v in &table_proof.w_out_evals { - w_out_evals.push(serde_json::from_value(serde_json::to_value(v[0]).unwrap()).unwrap()); - w_out_evals.push(serde_json::from_value(serde_json::to_value(v[1]).unwrap()).unwrap()); - } - let compressed_rw_out_len: usize = r_out_evals.len() / 2; - for v in &table_proof.lk_out_evals { - lk_out_evals.push(serde_json::from_value(serde_json::to_value(v[0]).unwrap()).unwrap()); - lk_out_evals.push(serde_json::from_value(serde_json::to_value(v[1]).unwrap()).unwrap()); - lk_out_evals.push(serde_json::from_value(serde_json::to_value(v[2]).unwrap()).unwrap()); - lk_out_evals.push(serde_json::from_value(serde_json::to_value(v[3]).unwrap()).unwrap()); - } - let compressed_lk_out_len: usize = lk_out_evals.len() / 4; - - let mut has_same_r_sumcheck_proofs: usize = 0; - let mut same_r_sumcheck_proofs: Vec = vec![]; - if table_proof.same_r_sumcheck_proofs.is_some() { - for m in table_proof.same_r_sumcheck_proofs.as_ref().unwrap() { - let mut evaluation_vec: Vec = vec![]; - for v in &m.evaluations { - let v_e: E = serde_json::from_value(serde_json::to_value(v).unwrap()).unwrap(); - evaluation_vec.push(v_e); - } - same_r_sumcheck_proofs.push(IOPProverMessage { - evaluations: evaluation_vec, - }); + let mut record_r_out_evals: Vec> = vec![]; + let mut record_w_out_evals: Vec> = vec![]; + let mut record_lk_out_evals: Vec> = vec![]; + + let record_r_out_evals_len: usize = table_proof.r_out_evals.len(); + for v_vec in &table_proof.r_out_evals { + let mut arr: Vec = vec![]; + for v in v_vec { + let v_e: E = + serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); + arr.push(v_e); } - } else { - has_same_r_sumcheck_proofs = 0; + record_r_out_evals.push(arr); } - - let mut rw_in_evals: Vec = vec![]; - for v in &table_proof.rw_in_evals { - let v_e: E = serde_json::from_value(serde_json::to_value(v).unwrap()).unwrap(); - rw_in_evals.push(v_e); + let record_w_out_evals_len: usize = table_proof.w_out_evals.len(); + for v_vec in &table_proof.w_out_evals { + let mut arr: Vec = vec![]; + for v in v_vec { + let v_e: E = + serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); + arr.push(v_e); + } + record_w_out_evals.push(arr); } - let mut lk_in_evals: Vec = vec![]; - for v in &table_proof.lk_in_evals { - let v_e: E = serde_json::from_value(serde_json::to_value(v).unwrap()).unwrap(); - lk_in_evals.push(v_e); + let record_lk_out_evals_len: usize = table_proof.lk_out_evals.len(); + for v_vec in &table_proof.lk_out_evals { + let mut arr: Vec = vec![]; + for v in v_vec { + let v_e: E = + serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); + arr.push(v_e); + } + record_lk_out_evals.push(arr); } // Tower proof @@ -382,26 +333,24 @@ pub fn parse_zkvm_proof_import( table_proofs_vec.push(ZKVMTableProofInput { idx: table_id.clone(), num_instances, - r_out_evals, - w_out_evals, - compressed_rw_out_len, - lk_out_evals, - compressed_lk_out_len, - has_same_r_sumcheck_proofs, - same_r_sumcheck_proofs, - rw_in_evals, - lk_in_evals, + record_r_out_evals_len, + record_w_out_evals_len, + record_lk_out_evals_len, + record_r_out_evals, + record_w_out_evals, + record_lk_out_evals, tower_proof, fixed_in_evals, wits_in_evals, }); } + */ let witin_commit: BasefoldCommitment = serde_json::from_value(serde_json::to_value(zkvm_proof.witin_commit).unwrap()).unwrap(); let fixed_commit = verifier.vk.fixed_commit.clone(); - let pcs_proof = zkvm_proof.fixed_witin_opening_proof; + let pcs_proof = zkvm_proof.opening_proof; // let query_phase_verifier_input = QueryPhaseVerifierInput { // max_num_var, @@ -422,20 +371,20 @@ pub fn parse_zkvm_proof_import( ZKVMProofInput { raw_pi, pi_evals, - opcode_proofs: opcode_proofs_vec, - table_proofs: table_proofs_vec, + opcode_proofs: vec![], + table_proofs: vec![], witin_commit, fixed_commit, - num_instances: zkvm_proof.num_instances.clone(), - // query_phase_verifier_input, + num_instances: vec![], // TODO: Fixme + // query_phase_verifier_input, }, proving_sequence, ) } -#[test] -pub fn test_zkvm_proof_verifier_from_bincode_exports() { +pub fn inner_test_thread() { setup_tracing_with_log_level(tracing::Level::WARN); + let proof_path = "./src/e2e/encoded/proof.bin"; let vk_path = "./src/e2e/encoded/vk.bin"; @@ -502,3 +451,15 @@ pub fn test_zkvm_proof_verifier_from_bincode_exports() { println!("=> segment {:?} metrics: {:?}", i, seg.metrics); } } + +#[test] +pub fn test_zkvm_proof_verifier_from_bincode_exports() { + let stack_size = 64 * 1024 * 1024; // 64 MB + + let handler = thread::Builder::new() + .stack_size(stack_size) + .spawn(inner_test_thread) + .expect("Failed to spawn thread"); + + handler.join().expect("Thread panicked"); +} diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 9a37a6c..438284d 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -11,17 +11,13 @@ use crate::zkvm_verifier::verifier::verify_zkvm_proof; use crate::{ arithmetics::{ build_eq_x_r_vec_sequential, ceil_log2, concat, dot_product as ext_dot_product, - eq_eval_less_or_equal_than, eval_ceno_expr_with_instance, eval_wellform_address_vec, - gen_alpha_pows, max_usize_arr, max_usize_vec, next_pow2_instance_padding, product, - sum as ext_sum, - }, - tower_verifier::{ - binding::{PointVariable, TowerVerifierInputVariable}, - program::iop_verifier_state_verify, + eq_eval_less_or_equal_than, eval_wellform_address_vec, gen_alpha_pows, max_usize_arr, + max_usize_vec, next_pow2_instance_padding, product, sum as ext_sum, }, + tower_verifier::{binding::PointVariable, program::iop_verifier_state_verify}, }; -use ceno_zkvm::circuit_builder::SetTableSpec; -use ceno_zkvm::{expression::StructuralWitIn, scheme::verifier::ZKVMVerifier}; +use ceno_mle::expression::StructuralWitIn; +use ceno_zkvm::{circuit_builder::SetTableSpec, scheme::verifier::ZKVMVerifier}; use ff_ext::BabyBearExt4; use itertools::interleave; use itertools::max; diff --git a/src/tower_verifier/binding.rs b/src/tower_verifier/binding.rs index 5fd52bd..26c888a 100644 --- a/src/tower_verifier/binding.rs +++ b/src/tower_verifier/binding.rs @@ -11,7 +11,7 @@ pub type InnerConfig = AsmConfig; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; #[derive(DslVariable, Clone)] pub struct PointVariable { @@ -29,24 +29,6 @@ pub struct IOPProverMessageVariable { pub evaluations: Array>, } -#[derive(DslVariable, Clone)] -pub struct TowerVerifierInputVariable { - pub prod_out_evals: Array>>, - pub logup_out_evals: Array>>, - pub num_variables: Array>, - pub num_fanin: Usize, - - // TowerProofVariable - pub num_proofs: Usize, - pub num_prod_specs: Usize, - pub num_logup_specs: Usize, - pub max_num_variables: Usize, - - pub proofs: Array>>, - pub prod_specs_eval: Array>>>, - pub logup_specs_eval: Array>>>, -} - #[derive(Clone, Deserialize)] pub struct Point { pub fs: Vec, @@ -137,91 +119,3 @@ pub struct TowerVerifierInput { pub prod_specs_eval: Vec>>, pub logup_specs_eval: Vec>>, } - -impl Hintable for TowerVerifierInput { - type HintVariable = TowerVerifierInputVariable; - - fn read(builder: &mut Builder) -> Self::HintVariable { - let prod_out_evals = Vec::>::read(builder); - let logup_out_evals = Vec::>::read(builder); - let num_variables_var = Vec::::read(builder); - let num_variables = builder.dyn_array(num_variables_var.len()); - iter_zip!(builder, num_variables_var, num_variables).for_each(|ptr_vec, builder| { - let v = builder.iter_ptr_get(&num_variables_var, ptr_vec[0]); - let v_usize: Usize<::N> = Usize::from(v); - builder.iter_ptr_set(&num_variables, ptr_vec[1], v_usize); - }); - - let num_fanin = Usize::Var(usize::read(builder)); - let num_proofs = Usize::Var(usize::read(builder)); - let num_prod_specs = Usize::Var(usize::read(builder)); - let num_logup_specs = Usize::Var(usize::read(builder)); - let max_num_variables = Usize::Var(usize::read(builder)); - - let proofs = builder.dyn_array(num_proofs.clone()); - let prod_specs_eval = builder.dyn_array(num_prod_specs.clone()); - let logup_specs_eval = builder.dyn_array(num_logup_specs.clone()); - - iter_zip!(builder, proofs).for_each(|idx_vec, builder| { - let ptr = idx_vec[0]; - let proof = Vec::::read(builder); - builder.iter_ptr_set(&proofs, ptr, proof); - }); - - iter_zip!(builder, prod_specs_eval).for_each(|idx_vec, builder| { - let ptr = idx_vec[0]; - let evals = Vec::>::read(builder); - builder.iter_ptr_set(&prod_specs_eval, ptr, evals); - }); - - iter_zip!(builder, logup_specs_eval).for_each(|idx_vec, builder| { - let ptr = idx_vec[0]; - let evals = Vec::>::read(builder); - builder.iter_ptr_set(&logup_specs_eval, ptr, evals); - }); - - TowerVerifierInputVariable { - prod_out_evals, - logup_out_evals, - num_variables, - num_fanin, - num_proofs, - num_prod_specs, - num_logup_specs, - max_num_variables, - proofs, - prod_specs_eval, - logup_specs_eval, - } - } - - fn write(&self) -> Vec::N>> { - let mut stream = Vec::new(); - stream.extend(self.prod_out_evals.write()); - stream.extend(self.logup_out_evals.write()); - stream.extend(self.num_variables.write()); - stream.extend(>::write(&self.num_fanin)); - stream.extend(>::write(&self.num_proofs)); - stream.extend(>::write( - &self.num_prod_specs, - )); - stream.extend(>::write( - &self.num_logup_specs, - )); - - let max_num_variables = self.num_variables.iter().max().unwrap().clone(); - stream.extend(>::write(&max_num_variables)); - - for p in &self.proofs { - stream.extend(p.write()); - } - for evals in &self.prod_specs_eval { - stream.extend(evals.write()); - } - for evals in &self.logup_specs_eval { - stream.extend(evals.write()); - } - - stream - } -} diff --git a/src/tower_verifier/program.rs b/src/tower_verifier/program.rs index f8e6e3c..09e973c 100644 --- a/src/tower_verifier/program.rs +++ b/src/tower_verifier/program.rs @@ -1,11 +1,12 @@ -use super::binding::{ - IOPProverMessageVariable, PointAndEvalVariable, PointVariable, TowerVerifierInputVariable, -}; +use super::binding::{IOPProverMessageVariable, PointAndEvalVariable, PointVariable}; use crate::arithmetics::{ - challenger_multi_observe, dot_product, eq_eval, evaluate_at_point, extend, exts_to_felts, - fixed_dot_product, gen_alpha_pows, is_smaller_than, reverse, UniPolyExtrapolator, + challenger_multi_observe, dot_product, eq_eval, evaluate_at_point_degree_1, extend, + exts_to_felts, fixed_dot_product, gen_alpha_pows, is_smaller_than, print_ext_arr, reverse, + UniPolyExtrapolator, }; use crate::transcript::transcript_observe_label; +use crate::zkvm_verifier::binding::TowerProofInputVariable; +use ceno_zkvm::scheme::constants::NUM_FANIN; use openvm_native_compiler::prelude::*; use openvm_native_compiler_derive::iter_zip; use openvm_native_recursion::challenger::ChallengerVariable; @@ -144,11 +145,14 @@ pub fn iop_verifier_state_verify( ) { // TODO: either store it in a global cache or pass them as parameters let zero: Ext = builder.constant(C::EF::ZERO); + let zero_f: Felt = builder.constant(C::F::ZERO); let max_num_variables_usize: Usize = Usize::from(builder.cast_felt_to_var(max_num_variables.clone())); challenger.observe(builder, max_num_variables); + challenger.observe(builder, zero_f); challenger.observe(builder, max_degree); + challenger.observe(builder, zero_f); builder.assert_var_eq(max_num_variables_usize.get_var(), prover_messages.len()); @@ -193,7 +197,15 @@ pub fn iop_verifier_state_verify( pub fn verify_tower_proof( builder: &mut Builder, challenger: &mut DuplexChallengerVariable, - tower_verifier_input: TowerVerifierInputVariable, + prod_out_evals: Array>>, + logup_out_evals: &Array>>, + num_variables: Array>, + num_fanin: Usize, + + // TowerProofVariable + max_num_variables: Usize, + + proof: &TowerProofInputVariable, unipoly_extrapolator: &mut UniPolyExtrapolator, ) -> ( PointVariable, @@ -201,42 +213,32 @@ pub fn verify_tower_proof( Array>, Array>, ) { - let num_fanin: usize = 2; - builder.assert_usize_eq(tower_verifier_input.num_fanin, RVar::from(num_fanin)); - let num_prod_spec = tower_verifier_input.prod_out_evals.len(); - let num_logup_spec = tower_verifier_input.logup_out_evals.len(); + let num_prod_spec = prod_out_evals.len(); + let num_logup_spec = logup_out_evals.len(); let one: Ext = builder.constant(C::EF::ONE); let zero: Ext = builder.constant(C::EF::ZERO); - builder.assert_usize_eq( - tower_verifier_input.prod_specs_eval.len(), - num_prod_spec.clone(), - ); - builder.assert_usize_eq( - tower_verifier_input.logup_specs_eval.len(), - num_logup_spec.clone(), - ); - builder.assert_usize_eq( - tower_verifier_input.num_variables.len(), - num_prod_spec.clone() + num_logup_spec.clone(), - ); - - iter_zip!(builder, tower_verifier_input.prod_out_evals).for_each(|ptr_vec, builder| { + builder.assert_usize_eq(proof.prod_specs_eval.len(), num_prod_spec.clone()); + iter_zip!(builder, prod_out_evals).for_each(|ptr_vec, builder| { let ptr = ptr_vec[0]; - let evals = builder.iter_ptr_get(&tower_verifier_input.prod_out_evals, ptr); - builder.assert_usize_eq(evals.len(), RVar::from(num_fanin)); + let evals = builder.iter_ptr_get(&prod_out_evals, ptr); + builder.assert_usize_eq(evals.len(), num_fanin.clone()); }); - iter_zip!(builder, tower_verifier_input.logup_out_evals).for_each(|ptr_vec, builder| { + builder.assert_usize_eq(proof.logup_specs_eval.len(), num_logup_spec.clone()); + iter_zip!(builder, logup_out_evals).for_each(|ptr_vec, builder| { let ptr = ptr_vec[0]; - let evals = builder.iter_ptr_get(&tower_verifier_input.logup_out_evals, ptr); + let evals = builder.iter_ptr_get(&logup_out_evals, ptr); builder.assert_usize_eq(evals.len(), RVar::from(4)); }); + builder.assert_usize_eq( + num_variables.len(), + num_prod_spec.clone() + num_logup_spec.clone(), + ); - let num_specs: Var = builder.eval(num_prod_spec.get_var() + num_logup_spec.get_var()); let var_zero: Var = builder.constant(C::N::ZERO); let var_one: Var = builder.constant(C::N::ONE); - + let num_specs: Var = builder.eval(num_prod_spec.get_var() + num_logup_spec.get_var()); let should_skip: Array> = builder.dyn_array(num_specs); builder.range(0, num_specs).for_each(|i_vec, builder| { let i = i_vec[0]; @@ -255,7 +257,7 @@ pub fn verify_tower_proof( // out_j[rt] := (logup_q{j}[rt]) let log2_num_fanin = 1usize; builder.cycle_tracker_start("initial sum"); - let initial_rt: Array> = builder.dyn_array(RVar::from(log2_num_fanin)); + let initial_rt: Array> = builder.dyn_array(log2_num_fanin); transcript_observe_label(builder, challenger, b"product_sum"); builder .range(0, initial_rt.len()) @@ -267,15 +269,11 @@ pub fn verify_tower_proof( let prod_spec_point_n_eval: Array> = builder.dyn_array(num_prod_spec.clone()); - iter_zip!( - builder, - tower_verifier_input.prod_out_evals, - prod_spec_point_n_eval - ) - .for_each(|ptr_vec, builder| { + + iter_zip!(builder, prod_out_evals, prod_spec_point_n_eval).for_each(|ptr_vec, builder| { let ptr = ptr_vec[0]; - let evals = builder.iter_ptr_get(&tower_verifier_input.prod_out_evals, ptr); - let e = evaluate_at_point(builder, &evals, &initial_rt); + let evals = builder.iter_ptr_get(&prod_out_evals, ptr); + let e = evaluate_at_point_degree_1(builder, &evals, &initial_rt); let p_ptr = ptr_vec[1]; builder.iter_ptr_set( &prod_spec_point_n_eval, @@ -296,19 +294,19 @@ pub fn verify_tower_proof( iter_zip!( builder, - tower_verifier_input.logup_out_evals, + logup_out_evals, logup_spec_p_point_n_eval, logup_spec_q_point_n_eval ) .for_each(|ptr_vec, builder| { let ptr = ptr_vec[0]; - let evals = builder.iter_ptr_get(&tower_verifier_input.prod_out_evals, ptr); + let evals = builder.iter_ptr_get(&prod_out_evals, ptr); let p_slice = evals.slice(builder, 0, 2); let q_slice = evals.slice(builder, 2, 4); - let e1 = evaluate_at_point(builder, &p_slice, &initial_rt); - let e2 = evaluate_at_point(builder, &q_slice, &initial_rt); + let e1 = evaluate_at_point_degree_1(builder, &p_slice, &initial_rt); + let e2 = evaluate_at_point_degree_1(builder, &q_slice, &initial_rt); let p_ptr = ptr_vec[1]; let q_ptr = ptr_vec[2]; @@ -366,13 +364,11 @@ pub fn verify_tower_proof( builder.assign(&initial_claim, initial_claim + logup_eval.eval * alpha_acc); builder.assign(&alpha_acc, alpha_acc * alpha); }); - builder.cycle_tracker_end("initial sum"); let curr_pt = initial_rt.clone(); let curr_eval = initial_claim.clone(); - let op_range: RVar = - builder.eval_expr(tower_verifier_input.max_num_variables - Usize::from(1)); + let op_range: RVar = builder.eval_expr(max_num_variables - Usize::from(1)); let round: Felt = builder.constant(C::F::ZERO); let mut next_rt = PointAndEvalVariable { @@ -386,11 +382,9 @@ pub fn verify_tower_proof( .range(0, op_range.clone()) .for_each(|i_vec, builder| { let round_var = i_vec[0]; - let out_rt = &curr_pt; let out_claim = &curr_eval; - - let prover_messages = builder.get(&tower_verifier_input.proofs, round_var); + let prover_messages = builder.get(&proof.proofs, round_var); let max_num_variables: Felt = builder.constant(C::F::ONE); builder.assign(&max_num_variables, max_num_variables + round); @@ -421,7 +415,7 @@ pub fn verify_tower_proof( builder.cycle_tracker_start("accumulate expected eval for prod specs"); let spec_index = i_vec[0]; let skip = builder.get(&should_skip, spec_index.clone()); - let max_round = builder.get(&tower_verifier_input.num_variables, spec_index); + let max_round = builder.get(&num_variables, spec_index); let round_limit: RVar = builder.eval_expr(max_round - RVar::from(1)); let prod: Ext = builder.eval(zero + zero); @@ -433,11 +427,10 @@ pub fn verify_tower_proof( builder.if_eq(skip, var_zero.clone()).then(|builder| { builder.if_ne(round_var, round_limit).then_or_else( |builder| { - let prod_slice = - builder.get(&tower_verifier_input.prod_specs_eval, spec_index); + let prod_slice = builder.get(&proof.prod_specs_eval, spec_index); let prod_round_slice = builder.get(&prod_slice, round_var); builder.assign(&prod, one * one); - for j in 0..num_fanin { + for j in 0..NUM_FANIN { let prod_j = builder.get(&prod_round_slice, j); builder.assign(&prod, prod * prod_j); } @@ -453,12 +446,9 @@ pub fn verify_tower_proof( builder.cycle_tracker_end("accumulate expected eval for prod specs"); }); - let num_variables_len = tower_verifier_input.num_variables.len(); - let logup_num_variables_slice = tower_verifier_input.num_variables.slice( - builder, - num_prod_spec.clone(), - num_variables_len.clone(), - ); + let num_variables_len = num_variables.len(); + let logup_num_variables_slice = + num_variables.slice(builder, num_prod_spec.clone(), num_variables_len.clone()); builder .range(0, num_logup_spec.clone()) @@ -483,8 +473,7 @@ pub fn verify_tower_proof( builder.if_eq(skip, var_zero).then(|builder| { builder.if_ne(round_var, round_limit).then_or_else( |builder| { - let prod_slice = - builder.get(&tower_verifier_input.logup_specs_eval, spec_index); + let prod_slice = builder.get(&proof.logup_specs_eval, spec_index); let prod_round_slice = builder.get(&prod_slice, round_var); let p1 = builder.get(&prod_round_slice, 0); @@ -543,14 +532,12 @@ pub fn verify_tower_proof( builder.cycle_tracker_start("derive next layer for prod specs"); let spec_index = i_vec[0]; let skip = builder.get(&should_skip, spec_index.clone()); - let max_round = - builder.get(&tower_verifier_input.num_variables, spec_index.clone()); + let max_round = builder.get(&num_variables, spec_index.clone()); let round_limit: RVar = builder.eval_expr(max_round - RVar::from(1)); // now skip is 0 if and only if current round_var is smaller than round_limit. builder.if_eq(skip, var_zero.clone()).then(|builder| { - let prod_slice = - builder.get(&tower_verifier_input.prod_specs_eval, spec_index); + let prod_slice = builder.get(&proof.prod_specs_eval, spec_index); let prod_round_slice = builder.get(&prod_slice, round_var); let evals = fixed_dot_product(builder, &coeffs, &prod_round_slice, zero); @@ -584,11 +571,8 @@ pub fn verify_tower_proof( let next_logup_spec_evals: Ext<::F, ::EF> = builder.eval(zero + zero); - let logup_num_variables_slice = tower_verifier_input.num_variables.slice( - builder, - num_prod_spec.clone(), - num_variables_len.clone(), - ); + let logup_num_variables_slice = + num_variables.slice(builder, num_prod_spec.clone(), num_variables_len.clone()); builder .range(0, num_logup_spec.clone()) @@ -608,8 +592,7 @@ pub fn verify_tower_proof( // now skip is 0 if and only if current round_var is smaller than round_limit. builder.if_eq(skip, var_zero).then(|builder| { - let prod_slice = - builder.get(&tower_verifier_input.logup_specs_eval, spec_index); + let prod_slice = builder.get(&proof.logup_specs_eval, spec_index); let prod_round_slice = builder.get(&prod_slice, round_var); let p1 = builder.get(&prod_round_slice, 0); let p2 = builder.get(&prod_round_slice, 1); @@ -682,6 +665,7 @@ pub fn verify_tower_proof( ) } +/* #[cfg(test)] mod tests { use crate::arithmetics::UniPolyExtrapolator; @@ -689,15 +673,14 @@ mod tests { use crate::tower_verifier::binding::TowerVerifierInput; use crate::tower_verifier::program::iop_verifier_state_verify; use crate::tower_verifier::program::verify_tower_proof; - use ceno_mle::mle::{DenseMultilinearExtension, IntoMLE, MultilinearExtension}; - use ceno_mle::virtual_poly::ArcMultilinearExtension; + use ceno_mle::mle::ArcMultilinearExtension; + use ceno_mle::mle::{IntoMLE, MultilinearExtension}; use ceno_mle::virtual_polys::VirtualPolynomials; use ceno_sumcheck::structs::IOPProverState; use ceno_transcript::BasicTranscript; use ceno_zkvm::scheme::constants::NUM_FANIN; - use ceno_zkvm::scheme::utils::infer_tower_logup_witness; - use ceno_zkvm::scheme::utils::infer_tower_product_witness; - use ceno_zkvm::structs::TowerProver; + use ceno_zkvm::scheme::hal::TowerProver; + use ceno_zkvm::scheme::hal::TowerProverSpec; use ff_ext::BabyBearExt4; use ff_ext::FieldFrom; use ff_ext::FromUniformBytes; @@ -771,9 +754,8 @@ mod tests { // run sumcheck prover to get sumcheck proof let mut rng = thread_rng(); - let (mles, expected_sum) = - DenseMultilinearExtension::::random_mle_list(nv, degree, &mut rng); - let mles: Vec> = + let (mles, expected_sum) = MultilinearExtension::::random_mle_list(nv, degree, &mut rng); + let mles: Vec> = mles.into_iter().map(|mle| mle as _).collect_vec(); let mut virtual_poly: VirtualPolynomials<'_, E> = VirtualPolynomials::new(1, nv); virtual_poly.add_mle_list(mles.iter().collect_vec(), E::from_v(1)); @@ -843,9 +825,9 @@ mod tests { setup_tracing_with_log_level(tracing::Level::WARN); - let records: Vec> = (0..num_prod_specs) + let records: Vec> = (0..num_prod_specs) .map(|_| { - DenseMultilinearExtension::from_evaluations_ext_vec( + MultilinearExtension::from_evaluations_ext_vec( nv - 1, E::random_vec(1 << (nv - 1), &mut rng), ) @@ -853,10 +835,7 @@ mod tests { .collect_vec(); let denom_records = (0..num_logup_specs) .map(|_| { - DenseMultilinearExtension::from_evaluations_ext_vec( - nv, - E::random_vec(1 << nv, &mut rng), - ) + MultilinearExtension::from_evaluations_ext_vec(nv, E::random_vec(1 << nv, &mut rng)) }) .collect_vec(); @@ -897,7 +876,7 @@ mod tests { first.to_vec().into_mle().into(), second.to_vec().into_mle().into(), ]; - ceno_zkvm::structs::TowerProverSpec { + TowerProverSpec { witness: infer_tower_logup_witness(None, last_layer), } }) @@ -1015,3 +994,4 @@ mod tests { unsafe { Vec::from_raw_parts(new_ptr, length, capacity) } } } +*/ diff --git a/src/zkvm_verifier/binding.rs b/src/zkvm_verifier/binding.rs index 21cdee2..782658a 100644 --- a/src/zkvm_verifier/binding.rs +++ b/src/zkvm_verifier/binding.rs @@ -62,21 +62,19 @@ pub struct ZKVMOpcodeProofInputVariable { pub num_instances_minus_one_bit_decomposition: Array>, pub log2_num_instances: Usize, - pub record_r_out_evals: Array>, - pub record_w_out_evals: Array>, + pub record_r_out_evals_len: Usize, + pub record_w_out_evals_len: Usize, + pub record_lk_out_evals_len: Usize, - pub lk_p1_out_eval: Ext, - pub lk_p2_out_eval: Ext, - pub lk_q1_out_eval: Ext, - pub lk_q2_out_eval: Ext, + pub record_r_out_evals: Array>>, + pub record_w_out_evals: Array>>, + pub record_lk_out_evals: Array>>, pub tower_proof: TowerProofInputVariable, pub main_sel_sumcheck_proofs: Array>, - pub r_records_in_evals: Array>, - pub w_records_in_evals: Array>, - pub lk_records_in_evals: Array>, pub wits_in_evals: Array>, + pub fixed_in_evals: Array>, } #[derive(DslVariable, Clone)] @@ -86,16 +84,13 @@ pub struct ZKVMTableProofInputVariable { pub num_instances: Usize, pub log2_num_instances: Usize, - pub r_out_evals: Array>, // Vec<[E; 2]>, - pub w_out_evals: Array>, // Vec<[E; 2]>, - pub compressed_rw_out_len: Usize, - pub lk_out_evals: Array>, // Vec<[E; 4]>, - pub compressed_lk_out_len: Usize, + pub record_r_out_evals_len: Usize, + pub record_w_out_evals_len: Usize, + pub record_lk_out_evals_len: Usize, - pub has_same_r_sumcheck_proofs: Usize, // Either 1 or 0 - pub same_r_sumcheck_proofs: Array>, // Could be empty - pub rw_in_evals: Array>, - pub lk_in_evals: Array>, + pub record_r_out_evals: Array>>, + pub record_w_out_evals: Array>>, + pub record_lk_out_evals: Array>>, pub tower_proof: TowerProofInputVariable, pub fixed_in_evals: Array>, @@ -160,7 +155,7 @@ impl Hintable for ZKVMProofInput { let mut raw_pi_num_variables: Vec = vec![]; for v in &self.raw_pi { - raw_pi_num_variables.push(v.len().next_power_of_two()); + raw_pi_num_variables.push(ceil_log2(v.len().next_power_of_two())); } stream.extend(raw_pi_num_variables.write()); @@ -175,15 +170,15 @@ impl Hintable for ZKVMProofInput { cmt_vec.push(f); }); let mut witin_commit_trivial_commits: Vec> = vec![]; - for trivial_commit in &self.witin_commit.trivial_commits { - let mut t_cmt_vec: Vec = vec![]; - trivial_commit.iter().for_each(|x| { - let f: F = - serde_json::from_value(serde_json::to_value(x.clone()).unwrap()).unwrap(); - t_cmt_vec.push(f); - }); - witin_commit_trivial_commits.push(t_cmt_vec); - } + // for trivial_commit in &self.witin_commit.trivial_commits { + // let mut t_cmt_vec: Vec = vec![]; + // trivial_commit.1.iter().for_each(|x| { + // let f: F = + // serde_json::from_value(serde_json::to_value(x.clone()).unwrap()).unwrap(); + // t_cmt_vec.push(f); + // }); + // witin_commit_trivial_commits.push(t_cmt_vec); + // } let witin_commit_log2_max_codeword_size = F::from_canonical_u32(self.witin_commit.log2_max_codeword_size as u32); stream.extend(cmt_vec.write()); @@ -207,15 +202,15 @@ impl Hintable for ZKVMProofInput { fixed_commit_vec.push(f); }); - for trivial_commit in &self.fixed_commit.as_ref().unwrap().trivial_commits { - let mut t_cmt_vec: Vec = vec![]; - trivial_commit.iter().for_each(|x| { - let f: F = - serde_json::from_value(serde_json::to_value(x.clone()).unwrap()).unwrap(); - t_cmt_vec.push(f); - }); - fixed_commit_trivial_commits.push(t_cmt_vec); - } + // for trivial_commit in &self.fixed_commit.as_ref().unwrap().trivial_commits { + // let mut t_cmt_vec: Vec = vec![]; + // trivial_commit.1.iter().for_each(|x| { + // let f: F = + // serde_json::from_value(serde_json::to_value(x.clone()).unwrap()).unwrap(); + // t_cmt_vec.push(f); + // }); + // fixed_commit_trivial_commits.push(t_cmt_vec); + // } fixed_commit_log2_max_codeword_size = F::from_canonical_u32( self.fixed_commit.as_ref().unwrap().log2_max_codeword_size as u32, ); @@ -317,23 +312,19 @@ pub struct ZKVMOpcodeProofInput { pub num_instances: usize, // product constraints - pub record_r_out_evals: Vec, - pub record_w_out_evals: Vec, - - // logup sum at layer 1 - pub lk_p1_out_eval: E, - pub lk_p2_out_eval: E, - pub lk_q1_out_eval: E, - pub lk_q2_out_eval: E, + pub record_r_out_evals_len: usize, + pub record_w_out_evals_len: usize, + pub record_lk_out_evals_len: usize, + pub record_r_out_evals: Vec>, + pub record_w_out_evals: Vec>, + pub record_lk_out_evals: Vec>, pub tower_proof: TowerProofInput, // main constraint and select sumcheck proof - pub main_sel_sumcheck_proofs: Vec, - pub r_records_in_evals: Vec, - pub w_records_in_evals: Vec, - pub lk_records_in_evals: Vec, + pub main_sumcheck_proofs: Vec, pub wits_in_evals: Vec, + pub fixed_in_evals: Vec, } impl VecAutoHintable for ZKVMOpcodeProofInput {} impl Hintable for ZKVMOpcodeProofInput { @@ -345,18 +336,19 @@ impl Hintable for ZKVMOpcodeProofInput { let num_instances = Usize::Var(usize::read(builder)); let num_instances_minus_one_bit_decomposition = Vec::::read(builder); let log2_num_instances = Usize::Var(usize::read(builder)); - let record_r_out_evals = Vec::::read(builder); - let record_w_out_evals = Vec::::read(builder); - let lk_p1_out_eval = E::read(builder); - let lk_p2_out_eval = E::read(builder); - let lk_q1_out_eval = E::read(builder); - let lk_q2_out_eval = E::read(builder); + + let record_r_out_evals_len = Usize::Var(usize::read(builder)); + let record_w_out_evals_len = Usize::Var(usize::read(builder)); + let record_lk_out_evals_len = Usize::Var(usize::read(builder)); + + let record_r_out_evals = Vec::>::read(builder); + let record_w_out_evals = Vec::>::read(builder); + let record_lk_out_evals = Vec::>::read(builder); + let tower_proof = TowerProofInput::read(builder); let main_sel_sumcheck_proofs = Vec::::read(builder); - let r_records_in_evals = Vec::::read(builder); - let w_records_in_evals = Vec::::read(builder); - let lk_records_in_evals = Vec::::read(builder); let wits_in_evals = Vec::::read(builder); + let fixed_in_evals = Vec::::read(builder); ZKVMOpcodeProofInputVariable { idx, @@ -364,18 +356,16 @@ impl Hintable for ZKVMOpcodeProofInput { num_instances, num_instances_minus_one_bit_decomposition, log2_num_instances, + record_r_out_evals_len, + record_w_out_evals_len, + record_lk_out_evals_len, record_r_out_evals, record_w_out_evals, - lk_p1_out_eval, - lk_p2_out_eval, - lk_q1_out_eval, - lk_q2_out_eval, + record_lk_out_evals, tower_proof, main_sel_sumcheck_proofs, - r_records_in_evals, - w_records_in_evals, - lk_records_in_evals, wits_in_evals, + fixed_in_evals, } } @@ -399,18 +389,24 @@ impl Hintable for ZKVMOpcodeProofInput { let log2_num_instances = ceil_log2(next_pow2_instance); stream.extend(>::write(&log2_num_instances)); + stream.extend(>::write( + &self.record_r_out_evals_len, + )); + stream.extend(>::write( + &self.record_w_out_evals_len, + )); + stream.extend(>::write( + &self.record_lk_out_evals_len, + )); + stream.extend(self.record_r_out_evals.write()); stream.extend(self.record_w_out_evals.write()); - stream.extend(>::write(&self.lk_p1_out_eval)); - stream.extend(>::write(&self.lk_p2_out_eval)); - stream.extend(>::write(&self.lk_q1_out_eval)); - stream.extend(>::write(&self.lk_q2_out_eval)); + stream.extend(self.record_lk_out_evals.write()); + stream.extend(self.tower_proof.write()); - stream.extend(self.main_sel_sumcheck_proofs.write()); - stream.extend(self.r_records_in_evals.write()); - stream.extend(self.w_records_in_evals.write()); - stream.extend(self.lk_records_in_evals.write()); + stream.extend(self.main_sumcheck_proofs.write()); stream.extend(self.wits_in_evals.write()); + stream.extend(self.fixed_in_evals.write()); stream } @@ -421,16 +417,12 @@ pub struct ZKVMTableProofInput { pub num_instances: usize, // tower evaluation at layer 1 - pub r_out_evals: Vec, // Vec<[E; 2]> - pub w_out_evals: Vec, // Vec<[E; 2]> - pub compressed_rw_out_len: usize, - pub lk_out_evals: Vec, // Vec<[E; 4]> - pub compressed_lk_out_len: usize, - - pub has_same_r_sumcheck_proofs: usize, - pub same_r_sumcheck_proofs: Vec, // Could be empty - pub rw_in_evals: Vec, - pub lk_in_evals: Vec, + pub record_r_out_evals_len: usize, + pub record_w_out_evals_len: usize, + pub record_lk_out_evals_len: usize, + pub record_r_out_evals: Vec>, + pub record_w_out_evals: Vec>, + pub record_lk_out_evals: Vec>, pub tower_proof: TowerProofInput, @@ -448,16 +440,14 @@ impl Hintable for ZKVMTableProofInput { let num_instances = Usize::Var(usize::read(builder)); let log2_num_instances = Usize::Var(usize::read(builder)); - let r_out_evals = Vec::::read(builder); - let w_out_evals = Vec::::read(builder); - let compressed_rw_out_len = Usize::Var(usize::read(builder)); - let lk_out_evals = Vec::::read(builder); - let compressed_lk_out_len = Usize::Var(usize::read(builder)); + let record_r_out_evals_len = Usize::Var(usize::read(builder)); + let record_w_out_evals_len = Usize::Var(usize::read(builder)); + let record_lk_out_evals_len = Usize::Var(usize::read(builder)); + + let record_r_out_evals = Vec::>::read(builder); + let record_w_out_evals = Vec::>::read(builder); + let record_lk_out_evals = Vec::>::read(builder); - let has_same_r_sumcheck_proofs = Usize::Var(usize::read(builder)); - let same_r_sumcheck_proofs = Vec::::read(builder); - let rw_in_evals = Vec::::read(builder); - let lk_in_evals = Vec::::read(builder); let tower_proof = TowerProofInput::read(builder); let fixed_in_evals = Vec::::read(builder); let wits_in_evals = Vec::::read(builder); @@ -467,15 +457,12 @@ impl Hintable for ZKVMTableProofInput { idx_felt, num_instances, log2_num_instances, - r_out_evals, - w_out_evals, - compressed_rw_out_len, - lk_out_evals, - compressed_lk_out_len, - has_same_r_sumcheck_proofs, - same_r_sumcheck_proofs, - rw_in_evals, - lk_in_evals, + record_r_out_evals_len, + record_w_out_evals_len, + record_lk_out_evals_len, + record_r_out_evals, + record_w_out_evals, + record_lk_out_evals, tower_proof, fixed_in_evals, wits_in_evals, @@ -493,22 +480,20 @@ impl Hintable for ZKVMTableProofInput { let log2_num_instances = ceil_log2(self.num_instances); stream.extend(>::write(&log2_num_instances)); - stream.extend(self.r_out_evals.write()); - stream.extend(self.w_out_evals.write()); stream.extend(>::write( - &self.compressed_rw_out_len, + &self.record_r_out_evals_len, )); - stream.extend(self.lk_out_evals.write()); stream.extend(>::write( - &self.compressed_lk_out_len, + &self.record_w_out_evals_len, )); - stream.extend(>::write( - &self.has_same_r_sumcheck_proofs, + &self.record_lk_out_evals_len, )); - stream.extend(self.same_r_sumcheck_proofs.write()); - stream.extend(self.rw_in_evals.write()); - stream.extend(self.lk_in_evals.write()); + + stream.extend(self.record_r_out_evals.write()); + stream.extend(self.record_w_out_evals.write()); + stream.extend(self.record_lk_out_evals.write()); + stream.extend(self.tower_proof.write()); stream.extend(self.fixed_in_evals.write()); stream.extend(self.wits_in_evals.write()); diff --git a/src/zkvm_verifier/verifier.rs b/src/zkvm_verifier/verifier.rs index 10e768d..17f7d82 100644 --- a/src/zkvm_verifier/verifier.rs +++ b/src/zkvm_verifier/verifier.rs @@ -1,24 +1,23 @@ use super::binding::{ ZKVMOpcodeProofInputVariable, ZKVMProofInputVariable, ZKVMTableProofInputVariable, }; -use crate::arithmetics::{challenger_multi_observe, UniPolyExtrapolator}; +use crate::arithmetics::{ + challenger_multi_observe, eval_ceno_expr_with_instance, print_ext_arr, print_felt_arr, + PolyEvaluator, UniPolyExtrapolator, +}; use crate::e2e::SubcircuitParams; use crate::tower_verifier::program::verify_tower_proof; use crate::transcript::transcript_observe_label; use crate::{ arithmetics::{ build_eq_x_r_vec_sequential, ceil_log2, concat, dot_product as ext_dot_product, - eq_eval_less_or_equal_than, eval_ceno_expr_with_instance, eval_wellform_address_vec, - gen_alpha_pows, max_usize_arr, max_usize_vec, next_pow2_instance_padding, product, - sum as ext_sum, - }, - tower_verifier::{ - binding::{PointVariable, TowerVerifierInputVariable}, - program::iop_verifier_state_verify, + eq_eval_less_or_equal_than, eval_wellform_address_vec, gen_alpha_pows, max_usize_arr, + max_usize_vec, nested_product, next_pow2_instance_padding, product, sum as ext_sum, }, + tower_verifier::{binding::PointVariable, program::iop_verifier_state_verify}, }; -use ceno_zkvm::circuit_builder::SetTableSpec; -use ceno_zkvm::{expression::StructuralWitIn, scheme::verifier::ZKVMVerifier}; +use ceno_mle::expression::{Instance, StructuralWitIn}; +use ceno_zkvm::{circuit_builder::SetTableSpec, scheme::verifier::ZKVMVerifier}; use ff_ext::BabyBearExt4; use itertools::interleave; use itertools::max; @@ -102,7 +101,6 @@ pub fn verify_zkvm_proof( ); challenger_multi_observe(builder, &mut challenger, &zkvm_proof_input.fixed_commit); - iter_zip!(builder, zkvm_proof_input.fixed_commit_trivial_commits).for_each( |ptr_vec, builder| { let trivial_cmt = @@ -115,12 +113,16 @@ pub fn verify_zkvm_proof( zkvm_proof_input.fixed_commit_log2_max_codeword_size, ); + let zero_f: Felt = builder.constant(C::F::ZERO); iter_zip!(builder, zkvm_proof_input.num_instances).for_each(|ptr_vec, builder| { let ns = builder.iter_ptr_get(&zkvm_proof_input.num_instances, ptr_vec[0]); let circuit_size = builder.get(&ns, 0); let num_var = builder.get(&ns, 1); + challenger.observe(builder, circuit_size); + challenger.observe(builder, zero_f); challenger.observe(builder, num_var); + challenger.observe(builder, zero_f); }); challenger_multi_observe(builder, &mut challenger, &zkvm_proof_input.witin_commit); @@ -145,10 +147,15 @@ pub fn verify_zkvm_proof( builder.set(&challenges, 1, beta.clone()); let mut unipoly_extrapolator = UniPolyExtrapolator::new(builder); + let mut poly_evaluator = PolyEvaluator::new(builder); let dummy_table_item = alpha.clone(); let dummy_table_item_multiplicity: Ext = builder.constant(C::EF::ZERO); + let mut rt_points: Vec>> = Vec::with_capacity(proving_sequence.len()); + let mut evaluations: Vec>> = + Vec::with_capacity(2 * proving_sequence.len()); // witin + fixed thus *2 + for subcircuit_params in proving_sequence { if subcircuit_params.is_opcode { let opcode_proof = builder.get( @@ -159,7 +166,8 @@ pub fn verify_zkvm_proof( builder.constant(C::F::from_canonical_usize(subcircuit_params.id)); challenger.observe(builder, id_f); - verify_opcode_proof( + builder.cycle_tracker_start("Verify opcode proof"); + let input_opening_point = verify_opcode_proof( builder, &mut challenger, &opcode_proof, @@ -169,37 +177,40 @@ pub fn verify_zkvm_proof( &ceno_constraint_system, &mut unipoly_extrapolator, ); + builder.cycle_tracker_end("Verify opcode proof"); - let cs = ceno_constraint_system.vk.circuit_vks[&subcircuit_params.name].get_cs(); - let num_lks = cs.lk_expressions.len(); + rt_points.push(input_opening_point); + evaluations.push(opcode_proof.wits_in_evals); + // getting the number of dummy padding item that we used in this opcode circuit + let cs = ceno_constraint_system.vk.circuit_vks[&subcircuit_params.name].get_cs(); let num_instances = subcircuit_params.num_instances; - let num_padded_lks_per_instance = next_pow2_instance_padding(num_lks) - num_lks; + let num_lks = cs.zkvm_v1_css.lk_expressions.len(); let num_padded_instance = next_pow2_instance_padding(num_instances) - num_instances; - let new_multiplicity: Ext = builder.constant(C::EF::from_canonical_usize( - num_padded_lks_per_instance * num_instances - + num_lks.next_power_of_two() * num_padded_instance, - )); + let new_multiplicity: Ext = + builder.constant(C::EF::from_canonical_usize(num_lks * num_padded_instance)); builder.assign( &dummy_table_item_multiplicity, dummy_table_item_multiplicity + new_multiplicity, ); - let record_r_out_evals_prod = product(builder, &opcode_proof.record_r_out_evals); + let record_r_out_evals_prod = nested_product(builder, &opcode_proof.record_r_out_evals); builder.assign(&prod_r, prod_r * record_r_out_evals_prod); - let record_w_out_evals_prod = product(builder, &opcode_proof.record_w_out_evals); + let record_w_out_evals_prod = nested_product(builder, &opcode_proof.record_w_out_evals); builder.assign(&prod_w, prod_w * record_w_out_evals_prod); - builder.assign( - &logup_sum, - logup_sum + opcode_proof.lk_p1_out_eval * opcode_proof.lk_q1_out_eval.inverse(), - ); - builder.assign( - &logup_sum, - logup_sum + opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.inverse(), - ); + iter_zip!(builder, opcode_proof.record_lk_out_evals).for_each(|ptr_vec, builder| { + let evals = builder.iter_ptr_get(&opcode_proof.record_lk_out_evals, ptr_vec[0]); + let p1 = builder.get(&evals, 0); + let p2 = builder.get(&evals, 1); + let q1 = builder.get(&evals, 2); + let q2 = builder.get(&evals, 3); + + builder.assign(&logup_sum, logup_sum + p1 * q1.inverse()); + builder.assign(&logup_sum, logup_sum + p2 * q2.inverse()); + }); } else { let table_proof = builder.get( &zkvm_proof_input.table_proofs, @@ -209,7 +220,7 @@ pub fn verify_zkvm_proof( builder.constant(C::F::from_canonical_usize(subcircuit_params.id)); challenger.observe(builder, id_f); - verify_table_proof( + let input_opening_point = verify_table_proof( builder, &mut challenger, &table_proof, @@ -220,31 +231,32 @@ pub fn verify_zkvm_proof( &subcircuit_params, ceno_constraint_system, &mut unipoly_extrapolator, + &mut poly_evaluator, ); - let step = C::N::from_canonical_usize(4); - builder - .range_with_step(0, table_proof.lk_out_evals.len(), step) - .for_each(|idx_vec, builder| { - let p2_idx: Usize = builder.eval(idx_vec[0] + RVar::from(1)); - let q1_idx: Usize = builder.eval(p2_idx.clone() + RVar::from(1)); - let q2_idx: Usize = builder.eval(q1_idx.clone() + RVar::from(1)); - - let p1 = builder.get(&table_proof.lk_out_evals, idx_vec[0]); - let p2 = builder.get(&table_proof.lk_out_evals, p2_idx); - let q1 = builder.get(&table_proof.lk_out_evals, q1_idx); - let q2 = builder.get(&table_proof.lk_out_evals, q2_idx); - - builder.assign( - &logup_sum, - logup_sum - p1 * q1.inverse() - p2 * q2.inverse(), - ); - }); - - let w_out_evals_prod = product(builder, &table_proof.w_out_evals); - builder.assign(&prod_w, prod_w * w_out_evals_prod); - let r_out_evals_prod = product(builder, &table_proof.r_out_evals); - builder.assign(&prod_r, prod_r * r_out_evals_prod); + rt_points.push(input_opening_point); + evaluations.push(table_proof.wits_in_evals); + let cs = ceno_constraint_system.vk.circuit_vks[&subcircuit_params.name].get_cs(); + if cs.num_fixed() > 0 { + evaluations.push(table_proof.fixed_in_evals); + } + + iter_zip!(builder, table_proof.record_lk_out_evals).for_each(|ptr_vec, builder| { + let evals = builder.iter_ptr_get(&table_proof.record_lk_out_evals, ptr_vec[0]); + let p1 = builder.get(&evals, 0); + let p2 = builder.get(&evals, 1); + let q1 = builder.get(&evals, 2); + let q2 = builder.get(&evals, 3); + builder.assign( + &logup_sum, + logup_sum - p1 * q1.inverse() - p2 * q2.inverse(), + ); + }); + + let record_w_out_evals_prod = nested_product(builder, &table_proof.record_w_out_evals); + builder.assign(&prod_w, prod_w * record_w_out_evals_prod); + let record_r_out_evals_prod = nested_product(builder, &table_proof.record_r_out_evals); + builder.assign(&prod_r, prod_r * record_r_out_evals_prod); } } @@ -253,8 +265,7 @@ pub fn verify_zkvm_proof( logup_sum - dummy_table_item_multiplicity * dummy_table_item.inverse(), ); - /* TODO: Basefold e2e - // verify mpcs + /* TODO: MPCS PCS::batch_verify( &self.vk.vp, &vm_proof.num_instances, @@ -292,7 +303,9 @@ pub fn verify_zkvm_proof( ); builder.assign(&prod_r, prod_r * finalize_global_state); - // builder.assert_ext_eq(prod_r, prod_w); + /* TODO: Temporarily disable product check for missing subcircuits + builder.assert_ext_eq(prod_r, prod_w); + */ } pub fn verify_opcode_proof( @@ -304,135 +317,106 @@ pub fn verify_opcode_proof( subcircuit_params: &SubcircuitParams, cs: &ZKVMVerifier, unipoly_extrapolator: &mut UniPolyExtrapolator, -) { +) -> Array> { let cs = &cs.vk.circuit_vks[&subcircuit_params.name].cs; let one: Ext = builder.constant(C::EF::ONE); let zero: Ext = builder.constant(C::EF::ZERO); - let r_len = cs.r_expressions.len(); - let w_len = cs.w_expressions.len(); - let lk_len = cs.lk_expressions.len(); + let r_len = cs.zkvm_v1_css.r_expressions.len(); + let w_len = cs.zkvm_v1_css.w_expressions.len(); + let lk_len = cs.zkvm_v1_css.lk_expressions.len(); - let max_expr_len = *max([r_len, w_len, lk_len].iter()).unwrap(); + let num_batched = r_len + w_len + lk_len; + let chip_record_alpha: Ext = builder.get(challenges, 0); let r_counts_per_instance: Usize = Usize::from(r_len); let w_counts_per_instance: Usize = Usize::from(w_len); let lk_counts_per_instance: Usize = Usize::from(lk_len); + let num_batched: Usize = Usize::from(num_batched); let log2_r_count: Usize = Usize::from(ceil_log2(r_len)); let log2_w_count: Usize = Usize::from(ceil_log2(w_len)); let log2_lk_count: Usize = Usize::from(ceil_log2(lk_len)); + let log2_num_instances = opcode_proof.log2_num_instances.clone(); - let num_variables: Array> = builder.dyn_array(3); - let num_variables_r: Usize = - builder.eval(log2_num_instances.clone() + log2_r_count.clone()); - builder.set(&num_variables, 0, num_variables_r); - let num_variables_w: Usize = - builder.eval(log2_num_instances.clone() + log2_w_count.clone()); - builder.set(&num_variables, 1, num_variables_w); - let num_variables_lk: Usize = - builder.eval(log2_num_instances.clone() + log2_lk_count.clone()); - builder.set(&num_variables, 2, num_variables_lk); - let max_num_variables: Usize = - builder.eval(log2_num_instances.clone() + Usize::from(ceil_log2(max_expr_len))); let tower_proof = &opcode_proof.tower_proof; - let num_proofs = tower_proof.proofs.len(); - let num_prod_specs = tower_proof.prod_specs_eval.len(); - let num_logup_specs = tower_proof.logup_specs_eval.len(); - let num_fanin: Usize = Usize::from(NUM_FANIN); + let num_variables: Array> = builder.dyn_array(num_batched); + builder + .range(0, num_variables.len()) + .for_each(|idx_vec, builder| { + builder.set(&num_variables, idx_vec[0], log2_num_instances.clone()); + }); - let prod_out_evals: Array>> = builder.dyn_array(2); - builder.set(&prod_out_evals, 0, opcode_proof.record_r_out_evals.clone()); - builder.set(&prod_out_evals, 1, opcode_proof.record_w_out_evals.clone()); + let prod_out_evals: Array>> = concat( + builder, + &opcode_proof.record_r_out_evals, + &opcode_proof.record_w_out_evals, + ); - let logup_out_evals: Array>> = builder.dyn_array(1); - let logup_inner_evals: Array> = builder.dyn_array(4); - builder.set(&logup_inner_evals, 0, opcode_proof.lk_p1_out_eval); - builder.set(&logup_inner_evals, 1, opcode_proof.lk_p2_out_eval); - builder.set(&logup_inner_evals, 2, opcode_proof.lk_q1_out_eval); - builder.set(&logup_inner_evals, 3, opcode_proof.lk_q2_out_eval); - builder.set(&logup_out_evals, 0, logup_inner_evals); + let num_fanin: Usize = Usize::from(NUM_FANIN); + let max_expr_len = *max([r_len, w_len, lk_len].iter()).unwrap(); builder.cycle_tracker_start("verify tower proof for opcode"); let (rt_tower, record_evals, logup_p_evals, logup_q_evals) = verify_tower_proof( builder, challenger, - TowerVerifierInputVariable { - prod_out_evals, - logup_out_evals, - num_variables, - num_fanin, - num_proofs, - num_prod_specs, - num_logup_specs, - max_num_variables, - proofs: tower_proof.proofs.clone(), - prod_specs_eval: tower_proof.prod_specs_eval.clone(), - logup_specs_eval: tower_proof.logup_specs_eval.clone(), - }, + prod_out_evals, + &opcode_proof.record_lk_out_evals, + num_variables, + num_fanin, + log2_num_instances.clone(), + tower_proof, unipoly_extrapolator, ); builder.cycle_tracker_end("verify tower proof for opcode"); - let rt_non_lc_sumcheck: Array> = - rt_tower - .fs - .clone() - .slice(builder, 0, log2_num_instances.clone()); - - builder.assert_usize_eq(record_evals.len(), Usize::from(2)); - builder.assert_usize_eq(logup_q_evals.len(), Usize::from(1)); - builder.assert_usize_eq(logup_p_evals.len(), Usize::from(1)); - let logup_p_eval = builder.get(&logup_p_evals, 0).eval; builder.assert_ext_eq(logup_p_eval, one); - let [rt_r, rt_w, rt_lk]: [Array>; 3] = [ - builder.get(&record_evals, 0).point.fs, - builder.get(&record_evals, 1).point.fs, - builder.get(&logup_q_evals, 0).point.fs, - ]; - - let zero_sumcheck_expr_len: usize = cs.assert_zero_sumcheck_expressions.len(); + // verify zero statement (degree > 1) + sel sumcheck + let rt = builder.get(&record_evals, 0); + let num_rw_records: Usize = builder.eval(r_counts_per_instance + w_counts_per_instance); + builder.assert_usize_eq(record_evals.len(), num_rw_records.clone()); - let alpha_len = MAINCONSTRAIN_SUMCHECK_BATCH_SIZE + zero_sumcheck_expr_len; - let alpha_len: Usize = Usize::from(alpha_len); + let alpha_len = builder.eval( + num_rw_records.clone() + + lk_counts_per_instance + + Usize::from(cs.zkvm_v1_css.assert_zero_sumcheck_expressions.len()), + ); transcript_observe_label(builder, challenger, b"combine subset evals"); let alpha_pow = gen_alpha_pows(builder, challenger, alpha_len); - let [alpha_read, alpha_write, alpha_lk]: [Ext; 3] = [ - builder.get(&alpha_pow, 0), - builder.get(&alpha_pow, 1), - builder.get(&alpha_pow, 2), - ]; - // alpha_read * (out_r[rt] - 1) + alpha_write * (out_w[rt] - 1) + alpha_lk * (out_lk_q - chip_record_alpha) // + 0 // 0 come from zero check let claim_sum: Ext = builder.constant(C::EF::ZERO); - let record_eval_0 = builder.get(&record_evals, 0).eval; - let record_eval_1 = builder.get(&record_evals, 1).eval; - let logup_q_eval_0 = builder.get(&logup_q_evals, 0).eval; - let alpha = builder.get(challenges, 0); + let rw_logup_len: Usize = builder.eval(num_rw_records.clone() + logup_q_evals.len()); + let alpha_rw_slice = alpha_pow.slice(builder, 0, num_rw_records.clone()); + iter_zip!(builder, alpha_rw_slice, record_evals).for_each(|ptr_vec, builder| { + let alpha = builder.iter_ptr_get(&alpha_rw_slice, ptr_vec[0]); + let eval = builder.iter_ptr_get(&record_evals, ptr_vec[1]); - builder.assign( - &claim_sum, - alpha_read * (record_eval_0 - one) - + alpha_write * (record_eval_1 - one) - + alpha_lk * (logup_q_eval_0 - alpha), - ); + builder.assign(&claim_sum, claim_sum + alpha * (eval.eval - one)); + }); + let alpha_logup_slice = alpha_pow.slice(builder, num_rw_records.clone(), rw_logup_len); + iter_zip!(builder, alpha_logup_slice, logup_q_evals).for_each(|ptr_vec, builder| { + let alpha = builder.iter_ptr_get(&alpha_logup_slice, ptr_vec[0]); + let eval = builder.iter_ptr_get(&logup_q_evals, ptr_vec[1]); + builder.assign( + &claim_sum, + claim_sum + alpha * (eval.eval - chip_record_alpha), + ); + }); let log2_num_instances_var: Var = RVar::from(log2_num_instances.clone()).variable(); let log2_num_instances_f: Felt = builder.unsafe_cast_var_to_felt(log2_num_instances_var); - - let max_non_lc_degree: usize = cs.max_non_lc_degree; + let max_non_lc_degree: usize = cs.zkvm_v1_css.max_non_lc_degree; let main_sel_subclaim_max_degree: Felt = builder.constant(C::F::from_canonical_u32( SEL_DEGREE.max(max_non_lc_degree + 1) as u32, )); - builder.cycle_tracker_start("main sumcheck"); - let main_sel_subclaim = iop_verifier_state_verify( + let (input_opening_point, expected_evaluation) = iop_verifier_state_verify( builder, challenger, &claim_sum, @@ -443,150 +427,43 @@ pub fn verify_opcode_proof( ); builder.cycle_tracker_end("main sumcheck"); - let input_opening_point = PointVariable { - fs: main_sel_subclaim.0, - }; - let expected_evaluation: Ext = main_sel_subclaim.1; - - let rt_r_eq = rt_r.slice(builder, 0, log2_r_count.clone()); - let eq_r = build_eq_x_r_vec_sequential(builder, &rt_r_eq); - let rt_w_eq = rt_w.slice(builder, 0, log2_w_count.clone()); - let eq_w = build_eq_x_r_vec_sequential(builder, &rt_w_eq); - let rt_lk_eq = rt_lk.slice(builder, 0, log2_lk_count.clone()); - let eq_lk = build_eq_x_r_vec_sequential(builder, &rt_lk_eq); - - let rt_r_eq_less = rt_r.slice(builder, log2_r_count.clone(), rt_r.len()); - let rt_w_eq_less = rt_w.slice(builder, log2_w_count.clone(), rt_w.len()); - let rt_lk_eq_less = rt_lk.slice(builder, log2_lk_count.clone(), rt_lk.len()); - - let sel_r = eq_eval_less_or_equal_than( - builder, - challenger, - opcode_proof, - &input_opening_point.fs, - &rt_r_eq_less, - ); - let sel_w = eq_eval_less_or_equal_than( + // sel(rt, t) + let sel = eq_eval_less_or_equal_than( builder, challenger, opcode_proof, - &input_opening_point.fs, - &rt_w_eq_less, - ); - let sel_lk = eq_eval_less_or_equal_than( - builder, - challenger, - opcode_proof, - &input_opening_point.fs, - &rt_lk_eq_less, + &input_opening_point, + &rt.point.fs, ); - let sel_non_lc_zero_sumcheck: Ext = builder.constant(C::EF::ZERO); + // derive r_records, w_records, lk_records from witness's evaluations + let alpha_idx: Var = builder.uninit(); + builder.assign(&alpha_idx, Usize::from(0)); + let empty_arr: Array> = builder.dyn_array(0); - let zero_sumcheck_expressions_len = RVar::from(cs.assert_zero_sumcheck_expressions.len()); - builder - .if_ne(zero_sumcheck_expressions_len, RVar::from(0)) - .then(|builder| { - let sel_sumcheck = eq_eval_less_or_equal_than( + let rw_expressions_sum: Ext = builder.constant(C::EF::ZERO); + cs.zkvm_v1_css + .r_expressions + .iter() + .chain(cs.zkvm_v1_css.w_expressions.iter()) + .for_each(|expr| { + let e = eval_ceno_expr_with_instance( builder, - challenger, - opcode_proof, - &input_opening_point.fs, - &rt_non_lc_sumcheck, + &empty_arr, + &opcode_proof.wits_in_evals, + &empty_arr, + pi_evals, + challenges, + expr, ); - builder.assign(&sel_non_lc_zero_sumcheck, sel_sumcheck); + let alpha = builder.get(&alpha_pow, alpha_idx); + builder.assign(&alpha_idx, alpha_idx + Usize::from(1)); + builder.assign(&rw_expressions_sum, rw_expressions_sum + alpha * (e - one)) }); + builder.assign(&rw_expressions_sum, rw_expressions_sum * sel); - let r_records_slice = - opcode_proof - .r_records_in_evals - .slice(builder, 0, r_counts_per_instance.clone()); - let eq_r_slice = eq_r.slice(builder, 0, r_counts_per_instance.clone()); - let eq_r_rest_slice = eq_r.slice(builder, r_counts_per_instance.clone(), eq_r.len()); - let r_prod = ext_dot_product(builder, &r_records_slice, &eq_r_slice); - let eq_r_rest_sum = ext_sum(builder, &eq_r_rest_slice); - let r_eval: Ext = - builder.eval(alpha_read * sel_r * (r_prod + eq_r_rest_sum - one)); - - let w_records_slice = - opcode_proof - .w_records_in_evals - .slice(builder, 0, w_counts_per_instance.clone()); - let eq_w_slice = eq_w.slice(builder, 0, w_counts_per_instance.clone()); - let eq_w_rest_slice = eq_w.slice(builder, w_counts_per_instance.clone(), eq_w.len()); - let w_prod = ext_dot_product(builder, &w_records_slice, &eq_w_slice); - let eq_w_rest_sum = ext_sum(builder, &eq_w_rest_slice); - let w_eval: Ext = - builder.eval(alpha_write * sel_w * (w_prod + eq_w_rest_sum - one)); - - let lk_records_slice = - opcode_proof - .lk_records_in_evals - .slice(builder, 0, lk_counts_per_instance.clone()); - let eq_lk_slice = eq_lk.slice(builder, 0, lk_counts_per_instance.clone()); - let eq_lk_rest_slice = eq_lk.slice(builder, lk_counts_per_instance.clone(), eq_lk.len()); - let lk_prod = ext_dot_product(builder, &lk_records_slice, &eq_lk_slice); - let eq_lk_rest_sum = ext_sum(builder, &eq_lk_rest_slice); - let lk_eval: Ext = - builder.eval(alpha_lk * sel_lk * (lk_prod + alpha * (eq_lk_rest_sum - one))); - - let computed_eval: Ext = builder.eval(r_eval + w_eval + lk_eval); - let empty_arr: Array> = builder.dyn_array(0); - - // sel(rt_non_lc_sumcheck, main_sel_eval_point) * \sum_j (alpha{j} * expr(main_sel_eval_point)) - let sel_sum: Ext = builder.constant(C::EF::ZERO); - let alpha_pow_sel_sum = alpha_pow.slice(builder, 3, alpha_pow.len()); - for i in 0..cs.assert_zero_sumcheck_expressions.len() { - let expr = &cs.assert_zero_sumcheck_expressions[i]; - let al = builder.get(&alpha_pow_sel_sum, i); - - let expr_eval = eval_ceno_expr_with_instance( - builder, - &empty_arr, - &opcode_proof.wits_in_evals, - &empty_arr, - pi_evals, - challenges, - expr, - ); - - builder.assign(&sel_sum, sel_sum + al * expr_eval); - } - - builder.assign( - &computed_eval, - computed_eval + sel_sum * sel_non_lc_zero_sumcheck, - ); - builder.assert_ext_eq(computed_eval, expected_evaluation); - - // verify records (degree = 1) statement, thus no sumcheck - let r_records = &opcode_proof - .r_records_in_evals - .slice(builder, 0, r_counts_per_instance); - let w_records = &opcode_proof - .w_records_in_evals - .slice(builder, 0, w_counts_per_instance); - let lk_records = &opcode_proof - .lk_records_in_evals - .slice(builder, 0, lk_counts_per_instance); - - let _ = &cs.r_expressions.iter().enumerate().for_each(|(idx, expr)| { - let expected_eval = builder.get(&r_records, idx); - let e = eval_ceno_expr_with_instance( - builder, - &empty_arr, - &opcode_proof.wits_in_evals, - &empty_arr, - pi_evals, - challenges, - expr, - ); - - builder.assert_ext_eq(e, expected_eval); - }); - - let _ = &cs.w_expressions.iter().enumerate().for_each(|(idx, expr)| { - let expected_eval = builder.get(&w_records, idx); + let lk_expressions_sum: Ext = builder.constant(C::EF::ZERO); + cs.zkvm_v1_css.lk_expressions.iter().for_each(|expr| { let e = eval_ceno_expr_with_instance( builder, &empty_arr, @@ -596,16 +473,21 @@ pub fn verify_opcode_proof( challenges, expr, ); - - builder.assert_ext_eq(e, expected_eval); + let alpha = builder.get(&alpha_pow, alpha_idx); + builder.assign(&alpha_idx, alpha_idx + Usize::from(1)); + builder.assign( + &lk_expressions_sum, + lk_expressions_sum + alpha * (e - chip_record_alpha), + ) }); + builder.assign(&lk_expressions_sum, lk_expressions_sum * sel); - let _ = &cs - .lk_expressions + let zero_expressions_sum: Ext = builder.constant(C::EF::ZERO); + cs.zkvm_v1_css + .assert_zero_sumcheck_expressions .iter() - .enumerate() - .for_each(|(idx, expr)| { - let expected_eval = builder.get(&lk_records, idx); + .for_each(|expr| { + // evaluate zero expression by all wits_in_evals because they share the unique input_opening_point opening let e = eval_ceno_expr_with_instance( builder, &empty_arr, @@ -615,14 +497,21 @@ pub fn verify_opcode_proof( challenges, expr, ); - - builder.assert_ext_eq(e, expected_eval); + let alpha = builder.get(&alpha_pow, alpha_idx); + builder.assign(&alpha_idx, alpha_idx + Usize::from(1)); + builder.assign(&zero_expressions_sum, zero_expressions_sum + alpha * e); }); + builder.assign(&zero_expressions_sum, zero_expressions_sum * sel); + + let computed_eval: Ext = + builder.eval(rw_expressions_sum + lk_expressions_sum + zero_expressions_sum); + builder.assert_ext_eq(computed_eval, expected_evaluation); - cs.assert_zero_expressions + // verify zero expression (degree = 1) statement, thus no sumcheck + cs.zkvm_v1_css + .assert_zero_expressions .iter() - .enumerate() - .for_each(|(_idx, expr)| { + .for_each(|expr| { let e = eval_ceno_expr_with_instance( builder, &empty_arr, @@ -632,30 +521,33 @@ pub fn verify_opcode_proof( challenges, expr, ); - builder.assert_ext_eq(e, zero); }); + + input_opening_point } pub fn verify_table_proof( builder: &mut Builder, challenger: &mut DuplexChallengerVariable, table_proof: &ZKVMTableProofInputVariable, - _raw_pi: &Array>>, - _raw_pi_num_variables: &Array>, + raw_pi: &Array>>, + raw_pi_num_variables: &Array>, pi_evals: &Array>, challenges: &Array>, subcircuit_params: &SubcircuitParams, cs: &ZKVMVerifier, unipoly_extrapolator: &mut UniPolyExtrapolator, -) { + poly_evaluator: &mut PolyEvaluator, +) -> Array> { let cs = cs.vk.circuit_vks[&subcircuit_params.name].get_cs(); let tower_proof: &super::binding::TowerProofInputVariable = &table_proof.tower_proof; let r_expected_rounds: Array> = - builder.dyn_array(cs.r_table_expressions.len() * 2); + builder.dyn_array(cs.zkvm_v1_css.r_table_expressions.len() * 2); cs // only iterate r set, as read/write set round should match + .zkvm_v1_css .r_table_expressions .iter() .enumerate() @@ -679,8 +571,9 @@ pub fn verify_table_proof( }); let lk_expected_rounds: Array> = - builder.dyn_array(cs.lk_table_expressions.len()); - cs.lk_table_expressions + builder.dyn_array(cs.zkvm_v1_css.lk_table_expressions.len()); + cs.zkvm_v1_css + .lk_table_expressions .iter() .enumerate() .for_each(|(idx, expr)| { @@ -702,135 +595,60 @@ pub fn verify_table_proof( }); let expected_rounds = concat(builder, &r_expected_rounds, &lk_expected_rounds); let max_expected_rounds = max_usize_arr(builder, &expected_rounds); - - let prod_out_evals: Array>> = - builder.dyn_array(table_proof.r_out_evals.len()); - builder - .range_with_step( - 0, - table_proof.r_out_evals.len().clone(), - C::N::from_canonical_usize(2), - ) - .for_each(|idx_vec, builder| { - let r2_idx: Usize = builder.eval(idx_vec[0] + Usize::from(1)); - let r1 = builder.get(&table_proof.r_out_evals, idx_vec[0]); - let r2 = builder.get(&table_proof.r_out_evals, r2_idx.clone()); - let w1 = builder.get(&table_proof.w_out_evals, idx_vec[0]); - let w2 = builder.get(&table_proof.w_out_evals, r2_idx.clone()); - - let r_vec: Array> = builder.dyn_array(2); - let w_vec: Array> = builder.dyn_array(2); - - builder.set(&r_vec, 0, r1); - builder.set(&r_vec, 1, r2); - builder.set(&w_vec, 0, w1); - builder.set(&w_vec, 1, w2); - - builder.set(&prod_out_evals, idx_vec[0], r_vec); - builder.set(&prod_out_evals, r2_idx, w_vec); - }); - - let logup_out_evals: Array>> = - builder.dyn_array(table_proof.compressed_lk_out_len.clone()); - builder - .range(0, logup_out_evals.len()) - .for_each(|idx_vec, builder| { - let lk_vec: Array> = builder.dyn_array(4); - let lk2_idx: Usize = builder.eval(idx_vec[0] + Usize::from(1)); - let lk3_idx: Usize = builder.eval(lk2_idx.clone() + Usize::from(1)); - let lk4_idx: Usize = builder.eval(lk3_idx.clone() + Usize::from(1)); - - let lk1 = builder.get(&table_proof.lk_out_evals, idx_vec[0]); - let lk2 = builder.get(&table_proof.lk_out_evals, lk2_idx); - let lk3 = builder.get(&table_proof.lk_out_evals, lk3_idx); - let lk4 = builder.get(&table_proof.lk_out_evals, lk4_idx); - - builder.set(&lk_vec, 0, lk1); - builder.set(&lk_vec, 1, lk2); - builder.set(&lk_vec, 2, lk3); - builder.set(&lk_vec, 3, lk4); - - builder.set(&logup_out_evals, idx_vec[0], lk_vec); - }); - let num_fanin: Usize = Usize::from(NUM_FANIN); - let num_proofs: Usize = tower_proof.proofs.len(); - let num_prod_specs: Usize = tower_proof.prod_specs_eval.len(); - let num_logup_specs: Usize = tower_proof.logup_specs_eval.len(); let max_num_variables: Usize = Usize::from(max_expected_rounds); + let prod_out_evals: Array>> = concat( + builder, + &table_proof.record_r_out_evals, + &table_proof.record_w_out_evals, + ); builder.cycle_tracker_start("verify tower proof"); let (rt_tower, prod_point_and_eval, logup_p_point_and_eval, logup_q_point_and_eval) = verify_tower_proof( builder, challenger, - TowerVerifierInputVariable { - prod_out_evals, - logup_out_evals, - num_variables: expected_rounds, - num_fanin, - num_proofs, - num_prod_specs, - num_logup_specs, - max_num_variables, - proofs: tower_proof.proofs.clone(), - prod_specs_eval: tower_proof.prod_specs_eval.clone(), - logup_specs_eval: tower_proof.logup_specs_eval.clone(), - }, + prod_out_evals, + &table_proof.record_lk_out_evals, + expected_rounds, + num_fanin, + max_num_variables, + tower_proof, unipoly_extrapolator, ); builder.cycle_tracker_end("verify tower proof"); builder.assert_usize_eq( logup_q_point_and_eval.len(), - Usize::from(cs.lk_table_expressions.len()), + Usize::from(cs.zkvm_v1_css.lk_table_expressions.len()), ); builder.assert_usize_eq( logup_p_point_and_eval.len(), - Usize::from(cs.lk_table_expressions.len()), + Usize::from(cs.zkvm_v1_css.lk_table_expressions.len()), ); builder.assert_usize_eq( prod_point_and_eval.len(), - Usize::from(cs.r_table_expressions.len() + cs.w_table_expressions.len()), + Usize::from( + cs.zkvm_v1_css.r_table_expressions.len() + cs.zkvm_v1_css.w_table_expressions.len(), + ), ); - // TODO: Current Ceno verifier specification // in table proof, we always skip same point sumcheck for now // as tower sumcheck batch product argument/logup in same length let _is_skip_same_point_sumcheck = true; - let input_opening_point = rt_tower.clone().fs; - let in_evals_len: Usize = builder.eval( - prod_point_and_eval.len() + logup_p_point_and_eval.len() + logup_q_point_and_eval.len(), - ); - let in_evals: Array> = builder.dyn_array(in_evals_len); - builder - .range(0, prod_point_and_eval.len()) - .for_each(|idx_vec, builder| { - let e = builder.get(&prod_point_and_eval, idx_vec[0]).eval; - builder.set(&in_evals, idx_vec[0], e); - }); - builder - .range(0, logup_p_point_and_eval.len()) - .for_each(|idx_vec, builder| { - let p_e = builder.get(&logup_p_point_and_eval, idx_vec[0]).eval; - let q_e = builder.get(&logup_q_point_and_eval, idx_vec[0]).eval; - - let p_idx: Usize = - builder.eval(prod_point_and_eval.len() + idx_vec[0] * Usize::from(2)); - let q_idx: Usize = builder - .eval(prod_point_and_eval.len() + idx_vec[0] * Usize::from(2) + Usize::from(1)); - - builder.set(&in_evals, p_idx, p_e); - builder.set(&in_evals, q_idx, q_e); - }); - // evaluate structural witness from verifier let set_table_exprs = cs + .zkvm_v1_css .r_table_expressions .iter() .map(|r| &r.table_spec) - .chain(cs.lk_table_expressions.iter().map(|r| &r.table_spec)) + .chain( + cs.zkvm_v1_css + .lk_table_expressions + .iter() + .map(|r| &r.table_spec), + ) .collect::>(); let structural_witnesses_vec: Vec> = set_table_exprs .iter() @@ -864,7 +682,7 @@ pub fn verify_table_proof( builder, offset as u32, multi_factor as u32, - &input_opening_point, + &rt_tower.fs, *descending, ) }, @@ -881,14 +699,40 @@ pub fn verify_table_proof( builder.set(&structural_witnesses, idx, e); }); + let in_evals_len: Usize = builder.eval( + prod_point_and_eval.len() + logup_p_point_and_eval.len() + logup_q_point_and_eval.len(), + ); + let in_evals: Array> = builder.dyn_array(in_evals_len); + builder + .range(0, prod_point_and_eval.len()) + .for_each(|idx_vec, builder| { + let e = builder.get(&prod_point_and_eval, idx_vec[0]).eval; + builder.set(&in_evals, idx_vec[0], e); + }); + builder + .range(0, logup_p_point_and_eval.len()) + .for_each(|idx_vec, builder| { + let p_e = builder.get(&logup_p_point_and_eval, idx_vec[0]).eval; + let q_e = builder.get(&logup_q_point_and_eval, idx_vec[0]).eval; + + let p_idx: Usize = + builder.eval(prod_point_and_eval.len() + idx_vec[0] * Usize::from(2)); + let q_idx: Usize = builder + .eval(prod_point_and_eval.len() + idx_vec[0] * Usize::from(2) + Usize::from(1)); + + builder.set(&in_evals, p_idx, p_e); + builder.set(&in_evals, q_idx, q_e); + }); + // verify records (degree = 1) statement, thus no sumcheck interleave( - &cs.r_table_expressions, // r - &cs.w_table_expressions, // w + &cs.zkvm_v1_css.r_table_expressions, // r + &cs.zkvm_v1_css.w_table_expressions, // w ) .map(|rw| &rw.expr) .chain( - cs.lk_table_expressions + cs.zkvm_v1_css + .lk_table_expressions .iter() .flat_map(|lk| vec![&lk.multiplicity, &lk.values]), // p, q ) @@ -907,4 +751,16 @@ pub fn verify_table_proof( let expected_evals = builder.get(&in_evals, idx); builder.assert_ext_eq(e, expected_evals); }); + + // assume public io is tiny vector, so we evaluate it directly without PCS + for &Instance(idx) in cs.instance_name_map().keys() { + let poly = builder.get(raw_pi, idx); + let poly_num_vars = builder.get(raw_pi_num_variables, idx); + let eval_point = rt_tower.fs.slice(builder, 0, poly_num_vars); + let expected_eval = poly_evaluator.evaluate_base_poly_at_point(builder, &poly, &eval_point); + let eval = builder.get(&pi_evals, idx); + builder.assert_ext_eq(eval, expected_eval); + } + + rt_tower.fs } From 1c246a1c79b5dfde42bdff6acd4c65da186efa38 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 21 Jul 2025 21:39:46 +0800 Subject: [PATCH 64/70] add multiple matrices as inputs to the basefold's unit test --- src/basefold_verifier/query_phase.rs | 71 ++++++++++++++++++---------- 1 file changed, 47 insertions(+), 24 deletions(-) diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 9eb9a57..08fb0b3 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -648,6 +648,7 @@ pub(crate) fn batch_verifier_query_phase( #[cfg(test)] pub mod tests { + use ceno_mle::mle; use ceno_transcript::{BasicTranscript, Transcript}; use ff_ext::{BabyBearExt4, FromUniformBytes}; use itertools::Itertools; @@ -697,39 +698,64 @@ pub mod tests { #[test] fn test_verify_query_phase_batch() { let mut rng = thread_rng(); - let m1 = ceno_witness::RowMajorMatrix::::rand(&mut rng, 1 << 10, 10); - let mles_1 = m1.to_mles(); - let matrices = vec![m1]; + // setup PCS let pp = PCS::setup(1 << 20, mpcs::SecurityLevel::Conjecture100bits).unwrap(); let (pp, vp) = pcs_trim::(pp, 1 << 20).unwrap(); + + let (matrices, mles): (Vec<_>, Vec<_>) = vec![(14, 20), (13, 30), (12, 10), (11, 15)] + .into_iter() + .map(|(num_vars, width)| { + let m = ceno_witness::RowMajorMatrix::::rand(&mut rng, 1 << num_vars, width); + let mles = m.to_mles(); + + (m, mles) + }) + .unzip(); + + // commit to matrices let pcs_data = pcs_batch_commit::(&pp, matrices).unwrap(); let comm = PCS::get_pure_commitment(&pcs_data); - let point = E::random_vec(10, &mut rng); - let evals = mles_1.iter().map(|mle| mle.evaluate(&point)).collect_vec(); + let point_and_evals = mles + .iter() + .map(|mles| { + let point = E::random_vec(mles[0].num_vars(), &mut rng); + let evals = mles.iter().map(|mle| mle.evaluate(&point)).collect_vec(); + + (point, evals) + }) + .collect_vec(); - // let evals = mles_1 - // .iter() - // .map(|mle| points.iter().map(|p| mle.evaluate(&p)).collect_vec()) - // .collect::>(); + // batch open let mut transcript = BasicTranscript::::new(&[]); - let rounds = vec![(&pcs_data, vec![(point.clone(), evals.clone())])]; + let rounds = vec![(&pcs_data, point_and_evals.clone())]; let opening_proof = PCS::batch_open(&pp, rounds, &mut transcript).unwrap(); + // batch verify let mut transcript = BasicTranscript::::new(&[]); - let rounds = vec![(comm, vec![(point.len(), (point, evals.clone()))])]; + let rounds = vec![( + comm, + point_and_evals + .iter() + .map(|(point, evals)| (point.len(), (point.clone(), evals.clone()))) + .collect_vec(), + )]; PCS::batch_verify(&vp, rounds.clone(), &opening_proof, &mut transcript) .expect("Native verification failed"); let mut transcript = BasicTranscript::::new(&[]); let batch_coeffs = transcript.sample_and_append_challenge_pows(10, b"batch coeffs"); - let max_num_var = 10; + let max_num_var = point_and_evals + .iter() + .map(|(point, _)| point.len()) + .max() + .unwrap(); let num_rounds = max_num_var; // The final message is of length 1 // prepare folding challenges via sumcheck round msg + FRI commitment - let mut fold_challenges: Vec = Vec::with_capacity(10); + let mut fold_challenges: Vec = Vec::with_capacity(num_rounds); let commits = &opening_proof.commits; let sumcheck_messages = opening_proof.sumcheck_proof.as_ref().unwrap(); @@ -759,26 +785,23 @@ pub mod tests { ); let query_input = QueryPhaseVerifierInput { - // t_inv_halves: vp.encoding_params.t_inv_halves, - max_num_var: 10, + max_num_var, fold_challenges, batch_coeffs, indices: queries, proof: opening_proof.into(), rounds: rounds - .iter() + .into_iter() .map(|round| Round { - commit: round.0.clone().into(), + commit: round.0.into(), openings: round .1 - .iter() - .map(|opening| RoundOpening { - num_var: opening.0, + .into_iter() + .map(|(num_var, (point, evals))| RoundOpening { + num_var, point_and_evals: PointAndEvals { - point: Point { - fs: opening.1.clone().0, - }, - evals: opening.1.clone().1, + point: Point { fs: point }, + evals, }, }) .collect(), From ba4ff6916c0eac9dfc47c3d068895b78abeaeb73 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 22 Jul 2025 00:00:20 +0800 Subject: [PATCH 65/70] fix --- src/basefold_verifier/query_phase.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 08fb0b3..c8c1f51 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -703,11 +703,13 @@ pub mod tests { let pp = PCS::setup(1 << 20, mpcs::SecurityLevel::Conjecture100bits).unwrap(); let (pp, vp) = pcs_trim::(pp, 1 << 20).unwrap(); + let mut num_total_polys = 0; let (matrices, mles): (Vec<_>, Vec<_>) = vec![(14, 20), (13, 30), (12, 10), (11, 15)] .into_iter() .map(|(num_vars, width)| { let m = ceno_witness::RowMajorMatrix::::rand(&mut rng, 1 << num_vars, width); let mles = m.to_mles(); + num_total_polys += width; (m, mles) }) @@ -745,7 +747,7 @@ pub mod tests { .expect("Native verification failed"); let mut transcript = BasicTranscript::::new(&[]); - let batch_coeffs = transcript.sample_and_append_challenge_pows(10, b"batch coeffs"); + let batch_coeffs = transcript.sample_and_append_challenge_pows(num_total_polys, b"batch coeffs"); let max_num_var = point_and_evals .iter() From c5f67c51b441b35f1631af802eb5bc922a30bc14 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 21 Jul 2025 23:12:01 +0800 Subject: [PATCH 66/70] unit test passed --- src/basefold_verifier/query_phase.rs | 41 +++++++++++++++++++--------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index c8c1f51..7798203 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -346,6 +346,20 @@ pub(crate) fn batch_verifier_query_phase( let log2_max_codeword_size: Var = builder.eval(input.max_num_var.clone() + get_rate_log::()); + // this array is shared among all indices + let reduced_codeword_by_height: Array> = + builder.dyn_array(log2_max_codeword_size.clone()); + + let zero: Ext = builder.constant(C::EF::ZERO); + // initialize reduced_codeword_by_height with zeroes + iter_zip!(builder, reduced_codeword_by_height).for_each(|ptr_vec, builder| { + let zero_codeword = PackedCodeword { + low: zero.clone(), + high: zero.clone(), + }; + builder.set_value(&reduced_codeword_by_height, ptr_vec[0], zero_codeword); + }); + iter_zip!(builder, input.indices, input.proof.query_opening_proof).for_each( |ptr_vec, builder| { // TODO: change type of input.indices to be `Array>>` @@ -361,9 +375,6 @@ pub(crate) fn batch_verifier_query_phase( }); let idx_bits = idx_bits.slice(builder, 1, log2_max_codeword_size.clone()); - let reduced_codeword_by_height: Array> = - builder.dyn_array(log2_max_codeword_size.clone()); - let query = builder.iter_ptr_get(&input.proof.query_opening_proof, ptr_vec[1]); let batch_coeffs_offset: Var = builder.constant(C::N::ZERO); @@ -491,13 +502,21 @@ pub(crate) fn batch_verifier_query_phase( builder.assign(&log2_height, log2_height - Usize::from(1)); let folded_idx = builder.get(&idx_bits, i); - // TODO: absorb smaller codeword - let new_involved_codeword: Ext = builder.constant(C::EF::ZERO); + let new_involved_packed_codeword = + builder.get(&reduced_codeword_by_height, log2_height.clone()); + + builder.if_eq(folded_idx, Usize::from(0)).then_or_else( + |builder| { + builder.assign(&folded, folded + new_involved_packed_codeword.low); + }, + |builder| { + builder.assign(&folded, folded + new_involved_packed_codeword.high); + }, + ); // leafs let leafs = builder.dyn_array(2); let sibling_idx = builder.eval_expr(RVar::from(1) - folded_idx); - builder.assign(&folded, folded + new_involved_codeword); builder.set_value(&leafs, folded_idx, folded); builder.set_value(&leafs, sibling_idx, sibling_value); @@ -648,13 +667,12 @@ pub(crate) fn batch_verifier_query_phase( #[cfg(test)] pub mod tests { - use ceno_mle::mle; use ceno_transcript::{BasicTranscript, Transcript}; use ff_ext::{BabyBearExt4, FromUniformBytes}; use itertools::Itertools; use mpcs::{ - pcs_batch_commit, pcs_setup, pcs_trim, util::hash::write_digest_to_transcript, - BasefoldDefault, PolynomialCommitmentScheme, + pcs_batch_commit, pcs_trim, util::hash::write_digest_to_transcript, BasefoldDefault, + PolynomialCommitmentScheme, }; use mpcs::{BasefoldRSParams, BasefoldSpec, PCSFriParam}; use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; @@ -663,8 +681,6 @@ pub mod tests { use openvm_native_recursion::hints::Hintable; use openvm_stark_backend::p3_challenger::GrindingChallenger; use openvm_stark_sdk::p3_baby_bear::BabyBear; - use p3_field::Field; - use p3_field::FieldAlgebra; use rand::thread_rng; type F = BabyBear; @@ -673,11 +689,10 @@ pub mod tests { use crate::basefold_verifier::basefold::{Round, RoundOpening}; use crate::basefold_verifier::query_phase::PointAndEvals; - use crate::tower_verifier::binding::{Point, PointAndEval}; + use crate::tower_verifier::binding::Point; use super::{batch_verifier_query_phase, QueryPhaseVerifierInput}; - #[allow(dead_code)] pub fn build_batch_verifier_query_phase( input: QueryPhaseVerifierInput, ) -> (Program, Vec>) { From 24ba27e6e9ad4c5813d9c6b8b7723a2172bdfa71 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 22 Jul 2025 00:18:02 +0800 Subject: [PATCH 67/70] support matrices that have same height --- src/basefold_verifier/query_phase.rs | 37 ++++++++++++++++------------ 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 7798203..c92d7bd 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -346,19 +346,7 @@ pub(crate) fn batch_verifier_query_phase( let log2_max_codeword_size: Var = builder.eval(input.max_num_var.clone() + get_rate_log::()); - // this array is shared among all indices - let reduced_codeword_by_height: Array> = - builder.dyn_array(log2_max_codeword_size.clone()); - let zero: Ext = builder.constant(C::EF::ZERO); - // initialize reduced_codeword_by_height with zeroes - iter_zip!(builder, reduced_codeword_by_height).for_each(|ptr_vec, builder| { - let zero_codeword = PackedCodeword { - low: zero.clone(), - high: zero.clone(), - }; - builder.set_value(&reduced_codeword_by_height, ptr_vec[0], zero_codeword); - }); iter_zip!(builder, input.indices, input.proof.query_opening_proof).for_each( |ptr_vec, builder| { @@ -375,6 +363,16 @@ pub(crate) fn batch_verifier_query_phase( }); let idx_bits = idx_bits.slice(builder, 1, log2_max_codeword_size.clone()); + let reduced_codeword_by_height: Array> = + builder.dyn_array(log2_max_codeword_size.clone()); + // initialize reduced_codeword_by_height with zeroes + iter_zip!(builder, reduced_codeword_by_height).for_each(|ptr_vec, builder| { + let zero_codeword = PackedCodeword { + low: zero.clone(), + high: zero.clone(), + }; + builder.set_value(&reduced_codeword_by_height, ptr_vec[0], zero_codeword); + }); let query = builder.iter_ptr_get(&input.proof.query_opening_proof, ptr_vec[1]); let batch_coeffs_offset: Var = builder.constant(C::N::ZERO); @@ -451,8 +449,14 @@ pub(crate) fn batch_verifier_query_phase( builder.assign(&high, high + coeff * high_value); }, ); - let codeword = PackedCodeword { low, high }; - builder.set_value(&reduced_codeword_by_height, log2_height, codeword); + let codeword: PackedCodeword = PackedCodeword { low, high }; + let codeword_acc = builder.get(&reduced_codeword_by_height, log2_height); + + // reduced_openings[log2_height] += codeword + builder.assign(&codeword_acc.low, codeword_acc.low + codeword.low); + builder.assign(&codeword_acc.high, codeword_acc.high + codeword.high); + + builder.set_value(&reduced_codeword_by_height, log2_height, codeword_acc); builder.assign(&batch_coeffs_offset, batch_coeffs_next_offset); }); }); @@ -719,7 +723,7 @@ pub mod tests { let (pp, vp) = pcs_trim::(pp, 1 << 20).unwrap(); let mut num_total_polys = 0; - let (matrices, mles): (Vec<_>, Vec<_>) = vec![(14, 20), (13, 30), (12, 10), (11, 15)] + let (matrices, mles): (Vec<_>, Vec<_>) = vec![(14, 20), (14, 30), (13, 30), (12, 10), (11, 15)] .into_iter() .map(|(num_vars, width)| { let m = ceno_witness::RowMajorMatrix::::rand(&mut rng, 1 << num_vars, width); @@ -762,7 +766,8 @@ pub mod tests { .expect("Native verification failed"); let mut transcript = BasicTranscript::::new(&[]); - let batch_coeffs = transcript.sample_and_append_challenge_pows(num_total_polys, b"batch coeffs"); + let batch_coeffs = + transcript.sample_and_append_challenge_pows(num_total_polys, b"batch coeffs"); let max_num_var = point_and_evals .iter() From f2faf8ea237832608d6633b4fb4098415b17259c Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 22 Jul 2025 00:28:02 +0800 Subject: [PATCH 68/70] refactor tests --- src/basefold_verifier/query_phase.rs | 29 +++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index c92d7bd..8cc8e7b 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -714,8 +714,7 @@ pub mod tests { (program, witness_stream) } - #[test] - fn test_verify_query_phase_batch() { + fn construct_test(dimensions: Vec<(usize, usize)>) { let mut rng = thread_rng(); // setup PCS @@ -723,7 +722,7 @@ pub mod tests { let (pp, vp) = pcs_trim::(pp, 1 << 20).unwrap(); let mut num_total_polys = 0; - let (matrices, mles): (Vec<_>, Vec<_>) = vec![(14, 20), (14, 30), (13, 30), (12, 10), (11, 15)] + let (matrices, mles): (Vec<_>, Vec<_>) = dimensions .into_iter() .map(|(num_vars, width)| { let m = ceno_witness::RowMajorMatrix::::rand(&mut rng, 1 << num_vars, width); @@ -846,4 +845,28 @@ pub mod tests { println!("=> cycle count: {:?}", seg.metrics.cycle_count); } } + + #[test] + fn test_simple_batch() { + for num_var in 5..20 { + construct_test(vec![(num_var, 20)]); + } + } + + #[test] + fn test_decreasing_batch() { + construct_test(vec![ + (14, 20), + (14, 40), + (13, 30), + (12, 30), + (11, 10), + (10, 15), + ]); + } + + #[test] + fn test_random_batch() { + construct_test(vec![(10, 20), (12, 30), (11, 10), (12, 15)]); + } } From eba50fdd5efe411a4bdc3a380fd84486c73220cd Mon Sep 17 00:00:00 2001 From: Cyte Zhang Date: Thu, 24 Jul 2025 11:31:27 +0800 Subject: [PATCH 69/70] Complete Basefold verifier and Basefold e2e integration (#36) * Switch ceno reliance * Fix compilation errors due to out of date code * Update test query phase batch * Fix query opening proof * Implement basefold proof variable * Update query phase verifier input * Preparing test data for query phase with updated code * Implement basefold proof transform * Prepare query phase verifier input * Prepare query phase verifier input * Fix final message access * Switch ceno reliance to small field support * Create basefold verifier function * Check final message sizes * Fix final message size * Fix final message size * Check query opening proof len * Compute total number of polys * Sample batch coeffs * Compute max_num_var * Write sumcheck messages and commits to transcript * Write final message to transcript * Complete the code for batch verifier * Add verifier test * Try to fix some compilation errors in e2e * Connecting pcs with e2e * Fix some issues after merge * Make compilation pass temporarily * Make test pass before query phase * Supply the permutation and make the random case pass * Try fixing transcript inconsistency * Use bin to dec le * Add pow witness * Basefold verifier passes for simple case * Update dependency * Basefold verifier passes decreasing and random batches * update ceno dependencies * comment out patch * refactor * the computation of max_num_var is simplified * put perm to RoundVariable * remove debug routines * rename * clean * cleanup * ignore e2e test --------- Co-authored-by: kunxian xia --- Cargo.lock | 181 ++++++++----- Cargo.toml | 60 ++--- src/basefold_verifier/basefold.rs | 33 ++- src/basefold_verifier/mod.rs | 1 + src/basefold_verifier/query_phase.rs | 17 +- src/basefold_verifier/rs.rs | 12 +- src/basefold_verifier/structs.rs | 4 +- src/basefold_verifier/verifier.rs | 376 +++++++++++++++++++++++++++ src/e2e/mod.rs | 20 +- src/transcript/mod.rs | 17 +- src/zkvm_verifier/binding.rs | 11 +- src/zkvm_verifier/verifier.rs | 18 +- 12 files changed, 596 insertions(+), 154 deletions(-) create mode 100644 src/basefold_verifier/verifier.rs diff --git a/Cargo.lock b/Cargo.lock index 7c479b7..a62a669 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -516,9 +516,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.29" +version = "1.2.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c1599538de2394445747c8cf7935946e3cc27e9625f889d979bfb2aaf569362" +checksum = "deec109607ca693028562ed836a5f1c4b8bd77755c4e132fc5ce11b0b6211ae7" dependencies = [ "shlex", ] @@ -526,6 +526,7 @@ dependencies = [ [[package]] name = "ceno-examples" version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "glob", ] @@ -576,6 +577,7 @@ dependencies = [ [[package]] name = "ceno_emul" version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "anyhow", "ceno_rt", @@ -598,6 +600,7 @@ dependencies = [ [[package]] name = "ceno_host" version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "anyhow", "ceno_emul", @@ -610,6 +613,7 @@ dependencies = [ [[package]] name = "ceno_rt" version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "getrandom 0.2.16", "rkyv", @@ -618,6 +622,7 @@ dependencies = [ [[package]] name = "ceno_zkvm" version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "base64", "bincode", @@ -631,8 +636,10 @@ dependencies = [ "gkr_iop", "glob", "itertools 0.13.0", + "keccakf", "mpcs", "multilinear_extensions", + "ndarray", "num-traits", "p3", "parse-size", @@ -1134,6 +1141,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "once_cell", "p3", @@ -1223,6 +1231,7 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "gkr_iop" version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "ark-std 0.5.0", "bincode", @@ -1233,7 +1242,6 @@ dependencies = [ "itertools 0.13.0", "mpcs", "multilinear_extensions", - "ndarray", "p3", "p3-field", "p3-goldilocks", @@ -1563,6 +1571,15 @@ dependencies = [ "cpufeatures", ] +[[package]] +name = "keccakf" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d4ade81a4c9327bf19dcd0bd45784b99f86243edca6be0de19fc2d3aa8a4de2" +dependencies = [ + "crunchy", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -1586,9 +1603,9 @@ checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" [[package]] name = "libredox" -version = "0.1.4" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1580801010e535496706ba011c15f8532df6b42297d2e471fec38ceadd8c0638" +checksum = "4488594b9328dee448adb906d8b126d9b7deb7cf5c22161ee591610bb1be83c0" dependencies = [ "bitflags", "libc", @@ -1701,6 +1718,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "aes", "bincode", @@ -1731,6 +1749,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "either", "ff_ext", @@ -1928,8 +1947,8 @@ checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" [[package]] name = "openvm" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "bytemuck", "num-bigint 0.4.6", @@ -1941,8 +1960,8 @@ dependencies = [ [[package]] name = "openvm-circuit" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "backtrace", "cfg-if", @@ -1972,8 +1991,8 @@ dependencies = [ [[package]] name = "openvm-circuit-derive" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "itertools 0.14.0", "quote", @@ -1982,8 +2001,8 @@ dependencies = [ [[package]] name = "openvm-circuit-primitives" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "derive-new 0.6.0", "itertools 0.14.0", @@ -1997,8 +2016,8 @@ dependencies = [ [[package]] name = "openvm-circuit-primitives-derive" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "itertools 0.14.0", "quote", @@ -2008,7 +2027,7 @@ dependencies = [ [[package]] name = "openvm-custom-insn" version = "0.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "proc-macro2", "quote", @@ -2017,8 +2036,8 @@ dependencies = [ [[package]] name = "openvm-instructions" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "backtrace", "derive-new 0.6.0", @@ -2034,8 +2053,8 @@ dependencies = [ [[package]] name = "openvm-instructions-derive" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "quote", "syn 2.0.104", @@ -2043,8 +2062,8 @@ dependencies = [ [[package]] name = "openvm-native-circuit" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", @@ -2070,8 +2089,8 @@ dependencies = [ [[package]] name = "openvm-native-compiler" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "backtrace", "itertools 0.14.0", @@ -2092,8 +2111,8 @@ dependencies = [ [[package]] name = "openvm-native-compiler-derive" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "quote", "syn 2.0.104", @@ -2101,8 +2120,8 @@ dependencies = [ [[package]] name = "openvm-native-recursion" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "cfg-if", "itertools 0.14.0", @@ -2125,10 +2144,9 @@ dependencies = [ [[package]] name = "openvm-platform" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ - "getrandom 0.2.16", "libm", "openvm-custom-insn", "openvm-rv32im-guest", @@ -2136,8 +2154,8 @@ dependencies = [ [[package]] name = "openvm-poseidon2-air" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "derivative", "lazy_static", @@ -2153,8 +2171,8 @@ dependencies = [ [[package]] name = "openvm-rv32im-circuit" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", @@ -2176,17 +2194,18 @@ dependencies = [ [[package]] name = "openvm-rv32im-guest" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "openvm-custom-insn", + "p3-field", "strum_macros", ] [[package]] name = "openvm-rv32im-transpiler" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "openvm-instructions", "openvm-instructions-derive", @@ -2201,8 +2220,8 @@ dependencies = [ [[package]] name = "openvm-stark-backend" -version = "1.0.0" -source = "git+https://github.com/openvm-org/stark-backend.git?tag=v1.0.0#884f8e6aabf72bde00dc51f1f1121277bff73b1e" +version = "1.1.1" +source = "git+https://github.com/openvm-org/stark-backend.git?rev=f48090c9febd021f8ee0349bc929a775fb1fa3ad#f48090c9febd021f8ee0349bc929a775fb1fa3ad" dependencies = [ "bitcode", "cfg-if", @@ -2226,8 +2245,8 @@ dependencies = [ [[package]] name = "openvm-stark-sdk" -version = "1.0.0" -source = "git+https://github.com/openvm-org/stark-backend.git?tag=v1.0.0#884f8e6aabf72bde00dc51f1f1121277bff73b1e" +version = "1.1.1" +source = "git+https://github.com/openvm-org/stark-backend.git?rev=f48090c9febd021f8ee0349bc929a775fb1fa3ad#f48090c9febd021f8ee0349bc929a775fb1fa3ad" dependencies = [ "derivative", "derive_more 0.99.20", @@ -2244,6 +2263,7 @@ dependencies = [ "p3-fri", "p3-goldilocks", "p3-keccak", + "p3-koala-bear", "p3-merkle-tree", "p3-poseidon", "p3-poseidon2", @@ -2261,8 +2281,8 @@ dependencies = [ [[package]] name = "openvm-transpiler" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#96c432c2fdbe7e1516315c3de6beba22b2714f43" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "elf", "eyre", @@ -2291,6 +2311,7 @@ checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] name = "p3" version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "p3-baby-bear", "p3-challenger", @@ -2312,7 +2333,7 @@ dependencies = [ [[package]] name = "p3-air" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "p3-field", "p3-matrix", @@ -2321,7 +2342,7 @@ dependencies = [ [[package]] name = "p3-baby-bear" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "p3-field", "p3-mds", @@ -2335,7 +2356,7 @@ dependencies = [ [[package]] name = "p3-blake3" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "blake3", "p3-symmetric", @@ -2345,7 +2366,7 @@ dependencies = [ [[package]] name = "p3-bn254-fr" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "ff 0.13.1", "halo2curves 0.8.0", @@ -2360,7 +2381,7 @@ dependencies = [ [[package]] name = "p3-challenger" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "p3-field", "p3-maybe-rayon", @@ -2372,7 +2393,7 @@ dependencies = [ [[package]] name = "p3-commit" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "p3-challenger", @@ -2386,7 +2407,7 @@ dependencies = [ [[package]] name = "p3-dft" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "p3-field", @@ -2399,7 +2420,7 @@ dependencies = [ [[package]] name = "p3-field" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "num-bigint 0.4.6", @@ -2416,7 +2437,7 @@ dependencies = [ [[package]] name = "p3-fri" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "p3-challenger", @@ -2435,7 +2456,7 @@ dependencies = [ [[package]] name = "p3-goldilocks" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "num-bigint 0.4.6", "p3-dft", @@ -2452,7 +2473,7 @@ dependencies = [ [[package]] name = "p3-interpolation" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "p3-field", "p3-matrix", @@ -2463,7 +2484,7 @@ dependencies = [ [[package]] name = "p3-keccak" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "p3-field", @@ -2472,10 +2493,24 @@ dependencies = [ "tiny-keccak", ] +[[package]] +name = "p3-koala-bear" +version = "0.1.0" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" +dependencies = [ + "p3-field", + "p3-mds", + "p3-monty-31", + "p3-poseidon2", + "p3-symmetric", + "rand", + "serde", +] + [[package]] name = "p3-matrix" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "p3-field", @@ -2490,7 +2525,7 @@ dependencies = [ [[package]] name = "p3-maybe-rayon" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "rayon", ] @@ -2498,7 +2533,7 @@ dependencies = [ [[package]] name = "p3-mds" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "p3-dft", @@ -2512,7 +2547,7 @@ dependencies = [ [[package]] name = "p3-merkle-tree" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "p3-commit", @@ -2529,7 +2564,7 @@ dependencies = [ [[package]] name = "p3-monty-31" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "num-bigint 0.4.6", @@ -2550,7 +2585,7 @@ dependencies = [ [[package]] name = "p3-poseidon" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "p3-field", "p3-mds", @@ -2561,7 +2596,7 @@ dependencies = [ [[package]] name = "p3-poseidon2" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "gcd", "p3-field", @@ -2573,7 +2608,7 @@ dependencies = [ [[package]] name = "p3-poseidon2-air" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "p3-air", "p3-field", @@ -2589,7 +2624,7 @@ dependencies = [ [[package]] name = "p3-symmetric" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "p3-field", @@ -2599,7 +2634,7 @@ dependencies = [ [[package]] name = "p3-uni-stark" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "p3-air", @@ -2617,7 +2652,7 @@ dependencies = [ [[package]] name = "p3-util" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "serde", ] @@ -2734,6 +2769,7 @@ dependencies = [ [[package]] name = "poseidon" version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "criterion", "ff_ext", @@ -3201,9 +3237,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.140" +version = "1.0.141" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +checksum = "30b9eff21ebe718216c6ec64e1d9ac57087aad11efc64e32002bce4a0d4c03d3" dependencies = [ "itoa", "memchr", @@ -3342,6 +3378,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "crossbeam-channel", "either", @@ -3360,6 +3397,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "itertools 0.13.0", "p3", @@ -3633,6 +3671,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "crossbeam-channel", "ff_ext", @@ -3810,6 +3849,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "bincode", "blake2", @@ -4045,6 +4085,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/Cargo.toml b/Cargo.toml index eea354f..41488ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,8 +10,8 @@ openvm-native-circuit = { git = "https://github.com/scroll-tech/openvm.git", bra openvm-native-compiler = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/native_multi_observe", default-features = false } openvm-native-compiler-derive = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/native_multi_observe", default-features = false } openvm-native-recursion = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/native_multi_observe", default-features = false } -openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.0.0", default-features = false } -openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.0.0", default-features = false } +openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", rev = "f48090c9febd021f8ee0349bc929a775fb1fa3ad", default-features = false } +openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", rev = "f48090c9febd021f8ee0349bc929a775fb1fa3ad", default-features = false } rand = { version = "0.8.5", default-features = false } itertools = { version = "0.13.0", default-features = false } @@ -19,17 +19,17 @@ bincode = "1.3.3" tracing = "0.1.40" # Plonky3 -p3-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } -p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } -p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } -p3-commit = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } -p3-matrix = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } -p3-util = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } -p3-challenger = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } -p3-monty-31 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } -p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } -p3-fri = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } -p3-goldilocks = { git = "https://github.com/Plonky3/plonky3", rev = "1ba4e5c" } +p3-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-commit = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-matrix = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-util = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-challenger = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-monty-31 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-fri = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-goldilocks = { git = "https://github.com/Plonky3/plonky3", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } # WHIR ark-std = { version = "0.5", features = ["std"] } @@ -38,26 +38,26 @@ ark-poly = "0.5" ark-serialize = "0.5" # Ceno -ceno_mle = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support", package = "multilinear_extensions" } -ceno_sumcheck = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support", package = "sumcheck" } -ceno_transcript = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support", package = "transcript" } -ceno_witness = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support", package = "witness" } -ceno_zkvm = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" } -ceno_emul = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" } -mpcs = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" } -ff_ext = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" } +ceno_mle = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc", package = "multilinear_extensions" } +ceno_sumcheck = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc", package = "sumcheck" } +ceno_transcript = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc", package = "transcript" } +ceno_witness = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc", package = "witness" } +ceno_zkvm = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc" } +ceno_emul = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc" } +mpcs = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc" } +ff_ext = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc" } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" [features] bench-metrics = ["openvm-circuit/bench-metrics"] -[patch."https://github.com/scroll-tech/ceno.git"] -ceno_mle = { path = "../ceno/multilinear_extensions", package = "multilinear_extensions" } -ceno_sumcheck = { path = "../ceno/sumcheck", package = "sumcheck" } -ceno_transcript = { path = "../ceno/transcript", package = "transcript" } -ceno_witness = { path = "../ceno/witness", package = "witness" } -ceno_zkvm = { path = "../ceno/ceno_zkvm" } -ceno_emul = { path = "../ceno/ceno_emul" } -mpcs = { path = "../ceno/mpcs" } -ff_ext = { path = "../ceno/ff_ext" } +# [patch."https://github.com/scroll-tech/ceno.git"] +# ceno_mle = { path = "../ceno/multilinear_extensions", package = "multilinear_extensions" } +# ceno_sumcheck = { path = "../ceno/sumcheck", package = "sumcheck" } +# ceno_transcript = { path = "../ceno/transcript", package = "transcript" } +# ceno_witness = { path = "../ceno/witness", package = "witness" } +# ceno_zkvm = { path = "../ceno/ceno_zkvm" } +# ceno_emul = { path = "../ceno/ceno_emul" } +# mpcs = { path = "../ceno/mpcs" } +# ff_ext = { path = "../ceno/ff_ext" } diff --git a/src/basefold_verifier/basefold.rs b/src/basefold_verifier/basefold.rs index 12735ef..e8af6b8 100644 --- a/src/basefold_verifier/basefold.rs +++ b/src/basefold_verifier/basefold.rs @@ -1,5 +1,9 @@ +use std::collections::BTreeMap; + +use itertools::Itertools; use mpcs::basefold::BasefoldProof as InnerBasefoldProof; use openvm_native_compiler::{asm::AsmConfig, prelude::*}; +use openvm_native_compiler_derive::iter_zip; use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; @@ -78,6 +82,7 @@ pub struct BasefoldProof { pub final_message: Vec>, pub query_opening_proof: QueryOpeningProofs, pub sumcheck_proof: Vec, + pub pow_witness: F, } #[derive(DslVariable, Clone)] @@ -86,6 +91,7 @@ pub struct BasefoldProofVariable { pub final_message: Array>>, pub query_opening_proof: QueryOpeningProofsVariable, pub sumcheck_proof: Array>, + pub pow_witness: Felt, } impl Hintable for BasefoldProof { @@ -95,11 +101,13 @@ impl Hintable for BasefoldProof { let final_message = Vec::>::read(builder); let query_opening_proof = QueryOpeningProofs::read(builder); let sumcheck_proof = Vec::::read(builder); + let pow_witness = F::read(builder); BasefoldProofVariable { commits, final_message, query_opening_proof, sumcheck_proof, + pow_witness, } } @@ -109,6 +117,7 @@ impl Hintable for BasefoldProof { stream.extend(self.final_message.write()); stream.extend(self.query_opening_proof.write()); stream.extend(self.sumcheck_proof.write()); + stream.extend(self.pow_witness.write()); stream } } @@ -126,6 +135,7 @@ impl From> for BasefoldProof { sumcheck_proof: proof.sumcheck_proof.map_or(vec![], |proof| { proof.into_iter().map(|proof| proof.into()).collect() }), + pow_witness: proof.pow_witness, } } } @@ -173,6 +183,7 @@ pub struct Round { pub struct RoundVariable { pub commit: BasefoldCommitmentVariable, pub openings: Array>, + pub perm: Array>, } impl Hintable for Round { @@ -180,13 +191,33 @@ impl Hintable for Round { fn read(builder: &mut Builder) -> Self::HintVariable { let commit = BasefoldCommitment::read(builder); let openings = Vec::::read(builder); - RoundVariable { commit, openings } + let perm = Vec::::read(builder); + RoundVariable { + commit, + openings, + perm, + } } fn write(&self) -> Vec::N>> { + let mut perm = vec![0; self.openings.len()]; + self.openings + .iter() + .enumerate() + // the original order + .map(|(i, opening)| (i, opening.num_var)) + .sorted_by(|(_, nv_a), (_, nv_b)| Ord::cmp(nv_b, nv_a)) + .enumerate() + // j is the new index where i is the original index + .map(|(j, (i, _))| (i, j)) + .for_each(|(i, j)| { + perm[i] = j; + }); let mut stream = vec![]; stream.extend(self.commit.write()); stream.extend(self.openings.write()); + stream.extend(perm.write()); + stream } } diff --git a/src/basefold_verifier/mod.rs b/src/basefold_verifier/mod.rs index 0353225..fac53c8 100644 --- a/src/basefold_verifier/mod.rs +++ b/src/basefold_verifier/mod.rs @@ -7,3 +7,4 @@ pub(crate) mod rs; pub(crate) mod structs; // pub(crate) mod field; pub(crate) mod utils; +pub(crate) mod verifier; diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index 8cc8e7b..c7ff136 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -298,7 +298,7 @@ pub struct PackedCodeword { pub high: Ext, } -pub(crate) fn batch_verifier_query_phase( +pub(crate) fn batch_verifier_query_phase( builder: &mut Builder, input: QueryPhaseVerifierInputVariable, ) { @@ -344,7 +344,7 @@ pub(crate) fn batch_verifier_query_phase( let final_codeword = encode_small(builder, final_rmm); let log2_max_codeword_size: Var = - builder.eval(input.max_num_var.clone() + get_rate_log::()); + builder.eval(input.max_num_var.clone() + Usize::from(get_rate_log())); let zero: Ext = builder.constant(C::EF::ZERO); @@ -391,12 +391,9 @@ pub(crate) fn batch_verifier_query_phase( let j = j_vec[0]; let mat_j = builder.get(&opened_values, j); let num_var_j = builder.get(&round.openings, j).num_var; - let height_j = - builder.eval(num_var_j + get_rate_log::() - Usize::from(1)); + let height_j = builder.eval(num_var_j + Usize::from(get_rate_log() - 1)); - // TODO: use permutation to get the index - // let permuted_index = builder.get(&perm, j); - let permuted_j = j; + let permuted_j = builder.get(&round.perm, j); builder.set_value(&perm_opened_values, permuted_j, mat_j); builder.set_value(&dimensions, permuted_j, height_j); @@ -423,7 +420,7 @@ pub(crate) fn batch_verifier_query_phase( let opened_value = builder.iter_ptr_get(&opened_values, ptr_vec[0]); let opening = builder.iter_ptr_get(&round.openings, ptr_vec[1]); let log2_height: Var = - builder.eval(opening.num_var + get_rate_log::() - Usize::from(1)); + builder.eval(opening.num_var + Usize::from(get_rate_log() - 1)); let width = opening.point_and_evals.evals.len(); let batch_coeffs_next_offset: Var = @@ -466,7 +463,7 @@ pub(crate) fn batch_verifier_query_phase( // fold 1st codeword let cur_num_var: Var = builder.eval(input.max_num_var.clone()); let log2_height: Var = - builder.eval(cur_num_var + get_rate_log::() - Usize::from(1)); + builder.eval(cur_num_var + Usize::from(get_rate_log() - 1)); let r = builder.get(&input.fold_challenges, 0); let codeword = builder.get(&reduced_codeword_by_height, log2_height); @@ -639,7 +636,7 @@ pub(crate) fn batch_verifier_query_phase( let point = &point_and_evals.point; let num_vars_evaluated: Var = - builder.eval(point.fs.len() - get_basecode_msg_size_log::()); + builder.eval(point.fs.len() - Usize::from(get_basecode_msg_size_log())); let final_message = builder.get(&input.proof.final_message, j); // coeff is the eq polynomial evaluated at the first challenge.len() variables diff --git a/src/basefold_verifier/rs.rs b/src/basefold_verifier/rs.rs index a533eaa..0f4db95 100644 --- a/src/basefold_verifier/rs.rs +++ b/src/basefold_verifier/rs.rs @@ -93,12 +93,16 @@ impl DenseMatrixVariable { } } -pub fn get_rate_log() -> Usize { - Usize::from(1) +pub fn get_rate_log() -> usize { + 1 } -pub fn get_basecode_msg_size_log() -> Usize { - Usize::from(0) +pub fn get_basecode_msg_size_log() -> usize { + 0 +} + +pub fn get_num_queries() -> usize { + 100 } pub fn verifier_folding_coeffs_level( diff --git a/src/basefold_verifier/structs.rs b/src/basefold_verifier/structs.rs index 1c31594..724ff19 100644 --- a/src/basefold_verifier/structs.rs +++ b/src/basefold_verifier/structs.rs @@ -113,7 +113,7 @@ pub fn get_base_codeword_dimensions( let fixed_num_polys = tmp.fixed_num_polys; // wit_dim // let width = builder.eval(witin_num_polys * Usize::from(2)); - let height_exp = builder.eval(witin_num_vars + get_rate_log::() - Usize::from(1)); + let height_exp = builder.eval(witin_num_vars + Usize::from(get_rate_log() - 1)); // let height = pow_2(builder, height_exp); // let next_wit: DimensionsVariable = DimensionsVariable { width, height }; // Only keep the height because the width is not needed in the mmcs batch @@ -129,7 +129,7 @@ pub fn get_base_codeword_dimensions( .then(|builder| { // let width = builder.eval(fixed_num_polys.clone() * Usize::from(2)); let height_exp = - builder.eval(fixed_num_vars.clone() + get_rate_log::() - Usize::from(1)); + builder.eval(fixed_num_vars.clone() + Usize::from(get_rate_log() - 1)); // XXX: more efficient pow implementation // let height = pow_2(builder, height_exp); // let next_fixed: DimensionsVariable = DimensionsVariable { width, height }; diff --git a/src/basefold_verifier/verifier.rs b/src/basefold_verifier/verifier.rs new file mode 100644 index 0000000..5eb39de --- /dev/null +++ b/src/basefold_verifier/verifier.rs @@ -0,0 +1,376 @@ +use crate::{ + basefold_verifier::query_phase::{batch_verifier_query_phase, QueryPhaseVerifierInputVariable}, + transcript::{transcript_check_pow_witness, transcript_observe_label}, +}; + +use super::{basefold::*, extension_mmcs::*, mmcs::*, rs::*, structs::*, utils::*}; +use ff_ext::{BabyBearExt4, ExtensionField, PoseidonField}; +use openvm_native_compiler::{asm::AsmConfig, ir::FromConstant, prelude::*}; +use openvm_native_compiler_derive::iter_zip; +use openvm_native_recursion::{ + challenger::{ + duplex::DuplexChallengerVariable, CanObserveDigest, CanObserveVariable, + CanSampleBitsVariable, CanSampleVariable, FeltChallenger, + }, + hints::{Hintable, VecAutoHintable}, + vars::HintSlice, +}; +use openvm_stark_sdk::p3_baby_bear::BabyBear; +use p3_field::FieldAlgebra; + +pub type F = BabyBear; +pub type E = BabyBearExt4; +pub type InnerConfig = AsmConfig; + +pub fn batch_verify( + builder: &mut Builder, + max_num_var: Var, + rounds: Array>, + proof: BasefoldProofVariable, + challenger: &mut DuplexChallengerVariable, +) { + builder.assert_nonzero(&proof.final_message.len()); + builder.assert_nonzero(&proof.sumcheck_proof.len()); + + // we don't support early stopping for now + iter_zip!(builder, proof.final_message).for_each(|ptr_vec, builder| { + let final_message_len = builder.iter_ptr_get(&proof.final_message, ptr_vec[0]).len(); + builder.assert_eq::>( + final_message_len, + Usize::from(1 << get_basecode_msg_size_log()), + ); + }); + + builder.assert_eq::>( + proof.query_opening_proof.len(), + Usize::from(get_num_queries()), + ); + + // Compute the total number of polynomials across all rounds + let total_num_polys: Var = builder.constant(C::N::ZERO); + iter_zip!(builder, rounds).for_each(|ptr_vec, builder| { + let openings = builder.iter_ptr_get(&rounds, ptr_vec[0]).openings; + iter_zip!(builder, openings).for_each(|ptr_vec_openings, builder| { + let evals_num = builder + .iter_ptr_get(&openings, ptr_vec_openings[0]) + .point_and_evals + .evals + .len(); + builder.assign(&total_num_polys, total_num_polys + evals_num); + }); + }); + + // get batch coeffs + transcript_observe_label(builder, challenger, b"batch coeffs"); + let batch_coeff = challenger.sample_ext(builder); + let running_coeff = + as FromConstant>::constant(C::EF::from_canonical_usize(1), builder); + let batch_coeffs: Array> = builder.dyn_array(total_num_polys); + + iter_zip!(builder, batch_coeffs).for_each(|ptr_vec_batch_coeffs, builder| { + builder.iter_ptr_set(&batch_coeffs, ptr_vec_batch_coeffs[0], running_coeff); + builder.assign(&running_coeff, running_coeff * batch_coeff); + }); + + // The max num var is provided by the prover and not guaranteed to be correct. + // Check that + // 1. it is greater than or equal to every num var; + // 2. it is equal to at least one of the num vars by multiplying all the differences + // together and assert the product is zero. + let diff_product: Var = builder.eval(Usize::from(1)); + iter_zip!(builder, rounds).for_each(|ptr_vec, builder| { + let round = builder.iter_ptr_get(&rounds, ptr_vec[0]); + + iter_zip!(builder, round.openings).for_each(|ptr_vec_opening, builder| { + let opening = builder.iter_ptr_get(&round.openings, ptr_vec_opening[0]); + let diff: Var = builder.eval(max_num_var.clone() - opening.num_var); + // num_var is always smaller than 32. + builder.range_check_var(diff, 5); + builder.assign(&diff_product, diff_product * diff); + }); + }); + // Check that at least one num_var is equal to max_num_var + let zero: Var = builder.eval(C::N::ZERO); + builder.assert_eq::>(diff_product, zero); + + let num_rounds: Var = + builder.eval(max_num_var - Usize::from(get_basecode_msg_size_log())); + + let fold_challenges: Array> = builder.dyn_array(max_num_var); + builder.range(0, num_rounds).for_each(|index_vec, builder| { + let sumcheck_message = builder.get(&proof.sumcheck_proof, index_vec[0]).evaluations; + iter_zip!(builder, sumcheck_message).for_each(|ptr_vec_sumcheck_message, builder| { + let elem = builder.iter_ptr_get(&sumcheck_message, ptr_vec_sumcheck_message[0]); + let elem_felts = builder.ext2felt(elem); + challenger.observe_slice(builder, elem_felts); + }); + + transcript_observe_label(builder, challenger, b"commit round"); + let challenge = challenger.sample_ext(builder); + builder.set(&fold_challenges, index_vec[0], challenge); + builder + .if_ne(index_vec[0], num_rounds - Usize::from(1)) + .then(|builder| { + let commit = builder.get(&proof.commits, index_vec[0]); + challenger.observe_digest(builder, commit.value.into()); + }); + }); + + iter_zip!(builder, proof.final_message).for_each(|ptr_vec_sumcheck_message, builder| { + // Each final message should contain a single element, since the final + // message size log is assumed to be zero + let elems = builder.iter_ptr_get(&proof.final_message, ptr_vec_sumcheck_message[0]); + let elem = builder.get(&elems, 0); + let elem_felts = builder.ext2felt(elem); + challenger.observe_slice(builder, elem_felts); + }); + + transcript_check_pow_witness(builder, challenger, 16, proof.pow_witness); // TODO: avoid hardcoding pow bits + transcript_observe_label(builder, challenger, b"query indices"); + let queries: Array> = builder.dyn_array(get_num_queries()); + builder + .range(0, get_num_queries()) + .for_each(|index_vec, builder| { + let number_of_bits = builder.eval_expr(max_num_var + Usize::from(get_rate_log())); + let query = challenger.sample_bits(builder, number_of_bits); + // TODO: the index will need to be split back to bits in query phase, so it's + // probably better to avoid converting bits to integer altogether + let number_of_bits = builder.eval(max_num_var + Usize::from(get_rate_log())); + let query = bin_to_dec_le(builder, &query, zero, number_of_bits); + builder.set(&queries, index_vec[0], query); + }); + + let input = QueryPhaseVerifierInputVariable { + max_num_var: builder.eval(max_num_var), + batch_coeffs, + fold_challenges, + indices: queries, + proof, + rounds, + }; + batch_verifier_query_phase(builder, input); +} + +#[cfg(test)] +pub mod tests { + use std::{cmp::Reverse, collections::BTreeMap, iter::once}; + + use ceno_mle::mle::MultilinearExtension; + use ceno_transcript::{BasicTranscript, Transcript}; + use ff_ext::{BabyBearExt4, FromUniformBytes}; + use itertools::Itertools; + use mpcs::{ + pcs_batch_commit, pcs_setup, pcs_trim, util::hash::write_digest_to_transcript, + BasefoldDefault, PolynomialCommitmentScheme, + }; + use mpcs::{BasefoldRSParams, BasefoldSpec, PCSFriParam}; + use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; + use openvm_native_circuit::{Native, NativeConfig}; + use openvm_native_compiler::asm::AsmBuilder; + use openvm_native_recursion::challenger::duplex::DuplexChallengerVariable; + use openvm_native_recursion::hints::Hintable; + use openvm_stark_backend::p3_challenger::GrindingChallenger; + use openvm_stark_sdk::config::baby_bear_poseidon2::Challenger; + use openvm_stark_sdk::p3_baby_bear::BabyBear; + use p3_field::Field; + use p3_field::FieldAlgebra; + use rand::thread_rng; + use serde::Deserialize; + + type F = BabyBear; + type E = BabyBearExt4; + type PCS = BasefoldDefault; + + use super::{batch_verify, BasefoldProof, BasefoldProofVariable, InnerConfig, RoundVariable}; + use crate::basefold_verifier::basefold::{Round, RoundOpening}; + use crate::basefold_verifier::query_phase::PointAndEvals; + use crate::{ + basefold_verifier::{ + basefold::BasefoldCommitment, + query_phase::{BatchOpening, CommitPhaseProofStep, QueryOpeningProof}, + structs::CircuitIndexMeta, + }, + tower_verifier::binding::{Point, PointAndEval}, + }; + use openvm_native_compiler::{asm::AsmConfig, prelude::*}; + + #[derive(Deserialize)] + pub struct VerifierInput { + pub max_num_var: usize, + pub proof: BasefoldProof, + pub rounds: Vec, + } + + impl Hintable for VerifierInput { + type HintVariable = VerifierInputVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let max_num_var = usize::read(builder); + let proof = BasefoldProof::read(builder); + let rounds = Vec::::read(builder); + + VerifierInputVariable { + max_num_var, + proof, + rounds, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(>::write(&self.max_num_var)); + stream.extend(self.proof.write()); + stream.extend(self.rounds.write()); + stream + } + } + + #[derive(DslVariable, Clone)] + pub struct VerifierInputVariable { + pub max_num_var: Var, + pub proof: BasefoldProofVariable, + pub rounds: Array>, + } + + #[allow(dead_code)] + pub fn build_batch_verifier(input: VerifierInput) -> (Program, Vec>) { + // build test program + let mut builder = AsmBuilder::::default(); + let mut challenger = DuplexChallengerVariable::new(&mut builder); + let verifier_input = VerifierInput::read(&mut builder); + batch_verify( + &mut builder, + verifier_input.max_num_var, + verifier_input.rounds, + verifier_input.proof, + &mut challenger, + ); + builder.halt(); + let program = builder.compile_isa(); + + let mut witness_stream: Vec> = Vec::new(); + witness_stream.extend(input.write()); + + (program, witness_stream) + } + + fn construct_test(dimensions: Vec<(usize, usize)>) { + let mut rng = thread_rng(); + + // setup PCS + let pp = PCS::setup(1 << 20, mpcs::SecurityLevel::Conjecture100bits).unwrap(); + let (pp, vp) = pcs_trim::(pp, 1 << 20).unwrap(); + + let mut num_total_polys = 0; + let (matrices, mles): (Vec<_>, Vec<_>) = dimensions + .into_iter() + .map(|(num_vars, width)| { + let m = ceno_witness::RowMajorMatrix::::rand(&mut rng, 1 << num_vars, width); + let mles = m.to_mles(); + num_total_polys += width; + + (m, mles) + }) + .unzip(); + + // commit to matrices + let pcs_data = pcs_batch_commit::(&pp, matrices).unwrap(); + let comm = PCS::get_pure_commitment(&pcs_data); + + let point_and_evals = mles + .iter() + .map(|mles| { + let point = E::random_vec(mles[0].num_vars(), &mut rng); + let evals = mles.iter().map(|mle| mle.evaluate(&point)).collect_vec(); + + (point, evals) + }) + .collect_vec(); + + // batch open + let mut transcript = BasicTranscript::::new(&[]); + let rounds = vec![(&pcs_data, point_and_evals.clone())]; + let opening_proof = PCS::batch_open(&pp, rounds, &mut transcript).unwrap(); + + // batch verify + let mut transcript = BasicTranscript::::new(&[]); + let rounds = vec![( + comm, + point_and_evals + .iter() + .map(|(point, evals)| (point.len(), (point.clone(), evals.clone()))) + .collect_vec(), + )]; + PCS::batch_verify(&vp, rounds.clone(), &opening_proof, &mut transcript) + .expect("Native verification failed"); + + let max_num_var = point_and_evals + .iter() + .map(|(point, _)| point.len()) + .max() + .unwrap(); + + let verifier_input = VerifierInput { + max_num_var, + rounds: rounds + .into_iter() + .map(|(commit, openings)| Round { + commit: commit.into(), + openings: openings + .into_iter() + .map(|(num_var, (point, evals))| RoundOpening { + num_var, + point_and_evals: PointAndEvals { + point: Point { + fs: point, + }, + evals, + }, + }) + .collect(), + }) + .collect(), + proof: opening_proof.into(), + }; + + let (program, witness) = build_batch_verifier(verifier_input); + + let system_config = SystemConfig::default() + .with_public_values(4) + .with_max_segment_len((1 << 25) - 100); + let config = NativeConfig::new(system_config, Native); + + let executor = VmExecutor::::new(config); + executor.execute(program.clone(), witness.clone()).unwrap(); + + // _debug + let results = executor.execute_segments(program, witness).unwrap(); + for seg in results { + println!("=> cycle count: {:?}", seg.metrics.cycle_count); + } + } + + #[test] + fn test_simple_batch() { + for num_var in 5..20 { + construct_test(vec![(num_var, 20)]); + } + } + + #[test] + fn test_decreasing_batch() { + construct_test(vec![ + (14, 20), + (14, 40), + (13, 30), + (12, 30), + (11, 10), + (10, 15), + ]); + } + + #[test] + fn test_random_batch() { + construct_test(vec![(10, 20), (12, 30), (11, 10), (12, 15)]); + } +} diff --git a/src/e2e/mod.rs b/src/e2e/mod.rs index 5d5aead..b966742 100644 --- a/src/e2e/mod.rs +++ b/src/e2e/mod.rs @@ -350,22 +350,7 @@ pub fn parse_zkvm_proof_import( serde_json::from_value(serde_json::to_value(zkvm_proof.witin_commit).unwrap()).unwrap(); let fixed_commit = verifier.vk.fixed_commit.clone(); - let pcs_proof = zkvm_proof.opening_proof; - - // let query_phase_verifier_input = QueryPhaseVerifierInput { - // max_num_var, - // indices, - // final_message, - // batch_coeffs, - // queries, - // fixed_comm, - // witin_comm, - // circuit_meta, - // commits: pcs_proof.commits, - // fold_challenges, - // sumcheck_messages: pcs_proof.sumcheck_proof.unwrap(), - // point_evals, - // }; + let pcs_proof = zkvm_proof.opening_proof.into(); ( ZKVMProofInput { @@ -376,7 +361,7 @@ pub fn parse_zkvm_proof_import( witin_commit, fixed_commit, num_instances: vec![], // TODO: Fixme - // query_phase_verifier_input, + pcs_proof, }, proving_sequence, ) @@ -453,6 +438,7 @@ pub fn inner_test_thread() { } #[test] +#[ignore = "e2e does not work for now"] pub fn test_zkvm_proof_verifier_from_bincode_exports() { let stack_size = 64 * 1024 * 1024; // 64 MB diff --git a/src/transcript/mod.rs b/src/transcript/mod.rs index 9a61b80..45af2e2 100644 --- a/src/transcript/mod.rs +++ b/src/transcript/mod.rs @@ -1,9 +1,9 @@ use ff_ext::{BabyBearExt4, ExtensionField as CenoExtensionField, SmallField}; use openvm_native_compiler::prelude::*; -use openvm_native_recursion::challenger::ChallengerVariable; use openvm_native_recursion::challenger::{ duplex::DuplexChallengerVariable, CanObserveVariable, FeltChallenger, }; +use openvm_native_recursion::challenger::{CanSampleBitsVariable, ChallengerVariable}; use p3_field::FieldAlgebra; pub fn transcript_observe_label( @@ -17,3 +17,18 @@ pub fn transcript_observe_label( challenger.observe(builder, f); } } + +pub fn transcript_check_pow_witness( + builder: &mut Builder, + challenger: &mut DuplexChallengerVariable, + nbits: usize, + witness: Felt, +) { + let nbits = builder.eval_expr(Usize::from(nbits)); + challenger.observe(builder, witness); + let bits = challenger.sample_bits(builder, nbits); + builder.range(0, nbits).for_each(|index_vec, builder| { + let bit = builder.get(&bits, index_vec[0]); + builder.assert_eq::>(bit, Usize::from(0)); + }); +} diff --git a/src/zkvm_verifier/binding.rs b/src/zkvm_verifier/binding.rs index 782658a..d0c3e37 100644 --- a/src/zkvm_verifier/binding.rs +++ b/src/zkvm_verifier/binding.rs @@ -1,4 +1,5 @@ use crate::arithmetics::next_pow2_instance_padding; +use crate::basefold_verifier::basefold::{BasefoldProof, BasefoldProofVariable}; use crate::basefold_verifier::query_phase::{ QueryPhaseVerifierInput, QueryPhaseVerifierInputVariable, }; @@ -41,7 +42,7 @@ pub struct ZKVMProofInputVariable { pub fixed_commit_log2_max_codeword_size: Felt, pub num_instances: Array>>, - pub query_phase_verifier_input: QueryPhaseVerifierInputVariable, + pub pcs_proof: BasefoldProofVariable, } #[derive(DslVariable, Clone)] @@ -106,7 +107,7 @@ pub(crate) struct ZKVMProofInput { pub witin_commit: BasefoldCommitment, pub fixed_commit: Option>, pub num_instances: Vec<(usize, usize)>, - // pub query_phase_verifier_input: QueryPhaseVerifierInput, + pub pcs_proof: BasefoldProof, } impl Hintable for ZKVMProofInput { type HintVariable = ZKVMProofInputVariable; @@ -129,7 +130,7 @@ impl Hintable for ZKVMProofInput { let num_instances = Vec::>::read(builder); - let query_phase_verifier_input = QueryPhaseVerifierInput::read(builder); + let pcs_proof = BasefoldProof::read(builder); ZKVMProofInputVariable { raw_pi, @@ -145,7 +146,7 @@ impl Hintable for ZKVMProofInput { fixed_commit_trivial_commits, fixed_commit_log2_max_codeword_size, num_instances, - query_phase_verifier_input, + pcs_proof, } } @@ -230,7 +231,7 @@ impl Hintable for ZKVMProofInput { } stream.extend(num_instances_vec.write()); - // stream.extend(self.query_phase_verifier_input.write()); + stream.extend(self.pcs_proof.write()); stream } diff --git a/src/zkvm_verifier/verifier.rs b/src/zkvm_verifier/verifier.rs index 17f7d82..ed3b6b8 100644 --- a/src/zkvm_verifier/verifier.rs +++ b/src/zkvm_verifier/verifier.rs @@ -5,6 +5,7 @@ use crate::arithmetics::{ challenger_multi_observe, eval_ceno_expr_with_instance, print_ext_arr, print_felt_arr, PolyEvaluator, UniPolyExtrapolator, }; +use crate::basefold_verifier::verifier::batch_verify; use crate::e2e::SubcircuitParams; use crate::tower_verifier::program::verify_tower_proof; use crate::transcript::transcript_observe_label; @@ -265,20 +266,9 @@ pub fn verify_zkvm_proof( logup_sum - dummy_table_item_multiplicity * dummy_table_item.inverse(), ); - /* TODO: MPCS - PCS::batch_verify( - &self.vk.vp, - &vm_proof.num_instances, - &rt_points, - self.vk.fixed_commit.as_ref(), - &vm_proof.witin_commit, - &evaluations, - &vm_proof.fixed_witin_opening_proof, - &self.vk.circuit_num_polys, - &mut transcript, - ) - .map_err(ZKVMError::PCSError)?; - */ + // TODO: prepare rounds and uncomment this + + // batch_verifier(builder, rounds, zkvm_proof_input.pcs_proof, &mut challenger); let empty_arr: Array> = builder.dyn_array(0); let initial_global_state = eval_ceno_expr_with_instance( From 5e4f289816b07a16e898f082605d2f3d97013dd6 Mon Sep 17 00:00:00 2001 From: xkx Date: Wed, 30 Jul 2025 17:32:24 +0800 Subject: [PATCH 70/70] Feat: integrate BaseFold module to verify zkVM proof (#43) * wip * fix * update zkvm verifier * e2e pass without basefold * cleanup * cleanup 2 * verify witin openings but still failed * remove debug logs * add fixed opening * pass fri check * turn on input_opening_point length check * pass e2e test --- src/arithmetics/mod.rs | 4 +- src/basefold_verifier/basefold.rs | 7 +- src/basefold_verifier/extension_mmcs.rs | 2 +- src/basefold_verifier/hash.rs | 30 +- src/basefold_verifier/mmcs.rs | 94 +++--- src/basefold_verifier/query_phase.rs | 17 +- src/basefold_verifier/rs.rs | 2 +- src/basefold_verifier/structs.rs | 206 +------------- src/basefold_verifier/utils.rs | 8 + src/basefold_verifier/verifier.rs | 5 +- src/e2e/mod.rs | 225 +++------------ src/extensions/mod.rs | 220 +++++++------- src/tower_verifier/program.rs | 45 +-- src/zkvm_verifier/binding.rs | 303 +++++--------------- src/zkvm_verifier/verifier.rs | 364 +++++++++++++++--------- 15 files changed, 517 insertions(+), 1015 deletions(-) diff --git a/src/arithmetics/mod.rs b/src/arithmetics/mod.rs index b8ac2bc..c79494d 100644 --- a/src/arithmetics/mod.rs +++ b/src/arithmetics/mod.rs @@ -1,5 +1,5 @@ use crate::tower_verifier::binding::PointAndEvalVariable; -use crate::zkvm_verifier::binding::ZKVMOpcodeProofInputVariable; +use crate::zkvm_verifier::binding::ZKVMChipProofInputVariable; use ceno_mle::{Expression, Fixed, Instance}; use ceno_zkvm::structs::{ChallengeId, WitnessId}; use ff_ext::ExtensionField; @@ -422,7 +422,7 @@ pub fn gen_alpha_pows( pub fn eq_eval_less_or_equal_than( builder: &mut Builder, _challenger: &mut DuplexChallengerVariable, - opcode_proof: &ZKVMOpcodeProofInputVariable, + opcode_proof: &ZKVMChipProofInputVariable, a: &Array>, b: &Array>, ) -> Ext { diff --git a/src/basefold_verifier/basefold.rs b/src/basefold_verifier/basefold.rs index e8af6b8..8340e54 100644 --- a/src/basefold_verifier/basefold.rs +++ b/src/basefold_verifier/basefold.rs @@ -19,14 +19,14 @@ use crate::{ tower_verifier::binding::{IOPProverMessage, IOPProverMessageVariable}, }; -use super::{mmcs::*, structs::DIMENSIONS}; +use super::{mmcs::*, structs::DEGREE}; pub type F = BabyBear; -pub type E = BinomialExtensionField; +pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; pub type HashDigest = MmcsCommitment; -#[derive(Deserialize)] +#[derive(Deserialize, Debug)] pub struct BasefoldCommitment { pub commit: HashDigest, pub log2_max_codeword_size: usize, @@ -73,7 +73,6 @@ pub type HashDigestVariable = MmcsCommitmentVariable; pub struct BasefoldCommitmentVariable { pub commit: HashDigestVariable, pub log2_max_codeword_size: Usize, - // pub trivial_commits: Array>, } #[derive(Deserialize)] diff --git a/src/basefold_verifier/extension_mmcs.rs b/src/basefold_verifier/extension_mmcs.rs index b60794d..6e44863 100644 --- a/src/basefold_verifier/extension_mmcs.rs +++ b/src/basefold_verifier/extension_mmcs.rs @@ -6,7 +6,7 @@ use p3_field::extension::BinomialExtensionField; use super::{mmcs::*, structs::*}; pub type F = BabyBear; -pub type E = BinomialExtensionField; +pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; pub struct ExtMmcsVerifierInput { diff --git a/src/basefold_verifier/hash.rs b/src/basefold_verifier/hash.rs index 9aee223..8cede68 100644 --- a/src/basefold_verifier/hash.rs +++ b/src/basefold_verifier/hash.rs @@ -2,30 +2,21 @@ use openvm_native_compiler::{asm::AsmConfig, prelude::*}; use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; -use p3_field::FieldAlgebra; use serde::Deserialize; -use super::structs::DIMENSIONS; +use super::structs::DEGREE; pub const DIGEST_ELEMS: usize = 8; pub type F = BabyBear; -pub type E = BinomialExtensionField; +pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; -#[derive(Deserialize)] +#[derive(Deserialize, Default, Debug)] pub struct Hash { pub value: [F; DIGEST_ELEMS], } -impl Default for Hash { - fn default() -> Self { - Hash { - value: [F::ZERO; DIGEST_ELEMS], - } - } -} - impl From> for Hash { fn from(hash: p3_symmetric::Hash) -> Self { Hash { value: hash.into() } @@ -43,22 +34,13 @@ impl Hintable for Hash { type HintVariable = HashVariable; fn read(builder: &mut Builder) -> Self::HintVariable { - let value = builder.dyn_array(DIGEST_ELEMS); - for i in 0..DIGEST_ELEMS { - let tmp = F::read(builder); - builder.set(&value, i, tmp); - } + let value = builder.hint_felts_fixed(DIGEST_ELEMS); HashVariable { value } } fn write(&self) -> Vec::N>> { - let mut stream = Vec::new(); - // Write out each entries - for i in 0..DIGEST_ELEMS { - stream.extend(self.value[i].write()); - } - stream + self.value.map(|felt| vec![felt]).to_vec() } } @@ -67,8 +49,6 @@ mod tests { use openvm_circuit::arch::{SystemConfig, VmExecutor}; use openvm_native_circuit::{Native, NativeConfig}; use openvm_native_compiler::asm::AsmBuilder; - type F = BabyBear; - type E = BinomialExtensionField; use crate::basefold_verifier::basefold::HashDigest; diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs index 5093a9e..1f030dc 100644 --- a/src/basefold_verifier/mmcs.rs +++ b/src/basefold_verifier/mmcs.rs @@ -1,16 +1,14 @@ -// Note: check all XXX comments! - -use std::marker::PhantomData; - use openvm_native_compiler::{asm::AsmConfig, prelude::*}; use openvm_native_recursion::{hints::Hintable, vars::HintSlice}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; +use crate::basefold_verifier::utils::{read_hint_slice, write_mmcs_proof}; + use super::{hash::*, structs::*}; pub type F = BabyBear; -pub type E = BinomialExtensionField; +pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; pub type MmcsCommitment = Hash; @@ -24,6 +22,17 @@ pub struct MmcsVerifierInput { pub proof: MmcsProof, } +pub type MmcsCommitmentVariable = HashVariable; + +#[derive(DslVariable, Clone)] +pub struct MmcsVerifierInputVariable { + pub commit: MmcsCommitmentVariable, + pub dimensions: Array>, + pub index_bits: Array>, + pub opened_values: Array>>, + pub proof: HintSlice, +} + impl Hintable for MmcsVerifierInput { type HintVariable = MmcsVerifierInputVariable; @@ -32,9 +41,7 @@ impl Hintable for MmcsVerifierInput { let dimensions = Vec::::read(builder); let index_bits = Vec::::read(builder); let opened_values = Vec::>::read(builder); - let length = Usize::from(builder.hint_var()); - let id = Usize::from(builder.hint_load()); - let proof = HintSlice { length, id }; + let proof = read_hint_slice(builder); MmcsVerifierInputVariable { commit, @@ -47,45 +54,27 @@ impl Hintable for MmcsVerifierInput { fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); - // Split index into bits + + let idx_bits = (0..self.proof.len()) + .scan(self.index, |acc, _| { + let bit = *acc & 0x01; + *acc >>= 1; + + Some(bit) + }) + .collect::>(); + stream.extend(self.commit.write()); stream.extend(self.dimensions.write()); - let mut index_bits = Vec::new(); - let mut index = self.index; - for _ in 0..self.proof.len() { - index_bits.push(index % 2); - index /= 2; - } - // index_bits.reverse(); // Index bits is big endian ? - stream.extend( as Hintable>::write(&index_bits)); + stream.extend(idx_bits.write()); stream.extend(self.opened_values.write()); - stream.extend(>::write(&(self.proof.len()))); // According to openvm extensions/native/recursion/src/fri/hints.rs - stream.extend( - self.proof - .iter() - .flat_map(|p| p.iter().copied()) - .collect::>() - .write(), - ); // According to openvm extensions/native/recursion/src/fri/hints.rs + stream.extend(write_mmcs_proof(&self.proof)); + stream } } -pub type MmcsCommitmentVariable = HashVariable; - -#[derive(DslVariable, Clone)] -pub struct MmcsVerifierInputVariable { - pub commit: MmcsCommitmentVariable, - pub dimensions: Array>, - pub index_bits: Array>, - pub opened_values: Array>>, - pub proof: HintSlice, -} - -pub(crate) fn mmcs_verify_batch( - builder: &mut Builder, - input: MmcsVerifierInputVariable, -) { +pub fn mmcs_verify_batch(builder: &mut Builder, input: MmcsVerifierInputVariable) { let dimensions = match input.dimensions { Array::Dyn(ptr, len) => Array::Dyn(ptr, len.clone()), _ => panic!("Expected a dynamic array of felts"), @@ -99,32 +88,23 @@ pub(crate) fn mmcs_verify_batch( ); } +#[cfg(test)] pub mod tests { use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; use openvm_native_circuit::{Native, NativeConfig}; use openvm_native_compiler::asm::AsmBuilder; use openvm_native_recursion::hints::Hintable; - use openvm_stark_backend::config::StarkGenericConfig; - use openvm_stark_sdk::{ - config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, - }; - use p3_field::{extension::BinomialExtensionField, FieldAlgebra}; - type SC = BabyBearPoseidon2Config; - - type F = BabyBear; - type E = BinomialExtensionField; - type EF = ::Challenge; - use crate::basefold_verifier::structs::Dimensions; + use p3_field::FieldAlgebra; - use super::{mmcs_verify_batch, InnerConfig, MmcsCommitment, MmcsVerifierInput}; + use super::{mmcs_verify_batch, MmcsCommitment, MmcsVerifierInput, E, F}; /// The witness in this test is produced by: /// https://github.com/Jiangkm3/Plonky3 branch cyte/mmcs-poseidon2-constants /// cargo test --package p3-merkle-tree --lib -- mmcs::tests::size_gaps --exact --show-output #[allow(dead_code)] - pub fn build_mmcs_verify_batch() -> (Program, Vec>) { + pub fn build_mmcs_verify_batch() -> (Program, Vec>) { // OpenVM DSL - let mut builder = AsmBuilder::::default(); + let mut builder = AsmBuilder::::default(); // Witness inputs let mmcs_input = MmcsVerifierInput::read(&mut builder); @@ -264,9 +244,7 @@ pub mod tests { witness_stream.extend(mmcs_input.write()); // PROGRAM - let program: Program< - p3_monty_31::MontyField31, - > = builder.compile_isa(); + let program: Program = builder.compile_isa(); (program, witness_stream) } @@ -280,7 +258,7 @@ pub mod tests { .with_max_segment_len((1 << 25) - 100); let config = NativeConfig::new(system_config, Native); - let executor = VmExecutor::::new(config); + let executor = VmExecutor::::new(config); executor.execute(program, witness).unwrap(); // _debug diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs index c7ff136..2497666 100644 --- a/src/basefold_verifier/query_phase.rs +++ b/src/basefold_verifier/query_phase.rs @@ -1,6 +1,5 @@ // Note: check all XXX comments! -use ark_std::log2; use ff_ext::{BabyBearExt4, ExtensionField, PoseidonField}; use mpcs::basefold::QueryOpeningProof as InnerQueryOpeningProof; use openvm_native_compiler::{asm::AsmConfig, prelude::*}; @@ -13,11 +12,10 @@ use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_commit::ExtensionMmcs; use p3_field::{Field, FieldAlgebra}; use serde::Deserialize; -use std::fmt::Debug; -use super::{basefold::*, extension_mmcs::*, mmcs::*, rs::*, structs::*, utils::*}; +use super::{basefold::*, extension_mmcs::*, mmcs::*, rs::*, utils::*}; use crate::{ - arithmetics::{build_eq_x_r_vec_sequential_with_offset, eq_eval_with_index}, + arithmetics::eq_eval_with_index, tower_verifier::{binding::*, program::interpolate_uni_poly}, }; @@ -79,14 +77,8 @@ impl Hintable for BatchOpening { fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); stream.extend(self.opened_values.write()); - stream.extend(vec![ - vec![F::from_canonical_usize(self.opening_proof.len())], - self.opening_proof - .iter() - .flatten() - .copied() - .collect::>(), - ]); + stream.extend(write_mmcs_proof(&self.opening_proof)); + stream } } @@ -376,6 +368,7 @@ pub(crate) fn batch_verifier_query_phase( let query = builder.iter_ptr_get(&input.proof.query_opening_proof, ptr_vec[1]); let batch_coeffs_offset: Var = builder.constant(C::N::ZERO); + builder.assert_usize_eq(query.input_proofs.len(), input.rounds.len()); iter_zip!(builder, query.input_proofs, input.rounds).for_each(|ptr_vec, builder| { let batch_opening = builder.iter_ptr_get(&query.input_proofs, ptr_vec[0]); let round = builder.iter_ptr_get(&input.rounds, ptr_vec[1]); diff --git a/src/basefold_verifier/rs.rs b/src/basefold_verifier/rs.rs index 0f4db95..b2730c7 100644 --- a/src/basefold_verifier/rs.rs +++ b/src/basefold_verifier/rs.rs @@ -13,7 +13,7 @@ use super::structs::*; use super::utils::{pow_felt, pow_felt_bits}; pub type F = BabyBear; -pub type E = BinomialExtensionField; +pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; pub struct DenseMatrix { diff --git a/src/basefold_verifier/structs.rs b/src/basefold_verifier/structs.rs index 724ff19..1f4ffa8 100644 --- a/src/basefold_verifier/structs.rs +++ b/src/basefold_verifier/structs.rs @@ -3,69 +3,13 @@ use openvm_native_compiler_derive::DslVariable; use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; -use serde::Deserialize; -pub const DIMENSIONS: usize = 4; +pub const DEGREE: usize = 4; pub type F = BabyBear; -pub type E = BinomialExtensionField; +pub type E = BinomialExtensionField; pub type InnerConfig = AsmConfig; -use super::rs::get_rate_log; -use super::utils::pow_2; - -#[derive(DslVariable, Clone)] -pub struct CircuitIndexMetaVariable { - pub witin_num_vars: Usize, - pub witin_num_polys: Usize, - pub fixed_num_vars: Usize, - pub fixed_num_polys: Usize, -} - -#[derive(Deserialize)] -pub struct CircuitIndexMeta { - pub witin_num_vars: usize, - pub witin_num_polys: usize, - pub fixed_num_vars: usize, - pub fixed_num_polys: usize, -} - -impl Hintable for CircuitIndexMeta { - type HintVariable = CircuitIndexMetaVariable; - - fn read(builder: &mut Builder) -> Self::HintVariable { - let witin_num_vars = Usize::Var(usize::read(builder)); - let witin_num_polys = Usize::Var(usize::read(builder)); - let fixed_num_vars = Usize::Var(usize::read(builder)); - let fixed_num_polys = Usize::Var(usize::read(builder)); - - CircuitIndexMetaVariable { - witin_num_vars, - witin_num_polys, - fixed_num_vars, - fixed_num_polys, - } - } - - fn write(&self) -> Vec::N>> { - let mut stream = Vec::new(); - stream.extend(>::write( - &self.witin_num_vars, - )); - stream.extend(>::write( - &self.witin_num_polys, - )); - stream.extend(>::write( - &self.fixed_num_vars, - )); - stream.extend(>::write( - &self.fixed_num_polys, - )); - stream - } -} -impl VecAutoHintable for CircuitIndexMeta {} - #[derive(DslVariable, Clone)] pub struct DimensionsVariable { pub width: Var, @@ -77,6 +21,8 @@ pub struct Dimensions { pub height: usize, } +impl VecAutoHintable for Dimensions {} + impl Hintable for Dimensions { type HintVariable = DimensionsVariable; @@ -94,147 +40,3 @@ impl Hintable for Dimensions { stream } } -impl VecAutoHintable for Dimensions {} - -pub fn get_base_codeword_dimensions( - builder: &mut Builder, - circuit_meta_map: Array>, -) -> (Array>, Array>) { - let dim_len = circuit_meta_map.len(); - let wit_dim: Array> = builder.dyn_array(dim_len.clone()); - let fixed_dim: Array> = builder.dyn_array(dim_len.clone()); - - builder.range(0, dim_len).for_each(|i_vec, builder| { - let i = i_vec[0]; - let tmp = builder.get(&circuit_meta_map, i); - let witin_num_vars = tmp.witin_num_vars; - let witin_num_polys = tmp.witin_num_polys; - let fixed_num_vars = tmp.fixed_num_vars; - let fixed_num_polys = tmp.fixed_num_polys; - // wit_dim - // let width = builder.eval(witin_num_polys * Usize::from(2)); - let height_exp = builder.eval(witin_num_vars + Usize::from(get_rate_log() - 1)); - // let height = pow_2(builder, height_exp); - // let next_wit: DimensionsVariable = DimensionsVariable { width, height }; - // Only keep the height because the width is not needed in the mmcs batch - // verify instruction - // The dimension passed to the mmcs verifier batch is log of the height, not - // the height itself - builder.set_value(&wit_dim, i, height_exp); - - // fixed_dim - // XXX: since fixed_num_vars is usize, fixed_num_vars > 0 is equivalent to fixed_num_vars != 0 - builder - .if_ne(fixed_num_vars.clone(), Usize::from(0)) - .then(|builder| { - // let width = builder.eval(fixed_num_polys.clone() * Usize::from(2)); - let height_exp = - builder.eval(fixed_num_vars.clone() + Usize::from(get_rate_log() - 1)); - // XXX: more efficient pow implementation - // let height = pow_2(builder, height_exp); - // let next_fixed: DimensionsVariable = DimensionsVariable { width, height }; - builder.set_value(&fixed_dim, i, height_exp); - }); - }); - (wit_dim, fixed_dim) -} - -pub mod tests { - use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; - use openvm_native_circuit::{Native, NativeConfig}; - use openvm_native_compiler::asm::AsmBuilder; - use openvm_native_compiler::prelude::*; - use openvm_native_recursion::hints::Hintable; - use openvm_stark_backend::config::StarkGenericConfig; - use openvm_stark_sdk::{ - config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, - }; - use p3_field::extension::BinomialExtensionField; - type SC = BabyBearPoseidon2Config; - - type F = BabyBear; - type E = BinomialExtensionField; - type EF = ::Challenge; - use crate::basefold_verifier::structs::*; - - use super::{get_base_codeword_dimensions, InnerConfig}; - - #[allow(dead_code)] - pub fn build_test_get_base_codeword_dimensions() -> (Program, Vec>) { - // OpenVM DSL - let mut builder = AsmBuilder::::default(); - - // Witness inputs - let map_len = Usize::Var(usize::read(&mut builder)); - let circuit_meta_map = builder.dyn_array(map_len.clone()); - builder - .range(0, map_len.clone()) - .for_each(|i_vec, builder| { - let i = i_vec[0]; - let next_meta = CircuitIndexMeta::read(builder); - builder.set(&circuit_meta_map, i, next_meta); - }); - - let (wit_dim, fixed_dim) = get_base_codeword_dimensions(&mut builder, circuit_meta_map); - builder.range(0, map_len).for_each(|i_vec, builder| { - let i = i_vec[0]; - let wit = builder.get(&wit_dim, i); - let fixed = builder.get(&fixed_dim, i); - let i_val: Var<_> = builder.eval(i); - builder.print_v(i_val); - // let ww_val: Var<_> = builder.eval(wit.width); - // let wh_val: Var<_> = builder.eval(wit.height); - // let fw_val: Var<_> = builder.eval(fixed.width); - // let fh_val: Var<_> = builder.eval(fixed.height); - // builder.print_v(ww_val); - builder.print_v(wit); - // builder.print_v(fw_val); - builder.print_v(fixed); - }); - builder.halt(); - - // Pass in witness stream - let mut witness_stream: Vec< - Vec>, - > = Vec::new(); - - // Map length - let map_len = 5; - witness_stream.extend(>::write(&map_len)); - for i in 0..map_len { - // Individual metas - let circuit_meta = CircuitIndexMeta { - witin_num_vars: i, - witin_num_polys: i, - fixed_num_vars: i, - fixed_num_polys: i, - }; - witness_stream.extend(circuit_meta.write()); - } - - let program: Program< - p3_monty_31::MontyField31, - > = builder.compile_isa(); - - (program, witness_stream) - } - - #[test] - fn test_dense_matrix_pad() { - let (program, witness) = build_test_get_base_codeword_dimensions(); - - let system_config = SystemConfig::default() - .with_public_values(4) - .with_max_segment_len((1 << 25) - 100); - let config = NativeConfig::new(system_config, Native); - - let executor = VmExecutor::::new(config); - executor.execute(program, witness).unwrap(); - - // _debug - // let results = executor.execute_segments(program, witness).unwrap(); - // for seg in results { - // println!("=> cycle count: {:?}", seg.metrics.cycle_count); - // } - } -} diff --git a/src/basefold_verifier/utils.rs b/src/basefold_verifier/utils.rs index 6b6bf92..86f877a 100644 --- a/src/basefold_verifier/utils.rs +++ b/src/basefold_verifier/utils.rs @@ -1,5 +1,6 @@ use openvm_native_compiler::ir::*; use openvm_native_recursion::vars::HintSlice; +use p3_baby_bear::BabyBear; use p3_field::FieldAlgebra; use crate::basefold_verifier::mmcs::MmcsProof; @@ -305,3 +306,10 @@ pub(crate) fn read_hint_slice(builder: &mut Builder) -> HintSlice< let id = Usize::from(builder.hint_load()); HintSlice { length, id } } + +pub(crate) fn write_mmcs_proof(proof: &MmcsProof) -> Vec> { + vec![ + vec![BabyBear::from_canonical_usize(proof.len())], + proof.iter().flatten().copied().collect::>(), + ] +} diff --git a/src/basefold_verifier/verifier.rs b/src/basefold_verifier/verifier.rs index 5eb39de..1ea6c12 100644 --- a/src/basefold_verifier/verifier.rs +++ b/src/basefold_verifier/verifier.rs @@ -188,7 +188,6 @@ pub mod tests { basefold_verifier::{ basefold::BasefoldCommitment, query_phase::{BatchOpening, CommitPhaseProofStep, QueryOpeningProof}, - structs::CircuitIndexMeta, }, tower_verifier::binding::{Point, PointAndEval}, }; @@ -321,9 +320,7 @@ pub mod tests { .map(|(num_var, (point, evals))| RoundOpening { num_var, point_and_evals: PointAndEvals { - point: Point { - fs: point, - }, + point: Point { fs: point }, evals, }, }) diff --git a/src/e2e/mod.rs b/src/e2e/mod.rs index b966742..b707d72 100644 --- a/src/e2e/mod.rs +++ b/src/e2e/mod.rs @@ -1,14 +1,12 @@ +use crate::basefold_verifier::basefold::BasefoldCommitment; use crate::basefold_verifier::query_phase::QueryPhaseVerifierInput; use crate::tower_verifier::binding::IOPProverMessage; use crate::zkvm_verifier::binding::ZKVMProofInput; -use crate::zkvm_verifier::binding::{ - TowerProofInput, ZKVMOpcodeProofInput, ZKVMTableProofInput, E, F, -}; +use crate::zkvm_verifier::binding::{TowerProofInput, ZKVMChipProofInput, E, F}; use crate::zkvm_verifier::verifier::verify_zkvm_proof; use ceno_mle::util::ceil_log2; use ff_ext::BabyBearExt4; use itertools::Itertools; -use mpcs::BasefoldCommitment; use mpcs::{Basefold, BasefoldRSParams}; use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; use openvm_native_circuit::{Native, NativeConfig}; @@ -23,9 +21,7 @@ use openvm_stark_sdk::config::setup_tracing_with_log_level; use openvm_stark_sdk::{ config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, }; -use std::collections::HashMap; use std::fs::File; -use std::thread; type SC = BabyBearPoseidon2Config; type EF = ::Challenge; @@ -35,27 +31,10 @@ use ceno_zkvm::{ structs::ZKVMVerifyingKey, }; -#[derive(Debug, Clone)] -pub struct SubcircuitParams { - pub id: usize, - pub order_idx: usize, - pub type_order_idx: usize, - pub name: String, - pub num_instances: usize, - pub is_opcode: bool, -} - pub fn parse_zkvm_proof_import( zkvm_proof: ZKVMProof>, verifier: &ZKVMVerifier>, -) -> (ZKVMProofInput, Vec) { - let subcircuit_names = verifier.vk.circuit_vks.keys().collect_vec(); - - let mut order_idx: usize = 0; - let mut opcode_order_idx: usize = 0; - let mut table_order_idx: usize = 0; - let mut proving_sequence: Vec = vec![]; - +) -> ZKVMProofInput { let raw_pi = zkvm_proof .raw_pi .iter() @@ -80,15 +59,14 @@ pub fn parse_zkvm_proof_import( }) .collect::>(); - let mut opcode_proofs_vec: Vec = vec![]; - /* - for (opcode_id, opcode_proof) in &zkvm_proof.chip_proofs { + let mut chip_proofs: Vec = vec![]; + for (chip_id, chip_proof) in &zkvm_proof.chip_proofs { let mut record_r_out_evals: Vec> = vec![]; let mut record_w_out_evals: Vec> = vec![]; let mut record_lk_out_evals: Vec> = vec![]; - let record_r_out_evals_len: usize = opcode_proof.r_out_evals.len(); - for v_vec in &opcode_proof.r_out_evals { + let record_r_out_evals_len: usize = chip_proof.r_out_evals.len(); + for v_vec in &chip_proof.r_out_evals { let mut arr: Vec = vec![]; for v in v_vec { let v_e: E = @@ -97,8 +75,8 @@ pub fn parse_zkvm_proof_import( } record_r_out_evals.push(arr); } - let record_w_out_evals_len: usize = opcode_proof.w_out_evals.len(); - for v_vec in &opcode_proof.w_out_evals { + let record_w_out_evals_len: usize = chip_proof.w_out_evals.len(); + for v_vec in &chip_proof.w_out_evals { let mut arr: Vec = vec![]; for v in v_vec { let v_e: E = @@ -107,8 +85,8 @@ pub fn parse_zkvm_proof_import( } record_w_out_evals.push(arr); } - let record_lk_out_evals_len: usize = opcode_proof.lk_out_evals.len(); - for v_vec in &opcode_proof.lk_out_evals { + let record_lk_out_evals_len: usize = chip_proof.lk_out_evals.len(); + for v_vec in &chip_proof.lk_out_evals { let mut arr: Vec = vec![]; for v in v_vec { let v_e: E = @@ -122,7 +100,7 @@ pub fn parse_zkvm_proof_import( let mut tower_proof = TowerProofInput::default(); let mut proofs: Vec> = vec![]; - for proof in &opcode_proof.tower_proof.proofs { + for proof in &chip_proof.tower_proof.proofs { let mut proof_messages: Vec = vec![]; for m in proof { let mut evaluations_vec: Vec = vec![]; @@ -142,7 +120,7 @@ pub fn parse_zkvm_proof_import( tower_proof.proofs = proofs; let mut prod_specs_eval: Vec>> = vec![]; - for inner_val in &opcode_proof.tower_proof.prod_specs_eval { + for inner_val in &chip_proof.tower_proof.prod_specs_eval { let mut inner_v: Vec> = vec![]; for inner_evals_val in inner_val { let mut inner_evals_v: Vec = vec![]; @@ -160,7 +138,7 @@ pub fn parse_zkvm_proof_import( tower_proof.prod_specs_eval = prod_specs_eval; let mut logup_specs_eval: Vec>> = vec![]; - for inner_val in &opcode_proof.tower_proof.logup_specs_eval { + for inner_val in &chip_proof.tower_proof.logup_specs_eval { let mut inner_v: Vec> = vec![]; for inner_evals_val in inner_val { let mut inner_evals_v: Vec = vec![]; @@ -179,8 +157,8 @@ pub fn parse_zkvm_proof_import( // main constraint and select sumcheck proof let mut main_sumcheck_proofs: Vec = vec![]; - if opcode_proof.main_sumcheck_proofs.is_some() { - for m in opcode_proof.main_sumcheck_proofs.as_ref().unwrap() { + if chip_proof.main_sumcheck_proofs.is_some() { + for m in chip_proof.main_sumcheck_proofs.as_ref().unwrap() { let mut evaluations_vec: Vec = vec![]; for v in &m.evaluations { let v_e: E = @@ -194,20 +172,20 @@ pub fn parse_zkvm_proof_import( } let mut wits_in_evals: Vec = vec![]; - for v in &opcode_proof.wits_in_evals { + for v in &chip_proof.wits_in_evals { let v_e: E = serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); wits_in_evals.push(v_e); } let mut fixed_in_evals: Vec = vec![]; - for v in &opcode_proof.fixed_in_evals { + for v in &chip_proof.fixed_in_evals { let v_e: E = serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); fixed_in_evals.push(v_e); } - opcode_proofs_vec.push(ZKVMOpcodeProofInput { - idx: opcode_id.clone(), - num_instances: opcode_num_instances_lookup.get(opcode_id).unwrap().clone(), + chip_proofs.push(ZKVMChipProofInput { + idx: chip_id.clone(), + num_instances: chip_proof.num_instances, record_r_out_evals_len, record_w_out_evals_len, record_lk_out_evals_len, @@ -221,150 +199,19 @@ pub fn parse_zkvm_proof_import( }); } - let mut table_proofs_vec: Vec = vec![]; - for (table_id, table_proof) in &zkvm_proof.table_proofs { - let mut record_r_out_evals: Vec> = vec![]; - let mut record_w_out_evals: Vec> = vec![]; - let mut record_lk_out_evals: Vec> = vec![]; - - let record_r_out_evals_len: usize = table_proof.r_out_evals.len(); - for v_vec in &table_proof.r_out_evals { - let mut arr: Vec = vec![]; - for v in v_vec { - let v_e: E = - serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - arr.push(v_e); - } - record_r_out_evals.push(arr); - } - let record_w_out_evals_len: usize = table_proof.w_out_evals.len(); - for v_vec in &table_proof.w_out_evals { - let mut arr: Vec = vec![]; - for v in v_vec { - let v_e: E = - serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - arr.push(v_e); - } - record_w_out_evals.push(arr); - } - let record_lk_out_evals_len: usize = table_proof.lk_out_evals.len(); - for v_vec in &table_proof.lk_out_evals { - let mut arr: Vec = vec![]; - for v in v_vec { - let v_e: E = - serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - arr.push(v_e); - } - record_lk_out_evals.push(arr); - } - - // Tower proof - let mut tower_proof = TowerProofInput::default(); - let mut proofs: Vec> = vec![]; - - for proof in &table_proof.tower_proof.proofs { - let mut proof_messages: Vec = vec![]; - for m in proof { - let mut evaluations_vec: Vec = vec![]; - - for v in &m.evaluations { - let v_e: E = - serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - evaluations_vec.push(v_e); - } - proof_messages.push(IOPProverMessage { - evaluations: evaluations_vec, - }); - } - proofs.push(proof_messages); - } - tower_proof.num_proofs = proofs.len(); - tower_proof.proofs = proofs; - - let mut prod_specs_eval: Vec>> = vec![]; - for inner_val in &table_proof.tower_proof.prod_specs_eval { - let mut inner_v: Vec> = vec![]; - for inner_evals_val in inner_val { - let mut inner_evals_v: Vec = vec![]; - - for v in inner_evals_val { - let v_e: E = - serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - inner_evals_v.push(v_e); - } - inner_v.push(inner_evals_v); - } - prod_specs_eval.push(inner_v); - } - tower_proof.num_prod_specs = prod_specs_eval.len(); - tower_proof.prod_specs_eval = prod_specs_eval; - - let mut logup_specs_eval: Vec>> = vec![]; - for inner_val in &table_proof.tower_proof.logup_specs_eval { - let mut inner_v: Vec> = vec![]; - for inner_evals_val in inner_val { - let mut inner_evals_v: Vec = vec![]; - - for v in inner_evals_val { - let v_e: E = - serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - inner_evals_v.push(v_e); - } - inner_v.push(inner_evals_v); - } - logup_specs_eval.push(inner_v); - } - tower_proof.num_logup_specs = logup_specs_eval.len(); - tower_proof.logup_specs_eval = logup_specs_eval; - - let mut fixed_in_evals: Vec = vec![]; - for v in &table_proof.fixed_in_evals { - let v_e: E = serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - fixed_in_evals.push(v_e); - } - let mut wits_in_evals: Vec = vec![]; - for v in &table_proof.wits_in_evals { - let v_e: E = serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - wits_in_evals.push(v_e); - } - - let num_instances = table_num_instances_lookup.get(table_id).unwrap().clone(); - - table_proofs_vec.push(ZKVMTableProofInput { - idx: table_id.clone(), - num_instances, - record_r_out_evals_len, - record_w_out_evals_len, - record_lk_out_evals_len, - record_r_out_evals, - record_w_out_evals, - record_lk_out_evals, - tower_proof, - fixed_in_evals, - wits_in_evals, - }); - } - */ - - let witin_commit: BasefoldCommitment = + let witin_commit: mpcs::BasefoldCommitment = serde_json::from_value(serde_json::to_value(zkvm_proof.witin_commit).unwrap()).unwrap(); - let fixed_commit = verifier.vk.fixed_commit.clone(); + let witin_commit: BasefoldCommitment = witin_commit.into(); let pcs_proof = zkvm_proof.opening_proof.into(); - ( - ZKVMProofInput { - raw_pi, - pi_evals, - opcode_proofs: vec![], - table_proofs: vec![], - witin_commit, - fixed_commit, - num_instances: vec![], // TODO: Fixme - pcs_proof, - }, - proving_sequence, - ) + ZKVMProofInput { + raw_pi, + pi_evals, + chip_proofs, + witin_commit, + pcs_proof, + } } pub fn inner_test_thread() { @@ -382,19 +229,14 @@ pub fn inner_test_thread() { .expect("Failed to deserialize vk file"); let verifier = ZKVMVerifier::new(vk); - let (zkvm_proof_input, proving_sequence) = parse_zkvm_proof_import(zkvm_proof, &verifier); + let zkvm_proof_input = parse_zkvm_proof_import(zkvm_proof, &verifier); // OpenVM DSL let mut builder = AsmBuilder::::default(); // Obtain witness inputs let zkvm_proof_input_variables = ZKVMProofInput::read(&mut builder); - verify_zkvm_proof( - &mut builder, - zkvm_proof_input_variables, - &verifier, - proving_sequence, - ); + verify_zkvm_proof(&mut builder, zkvm_proof_input_variables, &verifier); builder.halt(); // Pass in witness stream @@ -438,11 +280,10 @@ pub fn inner_test_thread() { } #[test] -#[ignore = "e2e does not work for now"] pub fn test_zkvm_proof_verifier_from_bincode_exports() { let stack_size = 64 * 1024 * 1024; // 64 MB - let handler = thread::Builder::new() + let handler = std::thread::Builder::new() .stack_size(stack_size) .spawn(inner_test_thread) .expect("Failed to spawn thread"); diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 438284d..347f556 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -1,126 +1,114 @@ -use crate::arithmetics::{challenger_multi_observe, exts_to_felts, print_felt_arr}; -use crate::e2e::SubcircuitParams; -use crate::tower_verifier::binding::IOPProverMessage; -use crate::tower_verifier::program::verify_tower_proof; -use crate::transcript::transcript_observe_label; -use crate::zkvm_verifier::binding::ZKVMProofInput; -use crate::zkvm_verifier::binding::{ - TowerProofInput, ZKVMOpcodeProofInput, ZKVMTableProofInput, E, F, -}; -use crate::zkvm_verifier::verifier::verify_zkvm_proof; -use crate::{ - arithmetics::{ - build_eq_x_r_vec_sequential, ceil_log2, concat, dot_product as ext_dot_product, - eq_eval_less_or_equal_than, eval_wellform_address_vec, gen_alpha_pows, max_usize_arr, - max_usize_vec, next_pow2_instance_padding, product, sum as ext_sum, - }, - tower_verifier::{binding::PointVariable, program::iop_verifier_state_verify}, -}; -use ceno_mle::expression::StructuralWitIn; -use ceno_zkvm::{circuit_builder::SetTableSpec, scheme::verifier::ZKVMVerifier}; -use ff_ext::BabyBearExt4; -use itertools::interleave; -use itertools::max; -use itertools::Itertools; -use mpcs::BasefoldCommitment; -use mpcs::{Basefold, BasefoldRSParams}; -use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; -use openvm_native_circuit::{Native, NativeConfig}; -use openvm_native_compiler::conversion::convert_program; -use openvm_native_compiler::prelude::*; -use openvm_native_compiler::{asm::AsmBuilder, conversion::CompilerOptions}; -use openvm_native_compiler_derive::iter_zip; -use openvm_native_recursion::challenger::{self, CanSampleVariable}; -use openvm_native_recursion::challenger::{ - duplex::DuplexChallengerVariable, CanObserveVariable, FeltChallenger, -}; -use openvm_native_recursion::hints::Hintable; -use openvm_stark_backend::config::StarkGenericConfig; -use openvm_stark_sdk::{ - config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, -}; -use p3_field::{Field, FieldAlgebra, FieldExtensionAlgebra}; -use std::collections::HashMap; -use std::fs::File; -use std::marker::PhantomData; - -type Pcs = Basefold; -const NUM_FANIN: usize = 2; -const MAINCONSTRAIN_SUMCHECK_BATCH_SIZE: usize = 3; // read/write/lookup -const SEL_DEGREE: usize = 2; - -type SC = BabyBearPoseidon2Config; -type EF = ::Challenge; - -#[test] -pub fn test_native_multi_observe() { - // OpenVM DSL - let mut builder = AsmBuilder::::default(); - - vm_program(&mut builder); - - builder.halt(); - - // Pass in witness stream - let witness_stream: Vec< - Vec>, - > = Vec::new(); - - // Compile program - let options = CompilerOptions::default().with_cycle_tracker(); - let mut compiler = AsmCompiler::new(options.word_size); - compiler.build(builder.operations); - let asm_code = compiler.code(); - let program = convert_program(asm_code, options); - - let mut system_config = SystemConfig::default() - .with_public_values(4) - .with_max_segment_len((1 << 25) - 100); - system_config.profiling = true; - let config = NativeConfig::new(system_config, Native); - - let executor = VmExecutor::::new(config); - - // Alternative execution - // executor.execute(program, witness_stream).unwrap(); - - let res = executor - .execute_and_then(program, witness_stream, |_, seg| Ok(seg), |err| err) - .unwrap(); - - for (i, seg) in res.iter().enumerate() { - println!("=> segment {:?} metrics: {:?}", i, seg.metrics); +#[cfg(test)] +mod tests { + + use crate::arithmetics::{challenger_multi_observe, exts_to_felts}; + + use crate::zkvm_verifier::binding::{E, F}; + use ceno_mle::expression::StructuralWitIn; + use ceno_zkvm::{circuit_builder::SetTableSpec, scheme::verifier::ZKVMVerifier}; + use ff_ext::BabyBearExt4; + use itertools::interleave; + use itertools::max; + use itertools::Itertools; + use mpcs::BasefoldCommitment; + use mpcs::{Basefold, BasefoldRSParams}; + use openvm_circuit::arch::SystemConfig; + use openvm_circuit::arch::VmExecutor; + use openvm_native_circuit::Native; + use openvm_native_circuit::NativeConfig; + use openvm_native_compiler::conversion::convert_program; + use openvm_native_compiler::prelude::*; + use openvm_native_compiler::{asm::AsmBuilder, conversion::CompilerOptions}; + use openvm_native_compiler_derive::iter_zip; + use openvm_native_recursion::challenger::{self, CanSampleVariable}; + use openvm_native_recursion::challenger::{ + duplex::DuplexChallengerVariable, CanObserveVariable, FeltChallenger, + }; + use openvm_native_recursion::hints::Hintable; + use openvm_stark_backend::config::StarkGenericConfig; + use openvm_stark_sdk::{ + config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, + }; + use p3_field::{Field, FieldAlgebra, FieldExtensionAlgebra}; + + type Pcs = Basefold; + const NUM_FANIN: usize = 2; + const MAINCONSTRAIN_SUMCHECK_BATCH_SIZE: usize = 3; // read/write/lookup + const SEL_DEGREE: usize = 2; + + type SC = BabyBearPoseidon2Config; + type EF = ::Challenge; + + #[test] + pub fn test_native_multi_observe() { + // OpenVM DSL + let mut builder = AsmBuilder::::default(); + + vm_program(&mut builder); + + builder.halt(); + + // Pass in witness stream + let witness_stream: Vec< + Vec>, + > = Vec::new(); + + // Compile program + let options = CompilerOptions::default().with_cycle_tracker(); + let mut compiler = AsmCompiler::new(options.word_size); + compiler.build(builder.operations); + let asm_code = compiler.code(); + let program = convert_program(asm_code, options); + + let mut system_config = SystemConfig::default() + .with_public_values(4) + .with_max_segment_len((1 << 25) - 100); + system_config.profiling = true; + let config = NativeConfig::new(system_config, Native); + + let executor = VmExecutor::::new(config); + + // Alternative execution + // executor.execute(program, witness_stream).unwrap(); + + let res = executor + .execute_and_then(program, witness_stream, |_, seg| Ok(seg), |err| err) + .unwrap(); + + for (i, seg) in res.iter().enumerate() { + println!("=> segment {:?} metrics: {:?}", i, seg.metrics); + } } -} -fn vm_program(builder: &mut Builder) { - let e1: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(16)); - let e2: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(32)); - let e3: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(64)); - let e4: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(128)); - let e5: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(256)); - let len: usize = 5; + fn vm_program(builder: &mut Builder) { + let e1: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(16)); + let e2: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(32)); + let e3: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(64)); + let e4: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(128)); + let e5: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(256)); + let len: usize = 5; - let e_arr: Array> = builder.dyn_array(len); - builder.set(&e_arr, 0, e1); - builder.set(&e_arr, 1, e2); - builder.set(&e_arr, 2, e3); - builder.set(&e_arr, 3, e4); - builder.set(&e_arr, 4, e5); + let e_arr: Array> = builder.dyn_array(len); + builder.set(&e_arr, 0, e1); + builder.set(&e_arr, 1, e2); + builder.set(&e_arr, 2, e3); + builder.set(&e_arr, 3, e4); + builder.set(&e_arr, 4, e5); - unsafe { - let mut c1 = DuplexChallengerVariable::new(builder); - let mut c2 = DuplexChallengerVariable::new(builder); + unsafe { + let mut c1 = DuplexChallengerVariable::new(builder); + let mut c2 = DuplexChallengerVariable::new(builder); - let f_arr1 = exts_to_felts(builder, &e_arr); - let f_arr2 = f_arr1.clone(); + let f_arr1 = exts_to_felts(builder, &e_arr); + let f_arr2 = f_arr1.clone(); - challenger_multi_observe(builder, &mut c1, &f_arr1); - let test_e1 = c1.sample(builder); + challenger_multi_observe(builder, &mut c1, &f_arr1); + let test_e1 = c1.sample(builder); - c2.observe_slice(builder, f_arr2); - let test_e2 = c2.sample(builder); + c2.observe_slice(builder, f_arr2); + let test_e2 = c2.sample(builder); - builder.assert_felt_eq(test_e1, test_e2); + builder.assert_felt_eq(test_e1, test_e2); + } } } diff --git a/src/tower_verifier/program.rs b/src/tower_verifier/program.rs index 09e973c..85d097b 100644 --- a/src/tower_verifier/program.rs +++ b/src/tower_verifier/program.rs @@ -551,15 +551,17 @@ pub fn verify_tower_proof( }, // update point and eval only for last layer |builder| { - builder.set( - &prod_spec_point_n_eval, - spec_index, - PointAndEvalVariable { + let point_and_eval: PointAndEvalVariable = + builder.eval(PointAndEvalVariable { point: PointVariable { fs: rt_prime.clone(), }, eval: evals, - }, + }); + builder.set_value( + &prod_spec_point_n_eval, + spec_index, + point_and_eval, ); }, ); @@ -617,26 +619,22 @@ pub fn verify_tower_proof( }, // update point and eval only for last layer |builder| { - builder.set( - &logup_spec_p_point_n_eval, - spec_index, - PointAndEvalVariable { + let p_eval: PointAndEvalVariable = + builder.eval(PointAndEvalVariable { point: PointVariable { fs: rt_prime.clone(), }, eval: p_eval, - }, - ); - builder.set( - &logup_spec_q_point_n_eval, - spec_index, - PointAndEvalVariable { + }); + let q_eval: PointAndEvalVariable = + builder.eval(PointAndEvalVariable { point: PointVariable { fs: rt_prime.clone(), }, eval: q_eval, - }, - ); + }); + builder.set_value(&logup_spec_p_point_n_eval, spec_index, p_eval); + builder.set_value(&logup_spec_q_point_n_eval, spec_index, q_eval); }, ); }); @@ -649,12 +647,15 @@ pub fn verify_tower_proof( builder.cycle_tracker_end("derive next layer's expected sum"); - next_rt = PointAndEvalVariable { - point: PointVariable { - fs: rt_prime.clone(), + builder.assign( + &next_rt, + PointAndEvalVariable { + point: PointVariable { + fs: rt_prime.clone(), + }, + eval: curr_eval.clone(), }, - eval: curr_eval.clone(), - }; + ); }); ( diff --git a/src/zkvm_verifier/binding.rs b/src/zkvm_verifier/binding.rs index d0c3e37..5e35664 100644 --- a/src/zkvm_verifier/binding.rs +++ b/src/zkvm_verifier/binding.rs @@ -1,5 +1,7 @@ use crate::arithmetics::next_pow2_instance_padding; -use crate::basefold_verifier::basefold::{BasefoldProof, BasefoldProofVariable}; +use crate::basefold_verifier::basefold::{ + BasefoldCommitment, BasefoldCommitmentVariable, BasefoldProof, BasefoldProofVariable, +}; use crate::basefold_verifier::query_phase::{ QueryPhaseVerifierInput, QueryPhaseVerifierInputVariable, }; @@ -9,7 +11,7 @@ use crate::{ }; use ark_std::iterable::Iterable; use ff_ext::BabyBearExt4; -use mpcs::BasefoldCommitment; +use itertools::Itertools; use openvm_native_compiler::{ asm::AsmConfig, ir::{Array, Builder, Config, Felt}, @@ -29,19 +31,11 @@ pub struct ZKVMProofInputVariable { pub raw_pi: Array>>, pub raw_pi_num_variables: Array>, pub pi_evals: Array>, - pub opcode_proofs: Array>, - pub table_proofs: Array>, - - pub witin_commit: Array>, - pub witin_commit_trivial_commits: Array>>, - pub witin_commit_log2_max_codeword_size: Felt, - - pub has_fixed_commit: Usize, - pub fixed_commit: Array>, - pub fixed_commit_trivial_commits: Array>>, - pub fixed_commit_log2_max_codeword_size: Felt, - pub num_instances: Array>>, - + pub chip_proofs: Array>, + pub max_num_var: Var, + pub witin_commit: BasefoldCommitmentVariable, + pub witin_perm: Array>, + pub fixed_perm: Array>, pub pcs_proof: BasefoldProofVariable, } @@ -56,7 +50,7 @@ pub struct TowerProofInputVariable { } #[derive(DslVariable, Clone)] -pub struct ZKVMOpcodeProofInputVariable { +pub struct ZKVMChipProofInputVariable { pub idx: Usize, pub idx_felt: Felt, pub num_instances: Usize, @@ -78,37 +72,15 @@ pub struct ZKVMOpcodeProofInputVariable { pub fixed_in_evals: Array>, } -#[derive(DslVariable, Clone)] -pub struct ZKVMTableProofInputVariable { - pub idx: Usize, - pub idx_felt: Felt, - pub num_instances: Usize, - pub log2_num_instances: Usize, - - pub record_r_out_evals_len: Usize, - pub record_w_out_evals_len: Usize, - pub record_lk_out_evals_len: Usize, - - pub record_r_out_evals: Array>>, - pub record_w_out_evals: Array>>, - pub record_lk_out_evals: Array>>, - - pub tower_proof: TowerProofInputVariable, - pub fixed_in_evals: Array>, - pub wits_in_evals: Array>, -} - pub(crate) struct ZKVMProofInput { pub raw_pi: Vec>, // Evaluation of raw_pi. pub pi_evals: Vec, - pub opcode_proofs: Vec, - pub table_proofs: Vec, - pub witin_commit: BasefoldCommitment, - pub fixed_commit: Option>, - pub num_instances: Vec<(usize, usize)>, + pub chip_proofs: Vec, + pub witin_commit: BasefoldCommitment, pub pcs_proof: BasefoldProof, } + impl Hintable for ZKVMProofInput { type HintVariable = ZKVMProofInputVariable; @@ -116,121 +88,70 @@ impl Hintable for ZKVMProofInput { let raw_pi = Vec::>::read(builder); let raw_pi_num_variables = Vec::::read(builder); let pi_evals = Vec::::read(builder); - let opcode_proofs = Vec::::read(builder); - let table_proofs = Vec::::read(builder); - - let witin_commit = Vec::::read(builder); - let witin_commit_trivial_commits = Vec::>::read(builder); - let witin_commit_log2_max_codeword_size = F::read(builder); - - let has_fixed_commit = Usize::Var(usize::read(builder)); - let fixed_commit = Vec::::read(builder); - let fixed_commit_trivial_commits = Vec::>::read(builder); - let fixed_commit_log2_max_codeword_size = F::read(builder); - - let num_instances = Vec::>::read(builder); - + let chip_proofs = Vec::::read(builder); + let max_num_var = usize::read(builder); + let witin_commit = BasefoldCommitment::read(builder); + let witin_perm = Vec::::read(builder); + let fixed_perm = Vec::::read(builder); let pcs_proof = BasefoldProof::read(builder); ZKVMProofInputVariable { raw_pi, raw_pi_num_variables, pi_evals, - opcode_proofs, - table_proofs, + chip_proofs, + max_num_var, witin_commit, - witin_commit_trivial_commits, - witin_commit_log2_max_codeword_size, - has_fixed_commit, - fixed_commit, - fixed_commit_trivial_commits, - fixed_commit_log2_max_codeword_size, - num_instances, + witin_perm, + fixed_perm, pcs_proof, } } fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); - stream.extend(self.raw_pi.write()); + let raw_pi_num_variables: Vec = self + .raw_pi + .iter() + .map(|v| ceil_log2(v.len().next_power_of_two())) + .collect(); + let witin_num_vars = self + .chip_proofs + .iter() + .map(|proof| ceil_log2(proof.num_instances).max(1)) + .collect::>(); + let fixed_num_vars = self + .chip_proofs + .iter() + .filter(|proof| proof.fixed_in_evals.len() > 0) + .map(|proof| ceil_log2(proof.num_instances).max(1)) + .collect::>(); + let max_num_var = witin_num_vars.iter().map(|x| *x).max().unwrap_or(0); + let get_perm = |v: Vec| { + let mut perm = vec![0; v.len()]; + v.into_iter() + // the original order + .enumerate() + .sorted_by(|(_, nv_a), (_, nv_b)| Ord::cmp(nv_b, nv_a)) + .enumerate() + // j is the new index where i is the original index + .map(|(j, (i, _))| (i, j)) + .for_each(|(i, j)| { + perm[i] = j; + }); + perm + }; + let witin_perm = get_perm(witin_num_vars); + let fixed_perm = get_perm(fixed_num_vars); - let mut raw_pi_num_variables: Vec = vec![]; - for v in &self.raw_pi { - raw_pi_num_variables.push(ceil_log2(v.len().next_power_of_two())); - } + stream.extend(self.raw_pi.write()); stream.extend(raw_pi_num_variables.write()); - stream.extend(self.pi_evals.write()); - stream.extend(self.opcode_proofs.write()); - stream.extend(self.table_proofs.write()); - - // Write in witin_commit - let mut cmt_vec: Vec = vec![]; - self.witin_commit.commit().iter().for_each(|x| { - let f: F = serde_json::from_value(serde_json::to_value(&x).unwrap()).unwrap(); - cmt_vec.push(f); - }); - let mut witin_commit_trivial_commits: Vec> = vec![]; - // for trivial_commit in &self.witin_commit.trivial_commits { - // let mut t_cmt_vec: Vec = vec![]; - // trivial_commit.1.iter().for_each(|x| { - // let f: F = - // serde_json::from_value(serde_json::to_value(x.clone()).unwrap()).unwrap(); - // t_cmt_vec.push(f); - // }); - // witin_commit_trivial_commits.push(t_cmt_vec); - // } - let witin_commit_log2_max_codeword_size = - F::from_canonical_u32(self.witin_commit.log2_max_codeword_size as u32); - stream.extend(cmt_vec.write()); - stream.extend(witin_commit_trivial_commits.write()); - stream.extend(witin_commit_log2_max_codeword_size.write()); - - // Write in fixed_commit - let has_fixed_commit: usize = if self.fixed_commit.is_some() { 1 } else { 0 }; - let mut fixed_commit_vec: Vec = vec![]; - let mut fixed_commit_trivial_commits: Vec> = vec![]; - let mut fixed_commit_log2_max_codeword_size: F = F::ZERO.clone(); - if has_fixed_commit > 0 { - self.fixed_commit - .as_ref() - .unwrap() - .commit() - .iter() - .for_each(|x| { - let f: F = - serde_json::from_value(serde_json::to_value(x.clone()).unwrap()).unwrap(); - fixed_commit_vec.push(f); - }); - - // for trivial_commit in &self.fixed_commit.as_ref().unwrap().trivial_commits { - // let mut t_cmt_vec: Vec = vec![]; - // trivial_commit.1.iter().for_each(|x| { - // let f: F = - // serde_json::from_value(serde_json::to_value(x.clone()).unwrap()).unwrap(); - // t_cmt_vec.push(f); - // }); - // fixed_commit_trivial_commits.push(t_cmt_vec); - // } - fixed_commit_log2_max_codeword_size = F::from_canonical_u32( - self.fixed_commit.as_ref().unwrap().log2_max_codeword_size as u32, - ); - } - stream.extend(>::write(&has_fixed_commit)); - stream.extend(fixed_commit_vec.write()); - stream.extend(fixed_commit_trivial_commits.write()); - stream.extend(fixed_commit_log2_max_codeword_size.write()); - - // Write num_instances - let mut num_instances_vec: Vec> = vec![]; - for (circuit_size, num_var) in &self.num_instances { - num_instances_vec.push(vec![ - F::from_canonical_usize(*circuit_size), - F::from_canonical_usize(*num_var), - ]); - } - stream.extend(num_instances_vec.write()); - + stream.extend(self.chip_proofs.write()); + stream.extend(>::write(&max_num_var)); + stream.extend(self.witin_commit.write()); + stream.extend(witin_perm.write()); + stream.extend(fixed_perm.write()); stream.extend(self.pcs_proof.write()); stream @@ -248,6 +169,7 @@ pub struct TowerProofInput { pub num_logup_specs: usize, pub logup_specs_eval: Vec>>, } + impl Hintable for TowerProofInput { type HintVariable = TowerProofInputVariable; @@ -308,7 +230,7 @@ impl Hintable for TowerProofInput { } } -pub struct ZKVMOpcodeProofInput { +pub struct ZKVMChipProofInput { pub idx: usize, pub num_instances: usize, @@ -327,9 +249,11 @@ pub struct ZKVMOpcodeProofInput { pub wits_in_evals: Vec, pub fixed_in_evals: Vec, } -impl VecAutoHintable for ZKVMOpcodeProofInput {} -impl Hintable for ZKVMOpcodeProofInput { - type HintVariable = ZKVMOpcodeProofInputVariable; + +impl VecAutoHintable for ZKVMChipProofInput {} + +impl Hintable for ZKVMChipProofInput { + type HintVariable = ZKVMChipProofInputVariable; fn read(builder: &mut Builder) -> Self::HintVariable { let idx = Usize::Var(usize::read(builder)); @@ -351,7 +275,7 @@ impl Hintable for ZKVMOpcodeProofInput { let wits_in_evals = Vec::::read(builder); let fixed_in_evals = Vec::::read(builder); - ZKVMOpcodeProofInputVariable { + ZKVMChipProofInputVariable { idx, idx_felt, num_instances, @@ -412,92 +336,3 @@ impl Hintable for ZKVMOpcodeProofInput { stream } } - -pub struct ZKVMTableProofInput { - pub idx: usize, - pub num_instances: usize, - - // tower evaluation at layer 1 - pub record_r_out_evals_len: usize, - pub record_w_out_evals_len: usize, - pub record_lk_out_evals_len: usize, - pub record_r_out_evals: Vec>, - pub record_w_out_evals: Vec>, - pub record_lk_out_evals: Vec>, - - pub tower_proof: TowerProofInput, - - pub fixed_in_evals: Vec, - pub wits_in_evals: Vec, -} -impl VecAutoHintable for ZKVMTableProofInput {} -impl Hintable for ZKVMTableProofInput { - type HintVariable = ZKVMTableProofInputVariable; - - fn read(builder: &mut Builder) -> Self::HintVariable { - let idx = Usize::Var(usize::read(builder)); - let idx_felt = F::read(builder); - - let num_instances = Usize::Var(usize::read(builder)); - let log2_num_instances = Usize::Var(usize::read(builder)); - - let record_r_out_evals_len = Usize::Var(usize::read(builder)); - let record_w_out_evals_len = Usize::Var(usize::read(builder)); - let record_lk_out_evals_len = Usize::Var(usize::read(builder)); - - let record_r_out_evals = Vec::>::read(builder); - let record_w_out_evals = Vec::>::read(builder); - let record_lk_out_evals = Vec::>::read(builder); - - let tower_proof = TowerProofInput::read(builder); - let fixed_in_evals = Vec::::read(builder); - let wits_in_evals = Vec::::read(builder); - - ZKVMTableProofInputVariable { - idx, - idx_felt, - num_instances, - log2_num_instances, - record_r_out_evals_len, - record_w_out_evals_len, - record_lk_out_evals_len, - record_r_out_evals, - record_w_out_evals, - record_lk_out_evals, - tower_proof, - fixed_in_evals, - wits_in_evals, - } - } - - fn write(&self) -> Vec::N>> { - let mut stream = Vec::new(); - stream.extend(>::write(&self.idx)); - - let idx_u32: F = F::from_canonical_u32(self.idx as u32); - stream.extend(idx_u32.write()); - - stream.extend(>::write(&self.num_instances)); - let log2_num_instances = ceil_log2(self.num_instances); - stream.extend(>::write(&log2_num_instances)); - - stream.extend(>::write( - &self.record_r_out_evals_len, - )); - stream.extend(>::write( - &self.record_w_out_evals_len, - )); - stream.extend(>::write( - &self.record_lk_out_evals_len, - )); - - stream.extend(self.record_r_out_evals.write()); - stream.extend(self.record_w_out_evals.write()); - stream.extend(self.record_lk_out_evals.write()); - - stream.extend(self.tower_proof.write()); - stream.extend(self.fixed_in_evals.write()); - stream.extend(self.wits_in_evals.write()); - stream - } -} diff --git a/src/zkvm_verifier/verifier.rs b/src/zkvm_verifier/verifier.rs index ed3b6b8..2fabd37 100644 --- a/src/zkvm_verifier/verifier.rs +++ b/src/zkvm_verifier/verifier.rs @@ -1,12 +1,15 @@ -use super::binding::{ - ZKVMOpcodeProofInputVariable, ZKVMProofInputVariable, ZKVMTableProofInputVariable, -}; +use super::binding::{ZKVMChipProofInputVariable, ZKVMProofInputVariable}; use crate::arithmetics::{ challenger_multi_observe, eval_ceno_expr_with_instance, print_ext_arr, print_felt_arr, PolyEvaluator, UniPolyExtrapolator, }; +use crate::basefold_verifier::basefold::{ + BasefoldCommitmentVariable, RoundOpeningVariable, RoundVariable, +}; +use crate::basefold_verifier::mmcs::MmcsCommitmentVariable; +use crate::basefold_verifier::query_phase::PointAndEvalsVariable; +use crate::basefold_verifier::utils::pow_2; use crate::basefold_verifier::verifier::batch_verify; -use crate::e2e::SubcircuitParams; use crate::tower_verifier::program::verify_tower_proof; use crate::transcript::transcript_observe_label; use crate::{ @@ -18,20 +21,25 @@ use crate::{ tower_verifier::{binding::PointVariable, program::iop_verifier_state_verify}, }; use ceno_mle::expression::{Instance, StructuralWitIn}; +use ceno_zkvm::e2e::B; +use ceno_zkvm::structs::VerifyingKey; use ceno_zkvm::{circuit_builder::SetTableSpec, scheme::verifier::ZKVMVerifier}; use ff_ext::BabyBearExt4; -use itertools::interleave; use itertools::max; +use itertools::{interleave, Itertools}; use mpcs::{Basefold, BasefoldRSParams}; use openvm_native_compiler::prelude::*; use openvm_native_compiler_derive::iter_zip; use openvm_native_recursion::challenger::{ duplex::DuplexChallengerVariable, CanObserveVariable, FeltChallenger, }; +use p3_baby_bear::BabyBear; use p3_field::{Field, FieldAlgebra}; +type F = BabyBear; type E = BabyBearExt4; type Pcs = Basefold; + const NUM_FANIN: usize = 2; const MAINCONSTRAIN_SUMCHECK_BATCH_SIZE: usize = 3; // read/write/lookup const SEL_DEGREE: usize = 2; @@ -69,11 +77,10 @@ pub fn transcript_group_sample_ext( e } -pub fn verify_zkvm_proof( +pub fn verify_zkvm_proof>( builder: &mut Builder, zkvm_proof_input: ZKVMProofInputVariable, - ceno_constraint_system: &ZKVMVerifier, - proving_sequence: Vec, + vk: &ZKVMVerifier, ) { let mut challenger = DuplexChallengerVariable::new(builder); transcript_observe_label(builder, &mut challenger, b"riscv"); @@ -101,44 +108,63 @@ pub fn verify_zkvm_proof( }, ); - challenger_multi_observe(builder, &mut challenger, &zkvm_proof_input.fixed_commit); - iter_zip!(builder, zkvm_proof_input.fixed_commit_trivial_commits).for_each( - |ptr_vec, builder| { - let trivial_cmt = - builder.iter_ptr_get(&zkvm_proof_input.fixed_commit_trivial_commits, ptr_vec[0]); - challenger_multi_observe(builder, &mut challenger, &trivial_cmt); - }, - ); - challenger.observe( - builder, - zkvm_proof_input.fixed_commit_log2_max_codeword_size, - ); + let fixed_commit = if let Some(fixed_commit) = vk.vk.fixed_commit.as_ref() { + let commit: crate::basefold_verifier::hash::Hash = fixed_commit.commit().into(); + let commit_array: Array> = builder.dyn_array(commit.value.len()); + commit.value.into_iter().enumerate().for_each(|(i, v)| { + let v = builder.constant(v); + // TODO: put fixed commit to public values + // builder.commit_public_value(v); + + builder.set_value(&commit_array, i, v); + }); + challenger_multi_observe(builder, &mut challenger, &commit_array); + + // FIXME: do not hardcode this in the program + let log2_max_codeword_size_felt = builder.constant(C::F::from_canonical_usize( + fixed_commit.log2_max_codeword_size, + )); + let log2_max_codeword_size: Var = builder.constant(C::N::from_canonical_usize( + fixed_commit.log2_max_codeword_size, + )); + + challenger.observe(builder, log2_max_codeword_size_felt); + + Some(BasefoldCommitmentVariable { + commit: MmcsCommitmentVariable { + value: commit_array, + }, + log2_max_codeword_size: log2_max_codeword_size.into(), + }) + } else { + None + }; let zero_f: Felt = builder.constant(C::F::ZERO); - iter_zip!(builder, zkvm_proof_input.num_instances).for_each(|ptr_vec, builder| { - let ns = builder.iter_ptr_get(&zkvm_proof_input.num_instances, ptr_vec[0]); - let circuit_size = builder.get(&ns, 0); - let num_var = builder.get(&ns, 1); + iter_zip!(builder, zkvm_proof_input.chip_proofs).for_each(|ptr_vec, builder| { + let chip_proof = builder.iter_ptr_get(&zkvm_proof_input.chip_proofs, ptr_vec[0]); + let num_instances = builder.unsafe_cast_var_to_felt(chip_proof.num_instances.get_var()); - challenger.observe(builder, circuit_size); + challenger.observe(builder, chip_proof.idx_felt); challenger.observe(builder, zero_f); - challenger.observe(builder, num_var); + challenger.observe(builder, num_instances); challenger.observe(builder, zero_f); }); - challenger_multi_observe(builder, &mut challenger, &zkvm_proof_input.witin_commit); - - iter_zip!(builder, zkvm_proof_input.witin_commit_trivial_commits).for_each( - |ptr_vec, builder| { - let trivial_cmt = - builder.iter_ptr_get(&zkvm_proof_input.witin_commit_trivial_commits, ptr_vec[0]); - challenger_multi_observe(builder, &mut challenger, &trivial_cmt); - }, - ); - challenger.observe( + challenger_multi_observe( builder, - zkvm_proof_input.witin_commit_log2_max_codeword_size, + &mut challenger, + &zkvm_proof_input.witin_commit.commit.value, ); + { + let log2_max_codeword_size = builder.unsafe_cast_var_to_felt( + zkvm_proof_input + .witin_commit + .log2_max_codeword_size + .get_var(), + ); + challenger.observe(builder, log2_max_codeword_size); + } let alpha = challenger.sample_ext(builder); let beta = challenger.sample_ext(builder); @@ -151,124 +177,178 @@ pub fn verify_zkvm_proof( let mut poly_evaluator = PolyEvaluator::new(builder); let dummy_table_item = alpha.clone(); - let dummy_table_item_multiplicity: Ext = builder.constant(C::EF::ZERO); - - let mut rt_points: Vec>> = Vec::with_capacity(proving_sequence.len()); - let mut evaluations: Vec>> = - Vec::with_capacity(2 * proving_sequence.len()); // witin + fixed thus *2 - - for subcircuit_params in proving_sequence { - if subcircuit_params.is_opcode { - let opcode_proof = builder.get( - &zkvm_proof_input.opcode_proofs, - subcircuit_params.type_order_idx, - ); - let id_f: Felt = - builder.constant(C::F::from_canonical_usize(subcircuit_params.id)); - challenger.observe(builder, id_f); - - builder.cycle_tracker_start("Verify opcode proof"); - let input_opening_point = verify_opcode_proof( - builder, - &mut challenger, - &opcode_proof, - &zkvm_proof_input.pi_evals, - &challenges, - &subcircuit_params, - &ceno_constraint_system, - &mut unipoly_extrapolator, - ); - builder.cycle_tracker_end("Verify opcode proof"); + let dummy_table_item_multiplicity: Var = builder.constant(C::N::ZERO); + + let num_fixed_opening = vk + .vk + .circuit_vks + .values() + .filter(|c| c.get_cs().num_fixed() > 0) + .count(); + let witin_openings: Array> = + builder.dyn_array(zkvm_proof_input.chip_proofs.len()); + let fixed_openings: Array> = + builder.dyn_array(Usize::from(num_fixed_opening)); + let num_chips_verified: Usize = builder.eval(C::N::ZERO); + let num_chips_have_fixed: Usize = builder.eval(C::N::ZERO); + + let chip_indices: Array> = builder.dyn_array(zkvm_proof_input.chip_proofs.len()); + builder + .range(0, chip_indices.len()) + .for_each(|idx_vec, builder| { + let i = idx_vec[0]; + let chip_proof = builder.get(&zkvm_proof_input.chip_proofs, i); + builder.set(&chip_indices, i, chip_proof.idx); + }); - rt_points.push(input_opening_point); - evaluations.push(opcode_proof.wits_in_evals); + // iterate over all chips + for (i, chip_vk) in vk.vk.circuit_vks.values().enumerate() { + let chip_id: Var = builder.get(&chip_indices, num_chips_verified.get_var()); + builder.if_eq(chip_id, RVar::from(i)).then(|builder| { + let chip_proof = + builder.get(&zkvm_proof_input.chip_proofs, num_chips_verified.get_var()); + challenger.observe(builder, chip_proof.idx_felt); + + builder.cycle_tracker_start("Verify chip proof"); + let input_opening_point = if chip_vk.get_cs().is_opcode_circuit() { + verify_opcode_proof( + builder, + &mut challenger, + &chip_proof, + &zkvm_proof_input.pi_evals, + &challenges, + &chip_vk, + &mut unipoly_extrapolator, + ) + } else { + verify_table_proof( + builder, + &mut challenger, + &chip_proof, + &zkvm_proof_input.pi_evals, + &challenges, + &chip_vk, + &mut unipoly_extrapolator, + ) + }; + builder.cycle_tracker_end("Verify chip proof"); // getting the number of dummy padding item that we used in this opcode circuit - let cs = ceno_constraint_system.vk.circuit_vks[&subcircuit_params.name].get_cs(); - let num_instances = subcircuit_params.num_instances; - let num_lks = cs.zkvm_v1_css.lk_expressions.len(); - let num_padded_instance = next_pow2_instance_padding(num_instances) - num_instances; - - let new_multiplicity: Ext = - builder.constant(C::EF::from_canonical_usize(num_lks * num_padded_instance)); - builder.assign( - &dummy_table_item_multiplicity, - dummy_table_item_multiplicity + new_multiplicity, - ); + if chip_vk.get_cs().is_opcode_circuit() { + let num_lks = chip_vk.get_cs().num_lks(); + // FIXME: use builder to compute this + let num_instances = pow_2(builder, chip_proof.log2_num_instances.get_var()); + let num_padded_instance: Var = + builder.eval(num_instances - chip_proof.num_instances); + + let new_multiplicity: Usize = + builder.eval(Usize::from(num_lks) * Usize::from(num_padded_instance)); + builder.assign( + &dummy_table_item_multiplicity, + dummy_table_item_multiplicity + new_multiplicity, + ); + } - let record_r_out_evals_prod = nested_product(builder, &opcode_proof.record_r_out_evals); + let record_r_out_evals_prod = nested_product(builder, &chip_proof.record_r_out_evals); builder.assign(&prod_r, prod_r * record_r_out_evals_prod); - let record_w_out_evals_prod = nested_product(builder, &opcode_proof.record_w_out_evals); + let record_w_out_evals_prod = nested_product(builder, &chip_proof.record_w_out_evals); builder.assign(&prod_w, prod_w * record_w_out_evals_prod); - iter_zip!(builder, opcode_proof.record_lk_out_evals).for_each(|ptr_vec, builder| { - let evals = builder.iter_ptr_get(&opcode_proof.record_lk_out_evals, ptr_vec[0]); + let sign: Ext = if chip_vk.get_cs().is_opcode_circuit() { + builder.constant(C::EF::ONE) + } else { + builder.constant(-C::EF::ONE) + }; + + iter_zip!(builder, chip_proof.record_lk_out_evals).for_each(|ptr_vec, builder| { + let evals = builder.iter_ptr_get(&chip_proof.record_lk_out_evals, ptr_vec[0]); let p1 = builder.get(&evals, 0); let p2 = builder.get(&evals, 1); let q1 = builder.get(&evals, 2); let q2 = builder.get(&evals, 3); - builder.assign(&logup_sum, logup_sum + p1 * q1.inverse()); - builder.assign(&logup_sum, logup_sum + p2 * q2.inverse()); + builder.assign(&logup_sum, logup_sum + sign * p1 * q1.inverse()); + builder.assign(&logup_sum, logup_sum + sign * p2 * q2.inverse()); }); - } else { - let table_proof = builder.get( - &zkvm_proof_input.table_proofs, - subcircuit_params.type_order_idx, - ); - let id_f: Felt = - builder.constant(C::F::from_canonical_usize(subcircuit_params.id)); - challenger.observe(builder, id_f); - let input_opening_point = verify_table_proof( - builder, - &mut challenger, - &table_proof, - &zkvm_proof_input.raw_pi, - &zkvm_proof_input.raw_pi_num_variables, - &zkvm_proof_input.pi_evals, - &challenges, - &subcircuit_params, - ceno_constraint_system, - &mut unipoly_extrapolator, - &mut poly_evaluator, + builder.assert_usize_eq( + chip_proof.log2_num_instances.clone(), + input_opening_point.len(), ); - rt_points.push(input_opening_point); - evaluations.push(table_proof.wits_in_evals); - let cs = ceno_constraint_system.vk.circuit_vks[&subcircuit_params.name].get_cs(); - if cs.num_fixed() > 0 { - evaluations.push(table_proof.fixed_in_evals); - } - - iter_zip!(builder, table_proof.record_lk_out_evals).for_each(|ptr_vec, builder| { - let evals = builder.iter_ptr_get(&table_proof.record_lk_out_evals, ptr_vec[0]); - let p1 = builder.get(&evals, 0); - let p2 = builder.get(&evals, 1); - let q1 = builder.get(&evals, 2); - let q2 = builder.get(&evals, 3); - builder.assign( - &logup_sum, - logup_sum - p1 * q1.inverse() - p2 * q2.inverse(), - ); + let witin_round: RoundOpeningVariable = builder.eval(RoundOpeningVariable { + num_var: chip_proof.log2_num_instances.get_var(), + point_and_evals: PointAndEvalsVariable { + point: PointVariable { + fs: input_opening_point.clone(), + }, + evals: chip_proof.wits_in_evals, + }, }); + builder.set_value(&witin_openings, num_chips_verified.get_var(), witin_round); + + if chip_vk.get_cs().num_fixed() > 0 { + let fixed_round: RoundOpeningVariable = builder.eval(RoundOpeningVariable { + num_var: chip_proof.log2_num_instances.get_var(), + point_and_evals: PointAndEvalsVariable { + point: PointVariable { + fs: input_opening_point.clone(), + }, + evals: chip_proof.fixed_in_evals, + }, + }); - let record_w_out_evals_prod = nested_product(builder, &table_proof.record_w_out_evals); - builder.assign(&prod_w, prod_w * record_w_out_evals_prod); - let record_r_out_evals_prod = nested_product(builder, &table_proof.record_r_out_evals); - builder.assign(&prod_r, prod_r * record_r_out_evals_prod); - } + builder.set_value(&fixed_openings, num_chips_have_fixed.get_var(), fixed_round); + + builder.inc(&num_chips_have_fixed); + } + + builder.inc(&num_chips_verified); + }); } + builder.assert_usize_eq(num_chips_have_fixed, Usize::from(num_fixed_opening)); + builder.assert_eq::>(num_chips_verified, chip_indices.len()); + let dummy_table_item_multiplicity = + builder.unsafe_cast_var_to_felt(dummy_table_item_multiplicity); builder.assign( &logup_sum, logup_sum - dummy_table_item_multiplicity * dummy_table_item.inverse(), ); - // TODO: prepare rounds and uncomment this - - // batch_verifier(builder, rounds, zkvm_proof_input.pcs_proof, &mut challenger); + let rounds = if num_fixed_opening > 0 { + builder.dyn_array(2) + } else { + builder.dyn_array(1) + }; + builder.set( + &rounds, + 0, + RoundVariable { + commit: zkvm_proof_input.witin_commit, + openings: witin_openings, + perm: zkvm_proof_input.witin_perm.clone(), + }, + ); + if num_fixed_opening > 0 { + builder.set( + &rounds, + 1, + RoundVariable { + commit: fixed_commit.unwrap(), + openings: fixed_openings, + perm: zkvm_proof_input.fixed_perm, + }, + ); + } + batch_verify( + builder, + zkvm_proof_input.max_num_var, + rounds, + zkvm_proof_input.pcs_proof, + &mut challenger, + ); let empty_arr: Array> = builder.dyn_array(0); let initial_global_state = eval_ceno_expr_with_instance( @@ -278,7 +358,7 @@ pub fn verify_zkvm_proof( &empty_arr, &zkvm_proof_input.pi_evals, &challenges, - &ceno_constraint_system.vk.initial_global_state_expr, + &vk.vk.initial_global_state_expr, ); builder.assign(&prod_w, prod_w * initial_global_state); @@ -289,7 +369,7 @@ pub fn verify_zkvm_proof( &empty_arr, &zkvm_proof_input.pi_evals, &challenges, - &ceno_constraint_system.vk.finalize_global_state_expr, + &vk.vk.finalize_global_state_expr, ); builder.assign(&prod_r, prod_r * finalize_global_state); @@ -301,14 +381,13 @@ pub fn verify_zkvm_proof( pub fn verify_opcode_proof( builder: &mut Builder, challenger: &mut DuplexChallengerVariable, - opcode_proof: &ZKVMOpcodeProofInputVariable, + opcode_proof: &ZKVMChipProofInputVariable, pi_evals: &Array>, challenges: &Array>, - subcircuit_params: &SubcircuitParams, - cs: &ZKVMVerifier, + vk: &VerifyingKey, unipoly_extrapolator: &mut UniPolyExtrapolator, ) -> Array> { - let cs = &cs.vk.circuit_vks[&subcircuit_params.name].cs; + let cs = vk.get_cs(); let one: Ext = builder.constant(C::EF::ONE); let zero: Ext = builder.constant(C::EF::ZERO); @@ -520,17 +599,16 @@ pub fn verify_opcode_proof( pub fn verify_table_proof( builder: &mut Builder, challenger: &mut DuplexChallengerVariable, - table_proof: &ZKVMTableProofInputVariable, - raw_pi: &Array>>, - raw_pi_num_variables: &Array>, + table_proof: &ZKVMChipProofInputVariable, + // raw_pi: &Array>>, + // raw_pi_num_variables: &Array>, pi_evals: &Array>, challenges: &Array>, - subcircuit_params: &SubcircuitParams, - cs: &ZKVMVerifier, + vk: &VerifyingKey, unipoly_extrapolator: &mut UniPolyExtrapolator, - poly_evaluator: &mut PolyEvaluator, + // poly_evaluator: &mut PolyEvaluator, ) -> Array> { - let cs = cs.vk.circuit_vks[&subcircuit_params.name].get_cs(); + let cs = vk.get_cs(); let tower_proof: &super::binding::TowerProofInputVariable = &table_proof.tower_proof; let r_expected_rounds: Array> = @@ -742,6 +820,7 @@ pub fn verify_table_proof( builder.assert_ext_eq(e, expected_evals); }); + /* TODO: enable this // assume public io is tiny vector, so we evaluate it directly without PCS for &Instance(idx) in cs.instance_name_map().keys() { let poly = builder.get(raw_pi, idx); @@ -751,6 +830,7 @@ pub fn verify_table_proof( let eval = builder.get(&pi_evals, idx); builder.assert_ext_eq(eval, expected_eval); } + */ rt_tower.fs }