Skip to content

Commit 928be2e

Browse files
committed
Implement MultilinearExtension for DensePolynomialPqx
1 parent 901a56c commit 928be2e

File tree

2 files changed

+213
-18
lines changed

2 files changed

+213
-18
lines changed

spartan_parallel/src/custom_dense_mlpoly.rs

Lines changed: 205 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
#![allow(clippy::too_many_arguments)]
2+
use core::{unimplemented, unreachable};
23
use std::cmp::min;
34

4-
use crate::dense_mlpoly::DensePolynomial;
5+
use crate::{dense_mlpoly::DensePolynomial, instance};
6+
use ff::Field;
57
use ff_ext::ExtensionField;
6-
use multilinear_extensions::mle::DenseMultilinearExtension;
8+
use multilinear_extensions::mle::{FieldType, MultilinearExtension, DenseMultilinearExtension, RangedMultilinearExtension};
9+
use std::{any::TypeId, borrow::Cow, mem, sync::Arc};
710

811
use super::math::Math;
912

@@ -29,6 +32,7 @@ pub struct DensePolynomialPqx<E: ExtensionField> {
2932
// Let Q_max = max_num_proofs, assume that for a given P, num_proofs[P] = Q_i, then let STEP = Q_max / Q_i,
3033
// Z(P, y, .) is only non-zero if y is a multiple of STEP, so Z[P][j][.] actually stores Z(P, j*STEP, .)
3134
// The same applies to X
35+
pub dense_multilinear: Option<DenseMultilinearExtension<E>>,
3236
}
3337

3438
// Reverse the bits in q or x
@@ -50,15 +54,18 @@ impl<E: ExtensionField> DensePolynomialPqx<E> {
5054
) -> Self {
5155
let num_instances = z_mat.len().next_power_of_two();
5256
let num_witness_secs = z_mat[0][0].len().next_power_of_two();
53-
DensePolynomialPqx {
57+
let mut inst = DensePolynomialPqx {
5458
num_instances,
5559
num_proofs,
5660
max_num_proofs,
5761
num_witness_secs,
5862
num_inputs,
5963
max_num_inputs,
6064
Z: z_mat,
61-
}
65+
dense_multilinear: None,
66+
};
67+
inst.fill_dense_Z_poly();
68+
inst
6269
}
6370

6471
// Assume z_mat is in its standard form of (p, q, x)
@@ -101,15 +108,18 @@ impl<E: ExtensionField> DensePolynomialPqx<E> {
101108
}
102109
}
103110
}
104-
DensePolynomialPqx {
111+
let mut inst = DensePolynomialPqx {
105112
num_instances: num_instances.next_power_of_two(),
106113
num_proofs,
107114
max_num_proofs,
108115
num_witness_secs: num_witness_secs.next_power_of_two(),
109116
num_inputs,
110117
max_num_inputs,
111118
Z,
112-
}
119+
dense_multilinear: None,
120+
};
121+
inst.fill_dense_Z_poly();
122+
inst
113123
}
114124

115125
pub fn len(&self) -> usize {
@@ -319,6 +329,14 @@ impl<E: ExtensionField> DensePolynomialPqx<E> {
319329
}
320330
}
321331

332+
pub fn flattened_len(&self) -> usize {
333+
self.num_instances * self.max_num_proofs * self.num_witness_secs * self.max_num_inputs
334+
}
335+
336+
pub fn num_flattened_vars(&self) -> usize {
337+
self.flattened_len().log_2()
338+
}
339+
322340
pub fn evaluate(&self, r_p: &Vec<E>, r_q: &Vec<E>, r_w: &Vec<E>, r_x: &Vec<E>) -> E {
323341
let mut cl = self.clone();
324342
cl.bound_poly_vars_rx(r_x);
@@ -328,11 +346,11 @@ impl<E: ExtensionField> DensePolynomialPqx<E> {
328346
cl.index(0, 0, 0, 0)
329347
}
330348

331-
fn to_dense_Z_poly(&self) -> Vec<E> {
349+
fn fill_dense_Z_poly(&mut self) {
332350
let mut Z_poly =
333351
vec![
334352
E::ZERO;
335-
self.num_instances * self.max_num_proofs * self.num_witness_secs * self.max_num_inputs
353+
self.flattened_len()
336354
];
337355
for p in 0..min(self.num_instances, self.Z.len()) {
338356
let step_q = self.max_num_proofs / self.num_proofs[p];
@@ -351,17 +369,191 @@ impl<E: ExtensionField> DensePolynomialPqx<E> {
351369
}
352370
}
353371

354-
Z_poly
372+
self.dense_multilinear = Some(DenseMultilinearExtension::from_evaluations_ext_vec(Z_poly.len().log_2(), Z_poly));
355373
}
356374

357-
// Convert to a (p, q_rev, x_rev) regular dense poly of form (p, q, x)
375+
// Convert to Ceno prover compatible multilinear poly
358376
pub fn to_dense_poly(&self) -> DensePolynomial<E> {
359-
DensePolynomial::new(self.to_dense_Z_poly())
377+
match self.evaluations() {
378+
FieldType::Ext(v) => DensePolynomial::new(v.to_vec()),
379+
_ => { unreachable!() }
380+
}
360381
}
361382

362383
// Convert to Ceno prover compatible multilinear poly
363384
pub fn to_ceno_multilinear(&self) -> DenseMultilinearExtension<E> {
364-
let Z_poly = self.to_dense_Z_poly();
365-
DenseMultilinearExtension::from_evaluations_ext_vec(Z_poly.len().log_2(), Z_poly)
385+
match self.evaluations() {
386+
FieldType::Ext(v) => DenseMultilinearExtension::from_evaluations_ext_vec(v.len().log_2(), v.to_vec()),
387+
_ => { unreachable!() }
388+
}
389+
}
390+
}
391+
392+
impl<E: ExtensionField> MultilinearExtension<E> for DensePolynomialPqx<E> {
393+
type Output = DenseMultilinearExtension<E>;
394+
/// Reduce the number of variables of `self` by fixing the
395+
/// `partial_point.len()` variables at `partial_point`.
396+
fn fix_variables(&self, partial_point: &[E]) -> Self::Output {
397+
// TODO: return error.
398+
assert!(
399+
partial_point.len() <= self.num_vars(),
400+
"invalid size of partial point"
401+
);
402+
403+
let mut poly = self.clone();
404+
405+
for point in partial_point.iter() {
406+
poly.fix_variables_in_place(&[*point])
407+
}
408+
assert!(poly.num_flattened_vars() == self.num_flattened_vars() - partial_point.len(),);
409+
poly.to_ceno_multilinear()
410+
}
411+
412+
/// Reduce the number of variables of `self` by fixing the
413+
/// `partial_point.len()` variables at `partial_point` in place
414+
fn fix_variables_in_place(&mut self, partial_point: &[E]) {
415+
// TODO: return error.
416+
assert!(
417+
partial_point.len() <= self.num_flattened_vars(),
418+
"partial point len {} >= num_vars {}",
419+
partial_point.len(),
420+
self.num_flattened_vars()
421+
);
422+
423+
let mut instance_vars = self.num_instances.log_2();
424+
let mut proofs_vars = self.max_num_proofs.log_2();
425+
let mut witness_secs_vars = self.num_witness_secs.log_2();
426+
let mut input_vars = self.max_num_inputs.log_2();
427+
428+
for point in partial_point.iter() {
429+
if input_vars > 0 {
430+
self.bound_poly_vars_rx(&vec![*point]);
431+
input_vars /= 2;
432+
} else if witness_secs_vars > 0 {
433+
self.bound_poly_vars_rw(&vec![*point]);
434+
witness_secs_vars /= 2;
435+
} else if proofs_vars > 0 {
436+
self.bound_poly_vars_rq(&vec![*point]);
437+
proofs_vars /= 2;
438+
} else {
439+
self.bound_poly_vars_rp(&vec![*point]);
440+
instance_vars /= 2;
441+
}
442+
}
443+
}
444+
445+
/// Reduce the number of variables of `self` by fixing the
446+
/// `partial_point.len()` variables at `partial_point` from high position
447+
fn fix_high_variables(&self, _partial_point: &[E]) -> Self::Output {
448+
unimplemented!()
449+
}
450+
451+
/// Reduce the number of variables of `self` by fixing the
452+
/// `partial_point.len()` variables at `partial_point` from high position in place
453+
fn fix_high_variables_in_place(&mut self, _partial_point: &[E]) {
454+
unimplemented!()
455+
}
456+
457+
/// Evaluate the MLE at a give point.
458+
/// Returns an error if the MLE length does not match the point.
459+
fn evaluate(&self, point: &[E]) -> E {
460+
// TODO: return error.
461+
assert_eq!(
462+
self.num_vars(),
463+
point.len(),
464+
"MLE size does not match the point"
465+
);
466+
let mle = self.fix_variables_parallel(point);
467+
468+
if let Some(f) = &self.dense_multilinear {
469+
match &f.evaluations {
470+
FieldType::Ext(v) => v[0],
471+
_ => unreachable!()
472+
}
473+
} else {
474+
unreachable!()
475+
}
476+
}
477+
478+
fn num_vars(&self) -> usize {
479+
self.num_flattened_vars()
480+
}
481+
482+
/// Reduce the number of variables of `self` by fixing the
483+
/// `partial_point.len()` variables at `partial_point`.
484+
fn fix_variables_parallel(&self, partial_point: &[E]) -> Self::Output {
485+
self.fix_variables(partial_point)
486+
}
487+
488+
/// Reduce the number of variables of `self` by fixing the
489+
/// `partial_point.len()` variables at `partial_point` in place
490+
fn fix_variables_in_place_parallel(&mut self, partial_point: &[E]) {
491+
self.fix_variables_in_place(partial_point);
492+
}
493+
494+
fn evaluations(&self) -> &FieldType<E> {
495+
&self.dense_multilinear.as_ref().unwrap().evaluations
496+
}
497+
498+
fn evaluations_to_owned(self) -> FieldType<E> {
499+
unimplemented!()
500+
}
501+
502+
fn evaluations_range(&self) -> Option<(usize, usize)> {
503+
None
504+
}
505+
506+
fn name(&self) -> &'static str {
507+
"DensePolynomialPqx"
508+
}
509+
510+
/// assert and get base field vector
511+
/// panic if not the case
512+
fn get_base_field_vec(&self) -> &[E::BaseField] {
513+
if let Some(f) = &self.dense_multilinear {
514+
match &f.evaluations {
515+
FieldType::Base(evaluations) => &evaluations[..],
516+
_ => unreachable!(),
517+
}
518+
} else {
519+
unreachable!()
520+
}
521+
}
522+
523+
fn merge(&mut self, _rhs: DenseMultilinearExtension<E>) {
524+
unimplemented!()
525+
}
526+
527+
/// get ranged multiliear extention
528+
fn get_ranged_mle(
529+
&self,
530+
num_range: usize,
531+
range_index: usize,
532+
) -> RangedMultilinearExtension<'_, E> {
533+
assert!(num_range > 0);
534+
// ranged_mle is exclusively used in multi-thread parallelism
535+
// The number of ranges must be a power of 2
536+
assert!(num_range.next_power_of_two() == num_range);
537+
let offset = self.evaluations().len() / num_range;
538+
let start = offset * range_index;
539+
RangedMultilinearExtension::new(self.dense_multilinear.as_ref().unwrap(), start, offset)
540+
}
541+
542+
/// resize to new size (num_instances * new_size_per_instance / num_range)
543+
/// and selected by range_index
544+
/// only support resize base fields, otherwise panic
545+
fn resize_ranged(
546+
&self,
547+
_num_instances: usize,
548+
_new_size_per_instance: usize,
549+
_num_range: usize,
550+
_range_index: usize,
551+
) -> Self::Output {
552+
unimplemented!()
553+
}
554+
555+
/// dup to new size 1 << (self.num_vars + ceil_log2(num_dups))
556+
fn dup(&self, _num_instances: usize, _num_dups: usize) -> Self::Output {
557+
unimplemented!()
366558
}
367559
}

spartan_parallel/src/r1csproof.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use serde::Serialize;
1313
use std::cmp::min;
1414
use std::iter::zip;
1515
use std::sync::Arc;
16+
use std::cmp::max;
1617
use multilinear_extensions::{
1718
mle::{IntoMLE, MultilinearExtension, DenseMultilinearExtension},
1819
virtual_poly::VPAuxInfo,
@@ -245,13 +246,13 @@ impl<'a, E: ExtensionField + Send + Sync> R1CSProof<E> {
245246
timer_tmp.stop();
246247

247248
// == test: ceno_verifier_bench ==
248-
let max_num_vars = poly_tau.get_num_vars();
249249
let num_threads = 32;
250+
let max_num_vars = poly_tau.get_num_vars();
250251

251252
let arc_A: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(poly_tau.to_ceno_multilinear());
252-
let arc_B: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(poly_Az.to_ceno_multilinear());
253-
let arc_C: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(poly_Bz.to_ceno_multilinear());
254-
let arc_D: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(poly_Cz.to_ceno_multilinear());
253+
let arc_B: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(poly_Az);
254+
let arc_C: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(poly_Bz);
255+
let arc_D: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(poly_Cz);
255256

256257
let mut virtual_polys =
257258
VirtualPolynomials::new(num_threads, max_num_vars);
@@ -430,7 +431,6 @@ impl<'a, E: ExtensionField + Send + Sync> R1CSProof<E> {
430431
let mut eq_p_rp_poly = DensePolynomial::new(
431432
tmp_rp_poly.into_iter().map(|i| vec![i; scale]).collect::<Vec<Vec<E>>>().concat()
432433
);
433-
let max_num_vars_phase2 = ABC_poly.get_num_vars();
434434

435435
let mut claimed_sum = E::ZERO;
436436
let mut claimed_partial_sum = E::ZERO;
@@ -445,6 +445,9 @@ impl<'a, E: ExtensionField + Send + Sync> R1CSProof<E> {
445445
c_sum += c.clone();
446446
}
447447

448+
// debug_ceno_prover
449+
let max_num_vars_phase2 = max(ABC_poly.get_num_vars(), max(Z_poly.get_num_vars(), eq_p_rp_poly.get_num_vars()));
450+
448451
let arc_A: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(ABC_poly.to_ceno_multilinear());
449452
let arc_B: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(Z_poly.to_ceno_multilinear());
450453
let arc_C: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(eq_p_rp_poly.to_ceno_multilinear());

0 commit comments

Comments
 (0)