Skip to content

Commit 4443375

Browse files
authored
Add fixed-array Scalar batch inversion (#789)
1 parent 015707a commit 4443375

File tree

3 files changed

+68
-18
lines changed

3 files changed

+68
-18
lines changed

curve25519-dalek/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ major series.
55

66
## 5.x series
77

8+
## 5.0.0-pre.1
9+
10+
* Rename `Scalar::batch_invert` -> `Scalar::invert_batch` for consistency. Also make it no-alloc.
11+
* Add an allocating batch inversion called `Scalar::invert_batch_alloc`.
12+
813
## 5.0.0-pre.0
914

1015
* Update edition to 2024

curve25519-dalek/benches/dalek_benchmarks.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ mod scalar_benches {
393393
(0..size).map(|_| Scalar::random(&mut rng)).collect();
394394
b.iter(|| {
395395
let mut s = scalars.clone();
396-
Scalar::batch_invert(&mut s);
396+
Scalar::invert_batch_alloc(&mut s);
397397
});
398398
},
399399
);

curve25519-dalek/src/scalar.rs

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,7 @@ impl Scalar {
770770
/// Scalar::from(11u64),
771771
/// ];
772772
///
773-
/// let allinv = Scalar::batch_invert(&mut scalars);
773+
/// let allinv = Scalar::invert_batch(&mut scalars);
774774
///
775775
/// assert_eq!(allinv, Scalar::from(3*5*7*11u64).invert());
776776
/// assert_eq!(scalars[0], Scalar::from(3u64).invert());
@@ -779,20 +779,46 @@ impl Scalar {
779779
/// assert_eq!(scalars[3], Scalar::from(11u64).invert());
780780
/// # }
781781
/// ```
782+
pub fn invert_batch<const N: usize>(inputs: &mut [Scalar; N]) -> Scalar {
783+
let one: UnpackedScalar = Scalar::ONE.unpack().as_montgomery();
784+
785+
let mut scratch = [one; N];
786+
787+
Self::invert_batch_internal(inputs, &mut scratch)
788+
}
789+
790+
/// Given a slice of nonzero (possibly secret) `Scalar`s, compute their inverses in a batch.
791+
/// This the allocating form of [`Self::invert_batch`]. See those docs for examples.
792+
///
793+
/// # Return
794+
///
795+
/// Each element of `inputs` is replaced by its inverse.
796+
///
797+
/// The product of all inverses is returned.
798+
///
799+
/// # Warning
800+
///
801+
/// All input `Scalars` **MUST** be nonzero. If you cannot
802+
/// *prove* that this is the case, you **SHOULD NOT USE THIS
803+
/// FUNCTION**.
782804
#[cfg(feature = "alloc")]
783-
pub fn batch_invert(inputs: &mut [Scalar]) -> Scalar {
805+
pub fn invert_batch_alloc(inputs: &mut [Scalar]) -> Scalar {
806+
let n = inputs.len();
807+
let one: UnpackedScalar = Scalar::ONE.unpack().as_montgomery();
808+
809+
let mut scratch = vec![one; n];
810+
811+
Self::invert_batch_internal(inputs, &mut scratch)
812+
}
813+
814+
fn invert_batch_internal(inputs: &mut [Scalar], scratch: &mut [UnpackedScalar]) -> Scalar {
784815
// This code is essentially identical to the FieldElement
785816
// implementation, and is documented there. Unfortunately,
786817
// it's not easy to write it generically, since here we want
787818
// to use `UnpackedScalar`s internally, and `Scalar`s
788819
// externally, but there's no corresponding distinction for
789820
// field elements.
790821

791-
let n = inputs.len();
792-
let one: UnpackedScalar = Scalar::ONE.unpack().as_montgomery();
793-
794-
let mut scratch = vec![one; n];
795-
796822
// Keep an accumulator of all of the previous products
797823
let mut acc = Scalar::ONE.unpack().as_montgomery();
798824

@@ -826,7 +852,7 @@ impl Scalar {
826852
}
827853

828854
#[cfg(feature = "zeroize")]
829-
Zeroize::zeroize(&mut scratch);
855+
Zeroize::zeroize(&mut scratch.iter_mut());
830856

831857
ret
832858
}
@@ -1845,25 +1871,44 @@ pub(crate) mod test {
18451871
assert_eq!(X, bincode::deserialize(X.as_bytes()).unwrap(),);
18461872
}
18471873

1848-
#[cfg(all(debug_assertions, feature = "alloc"))]
1874+
#[cfg(debug_assertions)]
18491875
#[test]
18501876
#[should_panic]
1851-
fn batch_invert_with_a_zero_input_panics() {
1852-
let mut xs = vec![Scalar::ONE; 16];
1877+
fn invert_batch_with_a_zero_input_panics() {
1878+
let mut xs = [Scalar::ONE; 16];
18531879
xs[3] = Scalar::ZERO;
18541880
// This should panic in debug mode.
1855-
Scalar::batch_invert(&mut xs);
1881+
Scalar::invert_batch(&mut xs);
18561882
}
18571883

18581884
#[test]
1859-
#[cfg(feature = "alloc")]
1860-
fn batch_invert_empty() {
1861-
assert_eq!(Scalar::ONE, Scalar::batch_invert(&mut []));
1885+
fn invert_batch_empty() {
1886+
assert_eq!(Scalar::ONE, Scalar::invert_batch(&mut []));
1887+
}
1888+
1889+
#[test]
1890+
fn invert_batch_consistency() {
1891+
let mut x = Scalar::from(1u64);
1892+
let mut v1: [Scalar; 16] = core::array::from_fn(|_| {
1893+
let tmp = x;
1894+
x = x + x;
1895+
tmp
1896+
});
1897+
let v2 = v1;
1898+
1899+
let expected: Scalar = v1.iter().product();
1900+
let expected = expected.invert();
1901+
let ret = Scalar::invert_batch(&mut v1);
1902+
assert_eq!(ret, expected);
1903+
1904+
for (a, b) in v1.iter().zip(v2.iter()) {
1905+
assert_eq!(a * b, Scalar::ONE);
1906+
}
18621907
}
18631908

18641909
#[test]
18651910
#[cfg(feature = "alloc")]
1866-
fn batch_invert_consistency() {
1911+
fn batch_vec_invert_consistency() {
18671912
let mut x = Scalar::from(1u64);
18681913
let mut v1: Vec<_> = (0..16)
18691914
.map(|_| {
@@ -1876,7 +1921,7 @@ pub(crate) mod test {
18761921

18771922
let expected: Scalar = v1.iter().product();
18781923
let expected = expected.invert();
1879-
let ret = Scalar::batch_invert(&mut v1);
1924+
let ret = Scalar::invert_batch_alloc(&mut v1);
18801925
assert_eq!(ret, expected);
18811926

18821927
for (a, b) in v1.iter().zip(v2.iter()) {

0 commit comments

Comments
 (0)