Skip to content

Commit 937a924

Browse files
authored
Optimized linear combination of points (#380)
Add `lincomb()` as an alias for a 2-point linear combination
1 parent 2c12fc5 commit 937a924

File tree

6 files changed

+192
-51
lines changed

6 files changed

+192
-51
lines changed

k256/bench/scalar.rs

+10-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use criterion::{
66
use hex_literal::hex;
77
use k256::{
88
elliptic_curve::{generic_array::arr, group::ff::PrimeField},
9-
ProjectivePoint, Scalar,
9+
lincomb, ProjectivePoint, Scalar,
1010
};
1111

1212
fn test_scalar_x() -> Scalar {
@@ -34,9 +34,18 @@ fn bench_point_mul<'a, M: Measurement>(group: &mut BenchmarkGroup<'a, M>) {
3434
group.bench_function("point-scalar mul", |b| b.iter(|| &p * &s));
3535
}
3636

37+
fn bench_point_lincomb<'a, M: Measurement>(group: &mut BenchmarkGroup<'a, M>) {
38+
let p = ProjectivePoint::generator();
39+
let m = hex!("AA5E28D6A97A2479A65527F7290311A3624D4CC0FA1578598EE3C2613BF99522");
40+
let s = Scalar::from_repr(m.into()).unwrap();
41+
group.bench_function("lincomb via mul+add", |b| b.iter(|| &p * &s + &p * &s));
42+
group.bench_function("lincomb()", |b| b.iter(|| lincomb(&p, &s, &p, &s)));
43+
}
44+
3745
fn bench_high_level(c: &mut Criterion) {
3846
let mut group = c.benchmark_group("high-level operations");
3947
bench_point_mul(&mut group);
48+
bench_point_lincomb(&mut group);
4049
group.finish();
4150
}
4251

k256/src/arithmetic.rs

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pub(crate) mod scalar;
88
mod util;
99

1010
pub use field::FieldElement;
11+
pub use mul::lincomb;
1112

1213
use affine::AffinePoint;
1314
use projective::ProjectivePoint;

k256/src/arithmetic/mul.rs

+168-43
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ use core::ops::{Mul, MulAssign};
7070
use elliptic_curve::subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
7171

7272
/// Lookup table containing precomputed values `[p, 2p, 3p, ..., 8p]`
73+
#[derive(Copy, Clone, Default)]
7374
struct LookupTable([ProjectivePoint; 8]);
7475

7576
impl From<&ProjectivePoint> for LookupTable {
@@ -147,94 +148,218 @@ fn decompose_scalar(k: &Scalar) -> (Scalar, Scalar) {
147148
(r1, r2)
148149
}
149150

150-
/// Returns `[a_0, ..., a_32]` such that `sum(a_j * 2^(j * 4)) == x`,
151-
/// and `-8 <= a_j <= 7`.
152-
/// Assumes `x < 2^128`.
153-
fn to_radix_16_half(x: &Scalar) -> [i8; 33] {
154-
// `x` can have up to 256 bits, so we need an additional byte to store the carry.
155-
let mut output = [0i8; 33];
156-
157-
// Step 1: change radix.
158-
// Convert from radix 256 (bytes) to radix 16 (nibbles)
159-
let bytes = x.to_bytes();
160-
for i in 0..16 {
161-
output[2 * i] = (bytes[31 - i] & 0xf) as i8;
162-
output[2 * i + 1] = ((bytes[31 - i] >> 4) & 0xf) as i8;
163-
}
151+
// This needs to be an object to have Default implemented for it
152+
// (required because it's used in static_map later)
153+
// Otherwise we could just have a function returning an array.
154+
#[derive(Copy, Clone)]
155+
struct Radix16Decomposition([i8; 33]);
156+
157+
impl Radix16Decomposition {
158+
/// Returns an object containing a decomposition
159+
/// `[a_0, ..., a_32]` such that `sum(a_j * 2^(j * 4)) == x`,
160+
/// and `-8 <= a_j <= 7`.
161+
/// Assumes `x < 2^128`.
162+
fn new(x: &Scalar) -> Self {
163+
debug_assert!((x >> 128).is_zero().unwrap_u8() == 1);
164+
165+
// The resulting decomposition can be negative, so, despite the limit on `x`,
166+
// it can have up to 256 bits, and we need an additional byte to store the carry.
167+
let mut output = [0i8; 33];
168+
169+
// Step 1: change radix.
170+
// Convert from radix 256 (bytes) to radix 16 (nibbles)
171+
let bytes = x.to_bytes();
172+
for i in 0..16 {
173+
output[2 * i] = (bytes[31 - i] & 0xf) as i8;
174+
output[2 * i + 1] = ((bytes[31 - i] >> 4) & 0xf) as i8;
175+
}
164176

165-
debug_assert!((x >> 128).is_zero().unwrap_u8() == 1);
177+
// Step 2: recenter coefficients from [0,16) to [-8,8)
178+
for i in 0..32 {
179+
let carry = (output[i] + 8) >> 4;
180+
output[i] -= carry << 4;
181+
output[i + 1] += carry;
182+
}
166183

167-
// Step 2: recenter coefficients from [0,16) to [-8,8)
168-
for i in 0..32 {
169-
let carry = (output[i] + 8) >> 4;
170-
output[i] -= carry << 4;
171-
output[i + 1] += carry;
184+
Self(output)
172185
}
173-
174-
output
175186
}
176187

177-
fn mul_windowed(x: &ProjectivePoint, k: &Scalar) -> ProjectivePoint {
178-
let (r1, r2) = decompose_scalar(k);
179-
let x_beta = x.endomorphism();
188+
impl Default for Radix16Decomposition {
189+
fn default() -> Self {
190+
Self([0i8; 33])
191+
}
192+
}
180193

181-
let r1_sign = r1.is_high();
182-
let r1_c = Scalar::conditional_select(&r1, &-r1, r1_sign);
183-
let r2_sign = r2.is_high();
184-
let r2_c = Scalar::conditional_select(&r2, &-r2, r2_sign);
194+
/// Maps an array `x` to an array using the predicate `f`.
195+
/// We can't use the standard `map()` because as of Rust 1.51 we cannot collect into arrays.
196+
/// Consequently, since we cannot have an uninitialized array (without `unsafe`),
197+
/// a default value needs to be provided.
198+
fn static_map<T: Copy, V: Copy, const N: usize>(
199+
f: impl Fn(T) -> V,
200+
x: &[T; N],
201+
default: V,
202+
) -> [V; N] {
203+
let mut res = [default; N];
204+
for i in 0..N {
205+
res[i] = f(x[i]);
206+
}
207+
res
208+
}
185209

186-
let table1 = LookupTable::from(&ProjectivePoint::conditional_select(x, &-x, r1_sign));
187-
let table2 = LookupTable::from(&ProjectivePoint::conditional_select(
188-
&x_beta, &-x_beta, r2_sign,
189-
));
210+
/// Maps two arrays `x` and `y` into an array using a predicate `f` that takes two arguments.
211+
fn static_zip_map<T: Copy, S: Copy, V: Copy, const N: usize>(
212+
f: impl Fn(T, S) -> V,
213+
x: &[T; N],
214+
y: &[S; N],
215+
default: V,
216+
) -> [V; N] {
217+
let mut res = [default; N];
218+
for i in 0..N {
219+
res[i] = f(x[i], y[i]);
220+
}
221+
res
222+
}
190223

191-
let digits1 = to_radix_16_half(&r1_c);
192-
let digits2 = to_radix_16_half(&r2_c);
224+
/// Calculates a linear combination `sum(x[i] * k[i])`, `i = 0..N`
225+
#[inline(always)]
226+
fn lincomb_generic<const N: usize>(xs: &[ProjectivePoint; N], ks: &[Scalar; N]) -> ProjectivePoint {
227+
let rs = static_map(
228+
|k| decompose_scalar(&k),
229+
ks,
230+
(Scalar::default(), Scalar::default()),
231+
);
232+
let r1s = static_map(|(r1, _r2)| r1, &rs, Scalar::default());
233+
let r2s = static_map(|(_r1, r2)| r2, &rs, Scalar::default());
234+
235+
let xs_beta = static_map(|x| x.endomorphism(), xs, ProjectivePoint::default());
236+
237+
let r1_signs = static_map(|r| r.is_high(), &r1s, Choice::from(0u8));
238+
let r2_signs = static_map(|r| r.is_high(), &r2s, Choice::from(0u8));
239+
240+
let r1s_c = static_zip_map(
241+
|r, r_sign| Scalar::conditional_select(&r, &-r, r_sign),
242+
&r1s,
243+
&r1_signs,
244+
Scalar::default(),
245+
);
246+
let r2s_c = static_zip_map(
247+
|r, r_sign| Scalar::conditional_select(&r, &-r, r_sign),
248+
&r2s,
249+
&r2_signs,
250+
Scalar::default(),
251+
);
252+
253+
let tables1 = static_zip_map(
254+
|x, r_sign| LookupTable::from(&ProjectivePoint::conditional_select(&x, &-x, r_sign)),
255+
&xs,
256+
&r1_signs,
257+
LookupTable::default(),
258+
);
259+
let tables2 = static_zip_map(
260+
|x, r_sign| LookupTable::from(&ProjectivePoint::conditional_select(&x, &-x, r_sign)),
261+
&xs_beta,
262+
&r2_signs,
263+
LookupTable::default(),
264+
);
265+
266+
let digits1 = static_map(
267+
|r| Radix16Decomposition::new(&r),
268+
&r1s_c,
269+
Radix16Decomposition::default(),
270+
);
271+
let digits2 = static_map(
272+
|r| Radix16Decomposition::new(&r),
273+
&r2s_c,
274+
Radix16Decomposition::default(),
275+
);
276+
277+
let mut acc = ProjectivePoint::identity();
278+
for component in 0..N {
279+
acc += &tables1[component].select(digits1[component].0[32]);
280+
acc += &tables2[component].select(digits2[component].0[32]);
281+
}
193282

194-
let mut acc = table1.select(digits1[32]) + table2.select(digits2[32]);
195283
for i in (0..32).rev() {
196284
for _j in 0..4 {
197285
acc = acc.double();
198286
}
199287

200-
acc += &table1.select(digits1[i]);
201-
acc += &table2.select(digits2[i]);
288+
for component in 0..N {
289+
acc += &tables1[component].select(digits1[component].0[i]);
290+
acc += &tables2[component].select(digits2[component].0[i]);
291+
}
202292
}
203293
acc
204294
}
205295

296+
#[inline(always)]
297+
fn mul(x: &ProjectivePoint, k: &Scalar) -> ProjectivePoint {
298+
lincomb_generic(&[*x], &[*k])
299+
}
300+
301+
/// Calculates `x * k + y * l`.
302+
pub fn lincomb(
303+
x: &ProjectivePoint,
304+
k: &Scalar,
305+
y: &ProjectivePoint,
306+
l: &Scalar,
307+
) -> ProjectivePoint {
308+
lincomb_generic(&[*x, *y], &[*k, *l])
309+
}
310+
206311
impl Mul<Scalar> for ProjectivePoint {
207312
type Output = ProjectivePoint;
208313

209314
fn mul(self, other: Scalar) -> ProjectivePoint {
210-
mul_windowed(&self, &other)
315+
mul(&self, &other)
211316
}
212317
}
213318

214319
impl Mul<&Scalar> for &ProjectivePoint {
215320
type Output = ProjectivePoint;
216321

217322
fn mul(self, other: &Scalar) -> ProjectivePoint {
218-
mul_windowed(self, other)
323+
mul(self, other)
219324
}
220325
}
221326

222327
impl Mul<&Scalar> for ProjectivePoint {
223328
type Output = ProjectivePoint;
224329

225330
fn mul(self, other: &Scalar) -> ProjectivePoint {
226-
mul_windowed(&self, other)
331+
mul(&self, other)
227332
}
228333
}
229334

230335
impl MulAssign<Scalar> for ProjectivePoint {
231336
fn mul_assign(&mut self, rhs: Scalar) {
232-
*self = mul_windowed(self, &rhs);
337+
*self = mul(self, &rhs);
233338
}
234339
}
235340

236341
impl MulAssign<&Scalar> for ProjectivePoint {
237342
fn mul_assign(&mut self, rhs: &Scalar) {
238-
*self = mul_windowed(self, rhs);
343+
*self = mul(self, rhs);
344+
}
345+
}
346+
347+
#[cfg(test)]
348+
mod tests {
349+
use super::lincomb;
350+
use crate::arithmetic::{ProjectivePoint, Scalar};
351+
use elliptic_curve::rand_core::OsRng;
352+
use elliptic_curve::{Field, Group};
353+
354+
#[test]
355+
fn test_lincomb() {
356+
let x = ProjectivePoint::random(&mut OsRng);
357+
let y = ProjectivePoint::random(&mut OsRng);
358+
let k = Scalar::random(&mut OsRng);
359+
let l = Scalar::random(&mut OsRng);
360+
361+
let reference = &x * &k + &y * &l;
362+
let test = lincomb(&x, &k, &y, &l);
363+
assert_eq!(reference, test);
239364
}
240365
}

k256/src/ecdsa/recoverable.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ use crate::{
5151
consts::U32, generic_array::GenericArray, ops::Invert, subtle::Choice,
5252
weierstrass::DecompressPoint,
5353
},
54-
AffinePoint, FieldBytes, NonZeroScalar, ProjectivePoint, Scalar,
54+
lincomb, AffinePoint, FieldBytes, NonZeroScalar, ProjectivePoint, Scalar,
5555
};
5656

5757
#[cfg(feature = "keccak256")]
@@ -185,7 +185,7 @@ impl Signature {
185185
let r_inv = r.invert().unwrap();
186186
let u1 = -(r_inv * z);
187187
let u2 = r_inv * *s;
188-
let pk = ((ProjectivePoint::generator() * u1) + (R * u2)).to_affine();
188+
let pk = lincomb(&ProjectivePoint::generator(), &u1, &R, &u2).to_affine();
189189

190190
// TODO(tarcieri): ensure the signature verifies?
191191
Ok(VerifyingKey::from(&pk))

k256/src/ecdsa/verify.rs

+10-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
33
use super::{recoverable, Error, Signature};
44
use crate::{
5-
AffinePoint, CompressedPoint, EncodedPoint, ProjectivePoint, PublicKey, Scalar, Secp256k1,
5+
lincomb, AffinePoint, CompressedPoint, EncodedPoint, ProjectivePoint, PublicKey, Scalar,
6+
Secp256k1,
67
};
78
use core::convert::TryFrom;
89
use ecdsa_core::{hazmat::VerifyPrimitive, signature};
@@ -90,9 +91,14 @@ impl VerifyPrimitive<Secp256k1> for AffinePoint {
9091
let u1 = z * &s_inv;
9192
let u2 = *r * s_inv;
9293

93-
let x = ((ProjectivePoint::generator() * u1) + (ProjectivePoint::from(*self) * u2))
94-
.to_affine()
95-
.x;
94+
let x = lincomb(
95+
&ProjectivePoint::generator(),
96+
&u1,
97+
&ProjectivePoint::from(*self),
98+
&u2,
99+
)
100+
.to_affine()
101+
.x;
96102

97103
if Scalar::from_bytes_reduced(&x.to_bytes()).eq(&r) {
98104
Ok(())

k256/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ pub mod test_vectors;
6767
pub use elliptic_curve::{self, bigint::U256};
6868

6969
#[cfg(feature = "arithmetic")]
70-
pub use arithmetic::{affine::AffinePoint, projective::ProjectivePoint, scalar::Scalar};
70+
pub use arithmetic::{affine::AffinePoint, lincomb, projective::ProjectivePoint, scalar::Scalar};
7171

7272
#[cfg(feature = "expose-field")]
7373
pub use arithmetic::FieldElement;

0 commit comments

Comments
 (0)