diff --git a/benches/bench.rs b/benches/bench.rs index 26ee63a4a..188741713 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -153,6 +153,59 @@ fn bench_shifts(group: &mut BenchmarkGroup<'_, M>) { }); } +fn bench_inv_mod(group: &mut BenchmarkGroup<'_, M>) { + group.bench_function("inv_odd_mod, U256", |b| { + b.iter_batched( + || { + let m = U256::random(&mut OsRng) | U256::ONE; + loop { + let x = U256::random(&mut OsRng); + let (_, is_some) = x.inv_odd_mod(&m); + if is_some.into() { + break (x, m); + } + } + }, + |(x, m)| x.inv_odd_mod(&m), + BatchSize::SmallInput, + ) + }); + + group.bench_function("inv_mod, U256, odd modulus", |b| { + b.iter_batched( + || { + let m = U256::random(&mut OsRng) | U256::ONE; + loop { + let x = U256::random(&mut OsRng); + let (_, is_some) = x.inv_odd_mod(&m); + if is_some.into() { + break (x, m); + } + } + }, + |(x, m)| x.inv_mod(&m), + BatchSize::SmallInput, + ) + }); + + group.bench_function("inv_mod, U256", |b| { + b.iter_batched( + || { + let m = U256::random(&mut OsRng); + loop { + let x = U256::random(&mut OsRng); + let (_, is_some) = x.inv_mod(&m); + if is_some.into() { + break (x, m); + } + } + }, + |(x, m)| x.inv_mod(&m), + BatchSize::SmallInput, + ) + }); +} + fn bench_wrapping_ops(c: &mut Criterion) { let mut group = c.benchmark_group("wrapping ops"); bench_division(&mut group); @@ -169,6 +222,7 @@ fn bench_montgomery(c: &mut Criterion) { fn bench_modular_ops(c: &mut Criterion) { let mut group = c.benchmark_group("modular ops"); bench_shifts(&mut group); + bench_inv_mod(&mut group); group.finish(); } diff --git a/src/ct_choice.rs b/src/ct_choice.rs index 3ad2d2323..921e72de9 100644 --- a/src/ct_choice.rs +++ b/src/ct_choice.rs @@ -29,6 +29,17 @@ impl CtChoice { Self(value.wrapping_neg()) } + /// Returns the truthy value if `value != 0`, and the falsy value otherwise. + pub(crate) const fn from_usize_being_nonzero(value: usize) -> Self { + const HI_BIT: u32 = usize::BITS - 1; + Self::from_lsb(((value | value.wrapping_neg()) >> HI_BIT) as Word) + } + + /// Returns the truthy value if `x == y`, and the falsy value otherwise. + pub(crate) const fn from_usize_equality(x: usize, y: usize) -> Self { + Self::from_usize_being_nonzero(x.wrapping_sub(y)).not() + } + /// Returns the truthy value if `x < y`, and the falsy value otherwise. pub(crate) const fn from_usize_lt(x: usize, y: usize) -> Self { let bit = (((!x) & y) | (((!x) | y) & (x.wrapping_sub(y)))) >> (usize::BITS - 1); @@ -39,6 +50,10 @@ impl CtChoice { Self(!self.0) } + pub(crate) const fn or(&self, other: Self) -> Self { + Self(self.0 | other.0) + } + pub(crate) const fn and(&self, other: Self) -> Self { Self(self.0 & other.0) } diff --git a/src/uint/bits.rs b/src/uint/bits.rs index 500a9314a..506bf9967 100644 --- a/src/uint/bits.rs +++ b/src/uint/bits.rs @@ -69,7 +69,7 @@ impl Uint { /// Get the value of the bit at position `index`, as a truthy or falsy `CtChoice`. /// Returns the falsy value for indices out of range. pub const fn bit(&self, index: usize) -> CtChoice { - let limb_num = Limb((index / Limb::BITS) as Word); + let limb_num = index / Limb::BITS; let index_in_limb = index % Limb::BITS; let index_mask = 1 << index_in_limb; @@ -79,18 +79,36 @@ impl Uint { let mut i = 0; while i < LIMBS { let bit = limbs[i] & index_mask; - let is_right_limb = Limb::ct_eq(limb_num, Limb(i as Word)); + let is_right_limb = CtChoice::from_usize_equality(i, limb_num); result |= is_right_limb.if_true(bit); i += 1; } CtChoice::from_lsb(result >> index_in_limb) } + + /// Sets the bit at `index` to 0 or 1 depending on the value of `bit_value`. + pub(crate) const fn set_bit(self, index: usize, bit_value: CtChoice) -> Self { + let mut result = self; + let limb_num = index / Limb::BITS; + let index_in_limb = index % Limb::BITS; + let index_mask = 1 << index_in_limb; + + let mut i = 0; + while i < LIMBS { + let is_right_limb = CtChoice::from_usize_equality(i, limb_num); + let old_limb = result.limbs[i].0; + let new_limb = bit_value.select(old_limb & !index_mask, old_limb | index_mask); + result.limbs[i] = Limb(is_right_limb.select(old_limb, new_limb)); + i += 1; + } + result + } } #[cfg(test)] mod tests { - use crate::U256; + use crate::{CtChoice, U256}; fn uint_with_bits_at(positions: &[usize]) -> U256 { let mut result = U256::ZERO; @@ -159,4 +177,31 @@ mod tests { let u = U256::ZERO; assert_eq!(u.trailing_zeros() as u32, 256); } + + #[test] + fn set_bit() { + let u = uint_with_bits_at(&[16, 79, 150]); + assert_eq!( + u.set_bit(127, CtChoice::TRUE), + uint_with_bits_at(&[16, 79, 127, 150]) + ); + + let u = uint_with_bits_at(&[16, 79, 150]); + assert_eq!( + u.set_bit(150, CtChoice::TRUE), + uint_with_bits_at(&[16, 79, 150]) + ); + + let u = uint_with_bits_at(&[16, 79, 150]); + assert_eq!( + u.set_bit(127, CtChoice::FALSE), + uint_with_bits_at(&[16, 79, 150]) + ); + + let u = uint_with_bits_at(&[16, 79, 150]); + assert_eq!( + u.set_bit(150, CtChoice::FALSE), + uint_with_bits_at(&[16, 79]) + ); + } } diff --git a/src/uint/inv_mod.rs b/src/uint/inv_mod.rs index ef3c161f5..cf7cf071c 100644 --- a/src/uint/inv_mod.rs +++ b/src/uint/inv_mod.rs @@ -1,28 +1,68 @@ use super::Uint; -use crate::{CtChoice, Limb}; +use crate::CtChoice; impl Uint { - /// Computes 1/`self` mod 2^k as specified in Algorithm 4 from - /// A Secure Algorithm for Inversion Modulo 2k by - /// Sadiel de la Fe and Carles Ferrer. See - /// . + /// Computes 1/`self` mod `2^k`. + /// This method is constant-time w.r.t. `self` but not `k`. /// /// Conditions: `self` < 2^k and `self` must be odd - pub const fn inv_mod2k(&self, k: usize) -> Self { - let mut x = Self::ZERO; - let mut b = Self::ONE; + pub const fn inv_mod2k_vartime(&self, k: usize) -> Self { + // Using the Algorithm 3 from "A Secure Algorithm for Inversion Modulo 2k" + // by Sadiel de la Fe and Carles Ferrer. + // See . + + // Note that we are not using Alrgorithm 4, since we have a different approach + // of enforcing constant-timeness w.r.t. `self`. + + let mut x = Self::ZERO; // keeps `x` during iterations + let mut b = Self::ONE; // keeps `b_i` during iterations let mut i = 0; while i < k { - let mut x_i = Self::ZERO; - let j = b.limbs[0].0 & 1; - x_i.limbs[0] = Limb(j); - x = x.bitor(&x_i.shl_vartime(i)); + // X_i = b_i mod 2 + let x_i = b.limbs[0].0 & 1; + let x_i_choice = CtChoice::from_lsb(x_i); + // b_{i+1} = (b_i - a * X_i) / 2 + b = Self::ct_select(&b, &b.wrapping_sub(self), x_i_choice).shr_vartime(1); + // Store the X_i bit in the result (x = x | (1 << X_i)) + x = x.bitor(&Uint::from_word(x_i).shl_vartime(i)); + + i += 1; + } + + x + } + + /// Computes 1/`self` mod `2^k`. + /// + /// Conditions: `self` < 2^k and `self` must be odd + pub const fn inv_mod2k(&self, k: usize) -> Self { + // This is the same algorithm as in `inv_mod2k_vartime()`, + // but made constant-time w.r.t `k` as well. + + let mut x = Self::ZERO; // keeps `x` during iterations + let mut b = Self::ONE; // keeps `b_i` during iterations + let mut i = 0; + + while i < Self::BITS { + // Only iterations for i = 0..k need to change `x`, + // the rest are dummy ones performed for the sake of constant-timeness. + let within_range = CtChoice::from_usize_lt(i, k); + + // X_i = b_i mod 2 + let x_i = b.limbs[0].0 & 1; + let x_i_choice = CtChoice::from_lsb(x_i); + // b_{i+1} = (b_i - a * X_i) / 2 + b = Self::ct_select(&b, &b.wrapping_sub(self), x_i_choice).shr_vartime(1); + + // Store the X_i bit in the result (x = x | (1 << X_i)) + // Don't change the result in dummy iterations. + let x_i_choice = x_i_choice.and(within_range); + x = x.set_bit(i, x_i_choice); - let t = b.wrapping_sub(self); - b = Self::ct_select(&b, &t, CtChoice::from_lsb(j)).shr_vartime(1); i += 1; } + x } @@ -97,10 +137,45 @@ impl Uint { } /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd. - /// Returns `(inverse, Word::MAX)` if an inverse exists, otherwise `(undefined, Word::ZERO)`. + /// Returns `(inverse, CtChoice::TRUE)` if an inverse exists, + /// otherwise `(undefined, CtChoice::FALSE)`. pub const fn inv_odd_mod(&self, modulus: &Self) -> (Self, CtChoice) { self.inv_odd_mod_bounded(modulus, Uint::::BITS, Uint::::BITS) } + + /// Computes the multiplicative inverse of `self` mod `modulus`. + /// Returns `(inverse, CtChoice::TRUE)` if an inverse exists, + /// otherwise `(undefined, CtChoice::FALSE)`. + pub fn inv_mod(&self, modulus: &Self) -> (Self, CtChoice) { + // Decompose `modulus = s * 2^k` where `s` is odd + let k = modulus.trailing_zeros(); + let s = modulus.shr(k); + + // Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses. + // Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1` + let (a, a_is_some) = self.inv_odd_mod(&s); + let b = self.inv_mod2k(k); + // inverse modulo 2^k exists either if `k` is 0 or if `self` is odd. + let b_is_some = CtChoice::from_usize_being_nonzero(k) + .not() + .or(self.ct_is_odd()); + + // Restore from RNS: + // self^{-1} = a mod s = b mod 2^k + // => self^{-1} = a + s * ((b - a) * s^(-1) mod 2^k) + // (essentially one step of the Garner's algorithm for recovery from RNS). + + let m_odd_inv = s.inv_mod2k(k); // `s` is odd, so this always exists + + // This part is mod 2^k + let mask = (Uint::ONE << k).wrapping_sub(&Uint::ONE); + let t = (b.wrapping_sub(&a).wrapping_mul(&m_odd_inv)) & mask; + + // Will not overflow since `a <= s - 1`, `t <= 2^k - 1`, + // so `a + s * t <= s * 2^k - 1 == modulus - 1`. + let result = a.wrapping_add(&s.wrapping_mul(&t)); + (result, a_is_some.and(b_is_some)) + } } #[cfg(test)] @@ -125,7 +200,7 @@ mod tests { } #[test] - fn test_invert() { + fn test_invert_odd() { let a = U1024::from_be_hex(concat![ "000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E", "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8", @@ -138,15 +213,45 @@ mod tests { "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C", "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767" ]); - - let (res, is_some) = a.inv_odd_mod(&m); - let expected = U1024::from_be_hex(concat![ "B03623284B0EBABCABD5C5881893320281460C0A8E7BF4BFDCFFCBCCBF436A55", "D364235C8171E46C7D21AAD0680676E57274A8FDA6D12768EF961CACDD2DAE57", "88D93DA5EB8EDC391EE3726CDCF4613C539F7D23E8702200CB31B5ED5B06E5CA", "3E520968399B4017BF98A864FABA2B647EFC4998B56774D4F2CB026BC024A336" ]); + + let (res, is_some) = a.inv_odd_mod(&m); + assert!(is_some.is_true_vartime()); + assert_eq!(res, expected); + + // Even though it is less efficient, it still works + let (res, is_some) = a.inv_mod(&m); + assert!(is_some.is_true_vartime()); + assert_eq!(res, expected); + } + + #[test] + fn test_invert_even() { + let a = U1024::from_be_hex(concat![ + "000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E", + "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8", + "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8", + "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985" + ]); + let m = U1024::from_be_hex(concat![ + "D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F", + "37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A", + "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C", + "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156000" + ]); + let expected = U1024::from_be_hex(concat![ + "1EBF391306817E1BC610E213F4453AD70911CCBD59A901B2A468A4FC1D64F357", + "DBFC6381EC5635CAA664DF280028AF4651482C77A143DF38D6BFD4D64B6C0225", + "FC0E199B15A64966FB26D88A86AD144271F6BDCD3D63193AB2B3CC53B99F21A3", + "5B9BFAE5D43C6BC6E7A9856C71C7318C76530E9E5AE35882D5ABB02F1696874D", + ]); + + let (res, is_some) = a.inv_mod(&m); assert!(is_some.is_true_vartime()); assert_eq!(res, expected); } diff --git a/src/uint/modular/constant_mod/macros.rs b/src/uint/modular/constant_mod/macros.rs index 5e8ede8aa..dfa440e02 100644 --- a/src/uint/modular/constant_mod/macros.rs +++ b/src/uint/modular/constant_mod/macros.rs @@ -32,7 +32,7 @@ macro_rules! impl_modulus { const MOD_NEG_INV: $crate::Limb = $crate::Limb( $crate::Word::MIN.wrapping_sub( Self::MODULUS - .inv_mod2k($crate::Word::BITS as usize) + .inv_mod2k_vartime($crate::Word::BITS as usize) .as_limbs()[0] .0, ), diff --git a/src/uint/modular/runtime_mod.rs b/src/uint/modular/runtime_mod.rs index 4468d71ac..f04723298 100644 --- a/src/uint/modular/runtime_mod.rs +++ b/src/uint/modular/runtime_mod.rs @@ -47,8 +47,9 @@ impl DynResidueParams { // Since we are calculating the inverse modulo (Word::MAX+1), // we can take the modulo right away and calculate the inverse of the first limb only. let modulus_lo = Uint::<1>::from_words([modulus.limbs[0].0]); - let mod_neg_inv = - Limb(Word::MIN.wrapping_sub(modulus_lo.inv_mod2k(Word::BITS as usize).limbs[0].0)); + let mod_neg_inv = Limb( + Word::MIN.wrapping_sub(modulus_lo.inv_mod2k_vartime(Word::BITS as usize).limbs[0].0), + ); let r3 = montgomery_reduction(&r2.square_wide(), modulus, mod_neg_inv); diff --git a/tests/proptests.rs b/tests/proptests.rs index b87c70113..bad14bc1b 100644 --- a/tests/proptests.rs +++ b/tests/proptests.rs @@ -2,7 +2,7 @@ use crypto_bigint::{ modular::runtime_mod::{DynResidue, DynResidueParams}, - Encoding, Limb, NonZero, Word, U256, + CtChoice, Encoding, Limb, NonZero, Word, U256, }; use num_bigint::BigUint; use num_integer::Integer; @@ -204,7 +204,7 @@ proptest! { let a_bi = to_biguint(&a); let b_bi = to_biguint(&b); - if b_bi.is_zero() { + if !b_bi.is_zero() { let expected = to_uint(a_bi % b_bi); let actual = a.wrapping_rem(&b); @@ -212,6 +212,44 @@ proptest! { } } + #[test] + fn inv_mod2k(a in uint(), k in any::()) { + let a = a | U256::ONE; // make odd + let k = k % (U256::BITS + 1); + let a_bi = to_biguint(&a); + let m_bi = BigUint::one() << k; + + let actual = a.inv_mod2k(k); + let actual_vartime = a.inv_mod2k_vartime(k); + assert_eq!(actual, actual_vartime); + + if k == 0 { + assert_eq!(actual, U256::ZERO); + } + else { + let inv_bi = to_biguint(&actual); + let res = (inv_bi * a_bi) % m_bi; + assert_eq!(res, BigUint::one()); + } + } + + #[test] + fn inv_mod(a in uint(), b in uint()) { + let a_bi = to_biguint(&a); + let b_bi = to_biguint(&b); + + let expected_is_some = if a_bi.gcd(&b_bi) == BigUint::one() { CtChoice::TRUE } else { CtChoice::FALSE }; + let (actual, actual_is_some) = a.inv_mod(&b); + + assert_eq!(bool::from(expected_is_some), bool::from(actual_is_some)); + + if actual_is_some.into() { + let inv_bi = to_biguint(&actual); + let res = (inv_bi * a_bi) % b_bi; + assert_eq!(res, BigUint::one()); + } + } + #[test] fn wrapping_sqrt(a in uint()) { let a_bi = to_biguint(&a);