Skip to content

Commit fc88157

Browse files
tarcieridaxpedda
andauthored
Scalar::div_by_2 (#805)
* [WIP] Scalar::div_by_2 * debug_assert that carry is 0 * revise tests * Test multiply by half scalar, double and compress (#804) * Test `div_by_2` with `proptest` (#806) --------- Co-authored-by: daxpedda <[email protected]>
1 parent 1ad4603 commit fc88157

File tree

5 files changed

+116
-10
lines changed

5 files changed

+116
-10
lines changed

curve25519-dalek/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ sha2 = { version = "0.11.0-rc.0", default-features = false }
4141
bincode = "1"
4242
criterion = { version = "0.5", features = ["html_reports"] }
4343
hex = "0.4.2"
44+
proptest = "1"
4445
rand = "0.9"
4546
rand_core = { version = "0.9", default-features = false, features = ["os_rng"] }
4647

curve25519-dalek/src/backend/serial/u32/scalar.rs

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,15 +197,33 @@ impl Scalar29 {
197197
}
198198

199199
// conditionally add l if the difference is negative
200+
difference.conditional_add_l(Choice::from((borrow >> 31) as u8));
201+
difference
202+
}
203+
204+
pub(crate) fn conditional_add_l(&mut self, condition: Choice) -> u32 {
200205
let mut carry: u32 = 0;
206+
let mask = (1u32 << 29) - 1;
207+
201208
for i in 0..9 {
202-
let underflow = Choice::from((borrow >> 31) as u8);
203-
let addend = u32::conditional_select(&0, &constants::L[i], underflow);
204-
carry = (carry >> 29) + difference[i] + addend;
205-
difference[i] = carry & mask;
209+
let addend = u32::conditional_select(&0, &constants::L[i], condition);
210+
carry = (carry >> 29) + self[i] + addend;
211+
self[i] = carry & mask;
206212
}
213+
carry
214+
}
207215

208-
difference
216+
/// Compute a raw in-place carrying right shift over the limbs.
217+
#[inline(always)]
218+
pub(crate) fn shr1_assign(&mut self) -> u32 {
219+
let mut carry: u32 = 0;
220+
for i in (0..9).rev() {
221+
let limb = self[i];
222+
let next_carry = limb & 1;
223+
self[i] = (limb >> 1) | (carry << 28);
224+
carry = next_carry;
225+
}
226+
carry
209227
}
210228

211229
/// Compute `a * b`.

curve25519-dalek/src/backend/serial/u64/scalar.rs

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,34 @@ impl Scalar52 {
186186
}
187187

188188
// conditionally add l if the difference is negative
189+
difference.conditional_add_l(Choice::from((borrow >> 63) as u8));
190+
difference
191+
}
192+
193+
pub(crate) fn conditional_add_l(&mut self, condition: Choice) -> u64 {
189194
let mut carry: u64 = 0;
195+
let mask = (1u64 << 52) - 1;
196+
190197
for i in 0..5 {
191-
let underflow = Choice::from((borrow >> 63) as u8);
192-
let addend = u64::conditional_select(&0, &constants::L[i], underflow);
193-
carry = (carry >> 52) + difference[i] + addend;
194-
difference[i] = carry & mask;
198+
let addend = u64::conditional_select(&0, &constants::L[i], condition);
199+
carry = (carry >> 52) + self[i] + addend;
200+
self[i] = carry & mask;
195201
}
196202

197-
difference
203+
carry
204+
}
205+
206+
/// Compute a raw in-place carrying right shift over the limbs.
207+
#[inline(always)]
208+
pub(crate) fn shr1_assign(&mut self) -> u64 {
209+
let mut carry: u64 = 0;
210+
for i in (0..5).rev() {
211+
let limb = self[i];
212+
let next_carry = limb & 1;
213+
self[i] = (limb >> 1) | (carry << 51);
214+
carry = next_carry;
215+
}
216+
carry
198217
}
199218

200219
/// Compute `a * b`

curve25519-dalek/src/ristretto.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,6 +1321,8 @@ impl Zeroize for RistrettoPoint {
13211321
mod test {
13221322
use super::*;
13231323
use crate::edwards::CompressedEdwardsY;
1324+
#[cfg(feature = "group")]
1325+
use proptest::prelude::*;
13241326

13251327
use rand_core::{OsRng, TryRngCore};
13261328

@@ -1867,6 +1869,39 @@ mod test {
18671869
}
18681870
}
18691871

1872+
#[cfg(feature = "group")]
1873+
proptest! {
1874+
#[test]
1875+
fn multiply_double_and_compress_random_points(
1876+
p1 in any::<[u8; 64]>(),
1877+
p2 in any::<[u8; 64]>(),
1878+
s1 in any::<[u8; 32]>(),
1879+
s2 in any::<[u8; 32]>(),
1880+
) {
1881+
use group::Group;
1882+
1883+
let scalars = [
1884+
Scalar::from_bytes_mod_order(s1),
1885+
Scalar::ZERO,
1886+
Scalar::from_bytes_mod_order(s2),
1887+
];
1888+
1889+
let points = [
1890+
RistrettoPoint::from_uniform_bytes(&p1),
1891+
<RistrettoPoint as Group>::identity(),
1892+
RistrettoPoint::from_uniform_bytes(&p2),
1893+
];
1894+
1895+
let multiplied_points: [_; 3] =
1896+
core::array::from_fn(|i| scalars[i].div_by_2() * points[i]);
1897+
let compressed = RistrettoPoint::double_and_compress_batch(&multiplied_points);
1898+
1899+
for ((s, P), P2_compressed) in scalars.iter().zip(points).zip(compressed) {
1900+
prop_assert_eq!(P2_compressed, (s * P).compress());
1901+
}
1902+
}
1903+
}
1904+
18701905
#[test]
18711906
#[cfg(feature = "alloc")]
18721907
fn vartime_precomputed_vs_nonprecomputed_multiscalar() {

curve25519-dalek/src/scalar.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,22 @@ impl Scalar {
831831
ret
832832
}
833833

834+
/// Compute `b` such that `b + b = a mod modulus`.
835+
pub fn div_by_2(&self) -> Self {
836+
// We are looking for such `b` that `b + b = a mod modulus`.
837+
// Two possibilities:
838+
// - if `a` is even, we can just divide by 2;
839+
// - if `a` is odd, we divide `(a + modulus)` by 2.
840+
let is_odd = Choice::from(self.as_bytes()[0] & 1);
841+
let mut scalar = self.unpack();
842+
scalar.conditional_add_l(is_odd);
843+
844+
let carry = scalar.shr1_assign();
845+
debug_assert_eq!(carry, 0);
846+
847+
scalar.pack()
848+
}
849+
834850
/// Get the bits of the scalar, in little-endian order
835851
pub(crate) fn bits_le(&self) -> impl DoubleEndedIterator<Item = bool> + '_ {
836852
(0..256).map(|i| {
@@ -1677,6 +1693,23 @@ pub(crate) mod test {
16771693
}
16781694
}
16791695

1696+
#[test]
1697+
fn div_by_2() {
1698+
// test a range of small scalars
1699+
for i in 0u64..32 {
1700+
let scalar = Scalar::from(i);
1701+
let dividend = scalar.div_by_2();
1702+
assert_eq!(scalar, dividend + dividend);
1703+
}
1704+
1705+
// test a range of scalars near the modulus
1706+
for i in 0u64..32 {
1707+
let scalar = Scalar::ZERO - Scalar::from(i);
1708+
let dividend = scalar.div_by_2();
1709+
assert_eq!(scalar, dividend + dividend);
1710+
}
1711+
}
1712+
16801713
#[test]
16811714
fn reduce() {
16821715
let biggest = Scalar::from_bytes_mod_order([0xff; 32]);

0 commit comments

Comments
 (0)