From 4eef780387bd62ac77b97c20c48879240a5ded16 Mon Sep 17 00:00:00 2001 From: soham Date: Thu, 2 Jan 2025 08:10:48 +0530 Subject: [PATCH 01/12] refactor `assert_satisfied_full` (#649) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - uses MockProver utils inside `assert_satisfied_full` which enables additional missing checks for things like assert zero expressions. - removes duplicate code to check of multiplicity - enables MockProver to check for multiplicity of all instances (previously only instance 0 was considered) --------- Co-authored-by: Matthias Görgens Co-authored-by: sm.wu --- ceno_emul/src/elf.rs | 14 +- ceno_zkvm/examples/riscv_opcodes.rs | 13 +- ceno_zkvm/src/circuit_builder.rs | 12 +- ceno_zkvm/src/e2e.rs | 35 +- ceno_zkvm/src/expression.rs | 4 + ceno_zkvm/src/scheme/mock_prover.rs | 609 +++++++++++++++------------- ceno_zkvm/src/scheme/tests.rs | 2 +- ceno_zkvm/src/structs.rs | 18 +- ceno_zkvm/src/witness.rs | 170 ++++++-- 9 files changed, 519 insertions(+), 358 deletions(-) diff --git a/ceno_emul/src/elf.rs b/ceno_emul/src/elf.rs index ee59d3de3..ce849fde7 100644 --- a/ceno_emul/src/elf.rs +++ b/ceno_emul/src/elf.rs @@ -18,7 +18,7 @@ extern crate alloc; use alloc::collections::BTreeMap; -use crate::{addr::WORD_SIZE, disassemble::transpile, rv32im::Instruction}; +use crate::{CENO_PLATFORM, addr::WORD_SIZE, disassemble::transpile, rv32im::Instruction}; use anyhow::{Context, Result, anyhow, bail}; use elf::{ ElfBytes, @@ -40,6 +40,17 @@ pub struct Program { pub image: BTreeMap, } +impl From<&[Instruction]> for Program { + fn from(insn_codes: &[Instruction]) -> Program { + Self { + entry: CENO_PLATFORM.pc_base(), + base_address: CENO_PLATFORM.pc_base(), + instructions: insn_codes.to_vec(), + image: Default::default(), + } + } +} + impl Program { /// Create program pub fn new( @@ -55,6 +66,7 @@ impl Program { image, } } + /// Initialize a RISC Zero Program from an appropriate ELF file pub fn load_elf(input: &[u8], max_mem: u32) -> Result { let mut instructions: Vec = Vec::new(); diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index 54813659e..e7240c2c1 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -76,6 +76,7 @@ fn main() { program_code, Default::default(), ); + let program = Arc::new(program); let mem_addresses = CENO_PLATFORM.heap.clone(); let io_addresses = CENO_PLATFORM.public_io.clone(); @@ -163,7 +164,7 @@ fn main() { // init vm.x1 = 1, vm.x2 = -1, vm.x3 = step_loop let public_io_init = init_public_io(&[1, u32::MAX, step_loop]); - let mut vm = VMState::new(CENO_PLATFORM, Arc::new(program.clone())); + let mut vm = VMState::new(CENO_PLATFORM, program.clone()); // init memory mapped IO for record in &public_io_init { @@ -201,7 +202,7 @@ fn main() { config .assign_opcode_circuit(&zkvm_cs, &mut zkvm_witness, all_records) .unwrap(); - zkvm_witness.finalize_lk_multiplicities(); + zkvm_witness.finalize_lk_multiplicities(false); // Find the final register values and cycles. let reg_final = reg_init @@ -275,7 +276,13 @@ fn main() { trace_report.save_json("report.json"); trace_report.save_table("report.txt"); - MockProver::assert_satisfied_full(&zkvm_cs, zkvm_fixed_traces.clone(), &zkvm_witness, &pi); + MockProver::assert_satisfied_full( + &zkvm_cs, + zkvm_fixed_traces.clone(), + &zkvm_witness, + &pi, + &program, + ); let timer = Instant::now(); diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 2e3b585c3..5c3e2dc2c 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -116,9 +116,9 @@ pub struct ConstraintSystem { /// lookup expression pub lk_expressions: Vec>, - pub lk_expressions_namespace_map: Vec, pub lk_table_expressions: Vec>, - pub lk_table_expressions_namespace_map: Vec, + pub lk_expressions_namespace_map: Vec, + pub lk_expressions_items_map: Vec<(ROMType, Vec>)>, /// main constraints zero expression pub assert_zero_expressions: Vec>, @@ -136,7 +136,6 @@ pub struct ConstraintSystem { pub chip_record_beta: Expression, pub debug_map: HashMap>>, - pub lk_expressions_items_map: Vec<(ROMType, Vec>)>, pub(crate) phantom: PhantomData, } @@ -164,9 +163,9 @@ impl ConstraintSystem { w_table_expressions: vec![], w_table_expressions_namespace_map: vec![], lk_expressions: vec![], - lk_expressions_namespace_map: vec![], lk_table_expressions: vec![], - lk_table_expressions_namespace_map: vec![], + lk_expressions_namespace_map: vec![], + lk_expressions_items_map: vec![], assert_zero_expressions: vec![], assert_zero_expressions_namespace_map: vec![], assert_zero_sumcheck_expressions: vec![], @@ -176,7 +175,6 @@ impl ConstraintSystem { chip_record_beta: Expression::Challenge(1, 1, E::ONE, E::ZERO), debug_map: HashMap::new(), - lk_expressions_items_map: vec![], phantom: std::marker::PhantomData, } @@ -326,7 +324,7 @@ impl ConstraintSystem { table_len, }); let path = self.ns.compute_path(name_fn().into()); - self.lk_table_expressions_namespace_map.push(path); + self.lk_expressions_namespace_map.push(path); // Since lk_expression is RLC(record) and when we're debugging // it's helpful to recover the value of record itself. self.lk_expressions_items_map.push((rom_type, record)); diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 0105e2344..31cf9d55c 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1,8 +1,11 @@ use crate::{ instructions::riscv::{DummyExtraConfig, MemPadder, MmuConfig, Rv32imConfig}, scheme::{ - PublicValues, ZKVMProof, constants::MAX_NUM_VARIABLES, mock_prover::MockProver, - prover::ZKVMProver, verifier::ZKVMVerifier, + PublicValues, ZKVMProof, + constants::MAX_NUM_VARIABLES, + mock_prover::{LkMultiplicityKey, MockProver}, + prover::ZKVMProver, + verifier::ZKVMVerifier, }, state::GlobalState, structs::{ @@ -21,7 +24,6 @@ use mpcs::PolynomialCommitmentScheme; use std::{ collections::{BTreeSet, HashMap, HashSet}, iter::zip, - ops::Deref, sync::Arc, }; use transcript::BasicTranscript as Transcript; @@ -296,6 +298,7 @@ pub fn generate_witness( system_config: &ConstraintSystemConfig, emul_result: EmulationResult, program: &Program, + is_mock_proving: bool, ) -> ZKVMWitnesses { let mut zkvm_witness = ZKVMWitnesses::default(); // assign opcode circuits @@ -311,7 +314,7 @@ pub fn generate_witness( .dummy_config .assign_opcode_circuit(&system_config.zkvm_cs, &mut zkvm_witness, dummy_records) .unwrap(); - zkvm_witness.finalize_lk_multiplicities(); + zkvm_witness.finalize_lk_multiplicities(is_mock_proving); // assign table circuits system_config @@ -370,7 +373,10 @@ pub type IntermediateState = (ZKVMProof, ZKVMVerifier); // state external to this pipeline (e.g, sanity check in bin/e2e.rs) #[allow(clippy::type_complexity)] -pub fn run_e2e_with_checkpoint + 'static>( +pub fn run_e2e_with_checkpoint< + E: ExtensionField + LkMultiplicityKey, + PCS: PolynomialCommitmentScheme + 'static, +>( program: Program, platform: Platform, hints: Vec, @@ -414,6 +420,8 @@ pub fn run_e2e_with_checkpoint>( +pub fn run_e2e_proof>( program: Arc, max_steps: usize, init_full_mem: InitMemState, @@ -490,6 +499,7 @@ pub fn run_e2e_proof>( system_config: &ConstraintSystemConfig, pk: ZKVMProvingKey, zkvm_fixed_traces: ZKVMFixedTraces, + is_mock_proving: bool, ) -> ZKVMProof { // Emulate program let emul_result = emulate_program(program.clone(), max_steps, init_full_mem, &platform, hints); @@ -498,17 +508,18 @@ pub fn run_e2e_proof>( let pi = emul_result.pi.clone(); // Generate witness - let zkvm_witness = generate_witness(system_config, emul_result, program.deref()); + let zkvm_witness = generate_witness(system_config, emul_result, &program, is_mock_proving); // proving let prover = ZKVMProver::new(pk); - if std::env::var("MOCK_PROVING").is_ok() { + if is_mock_proving { MockProver::assert_satisfied_full( &system_config.zkvm_cs, zkvm_fixed_traces.clone(), &zkvm_witness, &pi, + &program, ); tracing::info!("Mock proving passed"); } diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 82fb9fe97..6f0757a93 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -216,6 +216,10 @@ impl Expression { self.to_monomial_form_inner() } + pub fn is_constant(&self) -> bool { + matches!(self, Expression::Constant(_)) + } + fn is_zero_expr(expr: &Expression) -> bool { match expr { Expression::Fixed(_) => false, diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 054a5b3cf..494fc6134 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -13,7 +13,7 @@ use crate::{ AndTable, LtuTable, OpsTable, OrTable, PowTable, ProgramTableCircuit, RangeTable, TableCircuit, U5Table, U8Table, U14Table, U16Table, XorTable, }, - witness::{LkMultiplicity, RowMajorMatrix}, + witness::{LkMultiplicity, LkMultiplicityRaw, RowMajorMatrix}, }; use ark_std::test_rng; use base64::{Engine, engine::general_purpose::STANDARD_NO_PAD}; @@ -21,12 +21,14 @@ use ceno_emul::{ByteAddr, CENO_PLATFORM, Platform, Program}; use ff::Field; use ff_ext::ExtensionField; use generic_static::StaticTypeMap; -use goldilocks::SmallField; -use itertools::{Itertools, enumerate, izip}; +use goldilocks::{GoldilocksExt2, SmallField}; +use itertools::{Itertools, chain, enumerate, izip}; use multilinear_extensions::{mle::IntoMLEs, virtual_poly_v2::ArcMultilinearExtension}; use rand::thread_rng; use std::{ - collections::{HashMap, HashSet}, + cmp::max, + collections::{BTreeSet, HashMap, HashSet}, + fmt::Debug, fs::File, hash::Hash, io::{BufReader, ErrorKind}, @@ -51,9 +53,27 @@ pub const MOCK_PC_START: ByteAddr = ByteAddr({ CENO_PLATFORM.pc_base() }); +/// Allow LK Multiplicity's key to be used with `u64` and `GoldilocksExt2`. +pub trait LkMultiplicityKey: Copy + Clone + Debug + Eq + Hash + Send { + /// If key is u64, return Some(u64), otherwise None. + fn to_u64(&self) -> Option; +} + +impl LkMultiplicityKey for u64 { + fn to_u64(&self) -> Option { + Some(*self) + } +} + +impl LkMultiplicityKey for GoldilocksExt2 { + fn to_u64(&self) -> Option { + None + } +} + #[allow(clippy::enum_variant_names)] #[derive(Debug, Clone)] -pub enum MockProverError { +pub enum MockProverError { AssertZeroError { expression: Expression, evaluated: E::BaseField, @@ -74,6 +94,7 @@ pub enum MockProverError { name: String, }, LookupError { + rom_type: ROMType, expression: Expression, evaluated: E, name: String, @@ -84,13 +105,12 @@ pub enum MockProverError { // w_expressions LkMultiplicityError { rom_type: ROMType, - key: u64, + key: K, count: isize, // +ve => missing in cs, -ve => missing in assignments - inst_id: usize, }, } -impl PartialEq for MockProverError { +impl PartialEq for MockProverError { // Compare errors based on the content, ignoring the inst_id fn eq(&self, other: &Self) -> bool { match (self, other) { @@ -154,13 +174,25 @@ impl PartialEq for MockProverError { && left_evaluated == right_evaluated && left_name == right_name } + ( + MockProverError::LkMultiplicityError { + rom_type: left_rom_type, + key: left_key, + count: left_count, + }, + MockProverError::LkMultiplicityError { + rom_type: right_rom_type, + key: right_key, + count: right_count, + }, + ) => (left_rom_type, left_key, left_count) == (right_rom_type, right_key, right_count), _ => false, } } } -impl MockProverError { - pub fn print(&self, wits_in: &[ArcMultilinearExtension], wits_in_name: &[String]) { +impl MockProverError { + fn print(&self, wits_in: &[ArcMultilinearExtension], wits_in_name: &[String]) { let mut wtns = vec![]; match self { @@ -214,6 +246,7 @@ impl MockProverError { ); } Self::LookupError { + rom_type, expression, evaluated, name, @@ -224,6 +257,7 @@ impl MockProverError { let eval_fmt = fmt::field(evaluated); println!( "\nLookupError {name:#?}: Evaluated expression does not exist in T vector\n\ + ROM Type: {rom_type:?}\n\ Expression: {expression_fmt}\n\ Evaluation: {eval_fmt}\n\ Inst[{inst_id}]:\n{wtns_fmt}\n", @@ -240,36 +274,49 @@ impl MockProverError { } else { "Lookup".to_string() }; - let location = if *count > 0 { - "constraint system" + + let (location, element) = if let Some(key) = key.to_u64() { + let location = if *count > 0 { + "constraint system" + } else { + "assignments" + }; + let element = match rom_type { + ROMType::U5 | ROMType::U8 | ROMType::U14 | ROMType::U16 => { + format!("Element: {key:?}") + } + ROMType::And => { + let (a, b) = AndTable::unpack(key); + format!("Element: {a} && {b}") + } + ROMType::Or => { + let (a, b) = OrTable::unpack(key); + format!("Element: {a} || {b}") + } + ROMType::Xor => { + let (a, b) = XorTable::unpack(key); + format!("Element: {a} ^ {b}") + } + ROMType::Ltu => { + let (a, b) = LtuTable::unpack(key); + format!("Element: {a} < {b}") + } + ROMType::Pow => { + let (a, b) = PowTable::unpack(key); + format!("Element: {a} ** {b}") + } + ROMType::Instruction => format!("PC: {key}"), + }; + (location, element) } else { - "assignments" - }; - let element = match rom_type { - ROMType::U5 | ROMType::U8 | ROMType::U14 | ROMType::U16 => { - format!("Element: {key}") - } - ROMType::And => { - let (a, b) = AndTable::unpack(*key); - format!("Element: {a} < {b}") - } - ROMType::Or => { - let (a, b) = OrTable::unpack(*key); - format!("Element: {a} || {b}") - } - ROMType::Xor => { - let (a, b) = XorTable::unpack(*key); - format!("Element: {a} ^ {b}") - } - ROMType::Ltu => { - let (a, b) = LtuTable::unpack(*key); - format!("Element: {a} < {b}") - } - ROMType::Pow => { - let (a, b) = PowTable::unpack(*key); - format!("Element: {a} ** {b}") - } - ROMType::Instruction => format!("PC: {key}"), + ( + if *count > 0 { + "combined_lkm_tables" + } else { + "combined_lkm_opcodes" + }, + format!("Element: {key:?}"), + ) }; println!( "\nLkMultiplicityError:\n\ @@ -285,9 +332,8 @@ impl MockProverError { match self { Self::AssertZeroError { inst_id, .. } | Self::AssertEqualError { inst_id, .. } - | Self::LookupError { inst_id, .. } - | Self::LkMultiplicityError { inst_id, .. } => *inst_id, - Self::DegreeTooHigh { .. } => unreachable!(), + | Self::LookupError { inst_id, .. } => *inst_id, + Self::DegreeTooHigh { .. } | Self::LkMultiplicityError { .. } => unreachable!(), } } @@ -300,15 +346,18 @@ pub struct MockProver { _phantom: PhantomData, } -fn load_tables(cb: &CircuitBuilder, challenge: [E; 2]) -> HashSet> { +fn load_tables( + cs: &ConstraintSystem, + challenge: [E; 2], +) -> HashSet> { fn load_range_table( t_vec: &mut Vec>, - cb: &CircuitBuilder, + cs: &ConstraintSystem, challenge: [E; 2], ) { for i in RANGE::content() { let rlc_record = - cb.rlc_chip_record(vec![(RANGE::ROM_TYPE as usize).into(), (i as usize).into()]); + cs.rlc_chip_record(vec![(RANGE::ROM_TYPE as usize).into(), (i as usize).into()]); let rlc_record = eval_by_expr(&[], &[], &challenge, &rlc_record); t_vec.push(rlc_record.to_canonical_u64_vec()); } @@ -316,11 +365,11 @@ fn load_tables(cb: &CircuitBuilder, challenge: [E; 2]) -> fn load_op_table( t_vec: &mut Vec>, - cb: &CircuitBuilder, + cs: &ConstraintSystem, challenge: [E; 2], ) { for [a, b, c] in OP::content() { - let rlc_record = cb.rlc_chip_record(vec![ + let rlc_record = cs.rlc_chip_record(vec![ (OP::ROM_TYPE as usize).into(), (a as usize).into(), (b as usize).into(), @@ -332,15 +381,15 @@ fn load_tables(cb: &CircuitBuilder, challenge: [E; 2]) -> } let mut table_vec = vec![]; - load_range_table::(&mut table_vec, cb, challenge); - load_range_table::(&mut table_vec, cb, challenge); - load_range_table::(&mut table_vec, cb, challenge); - load_range_table::(&mut table_vec, cb, challenge); - load_op_table::(&mut table_vec, cb, challenge); - load_op_table::(&mut table_vec, cb, challenge); - load_op_table::(&mut table_vec, cb, challenge); - load_op_table::(&mut table_vec, cb, challenge); - load_op_table::(&mut table_vec, cb, challenge); + load_range_table::(&mut table_vec, cs, challenge); + load_range_table::(&mut table_vec, cs, challenge); + load_range_table::(&mut table_vec, cs, challenge); + load_range_table::(&mut table_vec, cs, challenge); + load_op_table::(&mut table_vec, cs, challenge); + load_op_table::(&mut table_vec, cs, challenge); + load_op_table::(&mut table_vec, cs, challenge); + load_op_table::(&mut table_vec, cs, challenge); + load_op_table::(&mut table_vec, cs, challenge); HashSet::from_iter(table_vec) } @@ -349,7 +398,7 @@ fn load_tables(cb: &CircuitBuilder, challenge: [E; 2]) -> // return challenge and table #[allow(clippy::type_complexity)] fn load_once_tables( - cb: &CircuitBuilder, + cs: &ConstraintSystem, ) -> ([E; 2], HashSet>) { static CACHE: OnceLock; 2], HashSet>)>> = OnceLock::new(); let cache = CACHE.get_or_init(StaticTypeMap::new); @@ -373,7 +422,7 @@ fn load_once_tables( let mut file = tempfile::NamedTempFile::new_in(".").unwrap(); // load new table and seserialize to file for later use - let table = load_tables(cb, challenge); + let table = load_tables(cs, challenge); serde_json::to_writer(&mut file, &table).unwrap(); // Persist the file to the target location // This is an atomic operation on Posix-like systems, so we don't have to worry @@ -404,58 +453,56 @@ impl<'a, E: ExtensionField + Hash> MockProver { wits_in: &[ArcMultilinearExtension<'a, E>], challenge: [E; 2], lkm: Option, - ) -> Result<(), Vec>> { + ) -> Result<(), Vec>> { Self::run_maybe_challenge(cb, wits_in, &[], &[], Some(challenge), lkm) } pub fn run( cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], - programs: &[ceno_emul::Instruction], + program: &[ceno_emul::Instruction], lkm: Option, - ) -> Result<(), Vec>> { - Self::run_maybe_challenge(cb, wits_in, programs, &[], None, lkm) + ) -> Result<(), Vec>> { + Self::run_maybe_challenge(cb, wits_in, program, &[], None, lkm) } fn run_maybe_challenge( cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], - input_programs: &[ceno_emul::Instruction], + program: &[ceno_emul::Instruction], pi: &[ArcMultilinearExtension<'a, E>], challenge: Option<[E; 2]>, lkm: Option, - ) -> Result<(), Vec>> { - let program = Program::new( - CENO_PLATFORM.pc_base(), - CENO_PLATFORM.pc_base(), - input_programs.to_vec(), - Default::default(), - ); + ) -> Result<(), Vec>> { + let program = Program::from(program); + let (table, challenge) = Self::load_tables_with_program(cb.cs, &program, challenge); - // load tables - let (challenge, mut table) = if let Some(challenge) = challenge { - (challenge, load_tables(cb, challenge)) - } else { - load_once_tables(cb) - }; - let mut prog_table = vec![]; - Self::load_program_table(&mut prog_table, &program, challenge); - for prog in prog_table { - table.insert(prog); - } + Self::run_maybe_challenge_with_table(cb.cs, &table, wits_in, pi, 1, challenge, lkm) + .map(|_| ()) + } + #[allow(clippy::too_many_arguments)] + fn run_maybe_challenge_with_table( + cs: &ConstraintSystem, + table: &HashSet>, + wits_in: &[ArcMultilinearExtension<'a, E>], + pi: &[ArcMultilinearExtension<'a, E>], + num_instances: usize, + challenge: [E; 2], + expected_lkm: Option, + ) -> Result, Vec>> { + let mut shared_lkm = LkMultiplicityRaw::::default(); let mut errors = vec![]; + // Assert zero expressions - for (expr, name) in cb - .cs + for (expr, name) in cs .assert_zero_expressions .iter() - .chain(&cb.cs.assert_zero_sumcheck_expressions) + .chain(&cs.assert_zero_sumcheck_expressions) .zip_eq( - cb.cs - .assert_zero_expressions_namespace_map + cs.assert_zero_expressions_namespace_map .iter() - .chain(&cb.cs.assert_zero_sumcheck_expressions_namespace_map), + .chain(&cs.assert_zero_sumcheck_expressions_namespace_map), ) { if expr.degree() > MAX_CONSTRAINT_DEGREE { @@ -513,19 +560,20 @@ impl<'a, E: ExtensionField + Hash> MockProver { } // Lookup expressions - for (expr, name) in cb - .cs + for ((expr, name), (rom_type, _)) in cs .lk_expressions .iter() - .zip_eq(cb.cs.lk_expressions_namespace_map.iter()) + .zip_eq(cs.lk_expressions_namespace_map.iter()) + .zip_eq(cs.lk_expressions_items_map.iter()) { let expr_evaluated = wit_infer_by_expr(&[], wits_in, &[], pi, &challenge, expr); - let expr_evaluated = expr_evaluated.get_ext_field_vec(); + let expr_evaluated = &expr_evaluated.get_ext_field_vec()[..num_instances]; // Check each lookup expr exists in t vec for (inst_id, element) in enumerate(expr_evaluated) { if !table.contains(&element.to_canonical_u64_vec()) { errors.push(MockProverError::LookupError { + rom_type: *rom_type, expression: expr.clone(), evaluated: *element, name: name.clone(), @@ -533,114 +581,95 @@ impl<'a, E: ExtensionField + Hash> MockProver { }); } } + + // Increment shared LK Multiplicity + for element in expr_evaluated { + shared_lkm.increment(*rom_type, *element); + } } // LK Multiplicity check - if let Some(lkm_from_assignment) = lkm { + if let Some(lkm_from_assignment) = expected_lkm { // Infer LK Multiplicity from constraint system. - let lkm_from_cs = cb - .cs - .lk_expressions_items_map - .iter() - .map(|(rom_type, items)| { - ( - rom_type, - items + let mut lkm_from_cs = LkMultiplicity::default(); + for (rom_type, args) in &cs.lk_expressions_items_map { + let args_eval: Vec<_> = args + .iter() + .map(|arg_expr| { + let arg_eval = + wit_infer_by_expr(&[], wits_in, &[], pi, &challenge, arg_expr); + let mut arg_eval = arg_eval + .get_base_field_vec() .iter() - .map(|expr| { - // TODO generalized to all inst_id - let inst_id = 0; - wit_infer_by_expr(&[], wits_in, &[], pi, &challenge, expr) - .get_base_field_vec()[inst_id] - .to_canonical_u64() - }) - .collect::>(), - ) - }) - .fold(LkMultiplicity::default(), |mut lkm, (rom_type, args)| { + .map(SmallField::to_canonical_u64) + .take(num_instances) + .collect_vec(); + + // Constant terms will have single element in `args_expr_evaluated`, so let's fix that. + if arg_expr.is_constant() { + assert_eq!(arg_eval.len(), 1); + arg_eval.resize(num_instances, arg_eval[0]) + } + arg_eval + }) + .collect(); + + // Count lookups infered from ConstraintSystem from all instances into lkm_from_cs. + for inst_id in 0..num_instances { match rom_type { - ROMType::U5 => lkm.assert_ux::<5>(args[0]), - ROMType::U8 => lkm.assert_ux::<8>(args[0]), - ROMType::U14 => lkm.assert_ux::<14>(args[0]), - ROMType::U16 => lkm.assert_ux::<16>(args[0]), - ROMType::And => lkm.lookup_and_byte(args[0], args[1]), - ROMType::Or => lkm.lookup_or_byte(args[0], args[1]), - ROMType::Xor => lkm.lookup_xor_byte(args[0], args[1]), - ROMType::Ltu => lkm.lookup_ltu_byte(args[0], args[1]), + ROMType::U5 => lkm_from_cs.assert_ux::<5>(args_eval[0][inst_id]), + ROMType::U8 => lkm_from_cs.assert_ux::<8>(args_eval[0][inst_id]), + ROMType::U14 => lkm_from_cs.assert_ux::<14>(args_eval[0][inst_id]), + ROMType::U16 => lkm_from_cs.assert_ux::<16>(args_eval[0][inst_id]), + ROMType::And => lkm_from_cs + .lookup_and_byte(args_eval[0][inst_id], args_eval[1][inst_id]), + ROMType::Or => { + lkm_from_cs.lookup_or_byte(args_eval[0][inst_id], args_eval[1][inst_id]) + } + ROMType::Xor => lkm_from_cs + .lookup_xor_byte(args_eval[0][inst_id], args_eval[1][inst_id]), + ROMType::Ltu => lkm_from_cs + .lookup_ltu_byte(args_eval[0][inst_id], args_eval[1][inst_id]), ROMType::Pow => { - assert_eq!(args[0], 2); - lkm.lookup_pow2(args[1]) + assert_eq!(args_eval[0][inst_id], 2); + lkm_from_cs.lookup_pow2(args_eval[1][inst_id]) } - ROMType::Instruction => lkm.fetch(args[0] as u32), + ROMType::Instruction => lkm_from_cs.fetch(args_eval[0][inst_id] as u32), }; - - lkm - }); - - let lkm_from_cs = lkm_from_cs.into_finalize_result(); - let lkm_from_assignment = lkm_from_assignment.into_finalize_result(); - - // Compare each LK Multiplicity. - - for (rom_type, cs_map, ass_map) in - izip!(ROMType::iter(), &lkm_from_cs, &lkm_from_assignment) - { - if *cs_map != *ass_map { - let cs_keys: HashSet<_> = cs_map.keys().collect(); - let ass_keys: HashSet<_> = ass_map.keys().collect(); - - // lookup missing in lkm Constraint System. - ass_keys.difference(&cs_keys).for_each(|k| { - let count_ass = ass_map.get(k).unwrap(); - errors.push(MockProverError::LkMultiplicityError { - rom_type, - key: **k, - count: *count_ass as isize, - inst_id: 0, - }) - }); - - // lookup missing in lkm Assignments. - cs_keys.difference(&ass_keys).for_each(|k| { - let count_cs = cs_map.get(k).unwrap(); - errors.push(MockProverError::LkMultiplicityError { - rom_type, - key: **k, - count: -(*count_cs as isize), - inst_id: 0, - }) - }); - - // count of specific lookup differ lkm assignments and lkm cs - cs_keys.intersection(&ass_keys).for_each(|k| { - let count_cs = cs_map.get(k).unwrap(); - let count_ass = ass_map.get(k).unwrap(); - - if count_cs != count_ass { - errors.push(MockProverError::LkMultiplicityError { - rom_type, - key: **k, - count: (*count_ass as isize) - (*count_cs as isize), - inst_id: 0, - }) - } - }); } } + + errors.extend(compare_lkm(lkm_from_cs, lkm_from_assignment)); } if errors.is_empty() { - Ok(()) + Ok(shared_lkm) } else { Err(errors) } } - fn load_program_table(t_vec: &mut Vec>, program: &Program, challenge: [E; 2]) { + fn load_tables_with_program( + cs: &ConstraintSystem, + program: &Program, + challenge: Option<[E; 2]>, + ) -> (HashSet>, [E; 2]) { + // load tables + let (challenge, mut table) = if let Some(challenge) = challenge { + (challenge, load_tables(cs, challenge)) + } else { + load_once_tables(cs) + }; + table.extend(Self::load_program_table(program, challenge)); + (table, challenge) + } + + fn load_program_table(program: &Program, challenge: [E; 2]) -> Vec> { + let mut t_vec = vec![]; let mut cs = ConstraintSystem::::new(|| "mock_program"); let mut cb = CircuitBuilder::new_with_params(&mut cs, ProgramParams { platform: CENO_PLATFORM, - program_size: MOCK_PROGRAM_SIZE, + program_size: max(program.instructions.len(), MOCK_PROGRAM_SIZE), ..ProgramParams::default() }); let config = ProgramTableCircuit::<_>::construct_circuit(&mut cb).unwrap(); @@ -648,12 +677,13 @@ impl<'a, E: ExtensionField + Hash> MockProver { for table_expr in &cs.lk_table_expressions { for row in fixed.iter_rows() { // TODO: Find a better way to obtain the row content. - let row = row.iter().map(|v| (*v).into()).collect::>(); + let row: Vec = row.iter().map(|v| (*v).into()).collect(); let rlc_record = eval_by_expr_with_fixed(&row, &[], &[], &challenge, &table_expr.values); t_vec.push(rlc_record.to_canonical_u64_vec()); } } + t_vec } /// Run and check errors @@ -663,7 +693,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { pub fn assert_with_expected_errors( cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], - programs: &[ceno_emul::Instruction], + program: &[ceno_emul::Instruction], constraint_names: &[&str], challenge: Option<[E; 2]>, lkm: Option, @@ -671,7 +701,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { let error_groups = if let Some(challenge) = challenge { Self::run_with_challenge(cb, wits_in, challenge, lkm) } else { - Self::run(cb, wits_in, programs, lkm) + Self::run(cb, wits_in, program, lkm) } .err() .into_iter() @@ -691,15 +721,7 @@ Hints: " ); - for (count, error) in errors.iter().dedup_with_count() { - error.print(wits_in, &cb.cs.witin_namespace_map); - if count > 1 { - println!("Error: {} duplicates hidden.", count - 1); - } - } - println!("Error: {} constraints not satisfied", errors.len()); - println!("======================================================"); - panic!("(Unexpected) Constraints not satisfied"); + print_errors(errors, wits_in, &cb.cs.witin_namespace_map, true); } for constraint_name in constraint_names { // Expected errors didn't happen: @@ -715,7 +737,7 @@ Hints: pub fn assert_satisfied_raw( cb: &CircuitBuilder, raw_witin: RowMajorMatrix, - programs: &[ceno_emul::Instruction], + program: &[ceno_emul::Instruction], challenge: Option<[E; 2]>, lkm: Option, ) { @@ -724,17 +746,17 @@ Hints: .into_iter() .map(|v| v.into()) .collect_vec(); - Self::assert_satisfied(cb, &wits_in, programs, challenge, lkm); + Self::assert_satisfied(cb, &wits_in, program, challenge, lkm); } pub fn assert_satisfied( cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], - programs: &[ceno_emul::Instruction], + program: &[ceno_emul::Instruction], challenge: Option<[E; 2]>, lkm: Option, ) { - Self::assert_with_expected_errors(cb, wits_in, programs, &[], challenge, lkm); + Self::assert_with_expected_errors(cb, wits_in, program, &[], challenge, lkm); } pub fn assert_satisfied_full( @@ -742,7 +764,10 @@ Hints: mut fixed_trace: ZKVMFixedTraces, witnesses: &ZKVMWitnesses, pi: &PublicValues, - ) { + program: &Program, + ) where + E: LkMultiplicityKey, + { let instance = pi .to_vec::() .concat() @@ -758,13 +783,21 @@ Hints: let mut rng = thread_rng(); let challenges = [0u8; 2].map(|_| E::random(&mut rng)); + // Load lookup table. + let (lookup_table, _) = Self::load_tables_with_program( + &ConstraintSystem::::new(|| "temp for loading table"), + program, + Some(challenges), + ); + let mut wit_mles = HashMap::new(); let mut fixed_mles = HashMap::new(); let mut num_instances = HashMap::new(); - // Lookup errors - let mut rom_inputs = - HashMap::, String, String, Vec>)>>::new(); - let mut rom_tables = HashMap::>::new(); + + let mut lkm_tables = LkMultiplicityRaw::::default(); + let mut lkm_opcodes = LkMultiplicityRaw::::default(); + + // Process all circuits. for (circuit_name, cs) in &cs.circuit_css { let is_opcode = cs.lk_table_expressions.is_empty() && cs.r_table_expressions.is_empty() @@ -795,7 +828,6 @@ Hints: .circuit_fixed_traces .remove(circuit_name) .and_then(|fixed| fixed) - // .expect(format!("circuit {}'s fixed traces should not be None", circuit_name).as_str()) .map_or(vec![], |fixed| { fixed .into_mles() @@ -805,30 +837,35 @@ Hints: }); if is_opcode { tracing::info!( - "preprocessing opcode {} with {} entries", + "Mock proving opcode {} with {} entries", circuit_name, num_rows ); - // gather lookup inputs - for (expr, annotation, (rom_type, values)) in izip!( - &cs.lk_expressions, - &cs.lk_expressions_namespace_map, - &cs.lk_expressions_items_map + // Assert opcode and check single opcode lk multiplicity + // Also combine multiplicity in lkm_opcodes + let lkm_from_assignments = witnesses + .get_lk_mlt(circuit_name) + .map(LkMultiplicityRaw::deep_clone); + match Self::run_maybe_challenge_with_table( + cs, + &lookup_table, + &witness, + &[], + num_rows, + challenges, + lkm_from_assignments, ) { - let lk_input = - (wit_infer_by_expr(&fixed, &witness, &[], &pi_mles, &challenges, expr) - .get_ext_field_vec())[..num_rows] - .to_vec(); - rom_inputs.entry(*rom_type).or_default().push(( - lk_input, - circuit_name.clone(), - annotation.clone(), - values.clone(), - )); + Ok(multiplicities) => { + lkm_opcodes += multiplicities; + } + Err(errors) => { + tracing::error!("Mock proving failed for opcode {}", circuit_name); + print_errors(&errors, &witness, &cs.witin_namespace_map, true); + } } } else { tracing::info!( - "preprocessing table {} with {} entries", + "Mock proving table {} with {} entries", circuit_name, num_rows ); @@ -858,16 +895,13 @@ Hints: .get_base_field_vec() .to_vec(); - assert!( - rom_tables - .insert( - *rom_type, - izip!(lk_table, multiplicity).collect::>(), - ) - .is_none(), - "cannot assign to rom table {:?} twice", - rom_type - ); + for (key, multiplicity) in izip!(lk_table, multiplicity) { + lkm_tables.set_count( + *rom_type, + key, + multiplicity.to_canonical_u64() as usize, + ); + } } } wit_mles.insert(circuit_name.clone(), witness); @@ -875,72 +909,14 @@ Hints: num_instances.insert(circuit_name.clone(), num_rows); } - for (rom_type, inputs) in rom_inputs { - let table = rom_tables.get_mut(&rom_type).unwrap(); - for (lk_input_values, circuit_name, lk_input_annotation, input_value_exprs) in inputs { - // counting multiplicity in rom_input - let mut lk_input_values_multiplicity = HashMap::new(); - for (row, input_value) in enumerate(&lk_input_values) { - // we only keep first row to restore debug information - lk_input_values_multiplicity - .entry(input_value) - .or_insert([0u64, row as u64])[0] += 1; - } + // Assert lkm between all tables and combined opcode circuits + let errors: Vec> = compare_lkm(lkm_tables, lkm_opcodes); - for (k, [input_multiplicity, row]) in lk_input_values_multiplicity { - let table_multiplicity = if let Some(table_multiplicity) = table.get_mut(k) { - if input_multiplicity <= table_multiplicity.to_canonical_u64() { - *table_multiplicity -= E::BaseField::from(input_multiplicity); - continue; - } - table_multiplicity.to_canonical_u64() - } else { - 0 - }; - // log mismatch error - let witness = wit_mles - .get(&circuit_name) - .map(|mles| { - mles.iter() - .map(|mle| E::from(mle.get_base_field_vec()[row as usize])) - .collect_vec() - }) - .unwrap(); - let values = input_value_exprs - .iter() - .map(|expr| { - eval_by_expr_with_instance( - &[], - &witness, - &[], - &instance, - challenges.as_slice(), - expr, - ) - .as_bases()[0] - }) - .collect_vec(); - tracing::error!( - "{}: value {:x?} mismatch lk_multiplicity: real {:x} > remaining {:x} in {:?} table", - lk_input_annotation, - values, - input_multiplicity, - table_multiplicity, - rom_type, - ); - } - } - // each table entry's multiplicity should equal to 0 - for (k, multiplicity) in table { - if !multiplicity.is_zero_vartime() { - tracing::error!( - "table {:?}: {:x?} multiplicity = {:x}", - rom_type, - k, - multiplicity.to_canonical_u64() - ); - } - } + if errors.is_empty() { + tracing::info!("Mock proving successful for tables"); + } else { + tracing::error!("Mock proving failed for tables - {} errors", errors.len()); + print_errors(&errors, &[], &[], true); } // find out r != w errors @@ -1205,6 +1181,56 @@ Hints: } } +fn compare_lkm( + lkm_a: LkMultiplicityRaw, + lkm_b: LkMultiplicityRaw, +) -> Vec> +where + E: ExtensionField, + K: LkMultiplicityKey + Default + Ord, +{ + let lkm_a = lkm_a.into_finalize_result(); + let lkm_b = lkm_b.into_finalize_result(); + + // Compare each LK Multiplicity. + izip!(ROMType::iter(), &lkm_a, &lkm_b) + .flat_map(|(rom_type, a_map, b_map)| { + // We use a BTreeSet, instead of a HashSet, to ensure deterministic order. + let keys: BTreeSet<_> = chain!(a_map.keys(), b_map.keys()).collect(); + keys.into_iter().filter_map(move |key| { + let count = + *a_map.get(key).unwrap_or(&0) as isize - *b_map.get(key).unwrap_or(&0) as isize; + + (count != 0).then_some(MockProverError::LkMultiplicityError { + rom_type, + key: *key, + count, + }) + }) + }) + .collect() +} + +fn print_errors( + errors: &[MockProverError], + wits_in: &[ArcMultilinearExtension], + wits_in_name: &[String], + panic_on_error: bool, +) { + println!("======================================================"); + for (count, error) in errors.iter().dedup_with_count() { + error.print(wits_in, wits_in_name); + if count > 1 { + println!("Error: {} duplicates hidden.", count - 1); + } + } + println!("Error: {} constraints not satisfied", errors.len()); + println!("======================================================"); + if panic_on_error { + panic!("(Unexpected) Constraints not satisfied"); + } +} + #[cfg(test)] mod tests { @@ -1328,6 +1354,7 @@ mod tests { assert!(result.is_err(), "Expected error"); let err = result.unwrap_err(); assert_eq!(err, vec![MockProverError::LookupError { + rom_type: ROMType::U5, expression: Expression::Sum( Box::new(Expression::ScaledSum( Box::new(Expression::WitIn(0)), diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index e694f03b5..ded091508 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -280,7 +280,7 @@ fn test_single_add_instance_e2e() { zkvm_witness .assign_opcode_circuit::>(&zkvm_cs, &halt_config, halt_records) .unwrap(); - zkvm_witness.finalize_lk_multiplicities(); + zkvm_witness.finalize_lk_multiplicities(false); zkvm_witness .assign_table_circuit::>(&zkvm_cs, &u16_range_config, &()) .unwrap(); diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 8ae405ec9..781d8e96a 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -254,6 +254,10 @@ impl ZKVMWitnesses { self.witnesses_tables.get(name).cloned() } + pub fn get_lk_mlt(&self, name: &String) -> Option<&LkMultiplicity> { + self.lk_mlts.get(name) + } + pub fn assign_opcode_circuit>( &mut self, cs: &ZKVMConstraintSystem, @@ -277,14 +281,24 @@ impl ZKVMWitnesses { } // merge the multiplicities in each opcode circuit into one - pub fn finalize_lk_multiplicities(&mut self) { + pub fn finalize_lk_multiplicities(&mut self, is_keep_raw_lk_mlts: bool) { assert!(self.combined_lk_mlt.is_none()); assert!(!self.lk_mlts.is_empty()); let mut combined_lk_mlt = vec![]; let keys = self.lk_mlts.keys().cloned().collect_vec(); for name in keys { - let lk_mlt = self.lk_mlts.remove(&name).unwrap().into_finalize_result(); + let lk_mlt = if is_keep_raw_lk_mlts { + // mock prover needs the lk_mlt for processing, so we do not remove it + self.lk_mlts + .get(&name) + .unwrap() + .deep_clone() + .into_finalize_result() + } else { + self.lk_mlts.remove(&name).unwrap().into_finalize_result() + }; + if combined_lk_mlt.is_empty() { combined_lk_mlt = lk_mlt.to_vec(); } else { diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index 97150e086..e2db20102 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -1,19 +1,20 @@ use ff::Field; +use itertools::izip; +use multilinear_extensions::mle::{DenseMultilinearExtension, IntoMLE}; +use rayon::{ + iter::{IntoParallelIterator, ParallelIterator}, + slice::ParallelSliceMut, +}; use std::{ - array, cell::RefCell, collections::HashMap, + fmt::Debug, + hash::Hash, mem::{self}, - ops::Index, + ops::{AddAssign, Index}, slice::{Chunks, ChunksMut}, sync::Arc, }; - -use multilinear_extensions::mle::{DenseMultilinearExtension, IntoMLE}; -use rayon::{ - iter::{IntoParallelIterator, ParallelIterator}, - slice::ParallelSliceMut, -}; use thread_local::ThreadLocal; use crate::{ @@ -125,26 +126,137 @@ impl Index for RowMajorMatrix { } } +pub type MultiplicityRaw = [HashMap; mem::variant_count::()]; + +#[derive(Clone, Default, Debug)] +pub struct Multiplicity(pub MultiplicityRaw); + /// A lock-free thread safe struct to count logup multiplicity for each ROM type /// Lock-free by thread-local such that each thread will only have its local copy /// struct is cloneable, for internallly it use Arc so the clone will be low cost #[derive(Clone, Default, Debug)] #[allow(clippy::type_complexity)] -pub struct LkMultiplicity { - multiplicity: Arc; mem::variant_count::()]>>>, +pub struct LkMultiplicityRaw { + multiplicity: Arc>>>, } +impl AddAssign for LkMultiplicityRaw +where + K: Copy + Clone + Debug + Default + Eq + Hash + Send, +{ + fn add_assign(&mut self, rhs: Self) { + *self += Multiplicity(rhs.into_finalize_result()); + } +} + +impl AddAssign for Multiplicity +where + K: Eq + Hash, +{ + fn add_assign(&mut self, rhs: Self) { + for (lhs, rhs) in izip!(&mut self.0, rhs.0) { + for (key, value) in rhs { + *lhs.entry(key).or_default() += value; + } + } + } +} + +impl AddAssign> for LkMultiplicityRaw +where + K: Copy + Clone + Debug + Default + Eq + Hash + Send, +{ + fn add_assign(&mut self, rhs: Multiplicity) { + let multiplicity = self.multiplicity.get_or_default(); + for (lhs, rhs) in izip!(&mut multiplicity.borrow_mut().0, rhs.0) { + for (key, value) in rhs { + *lhs.entry(key).or_default() += value; + } + } + } +} + +impl AddAssign<((ROMType, K), usize)> for LkMultiplicityRaw +where + K: Copy + Clone + Debug + Default + Eq + Hash + Send, +{ + fn add_assign(&mut self, ((rom_type, key), value): ((ROMType, K), usize)) { + let multiplicity = self.multiplicity.get_or_default(); + (*multiplicity.borrow_mut().0[rom_type as usize] + .entry(key) + .or_default()) += value; + } +} + +impl AddAssign<(ROMType, K)> for LkMultiplicityRaw +where + K: Copy + Clone + Debug + Default + Eq + Hash + Send, +{ + fn add_assign(&mut self, (rom_type, key): (ROMType, K)) { + let multiplicity = self.multiplicity.get_or_default(); + (*multiplicity.borrow_mut().0[rom_type as usize] + .entry(key) + .or_default()) += 1; + } +} + +impl LkMultiplicityRaw { + /// Merge result from multiple thread local to single result. + pub fn into_finalize_result(self) -> MultiplicityRaw { + let mut results = Multiplicity::default(); + for y in Arc::try_unwrap(self.multiplicity).unwrap() { + results += y.into_inner(); + } + results.0 + } + + pub fn increment(&mut self, rom_type: ROMType, key: K) { + *self += (rom_type, key); + } + + pub fn set_count(&mut self, rom_type: ROMType, key: K, count: usize) { + if count == 0 { + return; + } + let multiplicity = self.multiplicity.get_or_default(); + let table = &mut multiplicity.borrow_mut().0[rom_type as usize]; + if count == 0 { + table.remove(&key); + } else { + table.insert(key, count); + } + } + + /// Clone inner, expensive operation. + pub fn deep_clone(&self) -> Self { + let multiplicity = self.multiplicity.get_or_default(); + let deep_cloned = multiplicity.borrow().clone(); + let thread_local = ThreadLocal::new(); + thread_local.get_or(|| RefCell::new(deep_cloned)); + LkMultiplicityRaw { + multiplicity: Arc::new(thread_local), + } + } +} + +/// Default LkMultiplicity with u64 key. +pub type LkMultiplicity = LkMultiplicityRaw; + impl LkMultiplicity { /// assert within range #[inline(always)] pub fn assert_ux(&mut self, v: u64) { - match C { - 16 => self.increment(ROMType::U16, v), - 14 => self.increment(ROMType::U14, v), - 8 => self.increment(ROMType::U8, v), - 5 => self.increment(ROMType::U5, v), - _ => panic!("Unsupported bit range"), - } + use ROMType::*; + self.increment( + match C { + 16 => U16, + 14 => U14, + 8 => U8, + 5 => U5, + _ => panic!("Unsupported bit range"), + }, + v, + ); } /// Track a lookup into a logic table (AndTable, etc). @@ -180,30 +292,6 @@ impl LkMultiplicity { pub fn fetch(&mut self, pc: u32) { self.increment(ROMType::Instruction, pc as u64); } - - /// merge result from multiple thread local to single result - pub fn into_finalize_result(self) -> [HashMap; mem::variant_count::()] { - Arc::try_unwrap(self.multiplicity) - .unwrap() - .into_iter() - .fold(array::from_fn(|_| HashMap::new()), |mut x, y| { - x.iter_mut().zip(y.borrow().iter()).for_each(|(m1, m2)| { - for (key, value) in m2 { - *m1.entry(*key).or_insert(0) += value; - } - }); - x - }) - } - - fn increment(&mut self, rom_type: ROMType, key: u64) { - let multiplicity = self - .multiplicity - .get_or(|| RefCell::new(array::from_fn(|_| HashMap::new()))); - (*multiplicity.borrow_mut()[rom_type as usize] - .entry(key) - .or_default()) += 1; - } } #[cfg(test)] From f91ec067bb6006f798a7976874faedbf3c50d3e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20G=C3=B6rgens?= Date: Thu, 2 Jan 2025 16:28:48 +0800 Subject: [PATCH 02/12] Implement some of Risc0's / SP1's syscalls to make Sproll work (#800) --- ceno_rt/src/lib.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ceno_rt/src/lib.rs b/ceno_rt/src/lib.rs index 5562b1e8a..576615802 100644 --- a/ceno_rt/src/lib.rs +++ b/ceno_rt/src/lib.rs @@ -5,6 +5,10 @@ use getrandom::{Error, register_custom_getrandom}; #[cfg(target_arch = "riscv32")] use core::arch::{asm, global_asm}; +use std::{ + alloc::{Layout, alloc_zeroed}, + ptr::null, +}; #[cfg(target_arch = "riscv32")] mod allocator; @@ -26,19 +30,19 @@ pub use syscalls::*; #[no_mangle] #[linkage = "weak"] pub extern "C" fn sys_write(_fd: i32, _buf: *const u8, _count: usize) -> isize { - unimplemented!(); + 0 } #[no_mangle] #[linkage = "weak"] -pub extern "C" fn sys_alloc_words(_nwords: usize) -> *mut u32 { - unimplemented!(); +pub extern "C" fn sys_alloc_words(nwords: usize) -> *mut u32 { + unsafe { alloc_zeroed(Layout::from_size_align(4 * nwords, 4).unwrap()) as *mut u32 } } #[no_mangle] #[linkage = "weak"] pub extern "C" fn sys_getenv(_name: *const u8) -> *const u8 { - unimplemented!(); + null() } /// Generates random bytes. From 0c39f7f8c7ae444ae19e24bff850e5cc9041950d Mon Sep 17 00:00:00 2001 From: Ming Date: Thu, 2 Jan 2025 19:04:08 +0700 Subject: [PATCH 03/12] refactor to replace v1 by v2 and cleanup suffix (#791) Previously to quick verify idea and avoid massive change, there are new functionality with suffix `_v2`. After experiment with good result, long time ago all logic already stick to v2 version and no longer use v1. This PR clean up all leftover v1 version, do renaming and file replacement without modify existing logic. In summary - `sumcheck/src/prover_v2.rs` -> `sumcheck/src/prover.rs` - `multilinear_extensions/src/virtual_poly_v2.rs` -> `multilinear_extensions/src/virtual_poly.rs` - clean up all `V2` suffix This addressed previous out-dated PR https://github.com/scroll-tech/ceno/pull/162, and as a preparation for https://github.com/scroll-tech/ceno/issues/788, https://github.com/scroll-tech/ceno/issues/702 --- ceno_zkvm/src/expression.rs | 2 +- ceno_zkvm/src/scheme/mock_prover.rs | 2 +- ceno_zkvm/src/scheme/prover.rs | 13 +- ceno_zkvm/src/scheme/tests.rs | 2 +- ceno_zkvm/src/scheme/utils.rs | 4 +- ceno_zkvm/src/scheme/verifier.rs | 6 +- ceno_zkvm/src/structs.rs | 2 +- ceno_zkvm/src/uint/arithmetic.rs | 2 +- ceno_zkvm/src/virtual_polys.rs | 19 +- mpcs/benches/basefold.rs | 2 +- mpcs/src/basefold.rs | 2 +- mpcs/src/lib.rs | 4 +- multilinear_extensions/src/lib.rs | 1 - multilinear_extensions/src/test.rs | 25 - multilinear_extensions/src/virtual_poly.rs | 203 +--- multilinear_extensions/src/virtual_poly_v2.rs | 263 ----- sumcheck/benches/devirgo_sumcheck.rs | 4 +- sumcheck/src/lib.rs | 1 - sumcheck/src/prover.rs | 580 +++++++---- sumcheck/src/prover_v2.rs | 919 ------------------ sumcheck/src/structs.rs | 22 +- sumcheck/src/test.rs | 8 +- sumcheck/src/util.rs | 45 +- sumcheck/src/verifier.rs | 12 +- 24 files changed, 497 insertions(+), 1646 deletions(-) delete mode 100644 multilinear_extensions/src/virtual_poly_v2.rs delete mode 100644 sumcheck/src/prover_v2.rs diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 6f0757a93..5d7233a00 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -12,7 +12,7 @@ use ff::Field; use ff_ext::ExtensionField; use goldilocks::SmallField; -use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension; +use multilinear_extensions::virtual_poly::ArcMultilinearExtension; use crate::{ circuit_builder::CircuitBuilder, diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 494fc6134..e63ae4ee7 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -23,7 +23,7 @@ use ff_ext::ExtensionField; use generic_static::StaticTypeMap; use goldilocks::{GoldilocksExt2, SmallField}; use itertools::{Itertools, chain, enumerate, izip}; -use multilinear_extensions::{mle::IntoMLEs, virtual_poly_v2::ArcMultilinearExtension}; +use multilinear_extensions::{mle::IntoMLEs, virtual_poly::ArcMultilinearExtension}; use rand::thread_rng; use std::{ cmp::max, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index daaa1099d..a7042e5ca 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -10,13 +10,12 @@ use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ mle::{IntoMLE, MultilinearExtension}, util::ceil_log2, - virtual_poly::build_eq_x_r_vec, - virtual_poly_v2::ArcMultilinearExtension, + virtual_poly::{ArcMultilinearExtension, build_eq_x_r_vec}, }; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use sumcheck::{ macros::{entered_span, exit_span}, - structs::{IOPProverMessage, IOPProverStateV2}, + structs::{IOPProverMessage, IOPProverState}, }; use transcript::{ForkableTranscript, Transcript}; @@ -583,7 +582,7 @@ impl> ZKVMProver { } tracing::debug!("main sel sumcheck start"); - let (main_sel_sumcheck_proofs, state) = IOPProverStateV2::prove_batch_polys( + let (main_sel_sumcheck_proofs, state) = IOPProverState::prove_batch_polys( num_threads, virtual_polys.get_batched_polys(), transcript, @@ -1029,7 +1028,7 @@ impl> ZKVMProver { virtual_polys.add_mle_list(vec![eq, lk_d_wit], *alpha); } - let (same_r_sumcheck_proofs, state) = IOPProverStateV2::prove_batch_polys( + let (same_r_sumcheck_proofs, state) = IOPProverState::prove_batch_polys( num_threads, virtual_polys.get_batched_polys(), transcript, @@ -1241,7 +1240,7 @@ impl TowerProver { layer_polys .iter() .all(|f| { - f.evaluations().len() == (1 << (log_num_fanin * round)) + f.evaluations().len() == 1 << (log_num_fanin * round) }) ); @@ -1287,7 +1286,7 @@ impl TowerProver { // NOTE: at the time of adding this span, visualizing it with the flamegraph layer // shows it to be (inexplicably) much more time-consuming than the call to `prove_batch_polys` // This is likely a bug in the tracing-flame crate. - let (sumcheck_proofs, state) = IOPProverStateV2::prove_batch_polys( + let (sumcheck_proofs, state) = IOPProverState::prove_batch_polys( num_threads, virtual_polys.get_batched_polys(), transcript, diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index ded091508..ed747fdfe 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -12,7 +12,7 @@ use goldilocks::GoldilocksExt2; use itertools::Itertools; use mpcs::{Basefold, BasefoldDefault, BasefoldRSParams, PolynomialCommitmentScheme}; use multilinear_extensions::{ - mle::IntoMLE, util::ceil_log2, virtual_poly_v2::ArcMultilinearExtension, + mle::IntoMLE, util::ceil_log2, virtual_poly::ArcMultilinearExtension, }; use transcript::{BasicTranscript, BasicTranscriptWithStat, StatisticRecorder, Transcript}; diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 16746c9b1..9fc3b64a6 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -8,7 +8,7 @@ use multilinear_extensions::{ mle::{DenseMultilinearExtension, FieldType, IntoMLE}, op_mle_xa_b, op_mle3_range, util::ceil_log2, - virtual_poly_v2::ArcMultilinearExtension, + virtual_poly::ArcMultilinearExtension, }; use rayon::{ iter::{ @@ -415,7 +415,7 @@ mod tests { commutative_op_mle_pair, mle::{FieldType, IntoMLE}, util::ceil_log2, - virtual_poly_v2::ArcMultilinearExtension, + virtual_poly::ArcMultilinearExtension, }; use crate::{ diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 9a082c25f..03ef32e64 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -344,7 +344,7 @@ impl> ZKVMVerifier &VPAuxInfo { // + 1 from sel_non_lc_zero_sumcheck max_degree: SEL_DEGREE.max(cs.max_non_lc_degree + 1), - num_variables: log2_num_instances, + max_num_variables: log2_num_instances, phantom: PhantomData, }, transcript, @@ -634,7 +634,7 @@ impl> ZKVMVerifier }, &VPAuxInfo { max_degree: SEL_DEGREE, - num_variables: expected_max_rounds, + max_num_variables: expected_max_rounds, phantom: PhantomData, }, transcript, @@ -904,7 +904,7 @@ impl TowerVerify { }, &VPAuxInfo { max_degree: NUM_FANIN + 1, // + 1 for eq - num_variables: (round + 1) * log2_num_fanin, + max_num_variables: (round + 1) * log2_num_fanin, phantom: PhantomData, }, transcript, diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 781d8e96a..db1cd5dc8 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -12,7 +12,7 @@ use ff_ext::ExtensionField; use itertools::{Itertools, chain}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ - mle::DenseMultilinearExtension, virtual_poly_v2::ArcMultilinearExtension, + mle::DenseMultilinearExtension, virtual_poly::ArcMultilinearExtension, }; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 910bafa3c..c2e09b470 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -687,7 +687,7 @@ mod tests { use goldilocks::GoldilocksExt2; use itertools::Itertools; use multilinear_extensions::{ - mle::DenseMultilinearExtension, virtual_poly_v2::ArcMultilinearExtension, + mle::DenseMultilinearExtension, virtual_poly::ArcMultilinearExtension, }; type E = GoldilocksExt2; // 18446744069414584321 diff --git a/ceno_zkvm/src/virtual_polys.rs b/ceno_zkvm/src/virtual_polys.rs index 4a0bcbb51..c3efd4172 100644 --- a/ceno_zkvm/src/virtual_polys.rs +++ b/ceno_zkvm/src/virtual_polys.rs @@ -9,14 +9,14 @@ use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::{ util::ceil_log2, - virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2}, + virtual_poly::{ArcMultilinearExtension, VirtualPolynomial}, }; use crate::{expression::Expression, utils::transpose}; pub struct VirtualPolynomials<'a, E: ExtensionField> { num_threads: usize, - polys: Vec>, + polys: Vec>, /// a storage to keep thread based mles, specific to multi-thread logic thread_based_mles_storage: HashMap>>, } @@ -26,7 +26,7 @@ impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> { VirtualPolynomials { num_threads, polys: (0..num_threads) - .map(|_| VirtualPolynomialV2::new(max_num_variables - ceil_log2(num_threads))) + .map(|_| VirtualPolynomial::new(max_num_variables - ceil_log2(num_threads))) .collect_vec(), thread_based_mles_storage: HashMap::new(), } @@ -77,7 +77,7 @@ impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> { }); } - pub fn get_batched_polys(self) -> Vec> { + pub fn get_batched_polys(self) -> Vec> { self.polys } @@ -174,10 +174,9 @@ mod tests { use itertools::Itertools; use multilinear_extensions::{ mle::IntoMLE, - virtual_poly::VPAuxInfo, - virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2}, + virtual_poly::{ArcMultilinearExtension, VPAuxInfo, VirtualPolynomial}, }; - use sumcheck::structs::{IOPProverStateV2, IOPVerifierState}; + use sumcheck::structs::{IOPProverState, IOPVerifierState}; use transcript::BasicTranscript as Transcript; use crate::{ @@ -284,7 +283,7 @@ mod tests { virtual_polys.add_mle_list(f2.iter().collect(), E::ONE); virtual_polys.add_mle_list(f3.iter().collect(), E::ONE); - let (sumcheck_proofs, _) = IOPProverStateV2::prove_batch_polys( + let (sumcheck_proofs, _) = IOPProverState::prove_batch_polys( num_threads, virtual_polys.get_batched_polys(), &mut transcript, @@ -296,13 +295,13 @@ mod tests { &sumcheck_proofs, &VPAuxInfo { max_degree: 3, - num_variables: max_num_vars, + max_num_variables: max_num_vars, phantom: std::marker::PhantomData, }, &mut transcript, ); - let mut verifier_poly = VirtualPolynomialV2::new(max_num_vars); + let mut verifier_poly = VirtualPolynomial::new(max_num_vars); verifier_poly.add_mle_list(f1.to_vec(), E::ONE); verifier_poly.add_mle_list(f2.to_vec(), E::ONE); verifier_poly.add_mle_list(f3.to_vec(), E::ONE); diff --git a/mpcs/benches/basefold.rs b/mpcs/benches/basefold.rs index 3207ec5ae..64b9a5460 100644 --- a/mpcs/benches/basefold.rs +++ b/mpcs/benches/basefold.rs @@ -16,7 +16,7 @@ use mpcs::{ use multilinear_extensions::{ mle::{DenseMultilinearExtension, MultilinearExtension}, - virtual_poly_v2::ArcMultilinearExtension, + virtual_poly::ArcMultilinearExtension, }; use transcript::{BasicTranscript, Transcript}; diff --git a/mpcs/src/basefold.rs b/mpcs/src/basefold.rs index c7eb18f71..6204ed038 100644 --- a/mpcs/src/basefold.rs +++ b/mpcs/src/basefold.rs @@ -66,7 +66,7 @@ mod commit_phase; use commit_phase::{batch_commit_phase, commit_phase, simple_batch_commit_phase}; mod encoding; pub use encoding::{coset_fft, fft, fft_root_table}; -use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension; +use multilinear_extensions::virtual_poly::ArcMultilinearExtension; mod query_phase; // This sumcheck module is different from the mpcs::sumcheck module, in that diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index fe110581c..fcfd1ba69 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -320,7 +320,7 @@ pub use basefold::{ EncodingScheme, RSCode, RSCodeDefaultSpec, coset_fft, fft, fft_root_table, one_level_eval_hc, one_level_interp_hc, }; -use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension; +use multilinear_extensions::virtual_poly::ArcMultilinearExtension; fn validate_input( function: &str, @@ -377,7 +377,7 @@ pub mod test_util { use multilinear_extensions::mle::DenseMultilinearExtension; #[cfg(test)] use multilinear_extensions::{ - mle::MultilinearExtension, virtual_poly_v2::ArcMultilinearExtension, + mle::MultilinearExtension, virtual_poly::ArcMultilinearExtension, }; use rand::rngs::OsRng; #[cfg(test)] diff --git a/multilinear_extensions/src/lib.rs b/multilinear_extensions/src/lib.rs index 9f669e348..56cd26dc4 100644 --- a/multilinear_extensions/src/lib.rs +++ b/multilinear_extensions/src/lib.rs @@ -2,7 +2,6 @@ pub mod mle; pub mod util; pub mod virtual_poly; -pub mod virtual_poly_v2; #[cfg(test)] mod test; diff --git a/multilinear_extensions/src/test.rs b/multilinear_extensions/src/test.rs index 9ef9b94e7..91b176e71 100644 --- a/multilinear_extensions/src/test.rs +++ b/multilinear_extensions/src/test.rs @@ -31,31 +31,6 @@ fn test_virtual_polynomial_additions() { } } -#[test] -fn test_virtual_polynomial_mul_by_mle() { - let mut rng = test_rng(); - for nv in 2..5 { - for num_products in 2..5 { - let base: Vec = (0..nv).map(|_| E::random(&mut rng)).collect(); - - let (a, _a_sum) = VirtualPolynomial::::random(nv, (2, 3), num_products, &mut rng); - let (b, _b_sum) = DenseMultilinearExtension::::random_mle_list(nv, 1, &mut rng); - let b_mle = b[0].clone(); - let coeff = Goldilocks::random(&mut rng); - let b_vp = VirtualPolynomial::new_from_mle(b_mle.clone(), coeff); - - let mut c = a.clone(); - - c.mul_by_mle(b_mle, coeff); - - assert_eq!( - a.evaluate(base.as_ref()) * b_vp.evaluate(base.as_ref()), - c.evaluate(base.as_ref()) - ); - } - } -} - #[test] fn test_eq_xr() { let mut rng = test_rng(); diff --git a/multilinear_extensions/src/virtual_poly.rs b/multilinear_extensions/src/virtual_poly.rs index 72c5b2649..bd50d659a 100644 --- a/multilinear_extensions/src/virtual_poly.rs +++ b/multilinear_extensions/src/virtual_poly.rs @@ -4,16 +4,18 @@ use crate::{ mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension}, util::{bit_decompose, create_uninit_vec, max_usable_threads}, }; -use ark_std::{end_timer, iterable::Iterable, rand::Rng, start_timer}; -use ff::{Field, PrimeField}; +use ark_std::{end_timer, rand::Rng, start_timer}; +use ff::PrimeField; use ff_ext::ExtensionField; +use itertools::Itertools; use rayon::{ - iter::IntoParallelIterator, - prelude::{IndexedParallelIterator, ParallelIterator}, + iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}, slice::ParallelSliceMut, }; use serde::{Deserialize, Serialize}; +pub type ArcMultilinearExtension<'a, E> = + Arc> + 'a>; #[rustfmt::skip] /// A virtual polynomial is a sum of products of multilinear polynomials; /// where the multilinear polynomials are stored via their multilinear @@ -41,15 +43,15 @@ use serde::{Deserialize, Serialize}; /// \] /// - raw_pointers_lookup_table maps fi to i /// -#[derive(Clone, Debug, Default, PartialEq)] -pub struct VirtualPolynomial { +#[derive(Default, Clone)] +pub struct VirtualPolynomial<'a, E: ExtensionField> { /// Aux information about the multilinear polynomial pub aux_info: VPAuxInfo, /// list of reference to products (as usize) of multilinear extension - pub products: Vec<(E::BaseField, Vec)>, + pub products: Vec<(E, Vec)>, /// Stores multilinear extensions in which product multiplicand can refer /// to. - pub flattened_ml_extensions: Vec>, + pub flattened_ml_extensions: Vec>, /// Pointers to the above poly extensions raw_pointers_lookup_table: HashMap, } @@ -59,20 +61,20 @@ pub struct VirtualPolynomial { pub struct VPAuxInfo { /// max number of multiplicands in each product pub max_degree: usize, - /// number of variables of the polynomial - pub num_variables: usize, + /// max number of variables of the polynomial + pub max_num_variables: usize, /// Associated field #[doc(hidden)] pub phantom: PhantomData, } -impl VirtualPolynomial { - /// Creates an empty virtual polynomial with `num_variables`. - pub fn new(num_variables: usize) -> Self { +impl<'a, E: ExtensionField> VirtualPolynomial<'a, E> { + /// Creates an empty virtual polynomial with `max_num_variables`. + pub fn new(max_num_variables: usize) -> Self { VirtualPolynomial { aux_info: VPAuxInfo { max_degree: 0, - num_variables, + max_num_variables, phantom: PhantomData, }, products: Vec::new(), @@ -82,8 +84,8 @@ impl VirtualPolynomial { } /// Creates an new virtual polynomial from a MLE and its coefficient. - pub fn new_from_mle(mle: ArcDenseMultilinearExtension, coefficient: E::BaseField) -> Self { - let mle_ptr: usize = Arc::as_ptr(&mle) as usize; + pub fn new_from_mle(mle: ArcMultilinearExtension<'a, E>, coefficient: E) -> Self { + let mle_ptr: usize = Arc::as_ptr(&mle) as *const () as usize; let mut hm = HashMap::new(); hm.insert(mle_ptr, 0); @@ -91,7 +93,7 @@ impl VirtualPolynomial { aux_info: VPAuxInfo { // The max degree is the max degree of any individual variable max_degree: 1, - num_variables: mle.num_vars, + max_num_variables: mle.num_vars(), phantom: PhantomData, }, // here `0` points to the first polynomial of `flattened_ml_extensions` @@ -102,31 +104,33 @@ impl VirtualPolynomial { } /// Add a product of list of multilinear extensions to self - /// Returns an error if the list is empty, or the MLE has a different - /// `num_vars` from self. + /// Returns an error if the list is empty. + /// + /// mle in mle_list must be in same num_vars() in same product, + /// while different product can have different num_vars() /// /// The MLEs will be multiplied together, and then multiplied by the scalar /// `coefficient`. - pub fn add_mle_list( - &mut self, - mle_list: Vec>, - coefficient: E::BaseField, - ) { - let mle_list: Vec> = mle_list.into_iter().collect(); + pub fn add_mle_list(&mut self, mle_list: Vec>, coefficient: E) { + let mle_list: Vec> = mle_list.into_iter().collect(); let mut indexed_product = Vec::with_capacity(mle_list.len()); assert!(!mle_list.is_empty(), "input mle_list is empty"); + // sanity check: all mle in mle_list must have same num_vars() + assert!( + mle_list + .iter() + .map(|m| { + assert!(m.num_vars() <= self.aux_info.max_num_variables); + m.num_vars() + }) + .all_equal() + ); self.aux_info.max_degree = max(self.aux_info.max_degree, mle_list.len()); for mle in mle_list { - assert_eq!( - mle.num_vars, self.aux_info.num_variables, - "product has a multiplicand with wrong number of variables {} vs {}", - mle.num_vars, self.aux_info.num_variables - ); - - let mle_ptr: usize = Arc::as_ptr(&mle) as usize; + let mle_ptr: usize = Arc::as_ptr(&mle) as *const () as usize; if let Some(index) = self.raw_pointers_lookup_table.get(&mle_ptr) { indexed_product.push(*index) } else { @@ -140,10 +144,10 @@ impl VirtualPolynomial { } /// in-place merge with another virtual polynomial - pub fn merge(&mut self, other: &VirtualPolynomial) { + pub fn merge(&mut self, other: &VirtualPolynomial<'a, E>) { let start = start_timer!(|| "virtual poly add"); for (coeffient, products) in other.products.iter() { - let cur: Vec> = products + let cur: Vec<_> = products .iter() .map(|&x| other.flattened_ml_extensions[x].clone()) .collect(); @@ -153,62 +157,23 @@ impl VirtualPolynomial { end_timer!(start); } - /// Multiple the current VirtualPolynomial by an MLE: - /// - add the MLE to the MLE list; - /// - multiple each product by MLE and its coefficient. - /// Returns an error if the MLE has a different `num_vars` from self. - #[tracing::instrument(skip_all, name = "mul_by_mle")] - pub fn mul_by_mle(&mut self, mle: ArcDenseMultilinearExtension, coefficient: E::BaseField) { - let start = start_timer!(|| "mul by mle"); - - assert_eq!( - mle.num_vars, self.aux_info.num_variables, - "product has a multiplicand with wrong number of variables {} vs {}", - mle.num_vars, self.aux_info.num_variables - ); - - let mle_ptr = Arc::as_ptr(&mle) as usize; - - // check if this mle already exists in the virtual polynomial - let mle_index = match self.raw_pointers_lookup_table.get(&mle_ptr) { - Some(&p) => p, - None => { - self.raw_pointers_lookup_table - .insert(mle_ptr, self.flattened_ml_extensions.len()); - self.flattened_ml_extensions.push(mle); - self.flattened_ml_extensions.len() - 1 - } - }; - - for (prod_coef, indices) in self.products.iter_mut() { - // - add the MLE to the MLE list; - // - multiple each product by MLE and its coefficient. - indices.push(mle_index); - *prod_coef *= coefficient; - } - - // increase the max degree by one as the MLE has degree 1. - self.aux_info.max_degree += 1; - end_timer!(start); - } - /// Evaluate the virtual polynomial at point `point`. /// Returns an error is point.len() does not match `num_variables`. pub fn evaluate(&self, point: &[E]) -> E { let start = start_timer!(|| "evaluation"); assert_eq!( - self.aux_info.num_variables, + self.aux_info.max_num_variables, point.len(), "wrong number of variables {} vs {}", - self.aux_info.num_variables, + self.aux_info.max_num_variables, point.len() ); let evals: Vec = self .flattened_ml_extensions .iter() - .map(|x| x.evaluate(point)) + .map(|x| x.evaluate(&point[0..x.num_vars()])) .collect(); let res = self @@ -237,7 +202,9 @@ impl VirtualPolynomial { rng.gen_range(num_multiplicands_range.0..num_multiplicands_range.1); let (product, product_sum) = DenseMultilinearExtension::random_mle_list(nv, num_multiplicands, rng); - let coefficient = E::BaseField::random(&mut rng); + let product: Vec> = + product.into_iter().map(|mle| mle as _).collect_vec(); + let coefficient = E::random(&mut rng); poly.add_mle_list(product, coefficient); sum += product_sum * coefficient; } @@ -246,90 +213,18 @@ impl VirtualPolynomial { (poly, sum) } - /// Sample a random virtual polynomial that evaluates to zero everywhere - /// over the boolean hypercube. - pub fn rand_zero( - nv: usize, - num_multiplicands_range: (usize, usize), - num_products: usize, - mut rng: impl Rng + Copy, - ) -> Self { - let mut poly = VirtualPolynomial::new(nv); - for _ in 0..num_products { - let num_multiplicands = - rng.gen_range(num_multiplicands_range.0..num_multiplicands_range.1); - let product = - DenseMultilinearExtension::random_zero_mle_list(nv, num_multiplicands, rng); - let coefficient = E::BaseField::random(&mut rng); - poly.add_mle_list(product, coefficient); - } - - poly - } - - // Input poly f(x) and a random vector r, output - // \hat f(x) = \sum_{x_i \in eval_x} f(x_i) eq(x, r) - // where - // eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i)) - // - // This function is used in ZeroCheck. - pub fn build_f_hat(&self, r: &[E]) -> Self { - let start = start_timer!(|| "zero check build hat f"); - - assert_eq!( - self.aux_info.num_variables, - r.len(), - "r.len() is different from number of variables: {} vs {}", - r.len(), - self.aux_info.num_variables - ); - - let eq_x_r = build_eq_x_r(r); - let mut res = self.clone(); - res.mul_by_mle(eq_x_r, E::BaseField::ONE); - - end_timer!(start); - res - } - - /// Print out the evaluation map for testing. Panic if the num_vars > 5. + /// Print out the evaluation map for testing. Panic if the num_vars() > 5. pub fn print_evals(&self) { - if self.aux_info.num_variables > 5 { - panic!("this function is used for testing only. cannot print more than 5 num_vars") + if self.aux_info.max_num_variables > 5 { + panic!("this function is used for testing only. cannot print more than 5 num_vars()") } - for i in 0..1 << self.aux_info.num_variables { - let point = bit_decompose(i, self.aux_info.num_variables); + for i in 0..1 << self.aux_info.max_num_variables { + let point = bit_decompose(i, self.aux_info.max_num_variables); let point_fr: Vec = point.iter().map(|&x| E::from(x as u64)).collect(); println!("{} {:?}", i, self.evaluate(point_fr.as_ref())) } println!() } - - // TODO: This seems expensive. Is there a better way to covert poly into its ext fields? - pub fn to_ext_field(&self) -> VirtualPolynomial { - let timer = start_timer!(|| "convert VP to ext field"); - let products = self.products.iter().map(|(f, v)| (*f, v.clone())).collect(); - - let mut flattened_ml_extensions = vec![]; - let mut hm = HashMap::new(); - for mle in self.flattened_ml_extensions.iter() { - let mle_ptr = Arc::as_ptr(mle) as usize; - let index = self.raw_pointers_lookup_table.get(&mle_ptr).unwrap(); - - let mle_ext_field = mle.as_ref().to_ext_field(); - let mle_ext_field = Arc::new(mle_ext_field); - let mle_ext_field_ptr = Arc::as_ptr(&mle_ext_field) as usize; - flattened_ml_extensions.push(mle_ext_field); - hm.insert(mle_ext_field_ptr, *index); - } - end_timer!(timer); - VirtualPolynomial { - aux_info: self.aux_info.clone(), - products, - flattened_ml_extensions, - raw_pointers_lookup_table: hm, - } - } } /// Evaluate eq polynomial. diff --git a/multilinear_extensions/src/virtual_poly_v2.rs b/multilinear_extensions/src/virtual_poly_v2.rs deleted file mode 100644 index 5d64d88bc..000000000 --- a/multilinear_extensions/src/virtual_poly_v2.rs +++ /dev/null @@ -1,263 +0,0 @@ -use std::{cmp::max, collections::HashMap, marker::PhantomData, sync::Arc}; - -use crate::{ - mle::{DenseMultilinearExtension, MultilinearExtension}, - util::bit_decompose, -}; -use ark_std::{end_timer, start_timer}; -use ff_ext::ExtensionField; -use itertools::Itertools; -use serde::{Deserialize, Serialize}; - -pub type ArcMultilinearExtension<'a, E> = - Arc> + 'a>; -#[rustfmt::skip] -/// A virtual polynomial is a sum of products of multilinear polynomials; -/// where the multilinear polynomials are stored via their multilinear -/// extensions: `(coefficient, DenseMultilinearExtension)` -/// -/// * Number of products n = `polynomial.products.len()`, -/// * Number of multiplicands of ith product m_i = -/// `polynomial.products[i].1.len()`, -/// * Coefficient of ith product c_i = `polynomial.products[i].0` -/// -/// The resulting polynomial is -/// -/// $$ \sum_{i=0}^{n} c_i \cdot \prod_{j=0}^{m_i} P_{ij} $$ -/// -/// Example: -/// f = c0 * f0 * f1 * f2 + c1 * f3 * f4 -/// where f0 ... f4 are multilinear polynomials -/// -/// - flattened_ml_extensions stores the multilinear extension representation of -/// f0, f1, f2, f3 and f4 -/// - products is -/// \[ -/// (c0, \[0, 1, 2\]), -/// (c1, \[3, 4\]) -/// \] -/// - raw_pointers_lookup_table maps fi to i -/// -#[derive(Default, Clone)] -pub struct VirtualPolynomialV2<'a, E: ExtensionField> { - /// Aux information about the multilinear polynomial - pub aux_info: VPAuxInfo, - /// list of reference to products (as usize) of multilinear extension - pub products: Vec<(E, Vec)>, - /// Stores multilinear extensions in which product multiplicand can refer - /// to. - pub flattened_ml_extensions: Vec>, - /// Pointers to the above poly extensions - raw_pointers_lookup_table: HashMap, -} - -#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] -/// Auxiliary information about the multilinear polynomial -pub struct VPAuxInfo { - /// max number of multiplicands in each product - pub max_degree: usize, - /// max number of variables of the polynomial - pub max_num_variables: usize, - /// Associated field - #[doc(hidden)] - pub phantom: PhantomData, -} - -impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> { - /// Creates an empty virtual polynomial with `max_num_variables`. - pub fn new(max_num_variables: usize) -> Self { - VirtualPolynomialV2 { - aux_info: VPAuxInfo { - max_degree: 0, - max_num_variables, - phantom: PhantomData, - }, - products: Vec::new(), - flattened_ml_extensions: Vec::new(), - raw_pointers_lookup_table: HashMap::new(), - } - } - - /// Creates an new virtual polynomial from a MLE and its coefficient. - pub fn new_from_mle(mle: ArcMultilinearExtension<'a, E>, coefficient: E) -> Self { - let mle_ptr: usize = Arc::as_ptr(&mle) as *const () as usize; - let mut hm = HashMap::new(); - hm.insert(mle_ptr, 0); - - VirtualPolynomialV2 { - aux_info: VPAuxInfo { - // The max degree is the max degree of any individual variable - max_degree: 1, - max_num_variables: mle.num_vars(), - phantom: PhantomData, - }, - // here `0` points to the first polynomial of `flattened_ml_extensions` - products: vec![(coefficient, vec![0])], - flattened_ml_extensions: vec![mle], - raw_pointers_lookup_table: hm, - } - } - - /// Add a product of list of multilinear extensions to self - /// Returns an error if the list is empty. - /// - /// mle in mle_list must be in same num_vars() in same product, - /// while different product can have different num_vars() - /// - /// The MLEs will be multiplied together, and then multiplied by the scalar - /// `coefficient`. - pub fn add_mle_list(&mut self, mle_list: Vec>, coefficient: E) { - let mle_list: Vec> = mle_list.into_iter().collect(); - let mut indexed_product = Vec::with_capacity(mle_list.len()); - - assert!(!mle_list.is_empty(), "input mle_list is empty"); - // sanity check: all mle in mle_list must have same num_vars() - assert!( - mle_list - .iter() - .map(|m| { - assert!(m.num_vars() <= self.aux_info.max_num_variables); - m.num_vars() - }) - .all_equal() - ); - - self.aux_info.max_degree = max(self.aux_info.max_degree, mle_list.len()); - - for mle in mle_list { - let mle_ptr: usize = Arc::as_ptr(&mle) as *const () as usize; - if let Some(index) = self.raw_pointers_lookup_table.get(&mle_ptr) { - indexed_product.push(*index) - } else { - let curr_index = self.flattened_ml_extensions.len(); - self.flattened_ml_extensions.push(mle); - self.raw_pointers_lookup_table.insert(mle_ptr, curr_index); - indexed_product.push(curr_index); - } - } - self.products.push((coefficient, indexed_product)); - } - - /// in-place merge with another virtual polynomial - pub fn merge(&mut self, other: &VirtualPolynomialV2<'a, E>) { - let start = start_timer!(|| "virtual poly add"); - for (coeffient, products) in other.products.iter() { - let cur: Vec<_> = products - .iter() - .map(|&x| other.flattened_ml_extensions[x].clone()) - .collect(); - - self.add_mle_list(cur, *coeffient); - } - end_timer!(start); - } - - /// Multiple the current VirtualPolynomial by an MLE: - /// - add the MLE to the MLE list; - /// - multiple each product by MLE and its coefficient. - /// Returns an error if the MLE has a different `num_vars()` from self. - #[tracing::instrument(skip_all, name = "mul_by_mle")] - pub fn mul_by_mle(&mut self, mle: ArcMultilinearExtension<'a, E>, coefficient: E::BaseField) { - let start = start_timer!(|| "mul by mle"); - - assert_eq!( - mle.num_vars(), - self.aux_info.max_num_variables, - "product has a multiplicand with wrong number of variables {} vs {}", - mle.num_vars(), - self.aux_info.max_num_variables - ); - - let mle_ptr = Arc::as_ptr(&mle) as *const () as usize; - - // check if this mle already exists in the virtual polynomial - let mle_index = match self.raw_pointers_lookup_table.get(&mle_ptr) { - Some(&p) => p, - None => { - self.raw_pointers_lookup_table - .insert(mle_ptr, self.flattened_ml_extensions.len()); - self.flattened_ml_extensions.push(mle); - self.flattened_ml_extensions.len() - 1 - } - }; - - for (prod_coef, indices) in self.products.iter_mut() { - // - add the MLE to the MLE list; - // - multiple each product by MLE and its coefficient. - indices.push(mle_index); - *prod_coef *= coefficient; - } - - // increase the max degree by one as the MLE has degree 1. - self.aux_info.max_degree += 1; - end_timer!(start); - } - - /// Evaluate the virtual polynomial at point `point`. - /// Returns an error is point.len() does not match `num_variables`. - pub fn evaluate(&self, point: &[E]) -> E { - let start = start_timer!(|| "evaluation"); - - assert_eq!( - self.aux_info.max_num_variables, - point.len(), - "wrong number of variables {} vs {}", - self.aux_info.max_num_variables, - point.len() - ); - - let evals: Vec = self - .flattened_ml_extensions - .iter() - .map(|x| x.evaluate(&point[0..x.num_vars()])) - .collect(); - - let res = self - .products - .iter() - .map(|(c, p)| p.iter().map(|&i| evals[i]).product::() * *c) - .sum(); - - end_timer!(start); - res - } - - /// Print out the evaluation map for testing. Panic if the num_vars() > 5. - pub fn print_evals(&self) { - if self.aux_info.max_num_variables > 5 { - panic!("this function is used for testing only. cannot print more than 5 num_vars()") - } - for i in 0..1 << self.aux_info.max_num_variables { - let point = bit_decompose(i, self.aux_info.max_num_variables); - let point_fr: Vec = point.iter().map(|&x| E::from(x as u64)).collect(); - println!("{} {:?}", i, self.evaluate(point_fr.as_ref())) - } - println!() - } - - // // TODO: This seems expensive. Is there a better way to covert poly into its ext fields? - // pub fn to_ext_field(&self) -> VirtualPolynomialV2 { - // let timer = start_timer!(|| "convert VP to ext field"); - // let products = self.products.iter().map(|(f, v)| (*f, v.clone())).collect(); - - // let mut flattened_ml_extensions = vec![]; - // let mut hm = HashMap::new(); - // for mle in self.flattened_ml_extensions.iter() { - // let mle_ptr = Arc::as_ptr(mle) as *const () as usize; - // let index = self.raw_pointers_lookup_table.get(&mle_ptr).unwrap(); - - // let mle_ext_field = mle.as_ref().to_ext_field(); - // let mle_ext_field = Arc::new(mle_ext_field); - // let mle_ext_field_ptr = Arc::as_ptr(&mle_ext_field) as usize; - // flattened_ml_extensions.push(mle_ext_field); - // hm.insert(mle_ext_field_ptr, *index); - // } - // end_timer!(timer); - // VirtualPolynomialV2 { - // aux_info: self.aux_info.clone(), - // products, - // flattened_ml_extensions, - // raw_pointers_lookup_table: hm, - // } - // } -} diff --git a/sumcheck/benches/devirgo_sumcheck.rs b/sumcheck/benches/devirgo_sumcheck.rs index 669ed9450..7fb919cb7 100644 --- a/sumcheck/benches/devirgo_sumcheck.rs +++ b/sumcheck/benches/devirgo_sumcheck.rs @@ -7,14 +7,14 @@ use ark_std::test_rng; use criterion::*; use ff_ext::ExtensionField; use itertools::Itertools; -use sumcheck::{structs::IOPProverStateV2 as IOPProverState, util::ceil_log2}; +use sumcheck::{structs::IOPProverState, util::ceil_log2}; use goldilocks::GoldilocksExt2; use multilinear_extensions::{ mle::DenseMultilinearExtension, op_mle, util::max_usable_threads, - virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2 as VirtualPolynomial}, + virtual_poly::{ArcMultilinearExtension, VirtualPolynomial}, }; use transcript::BasicTranscript as Transcript; diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index 4578b7ec2..797f151d4 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -2,7 +2,6 @@ #![feature(decl_macro)] pub mod macros; mod prover; -mod prover_v2; pub mod structs; pub mod util; mod verifier; diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index 0dc5de870..af4169d83 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -3,10 +3,16 @@ use std::{array, mem, sync::Arc}; use ark_std::{end_timer, start_timer}; use crossbeam_channel::bounded; use ff_ext::ExtensionField; +use itertools::Itertools; use multilinear_extensions::{ - commutative_op_mle_pair, mle::MultilinearExtension, op_mle, virtual_poly::VirtualPolynomial, + commutative_op_mle_pair, + mle::{DenseMultilinearExtension, MultilinearExtension}, + op_mle, op_mle_product_3, op_mle3_range, + util::largest_even_below, + virtual_poly::VirtualPolynomial, }; use rayon::{ + Scope, iter::{IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator}, prelude::{IntoParallelIterator, ParallelIterator}, }; @@ -17,33 +23,36 @@ use crate::{ structs::{IOPProof, IOPProverMessage, IOPProverState}, util::{ AdditiveArray, AdditiveVec, barycentric_weights, ceil_log2, extrapolate, - merge_sumcheck_polys, + merge_sumcheck_polys, serial_extrapolate, }, }; -impl IOPProverState { +impl<'a, E: ExtensionField> IOPProverState<'a, E> { /// Given a virtual polynomial, generate an IOP proof. /// multi-threads model follow https://arxiv.org/pdf/2210.00264#page=8 "distributed sumcheck" /// This is experiment features. It's preferable that we move parallel level up more to /// "bould_poly" so it can be more isolation - #[tracing::instrument(skip_all, name = "sumcheck::prove_batch_polys")] + #[tracing::instrument(skip_all, name = "sumcheck::prove_batch_polys", level = "trace")] pub fn prove_batch_polys( max_thread_id: usize, - mut polys: Vec>, + mut polys: Vec>, transcript: &mut impl Transcript, - ) -> (IOPProof, IOPProverState) { + ) -> (IOPProof, IOPProverState<'a, E>) { assert!(!polys.is_empty()); assert_eq!(polys.len(), max_thread_id); + assert!(max_thread_id.is_power_of_two()); let log2_max_thread_id = ceil_log2(max_thread_id); // do not support SIZE not power of 2 + assert!( + polys + .iter() + .map(|poly| (poly.aux_info.max_num_variables, poly.aux_info.max_degree)) + .all_equal() + ); let (num_variables, max_degree) = ( - polys[0].aux_info.num_variables, + polys[0].aux_info.max_num_variables, polys[0].aux_info.max_degree, ); - for poly in polys[1..].iter() { - assert!(poly.aux_info.num_variables == num_variables); - assert!(poly.aux_info.max_degree == max_degree); - } // return empty proof when target polymonial is constant if num_variables == 0 { @@ -68,137 +77,167 @@ impl IOPProverState { }) .collect::>(); - // spawn extra #(max_thread_id - 1) work threads, whereas the main-thread be the last work - // thread - for thread_id in 0..(max_thread_id - 1) { + // spawn extra #(max_thread_id - 1) work threads + let num_worker_threads = max_thread_id - 1; + // whereas the main-thread be the last work thread + let main_thread_id = num_worker_threads; + let span = entered_span!("spawn loop", profiling_4 = true); + let scoped_fn = |s: &Scope<'a>| { + for (thread_id, poly) in polys.iter_mut().enumerate().take(num_worker_threads) { + let mut prover_state = Self::prover_init_with_extrapolation_aux( + mem::take(poly), + extrapolation_aux.clone(), + ); + let tx_prover_state = tx_prover_state.clone(); + let mut thread_based_transcript = thread_based_transcript.clone(); + s.spawn(move |_| { + let mut challenge = None; + // Note: This span is not nested into the "spawn loop" span, although lexically it looks so. + // Nesting is possible, but then `tracing-forest` does the wrong thing when measuring duration. + // TODO: investigate possibility of nesting with correct duration of parent span + let span = entered_span!("prove_rounds", profiling_5 = true); + for _ in 0..num_variables { + let prover_msg = IOPProverState::prove_round_and_update_state( + &mut prover_state, + &challenge, + ); + thread_based_transcript.append_field_element_exts(&prover_msg.evaluations); + + challenge = Some( + thread_based_transcript.get_and_append_challenge(b"Internal round"), + ); + thread_based_transcript.commit_rolling(); + } + exit_span!(span); + // pushing the last challenge point to the state + if let Some(p) = challenge { + prover_state.challenges.push(p); + // fix last challenge to collect final evaluation + prover_state + .poly + .flattened_ml_extensions + .iter_mut() + .for_each(|mle| { + let mle = Arc::get_mut(mle).unwrap(); + if mle.num_vars() > 0 { + mle.fix_variables_in_place(&[p.elements]); + } + }); + tx_prover_state + .send(Some((thread_id, prover_state))) + .unwrap(); + } else { + tx_prover_state.send(None).unwrap(); + } + }) + } + exit_span!(span); + + let mut prover_msgs = Vec::with_capacity(num_variables); let mut prover_state = Self::prover_init_with_extrapolation_aux( - mem::take(&mut polys[thread_id]), + mem::take(&mut polys[main_thread_id]), extrapolation_aux.clone(), ); let tx_prover_state = tx_prover_state.clone(); let mut thread_based_transcript = thread_based_transcript.clone(); - let spawn_task = move || { - let mut challenge = None; - let span = entered_span!("prove_rounds"); - for _ in 0..num_variables { - let prover_msg = - IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge); - thread_based_transcript.append_field_element_exts(&prover_msg.evaluations); - - challenge = - Some(thread_based_transcript.get_and_append_challenge(b"Internal round")); - thread_based_transcript.commit_rolling(); - } - exit_span!(span); - // pushing the last challenge point to the state - if let Some(p) = challenge { - prover_state.challenges.push(p); - // fix last challenge to collect final evaluation - prover_state - .poly - .flattened_ml_extensions - .iter_mut() - .for_each(|mle| { - let mle = Arc::make_mut(mle); - mle.fix_variables_in_place(&[p.elements]); - }); - tx_prover_state - .send(Some((thread_id, prover_state))) - .unwrap(); - } else { - tx_prover_state.send(None).unwrap(); + let main_thread_span = entered_span!("main_thread_prove_rounds"); + // main thread also be one worker thread + // NOTE inline main thread flow with worker thread to improve efficiency + // refactor to shared closure cause to 5% throuput drop + let mut challenge = None; + for _ in 0..num_variables { + let prover_msg = + IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge); + + // for each round, we must collect #SIZE prover message + let mut evaluations = AdditiveVec::new(max_degree + 1); + + // sum for all round poly evaluations vector + evaluations += AdditiveVec(prover_msg.evaluations); + for _ in 0..num_worker_threads { + let round_poly_coeffs = thread_based_transcript.read_field_element_exts(); + evaluations += AdditiveVec(round_poly_coeffs); } - }; - // create local thread pool if global rayon pool size < max_thread_id - // this usually cause by global pool size not power of 2. - if rayon::current_num_threads() >= max_thread_id { - rayon::spawn(spawn_task); - } else { - panic!( - "rayon global thread pool size {} mismatch with desired poly size {}.", - rayon::current_num_threads(), - polys.len() - ); - } - } + let get_challenge_span = entered_span!("main_thread_get_challenge"); + transcript.append_field_element_exts(&evaluations.0); - let mut prover_msgs = Vec::with_capacity(num_variables); - let thread_id = max_thread_id - 1; - let mut prover_state = Self::prover_init_with_extrapolation_aux( - mem::take(&mut polys[thread_id]), - extrapolation_aux.clone(), - ); - let tx_prover_state = tx_prover_state.clone(); - let mut thread_based_transcript = thread_based_transcript.clone(); + let next_challenge = transcript.get_and_append_challenge(b"Internal round"); + (0..num_worker_threads).for_each(|_| { + thread_based_transcript.send_challenge(next_challenge.elements); + }); - let span = entered_span!("main_thread_prove_rounds"); - // main thread also be one worker thread - // NOTE inline main thread flow with worker thread to improve efficiency - // refactor to shared closure cause to 5% throuput drop - let mut challenge = None; - for _ in 0..num_variables { - let prover_msg = - IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge); - thread_based_transcript.append_field_element_exts(&prover_msg.evaluations); + exit_span!(get_challenge_span); - // for each round, we must collect #SIZE prover message - let mut evaluations = AdditiveVec::new(max_degree + 1); + prover_msgs.push(IOPProverMessage { + evaluations: evaluations.0, + }); - // sum for all round poly evaluations vector - for _ in 0..max_thread_id { - let round_poly_coeffs = thread_based_transcript.read_field_element_exts(); - evaluations += AdditiveVec(round_poly_coeffs); + challenge = Some(next_challenge); + thread_based_transcript.commit_rolling(); + } + exit_span!(main_thread_span); + // pushing the last challenge point to the state + if let Some(p) = challenge { + prover_state.challenges.push(p); + // fix last challenge to collect final evaluation + prover_state + .poly + .flattened_ml_extensions + .iter_mut() + .for_each(|mle| { + if num_variables == 1 { + // first time fix variable should be create new instance + if mle.num_vars() > 0 { + *mle = mle.fix_variables(&[p.elements]).into(); + } else { + *mle = + Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + 0, + mle.get_base_field_vec().to_vec(), + )) + } + } else { + let mle = Arc::get_mut(mle).unwrap(); + if mle.num_vars() > 0 { + mle.fix_variables_in_place(&[p.elements]); + } + } + }); + tx_prover_state + .send(Some((main_thread_id, prover_state))) + .unwrap(); + } else { + tx_prover_state.send(None).unwrap(); } - let span = entered_span!("main_thread_get_challenge"); - transcript.append_field_element_exts(&evaluations.0); - - let next_challenge = transcript.get_and_append_challenge(b"Internal round"); - (0..max_thread_id).for_each(|_| { - thread_based_transcript.send_challenge(next_challenge.elements); - }); - - exit_span!(span); + let mut prover_states = (0..max_thread_id) + .map(|_| IOPProverState::default()) + .collect::>(); + for _ in 0..max_thread_id { + if let Some((index, prover_msg)) = rx_prover_state.recv().unwrap() { + prover_states[index] = prover_msg + } else { + println!("got empty msg, which is normal if virtual poly is constant function") + } + } - prover_msgs.push(IOPProverMessage { - evaluations: evaluations.0, - }); + (prover_states, prover_msgs) + }; - challenge = Some(thread_based_transcript.get_and_append_challenge(b"Internal round")); - thread_based_transcript.commit_rolling(); - } - exit_span!(span); - // pushing the last challenge point to the state - if let Some(p) = challenge { - prover_state.challenges.push(p); - // fix last challenge to collect final evaluation - prover_state - .poly - .flattened_ml_extensions - .iter_mut() - .for_each(|mle| { - let mle = Arc::make_mut(mle); - mle.fix_variables_in_place(&[p.elements]); - }); - tx_prover_state - .send(Some((thread_id, prover_state))) - .unwrap(); + // create local thread pool if global rayon pool size < max_thread_id + // this usually cause by global pool size not power of 2. + let (mut prover_states, mut prover_msgs) = if rayon::current_num_threads() >= max_thread_id + { + rayon::in_place_scope(scoped_fn) } else { - tx_prover_state.send(None).unwrap(); - } - - let mut prover_states = (0..max_thread_id) - .map(|_| IOPProverState::default()) - .collect::>(); - for _ in 0..max_thread_id { - if let Some((index, prover_msg)) = rx_prover_state.recv().unwrap() { - prover_states[index] = prover_msg - } else { - println!("got empty msg, which is normal if virtual poly is constant function") - } - } + panic!( + "rayon global thread pool size {} mismatch with desired poly size {}.", + rayon::current_num_threads(), + polys.len() + ); + }; if log2_max_thread_id == 0 { let prover_state = mem::take(&mut prover_states[0]); @@ -243,10 +282,18 @@ impl IOPProverState { prover_state .poly .flattened_ml_extensions - .par_iter_mut() - .for_each(|mle| { - Arc::make_mut(mle).fix_variables_in_place(&[p.elements]); - }); + .iter_mut() + .for_each( + |mle: &mut Arc< + dyn MultilinearExtension>, + >| { + if mle.num_vars() > 0 { + Arc::get_mut(mle) + .unwrap() + .fix_variables_in_place(&[p.elements]); + } + }, + ); }; exit_span!(span); @@ -270,12 +317,12 @@ impl IOPProverState { /// Initialize the prover state to argue for the sum of the input polynomial /// over {0,1}^`num_vars`. pub fn prover_init_with_extrapolation_aux( - polynomial: VirtualPolynomial, + polynomial: VirtualPolynomial<'a, E>, extrapolation_aux: Vec<(Vec, Vec)>, ) -> Self { let start = start_timer!(|| "sum check prover init"); assert_ne!( - polynomial.aux_info.num_variables, 0, + polynomial.aux_info.max_num_variables, 0, "Attempt to prove a constant." ); end_timer!(start); @@ -283,7 +330,7 @@ impl IOPProverState { let max_degree = polynomial.aux_info.max_degree; assert!(extrapolation_aux.len() == max_degree - 1); Self { - challenges: Vec::with_capacity(polynomial.aux_info.num_variables), + challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables), round: 0, poly: polynomial, extrapolation_aux, @@ -303,7 +350,7 @@ impl IOPProverState { start_timer!(|| format!("sum check prove {}-th round and update state", self.round)); assert!( - self.round < self.poly.aux_info.num_variables, + self.round < self.poly.aux_info.max_num_variables, "Prover is not active" ); @@ -334,10 +381,13 @@ impl IOPProverState { let r = self.challenges[self.round - 1]; if self.challenges.len() == 1 { - self.poly - .flattened_ml_extensions - .iter_mut() - .for_each(|f| *f = Arc::new(f.fix_variables(&[r.elements]))); + self.poly.flattened_ml_extensions.iter_mut().for_each(|f| { + if f.num_vars() > 0 { + *f = Arc::new(f.fix_variables(&[r.elements])); + } else { + panic!("calling sumcheck on constant") + } + }); } else { self.poly .flattened_ml_extensions @@ -345,9 +395,13 @@ impl IOPProverState { // benchmark result indicate make_mut achieve better performange than get_mut, // which can be +5% overhead rust docs doen't explain the // reason - .map(Arc::make_mut) + .map(Arc::get_mut) .for_each(|f| { - f.fix_variables_in_place(&[r.elements]); + if let Some(f) = f { + if f.num_vars() > 0 { + f.fix_variables_in_place(&[r.elements]); + } + } }); } } @@ -358,6 +412,16 @@ impl IOPProverState { // Step 2: generate sum for the partial evaluated polynomial: // f(r_1, ... r_m,, x_{m+1}... x_n) + // + // To deal with different num_vars, we exploit a fact that for each product which num_vars < max_num_vars, + // for it evaluation value we need to times 2^(max_num_vars - num_vars) + // E.g. Giving multivariate poly f(X) = f_1(X1) + f_2(X), X1 \in {F}^{n'}, X \in {F}^{n}, |X1| := n', |X| = n, n' <= n + // For i round univariate poly, f^i(x) + // f^i[0] = \sum_b f(r, 0, b), b \in {0, 1}^{n-i-1}, r \in {F}^{n-i-1} chanllenge get from prev rounds + // = \sum_b f_1(r, 0, b1) + f_2(r, 0, b), |b| >= |b1|, |b| - |b1| = n - n' + // = 2^(|b| - |b1|) * \sum_b1 f_1(r, 0, b1) + \sum_b f_2(r, 0, b) + // same applied on f^i[1] + // It imply that, for every evals in f_1, to compute univariate poly, we just need to times a factor 2^(|b| - |b1|) for it evaluation value let span = entered_span!("products_sum"); let AdditiveVec(products_sum) = self.poly.products.iter().fold( AdditiveVec::new(self.poly.aux_info.max_degree + 1), @@ -369,13 +433,25 @@ impl IOPProverState { let f = &self.poly.flattened_ml_extensions[products[0]]; op_mle! { |f| { - (0..f.len()) + let res = (0..largest_even_below(f.len())) .step_by(2) - .fold(AdditiveArray::(array::from_fn(|_| 0.into())), |mut acc, b| { + .rev() + .fold(AdditiveArray::<_, 2>(array::from_fn(|_| 0.into())), |mut acc, b| { acc.0[0] += f[b]; acc.0[1] += f[b+1]; acc - }) + }); + let res = if f.len() == 1 { + AdditiveArray::<_, 2>([f[0]; 2]) + } else { + res + }; + let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); + if num_vars_multiplicity > 0 { + AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + } else { + res + } }, |sum| AdditiveArray(sum.0.map(E::from)) } @@ -387,32 +463,85 @@ impl IOPProverState { &self.poly.flattened_ml_extensions[products[1]], ); commutative_op_mle_pair!( - |f, g| (0..f.len()).step_by(2).fold( - AdditiveArray::(array::from_fn(|_| 0.into())), - |mut acc, b| { - acc.0[0] += f[b] * g[b]; - acc.0[1] += f[b + 1] * g[b + 1]; - acc.0[2] += - (f[b + 1] + f[b + 1] - f[b]) * (g[b + 1] + g[b + 1] - g[b]); - acc + |f, g| { + let res = (0..largest_even_below(f.len())).step_by(2).rev().fold( + AdditiveArray::<_, 3>(array::from_fn(|_| 0.into())), + |mut acc, b| { + acc.0[0] += f[b] * g[b]; + acc.0[1] += f[b + 1] * g[b + 1]; + acc.0[2] += + (f[b + 1] + f[b + 1] - f[b]) * (g[b + 1] + g[b + 1] - g[b]); + acc + }); + let res = if f.len() == 1 { + AdditiveArray::<_, 3>([f[0] * g[0]; 3]) + } else { + res + }; + let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); + if num_vars_multiplicity > 0 { + AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + } else { + res + } + }, + |sum| AdditiveArray(sum.0.map(E::from)) + ) + .to_vec() + } + 3 => { + let (f1, f2, f3) = ( + &self.poly.flattened_ml_extensions[products[0]], + &self.poly.flattened_ml_extensions[products[1]], + &self.poly.flattened_ml_extensions[products[2]], + ); + op_mle_product_3!( + |f1, f2, f3| { + let res = (0..largest_even_below(f1.len())) + .step_by(2) + .rev() + .map(|b| { + // f = c x + d + let c1 = f1[b + 1] - f1[b]; + let c2 = f2[b + 1] - f2[b]; + let c3 = f3[b + 1] - f3[b]; + AdditiveArray([ + f1[b] * (f2[b] * f3[b]), + f1[b + 1] * (f2[b + 1] * f3[b + 1]), + (c1 + f1[b + 1]) + * ((c2 + f2[b + 1]) * (c3 + f3[b + 1])), + (c1 + c1 + f1[b + 1]) + * ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])), + ]) + }) + .sum::>(); + let res = if f1.len() == 1 { + AdditiveArray::<_, 4>([f1[0] * f2[0] * f3[0]; 4]) + } else { + res + }; + let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f1.len()).max(1) + self.round - 1); + if num_vars_multiplicity > 0 { + AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + } else { + res } - ), + }, |sum| AdditiveArray(sum.0.map(E::from)) ) .to_vec() } - _ => unimplemented!("do not support degree > 2"), + _ => unimplemented!("do not support degree > 3"), }; exit_span!(span); sum.iter_mut().for_each(|sum| *sum *= coefficient); let span = entered_span!("extrapolation"); let extrapolation = (0..self.poly.aux_info.max_degree - products.len()) - .into_par_iter() .map(|i| { let (points, weights) = &self.extrapolation_aux[products.len() - 1]; let at = E::from((products.len() + 1 + i) as u64); - extrapolate(points, weights, &sum, &at) + serial_extrapolate(points, weights, &sum, &at) }) .collect::>(); sum.extend(extrapolation); @@ -439,9 +568,9 @@ impl IOPProverState { .iter() .map(|mle| { assert!( - mle.evaluations.len() == 1, + mle.evaluations().len() == 1, "mle.evaluations.len() {} != 1, must be called after prove_round_and_update_state", - mle.evaluations.len(), + mle.evaluations().len(), ); op_mle! { |mle| mle[0], @@ -454,14 +583,15 @@ impl IOPProverState { /// parallel version #[deprecated(note = "deprecated parallel version due to syncronizaion overhead")] -impl IOPProverState { +impl<'a, E: ExtensionField> IOPProverState<'a, E> { /// Given a virtual polynomial, generate an IOP proof. #[tracing::instrument(skip_all, name = "sumcheck::prove_parallel")] pub fn prove_parallel( - poly: VirtualPolynomial, + poly: VirtualPolynomial<'a, E>, transcript: &mut impl Transcript, - ) -> (IOPProof, IOPProverState) { - let (num_variables, max_degree) = (poly.aux_info.num_variables, poly.aux_info.max_degree); + ) -> (IOPProof, IOPProverState<'a, E>) { + let (num_variables, max_degree) = + (poly.aux_info.max_num_variables, poly.aux_info.max_degree); // return empty proof when target polymonial is constant if num_variables == 0 { @@ -507,7 +637,16 @@ impl IOPProverState { .flattened_ml_extensions .par_iter_mut() .for_each(|mle| { - Arc::make_mut(mle).fix_variables_in_place_parallel(&[p.elements]); + if let Some(mle) = Arc::get_mut(mle) { + if mle.num_vars() > 0 { + mle.fix_variables_in_place(&[p.elements]) + } + } else { + *mle = Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + 0, + mle.get_base_field_vec().to_vec(), + )) + } }); }; exit_span!(span); @@ -529,16 +668,16 @@ impl IOPProverState { /// Initialize the prover state to argue for the sum of the input polynomial /// over {0,1}^`num_vars`. - pub(crate) fn prover_init_parallel(polynomial: VirtualPolynomial) -> Self { + pub(crate) fn prover_init_parallel(polynomial: VirtualPolynomial<'a, E>) -> Self { let start = start_timer!(|| "sum check prover init"); assert_ne!( - polynomial.aux_info.num_variables, 0, + polynomial.aux_info.max_num_variables, 0, "Attempt to prove a constant." ); let max_degree = polynomial.aux_info.max_degree; let prover_state = Self { - challenges: Vec::with_capacity(polynomial.aux_info.num_variables), + challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables), round: 0, poly: polynomial, extrapolation_aux: (1..max_degree) @@ -567,7 +706,7 @@ impl IOPProverState { start_timer!(|| format!("sum check prove {}-th round and update state", self.round)); assert!( - self.round < self.poly.aux_info.num_variables, + self.round < self.poly.aux_info.max_num_variables, "Prover is not active" ); @@ -597,7 +736,13 @@ impl IOPProverState { self.poly .flattened_ml_extensions .par_iter_mut() - .for_each(|f| *f = f.fix_variables_parallel(&[r.elements]).into()); + .for_each(|f| { + if f.num_vars() > 0 { + *f = Arc::new(f.fix_variables_parallel(&[r.elements])); + } else { + panic!("calling sumcheck on constant") + } + }); } else { self.poly .flattened_ml_extensions @@ -605,9 +750,13 @@ impl IOPProverState { // benchmark result indicate make_mut achieve better performange than get_mut, // which can be +5% overhead rust docs doen't explain the // reason - .map(Arc::make_mut) + .map(Arc::get_mut) .for_each(|f| { - f.fix_variables_in_place_parallel(&[r.elements]); + if let Some(f) = f { + if f.num_vars() > 0 { + f.fix_variables_in_place_parallel(&[r.elements]) + } + } }); } } @@ -632,17 +781,30 @@ impl IOPProverState { 1 => { let f = &self.poly.flattened_ml_extensions[products[0]]; op_mle! { - |f| (0..f.len()) - .into_par_iter() - .step_by(2) - .with_min_len(64) - .map(|b| { - AdditiveArray([ - f[b], - f[b + 1] - ]) - }) - .sum::>(), + |f| { + let res = (0..largest_even_below(f.len())) + .into_par_iter() + .step_by(2) + .with_min_len(64) + .map(|b| { + AdditiveArray([ + f[b], + f[b + 1] + ]) + }) + .sum::>(); + let res = if f.len() == 1 { + AdditiveArray::<_, 2>([f[0]; 2]) + } else { + res + }; + let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); + if num_vars_multiplicity > 0 { + AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + } else { + res + } + }, |sum| AdditiveArray(sum.0.map(E::from)) } .to_vec() @@ -653,7 +815,8 @@ impl IOPProverState { &self.poly.flattened_ml_extensions[products[1]], ); commutative_op_mle_pair!( - |f, g| (0..f.len()) + |f, g| { + let res = (0..largest_even_below(f.len())) .into_par_iter() .step_by(2) .with_min_len(64) @@ -665,12 +828,65 @@ impl IOPProverState { * (g[b + 1] + g[b + 1] - g[b]), ]) }) - .sum::>(), + .sum::>(); + let res = if f.len() == 1 { + AdditiveArray::<_, 3>([f[0] * g[0]; 3]) + } else { + res + }; + let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); + if num_vars_multiplicity > 0 { + AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + } else { + res + } + }, + |sum| AdditiveArray(sum.0.map(E::from)) + ) + .to_vec() + } + 3 => { + let (f1, f2, f3) = ( + &self.poly.flattened_ml_extensions[products[0]], + &self.poly.flattened_ml_extensions[products[1]], + &self.poly.flattened_ml_extensions[products[2]], + ); + op_mle_product_3!( + |f1, f2, f3| { + let res = (0..largest_even_below(f1.len())) + .step_by(2) + .map(|b| { + // f = c x + d + let c1 = f1[b + 1] - f1[b]; + let c2 = f2[b + 1] - f2[b]; + let c3 = f3[b + 1] - f3[b]; + AdditiveArray([ + f1[b] * (f2[b] * f3[b]), + f1[b + 1] * (f2[b + 1] * f3[b + 1]), + (c1 + f1[b + 1]) + * ((c2 + f2[b + 1]) * (c3 + f3[b + 1])), + (c1 + c1 + f1[b + 1]) + * ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])), + ]) + }) + .sum::>(); + let res = if f1.len() == 1 { + AdditiveArray::<_, 4>([f1[0] * f2[0] * f3[0]; 4]) + } else { + res + }; + let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f1.len()).max(1) + self.round - 1); + if num_vars_multiplicity > 0 { + AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + } else { + res + } + }, |sum| AdditiveArray(sum.0.map(E::from)) ) .to_vec() } - _ => unimplemented!("do not support degree > 2"), + _ => unimplemented!("do not support degree > 3"), }; exit_span!(span); sum.iter_mut().for_each(|sum| *sum *= coefficient); diff --git a/sumcheck/src/prover_v2.rs b/sumcheck/src/prover_v2.rs deleted file mode 100644 index b4021b77d..000000000 --- a/sumcheck/src/prover_v2.rs +++ /dev/null @@ -1,919 +0,0 @@ -use std::{array, mem, sync::Arc}; - -use ark_std::{end_timer, start_timer}; -use crossbeam_channel::bounded; -use ff_ext::ExtensionField; -use itertools::Itertools; -use multilinear_extensions::{ - commutative_op_mle_pair, - mle::{DenseMultilinearExtension, MultilinearExtension}, - op_mle, op_mle_product_3, op_mle3_range, - util::largest_even_below, - virtual_poly_v2::VirtualPolynomialV2, -}; -use rayon::{ - Scope, - iter::{IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator}, - prelude::{IntoParallelIterator, ParallelIterator}, -}; -use transcript::{Challenge, Transcript, TranscriptSyncronized}; - -use crate::{ - macros::{entered_span, exit_span}, - structs::{IOPProof, IOPProverMessage, IOPProverStateV2}, - util::{ - AdditiveArray, AdditiveVec, barycentric_weights, ceil_log2, extrapolate, - merge_sumcheck_polys_v2, serial_extrapolate, - }, -}; - -impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { - /// Given a virtual polynomial, generate an IOP proof. - /// multi-threads model follow https://arxiv.org/pdf/2210.00264#page=8 "distributed sumcheck" - /// This is experiment features. It's preferable that we move parallel level up more to - /// "bould_poly" so it can be more isolation - #[tracing::instrument(skip_all, name = "sumcheck::prove_batch_polys", level = "trace")] - pub fn prove_batch_polys( - max_thread_id: usize, - mut polys: Vec>, - transcript: &mut impl Transcript, - ) -> (IOPProof, IOPProverStateV2<'a, E>) { - assert!(!polys.is_empty()); - assert_eq!(polys.len(), max_thread_id); - assert!(max_thread_id.is_power_of_two()); - - let log2_max_thread_id = ceil_log2(max_thread_id); // do not support SIZE not power of 2 - assert!( - polys - .iter() - .map(|poly| (poly.aux_info.max_num_variables, poly.aux_info.max_degree)) - .all_equal() - ); - let (num_variables, max_degree) = ( - polys[0].aux_info.max_num_variables, - polys[0].aux_info.max_degree, - ); - - // return empty proof when target polymonial is constant - if num_variables == 0 { - return (IOPProof::default(), IOPProverStateV2 { - poly: polys[0].clone(), - ..Default::default() - }); - } - let start = start_timer!(|| "sum check prove"); - - transcript.append_message(&(num_variables + log2_max_thread_id).to_le_bytes()); - transcript.append_message(&max_degree.to_le_bytes()); - let thread_based_transcript = TranscriptSyncronized::new(max_thread_id); - let (tx_prover_state, rx_prover_state) = bounded(max_thread_id); - - // extrapolation_aux only need to init once - let extrapolation_aux = (1..max_degree) - .map(|degree| { - let points = (0..1 + degree as u64).map(E::from).collect::>(); - let weights = barycentric_weights(&points); - (points, weights) - }) - .collect::>(); - - // spawn extra #(max_thread_id - 1) work threads - let num_worker_threads = max_thread_id - 1; - // whereas the main-thread be the last work thread - let main_thread_id = num_worker_threads; - let span = entered_span!("spawn loop", profiling_4 = true); - let scoped_fn = |s: &Scope<'a>| { - for (thread_id, poly) in polys.iter_mut().enumerate().take(num_worker_threads) { - let mut prover_state = Self::prover_init_with_extrapolation_aux( - mem::take(poly), - extrapolation_aux.clone(), - ); - let tx_prover_state = tx_prover_state.clone(); - let mut thread_based_transcript = thread_based_transcript.clone(); - s.spawn(move |_| { - let mut challenge = None; - // Note: This span is not nested into the "spawn loop" span, although lexically it looks so. - // Nesting is possible, but then `tracing-forest` does the wrong thing when measuring duration. - // TODO: investigate possibility of nesting with correct duration of parent span - let span = entered_span!("prove_rounds", profiling_5 = true); - for _ in 0..num_variables { - let prover_msg = IOPProverStateV2::prove_round_and_update_state( - &mut prover_state, - &challenge, - ); - thread_based_transcript.append_field_element_exts(&prover_msg.evaluations); - - challenge = Some( - thread_based_transcript.get_and_append_challenge(b"Internal round"), - ); - thread_based_transcript.commit_rolling(); - } - exit_span!(span); - // pushing the last challenge point to the state - if let Some(p) = challenge { - prover_state.challenges.push(p); - // fix last challenge to collect final evaluation - prover_state - .poly - .flattened_ml_extensions - .iter_mut() - .for_each(|mle| { - let mle = Arc::get_mut(mle).unwrap(); - if mle.num_vars() > 0 { - mle.fix_variables_in_place(&[p.elements]); - } - }); - tx_prover_state - .send(Some((thread_id, prover_state))) - .unwrap(); - } else { - tx_prover_state.send(None).unwrap(); - } - }) - } - exit_span!(span); - - let mut prover_msgs = Vec::with_capacity(num_variables); - let mut prover_state = Self::prover_init_with_extrapolation_aux( - mem::take(&mut polys[main_thread_id]), - extrapolation_aux.clone(), - ); - let tx_prover_state = tx_prover_state.clone(); - let mut thread_based_transcript = thread_based_transcript.clone(); - - let main_thread_span = entered_span!("main_thread_prove_rounds"); - // main thread also be one worker thread - // NOTE inline main thread flow with worker thread to improve efficiency - // refactor to shared closure cause to 5% throuput drop - let mut challenge = None; - for _ in 0..num_variables { - let prover_msg = - IOPProverStateV2::prove_round_and_update_state(&mut prover_state, &challenge); - - // for each round, we must collect #SIZE prover message - let mut evaluations = AdditiveVec::new(max_degree + 1); - - // sum for all round poly evaluations vector - evaluations += AdditiveVec(prover_msg.evaluations); - for _ in 0..num_worker_threads { - let round_poly_coeffs = thread_based_transcript.read_field_element_exts(); - evaluations += AdditiveVec(round_poly_coeffs); - } - - let get_challenge_span = entered_span!("main_thread_get_challenge"); - transcript.append_field_element_exts(&evaluations.0); - - let next_challenge = transcript.get_and_append_challenge(b"Internal round"); - (0..num_worker_threads).for_each(|_| { - thread_based_transcript.send_challenge(next_challenge.elements); - }); - - exit_span!(get_challenge_span); - - prover_msgs.push(IOPProverMessage { - evaluations: evaluations.0, - }); - - challenge = Some(next_challenge); - thread_based_transcript.commit_rolling(); - } - exit_span!(main_thread_span); - // pushing the last challenge point to the state - if let Some(p) = challenge { - prover_state.challenges.push(p); - // fix last challenge to collect final evaluation - prover_state - .poly - .flattened_ml_extensions - .iter_mut() - .for_each(|mle| { - if num_variables == 1 { - // first time fix variable should be create new instance - if mle.num_vars() > 0 { - *mle = mle.fix_variables(&[p.elements]).into(); - } else { - *mle = - Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - 0, - mle.get_base_field_vec().to_vec(), - )) - } - } else { - let mle = Arc::get_mut(mle).unwrap(); - if mle.num_vars() > 0 { - mle.fix_variables_in_place(&[p.elements]); - } - } - }); - tx_prover_state - .send(Some((main_thread_id, prover_state))) - .unwrap(); - } else { - tx_prover_state.send(None).unwrap(); - } - - let mut prover_states = (0..max_thread_id) - .map(|_| IOPProverStateV2::default()) - .collect::>(); - for _ in 0..max_thread_id { - if let Some((index, prover_msg)) = rx_prover_state.recv().unwrap() { - prover_states[index] = prover_msg - } else { - println!("got empty msg, which is normal if virtual poly is constant function") - } - } - - (prover_states, prover_msgs) - }; - - // create local thread pool if global rayon pool size < max_thread_id - // this usually cause by global pool size not power of 2. - let (mut prover_states, mut prover_msgs) = if rayon::current_num_threads() >= max_thread_id - { - rayon::in_place_scope(scoped_fn) - } else { - panic!( - "rayon global thread pool size {} mismatch with desired poly size {}.", - rayon::current_num_threads(), - polys.len() - ); - }; - - if log2_max_thread_id == 0 { - let prover_state = mem::take(&mut prover_states[0]); - return ( - IOPProof { - point: prover_state - .challenges - .iter() - .map(|challenge| challenge.elements) - .collect(), - proofs: prover_msgs, - }, - prover_state, - ); - } - - // second stage sumcheck - let poly = merge_sumcheck_polys_v2(&prover_states, max_thread_id); - let mut prover_state = - Self::prover_init_with_extrapolation_aux(poly, extrapolation_aux.clone()); - - let mut challenge = None; - let span = entered_span!("prove_rounds_stage2"); - for _ in 0..log2_max_thread_id { - let prover_msg = - IOPProverStateV2::prove_round_and_update_state(&mut prover_state, &challenge); - - prover_msg - .evaluations - .iter() - .for_each(|e| transcript.append_field_element_ext(e)); - prover_msgs.push(prover_msg); - challenge = Some(transcript.get_and_append_challenge(b"Internal round")); - } - exit_span!(span); - - let span = entered_span!("after_rounds_prover_state_stage2"); - // pushing the last challenge point to the state - if let Some(p) = challenge { - prover_state.challenges.push(p); - // fix last challenge to collect final evaluation - prover_state - .poly - .flattened_ml_extensions - .iter_mut() - .for_each( - |mle: &mut Arc< - dyn MultilinearExtension>, - >| { - if mle.num_vars() > 0 { - Arc::get_mut(mle) - .unwrap() - .fix_variables_in_place(&[p.elements]); - } - }, - ); - }; - exit_span!(span); - - end_timer!(start); - ( - IOPProof { - point: [ - mem::take(&mut prover_states[0]).challenges, - prover_state.challenges.clone(), - ] - .concat() - .iter() - .map(|challenge| challenge.elements) - .collect(), - proofs: prover_msgs, - }, - prover_state, - ) - } - - /// Initialize the prover state to argue for the sum of the input polynomial - /// over {0,1}^`num_vars`. - pub fn prover_init_with_extrapolation_aux( - polynomial: VirtualPolynomialV2<'a, E>, - extrapolation_aux: Vec<(Vec, Vec)>, - ) -> Self { - let start = start_timer!(|| "sum check prover init"); - assert_ne!( - polynomial.aux_info.max_num_variables, 0, - "Attempt to prove a constant." - ); - end_timer!(start); - - let max_degree = polynomial.aux_info.max_degree; - assert!(extrapolation_aux.len() == max_degree - 1); - Self { - challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables), - round: 0, - poly: polynomial, - extrapolation_aux, - } - } - - /// Receive message from verifier, generate prover message, and proceed to - /// next round. - /// - /// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2). - #[tracing::instrument(skip_all, name = "sumcheck::prove_round_and_update_state")] - pub(crate) fn prove_round_and_update_state( - &mut self, - challenge: &Option>, - ) -> IOPProverMessage { - let start = - start_timer!(|| format!("sum check prove {}-th round and update state", self.round)); - - assert!( - self.round < self.poly.aux_info.max_num_variables, - "Prover is not active" - ); - - // let fix_argument = start_timer!(|| "fix argument"); - - // Step 1: - // fix argument and evaluate f(x) over x_m = r; where r is the challenge - // for the current round, and m is the round number, indexed from 1 - // - // i.e.: - // at round m <= n, for each mle g(x_1, ... x_n) within the flattened_mle - // which has already been evaluated to - // - // g(r_1, ..., r_{m-1}, x_m ... x_n) - // - // eval g over r_m, and mutate g to g(r_1, ... r_m,, x_{m+1}... x_n) - let span = entered_span!("fix_variables"); - if self.round == 0 { - assert!(challenge.is_none(), "first round should be prover first."); - } else { - assert!( - challenge.is_some(), - "verifier message is empty in round {}", - self.round - ); - let chal = challenge.unwrap(); - self.challenges.push(chal); - let r = self.challenges[self.round - 1]; - - if self.challenges.len() == 1 { - self.poly.flattened_ml_extensions.iter_mut().for_each(|f| { - if f.num_vars() > 0 { - *f = Arc::new(f.fix_variables(&[r.elements])); - } else { - panic!("calling sumcheck on constant") - } - }); - } else { - self.poly - .flattened_ml_extensions - .iter_mut() - // benchmark result indicate make_mut achieve better performange than get_mut, - // which can be +5% overhead rust docs doen't explain the - // reason - .map(Arc::get_mut) - .for_each(|f| { - if let Some(f) = f { - if f.num_vars() > 0 { - f.fix_variables_in_place(&[r.elements]); - } - } - }); - } - } - exit_span!(span); - // end_timer!(fix_argument); - - self.round += 1; - - // Step 2: generate sum for the partial evaluated polynomial: - // f(r_1, ... r_m,, x_{m+1}... x_n) - // - // To deal with different num_vars, we exploit a fact that for each product which num_vars < max_num_vars, - // for it evaluation value we need to times 2^(max_num_vars - num_vars) - // E.g. Giving multivariate poly f(X) = f_1(X1) + f_2(X), X1 \in {F}^{n'}, X \in {F}^{n}, |X1| := n', |X| = n, n' <= n - // For i round univariate poly, f^i(x) - // f^i[0] = \sum_b f(r, 0, b), b \in {0, 1}^{n-i-1}, r \in {F}^{n-i-1} chanllenge get from prev rounds - // = \sum_b f_1(r, 0, b1) + f_2(r, 0, b), |b| >= |b1|, |b| - |b1| = n - n' - // = 2^(|b| - |b1|) * \sum_b1 f_1(r, 0, b1) + \sum_b f_2(r, 0, b) - // same applied on f^i[1] - // It imply that, for every evals in f_1, to compute univariate poly, we just need to times a factor 2^(|b| - |b1|) for it evaluation value - let span = entered_span!("products_sum"); - let AdditiveVec(products_sum) = self.poly.products.iter().fold( - AdditiveVec::new(self.poly.aux_info.max_degree + 1), - |mut products_sum, (coefficient, products)| { - let span = entered_span!("sum"); - - let mut sum = match products.len() { - 1 => { - let f = &self.poly.flattened_ml_extensions[products[0]]; - op_mle! { - |f| { - let res = (0..largest_even_below(f.len())) - .step_by(2) - .fold(AdditiveArray::<_, 2>(array::from_fn(|_| 0.into())), |mut acc, b| { - acc.0[0] += f[b]; - acc.0[1] += f[b+1]; - acc - }); - let res = if f.len() == 1 { - AdditiveArray::<_, 2>([f[0]; 2]) - } else { - res - }; - let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); - if num_vars_multiplicity > 0 { - AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) - } else { - res - } - }, - |sum| AdditiveArray(sum.0.map(E::from)) - } - .to_vec() - } - 2 => { - let (f, g) = ( - &self.poly.flattened_ml_extensions[products[0]], - &self.poly.flattened_ml_extensions[products[1]], - ); - commutative_op_mle_pair!( - |f, g| { - let res = (0..largest_even_below(f.len())).step_by(2).fold( - AdditiveArray::<_, 3>(array::from_fn(|_| 0.into())), - |mut acc, b| { - acc.0[0] += f[b] * g[b]; - acc.0[1] += f[b + 1] * g[b + 1]; - acc.0[2] += - (f[b + 1] + f[b + 1] - f[b]) * (g[b + 1] + g[b + 1] - g[b]); - acc - }); - let res = if f.len() == 1 { - AdditiveArray::<_, 3>([f[0] * g[0]; 3]) - } else { - res - }; - let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); - if num_vars_multiplicity > 0 { - AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) - } else { - res - } - }, - |sum| AdditiveArray(sum.0.map(E::from)) - ) - .to_vec() - } - 3 => { - let (f1, f2, f3) = ( - &self.poly.flattened_ml_extensions[products[0]], - &self.poly.flattened_ml_extensions[products[1]], - &self.poly.flattened_ml_extensions[products[2]], - ); - op_mle_product_3!( - |f1, f2, f3| { - let res = (0..largest_even_below(f1.len())) - .step_by(2) - .map(|b| { - // f = c x + d - let c1 = f1[b + 1] - f1[b]; - let c2 = f2[b + 1] - f2[b]; - let c3 = f3[b + 1] - f3[b]; - AdditiveArray([ - f1[b] * (f2[b] * f3[b]), - f1[b + 1] * (f2[b + 1] * f3[b + 1]), - (c1 + f1[b + 1]) - * ((c2 + f2[b + 1]) * (c3 + f3[b + 1])), - (c1 + c1 + f1[b + 1]) - * ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])), - ]) - }) - .sum::>(); - let res = if f1.len() == 1 { - AdditiveArray::<_, 4>([f1[0] * f2[0] * f3[0]; 4]) - } else { - res - }; - let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f1.len()).max(1) + self.round - 1); - if num_vars_multiplicity > 0 { - AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) - } else { - res - } - }, - |sum| AdditiveArray(sum.0.map(E::from)) - ) - .to_vec() - } - _ => unimplemented!("do not support degree > 3"), - }; - exit_span!(span); - sum.iter_mut().for_each(|sum| *sum *= coefficient); - - let span = entered_span!("extrapolation"); - let extrapolation = (0..self.poly.aux_info.max_degree - products.len()) - .map(|i| { - let (points, weights) = &self.extrapolation_aux[products.len() - 1]; - let at = E::from((products.len() + 1 + i) as u64); - serial_extrapolate(points, weights, &sum, &at) - }) - .collect::>(); - sum.extend(extrapolation); - exit_span!(span); - let span = entered_span!("extend_extrapolate"); - products_sum += AdditiveVec(sum); - exit_span!(span); - products_sum - }, - ); - exit_span!(span); - - end_timer!(start); - - IOPProverMessage { - evaluations: products_sum, - } - } - - /// collect all mle evaluation (claim) after sumcheck - pub fn get_mle_final_evaluations(&self) -> Vec { - self.poly - .flattened_ml_extensions - .iter() - .map(|mle| { - assert!( - mle.evaluations().len() == 1, - "mle.evaluations.len() {} != 1, must be called after prove_round_and_update_state", - mle.evaluations().len(), - ); - op_mle! { - |mle| mle[0], - |eval| E::from(eval) - } - }) - .collect() - } -} - -/// parallel version -#[deprecated(note = "deprecated parallel version due to syncronizaion overhead")] -impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { - /// Given a virtual polynomial, generate an IOP proof. - #[tracing::instrument(skip_all, name = "sumcheck::prove_parallel")] - pub fn prove_parallel( - poly: VirtualPolynomialV2<'a, E>, - transcript: &mut impl Transcript, - ) -> (IOPProof, IOPProverStateV2<'a, E>) { - let (num_variables, max_degree) = - (poly.aux_info.max_num_variables, poly.aux_info.max_degree); - - // return empty proof when target polymonial is constant - if num_variables == 0 { - return (IOPProof::default(), IOPProverStateV2 { - poly, - ..Default::default() - }); - } - let start = start_timer!(|| "sum check prove"); - - transcript.append_message(&num_variables.to_le_bytes()); - transcript.append_message(&max_degree.to_le_bytes()); - - let mut prover_state = Self::prover_init_parallel(poly); - let mut challenge = None; - let mut prover_msgs = Vec::with_capacity(num_variables); - let span = entered_span!("prove_rounds"); - for _ in 0..num_variables { - let prover_msg = IOPProverStateV2::prove_round_and_update_state_parallel( - &mut prover_state, - &challenge, - ); - - prover_msg - .evaluations - .iter() - .for_each(|e| transcript.append_field_element_ext(e)); - - prover_msgs.push(prover_msg); - let span = entered_span!("get_challenge"); - challenge = Some(transcript.get_and_append_challenge(b"Internal round")); - exit_span!(span); - } - exit_span!(span); - - let span = entered_span!("after_rounds_prover_state"); - // pushing the last challenge point to the state - if let Some(p) = challenge { - prover_state.challenges.push(p); - // fix last challenge to collect final evaluation - prover_state - .poly - .flattened_ml_extensions - .par_iter_mut() - .for_each(|mle| { - if let Some(mle) = Arc::get_mut(mle) { - if mle.num_vars() > 0 { - mle.fix_variables_in_place(&[p.elements]) - } - } else { - *mle = Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - 0, - mle.get_base_field_vec().to_vec(), - )) - } - }); - }; - exit_span!(span); - - end_timer!(start); - ( - IOPProof { - // the point consists of the first elements in the challenge - point: prover_state - .challenges - .iter() - .map(|challenge| challenge.elements) - .collect(), - proofs: prover_msgs, - }, - prover_state, - ) - } - - /// Initialize the prover state to argue for the sum of the input polynomial - /// over {0,1}^`num_vars`. - pub(crate) fn prover_init_parallel(polynomial: VirtualPolynomialV2<'a, E>) -> Self { - let start = start_timer!(|| "sum check prover init"); - assert_ne!( - polynomial.aux_info.max_num_variables, 0, - "Attempt to prove a constant." - ); - - let max_degree = polynomial.aux_info.max_degree; - let prover_state = Self { - challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables), - round: 0, - poly: polynomial, - extrapolation_aux: (1..max_degree) - .map(|degree| { - let points = (0..1 + degree as u64).map(E::from).collect::>(); - let weights = barycentric_weights(&points); - (points, weights) - }) - .collect(), - }; - - end_timer!(start); - prover_state - } - - /// Receive message from verifier, generate prover message, and proceed to - /// next round. - /// - /// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2). - #[tracing::instrument(skip_all, name = "sumcheck::prove_round_and_update_state_parallel")] - pub(crate) fn prove_round_and_update_state_parallel( - &mut self, - challenge: &Option>, - ) -> IOPProverMessage { - let start = - start_timer!(|| format!("sum check prove {}-th round and update state", self.round)); - - assert!( - self.round < self.poly.aux_info.max_num_variables, - "Prover is not active" - ); - - // let fix_argument = start_timer!(|| "fix argument"); - - // Step 1: - // fix argument and evaluate f(x) over x_m = r; where r is the challenge - // for the current round, and m is the round number, indexed from 1 - // - // i.e.: - // at round m <= n, for each mle g(x_1, ... x_n) within the flattened_mle - // which has already been evaluated to - // - // g(r_1, ..., r_{m-1}, x_m ... x_n) - // - // eval g over r_m, and mutate g to g(r_1, ... r_m,, x_{m+1}... x_n) - let span = entered_span!("fix_variables"); - if self.round == 0 { - assert!(challenge.is_none(), "first round should be prover first."); - } else { - assert!(challenge.is_some(), "verifier message is empty"); - let chal = challenge.unwrap(); - self.challenges.push(chal); - let r = self.challenges[self.round - 1]; - - if self.challenges.len() == 1 { - self.poly - .flattened_ml_extensions - .par_iter_mut() - .for_each(|f| { - if f.num_vars() > 0 { - *f = Arc::new(f.fix_variables_parallel(&[r.elements])); - } else { - panic!("calling sumcheck on constant") - } - }); - } else { - self.poly - .flattened_ml_extensions - .par_iter_mut() - // benchmark result indicate make_mut achieve better performange than get_mut, - // which can be +5% overhead rust docs doen't explain the - // reason - .map(Arc::get_mut) - .for_each(|f| { - if let Some(f) = f { - if f.num_vars() > 0 { - f.fix_variables_in_place_parallel(&[r.elements]) - } - } - }); - } - } - exit_span!(span); - // end_timer!(fix_argument); - - self.round += 1; - - // Step 2: generate sum for the partial evaluated polynomial: - // f(r_1, ... r_m,, x_{m+1}... x_n) - let span = entered_span!("products_sum"); - let AdditiveVec(products_sum) = self - .poly - .products - .par_iter() - .fold_with( - AdditiveVec::new(self.poly.aux_info.max_degree + 1), - |mut products_sum, (coefficient, products)| { - let span = entered_span!("sum"); - - let mut sum = match products.len() { - 1 => { - let f = &self.poly.flattened_ml_extensions[products[0]]; - op_mle! { - |f| { - let res = (0..largest_even_below(f.len())) - .into_par_iter() - .step_by(2) - .with_min_len(64) - .map(|b| { - AdditiveArray([ - f[b], - f[b + 1] - ]) - }) - .sum::>(); - let res = if f.len() == 1 { - AdditiveArray::<_, 2>([f[0]; 2]) - } else { - res - }; - let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); - if num_vars_multiplicity > 0 { - AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) - } else { - res - } - }, - |sum| AdditiveArray(sum.0.map(E::from)) - } - .to_vec() - } - 2 => { - let (f, g) = ( - &self.poly.flattened_ml_extensions[products[0]], - &self.poly.flattened_ml_extensions[products[1]], - ); - commutative_op_mle_pair!( - |f, g| { - let res = (0..largest_even_below(f.len())) - .into_par_iter() - .step_by(2) - .with_min_len(64) - .map(|b| { - AdditiveArray([ - f[b] * g[b], - f[b + 1] * g[b + 1], - (f[b + 1] + f[b + 1] - f[b]) - * (g[b + 1] + g[b + 1] - g[b]), - ]) - }) - .sum::>(); - let res = if f.len() == 1 { - AdditiveArray::<_, 3>([f[0] * g[0]; 3]) - } else { - res - }; - let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); - if num_vars_multiplicity > 0 { - AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) - } else { - res - } - }, - |sum| AdditiveArray(sum.0.map(E::from)) - ) - .to_vec() - } - 3 => { - let (f1, f2, f3) = ( - &self.poly.flattened_ml_extensions[products[0]], - &self.poly.flattened_ml_extensions[products[1]], - &self.poly.flattened_ml_extensions[products[2]], - ); - op_mle_product_3!( - |f1, f2, f3| { - let res = (0..largest_even_below(f1.len())) - .step_by(2) - .map(|b| { - // f = c x + d - let c1 = f1[b + 1] - f1[b]; - let c2 = f2[b + 1] - f2[b]; - let c3 = f3[b + 1] - f3[b]; - AdditiveArray([ - f1[b] * (f2[b] * f3[b]), - f1[b + 1] * (f2[b + 1] * f3[b + 1]), - (c1 + f1[b + 1]) - * ((c2 + f2[b + 1]) * (c3 + f3[b + 1])), - (c1 + c1 + f1[b + 1]) - * ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])), - ]) - }) - .sum::>(); - let res = if f1.len() == 1 { - AdditiveArray::<_, 4>([f1[0] * f2[0] * f3[0]; 4]) - } else { - res - }; - let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f1.len()).max(1) + self.round - 1); - if num_vars_multiplicity > 0 { - AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) - } else { - res - } - }, - |sum| AdditiveArray(sum.0.map(E::from)) - ) - .to_vec() - } - _ => unimplemented!("do not support degree > 3"), - }; - exit_span!(span); - sum.iter_mut().for_each(|sum| *sum *= coefficient); - - let span = entered_span!("extrapolation"); - let extrapolation = (0..self.poly.aux_info.max_degree - products.len()) - .into_par_iter() - .map(|i| { - let (points, weights) = &self.extrapolation_aux[products.len() - 1]; - let at = E::from((products.len() + 1 + i) as u64); - extrapolate(points, weights, &sum, &at) - }) - .collect::>(); - sum.extend(extrapolation); - exit_span!(span); - let span = entered_span!("extend_extrapolate"); - products_sum += AdditiveVec(sum); - exit_span!(span); - products_sum - }, - ) - .reduce_with(|acc, item| acc + item) - .unwrap(); - exit_span!(span); - - end_timer!(start); - - IOPProverMessage { - evaluations: products_sum, - } - } -} diff --git a/sumcheck/src/structs.rs b/sumcheck/src/structs.rs index a6089722a..2316d79aa 100644 --- a/sumcheck/src/structs.rs +++ b/sumcheck/src/structs.rs @@ -1,7 +1,5 @@ use ff_ext::ExtensionField; -use multilinear_extensions::{ - virtual_poly::VirtualPolynomial, virtual_poly_v2::VirtualPolynomialV2, -}; +use multilinear_extensions::virtual_poly::VirtualPolynomial; use serde::{Deserialize, Serialize}; use transcript::Challenge; @@ -28,27 +26,13 @@ pub struct IOPProverMessage { /// Prover State of a PolyIOP. #[derive(Default)] -pub struct IOPProverStateV2<'a, E: ExtensionField> { +pub struct IOPProverState<'a, E: ExtensionField> { /// sampled randomness given by the verifier pub challenges: Vec>, /// the current round number pub(crate) round: usize, /// pointer to the virtual polynomial - pub(crate) poly: VirtualPolynomialV2<'a, E>, - /// points with precomputed barycentric weights for extrapolating smaller - /// degree uni-polys to `max_degree + 1` evaluations. - pub(crate) extrapolation_aux: Vec<(Vec, Vec)>, -} - -/// Prover State of a PolyIOP. -#[derive(Default)] -pub struct IOPProverState { - /// sampled randomness given by the verifier - pub challenges: Vec>, - /// the current round number - pub(crate) round: usize, - /// pointer to the virtual polynomial - pub(crate) poly: VirtualPolynomial, + pub(crate) poly: VirtualPolynomial<'a, E>, /// points with precomputed barycentric weights for extrapolating smaller /// degree uni-polys to `max_degree + 1` evaluations. pub(crate) extrapolation_aux: Vec<(Vec, Vec)>, diff --git a/sumcheck/src/test.rs b/sumcheck/src/test.rs index 0c24e83cd..6b2d4a2c3 100644 --- a/sumcheck/src/test.rs +++ b/sumcheck/src/test.rs @@ -4,7 +4,7 @@ use ark_std::{rand::RngCore, test_rng}; use ff::Field; use ff_ext::ExtensionField; use goldilocks::GoldilocksExt2; -use multilinear_extensions::{mle::MultilinearExtension, virtual_poly::VirtualPolynomial}; +use multilinear_extensions::virtual_poly::VirtualPolynomial; use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; use transcript::{BasicTranscript, Transcript}; @@ -52,7 +52,7 @@ fn test_sumcheck_internal( let mut rng = test_rng(); let (poly, asserted_sum) = VirtualPolynomial::::random(nv, num_multiplicands_range, num_products, &mut rng); - let (poly_info, num_variables) = (poly.aux_info.clone(), poly.aux_info.num_variables); + let (poly_info, num_variables) = (poly.aux_info.clone(), poly.aux_info.max_num_variables); #[allow(deprecated)] let mut prover_state = IOPProverState::prover_init_parallel(poly.clone()); let mut verifier_state = IOPVerifierState::verifier_init(&poly_info); @@ -81,7 +81,9 @@ fn test_sumcheck_internal( .flattened_ml_extensions .par_iter_mut() .for_each(|mle| { - Arc::make_mut(mle).fix_variables_in_place(&[p.elements]); + Arc::get_mut(mle) + .unwrap() + .fix_variables_in_place(&[p.elements]); }); }; let subclaim = IOPVerifierState::check_and_generate_subclaim(&verifier_state, &asserted_sum); diff --git a/sumcheck/src/util.rs b/sumcheck/src/util.rs index d42adb966..4fa84f1e2 100644 --- a/sumcheck/src/util.rs +++ b/sumcheck/src/util.rs @@ -2,7 +2,6 @@ use std::{ array, cmp::max, iter::Sum, - mem, ops::{Add, AddAssign, Deref, DerefMut, Mul, MulAssign}, sync::Arc, }; @@ -11,14 +10,11 @@ use ark_std::{end_timer, start_timer}; use ff::PrimeField; use ff_ext::ExtensionField; use multilinear_extensions::{ - mle::{DenseMultilinearExtension, FieldType}, - op_mle, - virtual_poly::VirtualPolynomial, - virtual_poly_v2::VirtualPolynomialV2, + mle::DenseMultilinearExtension, op_mle, virtual_poly::VirtualPolynomial, }; use rayon::{prelude::ParallelIterator, slice::ParallelSliceMut}; -use crate::structs::{IOPProverState, IOPProverStateV2}; +use crate::structs::IOPProverState; pub fn barycentric_weights(points: &[F]) -> Vec { let mut weights = points @@ -221,41 +217,10 @@ pub fn ceil_log2(x: usize) -> usize { usize_bits - (x - 1).leading_zeros() as usize } -pub(crate) fn merge_sumcheck_polys( - prover_states: &[IOPProverState], +pub(crate) fn merge_sumcheck_polys<'a, E: ExtensionField>( + prover_states: &[IOPProverState<'a, E>], max_thread_id: usize, -) -> VirtualPolynomial { - let log2_max_thread_id = ceil_log2(max_thread_id); - let mut poly = prover_states[0].poly.clone(); // giving only one evaluation left, this clone is low cost. - poly.aux_info.num_variables = log2_max_thread_id; // size_log2 variates sumcheck - for i in 0..poly.flattened_ml_extensions.len() { - let ml_ext = Arc::make_mut(&mut poly.flattened_ml_extensions[i]); - let _ = mem::replace(&mut ml_ext.evaluations, { - let evaluations = prover_states - .iter() - .map(|prover_state| { - if let FieldType::Ext(evaluations) = - &prover_state.poly.flattened_ml_extensions[i].evaluations - { - assert!(evaluations.len() == 1); - evaluations[0] - } else { - unreachable!() - } - }) - .collect::>(); - assert!(evaluations.len() == max_thread_id); - FieldType::Ext(evaluations) - }); - ml_ext.num_vars = log2_max_thread_id; - } - poly -} - -pub(crate) fn merge_sumcheck_polys_v2<'a, E: ExtensionField>( - prover_states: &[IOPProverStateV2<'a, E>], - max_thread_id: usize, -) -> VirtualPolynomialV2<'a, E> { +) -> VirtualPolynomial<'a, E> { let log2_max_thread_id = ceil_log2(max_thread_id); let mut poly = prover_states[0].poly.clone(); // giving only one evaluation left, this clone is low cost. poly.aux_info.max_num_variables = log2_max_thread_id; // size_log2 variates sumcheck diff --git a/sumcheck/src/verifier.rs b/sumcheck/src/verifier.rs index 4dcd8767e..3adc70d27 100644 --- a/sumcheck/src/verifier.rs +++ b/sumcheck/src/verifier.rs @@ -15,7 +15,7 @@ impl IOPVerifierState { aux_info: &VPAuxInfo, transcript: &mut impl Transcript, ) -> SumCheckSubClaim { - if aux_info.num_variables == 0 { + if aux_info.max_num_variables == 0 { return SumCheckSubClaim { point: vec![], expected_evaluation: claimed_sum, @@ -23,11 +23,11 @@ impl IOPVerifierState { } let start = start_timer!(|| "sum check verify"); - transcript.append_message(&aux_info.num_variables.to_le_bytes()); + transcript.append_message(&aux_info.max_num_variables.to_le_bytes()); transcript.append_message(&aux_info.max_degree.to_le_bytes()); let mut verifier_state = IOPVerifierState::verifier_init(aux_info); - for i in 0..aux_info.num_variables { + for i in 0..aux_info.max_num_variables { let prover_msg = proof.proofs.get(i).expect("proof is incomplete"); prover_msg .evaluations @@ -47,11 +47,11 @@ impl IOPVerifierState { let start = start_timer!(|| "sum check verifier init"); let verifier_state = Self { round: 1, - num_vars: index_info.num_variables, + num_vars: index_info.max_num_variables, max_degree: index_info.max_degree, finished: false, - polynomials_received: Vec::with_capacity(index_info.num_variables), - challenges: Vec::with_capacity(index_info.num_variables), + polynomials_received: Vec::with_capacity(index_info.max_num_variables), + challenges: Vec::with_capacity(index_info.max_num_variables), }; end_timer!(start); verifier_state From 0f11835aea97b16e280e0f8e552123dcfa85e979 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20G=C3=B6rgens?= Date: Fri, 3 Jan 2025 16:44:07 +0800 Subject: [PATCH 04/12] Remove obsolete `ceno_rt/Cargo.lock` (#803) 830b840b0db15c4e5ac2484acf19b038c1281d1f / PR#759 made `ceno_rt/Cargo.lock` obsolete, but we forgot deleting it back then. Extracted from https://github.com/scroll-tech/ceno/pull/802 --- ceno_rt/Cargo.lock | 211 --------------------------------------------- 1 file changed, 211 deletions(-) delete mode 100644 ceno_rt/Cargo.lock diff --git a/ceno_rt/Cargo.lock b/ceno_rt/Cargo.lock deleted file mode 100644 index 9c9ebc993..000000000 --- a/ceno_rt/Cargo.lock +++ /dev/null @@ -1,211 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 4 - -[[package]] -name = "bytecheck" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50c8f430744b23b54ad15161fcbc22d82a29b73eacbe425fea23ec822600bc6f" -dependencies = [ - "bytecheck_derive", - "ptr_meta", - "rancor", - "simdutf8", -] - -[[package]] -name = "bytecheck_derive" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "523363cbe1df49b68215efdf500b103ac3b0fb4836aed6d15689a076eadb8fff" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "bytes" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" - -[[package]] -name = "ceno_rt" -version = "0.1.0" -dependencies = [ - "rkyv", -] - -[[package]] -name = "equivalent" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" - -[[package]] -name = "hashbrown" -version = "0.15.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" - -[[package]] -name = "indexmap" -version = "2.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" -dependencies = [ - "equivalent", - "hashbrown", -] - -[[package]] -name = "munge" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64142d38c84badf60abf06ff9bd80ad2174306a5b11bd4706535090a30a419df" -dependencies = [ - "munge_macro", -] - -[[package]] -name = "munge_macro" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bb5c1d8184f13f7d0ccbeeca0def2f9a181bce2624302793005f5ca8aa62e5e" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "proc-macro2" -version = "1.0.92" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "ptr_meta" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe9e76f66d3f9606f44e45598d155cb13ecf09f4a28199e48daf8c8fc937ea90" -dependencies = [ - "ptr_meta_derive", -] - -[[package]] -name = "ptr_meta_derive" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca414edb151b4c8d125c12566ab0d74dc9cdba36fb80eb7b848c15f495fd32d1" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "quote" -version = "1.0.37" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "rancor" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "caf5f7161924b9d1cea0e4cabc97c372cea92b5f927fc13c6bca67157a0ad947" -dependencies = [ - "ptr_meta", -] - -[[package]] -name = "rend" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a35e8a6bf28cd121053a66aa2e6a2e3eaffad4a60012179f0e864aa5ffeff215" -dependencies = [ - "bytecheck", -] - -[[package]] -name = "rkyv" -version = "0.8.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b11a153aec4a6ab60795f8ebe2923c597b16b05bb1504377451e705ef1a45323" -dependencies = [ - "bytecheck", - "bytes", - "hashbrown", - "indexmap", - "munge", - "ptr_meta", - "rancor", - "rend", - "rkyv_derive", - "tinyvec", - "uuid", -] - -[[package]] -name = "rkyv_derive" -version = "0.8.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "beb382a4d9f53bd5c0be86b10d8179c3f8a14c30bf774ff77096ed6581e35981" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "simdutf8" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" - -[[package]] -name = "syn" -version = "2.0.90" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "919d3b74a5dd0ccd15aeb8f93e7006bd9e14c295087c9896a110f490752bcf31" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "tinyvec" -version = "1.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" -dependencies = [ - "tinyvec_macros", -] - -[[package]] -name = "tinyvec_macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" - -[[package]] -name = "unicode-ident" -version = "1.0.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" - -[[package]] -name = "uuid" -version = "1.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" From b7d9524c87236ea84510e68f875e8c6b912506f4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 3 Jan 2025 16:54:37 +0800 Subject: [PATCH 05/12] Bump glob from 0.3.1 to 0.3.2 (#796) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [glob](https://github.com/rust-lang/glob) from 0.3.1 to 0.3.2.
Release notes

Sourced from glob's releases.

v0.3.2

What's Changed

New Contributors

Full Changelog: https://github.com/rust-lang/glob/compare/0.3.1...v0.3.2

Changelog

Sourced from glob's changelog.

0.3.2 - 2024-12-28

What's Changed

New Contributors

Full Changelog: https://github.com/rust-lang/glob/compare/0.3.1...0.3.2

Commits
  • 58d0748 chore: release v0.3.2
  • 55b1be0 Merge pull request #150 from tgross35/release-plz
  • 56054d2 Add release-plz for automated releases
  • b93bca1 Merge pull request #151 from tgross35/fix-ci
  • 1dff477 Add a success job to CI for branch protection
  • 9bd1af8 Update CI runners to the latest available versions
  • 8c5d22c Check only (no longer test) at the MSRV
  • 89ef8a3 Clean up the CI configuration file
  • 49ee1e9 Merge pull request #140 from rust-lang/dependabot/github_actions/actions/chec...
  • 9c9f43f Bump actions/checkout from 3 to 4
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=glob&package-manager=cargo&previous-version=0.3.1&new-version=0.3.2)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 355c8472c..0d234772f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -811,9 +811,9 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "glob" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "goldilocks" From 989f607612957822a275246cc62dd76cfd611d9f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 3 Jan 2025 16:54:51 +0800 Subject: [PATCH 06/12] Bump serde from 1.0.216 to 1.0.217 (#795) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [serde](https://github.com/serde-rs/serde) from 1.0.216 to 1.0.217.
Release notes

Sourced from serde's releases.

v1.0.217

  • Support serializing externally tagged unit variant inside flattened field (#2786, thanks @​Mingun)
Commits
  • 930401b Release 1.0.217
  • cb6eaea Fix roundtrip inconsistency:
  • b6f339c Resolve repr_packed_without_abi clippy lint in tests
  • 2a5caea Merge pull request #2872 from dtolnay/ehpersonality
  • b9f93f9 Add no-std CI on stable compiler
  • eb5cd47 Drop #[lang = "eh_personality"] from no-std test
  • 8478a3b Merge pull request #2871 from dtolnay/nostdstart
  • dbb9091 Replace #[start] with extern fn main
  • See full diff in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=serde&package-manager=cargo&previous-version=1.0.216&new-version=1.0.217)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0d234772f..4104eee37 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1876,18 +1876,18 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.216" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e" +checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.216" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e" +checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", From cfa349b7f27c9bdda2199e5cb0353316dcc23e38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20G=C3=B6rgens?= Date: Mon, 6 Jan 2025 10:47:03 +0800 Subject: [PATCH 07/12] Make examples part of the workspace (#804) That follows the lead of SP1, and allows things like `cargo clippy` to see the examples. Extracted from https://github.com/scroll-tech/ceno/pull/802 --- Cargo.lock | 8 ++++++++ Cargo.toml | 2 +- ceno_rt/src/lib.rs | 4 +--- ceno_rt/src/syscalls.rs | 8 ++++++-- examples-builder/build.rs | 2 +- 5 files changed, 17 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4104eee37..50dfadc36 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -710,6 +710,14 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "examples" +version = "0.1.0" +dependencies = [ + "ceno_rt", + "rkyv", +] + [[package]] name = "fastrand" version = "2.3.0" diff --git a/Cargo.toml b/Cargo.toml index 04c31fd15..0a5361732 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,11 @@ [workspace] -exclude = ["examples"] members = [ "ceno_emul", "ceno_host", "ceno_rt", "ceno_zkvm", "examples-builder", + "examples", "mpcs", "multilinear_extensions", "poseidon", diff --git a/ceno_rt/src/lib.rs b/ceno_rt/src/lib.rs index 576615802..697d0952a 100644 --- a/ceno_rt/src/lib.rs +++ b/ceno_rt/src/lib.rs @@ -22,10 +22,8 @@ pub use io::info_out; mod params; pub use params::*; -#[cfg(target_arch = "riscv32")] mod syscalls; -#[cfg(target_arch = "riscv32")] -pub use syscalls::*; +pub use syscalls::syscall_keccak_permute; #[no_mangle] #[linkage = "weak"] diff --git a/ceno_rt/src/syscalls.rs b/ceno_rt/src/syscalls.rs index 90ace85da..7de730ae3 100644 --- a/ceno_rt/src/syscalls.rs +++ b/ceno_rt/src/syscalls.rs @@ -1,7 +1,9 @@ // Based on https://github.com/succinctlabs/sp1/blob/013c24ea2fa15a0e7ed94f7d11a7ada4baa39ab9/crates/zkvm/entrypoint/src/syscalls/keccak_permute.rs +#[allow(dead_code)] const KECCAK_PERMUTE: u32 = 0x00_01_01_09; +#[cfg(target_os = "zkvm")] use core::arch::asm; /// Executes the Keccak256 permutation on the given state. @@ -11,8 +13,8 @@ use core::arch::asm; /// The caller must ensure that `state` is valid pointer to data that is aligned along a four /// byte boundary. #[allow(unused_variables)] -#[no_mangle] -pub extern "C" fn syscall_keccak_permute(state: &mut [u64; 25]) { +pub fn syscall_keccak_permute(state: &mut [u64; 25]) { + #[cfg(target_os = "zkvm")] unsafe { asm!( "ecall", @@ -21,4 +23,6 @@ pub extern "C" fn syscall_keccak_permute(state: &mut [u64; 25]) { in("a1") 0 ); } + #[cfg(not(target_os = "zkvm"))] + unreachable!() } diff --git a/examples-builder/build.rs b/examples-builder/build.rs index b6dee0cfe..10d79d4f3 100644 --- a/examples-builder/build.rs +++ b/examples-builder/build.rs @@ -40,7 +40,7 @@ fn build_elfs() { // TODO(Matthias): skip building the elfs if we are in clippy or check mode. // See git history for an attempt to do this. let output = Command::new("cargo") - .args(["build", "--release", "--examples"]) + .args(["build", "--release", "--examples", "--target-dir", "target"]) .current_dir("../examples") .env_clear() .envs(std::env::vars().filter(|x| !x.0.starts_with("CARGO_"))) From a3348f1a3d2caf57075472198083373105beb30c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20G=C3=B6rgens?= Date: Mon, 6 Jan 2025 11:09:55 +0800 Subject: [PATCH 08/12] Deduplicate magic syscall number (#802) Also make examples visible to a plain `cargo clippy --workspace --all-targets`. But that's now extracted into https://github.com/scroll-tech/ceno/pull/804 to be merged first. Similar with https://github.com/scroll-tech/ceno/pull/803 We are using a technique inspired by the SP1 example already mentioned in `ceno_emul/src/syscalls.rs`. --- Cargo.lock | 1 + ceno_emul/Cargo.toml | 1 + ceno_emul/src/syscalls.rs | 2 +- ceno_rt/src/lib.rs | 3 +- ceno_rt/src/syscalls.rs | 8 +- examples/Cargo.lock | 249 ---------------------------- examples/examples/ceno_rt_keccak.rs | 4 +- 7 files changed, 9 insertions(+), 259 deletions(-) delete mode 100644 examples/Cargo.lock diff --git a/Cargo.lock b/Cargo.lock index 50dfadc36..1d9cc9427 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -286,6 +286,7 @@ name = "ceno_emul" version = "0.1.0" dependencies = [ "anyhow", + "ceno_rt", "elf", "itertools 0.13.0", "num-derive", diff --git a/ceno_emul/Cargo.toml b/ceno_emul/Cargo.toml index 4f4f4a7c7..2bc4c830f 100644 --- a/ceno_emul/Cargo.toml +++ b/ceno_emul/Cargo.toml @@ -11,6 +11,7 @@ version.workspace = true [dependencies] anyhow.workspace = true +ceno_rt = { path = "../ceno_rt" } elf = "0.7" itertools.workspace = true num-derive.workspace = true diff --git a/ceno_emul/src/syscalls.rs b/ceno_emul/src/syscalls.rs index 251bedbae..d5ca85402 100644 --- a/ceno_emul/src/syscalls.rs +++ b/ceno_emul/src/syscalls.rs @@ -6,7 +6,7 @@ pub mod keccak_permute; // Using the same function codes as sp1: // https://github.com/succinctlabs/sp1/blob/013c24ea2fa15a0e7ed94f7d11a7ada4baa39ab9/crates/core/executor/src/syscalls/code.rs -pub const KECCAK_PERMUTE: u32 = 0x00_01_01_09; +pub use ceno_rt::syscalls::KECCAK_PERMUTE; /// Trace the inputs and effects of a syscall. pub fn handle_syscall(vm: &VMState, function_code: u32) -> Result { diff --git a/ceno_rt/src/lib.rs b/ceno_rt/src/lib.rs index 697d0952a..b48279d69 100644 --- a/ceno_rt/src/lib.rs +++ b/ceno_rt/src/lib.rs @@ -22,8 +22,7 @@ pub use io::info_out; mod params; pub use params::*; -mod syscalls; -pub use syscalls::syscall_keccak_permute; +pub mod syscalls; #[no_mangle] #[linkage = "weak"] diff --git a/ceno_rt/src/syscalls.rs b/ceno_rt/src/syscalls.rs index 7de730ae3..1ebef70bc 100644 --- a/ceno_rt/src/syscalls.rs +++ b/ceno_rt/src/syscalls.rs @@ -1,11 +1,9 @@ // Based on https://github.com/succinctlabs/sp1/blob/013c24ea2fa15a0e7ed94f7d11a7ada4baa39ab9/crates/zkvm/entrypoint/src/syscalls/keccak_permute.rs - -#[allow(dead_code)] -const KECCAK_PERMUTE: u32 = 0x00_01_01_09; - #[cfg(target_os = "zkvm")] use core::arch::asm; +pub const KECCAK_PERMUTE: u32 = 0x00_01_01_09; + /// Executes the Keccak256 permutation on the given state. /// /// ### Safety @@ -13,7 +11,7 @@ use core::arch::asm; /// The caller must ensure that `state` is valid pointer to data that is aligned along a four /// byte boundary. #[allow(unused_variables)] -pub fn syscall_keccak_permute(state: &mut [u64; 25]) { +pub fn keccak_permute(state: &mut [u64; 25]) { #[cfg(target_os = "zkvm")] unsafe { asm!( diff --git a/examples/Cargo.lock b/examples/Cargo.lock deleted file mode 100644 index 5808aab96..000000000 --- a/examples/Cargo.lock +++ /dev/null @@ -1,249 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 4 - -[[package]] -name = "bytecheck" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50c8f430744b23b54ad15161fcbc22d82a29b73eacbe425fea23ec822600bc6f" -dependencies = [ - "bytecheck_derive", - "ptr_meta", - "rancor", - "simdutf8", -] - -[[package]] -name = "bytecheck_derive" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "523363cbe1df49b68215efdf500b103ac3b0fb4836aed6d15689a076eadb8fff" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "bytes" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" - -[[package]] -name = "ceno_rt" -version = "0.1.0" -dependencies = [ - "getrandom", - "rkyv", -] - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "equivalent" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" - -[[package]] -name = "examples" -version = "0.1.0" -dependencies = [ - "ceno_rt", - "rkyv", -] - -[[package]] -name = "getrandom" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" -dependencies = [ - "cfg-if", - "libc", - "wasi", -] - -[[package]] -name = "hashbrown" -version = "0.15.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" - -[[package]] -name = "indexmap" -version = "2.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" -dependencies = [ - "equivalent", - "hashbrown", -] - -[[package]] -name = "libc" -version = "0.2.169" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" - -[[package]] -name = "munge" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64142d38c84badf60abf06ff9bd80ad2174306a5b11bd4706535090a30a419df" -dependencies = [ - "munge_macro", -] - -[[package]] -name = "munge_macro" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bb5c1d8184f13f7d0ccbeeca0def2f9a181bce2624302793005f5ca8aa62e5e" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "proc-macro2" -version = "1.0.92" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "ptr_meta" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe9e76f66d3f9606f44e45598d155cb13ecf09f4a28199e48daf8c8fc937ea90" -dependencies = [ - "ptr_meta_derive", -] - -[[package]] -name = "ptr_meta_derive" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca414edb151b4c8d125c12566ab0d74dc9cdba36fb80eb7b848c15f495fd32d1" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "quote" -version = "1.0.37" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "rancor" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "caf5f7161924b9d1cea0e4cabc97c372cea92b5f927fc13c6bca67157a0ad947" -dependencies = [ - "ptr_meta", -] - -[[package]] -name = "rend" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a35e8a6bf28cd121053a66aa2e6a2e3eaffad4a60012179f0e864aa5ffeff215" -dependencies = [ - "bytecheck", -] - -[[package]] -name = "rkyv" -version = "0.8.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b11a153aec4a6ab60795f8ebe2923c597b16b05bb1504377451e705ef1a45323" -dependencies = [ - "bytecheck", - "bytes", - "hashbrown", - "indexmap", - "munge", - "ptr_meta", - "rancor", - "rend", - "rkyv_derive", - "tinyvec", - "uuid", -] - -[[package]] -name = "rkyv_derive" -version = "0.8.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "beb382a4d9f53bd5c0be86b10d8179c3f8a14c30bf774ff77096ed6581e35981" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "simdutf8" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" - -[[package]] -name = "syn" -version = "2.0.91" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d53cbcb5a243bd33b7858b1d7f4aca2153490815872d86d955d6ea29f743c035" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "tinyvec" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8" -dependencies = [ - "tinyvec_macros", -] - -[[package]] -name = "tinyvec_macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" - -[[package]] -name = "unicode-ident" -version = "1.0.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" - -[[package]] -name = "uuid" -version = "1.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" - -[[package]] -name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" diff --git a/examples/examples/ceno_rt_keccak.rs b/examples/examples/ceno_rt_keccak.rs index 7a01557bf..57d9f76ed 100644 --- a/examples/examples/ceno_rt_keccak.rs +++ b/examples/examples/ceno_rt_keccak.rs @@ -3,7 +3,7 @@ //! Iterate multiple times and log the state after each iteration. extern crate ceno_rt; -use ceno_rt::{info_out, syscall_keccak_permute}; +use ceno_rt::{info_out, syscalls::keccak_permute}; use core::slice; const ITERATIONS: usize = 3; @@ -12,7 +12,7 @@ fn main() { let mut state = [0_u64; 25]; for _ in 0..ITERATIONS { - syscall_keccak_permute(&mut state); + keccak_permute(&mut state); log_state(&state); } } From 845b475f14c249818d9146ca61606d9440c4c49a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20G=C3=B6rgens?= Date: Mon, 6 Jan 2025 20:28:33 +0800 Subject: [PATCH 09/12] Support conversion of hints to `Vec` (#807) That's useful to produce a hint file we can then hand back to our `e2e` binary. Supports https://github.com/scroll-tech/sproll-evm/pull/87 --- ceno_host/src/lib.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/ceno_host/src/lib.rs b/ceno_host/src/lib.rs index 8e033365f..d6cd94ea1 100644 --- a/ceno_host/src/lib.rs +++ b/ceno_host/src/lib.rs @@ -99,14 +99,19 @@ impl Items { } } -impl From<&CenoStdin> for Vec { - fn from(stdin: &CenoStdin) -> Vec { +impl From<&CenoStdin> for Vec { + fn from(stdin: &CenoStdin) -> Vec { let mut items = Items::default(); for item in &stdin.items { items.append(Item::from(item)); } - items - .finalise() + items.finalise() + } +} + +impl From<&CenoStdin> for Vec { + fn from(stdin: &CenoStdin) -> Vec { + Vec::::from(stdin) .into_iter() .tuples() .map(|(a, b, c, d)| u32::from_le_bytes([a, b, c, d])) From c1b2a7c323d15743bd45e0d962dbeec95dbb3d3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20G=C3=B6rgens?= Date: Tue, 7 Jan 2025 10:29:45 +0800 Subject: [PATCH 10/12] Heap and stack sizes in kiB and MiB (#808) Allow specification of guest heap and stack sizes on the command line in more convenient units than just raw bytes. --- Cargo.lock | 7 +++++++ ceno_zkvm/Cargo.toml | 1 + ceno_zkvm/src/bin/e2e.rs | 11 +++++++++-- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1d9cc9427..ef8dbb072 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -339,6 +339,7 @@ dependencies = [ "mpcs", "multilinear_extensions", "num-traits", + "parse-size", "paste", "pprof2", "prettytable-rs", @@ -1364,6 +1365,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "parse-size" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "487f2ccd1e17ce8c1bfab3a65c89525af41cfad4c8659021a1e9a2aacd73b89b" + [[package]] name = "pasta_curves" version = "0.5.1" diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 0f8241f57..97a7c9364 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -40,6 +40,7 @@ tracing-subscriber.workspace = true bincode = "1" clap = { version = "4.5", features = ["derive"] } generic_static = "0.2" +parse-size = "1.1" rand.workspace = true tempfile = "3.14" thread_local = "1.1" diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 47f268b36..5de9b3fab 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -18,6 +18,13 @@ use transcript::{ BasicTranscript as Transcript, BasicTranscriptWithStat as TranscriptWithStat, StatisticRecorder, }; +fn parse_size(s: &str) -> Result { + parse_size::Config::new() + .with_binary() + .parse_size(s) + .map(|size| size as u32) +} + /// Prove the execution of a fixed RISC-V program. #[derive(Parser, Debug)] #[command(version, about, long_about = None)] @@ -45,11 +52,11 @@ struct Args { hints: Option, /// Stack size in bytes. - #[arg(long, default_value = "32768")] + #[arg(long, default_value = "32k", value_parser = parse_size)] stack_size: u32, /// Heap size in bytes. - #[arg(long, default_value = "2097152")] + #[arg(long, default_value = "2M", value_parser = parse_size)] heap_size: u32, } From 4684ab9309d75e013ffbd2fe6caddb94e25b6597 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20G=C3=B6rgens?= Date: Wed, 8 Jan 2025 14:56:40 +0800 Subject: [PATCH 11/12] Less spam when running `e2e` (#815) This avoids spam like the following going on for thousands of lines: ``` INFO e2e: Running on platform Ceno Platform { rom: 536870912..538373056, prog_data: {536870912, [...] ``` --- ceno_emul/src/platform.rs | 23 ++++++++++++++++++++++- ceno_zkvm/src/bin/e2e.rs | 2 +- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index e8b06721c..034340c54 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -1,4 +1,5 @@ -use std::{collections::BTreeSet, ops::Range}; +use core::fmt::{self, Formatter}; +use std::{collections::BTreeSet, fmt::Display, ops::Range}; use crate::addr::{Addr, RegIdx}; @@ -19,6 +20,26 @@ pub struct Platform { pub unsafe_ecall_nop: bool, } +impl Display for Platform { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + let prog_data: Option> = match (self.prog_data.first(), self.prog_data.last()) { + (Some(first), Some(last)) => Some(*first..*last), + _ => None, + }; + write!( + f, + "Platform {{ rom: {:?}, prog_data: {:?}, stack: {:?}, heap: {:?}, public_io: {:?}, hints: {:?}, unsafe_ecall_nop: {} }}", + self.rom, + prog_data, + self.stack, + self.heap, + self.public_io, + self.hints, + self.unsafe_ecall_nop + ) + } +} + pub const CENO_PLATFORM: Platform = Platform { rom: 0x2000_0000..0x3000_0000, prog_data: BTreeSet::new(), diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 5de9b3fab..f1da01d51 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -113,7 +113,7 @@ fn main() { args.heap_size, pub_io_size, ); - tracing::info!("Running on platform {:?} {:?}", args.platform, platform); + tracing::info!("Running on platform {:?} {}", args.platform, platform); tracing::info!( "Stack: {} bytes. Heap: {} bytes.", args.stack_size, From b60c6fe3dde3d4641630733a9023eb7a085ae047 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20G=C3=B6rgens?= Date: Wed, 8 Jan 2025 15:06:05 +0800 Subject: [PATCH 12/12] Only complain about threads once (#816) This cuts down on logging spam. --- multilinear_extensions/src/util.rs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/multilinear_extensions/src/util.rs b/multilinear_extensions/src/util.rs index a0a8e56a2..230f187f4 100644 --- a/multilinear_extensions/src/util.rs +++ b/multilinear_extensions/src/util.rs @@ -40,11 +40,16 @@ pub fn max_usable_threads() -> usize { if cfg!(test) { 1 } else { - let n = rayon::current_num_threads(); - let threads = prev_power_of_two(n); - if n != threads { - tracing::warn!("thread size {n} is not power of 2, using {threads} threads instead."); - } - threads + static MAX_USABLE_THREADS: std::sync::OnceLock = std::sync::OnceLock::new(); + *MAX_USABLE_THREADS.get_or_init(|| { + let n = rayon::current_num_threads(); + let threads = prev_power_of_two(n); + if n != threads { + tracing::warn!( + "thread size {n} is not power of 2, using {threads} threads instead." + ); + } + threads + }) } }