Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 65 additions & 16 deletions crates/multilinear/src/eval.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use hypercube_alloc::{buffer, Buffer, CpuBackend};
use hypercube_alloc::{Buffer, CpuBackend};
use hypercube_tensor::{Dimensions, Tensor};
use p3_field::{AbstractExtensionField, AbstractField};
use rayon::prelude::*;
use std::sync::{Arc, Mutex};

use crate::{partial_lagrange_blocking, Point};

Expand All @@ -16,22 +17,70 @@ pub(crate) fn eval_mle_at_point_blocking<
let mut sizes = mle.sizes().to_vec();
sizes.remove(0);
let dimensions = Dimensions::try_from(sizes).unwrap();
let mut dst = Tensor { storage: buffer![], dimensions };
let total_len = dst.total_len();
let dot_products = mle
.as_buffer()
let total_len = dimensions.total_len();

// Pre-allocation of the result buffer
let result = Arc::new(Mutex::new(vec![EF::zero(); total_len]));

// Process in parallel using Rayon
mle.as_buffer()
.par_chunks_exact(mle.strides()[0])
.zip(partial_lagrange.as_buffer().par_iter())
.map(|(chunk, scalar)| chunk.iter().map(|a| scalar.clone() * a.clone()).collect())
.reduce(
|| vec![EF::zero(); total_len],
|mut a, b| {
a.iter_mut().zip(b.iter()).for_each(|(a, b)| *a += b.clone());
a
},
);
.for_each(|(chunk, scalar)| {
// Process each chunk with a thread-local accumulator
let mut local_result = vec![EF::zero(); total_len];

// Avoid allocation in the inner loop
for (i, a) in chunk.iter().enumerate() {
if i < total_len {
// Compute scalar * a directly into the accumulator
local_result[i] = scalar.clone() * a.clone();
}
}

// Update the global result with our local computation
let result_clone = Arc::clone(&result);
let mut global_result = result_clone.lock().unwrap();
for i in 0..total_len {
global_result[i] += local_result[i].clone();
}
});

let dot_products = Buffer::from(dot_products);
dst.storage = dot_products;
dst
// Create the final tensor
let result_buffer = Buffer::from(Arc::try_unwrap(result).unwrap().into_inner().unwrap());
Tensor { storage: result_buffer, dimensions }
}

// Add a specialized implementation for the case when the number of polynomials is small
pub(crate) fn eval_mle_at_point_small_batch<
F: AbstractField + Sync,
EF: AbstractExtensionField<F> + Send + Sync,
>(
mle: &Tensor<F, CpuBackend>,
point: &Point<EF, CpuBackend>,
) -> Tensor<EF, CpuBackend> {
// For small batches (fewer than 4 polynomials), use a different approach
// that avoids the overhead of parallelization
let partial_lagrange = partial_lagrange_blocking(point);
let mut sizes = mle.sizes().to_vec();
sizes.remove(0);
let dimensions = Dimensions::try_from(sizes).unwrap();
let total_len = dimensions.total_len();

// Direct computation without parallelization for small batches
let mut result = vec![EF::zero(); total_len];

for (chunk, scalar) in mle.as_buffer()
.chunks_exact(mle.strides()[0])
.zip(partial_lagrange.as_buffer().iter())
{
for (i, a) in chunk.iter().enumerate() {
if i < total_len {
result[i] += scalar.clone() * a.clone();
}
}
}

let result_buffer = Buffer::from(result);
Tensor { storage: result_buffer, dimensions }
}
70 changes: 49 additions & 21 deletions crates/multilinear/src/mle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use p3_field::{AbstractExtensionField, AbstractField, Field};
use rand::{distributions::Standard, prelude::Distribution, Rng};
use serde::{Deserialize, Serialize};

use crate::{eval::eval_mle_at_point_blocking, partial_lagrange_blocking, MleBaseBackend, Point};
use crate::eval::{eval_mle_at_point_blocking, eval_mle_at_point_small_batch};
use crate::{partial_lagrange_blocking, MleBaseBackend, Point};

/// A bacth of multi-linear polynomials.
#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -263,7 +264,15 @@ impl<T> Mle<T, CpuBackend> {
T: AbstractField + 'static + Send + Sync,
E: AbstractExtensionField<T> + 'static + Send + Sync,
{
MleEval::new(eval_mle_at_point_blocking(self.guts(), point))
// Modify this line to use the optimized function when appropriate
let result = if self.num_polynomials() < 4 {
// For small batches, use the specialized implementation
MleEval::new(eval_mle_at_point_small_batch(self.guts(), point))
} else {
// For larger batches, use the parallel implementation
MleEval::new(eval_mle_at_point_blocking(self.guts(), point))
};
result
}

pub fn blocking_partial_lagrange(point: &Point<T>) -> Mle<T, CpuBackend>
Expand All @@ -284,25 +293,43 @@ impl<T> Mle<T, CpuBackend> {
///
/// The polynomial f(X,Y) is an important building block in zerocheck and other protocols which use
/// sumcheck.
pub fn full_lagrange_eval<EF>(point_1: &Point<T>, point_2: &Point<EF>) -> EF
where
T: AbstractField,
EF: AbstractExtensionField<T>,
{
assert_eq!(point_1.dimension(), point_2.dimension());

// Iterate over all values in the n-variates X and Y.
point_1
.iter()
.zip(point_2.iter())
.map(|(x, y)| {
// Multiply by (x_i * y_i + (1-x_i) * (1-y_i)).
let prod = y.clone() * x.clone();
prod.clone() + prod + EF::one() - x.clone() - y.clone()
})
.product()
}
}
pub fn full_lagrange_eval<EF>(point_1: &Point<T>, point_2: &Point<EF>) -> EF
where
T: AbstractField,
EF: AbstractExtensionField<T>,
{
assert_eq!(point_1.dimension(), point_2.dimension());

// Iterate over all values in the n-variates X and Y.
point_1
.iter()
.zip(point_2.iter())
.map(|(x, y)| {
// Multiply by (x_i * y_i + (1-x_i) * (1-y_i)).
let prod = y.clone() * x.clone();
prod.clone() + prod + EF::one() - x.clone() - y.clone()
})
.product()
}
}


// pub fn blocking_eval_at<E>(&self, point: &Point<E>) -> MleEval<E>
// where
// T: AbstractField + 'static + Send + Sync,
// E: AbstractExtensionField<T> + 'static + Send + Sync,
// {
// // Modify this line to use the optimized function when appropriate
// let result = if self.num_polynomials() < 4 {
// // For small batches, use the specialized implementation
// MleEval::new(eval_mle_at_point_small_batch(self.guts(), point))
// } else {
// // For larger batches, use the parallel implementation
// MleEval::new(eval_mle_at_point_blocking(self.guts(), point))
// };
// result
// }
// }

// impl<T: AbstractField + Send + Sync> TryInto<p3_matrix::dense::RowMajorMatrix<T>>
// for Mle<T, CpuBackend>
Expand Down Expand Up @@ -492,3 +519,4 @@ impl<T> FromIterator<T> for MleEval<T, CpuBackend> {
Self::new(Tensor::from(iter.into_iter().collect::<Vec<_>>()))
}
}

Loading