diff --git a/spartan_parallel/src/custom_dense_mlpoly.rs b/spartan_parallel/src/custom_dense_mlpoly.rs index b8470db7..4cdcdaff 100644 --- a/spartan_parallel/src/custom_dense_mlpoly.rs +++ b/spartan_parallel/src/custom_dense_mlpoly.rs @@ -1,9 +1,12 @@ #![allow(clippy::too_many_arguments)] +use core::{unimplemented, unreachable}; use std::cmp::min; -use crate::dense_mlpoly::DensePolynomial; +use crate::{dense_mlpoly::DensePolynomial, instance}; +use ff::Field; use ff_ext::ExtensionField; -use multilinear_extensions::mle::DenseMultilinearExtension; +use multilinear_extensions::mle::{FieldType, MultilinearExtension, DenseMultilinearExtension, RangedMultilinearExtension}; +use std::{any::TypeId, borrow::Cow, mem, sync::Arc}; use super::math::Math; @@ -29,6 +32,7 @@ pub struct DensePolynomialPqx { // Let Q_max = max_num_proofs, assume that for a given P, num_proofs[P] = Q_i, then let STEP = Q_max / Q_i, // Z(P, y, .) is only non-zero if y is a multiple of STEP, so Z[P][j][.] actually stores Z(P, j*STEP, .) // The same applies to X + pub dense_multilinear: Option>, } // Reverse the bits in q or x @@ -50,7 +54,7 @@ impl DensePolynomialPqx { ) -> Self { let num_instances = z_mat.len().next_power_of_two(); let num_witness_secs = z_mat[0][0].len().next_power_of_two(); - DensePolynomialPqx { + let mut inst = DensePolynomialPqx { num_instances, num_proofs, max_num_proofs, @@ -58,7 +62,10 @@ impl DensePolynomialPqx { num_inputs, max_num_inputs, Z: z_mat, - } + dense_multilinear: None, + }; + inst.fill_dense_Z_poly(); + inst } // Assume z_mat is in its standard form of (p, q, x) @@ -101,7 +108,7 @@ impl DensePolynomialPqx { } } } - DensePolynomialPqx { + let mut inst = DensePolynomialPqx { num_instances: num_instances.next_power_of_two(), num_proofs, max_num_proofs, @@ -109,7 +116,10 @@ impl DensePolynomialPqx { num_inputs, max_num_inputs, Z, - } + dense_multilinear: None, + }; + inst.fill_dense_Z_poly(); + inst } pub fn len(&self) -> usize { @@ -319,6 +329,14 @@ impl DensePolynomialPqx { } } + pub fn flattened_len(&self) -> usize { + self.num_instances * self.max_num_proofs * self.num_witness_secs * self.max_num_inputs + } + + pub fn num_flattened_vars(&self) -> usize { + self.flattened_len().log_2() + } + pub fn evaluate(&self, r_p: &Vec, r_q: &Vec, r_w: &Vec, r_x: &Vec) -> E { let mut cl = self.clone(); cl.bound_poly_vars_rx(r_x); @@ -328,11 +346,11 @@ impl DensePolynomialPqx { cl.index(0, 0, 0, 0) } - fn to_dense_Z_poly(&self) -> Vec { + fn fill_dense_Z_poly(&mut self) { let mut Z_poly = vec![ E::ZERO; - self.num_instances * self.max_num_proofs * self.num_witness_secs * self.max_num_inputs + self.flattened_len() ]; for p in 0..min(self.num_instances, self.Z.len()) { let step_q = self.max_num_proofs / self.num_proofs[p]; @@ -351,17 +369,191 @@ impl DensePolynomialPqx { } } - Z_poly + self.dense_multilinear = Some(DenseMultilinearExtension::from_evaluations_ext_vec(Z_poly.len().log_2(), Z_poly)); } - // Convert to a (p, q_rev, x_rev) regular dense poly of form (p, q, x) + // Convert to Ceno prover compatible multilinear poly pub fn to_dense_poly(&self) -> DensePolynomial { - DensePolynomial::new(self.to_dense_Z_poly()) + match self.evaluations() { + FieldType::Ext(v) => DensePolynomial::new(v.to_vec()), + _ => { unreachable!() } + } } // Convert to Ceno prover compatible multilinear poly pub fn to_ceno_multilinear(&self) -> DenseMultilinearExtension { - let Z_poly = self.to_dense_Z_poly(); - DenseMultilinearExtension::from_evaluations_ext_vec(Z_poly.len().log_2(), Z_poly) + match self.evaluations() { + FieldType::Ext(v) => DenseMultilinearExtension::from_evaluations_ext_vec(v.len().log_2(), v.to_vec()), + _ => { unreachable!() } + } + } +} + +impl MultilinearExtension for DensePolynomialPqx { + type Output = DenseMultilinearExtension; + /// Reduce the number of variables of `self` by fixing the + /// `partial_point.len()` variables at `partial_point`. + fn fix_variables(&self, partial_point: &[E]) -> Self::Output { + // TODO: return error. + assert!( + partial_point.len() <= self.num_vars(), + "invalid size of partial point" + ); + + let mut poly = self.clone(); + + for point in partial_point.iter() { + poly.fix_variables_in_place(&[*point]) + } + assert!(poly.num_flattened_vars() == self.num_flattened_vars() - partial_point.len(),); + poly.to_ceno_multilinear() + } + + /// Reduce the number of variables of `self` by fixing the + /// `partial_point.len()` variables at `partial_point` in place + fn fix_variables_in_place(&mut self, partial_point: &[E]) { + // TODO: return error. + assert!( + partial_point.len() <= self.num_flattened_vars(), + "partial point len {} >= num_vars {}", + partial_point.len(), + self.num_flattened_vars() + ); + + let mut instance_vars = self.num_instances.log_2(); + let mut proofs_vars = self.max_num_proofs.log_2(); + let mut witness_secs_vars = self.num_witness_secs.log_2(); + let mut input_vars = self.max_num_inputs.log_2(); + + for point in partial_point.iter() { + if input_vars > 0 { + self.bound_poly_vars_rx(&vec![*point]); + input_vars /= 2; + } else if witness_secs_vars > 0 { + self.bound_poly_vars_rw(&vec![*point]); + witness_secs_vars /= 2; + } else if proofs_vars > 0 { + self.bound_poly_vars_rq(&vec![*point]); + proofs_vars /= 2; + } else { + self.bound_poly_vars_rp(&vec![*point]); + instance_vars /= 2; + } + } + } + + /// Reduce the number of variables of `self` by fixing the + /// `partial_point.len()` variables at `partial_point` from high position + fn fix_high_variables(&self, _partial_point: &[E]) -> Self::Output { + unimplemented!() + } + + /// Reduce the number of variables of `self` by fixing the + /// `partial_point.len()` variables at `partial_point` from high position in place + fn fix_high_variables_in_place(&mut self, _partial_point: &[E]) { + unimplemented!() + } + + /// Evaluate the MLE at a give point. + /// Returns an error if the MLE length does not match the point. + fn evaluate(&self, point: &[E]) -> E { + // TODO: return error. + assert_eq!( + self.num_vars(), + point.len(), + "MLE size does not match the point" + ); + let mle = self.fix_variables_parallel(point); + + if let Some(f) = &self.dense_multilinear { + match &f.evaluations { + FieldType::Ext(v) => v[0], + _ => unreachable!() + } + } else { + unreachable!() + } + } + + fn num_vars(&self) -> usize { + self.num_flattened_vars() + } + + /// Reduce the number of variables of `self` by fixing the + /// `partial_point.len()` variables at `partial_point`. + fn fix_variables_parallel(&self, partial_point: &[E]) -> Self::Output { + self.fix_variables(partial_point) + } + + /// Reduce the number of variables of `self` by fixing the + /// `partial_point.len()` variables at `partial_point` in place + fn fix_variables_in_place_parallel(&mut self, partial_point: &[E]) { + self.fix_variables_in_place(partial_point); + } + + fn evaluations(&self) -> &FieldType { + &self.dense_multilinear.as_ref().unwrap().evaluations + } + + fn evaluations_to_owned(self) -> FieldType { + unimplemented!() + } + + fn evaluations_range(&self) -> Option<(usize, usize)> { + None + } + + fn name(&self) -> &'static str { + "DensePolynomialPqx" + } + + /// assert and get base field vector + /// panic if not the case + fn get_base_field_vec(&self) -> &[E::BaseField] { + if let Some(f) = &self.dense_multilinear { + match &f.evaluations { + FieldType::Base(evaluations) => &evaluations[..], + _ => unreachable!(), + } + } else { + unreachable!() + } + } + + fn merge(&mut self, _rhs: DenseMultilinearExtension) { + unimplemented!() + } + + /// get ranged multiliear extention + fn get_ranged_mle( + &self, + num_range: usize, + range_index: usize, + ) -> RangedMultilinearExtension<'_, E> { + assert!(num_range > 0); + // ranged_mle is exclusively used in multi-thread parallelism + // The number of ranges must be a power of 2 + assert!(num_range.next_power_of_two() == num_range); + let offset = self.evaluations().len() / num_range; + let start = offset * range_index; + RangedMultilinearExtension::new(self.dense_multilinear.as_ref().unwrap(), start, offset) + } + + /// resize to new size (num_instances * new_size_per_instance / num_range) + /// and selected by range_index + /// only support resize base fields, otherwise panic + fn resize_ranged( + &self, + _num_instances: usize, + _new_size_per_instance: usize, + _num_range: usize, + _range_index: usize, + ) -> Self::Output { + unimplemented!() + } + + /// dup to new size 1 << (self.num_vars + ceil_log2(num_dups)) + fn dup(&self, _num_instances: usize, _num_dups: usize) -> Self::Output { + unimplemented!() } } diff --git a/spartan_parallel/src/r1csproof.rs b/spartan_parallel/src/r1csproof.rs index 90d261b5..c752ad7f 100644 --- a/spartan_parallel/src/r1csproof.rs +++ b/spartan_parallel/src/r1csproof.rs @@ -13,6 +13,7 @@ use serde::Serialize; use std::cmp::min; use std::iter::zip; use std::sync::Arc; +use std::cmp::max; use multilinear_extensions::{ mle::{IntoMLE, MultilinearExtension, DenseMultilinearExtension}, virtual_poly::VPAuxInfo, @@ -245,13 +246,13 @@ impl<'a, E: ExtensionField + Send + Sync> R1CSProof { timer_tmp.stop(); // == test: ceno_verifier_bench == - let max_num_vars = poly_tau.get_num_vars(); let num_threads = 32; + let max_num_vars = poly_tau.get_num_vars(); let arc_A: Arc>> = Arc::new(poly_tau.to_ceno_multilinear()); - let arc_B: Arc>> = Arc::new(poly_Az.to_ceno_multilinear()); - let arc_C: Arc>> = Arc::new(poly_Bz.to_ceno_multilinear()); - let arc_D: Arc>> = Arc::new(poly_Cz.to_ceno_multilinear()); + let arc_B: Arc>> = Arc::new(poly_Az); + let arc_C: Arc>> = Arc::new(poly_Bz); + let arc_D: Arc>> = Arc::new(poly_Cz); let mut virtual_polys = VirtualPolynomials::new(num_threads, max_num_vars); @@ -430,7 +431,6 @@ impl<'a, E: ExtensionField + Send + Sync> R1CSProof { let mut eq_p_rp_poly = DensePolynomial::new( tmp_rp_poly.into_iter().map(|i| vec![i; scale]).collect::>>().concat() ); - let max_num_vars_phase2 = ABC_poly.get_num_vars(); let mut claimed_sum = E::ZERO; let mut claimed_partial_sum = E::ZERO; @@ -445,6 +445,9 @@ impl<'a, E: ExtensionField + Send + Sync> R1CSProof { c_sum += c.clone(); } + // debug_ceno_prover + let max_num_vars_phase2 = max(ABC_poly.get_num_vars(), max(Z_poly.get_num_vars(), eq_p_rp_poly.get_num_vars())); + let arc_A: Arc>> = Arc::new(ABC_poly.to_ceno_multilinear()); let arc_B: Arc>> = Arc::new(Z_poly.to_ceno_multilinear()); let arc_C: Arc>> = Arc::new(eq_p_rp_poly.to_ceno_multilinear());