Skip to content

Commit d2dec5d

Browse files
committed
Add allocation-free RistrettoPoint::double_and_compress_batch()
1 parent 1ad4603 commit d2dec5d

File tree

4 files changed

+134
-71
lines changed

4 files changed

+134
-71
lines changed

curve25519-dalek/benches/dalek_benchmarks.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ mod ristretto_benches {
289289
let points: Vec<RistrettoPoint> = (0..size)
290290
.map(|_| RistrettoPoint::try_from_rng(&mut rng).unwrap())
291291
.collect();
292-
b.iter(|| RistrettoPoint::double_and_compress_batch(&points));
292+
b.iter(|| RistrettoPoint::double_and_compress_alloc_batch(&points));
293293
},
294294
);
295295
}

curve25519-dalek/src/edwards.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ impl EdwardsPoint {
599599

600600
// Compute the denominators in a batch
601601
let mut denominators = eds.iter().map(|p| &p.Z - &p.Y).collect::<Vec<_>>();
602-
FieldElement::batch_invert(&mut denominators);
602+
FieldElement::batch_alloc_invert(&mut denominators);
603603

604604
// Now compute the Montgomery u coordinate for every point
605605
let mut ret = Vec::with_capacity(eds.len());
@@ -621,7 +621,7 @@ impl EdwardsPoint {
621621
#[cfg(feature = "alloc")]
622622
pub fn compress_batch(inputs: &[EdwardsPoint]) -> Vec<CompressedEdwardsY> {
623623
let mut zs = inputs.iter().map(|input| input.Z).collect::<Vec<_>>();
624-
FieldElement::batch_invert(&mut zs);
624+
FieldElement::batch_alloc_invert(&mut zs);
625625

626626
inputs
627627
.iter()

curve25519-dalek/src/field.rs

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,17 +203,32 @@ impl FieldElement {
203203
(t19, t3)
204204
}
205205

206+
/// Given a slice of pub(crate)lic `FieldElements`, replace each with its inverse.
207+
///
208+
/// When an input `FieldElement` is zero, its value is unchanged.
209+
pub(crate) fn batch_invert<const N: usize>(inputs: &mut [FieldElement; N]) {
210+
let mut scratch = [FieldElement::ONE; N];
211+
212+
Self::internal_batch_invert(inputs, &mut scratch);
213+
}
214+
206215
/// Given a slice of pub(crate)lic `FieldElements`, replace each with its inverse.
207216
///
208217
/// When an input `FieldElement` is zero, its value is unchanged.
209218
#[cfg(feature = "alloc")]
210-
pub(crate) fn batch_invert(inputs: &mut [FieldElement]) {
219+
pub(crate) fn batch_alloc_invert(inputs: &mut [FieldElement]) {
220+
let n = inputs.len();
221+
let mut scratch = vec![FieldElement::ONE; n];
222+
223+
Self::internal_batch_invert(inputs, &mut scratch);
224+
}
225+
226+
fn internal_batch_invert(inputs: &mut [FieldElement], scratch: &mut [FieldElement]) {
211227
// Montgomery’s Trick and Fast Implementation of Masked AES
212228
// Genelle, Prouff and Quisquater
213229
// Section 3.2
214230

215-
let n = inputs.len();
216-
let mut scratch = vec![FieldElement::ONE; n];
231+
debug_assert_eq!(inputs.len(), scratch.len());
217232

218233
// Keep an accumulator of all of the previous products
219234
let mut acc = FieldElement::ONE;
@@ -234,12 +249,12 @@ impl FieldElement {
234249

235250
// Pass through the vector backwards to compute the inverses
236251
// in place
237-
for (input, scratch) in inputs.iter_mut().rev().zip(scratch.into_iter().rev()) {
252+
for (input, scratch) in inputs.iter_mut().rev().zip(scratch.iter_mut().rev()) {
238253
let tmp = &acc * input;
239254
// input <- acc * scratch, then acc <- tmp
240255
// Again, we skip zeros in a constant-time way
241256
let nz = !input.is_zero();
242-
input.conditional_assign(&(&acc * &scratch), nz);
257+
input.conditional_assign(&(&acc * scratch), nz);
243258
acc.conditional_assign(&tmp, nz);
244259
}
245260
}
@@ -485,7 +500,7 @@ mod test {
485500
let a2 = &a + &a;
486501
let a_list = vec![a, ap58, asq, ainv, a0, a2];
487502
let mut ainv_list = a_list.clone();
488-
FieldElement::batch_invert(&mut ainv_list[..]);
503+
FieldElement::batch_alloc_invert(&mut ainv_list[..]);
489504
for i in 0..6 {
490505
assert_eq!(a_list[i].invert(), ainv_list[i]);
491506
}

curve25519-dalek/src/ristretto.rs

Lines changed: 110 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@
161161
#[cfg(feature = "alloc")]
162162
use alloc::vec::Vec;
163163

164-
use core::array::TryFromSliceError;
164+
use core::array::{self, TryFromSliceError};
165165
use core::borrow::Borrow;
166166
use core::fmt::Debug;
167167
use core::iter::Sum;
@@ -532,6 +532,47 @@ impl RistrettoPoint {
532532
CompressedRistretto(s.to_bytes())
533533
}
534534

535+
/// Double-and-compress a batch of points. The Ristretto encoding
536+
/// is not batchable, since it requires an inverse square root.
537+
///
538+
/// However, given input points \\( P\_1, \ldots, P\_n, \\)
539+
/// it is possible to compute the encodings of their doubles \\(
540+
/// \mathrm{enc}( \[2\]P\_1), \ldots, \mathrm{enc}( \[2\]P\_n ) \\)
541+
/// in a batch.
542+
///
543+
#[cfg_attr(feature = "rand_core", doc = "```")]
544+
#[cfg_attr(not(feature = "rand_core"), doc = "```ignore")]
545+
/// # use curve25519_dalek::ristretto::RistrettoPoint;
546+
/// use rand_core::{OsRng, TryRngCore};
547+
///
548+
/// # // Need fn main() here in comment so the doctest compiles
549+
/// # // See https://doc.rust-lang.org/book/documentation.html#documentation-as-tests
550+
/// # fn main() {
551+
/// let mut rng = OsRng.unwrap_err();
552+
///
553+
/// let points: [RistrettoPoint; 32] =
554+
/// core::array::from_fn(|_| RistrettoPoint::random(&mut rng));
555+
///
556+
/// let compressed = RistrettoPoint::double_and_compress_batch(&points);
557+
///
558+
/// for (P, P2_compressed) in points.iter().zip(compressed.iter()) {
559+
/// assert_eq!(*P2_compressed, (P + P).compress());
560+
/// }
561+
/// # }
562+
/// ```
563+
pub fn double_and_compress_batch<const N: usize>(
564+
points: &[RistrettoPoint; N],
565+
) -> [CompressedRistretto; N] {
566+
let states: [BatchCompressState; N] =
567+
array::from_fn(|i| BatchCompressState::from(&points[i]));
568+
569+
let mut invs: [FieldElement; N] = array::from_fn(|i| states[i].efgh());
570+
571+
FieldElement::batch_invert(&mut invs);
572+
573+
array::from_fn(|i| Self::internal_double_and_compress_batch(&states[i], &invs[i]))
574+
}
575+
535576
/// Double-and-compress a batch of points. The Ristretto encoding
536577
/// is not batchable, since it requires an inverse square root.
537578
///
@@ -553,97 +594,68 @@ impl RistrettoPoint {
553594
/// let points: Vec<RistrettoPoint> =
554595
/// (0..32).map(|_| RistrettoPoint::random(&mut rng)).collect();
555596
///
556-
/// let compressed = RistrettoPoint::double_and_compress_batch(&points);
597+
/// let compressed = RistrettoPoint::double_and_compress_alloc_batch(&points);
557598
///
558599
/// for (P, P2_compressed) in points.iter().zip(compressed.iter()) {
559600
/// assert_eq!(*P2_compressed, (P + P).compress());
560601
/// }
561602
/// # }
562603
/// ```
563604
#[cfg(feature = "alloc")]
564-
pub fn double_and_compress_batch<'a, I>(points: I) -> Vec<CompressedRistretto>
605+
pub fn double_and_compress_alloc_batch<'a, I>(points: I) -> Vec<CompressedRistretto>
565606
where
566607
I: IntoIterator<Item = &'a RistrettoPoint>,
567608
{
568-
#[derive(Copy, Clone, Debug)]
569-
struct BatchCompressState {
570-
e: FieldElement,
571-
f: FieldElement,
572-
g: FieldElement,
573-
h: FieldElement,
574-
eg: FieldElement,
575-
fh: FieldElement,
576-
}
577-
578-
impl BatchCompressState {
579-
fn efgh(&self) -> FieldElement {
580-
&self.eg * &self.fh
581-
}
582-
}
583-
584-
impl<'a> From<&'a RistrettoPoint> for BatchCompressState {
585-
#[rustfmt::skip] // keep alignment of explanatory comments
586-
fn from(P: &'a RistrettoPoint) -> BatchCompressState {
587-
let XX = P.0.X.square();
588-
let YY = P.0.Y.square();
589-
let ZZ = P.0.Z.square();
590-
let dTT = &P.0.T.square() * &constants::EDWARDS_D;
591-
592-
let e = &P.0.X * &(&P.0.Y + &P.0.Y); // = 2*X*Y
593-
let f = &ZZ + &dTT; // = Z^2 + d*T^2
594-
let g = &YY + &XX; // = Y^2 - a*X^2
595-
let h = &ZZ - &dTT; // = Z^2 - d*T^2
596-
597-
let eg = &e * &g;
598-
let fh = &f * &h;
599-
600-
BatchCompressState{ e, f, g, h, eg, fh }
601-
}
602-
}
603-
604609
let states: Vec<BatchCompressState> =
605610
points.into_iter().map(BatchCompressState::from).collect();
606611

607612
let mut invs: Vec<FieldElement> = states.iter().map(|state| state.efgh()).collect();
608613

609-
FieldElement::batch_invert(&mut invs[..]);
614+
FieldElement::batch_alloc_invert(&mut invs[..]);
610615

611616
states
612617
.iter()
613618
.zip(invs.iter())
614619
.map(|(state, inv): (&BatchCompressState, &FieldElement)| {
615-
let Zinv = &state.eg * inv;
616-
let Tinv = &state.fh * inv;
620+
Self::internal_double_and_compress_batch(state, inv)
621+
})
622+
.collect()
623+
}
617624

618-
let mut magic = constants::INVSQRT_A_MINUS_D;
625+
fn internal_double_and_compress_batch(
626+
state: &BatchCompressState,
627+
inv: &FieldElement,
628+
) -> CompressedRistretto {
629+
let Zinv = &state.eg * inv;
630+
let Tinv = &state.fh * inv;
619631

620-
let negcheck1 = (&state.eg * &Zinv).is_negative();
632+
let mut magic = constants::INVSQRT_A_MINUS_D;
621633

622-
let mut e = state.e;
623-
let mut g = state.g;
624-
let mut h = state.h;
634+
let negcheck1 = (&state.eg * &Zinv).is_negative();
625635

626-
let minus_e = -&e;
627-
let f_times_sqrta = &state.f * &constants::SQRT_M1;
636+
let mut e = state.e;
637+
let mut g = state.g;
638+
let mut h = state.h;
628639

629-
e.conditional_assign(&state.g, negcheck1);
630-
g.conditional_assign(&minus_e, negcheck1);
631-
h.conditional_assign(&f_times_sqrta, negcheck1);
640+
let minus_e = -&e;
641+
let f_times_sqrta = &state.f * &constants::SQRT_M1;
632642

633-
magic.conditional_assign(&constants::SQRT_M1, negcheck1);
643+
e.conditional_assign(&state.g, negcheck1);
644+
g.conditional_assign(&minus_e, negcheck1);
645+
h.conditional_assign(&f_times_sqrta, negcheck1);
634646

635-
let negcheck2 = (&(&h * &e) * &Zinv).is_negative();
647+
magic.conditional_assign(&constants::SQRT_M1, negcheck1);
636648

637-
g.conditional_negate(negcheck2);
649+
let negcheck2 = (&(&h * &e) * &Zinv).is_negative();
638650

639-
let mut s = &(&h - &g) * &(&magic * &(&g * &Tinv));
651+
g.conditional_negate(negcheck2);
640652

641-
let s_is_negative = s.is_negative();
642-
s.conditional_negate(s_is_negative);
653+
let mut s = &(&h - &g) * &(&magic * &(&g * &Tinv));
643654

644-
CompressedRistretto(s.to_bytes())
645-
})
646-
.collect()
655+
let s_is_negative = s.is_negative();
656+
s.conditional_negate(s_is_negative);
657+
658+
CompressedRistretto(s.to_bytes())
647659
}
648660

649661
/// Return the coset self + E\[4\], for debugging.
@@ -1156,6 +1168,42 @@ impl RistrettoBasepointTable {
11561168
}
11571169
}
11581170

1171+
#[derive(Copy, Clone, Debug)]
1172+
struct BatchCompressState {
1173+
e: FieldElement,
1174+
f: FieldElement,
1175+
g: FieldElement,
1176+
h: FieldElement,
1177+
eg: FieldElement,
1178+
fh: FieldElement,
1179+
}
1180+
1181+
impl BatchCompressState {
1182+
fn efgh(&self) -> FieldElement {
1183+
&self.eg * &self.fh
1184+
}
1185+
}
1186+
1187+
impl<'a> From<&'a RistrettoPoint> for BatchCompressState {
1188+
#[rustfmt::skip] // keep alignment of explanatory comments
1189+
fn from(P: &'a RistrettoPoint) -> BatchCompressState {
1190+
let XX = P.0.X.square();
1191+
let YY = P.0.Y.square();
1192+
let ZZ = P.0.Z.square();
1193+
let dTT = &P.0.T.square() * &constants::EDWARDS_D;
1194+
1195+
let e = &P.0.X * &(&P.0.Y + &P.0.Y); // = 2*X*Y
1196+
let f = &ZZ + &dTT; // = Z^2 + d*T^2
1197+
let g = &YY + &XX; // = Y^2 - a*X^2
1198+
let h = &ZZ - &dTT; // = Z^2 - d*T^2
1199+
1200+
let eg = &e * &g;
1201+
let fh = &f * &h;
1202+
1203+
BatchCompressState{ e, f, g, h, eg, fh }
1204+
}
1205+
}
1206+
11591207
// ------------------------------------------------------------------------
11601208
// Constant-time conditional selection
11611209
// ------------------------------------------------------------------------
@@ -1860,7 +1908,7 @@ mod test {
18601908
.collect();
18611909
points[500] = <RistrettoPoint as Group>::identity();
18621910

1863-
let compressed = RistrettoPoint::double_and_compress_batch(&points);
1911+
let compressed = RistrettoPoint::double_and_compress_alloc_batch(&points);
18641912

18651913
for (P, P2_compressed) in points.iter().zip(compressed.iter()) {
18661914
assert_eq!(*P2_compressed, (P + P).compress());

0 commit comments

Comments
 (0)