diff --git a/jolt-core/benches/e2e_profiling.rs b/jolt-core/benches/e2e_profiling.rs index 8638cff9a..64199c3f2 100644 --- a/jolt-core/benches/e2e_profiling.rs +++ b/jolt-core/benches/e2e_profiling.rs @@ -203,7 +203,7 @@ fn prove_example( let mut tasks = Vec::new(); let mut program = host::Program::new(example_name); let (bytecode, init_memory_state, _) = program.decode(); - let (_, trace, _, program_io) = program.trace(&serialized_input, &[], &[]); + let (_lazy_trace, trace, _, program_io) = program.trace(&serialized_input, &[], &[]); let padded_trace_len = (trace.len() + 1).next_power_of_two(); drop(trace); diff --git a/jolt-core/src/field/challenge/mont_ark_u128.rs b/jolt-core/src/field/challenge/mont_ark_u128.rs index fa82e3560..8a7073e6d 100644 --- a/jolt-core/src/field/challenge/mont_ark_u128.rs +++ b/jolt-core/src/field/challenge/mont_ark_u128.rs @@ -9,7 +9,7 @@ use crate::field::OptimizedMul; use crate::field::{tracked_ark::TrackedFr, JoltField}; -use crate::impl_field_ops_inline; +//use crate::impl_field_ops_inline; use allocative::Allocative; use ark_ff::{BigInt, PrimeField, UniformRand}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; diff --git a/jolt-core/src/field/challenge/mont_ark_u254.rs b/jolt-core/src/field/challenge/mont_ark_u254.rs index fb26ad731..3575e6454 100644 --- a/jolt-core/src/field/challenge/mont_ark_u254.rs +++ b/jolt-core/src/field/challenge/mont_ark_u254.rs @@ -4,7 +4,7 @@ use crate::field::OptimizedMul; use crate::field::{tracked_ark::TrackedFr, JoltField}; -use crate::impl_field_ops_inline; +//use crate::impl_field_ops_inline; use allocative::Allocative; use ark_ff::UniformRand; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; diff --git a/jolt-core/src/poly/mod.rs b/jolt-core/src/poly/mod.rs index 2f55b2395..6cd948840 100644 --- a/jolt-core/src/poly/mod.rs +++ b/jolt-core/src/poly/mod.rs @@ -7,6 +7,7 @@ pub mod identity_poly; pub mod lagrange_poly; pub mod lt_poly; pub mod multilinear_polynomial; +pub mod multiquadratic_poly; pub mod one_hot_polynomial; pub mod opening_proof; pub mod prefix_suffix; diff --git a/jolt-core/src/poly/multiquadratic_poly.rs b/jolt-core/src/poly/multiquadratic_poly.rs new file mode 100644 index 000000000..3d2e1d01c --- /dev/null +++ b/jolt-core/src/poly/multiquadratic_poly.rs @@ -0,0 +1,386 @@ +use allocative::Allocative; + +use crate::field::JoltField; +use crate::poly::multilinear_polynomial::{BindingOrder, PolynomialBinding}; + +/// Multiquadratic polynomial represented by its evaluations on the grid +/// {0, 1, ∞}^num_vars in base-3 layout (z_0 least-significant / fastest-varying). +#[derive(Allocative)] +pub struct MultiquadraticPolynomial { + num_vars: usize, + evals: Vec, +} + +impl MultiquadraticPolynomial { + /// Construct a multiquadratic polynomial from its full grid of evaluations. + /// The caller is responsible for ensuring that `evals` is laid out in base-3 + /// order with z_0 as the least-significant digit. + pub fn new(num_vars: usize, evals: Vec) -> Self { + let expected_len = 3usize.pow(num_vars as u32); + debug_assert!( + evals.len() == expected_len, + "MultiquadraticPolynomial: expected {} evals, got {}", + expected_len, + evals.len() + ); + Self { num_vars, evals } + } + + /// Number of variables in the polynomial. + pub fn num_vars(&self) -> usize { + self.num_vars + } + + /// Underlying evaluations on {0, 1, ∞}^num_vars. + pub fn evals(&self) -> &[F] { + &self.evals + } + + /// Given evaluations of a degree-1 multivariate polynomial over {0,1}^dim, + /// expand them to the corresponding multiquadratic grid over {0,1,∞}^dim. + /// + /// The input is a length-2^dim slice `input` containing evaluations on the + /// Boolean hypercube. The caller must provide two length-3^dim buffers: + /// - `output` will contain the final {0,1,∞}^dim values on return + /// - `tmp` is a scratch buffer which this routine may use internally + /// + /// Layout is product-order with the last variable as the fastest-varying + /// coordinate. For each 1D slice (f0, f1) along a new dimension we write + /// (f(0), f(1), f(∞)) = (f0, f1, f1 - f0), so ∞ stores the slope. + /// + /// TODO: special-case dim ∈ {1,2,3} with hand-unrolled code to reduce + /// loop overhead on small windows. + #[inline(always)] + pub fn expand_linear_grid_to_multiquadratic( + input: &[F], // initial buffer (size 2^dim) + output: &mut [F], // final buffer (size 3^dim) + tmp: &mut [F], // scratch buffer, also (size 3^dim) + dim: usize, + ) { + let in_size = 1usize << dim; + let out_size = 3usize.pow(dim as u32); + + assert_eq!(input.len(), in_size); + assert_eq!(output.len(), out_size); + assert_eq!(tmp.len(), out_size); + + match dim { + 0 => { + if !input.is_empty() { + output[0] = input[0]; + } + return; + } + 1 => { + Self::expand_linear_dim1(input, output); + return; + } + 2 => { + Self::expand_linear_dim2(input, output); + return; + } + 3 => { + Self::expand_linear_dim3(input, output); + return; + } + _ => {} + } + + // Fill output by expanding one dimension at a time. + // We treat slices of increasing "arity" + + // Copy the initial evaluations into the start of either + // tmp or output, depending on parity of dim. + // We'll alternate between tmp and output as we expand dimensions. + let (mut cur, mut next) = if dim % 2 == 1 { + tmp[..input.len()].copy_from_slice(input); + (tmp, output) + } else { + output[..input.len()].copy_from_slice(input); + (output, tmp) + }; + + let mut in_stride = 1usize; + let mut out_stride = 1usize; + let mut blocks = 1 << (dim - 1); + + // sanity checks + assert_eq!(cur.len(), out_size); + assert_eq!(next.len(), out_size); + assert_eq!(input.len(), in_size); + + // start from the smallest subcubes and expand dimension by dimension + for _ in 0..dim { + for b in 0..blocks { + let in_off = b * 2 * in_stride; + let out_off = b * 3 * out_stride; + + for j in 0..in_stride { + // 1d extrapolate + let f0 = cur[in_off + j]; + let f1 = cur[in_off + in_stride + j]; + next[out_off + j] = f0; + next[out_off + out_stride + j] = f1; + next[out_off + 2 * out_stride + j] = f1 - f0; + } + } + // swap buffers + std::mem::swap(&mut cur, &mut next); + in_stride *= 3; + out_stride *= 3; + blocks /= 2; + } + } + + #[inline(always)] + fn expand_linear_dim1(input: &[F], output: &mut [F]) { + debug_assert_eq!(input.len(), 2); + debug_assert_eq!(output.len(), 3); + + let f0 = input[0]; + let f1 = input[1]; + + output[0] = f0; + output[1] = f1; + output[2] = f1 - f0; + } + + #[inline(always)] + fn expand_linear_dim2(input: &[F], output: &mut [F]) { + debug_assert_eq!(input.len(), 4); + debug_assert_eq!(output.len(), 9); + + let f00 = input[0]; // f(0,0) + let f01 = input[1]; // f(0,1) + let f10 = input[2]; // f(1,0) + let f11 = input[3]; // f(1,1) + + // First extrapolate along the fastest-varying variable (second coordinate). + let a00 = f00; + let a01 = f01; + let a0_inf = f01 - f00; + + let a10 = f10; + let a11 = f11; + let a1_inf = f11 - f10; + + // Then extrapolate along the remaining variable. + let inf0 = a10 - a00; + let inf1 = a11 - a01; + let inf_inf = a1_inf - a0_inf; + + // Layout: index = 3 * enc(x0) + enc(x1), x1 fastest, enc: {0,1,∞} -> {0,1,2}. + output[0] = a00; // (0,0) + output[1] = a01; // (0,1) + output[2] = a0_inf; // (0,∞) + + output[3] = a10; // (1,0) + output[4] = a11; // (1,1) + output[5] = a1_inf; // (1,∞) + + output[6] = inf0; // (∞,0) + output[7] = inf1; // (∞,1) + output[8] = inf_inf; // (∞,∞) + } + + #[inline(always)] + fn expand_linear_dim3(input: &[F], output: &mut [F]) { + debug_assert_eq!(input.len(), 8); + debug_assert_eq!(output.len(), 27); + + // Corner values f(x0, x1, x2) with x2 fastest. + let f000 = input[0]; + let f001 = input[1]; + let f010 = input[2]; + let f011 = input[3]; + let f100 = input[4]; + let f101 = input[5]; + let f110 = input[6]; + let f111 = input[7]; + + // Stage 1: extrapolate along x2 (fastest variable) for each (x0, x1). + let g000 = f000; + let g001 = f001; + let g00_inf = f001 - f000; + + let g010 = f010; + let g011 = f011; + let g01_inf = f011 - f010; + + let g100 = f100; + let g101 = f101; + let g10_inf = f101 - f100; + + let g110 = f110; + let g111 = f111; + let g11_inf = f111 - f110; + + // Stage 2: extrapolate along x1 for each (x0, x2). + // x0 = 0 + let h0_0_0 = g000; + let h0_1_0 = g010; + let h0_inf_0 = g010 - g000; + + let h0_0_1 = g001; + let h0_1_1 = g011; + let h0_inf_1 = g011 - g001; + + let h0_0_inf = g00_inf; + let h0_1_inf = g01_inf; + let h0_inf_inf = g01_inf - g00_inf; + + // x0 = 1 + let h1_0_0 = g100; + let h1_1_0 = g110; + let h1_inf_0 = g110 - g100; + + let h1_0_1 = g101; + let h1_1_1 = g111; + let h1_inf_1 = g111 - g101; + + let h1_0_inf = g10_inf; + let h1_1_inf = g11_inf; + let h1_inf_inf = g11_inf - g10_inf; + + // Stage 3: extrapolate along x0 for each (x1, x2). + // Index: idx(x0, x1, x2) = 9 * enc(x0) + 3 * enc(x1) + enc(x2), + // enc: {0,1,∞} -> {0,1,2}, x2 fastest. + + // (x1, x2) = (0, 0) + output[0] = h0_0_0; // (0,0,0) + output[9] = h1_0_0; // (1,0,0) + output[18] = h1_0_0 - h0_0_0; // (∞,0,0) + + // (0, 1) + output[1] = h0_0_1; // (0,0,1) + output[10] = h1_0_1; // (1,0,1) + output[19] = h1_0_1 - h0_0_1; // (∞,0,1) + + // (0, ∞) + output[2] = h0_0_inf; // (0,0,∞) + output[11] = h1_0_inf; // (1,0,∞) + output[20] = h1_0_inf - h0_0_inf; // (∞,0,∞) + + // (1, 0) + output[3] = h0_1_0; // (0,1,0) + output[12] = h1_1_0; // (1,1,0) + output[21] = h1_1_0 - h0_1_0; // (∞,1,0) + + // (1, 1) + output[4] = h0_1_1; // (0,1,1) + output[13] = h1_1_1; // (1,1,1) + output[22] = h1_1_1 - h0_1_1; // (∞,1,1) + + // (1, ∞) + output[5] = h0_1_inf; // (0,1,∞) + output[14] = h1_1_inf; // (1,1,∞) + output[23] = h1_1_inf - h0_1_inf; // (∞,1,∞) + + // (∞, 0) + output[6] = h0_inf_0; // (0,∞,0) + output[15] = h1_inf_0; // (1,∞,0) + output[24] = h1_inf_0 - h0_inf_0; // (∞,∞,0) + + // (∞, 1) + output[7] = h0_inf_1; // (0,∞,1) + output[16] = h1_inf_1; // (1,∞,1) + output[25] = h1_inf_1 - h0_inf_1; // (∞,∞,1) + + // (∞, ∞) + output[8] = h0_inf_inf; // (0,∞,∞) + output[17] = h1_inf_inf; // (1,∞,∞) + output[26] = h1_inf_inf - h0_inf_inf; // (∞,∞,∞) + } + + /// Bind the first (least-significant) variable z_0 := r, reducing the + /// dimension from w to w-1 and keeping the base-3 layout invariant. + /// + /// For each assignment to (z_1, ..., z_{w-1}), we have three stored values + /// f(0, ..), f(1, ..), f(∞, ..) + /// and interpolate the unique quadratic in z_0 that matches them, then + /// evaluate it at z_0 = r. + pub fn bind_first_variable(&mut self, r: F::Challenge) { + let w = self.num_vars; + debug_assert!(w > 0); + + let new_size = 3_usize.pow((w - 1) as u32); + let one = F::one(); + + let r_term = r * (r - one); + for new_idx in 0..new_size { + let old_base_idx = new_idx * 3; + let eval_at_0 = self.evals[old_base_idx]; // z_0 = 0 + let eval_at_1 = self.evals[old_base_idx + 1]; // z_0 = 1 + let eval_at_inf = self.evals[old_base_idx + 2]; // z_0 = ∞ + + self.evals[new_idx] = eval_at_0 * (one - r) + eval_at_1 * r + eval_at_inf * r_term; + } + + self.num_vars -= 1; + self.evals.truncate(new_size); + } + + /// Project t'(z_0, z_1, ..., z_{w-1}) to a univariate in z_0 by summing + /// against `E_active` over the remaining coordinates. + /// + /// The `E_active` table is interpreted identically to the existing outer + /// Spartan streaming implementation: each index encodes, in binary, which + /// of z_1..z_{w-1} take the "active" value (mapped to base-3 offset 1). + /// `first_coord_val` is the z_0 coordinate in {0, 1, 2}, where 2 encodes ∞. + pub fn project_to_first_variable(&self, E_active: &[F], first_coord_val: usize) -> F { + let w = self.num_vars; + debug_assert!(w >= 1); + + let offset = first_coord_val; // z_0 lives at the units place in base-3 + + E_active + .iter() + .enumerate() + .map(|(eq_active_idx, eq_active_val)| { + let mut index = offset; + let mut temp = eq_active_idx; + let mut power = 3; // start at 3^1 for z_1 + + for _ in 0..(w - 1) { + if temp & 1 == 1 { + index += power; + } + power *= 3; + temp >>= 1; + } + + self.evals[index] * *eq_active_val + }) + .sum() + } +} + +impl PolynomialBinding for MultiquadraticPolynomial { + fn is_bound(&self) -> bool { + self.num_vars == 0 || self.evals.len() == 1 + } + + #[tracing::instrument(skip_all, name = "MultiquadraticPolynomial::bind")] + fn bind(&mut self, r: F::Challenge, order: BindingOrder) { + match order { + BindingOrder::LowToHigh => self.bind_first_variable(r), + BindingOrder::HighToLow => { + // Not currently needed by the outer Spartan streaming code. + unimplemented!( + "HighToLow binding order is not implemented for MultiquadraticPolynomial" + ) + } + } + } + + fn bind_parallel(&mut self, r: F::Challenge, order: BindingOrder) { + // Window sizes are small; fall back to the sequential implementation. + self.bind(r, order); + } + + fn final_sumcheck_claim(&self) -> F { + debug_assert!(self.is_bound()); + debug_assert_eq!(self.evals.len(), 1); + self.evals[0] + } +} diff --git a/jolt-core/src/poly/split_eq_poly.rs b/jolt-core/src/poly/split_eq_poly.rs index c2dd0ee1c..3514dff71 100644 --- a/jolt-core/src/poly/split_eq_poly.rs +++ b/jolt-core/src/poly/split_eq_poly.rs @@ -33,6 +33,9 @@ pub struct GruenSplitEqPolynomial { pub(crate) w: Vec, pub(crate) E_in_vec: Vec>, pub(crate) E_out_vec: Vec>, + /// Cached `[1]` table used to represent eq over zero variables when a side + /// (head, inner, or active) has no bits. + one_table: Vec, pub(crate) binding_order: BindingOrder, } @@ -52,18 +55,20 @@ impl GruenSplitEqPolynomial { // | | last element // | second half of remaining elements (for E_in) // first half of remaining elements (for E_out) - let (_, wprime) = w.split_last().unwrap(); + let (_w_last, wprime) = w.split_last().unwrap(); let (w_out, w_in) = wprime.split_at(m); let (E_out_vec, E_in_vec) = rayon::join( || EqPolynomial::evals_cached(w_out), || EqPolynomial::evals_cached(w_in), ); + let one_table = vec![F::one()]; Self { current_index: w.len(), current_scalar: scaling_factor.unwrap_or(F::one()), w: w.to_vec(), E_in_vec, E_out_vec, + one_table, binding_order, } } @@ -78,6 +83,7 @@ impl GruenSplitEqPolynomial { || EqPolynomial::evals_cached_rev(w_in), || EqPolynomial::evals_cached_rev(w_out), ); + let one_table = vec![F::one()]; Self { current_index: 0, // Start from 0 for high-to-low up to w.len() - 1 @@ -85,6 +91,7 @@ impl GruenSplitEqPolynomial { w: w.to_vec(), E_in_vec, E_out_vec, + one_table, binding_order, } } @@ -107,6 +114,16 @@ impl GruenSplitEqPolynomial { } } + /// Number of variables that have already been bound into `current_scalar`. + /// For LowToHigh this is `w.len() - current_index`; for HighToLow it is + /// `current_index`. + pub fn num_challenges(&self) -> usize { + match self.binding_order { + BindingOrder::LowToHigh => self.w.len() - self.current_index, + BindingOrder::HighToLow => self.current_index, + } + } + pub fn E_in_current_len(&self) -> usize { self.E_in_vec.last().map_or(0, |v| v.len()) } @@ -125,6 +142,146 @@ impl GruenSplitEqPolynomial { self.E_out_vec.last().map_or(&[], |v| v.as_slice()) } + /// Return the (E_out, E_in) tables corresponding to a streaming window of the + /// given `window_size`, using an explicit slice-based factorisation of the + /// current unbound variables. + /// + /// Semantics (LowToHigh): + /// - Let `num_unbound = current_index` and `remaining_w = w[..num_unbound]`. + /// - For a window of size `window_size >= 1`, define: + /// - `w_window` as the last `window_size` bits of `remaining_w` + /// - `w_head` as the prefix before `w_window` + /// - within `w_window`, the last bit is the current Gruen variable and the + /// preceding `window_size - 1` bits are the "active" window bits + /// - This function returns eq tables over `w_head`, split into two halves + /// `w_out` and `w_in`: + /// - `w_head = [w_out || w_in]` with `w_out` = first `⌊|w_head| / 2⌋` bits + /// - `eq(w_head, (x_out, x_in)) = E_out[x_out] * E_in[x_in]`. + /// + /// The active window bits are handled separately by [`E_active_for_window`]. + /// Together they satisfy, for `BindingOrder::LowToHigh`, + /// log2(|E_out|) + log2(|E_in|) + log2(|E_active|) + 1 = #unbound bits, + /// where the final `+ 1` accounts for the current linear Gruen bit. + /// + /// This helper returns slices and represents "no head bits" as + /// single-entry `[1]` tables, matching `eq((), ()) = 1`. + pub fn E_out_in_for_window(&self, window_size: usize) -> (&[F], &[F]) { + if window_size == 0 { + return (&self.one_table, &self.one_table); + } + + match self.binding_order { + BindingOrder::LowToHigh => { + let num_unbound = self.current_index; + if num_unbound == 0 { + return (&self.one_table, &self.one_table); + } + + // Restrict window size to the actually available unbound bits. + let window_size = core::cmp::min(window_size, num_unbound); + let head_len = num_unbound.saturating_sub(window_size); + if head_len == 0 { + // No head bits: represent as eq over zero vars. + return (&self.one_table, &self.one_table); + } + + // The head prefix consists of the earliest `head_len` bits of `w`. + // These live entirely in the original `[w_out || w_in] = w[..n-1]` + // region, so we can factor them via prefixes of `w_out` and `w_in`. + let n = self.w.len(); + let m = n / 2; + + let head_out_bits = core::cmp::min(head_len, m); + let head_in_bits = head_len.saturating_sub(head_out_bits); + + // Invariant: head_out_bits + head_in_bits == head_len + // and head_len <= num_unbound - window_size < n - 1 + debug_assert_eq!( + head_out_bits + head_in_bits, + head_len, + "head bit split mismatch: {head_out_bits} + {head_in_bits} != {head_len}", + ); + debug_assert!( + head_out_bits <= m, + "head_out_bits={head_out_bits} exceeds m={m}", + ); + debug_assert!( + head_in_bits <= n - 1 - m, + "head_in_bits={} exceeds available in bits={}", + head_in_bits, + n - 1 - m + ); + + let e_out = if head_out_bits == 0 { + &self.one_table + } else { + debug_assert!( + head_out_bits < self.E_out_vec.len(), + "head_out_bits={} out of bounds for E_out_vec.len()={}", + head_out_bits, + self.E_out_vec.len() + ); + &self.E_out_vec[head_out_bits] + }; + let e_in = if head_in_bits == 0 { + &self.one_table + } else { + debug_assert!( + head_in_bits < self.E_in_vec.len(), + "head_in_bits={} E_in_vec.len()={}", + head_in_bits, + self.E_in_vec.len() + ); + &self.E_in_vec[head_in_bits] + }; + + (e_out, e_in) + } + BindingOrder::HighToLow => { + // Streaming windows are not defined for HighToLow in the current + // Spartan code paths; return neutral head tables. + unimplemented!("Not implemented for high to low"); + } + } + } + + /// Return the equality table over the "active" window bits (all but the + /// last variable in the current streaming window). This is used when + /// projecting the multiquadratic t'(z_0, ..., z_{w-1}) down to a univariate + /// in the first variable by summing against eq(tau_active, ·) over the + /// remaining coordinates. + /// + /// We derive the active slice directly from the unbound portion of `w`. + /// For LowToHigh binding, the unbound variables are `w[..current_index]`; + /// the last `window_size` of these belong to the current window, and all + /// but the final one are "active". + pub fn E_active_for_window(&self, window_size: usize) -> Vec { + if window_size <= 1 { + // No active bits in a size-0/1 window; eq over zero vars is [1]. + return vec![F::one()]; + } + + match self.binding_order { + BindingOrder::LowToHigh => { + let num_unbound = self.current_index; + if window_size > num_unbound { + // Clamp to the maximum meaningful window size at this round. + return vec![F::one()]; + } + let remaining_w = &self.w[..num_unbound]; + let window_start = remaining_w.len() - window_size; + let (_w_body, w_window) = remaining_w.split_at(window_start); + let (w_active, _w_curr_slice) = w_window.split_at(window_size - 1); + // We only need the full eq table over the active window bits. + EqPolynomial::::evals(w_active) + } + BindingOrder::HighToLow => { + // Not used for the outer Spartan streaming code. + unimplemented!("Not implemented for high to low"); + } + } + } + #[tracing::instrument(skip_all, name = "GruenSplitEqPolynomial::bind")] pub fn bind(&mut self, r: F::Challenge) { match self.binding_order { @@ -425,6 +582,21 @@ mod tests { use ark_bn254::Fr; use ark_std::test_rng; + #[test] + fn window_out_in() { + const NUM_VARS: usize = 17; + let mut rng = test_rng(); + let w: Vec<::Challenge> = + std::iter::repeat_with(|| ::Challenge::random(&mut rng)) + .take(NUM_VARS) + .collect(); + + let split_eq = GruenSplitEqPolynomial::::new(&w, BindingOrder::LowToHigh); + let (e_prime_out, e_prime_in) = split_eq.E_out_in_for_window(1); + assert_eq!(split_eq.E_out_current_len(), 1 << 8); + assert_eq!(e_prime_out.len(), split_eq.E_out_current().len()); + assert_eq!(e_prime_in.len(), split_eq.E_in_current().len()); + } #[test] fn bind_low_high() { const NUM_VARS: usize = 10; @@ -473,4 +645,99 @@ mod tests { assert_eq!(regular_eq.Z[..regular_eq.len()], merged.Z[..merged.len()]); } } + + /// For window_size = 1, `E_out_in_for_window` should factor the eq polynomial + /// over the head bits `w[..current_index-1]` into a product of two tables. + #[test] + fn window_size_one_matches_current() { + const NUM_VARS: usize = 10; + let mut rng = test_rng(); + let w: Vec<::Challenge> = + std::iter::repeat_with(|| ::Challenge::random(&mut rng)) + .take(NUM_VARS) + .collect(); + + let mut split_eq: GruenSplitEqPolynomial = + GruenSplitEqPolynomial::new(&w, BindingOrder::LowToHigh); + + for _round in 0..NUM_VARS { + let num_unbound = split_eq.current_index; + if num_unbound <= 1 { + break; + } + + // Factor head = w[..num_unbound-1] into (E_out, E_in). + let (e_out_window, e_in_window) = split_eq.E_out_in_for_window(1); + let w_head = &split_eq.w[..num_unbound - 1]; + let head_evals = EqPolynomial::evals(w_head); + + let num_x_out = e_out_window.len(); + let num_x_in = e_in_window.len(); + assert_eq!(num_x_out * num_x_in, head_evals.len()); + + let x_in_bits = num_x_in.log_2(); + for x_out in 0..num_x_out { + for x_in in 0..num_x_in { + let idx = (x_out << x_in_bits) | x_in; + assert_eq!( + e_out_window[x_out] * e_in_window[x_in], + head_evals[idx], + "factorisation mismatch at round={_round}, x_out={x_out}, x_in={x_in}", + ); + } + } + + let r = ::Challenge::random(&mut rng); + split_eq.bind(r); + } + } + + /// Check basic bit-accounting invariants for the streaming factorisation: + /// log2(|E_out|) + log2(|E_in|) + log2(|E_active|) + 1 = number of unbound variables + /// for all window sizes and all rounds (LowToHigh). + #[test] + fn window_bit_accounting_invariants() { + const NUM_VARS: usize = 8; + let mut rng = test_rng(); + let w: Vec<::Challenge> = + std::iter::repeat_with(|| ::Challenge::random(&mut rng)) + .take(NUM_VARS) + .collect(); + + let mut split_eq: GruenSplitEqPolynomial = + GruenSplitEqPolynomial::new(&w, BindingOrder::LowToHigh); + + // Walk through all rounds, checking all window sizes that are + // meaningful at that point (at least one unbound variable). + for _round in 0..NUM_VARS { + let num_unbound = split_eq.len().log_2(); + if num_unbound == 0 { + break; + } + + for window_size in 1..=num_unbound { + let (e_out, e_in) = split_eq.E_out_in_for_window(window_size); + let e_active = split_eq.E_active_for_window(window_size); + // By construction, each side represents at least one entry. + debug_assert!(!e_out.is_empty()); + debug_assert!(!e_in.is_empty()); + debug_assert!(!e_active.is_empty()); + + let bits_out = e_out.len().log_2(); + let bits_in = e_in.len().log_2(); + let bits_active = e_active.len().log_2(); + + // One bit is reserved for the current variable in the Gruen + // cubic (the eq polynomial is linear in that bit). + assert_eq!( + bits_out + bits_in + bits_active + 1, + num_unbound, + "bit accounting failed for window_size={window_size} (bits_out={bits_out}, bits_in={bits_in}, bits_active={bits_active}, num_unbound={num_unbound})", + ); + } + + let r = ::Challenge::random(&mut rng); + split_eq.bind(r); + } + } } diff --git a/jolt-core/src/subprotocols/mod.rs b/jolt-core/src/subprotocols/mod.rs index d7cb4465f..5b10836bb 100644 --- a/jolt-core/src/subprotocols/mod.rs +++ b/jolt-core/src/subprotocols/mod.rs @@ -1,6 +1,7 @@ pub mod booleanity; pub mod hamming_weight; pub mod mles_product_sum; +pub mod streaming_schedule; pub mod sumcheck; pub mod sumcheck_prover; pub mod sumcheck_verifier; diff --git a/jolt-core/src/subprotocols/streaming_schedule.rs b/jolt-core/src/subprotocols/streaming_schedule.rs new file mode 100644 index 000000000..72c6f039d --- /dev/null +++ b/jolt-core/src/subprotocols/streaming_schedule.rs @@ -0,0 +1,346 @@ +use allocative::Allocative; +//TODO: Robustify half split +pub trait StreamingSchedule: Send + Sync { + /// If this round is the switch-over round then we finally + /// materialise the sum-check polynomials in memory. + /// Prior, to this round, the streaming data structure + /// use to respond to verifier messages is constructed directly + /// from the trace. + fn is_switch_over_point(&self, round: usize) -> bool; + fn before_switch_over_point(&self, round: usize) -> bool; + fn after_switch_over_point(&self, round: usize) -> bool; + + /// Returns true if we are starting a new streaming window. + /// This will lead to recomputation of the streaming data structure + /// storing the multi-variable polynomial used to computer prover messages. + fn is_window_start(&self, round: usize) -> bool; + + /// Get the total number of rounds for sumcheck + fn num_rounds(&self) -> usize; + + /// Get the number of unbound variables in given window + fn num_unbound_vars(&self, round: usize) -> usize; +} + +#[derive(Debug, Clone, Allocative)] +pub struct HalfSplitSchedule { + num_rounds: usize, + constant_window_width: usize, + switch_over_point: usize, + window_starts: Vec, +} + +impl HalfSplitSchedule { + pub fn new(num_rounds: usize, window_width: usize) -> Self { + let m = num_rounds / 2; + let mut window_starts = Vec::new(); + + // Generate constant-width windows until we exceed m + let mut j = 0; + loop { + let window_start = j * window_width; + if window_start >= m { + break; + } + window_starts.push(window_start); + j += 1; + } + + // Switch-over point is the first window that would exceed m + let switch_over_point = j * window_width; + + // Add single-round windows from switch-over to end + for i in switch_over_point..num_rounds { + window_starts.push(i); + } + + Self { + num_rounds, + constant_window_width: window_width, + switch_over_point, + window_starts, + } + } +} + +impl StreamingSchedule for HalfSplitSchedule { + fn is_switch_over_point(&self, round: usize) -> bool { + self.switch_over_point == round + } + + fn after_switch_over_point(&self, round: usize) -> bool { + round > self.switch_over_point + } + + fn before_switch_over_point(&self, round: usize) -> bool { + round < self.switch_over_point + } + + fn is_window_start(&self, round: usize) -> bool { + self.window_starts.contains(&round) + } + + fn num_rounds(&self) -> usize { + self.num_rounds + } + fn num_unbound_vars(&self, round: usize) -> usize { + // Find the index of the window start that is <= round + let current_idx = self + .window_starts + .iter() + .rposition(|&w| w <= round) + .expect("no window start found for round"); + + // Get the next window start value + let next_window_start = if current_idx + 1 < self.window_starts.len() { + self.window_starts[current_idx + 1] + } else { + // If this is the last window, assume num_rounds or some default + self.num_rounds + }; + + // Return the difference + next_window_start - round + } +} + +//#[derive(Debug, Clone, Allocative)] +//pub struct IncreasingWindowSchedule { +// pub(crate) num_rounds: usize, +// pub(crate) linear_start: usize, +// pub(crate) window_starts: Vec, +//} +// +//impl IncreasingWindowSchedule { +// pub fn new(num_rounds: usize) -> Self { +// let linear_start = num_rounds.div_ceil(2); +// +// let mut window_starts = Vec::new(); +// let mut round = 0usize; +// let mut width = 1usize; +// while round < linear_start { +// window_starts.push(round); +// let remaining = linear_start - round; +// let w = core::cmp::min(width, remaining); +// round += w; +// width += 1; +// } +// +// Self { +// num_rounds, +// linear_start, +// window_starts, +// } +// } +//} +// +//impl StreamingSchedule for IncreasingWindowSchedule { +// fn is_streaming(&self, round: usize) -> bool { +// round < self.linear_start +// } +// +// fn is_window_start(&self, round: usize) -> bool { +// self.window_starts.contains(&round) +// } +// +// fn is_first_linear(&self, round: usize) -> bool { +// round == self.linear_start +// } +// +// fn num_rounds(&self) -> usize { +// self.num_rounds +// } +// +// fn num_unbound_vars(&self, round: usize) -> usize { +// if round >= self.num_rounds { +// return 0; +// } +// if self.is_streaming(round) { +// let next_boundary = self +// .window_starts +// .iter() +// .find(|&&start| start > round) +// .copied() +// .unwrap_or(self.linear_start); +// next_boundary - round +// } else { +// self.num_rounds - round +// } +// } +//} + +/// A schedule that disables streaming and runs all sumcheck rounds in +/// the linear-time mode. +#[derive(Debug, Clone, Allocative)] +pub struct LinearOnlySchedule { + num_rounds: usize, +} + +impl LinearOnlySchedule { + pub fn new(num_rounds: usize) -> Self { + Self { num_rounds } + } +} + +impl StreamingSchedule for LinearOnlySchedule { + fn is_window_start(&self, _round: usize) -> bool { + true + } + + fn after_switch_over_point(&self, round: usize) -> bool { + round > 0 + } + + fn before_switch_over_point(&self, _round: usize) -> bool { + false + } + + fn is_switch_over_point(&self, round: usize) -> bool { + round == 0 + } + + fn num_rounds(&self) -> usize { + self.num_rounds + } + + fn num_unbound_vars(&self, _round: usize) -> usize { + 1 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic_construction() { + let schedule = HalfSplitSchedule::new(16, 3); + assert_eq!(schedule.num_rounds(), 16); + assert_eq!(schedule.switch_over_point, 9); + assert_eq!( + schedule.window_starts, + vec![0, 3, 6, 9, 10, 11, 12, 13, 14, 15] + ); + } + + #[test] + fn test_switch_over_point() { + let schedule = HalfSplitSchedule::new(16, 3); + + // Before switch-over + assert!(schedule.before_switch_over_point(0)); + assert!(schedule.before_switch_over_point(8)); + assert!(!schedule.before_switch_over_point(9)); + + // At switch-over + assert!(schedule.is_switch_over_point(9)); + assert!(!schedule.is_switch_over_point(8)); + assert!(!schedule.is_switch_over_point(10)); + + // After switch-over + assert!(schedule.after_switch_over_point(10)); + assert!(schedule.after_switch_over_point(15)); + assert!(!schedule.after_switch_over_point(9)); + assert!(!schedule.after_switch_over_point(8)); + } + + #[test] + fn test_window_starts() { + let schedule = HalfSplitSchedule::new(16, 3); + + // First half - constant width windows + assert!(schedule.is_window_start(0)); + assert!(schedule.is_window_start(3)); + assert!(schedule.is_window_start(6)); + assert!(!schedule.is_window_start(1)); + assert!(!schedule.is_window_start(2)); + assert!(!schedule.is_window_start(4)); + + // After switch-over - every round is a window start + assert!(schedule.is_window_start(9)); + assert!(schedule.is_window_start(10)); + assert!(schedule.is_window_start(11)); + assert!(schedule.is_window_start(15)); + } + + #[test] + fn test_num_unbound_vars_first_window() { + let schedule = HalfSplitSchedule::new(16, 3); + + // Window [0, 3): 3 unbound vars initially + assert_eq!(schedule.num_unbound_vars(0), 3); + assert_eq!(schedule.num_unbound_vars(1), 2); + assert_eq!(schedule.num_unbound_vars(2), 1); + } + + #[test] + fn test_num_unbound_vars_middle_window() { + let schedule = HalfSplitSchedule::new(16, 3); + + // Window [3, 6): 3 unbound vars + assert_eq!(schedule.num_unbound_vars(3), 3); + assert_eq!(schedule.num_unbound_vars(4), 2); + assert_eq!(schedule.num_unbound_vars(5), 1); + + // Window [6, 9): 3 unbound vars + assert_eq!(schedule.num_unbound_vars(6), 3); + assert_eq!(schedule.num_unbound_vars(7), 2); + assert_eq!(schedule.num_unbound_vars(8), 1); + } + + #[test] + fn test_num_unbound_vars_after_switchover() { + let schedule = HalfSplitSchedule::new(16, 3); + + // After switch-over, each window is size 1 + assert_eq!(schedule.num_unbound_vars(9), 1); + assert_eq!(schedule.num_unbound_vars(10), 1); + assert_eq!(schedule.num_unbound_vars(11), 1); + assert_eq!(schedule.num_unbound_vars(14), 1); + + // Last round + assert_eq!(schedule.num_unbound_vars(15), 1); + } + + #[test] + fn test_odd_num_rounds() { + let schedule = HalfSplitSchedule::new(15, 3); + + // linear_start = 15.div_ceil(2) = 8 + // switch_over_point = 9 + assert_eq!(schedule.switch_over_point, 9); + assert_eq!(schedule.window_starts, vec![0, 3, 6, 9, 10, 11, 12, 13, 14]); + } + + #[test] + fn test_small_num_rounds() { + let schedule = HalfSplitSchedule::new(4, 2); + + assert_eq!(schedule.switch_over_point, 2); + assert_eq!(schedule.window_starts, vec![0, 2, 3]); + + assert_eq!(schedule.num_unbound_vars(0), 2); + assert_eq!(schedule.num_unbound_vars(1), 1); + assert_eq!(schedule.num_unbound_vars(2), 1); + assert_eq!(schedule.num_unbound_vars(3), 1); + } + + #[test] + fn test_large_window_width() { + let schedule = HalfSplitSchedule::new(20, 10); + + assert_eq!(schedule.switch_over_point, 10); + assert_eq!( + schedule.window_starts, + vec![0, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + ); + + // First window covers rounds 0-9 + assert_eq!(schedule.num_unbound_vars(0), 10); + assert_eq!(schedule.num_unbound_vars(5), 5); + assert_eq!(schedule.num_unbound_vars(9), 1); + + // Second window is just round 10 + assert_eq!(schedule.num_unbound_vars(10), 1); + } +} diff --git a/jolt-core/src/utils/expanding_table.rs b/jolt-core/src/utils/expanding_table.rs index bec7ac712..8ab24abb0 100644 --- a/jolt-core/src/utils/expanding_table.rs +++ b/jolt-core/src/utils/expanding_table.rs @@ -22,6 +22,10 @@ impl ExpandingTable { self.len } + pub fn order(&self) -> BindingOrder { + self.binding_order + } + /// Initializes an `ExpandingTable` with the given `capacity`. #[tracing::instrument(skip_all, name = "ExpandingTable::new")] pub fn new(capacity: usize, binding_order: BindingOrder) -> Self { @@ -52,6 +56,7 @@ impl ExpandingTable { /// Updates this table (expanding it by a factor of 2) to incorporate /// the new random challenge `r_j`. + /// TODO: this is bad parallelisation. #[tracing::instrument(skip_all, name = "ExpandingTable::update")] pub fn update(&mut self, r_j: F::Challenge) { match self.binding_order { diff --git a/jolt-core/src/zkvm/prover.rs b/jolt-core/src/zkvm/prover.rs index 8e7ef4c76..b649b6db1 100644 --- a/jolt-core/src/zkvm/prover.rs +++ b/jolt-core/src/zkvm/prover.rs @@ -1,3 +1,4 @@ +use crate::subprotocols::streaming_schedule::LinearOnlySchedule; use std::{ collections::HashMap, fs::File, @@ -97,6 +98,8 @@ pub struct JoltCpuProver< pub program_io: JoltDevice, pub lazy_trace: LazyTraceIterator, pub trace: Arc>, + pub checkpoints: Vec>, + pub checkpoint_interval: usize, pub advice: JoltAdvice, pub unpadded_trace_len: usize, pub padded_trace_len: usize, @@ -107,7 +110,6 @@ pub struct JoltCpuProver< pub final_ram_state: Vec, pub one_hot_params: OneHotParams, } - impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscript: Transcript> JoltCpuProver<'a, F, PCS, ProofTranscript> { @@ -129,6 +131,21 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip program_size: Some(preprocessing.memory_layout.program_size), }; + // TODO: Currently we're manifesting the entire trace + // There is some debate on how to stream this efficiently + // We can move forward with streaming implementations assuming this will be + // fixed in the coming days + + //let checkpoint_interval = 256; + //let (checkpoints, _jolt_device) = trace_checkpoints( + // elf_contents, + // inputs, + // untrusted_advice, + // trusted_advice, + // &memory_config, + // checkpoint_interval, + //); + let (lazy_trace, trace, final_memory_state, program_io) = { let _pprof_trace = pprof_scope!("trace"); guest::program::trace( @@ -140,6 +157,37 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip &memory_config, ) }; + + //#[cfg(debug_assertions)] + //{ + // for (time_step_idx, expected_cycle) in trace.iter().enumerate() { + // // Calculate which checkpoint and offset + // let checkpoint_idx = time_step_idx / checkpoint_interval; + // let offset = time_step_idx % checkpoint_interval; + // + // // Clone the checkpoint and advance to target + // let mut iter = checkpoints[checkpoint_idx].clone(); + // + // // Skip offset cycles + // for _ in 0..offset { + // iter.next(); + // } + // + // // Get the cycle from checkpoint + // let checkpoint_cycle = iter.next().expect("checkpoint should have cycle"); + // + // // Assert they match + // assert_eq!( + // &checkpoint_cycle, expected_cycle, + // "Mismatch at cycle {time_step_idx}: checkpoint != trace", + // ); + // } + // println!( + // "✓ All {} cycles match between checkpoints and full trace", + // trace.len() + // ); + //} + // let num_riscv_cycles: usize = trace .par_iter() .map(|cycle| { @@ -162,14 +210,20 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip trace.len(), ); - Self::gen_from_trace( + let mut prover = Self::gen_from_trace( preprocessing, lazy_trace, trace, program_io, trusted_advice_commitment, final_memory_state, - ) + ); + + // Set checkpoints after construction + // Vec> + prover.checkpoints = Vec::new(); + prover.checkpoint_interval = 0; + prover } pub fn gen_from_trace( @@ -233,6 +287,8 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip program_io, lazy_trace, trace: trace.into(), + checkpoints: Vec::new(), // Empty by default + checkpoint_interval: 0, // Default value advice: JoltAdvice { untrusted_advice_polynomial: None, trusted_advice_commitment, @@ -492,10 +548,17 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip &mut self.transcript, ); + // Every sum-check with num_rounds > 1 requires a schedule + // which dictates the compute_message and bind methods + let schedule = LinearOnlySchedule::new(uni_skip_state.tau.len() - 1); + //let schedule = HalfSplitSchedule::new(uni_skip_state.tau.len() - 1, 3); let mut spartan_outer_remaining = OuterRemainingSumcheckProver::gen( Arc::clone(&self.trace), + &self.checkpoints, + self.checkpoint_interval, &self.preprocessing.bytecode, &uni_skip_state, + schedule, ); let (sumcheck_proof, _r_stage1) = BatchedSumcheck::prove( @@ -1417,7 +1480,6 @@ mod tests { init_memory_state, 1 << 16, ); - let prover = RV64IMACProver::gen_from_trace( &preprocessing, lazy_trace, diff --git a/jolt-core/src/zkvm/r1cs/evaluation.rs b/jolt-core/src/zkvm/r1cs/evaluation.rs index 9b6b2767b..f1af9c456 100644 --- a/jolt-core/src/zkvm/r1cs/evaluation.rs +++ b/jolt-core/src/zkvm/r1cs/evaluation.rs @@ -138,6 +138,29 @@ pub struct AzFirstGroup { pub must_start_sequence: bool, // NextIsVirtual && !NextIsFirstInSequence } +impl AzFirstGroup { + /// Fused multiply-add into an unreduced accumulator using Lagrange weights `w` + /// over the univariate-skip base window. This mirrors `az_at_r_first_group` + /// but keeps the result in an `Acc5U` accumulator without reducing. + #[inline(always)] + pub fn fmadd_at_r( + &self, + w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], + acc: &mut Acc5U, + ) { + acc.fmadd(&w[0], &self.not_load_store); + acc.fmadd(&w[1], &self.load_a); + acc.fmadd(&w[2], &self.load_b); + acc.fmadd(&w[3], &self.store); + acc.fmadd(&w[4], &self.add_sub_mul); + acc.fmadd(&w[5], &self.not_add_sub_mul); + acc.fmadd(&w[6], &self.assert_flag); + acc.fmadd(&w[7], &self.should_jump); + acc.fmadd(&w[8], &self.virtual_instruction); + acc.fmadd(&w[9], &self.must_start_sequence); + } +} + /// Magnitudes for the first group (kept small: bool/u64/S64) #[derive(Clone, Copy, Debug)] pub struct BzFirstGroup { @@ -153,6 +176,29 @@ pub struct BzFirstGroup { pub one_minus_do_not_update_unexpanded_pc: bool, // 1 - DoNotUpdateUnexpandedPC } +impl BzFirstGroup { + /// Fused multiply-add into an unreduced accumulator using Lagrange weights `w` + /// over the univariate-skip base window. This mirrors `bz_at_r_first_group` + /// but keeps the result in an `Acc6S` accumulator without reducing. + #[inline(always)] + pub fn fmadd_at_r( + &self, + w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], + acc: &mut Acc6S, + ) { + acc.fmadd(&w[0], &self.ram_addr); + acc.fmadd(&w[1], &self.ram_read_minus_ram_write); + acc.fmadd(&w[2], &self.ram_read_minus_rd_write); + acc.fmadd(&w[3], &self.rs2_minus_ram_write); + acc.fmadd(&w[4], &self.left_lookup); + acc.fmadd(&w[5], &self.left_lookup_minus_left_input); + acc.fmadd(&w[6], &self.lookup_output_minus_one); + acc.fmadd(&w[7], &self.next_unexp_pc_minus_lookup_output); + acc.fmadd(&w[8], &self.next_pc_minus_pc_plus_one); + acc.fmadd(&w[9], &self.one_minus_do_not_update_unexpanded_pc); + } +} + /// Guards for the second group (all booleans except two u8 flags) #[derive(Clone, Copy, Debug)] pub struct AzSecondGroup { @@ -167,6 +213,28 @@ pub struct AzSecondGroup { pub not_jump_or_branch: bool, // !(Jump || ShouldBranch) } +impl AzSecondGroup { + /// Fused multiply-add into an unreduced accumulator using Lagrange weights `w` + /// over the univariate-skip base window. This mirrors `az_at_r_second_group` + /// but keeps the result in an `Acc5U` accumulator without reducing. + #[inline(always)] + pub fn fmadd_at_r( + &self, + w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], + acc: &mut Acc5U, + ) { + acc.fmadd(&w[0], &self.load_or_store); + acc.fmadd(&w[1], &self.add); + acc.fmadd(&w[2], &self.sub); + acc.fmadd(&w[3], &self.mul); + acc.fmadd(&w[4], &self.not_add_sub_mul_advice); + acc.fmadd(&w[5], &self.write_lookup_to_rd); + acc.fmadd(&w[6], &self.write_pc_to_rd); + acc.fmadd(&w[7], &self.should_branch); + acc.fmadd(&w[8], &self.not_jump_or_branch); + } +} + /// Magnitudes for the second group (mixed precision up to S160) #[derive(Clone, Copy, Debug)] pub struct BzSecondGroup { @@ -181,6 +249,28 @@ pub struct BzSecondGroup { pub next_unexp_pc_minus_expected: S64, // NextUnexpandedPC - (UnexpandedPC + const) } +impl BzSecondGroup { + /// Fused multiply-add into an unreduced accumulator using Lagrange weights `w` + /// over the univariate-skip base window. This mirrors `bz_at_r_second_group` + /// but keeps the result in an `Acc7S` accumulator without reducing. + #[inline(always)] + pub fn fmadd_at_r( + &self, + w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], + acc: &mut Acc7S, + ) { + acc.fmadd(&w[0], &self.ram_addr_minus_rs1_plus_imm); + acc.fmadd(&w[1], &self.right_lookup_minus_add_result); + acc.fmadd(&w[2], &self.right_lookup_minus_sub_result); + acc.fmadd(&w[3], &self.right_lookup_minus_product); + acc.fmadd(&w[4], &self.right_lookup_minus_right_input); + acc.fmadd(&w[5], &self.rd_write_minus_lookup_output); + acc.fmadd(&w[6], &self.rd_write_minus_pc_plus_const); + acc.fmadd(&w[7], &self.next_unexp_pc_minus_pc_plus_imm); + acc.fmadd(&w[8], &self.next_unexp_pc_minus_expected); + } +} + /// Unified evaluator wrapper with typed accessors for both groups #[derive(Clone, Copy, Debug)] pub struct R1CSEval<'a, F: JoltField> { @@ -293,6 +383,22 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { acc.barrett_reduce() } + /// Fused accumulate of first-group Az and Bz into unreduced accumulators using + /// Lagrange weights `w`. This keeps everything in unreduced form; callers are + /// responsible for reducing at the end. + #[inline(always)] + pub fn fmadd_first_group_at_r( + &self, + w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], + acc_az: &mut Acc5U, + acc_bz: &mut Acc6S, + ) { + let az = self.eval_az_first_group(); + let bz = self.eval_bz_first_group(); + az.fmadd_at_r(w, acc_az); + bz.fmadd_at_r(w, acc_bz); + } + /// Product Az·Bz at the j-th extended uniskip target for the first group (uses precomputed weights). pub fn extended_azbz_product_first_group(&self, j: usize) -> S192 { let coeffs_i32: &[i32; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE] = &COEFFS_PER_J[j]; @@ -519,6 +625,21 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { acc.barrett_reduce() } + /// Fused accumulate of second-group Az and Bz into unreduced accumulators + /// using Lagrange weights `w`. This keeps everything in unreduced form; callers + /// are responsible for reducing at the end. + #[inline(always)] + pub fn fmadd_second_group_at_r( + &self, + w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], + acc_az: &mut Acc5U, + acc_bz: &mut Acc7S, + ) { + let az = self.eval_az_second_group(); + let bz = self.eval_bz_second_group(); + az.fmadd_at_r(w, acc_az); + bz.fmadd_at_r(w, acc_bz); + } /// Product Az·Bz at the j-th extended uniskip target for the second group (uses precomputed weights). pub fn extended_azbz_product_second_group(&self, j: usize) -> S192 { #[cfg(test)] diff --git a/jolt-core/src/zkvm/spartan/outer.rs b/jolt-core/src/zkvm/spartan/outer.rs index 6b08d6692..3dd2639fd 100644 --- a/jolt-core/src/zkvm/spartan/outer.rs +++ b/jolt-core/src/zkvm/spartan/outer.rs @@ -4,25 +4,30 @@ use allocative::Allocative; use ark_std::Zero; use rayon::prelude::*; use tracer::instruction::Cycle; +use tracer::LazyTraceIterator; +use crate::field::BarrettReduce; use crate::field::{FMAdd, JoltField, MontgomeryReduce}; use crate::poly::dense_mlpoly::DensePolynomial; use crate::poly::eq_poly::EqPolynomial; use crate::poly::lagrange_poly::LagrangePolynomial; -use crate::poly::multilinear_polynomial::BindingOrder; +use crate::poly::multilinear_polynomial::{BindingOrder, PolynomialBinding}; +use crate::poly::multiquadratic_poly::MultiquadraticPolynomial; use crate::poly::opening_proof::{ OpeningAccumulator, OpeningPoint, ProverOpeningAccumulator, SumcheckId, VerifierOpeningAccumulator, BIG_ENDIAN, LITTLE_ENDIAN, }; use crate::poly::split_eq_poly::GruenSplitEqPolynomial; use crate::poly::unipoly::UniPoly; +use crate::subprotocols::streaming_schedule::StreamingSchedule; use crate::subprotocols::sumcheck_prover::{ SumcheckInstanceProver, UniSkipFirstRoundInstanceProver, }; use crate::subprotocols::sumcheck_verifier::SumcheckInstanceVerifier; use crate::subprotocols::univariate_skip::{build_uniskip_first_round_poly, UniSkipState}; use crate::transcripts::Transcript; -use crate::utils::accumulation::Acc8S; +use crate::utils::accumulation::{Acc5U, Acc6S, Acc7S, Acc8S}; +use crate::utils::expanding_table::ExpandingTable; use crate::utils::math::Math; #[cfg(feature = "allocative")] use crate::utils::profiling::print_data_structure_heap_usage; @@ -43,6 +48,10 @@ use allocative::FlameGraphBuilder; /// Degree bound of the sumcheck round polynomials for [`OuterRemainingSumcheckVerifier`]. const OUTER_REMAINING_DEGREE_BOUND: usize = 3; +// this represents the index position in multi-quadratic poly array +// This should actually be d where degree is the degree of the streaming data structure +// For example : MultiQuadratic has d=2; for cubic this would be 3 etc. +const INFINITY: usize = 2; // Spartan Outer sumcheck // (with univariate-skip first round on Z, and no Cz term given all eq conditional constraints) @@ -204,29 +213,58 @@ impl UniSkipFirstRoundInstanceProver /// SumcheckInstance for Spartan outer rounds after the univariate-skip first round. /// Round 0 in this instance corresponds to the "streaming" round; subsequent rounds /// use the remaining linear-time algorithm over cycle variables. +//#[derive(Allocative)] +//pub struct OuterRemainingSumcheckProver { +// #[allocative(skip)] +// bytecode_preprocessing: BytecodePreprocessing, +// #[allocative(skip)] +// trace: Arc>, +// split_eq_poly: GruenSplitEqPolynomial, +// az: Option>, +// bz: Option>, +// t_prime_poly: Option>, // multiquadratic polynomial used to answer queries in a streaming window +// /// The first round evals (t0, t_inf) computed from a streaming pass over the trace +// first_round_evals: (F, F), +// #[allocative(skip)] +// params: OuterRemainingSumcheckParams, +//} +// #[derive(Allocative)] -pub struct OuterRemainingSumcheckProver { +pub struct OuterRemainingSumcheckProver<'a, F: JoltField, S: StreamingSchedule + Allocative> { #[allocative(skip)] bytecode_preprocessing: BytecodePreprocessing, #[allocative(skip)] trace: Arc>, + /// Split-eq instance used for both streaming and linear phases of the + /// outer Spartan sumcheck over cycle variables. split_eq_poly: GruenSplitEqPolynomial, - az: DensePolynomial, - bz: DensePolynomial, + az: Option>, + bz: Option>, + t_prime_poly: Option>, // multiquadratic polynomial used to answer queries in a streaming window + r_grid: ExpandingTable, // hadamard product of (1 - r_j, r_j) for bound variables so far to help with streaming /// The first round evals (t0, t_inf) computed from a streaming pass over the trace - first_round_evals: (F, F), #[allocative(skip)] params: OuterRemainingSumcheckParams, + lagrange_evals_r0: [F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], + schedule: S, + #[allocative(skip)] + checkpoints: &'a [std::iter::Take], + checkpoint_interval: usize, } -impl OuterRemainingSumcheckProver { +impl<'a, F: JoltField, S: StreamingSchedule + Allocative> OuterRemainingSumcheckProver<'a, F, S> { #[tracing::instrument(skip_all, name = "OuterRemainingSumcheckProver::gen")] pub fn gen( trace: Arc>, + checkpoints: &'a [std::iter::Take], // Add lifetime 'a, use slice + checkpoint_interval: usize, bytecode_preprocessing: &BytecodePreprocessing, uni: &UniSkipState, + schedule: S, ) -> Self { let bytecode_preprocessing = bytecode_preprocessing.clone(); + let n_cycle_vars = trace.len().log_2(); + let outer_params = OuterRemainingSumcheckParams::new(n_cycle_vars, uni); let lagrange_evals_r = LagrangePolynomial::::evals::< F::Challenge, @@ -248,60 +286,617 @@ impl OuterRemainingSumcheckProver { Some(lagrange_tau_r0), ); - let (t0, t_inf, az_bound, bz_bound) = Self::compute_first_quadratic_evals_and_bound_polys( - &bytecode_preprocessing, - &trace, - &lagrange_evals_r, - &split_eq_poly, - ); - - let n_cycle_vars = trace.len().ilog2() as usize; + // NOTE: The API changed recently: Both binding orders will technically pass + // based on current implementation. + let mut r_grid = ExpandingTable::new(1 << n_cycle_vars, BindingOrder::LowToHigh); + r_grid.reset(F::one()); Self { split_eq_poly, bytecode_preprocessing, trace, - az: az_bound, - bz: bz_bound, - first_round_evals: (t0, t_inf), - params: OuterRemainingSumcheckParams::new(n_cycle_vars, uni), + checkpoints, + checkpoint_interval, + az: None, + bz: None, + t_prime_poly: None, + r_grid, + params: outer_params, + lagrange_evals_r0: lagrange_evals_r, + schedule, } } - /// Compute the quadratic evaluations for the streaming round (right after univariate skip). - /// - /// This uses the streaming algorithm to compute the sum-check polynomial for the round - /// right after the univariate skip round. - /// - /// Recall that we need to compute - /// - /// `t_i(0) = \sum_{x_out} E_out[x_out] \sum_{x_in} E_in[x_in] * - /// unbound_coeffs_a(x_out, x_in, 0, r) * unbound_coeffs_b(x_out, x_in, 0, r)` - /// - /// and - /// - /// `t_i(∞) = \sum_{x_out} E_out[x_out] \sum_{x_in} E_in[x_in] * (unbound_coeffs_a(x_out, - /// x_in, ∞, r) * unbound_coeffs_b(x_out, x_in, ∞, r))` - /// - /// Here the "_a,b" subscript indicates the coefficients of `unbound_coeffs` corresponding to - /// Az and Bz respectively. Note that we index with x_out being the MSB here. - /// - /// Importantly, since the eval at `r` is not cached, we will need to recompute it via another - /// sum - /// - /// `unbound_coeffs_{a,b}(x_out, x_in, {0,∞}, r) = \sum_{y in D} Lagrange(r, y) * - /// unbound_coeffs_{a,b}(x_out, x_in, {0,∞}, y)` - /// - /// (and the eval at ∞ is computed as (eval at 1) - (eval at 0)) - #[inline] - fn compute_first_quadratic_evals_and_bound_polys( - bytecode_preprocessing: &BytecodePreprocessing, - trace: &[Cycle], - lagrange_evals_r: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], - split_eq_poly: &GruenSplitEqPolynomial, - ) -> (F, F, DensePolynomial, DensePolynomial) { - let num_x_out_vals = split_eq_poly.E_out_current_len(); - let num_x_in_vals = split_eq_poly.E_in_current_len(); + // gets the evaluations of az(x, {0,1}^log(jlen), r) and bz(x, {0,1}^log(jlen), r) + // where x is determined by the bit decomposition of offset + // and r is log(klen) variables + // this is used both in window computation (jlen is window size) + // and in converting to linear time (offset is 0, log(jlen) is the number of unbound variables) + // The caller must pass in `scaled_w`, the tensor product of the Lagrange weights + // at r0 with the current `r_grid` weights: + // scaled_w[k][t] = lagrange_evals_r0[t] * r_grid[k] (for klen > 1) + // and scaled_w[0][t] = lagrange_evals_r0[t] when klen == 1 (no r_grid factor). + #[allow(clippy::too_many_arguments)] + fn extrapolate_from_binary_grid_to_tertiary_grid( + &self, + grid_az: &mut [F], + grid_bz: &mut [F], + jlen: usize, + klen: usize, + offset: usize, + parallel: bool, + scaled_w: &[[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE]], + ) { + let preprocess = &self.bytecode_preprocessing; + let _checkpoints = &self.checkpoints; + let _checkpoint_interval = self.checkpoint_interval; + let trace = &self.trace; + debug_assert_eq!(scaled_w.len(), klen); + debug_assert_eq!(grid_az.len(), jlen); + debug_assert_eq!(grid_bz.len(), jlen); + + // Unreduced accumulators per j for Az and the two Bz groups. + let mut acc_az = vec![Acc5U::::zero(); jlen]; + let mut acc_bz_first = vec![Acc6S::::zero(); jlen]; + let mut acc_bz_second = vec![Acc7S::::zero(); jlen]; + + if !parallel { + // Sequential traversal: iterate over j first and then k so that we + // walk consecutive cycles in memory (full_idx increases by 1 inside + // the inner loop). + for j in 0..jlen { + for k in 0..klen { + let full_idx = offset + j * klen + k; + let current_step_idx = full_idx >> 1; + let selector = (full_idx & 1) == 1; + + // TODO: use the lazy trace iterator here instead of indexing directly into the + // trace that is all that needs to change for now + //let row_inputs_prime = R1CSCycleInputs::from_checkpoints::( + // preprocess, + // checkpoints, + // checkpoint_interval, + // current_step_idx, + //); + let row_inputs = + R1CSCycleInputs::from_trace::(preprocess, trace, current_step_idx); + let eval = R1CSEval::::from_cycle_inputs(&row_inputs); + let w_k = &scaled_w[k]; + + if !selector { + eval.fmadd_first_group_at_r(w_k, &mut acc_az[j], &mut acc_bz_first[j]); + } else { + eval.fmadd_second_group_at_r(w_k, &mut acc_az[j], &mut acc_bz_second[j]); + } + } + } + } else { + // Parallel traversal over j for the linear-time prover. + // Each worker owns disjoint accumulators for a fixed j, so there + // are no data races. We reuse the precomputed scaled Lagrange weights + // per k from `scaled_w`, avoiding redundant tensor products. + acc_az + .par_iter_mut() + .zip(acc_bz_first.par_iter_mut()) + .zip(acc_bz_second.par_iter_mut()) + .enumerate() + .for_each(|(j, ((acc_az_j, acc_bz_first_j), acc_bz_second_j))| { + for k in 0..klen { + let full_idx = offset + j * klen + k; + let current_step_idx = full_idx >> 1; + let selector = (full_idx & 1) == 1; + + let row_inputs = + R1CSCycleInputs::from_trace::(preprocess, trace, current_step_idx); + let eval = R1CSEval::::from_cycle_inputs(&row_inputs); + let w_k = &scaled_w[k]; + + if !selector { + eval.fmadd_first_group_at_r(w_k, acc_az_j, acc_bz_first_j); + } else { + eval.fmadd_second_group_at_r(w_k, acc_az_j, acc_bz_second_j); + } + } + }); + } + + // Final reductions: reduce accumulators and write to output slices. + // Each chunk writes to disjoint indices, so parallel iteration is safe. + const REDUCE_CHUNK_SIZE: usize = 4096; + grid_az + .par_chunks_mut(REDUCE_CHUNK_SIZE) + .zip(grid_bz.par_chunks_mut(REDUCE_CHUNK_SIZE)) + .enumerate() + .for_each(|(chunk_idx, (az_chunk, bz_chunk))| { + let start = chunk_idx * REDUCE_CHUNK_SIZE; + for (local_j, (az_out, bz_out)) in + az_chunk.iter_mut().zip(bz_chunk.iter_mut()).enumerate() + { + let j = start + local_j; + *az_out = acc_az[j].barrett_reduce(); + let bz_first_j = acc_bz_first[j].barrett_reduce(); + let bz_second_j = acc_bz_second[j].barrett_reduce(); + *bz_out = bz_first_j + bz_second_j; + } + }); + } + + // returns the grid of evaluations on {0,1,inf}^window_size + // touches each cycle of the trace exactly once and in order! + #[tracing::instrument( + skip_all, + name = "OuterRemainingSumcheckProver::compute_evaluation_grid_from_trace" + )] + fn compute_evaluation_grid_from_trace(&mut self, window_size: usize) { + // semantics in one place (see `split_eq_poly::E_out_in_for_window`). + let split_eq = &self.split_eq_poly; + + // helper constants + let three_pow_dim = 3_usize.pow(window_size as u32); + let jlen = 1 << window_size; + let klen = 1 << split_eq.num_challenges(); + + // Precompute the tensor product of the Lagrange weights at r0 with the + // current r_grid weights so that all calls into `build_grids` can reuse + // these scaled tables. + let lagrange_evals_r = &self.lagrange_evals_r0; + let r_grid = &self.r_grid; + let mut scaled_w = vec![[F::zero(); OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE]; klen]; + if klen > 1 { + debug_assert_eq!(klen, r_grid.len()); + for k in 0..klen { + let weight = r_grid[k]; + let row = &mut scaled_w[k]; + for t in 0..OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE { + row[t] = lagrange_evals_r[t] * weight; + } + } + } else { + debug_assert_eq!(klen, 1); + scaled_w[0].copy_from_slice(lagrange_evals_r); + } + + // Head-factor eq tables for this window. + let (e_out, e_in) = split_eq.E_out_in_for_window(window_size); + let e_in_len = e_in.len(); + + // main logic: parallelize outer sum over E_out_current; for each x_out, + // perform an inner unreduced accumulation over E_in_current and only + // reduce once per grid cell, then multiply by E_out unreduced. + let res_unr = e_out + .par_iter() + .enumerate() + .map(|(out_idx, out_val)| { + // Local unreduced accumulators and scratch buffers for this out_idx. + let mut local_res_unr = vec![F::Unreduced::<9>::zero(); three_pow_dim]; + let mut buff_a: Vec = vec![F::zero(); three_pow_dim]; + let mut buff_b = vec![F::zero(); three_pow_dim]; + let mut tmp = vec![F::zero(); three_pow_dim]; + let mut grid_a = vec![F::zero(); jlen]; + let mut grid_b = vec![F::zero(); jlen]; + + for (in_idx, in_val) in e_in.iter().enumerate() { + let i = out_idx * e_in_len + in_idx; + + // Reuse the same grid buffers across all x_in for this x_out. + grid_a.fill(F::zero()); + grid_b.fill(F::zero()); + // Keep this call sequential to avoid nested rayon parallelism. + self.extrapolate_from_binary_grid_to_tertiary_grid( + &mut grid_a, + &mut grid_b, + jlen, + klen, + i * jlen * klen, + false, + &scaled_w, + ); + + // Extrapolate grid_a and grid_b from {0,1}^window_size to {0,1,∞}^window_size. + MultiquadraticPolynomial::::expand_linear_grid_to_multiquadratic( + &grid_a, + &mut buff_a, + &mut tmp, + window_size, + ); + MultiquadraticPolynomial::::expand_linear_grid_to_multiquadratic( + &grid_b, + &mut buff_b, + &mut tmp, + window_size, + ); + + let e_in_val = *in_val; + for idx in 0..three_pow_dim { + let val = buff_a[idx] * buff_b[idx]; + local_res_unr[idx] += e_in_val.mul_unreduced::<9>(val); + } + } + + // Fold in E_out for this x_out. + let e_out_val = *out_val; + for idx in 0..three_pow_dim { + let inner_red = F::from_montgomery_reduce::<9>(local_res_unr[idx]); + local_res_unr[idx] = e_out_val.mul_unreduced::<9>(inner_red); + } + local_res_unr + }) + .reduce( + || vec![F::Unreduced::<9>::zero(); three_pow_dim], + |mut acc, local| { + for idx in 0..three_pow_dim { + acc[idx] += local[idx]; + } + acc + }, + ); + + // Final reduction over all (x_out, x_in) + let res: Vec = res_unr + .into_iter() + .map(|unr| F::from_montgomery_reduce::<9>(unr)) + .collect(); + self.t_prime_poly = Some(MultiquadraticPolynomial::new(window_size, res)); + } + + #[tracing::instrument( + skip_all, + name = "OuterRemainingSumcheckProver::compute_evaluation_grid_from_poly_parallel" + )] + pub fn compute_evaluation_grid_from_polynomials_parallel(&mut self, num_vars: usize) { + let eq_poly = &self.split_eq_poly; + + let n = self.az.as_ref().expect("az should be initialized").len(); + let az = self.az.as_ref().expect("az should be initialized"); + let bz = self.bz.as_ref().expect("bz should be initialized"); + debug_assert_eq!(n, bz.len()); + + let three_pow_dim = 3_usize.pow(num_vars as u32); + let grid_size = 1 << num_vars; + let (E_out, E_in) = eq_poly.E_out_in_for_window(num_vars); + let ans: Vec = if E_in.len() == 1 { + // Parallel version with reduction + (0..E_out.len()) + .into_par_iter() + .map(|i| { + let mut local_ans = vec![F::zero(); three_pow_dim]; + let mut az_grid = vec![F::zero(); grid_size]; + let mut bz_grid = vec![F::zero(); grid_size]; + let mut buff_a = vec![F::zero(); three_pow_dim]; + let mut buff_b = vec![F::zero(); three_pow_dim]; + let mut tmp = vec![F::zero(); three_pow_dim]; + + // Fill grids + for j in 0..grid_size { + let index = grid_size * i + j; + az_grid[j] = az[index]; + bz_grid[j] = bz[index]; + } + + MultiquadraticPolynomial::::expand_linear_grid_to_multiquadratic( + &az_grid, + &mut buff_a, + &mut tmp, + num_vars, + ); + MultiquadraticPolynomial::::expand_linear_grid_to_multiquadratic( + &bz_grid, + &mut buff_b, + &mut tmp, + num_vars, + ); + + for idx in 0..three_pow_dim { + local_ans[idx] = buff_a[idx] * buff_b[idx] * E_out[i]; + } + + local_ans + }) + .reduce( + || vec![F::zero(); three_pow_dim], + |mut acc, local_ans| { + for idx in 0..three_pow_dim { + acc[idx] += local_ans[idx]; + } + acc + }, + ) + } else { + let num_xin_bits = E_in.len().log_2(); + (0..E_out.len()) + .into_par_iter() + .map(|x_out| { + let mut local_ans = vec![F::zero(); three_pow_dim]; + let mut az_grid = vec![F::zero(); grid_size]; + let mut bz_grid = vec![F::zero(); grid_size]; + let mut buff_a = vec![F::zero(); three_pow_dim]; + let mut buff_b = vec![F::zero(); three_pow_dim]; + let mut tmp = vec![F::zero(); three_pow_dim]; + + for x_in in 0..E_in.len() { + let i = (x_out << num_xin_bits) | x_in; + + //az_grid.fill(F::zero()); + //bz_grid.fill(F::zero()); + for j in 0..grid_size { + az_grid[j] = az[grid_size * i + j]; + bz_grid[j] = bz[grid_size * i + j]; + } + + MultiquadraticPolynomial::::expand_linear_grid_to_multiquadratic( + &az_grid, + &mut buff_a, + &mut tmp, + num_vars, + ); + MultiquadraticPolynomial::::expand_linear_grid_to_multiquadratic( + &bz_grid, + &mut buff_b, + &mut tmp, + num_vars, + ); + + for idx in 0..three_pow_dim { + local_ans[idx] += buff_a[idx] * buff_b[idx] * E_in[x_in]; + } + } + for idx in 0..three_pow_dim { + local_ans[idx] *= E_out[x_out]; + } + + local_ans + }) + .reduce( + || vec![F::zero(); three_pow_dim], + |mut acc, local_ans| { + for idx in 0..three_pow_dim { + acc[idx] += local_ans[idx]; + } + acc + }, + ) + }; + self.t_prime_poly = Some(MultiquadraticPolynomial::new(num_vars, ans)); + } + + // Good to keep as a reference + #[tracing::instrument( + skip_all, + name = "OuterRemainingSumcheckProver::compute_evaluation_grid_from_poly_serial" + )] + pub fn _compute_evaluation_grid_from_polynomials_serial(&mut self, num_vars: usize) { + let eq_poly = &self.split_eq_poly; + + let n = self.az.as_ref().expect("az should be initialized").len(); + let az = self.az.as_ref().expect("az should be initialized"); + let bz = self.bz.as_ref().expect("bz should be initialized"); + debug_assert_eq!(n, bz.len()); + + let three_pow_dim = 3_usize.pow(num_vars as u32); + let grid_size = 1 << num_vars; + let mut az_grid = vec![F::zero(); grid_size]; + let mut bz_grid = vec![F::zero(); grid_size]; + let mut buff_a: Vec = vec![F::zero(); three_pow_dim]; + let mut buff_b = vec![F::zero(); three_pow_dim]; + let mut tmp = vec![F::zero(); three_pow_dim]; + let mut ans = vec![F::zero(); three_pow_dim]; + + let (E_out, E_in) = eq_poly.E_out_in_for_window(num_vars); + if E_in.len() == 1 { + // this is a simple case of a linear loop + for i in 0..E_out.len() { + az_grid.fill(F::zero()); + bz_grid.fill(F::zero()); + for j in 0..grid_size { + let index = (grid_size) * i + j; + az_grid[j] = az[index]; + bz_grid[j] = bz[index]; + } + MultiquadraticPolynomial::::expand_linear_grid_to_multiquadratic( + &az_grid, + &mut buff_a, + &mut tmp, + num_vars, + ); + MultiquadraticPolynomial::::expand_linear_grid_to_multiquadratic( + &bz_grid, + &mut buff_b, + &mut tmp, + num_vars, + ); + for idx in 0..three_pow_dim { + ans[idx] += buff_a[idx] * buff_b[idx] * E_out[i]; + } + } + } else { + let num_xin_bits = E_in.len().log_2(); + for x_out in 0..E_out.len() { + for x_in in 0..E_in.len() { + let i = (x_out << num_xin_bits) | x_in; + az_grid.fill(F::zero()); + bz_grid.fill(F::zero()); + for j in 0..grid_size { + az_grid[j] = az[grid_size * i + j]; + bz_grid[j] = bz[grid_size * i + j]; + } + + MultiquadraticPolynomial::::expand_linear_grid_to_multiquadratic( + &az_grid, + &mut buff_a, + &mut tmp, + num_vars, + ); + MultiquadraticPolynomial::::expand_linear_grid_to_multiquadratic( + &bz_grid, + &mut buff_b, + &mut tmp, + num_vars, + ); + for idx in 0..three_pow_dim { + ans[idx] += buff_a[idx] * buff_b[idx] * E_out[x_out] * E_in[x_in]; + } + } + } + } + + self.t_prime_poly = Some(MultiquadraticPolynomial::new(num_vars, ans)); + } + + // NOTE: no small value optimisation + #[tracing::instrument( + skip_all, + name = "OuterRemainingSumcheckProver::materialise_poly_from_trace_parallel" + )] + fn materialise_polynomials_from_trace_parallel_dim_one(&mut self) { + let num_x_out_vals = self.split_eq_poly.E_out_current_len(); + let num_x_in_vals = self.split_eq_poly.E_in_current_len(); + let r_grid = &self.r_grid; + let num_r_vals = r_grid.len(); + + // Output arrays are sized by (x_out, x_in) pairs + let output_size = num_x_out_vals * num_x_in_vals; + let mut az_bound: Vec = unsafe_allocate_zero_vec(2 * output_size); + let mut bz_bound: Vec = unsafe_allocate_zero_vec(2 * output_size); + + let num_r_bits = num_r_vals.log_2(); + let num_x_in_bits = num_x_in_vals.log_2(); + + // Dynamic chunking for parallelization + let num_threads = rayon::current_num_threads(); + let target_chunks = num_threads * 4; + let min_chunk_pairs = 16; + //let pairs_per_chunk = + //((output_size + target_chunks - 1) / target_chunks).max(min_chunk_pairs); + let pairs_per_chunk = output_size.div_ceil(target_chunks).max(min_chunk_pairs); + let chunk_size = pairs_per_chunk * 2; + + // Parallel computation with reduction + let (t0_acc, t_inf_acc) = az_bound + .par_chunks_mut(chunk_size) + .zip(bz_bound.par_chunks_mut(chunk_size)) + .enumerate() + .fold( + || (F::zero(), F::zero()), + |(mut t0_local, mut t_inf_local), (chunk_idx, (az_chunk, bz_chunk))| { + let start_pair = chunk_idx * pairs_per_chunk; + let end_pair = (start_pair + pairs_per_chunk).min(output_size); + + for pair_idx in start_pair..end_pair { + let x_in_val = pair_idx % num_x_in_vals; + let x_out_val = pair_idx / num_x_in_vals; + + let mut az0_sum = F::zero(); // For X=0 + let mut az1_sum = F::zero(); // For X=1 + let mut bz0_sum = F::zero(); // For X=0 + let mut bz1_sum = F::zero(); // For X=1 + + // Single loop over r values, computing both X=0 and X=1 + for r_idx in 0..num_r_vals { + let r_eval = r_grid[r_idx]; + + // Build indices for both X=0 and X=1 + let base_idx = (x_out_val << (num_x_in_bits + 1 + num_r_bits)) + | (x_in_val << (1 + num_r_bits)); + + let full_idx_x0 = base_idx | (0 << num_r_bits) | r_idx; + let full_idx_x1 = base_idx | (1 << num_r_bits) | r_idx; + + // Process X=0 + let step_idx_x0 = full_idx_x0 >> 1; + let selector_x0 = (full_idx_x0 & 1) == 1; + + let row_inputs_x0 = R1CSCycleInputs::from_trace::( + &self.bytecode_preprocessing, + &self.trace, + step_idx_x0, + ); + let eval_x0 = R1CSEval::::from_cycle_inputs(&row_inputs_x0); + + let (az_x0, bz_x0) = if !selector_x0 { + ( + eval_x0.az_at_r_first_group(&self.lagrange_evals_r0), + eval_x0.bz_at_r_first_group(&self.lagrange_evals_r0), + ) + } else { + ( + eval_x0.az_at_r_second_group(&self.lagrange_evals_r0), + eval_x0.bz_at_r_second_group(&self.lagrange_evals_r0), + ) + }; + + // Process X=1 + let step_idx_x1 = full_idx_x1 >> 1; + let selector_x1 = (full_idx_x1 & 1) == 1; + + let row_inputs_x1 = R1CSCycleInputs::from_trace::( + &self.bytecode_preprocessing, + &self.trace, + step_idx_x1, + ); + let eval_x1 = R1CSEval::::from_cycle_inputs(&row_inputs_x1); + + let (az_x1, bz_x1) = if !selector_x1 { + ( + eval_x1.az_at_r_first_group(&self.lagrange_evals_r0), + eval_x1.bz_at_r_first_group(&self.lagrange_evals_r0), + ) + } else { + ( + eval_x1.az_at_r_second_group(&self.lagrange_evals_r0), + eval_x1.bz_at_r_second_group(&self.lagrange_evals_r0), + ) + }; + + // Accumulate both with the same r_eval + az0_sum += az_x0 * r_eval; + bz0_sum += bz_x0 * r_eval; + az1_sum += az_x1 * r_eval; + bz1_sum += bz_x1 * r_eval; + } + + // Store in chunk-relative position + let buffer_offset = 2 * (pair_idx - start_pair); + az_chunk[buffer_offset] = az0_sum; + az_chunk[buffer_offset + 1] = az1_sum; + bz_chunk[buffer_offset] = bz0_sum; + bz_chunk[buffer_offset + 1] = bz1_sum; + + // Local accumulation for t_0 and t_inf + let e_in = self.split_eq_poly.E_in_current()[x_in_val]; + let e_out = self.split_eq_poly.E_out_current()[x_out_val]; + let p0 = az0_sum * bz0_sum; + let slope = (az1_sum - az0_sum) * (bz1_sum - bz0_sum); + + t0_local += e_out * e_in * p0; + t_inf_local += e_out * e_in * slope; + } + + (t0_local, t_inf_local) + }, + ) + .reduce( + || (F::zero(), F::zero()), + |(t0_a, t_inf_a), (t0_b, t_inf_b)| (t0_a + t0_b, t_inf_a + t_inf_b), + ); + + self.az = Some(DensePolynomial::new(az_bound)); + self.bz = Some(DensePolynomial::new(bz_bound)); + self.t_prime_poly = Some(MultiquadraticPolynomial::new( + 1, + vec![t0_acc, t0_acc, t_inf_acc], + )); + } + + // If the first round of the sumcheck is the switchover point + // then materialisng Az and Bz is significantly simpler. + // We do not need to deal with challenges. + #[tracing::instrument( + skip_all, + name = "OuterRemainingSumcheckProver::materialise_poly_from_trace_round_zero" + )] + fn materialise_polynomials_from_trace_round_zero_dim_one(&mut self) { + let num_x_out_vals = self.split_eq_poly.E_out_current_len(); + let num_x_in_vals = self.split_eq_poly.E_in_current_len(); let iter_num_x_in_vars = num_x_in_vals.log_2(); let groups_exact = num_x_out_vals @@ -324,19 +919,21 @@ impl OuterRemainingSumcheckProver { let mut inner_sum_inf = F::Unreduced::<9>::zero(); for x_in_val in 0..num_x_in_vals { let current_step_idx = (x_out_val << iter_num_x_in_vars) | x_in_val; + + // Possibly re-design this let row_inputs = R1CSCycleInputs::from_trace::( - bytecode_preprocessing, - trace, + &self.bytecode_preprocessing, + &self.trace, current_step_idx, ); let eval = R1CSEval::::from_cycle_inputs(&row_inputs); - let az0 = eval.az_at_r_first_group(lagrange_evals_r); - let bz0 = eval.bz_at_r_first_group(lagrange_evals_r); - let az1 = eval.az_at_r_second_group(lagrange_evals_r); - let bz1 = eval.bz_at_r_second_group(lagrange_evals_r); + let az0 = eval.az_at_r_first_group(&self.lagrange_evals_r0); + let bz0 = eval.bz_at_r_first_group(&self.lagrange_evals_r0); + let az1 = eval.az_at_r_second_group(&self.lagrange_evals_r0); + let bz1 = eval.bz_at_r_second_group(&self.lagrange_evals_r0); let p0 = az0 * bz0; let slope = (az1 - az0) * (bz1 - bz0); - let e_in = split_eq_poly.E_in_current()[x_in_val]; + let e_in = self.split_eq_poly.E_in_current()[x_in_val]; inner_sum0 += e_in.mul_unreduced::<9>(p0); inner_sum_inf += e_in.mul_unreduced::<9>(slope); let off = 2 * x_in_val; @@ -345,7 +942,7 @@ impl OuterRemainingSumcheckProver { bz_chunk[off] = bz0; bz_chunk[off + 1] = bz1; } - let e_out = split_eq_poly.E_out_current()[x_out_val]; + let e_out = self.split_eq_poly.E_out_current()[x_out_val]; let reduced0 = F::from_montgomery_reduce::<9>(inner_sum0); let reduced_inf = F::from_montgomery_reduce::<9>(inner_sum_inf); acc0 += e_out.mul_unreduced::<9>(reduced0); @@ -358,49 +955,453 @@ impl OuterRemainingSumcheckProver { |a, b| (a.0 + b.0, a.1 + b.1), ); - ( - F::from_montgomery_reduce::<9>(t0_acc_unr), - F::from_montgomery_reduce::<9>(t_inf_acc_unr), - DensePolynomial::new(az_bound), - DensePolynomial::new(bz_bound), - ) + self.az = Some(DensePolynomial::new(az_bound)); + self.bz = Some(DensePolynomial::new(bz_bound)); + let t_0 = F::from_montgomery_reduce::<9>(t0_acc_unr); + let t_inf = F::from_montgomery_reduce::<9>(t_inf_acc_unr); + self.t_prime_poly = Some(MultiquadraticPolynomial::new(1, vec![t_0, t_0, t_inf])); } + fn fused_materialise_polynomials_general_with_multiquadratic(&mut self, window_size: usize) { + let (E_out, E_in) = self.split_eq_poly.E_out_in_for_window(window_size); + let num_x_out_vals = E_out.len(); + let num_x_in_vals = E_in.len(); + let r_grid = &self.r_grid; + let num_r_vals = r_grid.len(); - // No special binding path needed; az/bz hold interleaved [lo,hi] ready for binding + let three_pow_dim = 3_usize.pow(window_size as u32); + let grid_size = 1 << window_size; + let num_evals_az = E_out.len() * E_in.len() * grid_size; - /// Compute the polynomial for each of the remaining rounds, using the - /// linear-time algorithm with split-eq optimizations. - /// - /// At this point, we have computed the `bound_coeffs` for the current round. - /// We need to compute: - /// - /// `t_i(0) = \sum_{x_out} E_out[x_out] \sum_{x_in} E_in[x_in] * - /// (az_bound[x_out, x_in, 0] * bz_bound[x_out, x_in, 0] - cz_bound[x_out, x_in, 0])` - /// - /// and - /// - /// `t_i(∞) = \sum_{x_out} E_out[x_out] \sum_{x_in} E_in[x_in] * - /// az_bound[x_out, x_in, ∞] * bz_bound[x_out, x_in, ∞]` - /// - /// (ordering of indices is MSB to LSB, so x_out is the MSB and x_in is the LSB) - #[inline] - fn remaining_quadratic_evals(&self) -> (F, F) { - let n = self.az.len(); - debug_assert_eq!(n, self.bz.len()); - let [t0, tinf] = self.split_eq_poly.par_fold_out_in_unreduced::<9, 2>(&|g| { - let az0 = self.az[2 * g]; - let az1 = self.az[2 * g + 1]; - let bz0 = self.bz[2 * g]; - let bz1 = self.bz[2 * g + 1]; - let p0 = az0 * bz0; - let slope = (az1 - az0) * (bz1 - bz0); - [p0, slope] - }); - (t0, tinf) + let mut az_bound: Vec = unsafe_allocate_zero_vec(num_evals_az); + let mut bz_bound: Vec = unsafe_allocate_zero_vec(num_evals_az); + + let num_r_bits = num_r_vals.log_2(); + let num_x_in_bits = num_x_in_vals.log_2(); + + let output_size = num_x_out_vals * num_x_in_vals; + + // Dynamic chunking for parallelization + let num_threads = rayon::current_num_threads(); + let target_chunks = num_threads * 4; + let min_chunk_pairs = 16; + let pairs_per_chunk = output_size.div_ceil(target_chunks).max(min_chunk_pairs); + let chunk_size = pairs_per_chunk * grid_size; + + // Parallel computation with reduction + let ans = az_bound + .par_chunks_mut(chunk_size) + .zip(bz_bound.par_chunks_mut(chunk_size)) + .enumerate() + .fold( + || vec![F::zero(); three_pow_dim], + |mut local_ans, (chunk_idx, (az_chunk, bz_chunk))| { + let start_pair = chunk_idx * pairs_per_chunk; + let end_pair = (start_pair + pairs_per_chunk).min(output_size); + + // Thread-local buffers for multiquadratic expansion + let mut buff_a = vec![F::zero(); three_pow_dim]; + let mut buff_b = vec![F::zero(); three_pow_dim]; + let mut tmp = vec![F::zero(); three_pow_dim]; + let mut az_grid = vec![F::zero(); grid_size]; + let mut bz_grid = vec![F::zero(); grid_size]; + + for pair_idx in start_pair..end_pair { + let x_in_val = pair_idx % num_x_in_vals; + let x_out_val = pair_idx / num_x_in_vals; + + // Initialize grid accumulator for this (x_out, x_in) pair + az_grid.fill(F::zero()); + bz_grid.fill(F::zero()); + + // Loop over r values and accumulate for each grid point + for r_idx in 0..num_r_vals { + let r_eval = r_grid[r_idx]; + + let base_idx = (x_out_val + << (num_x_in_bits + window_size + num_r_bits)) + | (x_in_val << (window_size + num_r_bits)); + + // Process all grid_size points + for x_val in 0..grid_size { + let full_idx = base_idx | (x_val << num_r_bits) | r_idx; + + let step_idx = full_idx >> 1; + let selector = (full_idx & 1) == 1; + + let row_inputs = R1CSCycleInputs::from_trace::( + &self.bytecode_preprocessing, + &self.trace, + step_idx, + ); + let eval = R1CSEval::::from_cycle_inputs(&row_inputs); + + let (az_val, bz_val) = if !selector { + ( + eval.az_at_r_first_group(&self.lagrange_evals_r0), + eval.bz_at_r_first_group(&self.lagrange_evals_r0), + ) + } else { + ( + eval.az_at_r_second_group(&self.lagrange_evals_r0), + eval.bz_at_r_second_group(&self.lagrange_evals_r0), + ) + }; + + // Accumulate with r_eval weight + az_grid[x_val] += az_val * r_eval; + bz_grid[x_val] += bz_val * r_eval; + } + } + + // Store the accumulated grid in chunk-relative position + let buffer_offset = grid_size * (pair_idx - start_pair); + let end = buffer_offset + grid_size; + az_chunk[buffer_offset..end].copy_from_slice(&az_grid[..grid_size]); + bz_chunk[buffer_offset..end].copy_from_slice(&bz_grid[..grid_size]); + // Expand to multiquadratic + MultiquadraticPolynomial::::expand_linear_grid_to_multiquadratic( + &az_grid, + &mut buff_a, + &mut tmp, + window_size, + ); + MultiquadraticPolynomial::::expand_linear_grid_to_multiquadratic( + &bz_grid, + &mut buff_b, + &mut tmp, + window_size, + ); + + // Accumulate into local ans with E_out * E_in weight + let e_in = E_in[x_in_val]; + let e_out = E_out[x_out_val]; + let e_product = e_out * e_in; + + for idx in 0..three_pow_dim { + local_ans[idx] += buff_a[idx] * buff_b[idx] * e_product; + } + } + + local_ans + }, + ) + .reduce( + || vec![F::zero(); three_pow_dim], + |mut acc, local_ans| { + for idx in 0..three_pow_dim { + acc[idx] += local_ans[idx]; + } + acc + }, + ); + + self.az = Some(DensePolynomial::new(az_bound)); + self.bz = Some(DensePolynomial::new(bz_bound)); + self.t_prime_poly = Some(MultiquadraticPolynomial::new(window_size, ans)); + } + fn fused_materialise_polynomials_round_zero(&mut self, num_vars: usize) { + // Note: this is the simplest materialise as there are no challenges to deal with + let eq_poly = &self.split_eq_poly; + + let three_pow_dim = 3_usize.pow(num_vars as u32); + let grid_size = 1 << num_vars; + let (E_out, E_in) = eq_poly.E_out_in_for_window(num_vars); + + let num_evals_az = E_out.len() * E_in.len() * grid_size; + let mut az: Vec = unsafe_allocate_zero_vec(num_evals_az); + let mut bz: Vec = unsafe_allocate_zero_vec(num_evals_az); + + let ans: Vec = if E_in.len() == 1 { + // Parallel version with reduction (E_in has only one element) + az.par_chunks_exact_mut(grid_size) + .zip(bz.par_chunks_exact_mut(grid_size)) + .enumerate() + .map(|(i, (az_chunk, bz_chunk))| { + let mut local_ans = vec![F::zero(); three_pow_dim]; + let mut az_grid = vec![F::zero(); grid_size]; + let mut bz_grid = vec![F::zero(); grid_size]; + let mut buff_a = vec![F::zero(); three_pow_dim]; + let mut buff_b = vec![F::zero(); three_pow_dim]; + let mut tmp = vec![F::zero(); three_pow_dim]; + + for j in 0..grid_size { + let full_idx = grid_size * i + j; + // Extract time_step_idx and selector from full_idx + let time_step_idx = full_idx >> 1; + let selector = (full_idx & 1) == 1; + + let row_inputs = R1CSCycleInputs::from_trace::( + &self.bytecode_preprocessing, + &self.trace, + time_step_idx, + ); + let eval = R1CSEval::::from_cycle_inputs(&row_inputs); + + let (az_at_full_idx, bz_at_full_idx) = if !selector { + ( + eval.az_at_r_first_group(&self.lagrange_evals_r0), + eval.bz_at_r_first_group(&self.lagrange_evals_r0), + ) + } else { + ( + eval.az_at_r_second_group(&self.lagrange_evals_r0), + eval.bz_at_r_second_group(&self.lagrange_evals_r0), + ) + }; + + az_chunk[j] = az_at_full_idx; + bz_chunk[j] = bz_at_full_idx; + az_grid[j] = az_at_full_idx; + bz_grid[j] = bz_at_full_idx; + } + + MultiquadraticPolynomial::::expand_linear_grid_to_multiquadratic( + &az_grid, + &mut buff_a, + &mut tmp, + num_vars, + ); + MultiquadraticPolynomial::::expand_linear_grid_to_multiquadratic( + &bz_grid, + &mut buff_b, + &mut tmp, + num_vars, + ); + + for idx in 0..three_pow_dim { + local_ans[idx] = buff_a[idx] * buff_b[idx] * E_out[i]; + } + + local_ans + }) + .reduce( + || vec![F::zero(); three_pow_dim], + |mut acc, local_ans| { + for idx in 0..three_pow_dim { + acc[idx] += local_ans[idx]; + } + acc + }, + ) + } else { + // Handle case where E_in has multiple elements + let num_xin_bits = E_in.len().log_2(); + az.par_chunks_exact_mut(grid_size * E_in.len()) + .zip(bz.par_chunks_exact_mut(grid_size * E_in.len())) + .enumerate() + .map(|(x_out, (az_outer_chunk, bz_outer_chunk))| { + let mut local_ans = vec![F::zero(); three_pow_dim]; + let mut az_grid = vec![F::zero(); grid_size]; + let mut bz_grid = vec![F::zero(); grid_size]; + let mut buff_a = vec![F::zero(); three_pow_dim]; + let mut buff_b = vec![F::zero(); three_pow_dim]; + let mut tmp = vec![F::zero(); three_pow_dim]; + + for x_in in 0..E_in.len() { + let i = (x_out << num_xin_bits) | x_in; + + // Fill grids for this (x_out, x_in) pair + for j in 0..grid_size { + let full_idx = grid_size * i + j; + let time_step_idx = full_idx >> 1; + let selector = (full_idx & 1) == 1; + + let row_inputs = R1CSCycleInputs::from_trace::( + &self.bytecode_preprocessing, + &self.trace, + time_step_idx, + ); + let eval = R1CSEval::::from_cycle_inputs(&row_inputs); + + let (az_at_full_idx, bz_at_full_idx) = if !selector { + ( + eval.az_at_r_first_group(&self.lagrange_evals_r0), + eval.bz_at_r_first_group(&self.lagrange_evals_r0), + ) + } else { + ( + eval.az_at_r_second_group(&self.lagrange_evals_r0), + eval.bz_at_r_second_group(&self.lagrange_evals_r0), + ) + }; + + let offset_in_chunk = x_in * grid_size + j; + az_outer_chunk[offset_in_chunk] = az_at_full_idx; + bz_outer_chunk[offset_in_chunk] = bz_at_full_idx; + az_grid[j] = az_at_full_idx; + bz_grid[j] = bz_at_full_idx; + } + + MultiquadraticPolynomial::::expand_linear_grid_to_multiquadratic( + &az_grid, + &mut buff_a, + &mut tmp, + num_vars, + ); + MultiquadraticPolynomial::::expand_linear_grid_to_multiquadratic( + &bz_grid, + &mut buff_b, + &mut tmp, + num_vars, + ); + + let e_product = E_out[x_out] * E_in[x_in]; + for idx in 0..three_pow_dim { + local_ans[idx] += buff_a[idx] * buff_b[idx] * e_product; + } + } + + local_ans + }) + .reduce( + || vec![F::zero(); three_pow_dim], + |mut acc, local_ans| { + for idx in 0..three_pow_dim { + acc[idx] += local_ans[idx]; + } + acc + }, + ) + }; + self.t_prime_poly = Some(MultiquadraticPolynomial::new(num_vars, ans)); + self.az = Some(DensePolynomial::new(az)); + self.bz = Some(DensePolynomial::new(bz)); + } + + #[tracing::instrument( + skip_all, + name = "OuterRemainingSumcheckProver::materialise_poly_from_trace" + )] + fn materialise_polynomials_from_trace(&mut self, window_size: usize) { + let split_eq_poly = &self.split_eq_poly; + let is_not_first_round_of_sumcheck = split_eq_poly.num_challenges() > 0; + match (is_not_first_round_of_sumcheck, window_size == 1) { + (true, true) => self.materialise_polynomials_from_trace_parallel_dim_one(), + (true, false) => { + self.fused_materialise_polynomials_general_with_multiquadratic(window_size) + } + (false, true) => self.materialise_polynomials_from_trace_round_zero_dim_one(), + (false, false) => self.fused_materialise_polynomials_round_zero(window_size), + } + //if split_eq_poly.num_challenges() > 0 { + // if window_size == 1 { + // self.materialise_polynomials_from_trace_parallel_dim_one(); + // } else { + // self.fused_materialise_polynomials_general_with_multiquadratic(window_size); + // } + //} else { + // if window_size == 1 { + // self.materialise_polynomials_from_trace_round_zero_dim_one(); + // } else { + // self.fused_materialise_polynomials_round_zero(window_size); + // } + //} + } + + // Compute prover message directly for window size 1 + fn _remaining_quadratic_evals(&mut self) -> (F, F) { + let eq_poly = &self.split_eq_poly; + + let n = self.az.as_ref().expect("az should be initialized").len(); + let az = self.az.as_ref().expect("az should be initialized"); + let bz = self.bz.as_ref().expect("bz should be initialized"); + + debug_assert_eq!(n, bz.len()); + if eq_poly.E_in_current_len() == 1 { + // groups are pairs (0,1) + let groups = n / 2; + let (t0_unr, tinf_unr) = (0..groups) + .into_par_iter() + .map(|g| { + let az0 = az[2 * g]; + let az1 = az[2 * g + 1]; + let bz0 = bz[2 * g]; + let bz1 = bz[2 * g + 1]; + let eq = eq_poly.E_out_current()[g]; + let p0 = az0 * bz0; + let slope = (az1 - az0) * (bz1 - bz0); + let t0_unr = eq.mul_unreduced::<9>(p0); + let tinf_unr = eq.mul_unreduced::<9>(slope); + (t0_unr, tinf_unr) + }) + .reduce( + || (F::Unreduced::<9>::zero(), F::Unreduced::<9>::zero()), + |a, b| (a.0 + b.0, a.1 + b.1), + ); + ( + F::from_montgomery_reduce::<9>(t0_unr), + F::from_montgomery_reduce::<9>(tinf_unr), + ) + } else { + let num_x1_bits = eq_poly.E_in_current_len().log_2(); + let x1_len = eq_poly.E_in_current_len(); + let x2_len = eq_poly.E_out_current_len(); + let (sum0_unr, suminf_unr) = (0..x2_len) + .into_par_iter() + .map(|x2| { + let mut inner0_unr = F::Unreduced::<9>::zero(); + let mut inner_inf_unr = F::Unreduced::<9>::zero(); + for x1 in 0..x1_len { + let g = (x2 << num_x1_bits) | x1; + let az0 = az[2 * g]; + let az1 = az[2 * g + 1]; + let bz0 = bz[2 * g]; + let bz1 = bz[2 * g + 1]; + let e_in = eq_poly.E_in_current()[x1]; + let p0 = az0 * bz0; + let slope = (az1 - az0) * (bz1 - bz0); + inner0_unr += e_in.mul_unreduced::<9>(p0); + inner_inf_unr += e_in.mul_unreduced::<9>(slope); + } + let e_out = eq_poly.E_out_current()[x2]; + let inner0_red = F::from_montgomery_reduce::<9>(inner0_unr); + let inner_inf_red = F::from_montgomery_reduce::<9>(inner_inf_unr); + let t0_unr = e_out.mul_unreduced::<9>(inner0_red); + let tinf_unr = e_out.mul_unreduced::<9>(inner_inf_red); + (t0_unr, tinf_unr) + }) + .reduce( + || (F::Unreduced::<9>::zero(), F::Unreduced::<9>::zero()), + |a, b| (a.0 + b.0, a.1 + b.1), + ); + ( + F::from_montgomery_reduce::<9>(sum0_unr), + F::from_montgomery_reduce::<9>(suminf_unr), + ) + } + } + + pub fn compute_t_evals(&self, window_size: usize) -> (F, F) { + let t_prime_poly = self + .t_prime_poly + .as_ref() + .expect("t_prime_poly should be initialized"); + + // Equality weights over the active window bits (all but the first). + let e_active = self.split_eq_poly.E_active_for_window(window_size); + let t_prime_0 = t_prime_poly.project_to_first_variable(&e_active, 0); + let t_prime_inf = t_prime_poly.project_to_first_variable(&e_active, INFINITY); + (t_prime_0, t_prime_inf) + } + + #[tracing::instrument( + skip_all, + name = "OuterRemainingSumcheckProver::compute_evaluation_grid_from_polynomials_serial" + )] + + pub fn final_sumcheck_evals(&self) -> [F; 2] { + let az = self.az.as_ref().expect("az should be initialized"); + let bz = self.bz.as_ref().expect("bz should be initialized"); + + let az0 = if !az.is_empty() { az[0] } else { F::zero() }; + let bz0 = if !bz.is_empty() { bz[0] } else { F::zero() }; + [az0, bz0] } } -impl SumcheckInstanceProver for OuterRemainingSumcheckProver { +impl SumcheckInstanceProver + for OuterRemainingSumcheckProver<'_, F, S> +{ fn degree(&self) -> usize { OUTER_REMAINING_DEGREE_BOUND } @@ -415,24 +1416,51 @@ impl SumcheckInstanceProver for OuterRemainin #[tracing::instrument(skip_all, name = "OuterRemainingSumcheckProver::compute_message")] fn compute_message(&mut self, round: usize, previous_claim: F) -> UniPoly { - let (t0, t_inf) = if round == 0 { - self.first_round_evals - } else { - self.remaining_quadratic_evals() - }; + let num_unbound_vars = self.schedule.num_unbound_vars(round); + if self.schedule.is_switch_over_point(round) { + self.materialise_polynomials_from_trace(num_unbound_vars); + } else if self.schedule.is_window_start(round) { + if self.schedule.before_switch_over_point(round) { + self.compute_evaluation_grid_from_trace(num_unbound_vars); + } else { + self.compute_evaluation_grid_from_polynomials_parallel(num_unbound_vars); + } + } + let (t_prime_0, t_prime_inf) = self.compute_t_evals(num_unbound_vars); self.split_eq_poly - .gruen_poly_deg_3(t0, t_inf, previous_claim) + .gruen_poly_deg_3(t_prime_0, t_prime_inf, previous_claim) } #[tracing::instrument(skip_all, name = "OuterRemainingSumcheckProver::ingest_challenge")] fn ingest_challenge(&mut self, r_j: F::Challenge, _round: usize) { - rayon::join( - || self.az.bind_parallel(r_j, BindingOrder::LowToHigh), - || self.bz.bind_parallel(r_j, BindingOrder::LowToHigh), - ); - - // Bind eq_poly for next round self.split_eq_poly.bind(r_j); + + if self.az.is_none() { + let t_prime_poly = self + .t_prime_poly + .as_mut() + .expect("t_prime_poly should be initialized"); + t_prime_poly.bind(r_j, BindingOrder::LowToHigh); + self.r_grid.update(r_j); + } else { + // NOTE: As we are binding low-to-high in streaming + // I need to revisit the binding algorithm and optimise + // cache lines again + rayon::join( + || { + self.az + .as_mut() + .expect("az should be initialised") + .bind_parallel(r_j, BindingOrder::LowToHigh) + }, + || { + self.bz + .as_mut() + .expect("bz should be initialised") + .bind_parallel(r_j, BindingOrder::LowToHigh) + }, + ); + } } fn cache_openings( diff --git a/tracer/src/utils/inline_helpers.rs b/tracer/src/utils/inline_helpers.rs index 982e7e624..447c5ad9d 100644 --- a/tracer/src/utils/inline_helpers.rs +++ b/tracer/src/utils/inline_helpers.rs @@ -448,7 +448,7 @@ impl InstrAssembler { } pub fn add(&mut self, rs1: Value, rs2: Value, rd: u8) -> Value { - self.bin::(rs1, rs2, rd, |x, y| ((x).wrapping_add(y))) + self.bin::(rs1, rs2, rd, |x, y| (x).wrapping_add(y)) } /// Logical right-shift immediate on a 32-bit word.