Skip to content

Commit 8020567

Browse files
committed
Merge grumpkin branch (PR a16z#1211): GLV scalar decomposition + generic MSM crate
2 parents 9e0cd8c + 892c12e commit 8020567

22 files changed

Lines changed: 2818 additions & 1 deletion

File tree

Cargo.lock

Lines changed: 37 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ members = [
3636
"jolt-inlines/blake3",
3737
"jolt-inlines/bigint",
3838
"jolt-inlines/secp256k1",
39+
"jolt-inlines/grumpkin",
3940
"examples/btreemap/host",
4041
"examples/btreemap/guest",
4142
"examples/collatz",
@@ -76,6 +77,8 @@ members = [
7677
"examples/merkle-tree/guest",
7778
"examples/hash-bench",
7879
"examples/hash-bench/guest",
80+
"examples/msm",
81+
"examples/msm/guest",
7982
"zklean-extractor",
8083
"z3-verifier",
8184
]
@@ -254,3 +257,4 @@ jolt-inlines-blake2 = { path = "./jolt-inlines/blake2", default-features = false
254257
jolt-inlines-blake3 = { path = "./jolt-inlines/blake3", default-features = false }
255258
jolt-inlines-bigint = { path = "./jolt-inlines/bigint", default-features = false }
256259
jolt-inlines-secp256k1 = { path = "./jolt-inlines/secp256k1", default-features = false }
260+
jolt-inlines-grumpkin = { path = "./jolt-inlines/grumpkin", default-features = false }

examples/msm/Cargo.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[package]
2+
name = "msm-bench"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[dependencies]
7+
jolt-sdk = { workspace = true, features = ["host"] }
8+
tracing-subscriber.workspace = true
9+
tracing.workspace = true
10+
jolt-inlines-grumpkin = { workspace = true, features = ["host"] }
11+
guest = { package = "msm-bench-guest", path = "./guest" }
12+
ark-bn254.workspace = true

examples/msm/guest/Cargo.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[package]
2+
name = "msm-bench-guest"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[features]
7+
guest = []
8+
9+
[dependencies]
10+
jolt = { package = "jolt-sdk", path = "../../../jolt-sdk", features = [] }
11+
jolt-inlines-grumpkin.workspace = true
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
use crate::fixed_base::FixedBaseTable as GenericFixedBaseTable;
2+
use crate::traits::{GlvCapable, MsmGroup};
3+
use jolt_inlines_grumpkin::{GrumpkinFr, GrumpkinPoint, UnwrapOrSpoilProof};
4+
5+
// ============================================================
6+
// Curve Parameters
7+
// ============================================================
8+
9+
pub const SCALAR_BITS: usize = 256;
10+
pub const GLV_SCALAR_BITS: usize = 128;
11+
12+
// Pippenger parameters for baseline (256-bit scalars).
13+
pub const BASELINE_WINDOW: usize = 12;
14+
pub const BASELINE_BUCKETS: usize = 1 << BASELINE_WINDOW;
15+
pub const BASELINE_WINDOWS: usize = SCALAR_BITS.div_ceil(BASELINE_WINDOW);
16+
17+
// Pippenger parameters for GLV (128-bit scalars).
18+
pub const GLV_WINDOW: usize = 8;
19+
pub const GLV_BUCKETS: usize = 1 << GLV_WINDOW;
20+
pub const GLV_WINDOWS: usize = GLV_SCALAR_BITS.div_ceil(GLV_WINDOW);
21+
22+
// Fixed-base (generator) windowed multiplication parameters (256-bit scalars).
23+
pub const FIXED_BASE_WINDOW: usize = 14;
24+
pub const FIXED_BASE_BUCKETS: usize = 1 << FIXED_BASE_WINDOW;
25+
pub const FIXED_BASE_WINDOWS: usize = SCALAR_BITS.div_ceil(FIXED_BASE_WINDOW);
26+
27+
pub type FixedBaseTable =
28+
GenericFixedBaseTable<GrumpkinPoint, FIXED_BASE_WINDOWS, FIXED_BASE_BUCKETS>;
29+
30+
// ============================================================
31+
// Trait Implementations
32+
// ============================================================
33+
34+
impl MsmGroup for GrumpkinPoint {
35+
#[inline(always)]
36+
fn identity() -> Self {
37+
GrumpkinPoint::infinity()
38+
}
39+
40+
#[inline(always)]
41+
fn is_identity(&self) -> bool {
42+
self.is_infinity()
43+
}
44+
45+
#[inline(always)]
46+
fn add(&self, other: &Self) -> Self {
47+
GrumpkinPoint::add(self, other)
48+
}
49+
50+
#[inline(always)]
51+
fn neg(&self) -> Self {
52+
GrumpkinPoint::neg(self)
53+
}
54+
55+
#[inline(always)]
56+
fn double(&self) -> Self {
57+
GrumpkinPoint::double(self)
58+
}
59+
60+
#[inline(always)]
61+
fn double_and_add(&self, other: &Self) -> Self {
62+
GrumpkinPoint::double_and_add(self, other)
63+
}
64+
}
65+
66+
impl GlvCapable for GrumpkinPoint {
67+
type HalfScalar = u128;
68+
type FullScalar = GrumpkinFr;
69+
70+
#[inline(always)]
71+
fn endomorphism(&self) -> Self {
72+
GrumpkinPoint::endomorphism(self)
73+
}
74+
75+
#[inline(always)]
76+
fn decompose_scalar(k: &GrumpkinFr) -> [(bool, u128); 2] {
77+
GrumpkinPoint::decompose_scalar(k)
78+
}
79+
}
80+
81+
// ============================================================
82+
// Benchmark Helpers
83+
// ============================================================
84+
85+
/// Grumpkin Fr modulus limbs for scalar reduction.
86+
const FR_MODULUS_LIMBS: [u64; 4] = [
87+
4332616871279656263,
88+
10917124144477883021,
89+
13281191951274694749,
90+
3486998266802970665,
91+
];
92+
93+
#[inline(always)]
94+
fn is_ge_modulus(x: &[u64; 4]) -> bool {
95+
let m = FR_MODULUS_LIMBS;
96+
if x[3] > m[3] {
97+
return true;
98+
}
99+
if x[3] < m[3] {
100+
return false;
101+
}
102+
if x[2] > m[2] {
103+
return true;
104+
}
105+
if x[2] < m[2] {
106+
return false;
107+
}
108+
if x[1] > m[1] {
109+
return true;
110+
}
111+
if x[1] < m[1] {
112+
return false;
113+
}
114+
x[0] >= m[0]
115+
}
116+
117+
#[inline(always)]
118+
fn sub_modulus(x: [u64; 4]) -> [u64; 4] {
119+
let m = FR_MODULUS_LIMBS;
120+
let mut out = [0u64; 4];
121+
let mut borrow = 0u128;
122+
let mut i = 0;
123+
while i < 4 {
124+
let xi = x[i] as u128;
125+
let mi = m[i] as u128 + borrow;
126+
if xi >= mi {
127+
out[i] = (xi - mi) as u64;
128+
borrow = 0;
129+
} else {
130+
out[i] = ((1u128 << 64) + xi - mi) as u64;
131+
borrow = 1;
132+
}
133+
i += 1;
134+
}
135+
out
136+
}
137+
138+
/// Reduce scalar modulo Fr.
139+
#[inline(always)]
140+
pub fn reduce_scalar(mut scalar: [u64; 4]) -> [u64; 4] {
141+
while is_ge_modulus(&scalar) {
142+
scalar = sub_modulus(scalar);
143+
}
144+
scalar
145+
}
146+
147+
/// Generate deterministic test scalars using a simple LCG.
148+
#[inline(always)]
149+
pub fn generate_scalars<const N: usize>(seed: u64) -> [[u64; 4]; N] {
150+
let mut scalars = [[0u64; 4]; N];
151+
let (a, c) = (6364136223846793005u64, 1442695040888963407u64);
152+
let mut state = seed;
153+
154+
for scalar in scalars.iter_mut() {
155+
for limb in scalar.iter_mut() {
156+
state = state.wrapping_mul(a).wrapping_add(c);
157+
*limb = state;
158+
}
159+
*scalar = reduce_scalar(*scalar);
160+
}
161+
scalars
162+
}
163+
164+
/// Generate deterministic test points by fixed-base scalar multiplication of generator.
165+
#[inline(always)]
166+
pub fn generate_points_fixed_base<const N: usize>(
167+
seed: u64,
168+
table_g: &FixedBaseTable,
169+
) -> [GrumpkinPoint; N] {
170+
let mut points: [GrumpkinPoint; N] = core::array::from_fn(|_| GrumpkinPoint::infinity());
171+
let (a, c) = (6364136223846793005u64, 1442695040888963407u64);
172+
let mut state = seed;
173+
174+
for point in points.iter_mut() {
175+
state = state.wrapping_mul(a).wrapping_add(c);
176+
let small_scalar = [state, 0, 0, 0];
177+
*point = table_g.scalar_mul(&small_scalar);
178+
}
179+
points
180+
}
181+
182+
/// Convert [u64; 4] to GrumpkinFr.
183+
#[inline(always)]
184+
pub fn scalar_to_fr(scalar: &[u64; 4]) -> GrumpkinFr {
185+
GrumpkinFr::from_u64_arr(scalar).unwrap_or_spoil_proof()
186+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub mod grumpkin;
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
use alloc::boxed::Box;
2+
3+
use crate::traits::{MsmGroup, WindowedScalar};
4+
5+
/// Fixed-base precomputed table for a single base point.
6+
/// table[window][digit] = digit · 2^(window·width) · base.
7+
///
8+
/// Type parameters:
9+
/// - `G`: The group/point type
10+
/// - `WINDOWS`: Number of windows (= ceil(scalar_bits / window_bits))
11+
/// - `BUCKETS`: Number of buckets per window (= 2^window_bits)
12+
pub struct FixedBaseTable<G, const WINDOWS: usize, const BUCKETS: usize> {
13+
table: Box<[[G; BUCKETS]; WINDOWS]>,
14+
}
15+
16+
impl<G: MsmGroup, const WINDOWS: usize, const BUCKETS: usize> FixedBaseTable<G, WINDOWS, BUCKETS> {
17+
/// Window size derived from BUCKETS at compile time.
18+
pub const WINDOW_BITS: usize = BUCKETS.trailing_zeros() as usize;
19+
20+
/// Precompute table for a given base point.
21+
/// Window size is derived from BUCKETS const generic.
22+
#[inline(always)]
23+
pub fn new(base: &G) -> Self {
24+
const {
25+
assert!(BUCKETS > 1, "BUCKETS must be > 1");
26+
assert!(BUCKETS.is_power_of_two(), "BUCKETS must be a power of two");
27+
assert!(
28+
BUCKETS.trailing_zeros() as usize <= 16,
29+
"BUCKETS (window size) must be <= 2^16"
30+
);
31+
}
32+
33+
let mut table: Box<[[G; BUCKETS]; WINDOWS]> = Box::new(core::array::from_fn(|_| {
34+
core::array::from_fn(|_| G::identity())
35+
}));
36+
37+
let mut window_base = base.clone();
38+
for window_table in table.iter_mut() {
39+
window_table[0] = G::identity();
40+
window_table[1] = window_base.clone();
41+
let mut digit = 2;
42+
while digit < BUCKETS {
43+
window_table[digit] = window_table[digit - 1].add(&window_base);
44+
digit += 1;
45+
}
46+
47+
let mut shift = 0;
48+
while shift < Self::WINDOW_BITS {
49+
window_base = window_base.double();
50+
shift += 1;
51+
}
52+
}
53+
54+
Self { table }
55+
}
56+
57+
/// Scalar multiplication using table lookup + addition.
58+
#[inline(always)]
59+
pub fn scalar_mul<S: WindowedScalar>(&self, scalar: &S) -> G {
60+
let mut result = G::identity();
61+
for (window_idx, window_table) in self.table.iter().enumerate() {
62+
let digit = scalar.window(window_idx * Self::WINDOW_BITS, Self::WINDOW_BITS) as usize;
63+
if digit != 0 {
64+
result = result.add(&window_table[digit]);
65+
}
66+
}
67+
result
68+
}
69+
}
70+
71+
/// Fixed-base MSM: Σ(scalar_i · base) using precomputed table.
72+
#[inline(always)]
73+
pub fn msm_fixed_base<G, S, const WINDOWS: usize, const BUCKETS: usize>(
74+
scalars: &[S],
75+
table: &FixedBaseTable<G, WINDOWS, BUCKETS>,
76+
) -> G
77+
where
78+
G: MsmGroup,
79+
S: WindowedScalar,
80+
{
81+
let mut result = G::identity();
82+
for scalar in scalars {
83+
let term = table.scalar_mul(scalar);
84+
result = result.add(&term);
85+
}
86+
result
87+
}

0 commit comments

Comments
 (0)