Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sumcheck macro #823

Merged
merged 16 commits into from
Jan 22, 2025
16 changes: 16 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ members = [
"examples",
"mpcs",
"multilinear_extensions",
"sumcheck_macro",
"poseidon",
"sumcheck",
"transcript",
Expand Down
1 change: 1 addition & 0 deletions sumcheck/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ tracing.workspace = true

crossbeam-channel.workspace = true
multilinear_extensions = { path = "../multilinear_extensions", features = ["parallel"] }
sumcheck_macro = { path = "../sumcheck_macro" }
transcript = { path = "../transcript" }

[dev-dependencies]
Expand Down
256 changes: 32 additions & 224 deletions sumcheck/src/prover.rs

Large diffs are not rendered by default.

27 changes: 19 additions & 8 deletions sumcheck/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use ark_std::{rand::RngCore, test_rng};
use ff::Field;
use ff_ext::ExtensionField;
use goldilocks::GoldilocksExt2;
use multilinear_extensions::virtual_poly::VirtualPolynomial;
use multilinear_extensions::{mle::DenseMultilinearExtension, virtual_poly::VirtualPolynomial};
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
use transcript::{BasicTranscript, Transcript};

Expand Down Expand Up @@ -81,9 +81,22 @@ fn test_sumcheck_internal<E: ExtensionField>(
.flattened_ml_extensions
.par_iter_mut()
.for_each(|mle| {
Arc::get_mut(mle)
.unwrap()
.fix_variables_in_place(&[p.elements]);
if num_variables == 1 {
// first time fix variable should be create new instance
if mle.num_vars() > 0 {
*mle = mle.fix_variables(&[p.elements]).into();
} else {
*mle = Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart(
0,
mle.get_base_field_vec().to_vec(),
))
}
} else {
let mle = Arc::get_mut(mle).unwrap();
if mle.num_vars() > 0 {
mle.fix_variables_in_place(&[p.elements]);
}
}
});
};
let subclaim = IOPVerifierState::check_and_generate_subclaim(&verifier_state, &asserted_sum);
Expand All @@ -101,29 +114,27 @@ fn test_sumcheck_internal<E: ExtensionField>(
}

#[test]
#[ignore = "temporarily not supporting degree > 2"]
fn test_trivial_polynomial() {
test_trivial_polynomial_helper::<GoldilocksExt2>();
}

fn test_trivial_polynomial_helper<E: ExtensionField>() {
let nv = 1;
let num_multiplicands_range = (4, 13);
let num_multiplicands_range = (3, 5);
let num_products = 5;

test_sumcheck::<E>(nv, num_multiplicands_range, num_products);
test_sumcheck_internal::<E>(nv, num_multiplicands_range, num_products);
}

#[test]
#[ignore = "temporarily not supporting degree > 2"]
fn test_normal_polynomial() {
test_normal_polynomial_helper::<GoldilocksExt2>();
}

fn test_normal_polynomial_helper<E: ExtensionField>() {
let nv = 12;
let num_multiplicands_range = (4, 9);
let num_multiplicands_range = (3, 5);
let num_products = 5;

test_sumcheck::<E>(nv, num_multiplicands_range, num_products);
Expand Down
26 changes: 26 additions & 0 deletions sumcheck_macro/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[package]
categories.workspace = true
description = "Macros for the Ceno project"
edition.workspace = true
keywords.workspace = true
license.workspace = true
name = "sumcheck_macro"
readme.workspace = true
repository.workspace = true
version.workspace = true

[lib]
proc-macro = true

[dependencies]
itertools.workspace = true
proc-macro2 = "1.0.92"
quote = "1.0"
syn = { version = "2.0", features = ["full"] }

[dev-dependencies]
ff_ext = { path = "../ff_ext" }
goldilocks.workspace = true
multilinear_extensions = { path = "../multilinear_extensions" }
rayon = { workspace = true }
sumcheck = { path = "../sumcheck" }
28 changes: 28 additions & 0 deletions sumcheck_macro/examples/expand.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/// To see the code generated by the macro, run the following command in the sumcheck_macro directory:
/// ```sh
/// cargo expand --example expand
/// ```
use ff_ext::ExtensionField;
use goldilocks::GoldilocksExt2;
use multilinear_extensions::{
mle::FieldType, util::largest_even_below, virtual_poly::VirtualPolynomial,
};
use sumcheck::util::{AdditiveArray, ceil_log2};

#[derive(Default)]
struct Container<'a, E: ExtensionField> {
poly: VirtualPolynomial<'a, E>,
round: usize,
}

fn main() {
let c = Container::<GoldilocksExt2>::default();
c.run();
}

impl<E: ExtensionField> Container<'_, E> {
pub fn run(&self) {
let _result: AdditiveArray<_, 4> =
sumcheck_macro::sumcheck_code_gen!(3, false, |_| &self.poly.flattened_ml_extensions[0]);
}
}
Loading