Skip to content

Commit 4095dcf

Browse files
committed
Add allocation-free RistrettoPoint::double_and_compress_batch()
1 parent 1ad4603 commit 4095dcf

File tree

2 files changed

+110
-62
lines changed

2 files changed

+110
-62
lines changed

curve25519-dalek/benches/dalek_benchmarks.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ mod ristretto_benches {
289289
let points: Vec<RistrettoPoint> = (0..size)
290290
.map(|_| RistrettoPoint::try_from_rng(&mut rng).unwrap())
291291
.collect();
292-
b.iter(|| RistrettoPoint::double_and_compress_batch(&points));
292+
b.iter(|| RistrettoPoint::double_and_compress_alloc_batch(&points));
293293
},
294294
);
295295
}

curve25519-dalek/src/ristretto.rs

Lines changed: 109 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@
161161
#[cfg(feature = "alloc")]
162162
use alloc::vec::Vec;
163163

164-
use core::array::TryFromSliceError;
164+
use core::array::{self, TryFromSliceError};
165165
use core::borrow::Borrow;
166166
use core::fmt::Debug;
167167
use core::iter::Sum;
@@ -532,6 +532,47 @@ impl RistrettoPoint {
532532
CompressedRistretto(s.to_bytes())
533533
}
534534

535+
/// Double-and-compress a batch of points. The Ristretto encoding
536+
/// is not batchable, since it requires an inverse square root.
537+
///
538+
/// However, given input points \\( P\_1, \ldots, P\_n, \\)
539+
/// it is possible to compute the encodings of their doubles \\(
540+
/// \mathrm{enc}( \[2\]P\_1), \ldots, \mathrm{enc}( \[2\]P\_n ) \\)
541+
/// in a batch.
542+
///
543+
#[cfg_attr(feature = "rand_core", doc = "```")]
544+
#[cfg_attr(not(feature = "rand_core"), doc = "```ignore")]
545+
/// # use curve25519_dalek::ristretto::RistrettoPoint;
546+
/// use rand_core::{OsRng, TryRngCore};
547+
///
548+
/// # // Need fn main() here in comment so the doctest compiles
549+
/// # // See https://doc.rust-lang.org/book/documentation.html#documentation-as-tests
550+
/// # fn main() {
551+
/// let mut rng = OsRng.unwrap_err();
552+
///
553+
/// let points: [RistrettoPoint; 32] =
554+
/// core::array::from_fn(|_| RistrettoPoint::random(&mut rng));
555+
///
556+
/// let compressed = RistrettoPoint::double_and_compress_batch(&points);
557+
///
558+
/// for (P, P2_compressed) in points.iter().zip(compressed.iter()) {
559+
/// assert_eq!(*P2_compressed, (P + P).compress());
560+
/// }
561+
/// # }
562+
/// ```
563+
pub fn double_and_compress_batch<const N: usize>(
564+
points: &[RistrettoPoint; N],
565+
) -> [CompressedRistretto; N] {
566+
let states: [BatchCompressState; N] =
567+
array::from_fn(|i| BatchCompressState::from(&points[i]));
568+
569+
let mut invs: [FieldElement; N] = array::from_fn(|i| states[i].efgh());
570+
571+
FieldElement::batch_invert(&mut invs[..]);
572+
573+
array::from_fn(|i| Self::internal_double_and_compress_batch(&states[i], &invs[i]))
574+
}
575+
535576
/// Double-and-compress a batch of points. The Ristretto encoding
536577
/// is not batchable, since it requires an inverse square root.
537578
///
@@ -553,54 +594,18 @@ impl RistrettoPoint {
553594
/// let points: Vec<RistrettoPoint> =
554595
/// (0..32).map(|_| RistrettoPoint::random(&mut rng)).collect();
555596
///
556-
/// let compressed = RistrettoPoint::double_and_compress_batch(&points);
597+
/// let compressed = RistrettoPoint::double_and_compress_alloc_batch(&points);
557598
///
558599
/// for (P, P2_compressed) in points.iter().zip(compressed.iter()) {
559600
/// assert_eq!(*P2_compressed, (P + P).compress());
560601
/// }
561602
/// # }
562603
/// ```
563604
#[cfg(feature = "alloc")]
564-
pub fn double_and_compress_batch<'a, I>(points: I) -> Vec<CompressedRistretto>
605+
pub fn double_and_compress_alloc_batch<'a, I>(points: I) -> Vec<CompressedRistretto>
565606
where
566607
I: IntoIterator<Item = &'a RistrettoPoint>,
567608
{
568-
#[derive(Copy, Clone, Debug)]
569-
struct BatchCompressState {
570-
e: FieldElement,
571-
f: FieldElement,
572-
g: FieldElement,
573-
h: FieldElement,
574-
eg: FieldElement,
575-
fh: FieldElement,
576-
}
577-
578-
impl BatchCompressState {
579-
fn efgh(&self) -> FieldElement {
580-
&self.eg * &self.fh
581-
}
582-
}
583-
584-
impl<'a> From<&'a RistrettoPoint> for BatchCompressState {
585-
#[rustfmt::skip] // keep alignment of explanatory comments
586-
fn from(P: &'a RistrettoPoint) -> BatchCompressState {
587-
let XX = P.0.X.square();
588-
let YY = P.0.Y.square();
589-
let ZZ = P.0.Z.square();
590-
let dTT = &P.0.T.square() * &constants::EDWARDS_D;
591-
592-
let e = &P.0.X * &(&P.0.Y + &P.0.Y); // = 2*X*Y
593-
let f = &ZZ + &dTT; // = Z^2 + d*T^2
594-
let g = &YY + &XX; // = Y^2 - a*X^2
595-
let h = &ZZ - &dTT; // = Z^2 - d*T^2
596-
597-
let eg = &e * &g;
598-
let fh = &f * &h;
599-
600-
BatchCompressState{ e, f, g, h, eg, fh }
601-
}
602-
}
603-
604609
let states: Vec<BatchCompressState> =
605610
points.into_iter().map(BatchCompressState::from).collect();
606611

@@ -612,38 +617,45 @@ impl RistrettoPoint {
612617
.iter()
613618
.zip(invs.iter())
614619
.map(|(state, inv): (&BatchCompressState, &FieldElement)| {
615-
let Zinv = &state.eg * inv;
616-
let Tinv = &state.fh * inv;
620+
Self::internal_double_and_compress_batch(state, inv)
621+
})
622+
.collect()
623+
}
617624

618-
let mut magic = constants::INVSQRT_A_MINUS_D;
625+
fn internal_double_and_compress_batch(
626+
state: &BatchCompressState,
627+
inv: &FieldElement,
628+
) -> CompressedRistretto {
629+
let Zinv = &state.eg * inv;
630+
let Tinv = &state.fh * inv;
619631

620-
let negcheck1 = (&state.eg * &Zinv).is_negative();
632+
let mut magic = constants::INVSQRT_A_MINUS_D;
621633

622-
let mut e = state.e;
623-
let mut g = state.g;
624-
let mut h = state.h;
634+
let negcheck1 = (&state.eg * &Zinv).is_negative();
625635

626-
let minus_e = -&e;
627-
let f_times_sqrta = &state.f * &constants::SQRT_M1;
636+
let mut e = state.e;
637+
let mut g = state.g;
638+
let mut h = state.h;
628639

629-
e.conditional_assign(&state.g, negcheck1);
630-
g.conditional_assign(&minus_e, negcheck1);
631-
h.conditional_assign(&f_times_sqrta, negcheck1);
640+
let minus_e = -&e;
641+
let f_times_sqrta = &state.f * &constants::SQRT_M1;
632642

633-
magic.conditional_assign(&constants::SQRT_M1, negcheck1);
643+
e.conditional_assign(&state.g, negcheck1);
644+
g.conditional_assign(&minus_e, negcheck1);
645+
h.conditional_assign(&f_times_sqrta, negcheck1);
634646

635-
let negcheck2 = (&(&h * &e) * &Zinv).is_negative();
647+
magic.conditional_assign(&constants::SQRT_M1, negcheck1);
636648

637-
g.conditional_negate(negcheck2);
649+
let negcheck2 = (&(&h * &e) * &Zinv).is_negative();
638650

639-
let mut s = &(&h - &g) * &(&magic * &(&g * &Tinv));
651+
g.conditional_negate(negcheck2);
640652

641-
let s_is_negative = s.is_negative();
642-
s.conditional_negate(s_is_negative);
653+
let mut s = &(&h - &g) * &(&magic * &(&g * &Tinv));
643654

644-
CompressedRistretto(s.to_bytes())
645-
})
646-
.collect()
655+
let s_is_negative = s.is_negative();
656+
s.conditional_negate(s_is_negative);
657+
658+
CompressedRistretto(s.to_bytes())
647659
}
648660

649661
/// Return the coset self + E\[4\], for debugging.
@@ -1156,6 +1168,42 @@ impl RistrettoBasepointTable {
11561168
}
11571169
}
11581170

1171+
#[derive(Copy, Clone, Debug)]
1172+
struct BatchCompressState {
1173+
e: FieldElement,
1174+
f: FieldElement,
1175+
g: FieldElement,
1176+
h: FieldElement,
1177+
eg: FieldElement,
1178+
fh: FieldElement,
1179+
}
1180+
1181+
impl BatchCompressState {
1182+
fn efgh(&self) -> FieldElement {
1183+
&self.eg * &self.fh
1184+
}
1185+
}
1186+
1187+
impl<'a> From<&'a RistrettoPoint> for BatchCompressState {
1188+
#[rustfmt::skip] // keep alignment of explanatory comments
1189+
fn from(P: &'a RistrettoPoint) -> BatchCompressState {
1190+
let XX = P.0.X.square();
1191+
let YY = P.0.Y.square();
1192+
let ZZ = P.0.Z.square();
1193+
let dTT = &P.0.T.square() * &constants::EDWARDS_D;
1194+
1195+
let e = &P.0.X * &(&P.0.Y + &P.0.Y); // = 2*X*Y
1196+
let f = &ZZ + &dTT; // = Z^2 + d*T^2
1197+
let g = &YY + &XX; // = Y^2 - a*X^2
1198+
let h = &ZZ - &dTT; // = Z^2 - d*T^2
1199+
1200+
let eg = &e * &g;
1201+
let fh = &f * &h;
1202+
1203+
BatchCompressState{ e, f, g, h, eg, fh }
1204+
}
1205+
}
1206+
11591207
// ------------------------------------------------------------------------
11601208
// Constant-time conditional selection
11611209
// ------------------------------------------------------------------------
@@ -1860,7 +1908,7 @@ mod test {
18601908
.collect();
18611909
points[500] = <RistrettoPoint as Group>::identity();
18621910

1863-
let compressed = RistrettoPoint::double_and_compress_batch(&points);
1911+
let compressed = RistrettoPoint::double_and_compress_alloc_batch(&points);
18641912

18651913
for (P, P2_compressed) in points.iter().zip(compressed.iter()) {
18661914
assert_eq!(*P2_compressed, (P + P).compress());

0 commit comments

Comments
 (0)