Skip to content

Commit

Permalink
remove redundant to_vec
Browse files Browse the repository at this point in the history
  • Loading branch information
zemse committed Jan 16, 2025
1 parent 4ab3396 commit c847a12
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 353 deletions.
238 changes: 16 additions & 222 deletions sumcheck/src/prover.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use std::{array, mem, sync::Arc};
use std::{mem, sync::Arc};

use ark_std::{end_timer, start_timer};
use crossbeam_channel::bounded;
use ff_ext::ExtensionField;
use itertools::Itertools;
use multilinear_extensions::{
commutative_op_mle_pair,
mle::{DenseMultilinearExtension, FieldType, MultilinearExtension},
op_mle, op_mle_product_3, op_mle3_range,
op_mle,
util::largest_even_below,
virtual_poly::VirtualPolynomial,
};
Expand All @@ -16,6 +15,7 @@ use rayon::{
iter::{IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator},
prelude::{IntoParallelIterator, ParallelIterator},
};
use sumcheck_macro::sumcheck_code_gen;
use transcript::{Challenge, Transcript, TranscriptSyncronized};

use crate::{
Expand Down Expand Up @@ -428,114 +428,13 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
|mut products_sum, (coefficient, products)| {
let span = entered_span!("sum");

let f = &self.poly.flattened_ml_extensions;
let mut sum = match products.len() {
// use macro to generate sumcheck protocol boipolate code here
1 => {
let f = &self.poly.flattened_ml_extensions[products[0]];
op_mle! {
|f| {
let res = (0..largest_even_below(f.len()))
.step_by(2)
.rev()
.fold(AdditiveArray::<_, 2>(array::from_fn(|_| 0.into())), |mut acc, b| {
acc.0[0] += f[b];
acc.0[1] += f[b+1];
acc
});
let res = if f.len() == 1 {
AdditiveArray::<_, 2>([f[0]; 2])
} else {
res
};
let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1);
if num_vars_multiplicity > 0 {
AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity)))
} else {
res
}
},
|sum| AdditiveArray(sum.0.map(E::from))
}
.to_vec()
}
2 => {
// sumcheck_macro::sumcheck_code_gen!(2, |i| &self.poly.flattened_ml_extensions[products[i]])
let (f, g) = (
&self.poly.flattened_ml_extensions[products[0]],
&self.poly.flattened_ml_extensions[products[1]],
);
commutative_op_mle_pair!(
|f, g| {
let res = (0..largest_even_below(f.len())).step_by(2).rev().fold(
AdditiveArray::<_, 3>(array::from_fn(|_| 0.into())),
// compiler can do more optimisation if below code goes from f1 * f2 to f1 * f2 * f3 * f4
|mut acc, b| {
acc.0[0] += f[b] * g[b];
acc.0[1] += f[b + 1] * g[b + 1];
acc.0[2] +=
(f[b + 1] + f[b + 1] - f[b]) * (g[b + 1] + g[b + 1] - g[b]);
acc
});
let res = if f.len() == 1 {
AdditiveArray::<_, 3>([f[0] * g[0]; 3])
} else {
res
};
let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1);
if num_vars_multiplicity > 0 {
AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity)))
} else {
res
}
},
|sum| AdditiveArray(sum.0.map(E::from))
)
.to_vec()
}
3 => {
let (f1, f2, f3) = (
&self.poly.flattened_ml_extensions[products[0]],
&self.poly.flattened_ml_extensions[products[1]],
&self.poly.flattened_ml_extensions[products[2]],
);
op_mle_product_3!(
|f1, f2, f3| {
let res = (0..largest_even_below(f1.len()))
.step_by(2)
.rev()
.map(|b| {
// f = c x + d
let c1 = f1[b + 1] - f1[b];
let c2 = f2[b + 1] - f2[b];
let c3 = f3[b + 1] - f3[b];
AdditiveArray([
f1[b] * (f2[b] * f3[b]),
f1[b + 1] * (f2[b + 1] * f3[b + 1]),
(c1 + f1[b + 1])
* ((c2 + f2[b + 1]) * (c3 + f3[b + 1])),
(c1 + c1 + f1[b + 1])
* ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])),
])
})
.sum::<AdditiveArray<_, 4>>();
let res = if f1.len() == 1 {
AdditiveArray::<_, 4>([f1[0] * f2[0] * f3[0]; 4])
} else {
res
};
let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f1.len()).max(1) + self.round - 1);
if num_vars_multiplicity > 0 {
AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity)))
} else {
res
}
},
|sum| AdditiveArray(sum.0.map(E::from))
)
.to_vec()
}
4 => sumcheck_macro::sumcheck_code_gen!(4, |i| &self.poly.flattened_ml_extensions[products[i]]).to_vec(),
_ => unimplemented!("do not support degree > 3"), // I have to support degree 4 and more here
1 => sumcheck_code_gen!(1, false, |i| &f[products[i]]).to_vec(),
2 => sumcheck_code_gen!(2, false, |i| &f[products[i]]).to_vec(),
3 => sumcheck_code_gen!(3, false, |i| &f[products[i]]).to_vec(),
4 => sumcheck_code_gen!(4, false, |i| &f[products[i]]).to_vec(),
_ => unimplemented!("do not support degree > 4"),
};
exit_span!(span);
sum.iter_mut().for_each(|sum| *sum *= coefficient);
Expand Down Expand Up @@ -781,119 +680,14 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
|mut products_sum, (coefficient, products)| {
let span = entered_span!("sum");

let f = &self.poly.flattened_ml_extensions;
let mut sum = match products.len() {
1 => {
let f = &self.poly.flattened_ml_extensions[products[0]];
op_mle! {
// following code for Base or Ext
|f| {
let res = (0..largest_even_below(f.len()))
.into_par_iter()
.step_by(2)
.with_min_len(64)
.map(|b| {
AdditiveArray([
f[b],
f[b + 1]
])
})
.sum::<AdditiveArray<_, 2>>();
let res = if f.len() == 1 {
AdditiveArray::<_, 2>([f[0]; 2])
} else {
res
};
let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1);
if num_vars_multiplicity > 0 {
AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity)))
} else {
res
}
},
// following code for Base after running above
|sum| AdditiveArray(sum.0.map(E::from))
}
.to_vec()
}
2 => {
let (f, g) = (
&self.poly.flattened_ml_extensions[products[0]],
&self.poly.flattened_ml_extensions[products[1]],
);
commutative_op_mle_pair!(
|f, g| {
let res = (0..largest_even_below(f.len()))
.into_par_iter()
.step_by(2)
.with_min_len(64)
.map(|b| {
AdditiveArray([
f[b] * g[b],
f[b + 1] * g[b + 1],
(f[b + 1] + f[b + 1] - f[b])
* (g[b + 1] + g[b + 1] - g[b]),
])
})
.sum::<AdditiveArray<_, 3>>();
let res = if f.len() == 1 {
AdditiveArray::<_, 3>([f[0] * g[0]; 3])
} else {
res
};
let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1);
if num_vars_multiplicity > 0 {
AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity)))
} else {
res
}
},
|sum| AdditiveArray(sum.0.map(E::from))
)
.to_vec()
}
3 => {
let (f1, f2, f3) = (
&self.poly.flattened_ml_extensions[products[0]],
&self.poly.flattened_ml_extensions[products[1]],
&self.poly.flattened_ml_extensions[products[2]],
);
op_mle_product_3!(
|f1, f2, f3| {
let res = (0..largest_even_below(f1.len()))
.step_by(2)
.map(|b| {
// f = c x + d
let c1 = f1[b + 1] - f1[b];
let c2 = f2[b + 1] - f2[b];
let c3 = f3[b + 1] - f3[b];
AdditiveArray([
f1[b] * (f2[b] * f3[b]),
f1[b + 1] * (f2[b + 1] * f3[b + 1]),
(c1 + f1[b + 1])
* ((c2 + f2[b + 1]) * (c3 + f3[b + 1])),
(c1 + c1 + f1[b + 1])
* ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])),
])
})
.sum::<AdditiveArray<_, 4>>();
let res = if f1.len() == 1 {
AdditiveArray::<_, 4>([f1[0] * f2[0] * f3[0]; 4])
} else {
res
};
let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f1.len()).max(1) + self.round - 1);
if num_vars_multiplicity > 0 {
AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity)))
} else {
res
}
},
|sum| AdditiveArray(sum.0.map(E::from))
)
.to_vec()
}
4 => sumcheck_macro::sumcheck_code_gen!(4, |i| &self.poly.flattened_ml_extensions[products[i]]).to_vec(),
_ => unimplemented!("do not support degree > 3"),
1 => sumcheck_code_gen!(1, true, |i| &f[products[i]]).to_vec(),
2 => sumcheck_code_gen!(2, true, |i| &f[products[i]]).to_vec(),
3 => sumcheck_code_gen!(3, true, |i| &f[products[i]]).to_vec(),
4 => sumcheck_code_gen!(4, true, |i| &f[products[i]]).to_vec(),

_ => unimplemented!("do not support degree > 5"),
};
exit_span!(span);
sum.iter_mut().for_each(|sum| *sum *= coefficient);
Expand Down
8 changes: 2 additions & 6 deletions sumcheck_macro/examples/expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@ use goldilocks::GoldilocksExt2;
use multilinear_extensions::{
mle::FieldType, util::largest_even_below, virtual_poly::VirtualPolynomial,
};
use rayon::{
iter::IndexedParallelIterator,
prelude::{IntoParallelIterator, ParallelIterator},
};
use sumcheck::util::{AdditiveArray, ceil_log2};

#[derive(Default)]
Expand All @@ -22,7 +18,7 @@ fn main() {

impl<E: ExtensionField> Container<'_, E> {
pub fn run(&self) {
let _result: AdditiveArray<_, 3> =
sumcheck_macro::sumcheck_code_gen!(2, |_| self.poly.flattened_ml_extensions[0].clone());
let _result: AdditiveArray<_, 4> =
sumcheck_macro::sumcheck_code_gen!(3, false, |_| &self.poly.flattened_ml_extensions[0]);
}
}
Loading

0 comments on commit c847a12

Please sign in to comment.