@@ -770,7 +770,7 @@ impl Scalar {
770770 /// Scalar::from(11u64),
771771 /// ];
772772 ///
773- /// let allinv = Scalar::batch_invert (&mut scalars);
773+ /// let allinv = Scalar::invert_batch (&mut scalars);
774774 ///
775775 /// assert_eq!(allinv, Scalar::from(3*5*7*11u64).invert());
776776 /// assert_eq!(scalars[0], Scalar::from(3u64).invert());
@@ -779,20 +779,46 @@ impl Scalar {
779779 /// assert_eq!(scalars[3], Scalar::from(11u64).invert());
780780 /// # }
781781 /// ```
782+ pub fn invert_batch < const N : usize > ( inputs : & mut [ Scalar ; N ] ) -> Scalar {
783+ let one: UnpackedScalar = Scalar :: ONE . unpack ( ) . as_montgomery ( ) ;
784+
785+ let mut scratch = [ one; N ] ;
786+
787+ Self :: invert_batch_internal ( inputs, & mut scratch)
788+ }
789+
790+ /// Given a slice of nonzero (possibly secret) `Scalar`s, compute their inverses in a batch.
791+ /// This the allocating form of [`Self::invert_batch`]. See those docs for examples.
792+ ///
793+ /// # Return
794+ ///
795+ /// Each element of `inputs` is replaced by its inverse.
796+ ///
797+ /// The product of all inverses is returned.
798+ ///
799+ /// # Warning
800+ ///
801+ /// All input `Scalars` **MUST** be nonzero. If you cannot
802+ /// *prove* that this is the case, you **SHOULD NOT USE THIS
803+ /// FUNCTION**.
782804 #[ cfg( feature = "alloc" ) ]
783- pub fn batch_invert ( inputs : & mut [ Scalar ] ) -> Scalar {
805+ pub fn invert_batch_alloc ( inputs : & mut [ Scalar ] ) -> Scalar {
806+ let n = inputs. len ( ) ;
807+ let one: UnpackedScalar = Scalar :: ONE . unpack ( ) . as_montgomery ( ) ;
808+
809+ let mut scratch = vec ! [ one; n] ;
810+
811+ Self :: invert_batch_internal ( inputs, & mut scratch)
812+ }
813+
814+ fn invert_batch_internal ( inputs : & mut [ Scalar ] , scratch : & mut [ UnpackedScalar ] ) -> Scalar {
784815 // This code is essentially identical to the FieldElement
785816 // implementation, and is documented there. Unfortunately,
786817 // it's not easy to write it generically, since here we want
787818 // to use `UnpackedScalar`s internally, and `Scalar`s
788819 // externally, but there's no corresponding distinction for
789820 // field elements.
790821
791- let n = inputs. len ( ) ;
792- let one: UnpackedScalar = Scalar :: ONE . unpack ( ) . as_montgomery ( ) ;
793-
794- let mut scratch = vec ! [ one; n] ;
795-
796822 // Keep an accumulator of all of the previous products
797823 let mut acc = Scalar :: ONE . unpack ( ) . as_montgomery ( ) ;
798824
@@ -826,7 +852,7 @@ impl Scalar {
826852 }
827853
828854 #[ cfg( feature = "zeroize" ) ]
829- Zeroize :: zeroize ( & mut scratch) ;
855+ Zeroize :: zeroize ( & mut scratch. iter_mut ( ) ) ;
830856
831857 ret
832858 }
@@ -1845,25 +1871,44 @@ pub(crate) mod test {
18451871 assert_eq ! ( X , bincode:: deserialize( X . as_bytes( ) ) . unwrap( ) , ) ;
18461872 }
18471873
1848- #[ cfg( all ( debug_assertions, feature = "alloc" ) ) ]
1874+ #[ cfg( debug_assertions) ]
18491875 #[ test]
18501876 #[ should_panic]
1851- fn batch_invert_with_a_zero_input_panics ( ) {
1852- let mut xs = vec ! [ Scalar :: ONE ; 16 ] ;
1877+ fn invert_batch_with_a_zero_input_panics ( ) {
1878+ let mut xs = [ Scalar :: ONE ; 16 ] ;
18531879 xs[ 3 ] = Scalar :: ZERO ;
18541880 // This should panic in debug mode.
1855- Scalar :: batch_invert ( & mut xs) ;
1881+ Scalar :: invert_batch ( & mut xs) ;
18561882 }
18571883
18581884 #[ test]
1859- #[ cfg( feature = "alloc" ) ]
1860- fn batch_invert_empty ( ) {
1861- assert_eq ! ( Scalar :: ONE , Scalar :: batch_invert( & mut [ ] ) ) ;
1885+ fn invert_batch_empty ( ) {
1886+ assert_eq ! ( Scalar :: ONE , Scalar :: invert_batch( & mut [ ] ) ) ;
1887+ }
1888+
1889+ #[ test]
1890+ fn invert_batch_consistency ( ) {
1891+ let mut x = Scalar :: from ( 1u64 ) ;
1892+ let mut v1: [ Scalar ; 16 ] = core:: array:: from_fn ( |_| {
1893+ let tmp = x;
1894+ x = x + x;
1895+ tmp
1896+ } ) ;
1897+ let v2 = v1;
1898+
1899+ let expected: Scalar = v1. iter ( ) . product ( ) ;
1900+ let expected = expected. invert ( ) ;
1901+ let ret = Scalar :: invert_batch ( & mut v1) ;
1902+ assert_eq ! ( ret, expected) ;
1903+
1904+ for ( a, b) in v1. iter ( ) . zip ( v2. iter ( ) ) {
1905+ assert_eq ! ( a * b, Scalar :: ONE ) ;
1906+ }
18621907 }
18631908
18641909 #[ test]
18651910 #[ cfg( feature = "alloc" ) ]
1866- fn batch_invert_consistency ( ) {
1911+ fn batch_vec_invert_consistency ( ) {
18671912 let mut x = Scalar :: from ( 1u64 ) ;
18681913 let mut v1: Vec < _ > = ( 0 ..16 )
18691914 . map ( |_| {
@@ -1876,7 +1921,7 @@ pub(crate) mod test {
18761921
18771922 let expected: Scalar = v1. iter ( ) . product ( ) ;
18781923 let expected = expected. invert ( ) ;
1879- let ret = Scalar :: batch_invert ( & mut v1) ;
1924+ let ret = Scalar :: invert_batch_alloc ( & mut v1) ;
18801925 assert_eq ! ( ret, expected) ;
18811926
18821927 for ( a, b) in v1. iter ( ) . zip ( v2. iter ( ) ) {
0 commit comments