161161#[ cfg( feature = "alloc" ) ]
162162use alloc:: vec:: Vec ;
163163
164- use core:: array:: TryFromSliceError ;
164+ use core:: array:: { self , TryFromSliceError } ;
165165use core:: borrow:: Borrow ;
166166use core:: fmt:: Debug ;
167167use core:: iter:: Sum ;
@@ -532,6 +532,47 @@ impl RistrettoPoint {
532532 CompressedRistretto ( s. to_bytes ( ) )
533533 }
534534
535+ /// Double-and-compress a batch of points. The Ristretto encoding
536+ /// is not batchable, since it requires an inverse square root.
537+ ///
538+ /// However, given input points \\( P\_1, \ldots, P\_n, \\)
539+ /// it is possible to compute the encodings of their doubles \\(
540+ /// \mathrm{enc}( \[2\]P\_1), \ldots, \mathrm{enc}( \[2\]P\_n ) \\)
541+ /// in a batch.
542+ ///
543+ #[ cfg_attr( feature = "rand_core" , doc = "```" ) ]
544+ #[ cfg_attr( not( feature = "rand_core" ) , doc = "```ignore" ) ]
545+ /// # use curve25519_dalek::ristretto::RistrettoPoint;
546+ /// use rand_core::{OsRng, TryRngCore};
547+ ///
548+ /// # // Need fn main() here in comment so the doctest compiles
549+ /// # // See https://doc.rust-lang.org/book/documentation.html#documentation-as-tests
550+ /// # fn main() {
551+ /// let mut rng = OsRng.unwrap_err();
552+ ///
553+ /// let points: [RistrettoPoint; 32] =
554+ /// core::array::from_fn(|_| RistrettoPoint::random(&mut rng));
555+ ///
556+ /// let compressed = RistrettoPoint::double_and_compress_batch(&points);
557+ ///
558+ /// for (P, P2_compressed) in points.iter().zip(compressed.iter()) {
559+ /// assert_eq!(*P2_compressed, (P + P).compress());
560+ /// }
561+ /// # }
562+ /// ```
563+ pub fn double_and_compress_batch < const N : usize > (
564+ points : & [ RistrettoPoint ; N ] ,
565+ ) -> [ CompressedRistretto ; N ] {
566+ let states: [ BatchCompressState ; N ] =
567+ array:: from_fn ( |i| BatchCompressState :: from ( & points[ i] ) ) ;
568+
569+ let mut invs: [ FieldElement ; N ] = array:: from_fn ( |i| states[ i] . efgh ( ) ) ;
570+
571+ FieldElement :: batch_invert ( & mut invs) ;
572+
573+ array:: from_fn ( |i| Self :: internal_double_and_compress_batch ( & states[ i] , & invs[ i] ) )
574+ }
575+
535576 /// Double-and-compress a batch of points. The Ristretto encoding
536577 /// is not batchable, since it requires an inverse square root.
537578 ///
@@ -553,97 +594,68 @@ impl RistrettoPoint {
553594 /// let points: Vec<RistrettoPoint> =
554595 /// (0..32).map(|_| RistrettoPoint::random(&mut rng)).collect();
555596 ///
556- /// let compressed = RistrettoPoint::double_and_compress_batch (&points);
597+ /// let compressed = RistrettoPoint::double_and_compress_alloc_batch (&points);
557598 ///
558599 /// for (P, P2_compressed) in points.iter().zip(compressed.iter()) {
559600 /// assert_eq!(*P2_compressed, (P + P).compress());
560601 /// }
561602 /// # }
562603 /// ```
563604 #[ cfg( feature = "alloc" ) ]
564- pub fn double_and_compress_batch < ' a , I > ( points : I ) -> Vec < CompressedRistretto >
605+ pub fn double_and_compress_alloc_batch < ' a , I > ( points : I ) -> Vec < CompressedRistretto >
565606 where
566607 I : IntoIterator < Item = & ' a RistrettoPoint > ,
567608 {
568- #[ derive( Copy , Clone , Debug ) ]
569- struct BatchCompressState {
570- e : FieldElement ,
571- f : FieldElement ,
572- g : FieldElement ,
573- h : FieldElement ,
574- eg : FieldElement ,
575- fh : FieldElement ,
576- }
577-
578- impl BatchCompressState {
579- fn efgh ( & self ) -> FieldElement {
580- & self . eg * & self . fh
581- }
582- }
583-
584- impl < ' a > From < & ' a RistrettoPoint > for BatchCompressState {
585- #[ rustfmt:: skip] // keep alignment of explanatory comments
586- fn from ( P : & ' a RistrettoPoint ) -> BatchCompressState {
587- let XX = P . 0 . X . square ( ) ;
588- let YY = P . 0 . Y . square ( ) ;
589- let ZZ = P . 0 . Z . square ( ) ;
590- let dTT = & P . 0 . T . square ( ) * & constants:: EDWARDS_D ;
591-
592- let e = & P . 0 . X * & ( & P . 0 . Y + & P . 0 . Y ) ; // = 2*X*Y
593- let f = & ZZ + & dTT; // = Z^2 + d*T^2
594- let g = & YY + & XX ; // = Y^2 - a*X^2
595- let h = & ZZ - & dTT; // = Z^2 - d*T^2
596-
597- let eg = & e * & g;
598- let fh = & f * & h;
599-
600- BatchCompressState { e, f, g, h, eg, fh }
601- }
602- }
603-
604609 let states: Vec < BatchCompressState > =
605610 points. into_iter ( ) . map ( BatchCompressState :: from) . collect ( ) ;
606611
607612 let mut invs: Vec < FieldElement > = states. iter ( ) . map ( |state| state. efgh ( ) ) . collect ( ) ;
608613
609- FieldElement :: batch_invert ( & mut invs[ ..] ) ;
614+ FieldElement :: batch_alloc_invert ( & mut invs[ ..] ) ;
610615
611616 states
612617 . iter ( )
613618 . zip ( invs. iter ( ) )
614619 . map ( |( state, inv) : ( & BatchCompressState , & FieldElement ) | {
615- let Zinv = & state. eg * inv;
616- let Tinv = & state. fh * inv;
620+ Self :: internal_double_and_compress_batch ( state, inv)
621+ } )
622+ . collect ( )
623+ }
617624
618- let mut magic = constants:: INVSQRT_A_MINUS_D ;
625+ fn internal_double_and_compress_batch (
626+ state : & BatchCompressState ,
627+ inv : & FieldElement ,
628+ ) -> CompressedRistretto {
629+ let Zinv = & state. eg * inv;
630+ let Tinv = & state. fh * inv;
619631
620- let negcheck1 = ( & state . eg * & Zinv ) . is_negative ( ) ;
632+ let mut magic = constants :: INVSQRT_A_MINUS_D ;
621633
622- let mut e = state. e ;
623- let mut g = state. g ;
624- let mut h = state. h ;
634+ let negcheck1 = ( & state. eg * & Zinv ) . is_negative ( ) ;
625635
626- let minus_e = -& e;
627- let f_times_sqrta = & state. f * & constants:: SQRT_M1 ;
636+ let mut e = state. e ;
637+ let mut g = state. g ;
638+ let mut h = state. h ;
628639
629- e. conditional_assign ( & state. g , negcheck1) ;
630- g. conditional_assign ( & minus_e, negcheck1) ;
631- h. conditional_assign ( & f_times_sqrta, negcheck1) ;
640+ let minus_e = -& e;
641+ let f_times_sqrta = & state. f * & constants:: SQRT_M1 ;
632642
633- magic. conditional_assign ( & constants:: SQRT_M1 , negcheck1) ;
643+ e. conditional_assign ( & state. g , negcheck1) ;
644+ g. conditional_assign ( & minus_e, negcheck1) ;
645+ h. conditional_assign ( & f_times_sqrta, negcheck1) ;
634646
635- let negcheck2 = ( & ( & h * & e ) * & Zinv ) . is_negative ( ) ;
647+ magic . conditional_assign ( & constants :: SQRT_M1 , negcheck1 ) ;
636648
637- g . conditional_negate ( negcheck2 ) ;
649+ let negcheck2 = ( & ( & h * & e ) * & Zinv ) . is_negative ( ) ;
638650
639- let mut s = & ( & h - & g ) * & ( & magic * & ( & g * & Tinv ) ) ;
651+ g . conditional_negate ( negcheck2 ) ;
640652
641- let s_is_negative = s. is_negative ( ) ;
642- s. conditional_negate ( s_is_negative) ;
653+ let mut s = & ( & h - & g) * & ( & magic * & ( & g * & Tinv ) ) ;
643654
644- CompressedRistretto ( s. to_bytes ( ) )
645- } )
646- . collect ( )
655+ let s_is_negative = s. is_negative ( ) ;
656+ s. conditional_negate ( s_is_negative) ;
657+
658+ CompressedRistretto ( s. to_bytes ( ) )
647659 }
648660
649661 /// Return the coset self + E\[4\], for debugging.
@@ -1156,6 +1168,42 @@ impl RistrettoBasepointTable {
11561168 }
11571169}
11581170
1171+ #[ derive( Copy , Clone , Debug ) ]
1172+ struct BatchCompressState {
1173+ e : FieldElement ,
1174+ f : FieldElement ,
1175+ g : FieldElement ,
1176+ h : FieldElement ,
1177+ eg : FieldElement ,
1178+ fh : FieldElement ,
1179+ }
1180+
1181+ impl BatchCompressState {
1182+ fn efgh ( & self ) -> FieldElement {
1183+ & self . eg * & self . fh
1184+ }
1185+ }
1186+
1187+ impl < ' a > From < & ' a RistrettoPoint > for BatchCompressState {
1188+ #[ rustfmt:: skip] // keep alignment of explanatory comments
1189+ fn from ( P : & ' a RistrettoPoint ) -> BatchCompressState {
1190+ let XX = P . 0 . X . square ( ) ;
1191+ let YY = P . 0 . Y . square ( ) ;
1192+ let ZZ = P . 0 . Z . square ( ) ;
1193+ let dTT = & P . 0 . T . square ( ) * & constants:: EDWARDS_D ;
1194+
1195+ let e = & P . 0 . X * & ( & P . 0 . Y + & P . 0 . Y ) ; // = 2*X*Y
1196+ let f = & ZZ + & dTT; // = Z^2 + d*T^2
1197+ let g = & YY + & XX ; // = Y^2 - a*X^2
1198+ let h = & ZZ - & dTT; // = Z^2 - d*T^2
1199+
1200+ let eg = & e * & g;
1201+ let fh = & f * & h;
1202+
1203+ BatchCompressState { e, f, g, h, eg, fh }
1204+ }
1205+ }
1206+
11591207// ------------------------------------------------------------------------
11601208// Constant-time conditional selection
11611209// ------------------------------------------------------------------------
@@ -1860,7 +1908,7 @@ mod test {
18601908 . collect ( ) ;
18611909 points[ 500 ] = <RistrettoPoint as Group >:: identity ( ) ;
18621910
1863- let compressed = RistrettoPoint :: double_and_compress_batch ( & points) ;
1911+ let compressed = RistrettoPoint :: double_and_compress_alloc_batch ( & points) ;
18641912
18651913 for ( P , P2_compressed ) in points. iter ( ) . zip ( compressed. iter ( ) ) {
18661914 assert_eq ! ( * P2_compressed , ( P + P ) . compress( ) ) ;
0 commit comments