diff --git a/.github/workflows/crypto-bigint.yml b/.github/workflows/crypto-bigint.yml index 8e98a130..bb5556e0 100644 --- a/.github/workflows/crypto-bigint.yml +++ b/.github/workflows/crypto-bigint.yml @@ -19,7 +19,7 @@ jobs: strategy: matrix: rust: - - 1.57.0 # MSRV + - 1.61.0 # MSRV - stable target: - thumbv7em-none-eabi @@ -49,7 +49,7 @@ jobs: include: # 32-bit Linux - target: i686-unknown-linux-gnu - rust: 1.57.0 # MSRV + rust: 1.61.0 # MSRV deps: sudo apt update && sudo apt install gcc-multilib - target: i686-unknown-linux-gnu rust: stable @@ -57,7 +57,7 @@ jobs: # 64-bit Linux - target: x86_64-unknown-linux-gnu - rust: 1.57.0 # MSRV + rust: 1.61.0 # MSRV - target: x86_64-unknown-linux-gnu rust: stable steps: @@ -122,7 +122,7 @@ jobs: - uses: actions/checkout@v3 - uses: actions-rs/toolchain@v1 with: - toolchain: 1.57.0 + toolchain: 1.61.0 components: clippy override: true profile: minimal diff --git a/Cargo.toml b/Cargo.toml index 909bcdba..e2e0775f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ keywords = ["arbitrary", "crypto", "bignum", "integer", "precision"] readme = "README.md" resolver = "2" edition = "2021" -rust-version = "1.57" +rust-version = "1.61" [dependencies] subtle = { version = "2.4", default-features = false } diff --git a/src/lib.rs b/src/lib.rs index 4d376421..855ea9b4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,6 +47,15 @@ //! pub const MODULUS_SHR1: U256 = MODULUS.shr_vartime(1); //! ``` //! +//! Note that large constant computations may accidentally trigger a the `const_eval_limit` of the compiler. +//! The current way to deal with this problem is to either simplify this computation, +//! or increase the compiler's limit (currently a nightly feature). +//! One can completely remove the compiler's limit using: +//! ```ignore +//! #![feature(const_eval_limit)] +//! #![const_eval_limit = "0"] +//! ``` +//! //! ### Trait-based usage //! //! The [`UInt`] type itself does not implement the standard arithmetic traits @@ -100,6 +109,10 @@ //! assert_eq!(b, U256::ZERO); //! ``` //! +//! It also supports modular arithmetic over constant moduli using `Residue`. +//! That includes modular exponentiation and multiplicative inverses. +//! These features are described in the [`modular`] module. +//! //! ### Random number generation //! //! When the `rand_core` or `rand` features of this crate are enabled, it's diff --git a/src/uint.rs b/src/uint.rs index 5aac67c5..2e730986 100644 --- a/src/uint.rs +++ b/src/uint.rs @@ -25,6 +25,7 @@ mod from; mod inv_mod; mod mul; mod mul_mod; +mod neg; mod neg_mod; mod resize; mod shl; @@ -33,6 +34,9 @@ mod sqrt; mod sub; mod sub_mod; +/// Implements modular arithmetic for constant moduli. +pub mod modular; + #[cfg(feature = "generic-array")] mod array; diff --git a/src/uint/add.rs b/src/uint/add.rs index 2822e9e6..8a7f929c 100644 --- a/src/uint/add.rs +++ b/src/uint/add.rs @@ -1,6 +1,6 @@ //! [`UInt`] addition operations. -use crate::{Checked, CheckedAdd, Limb, UInt, Wrapping, Zero}; +use crate::{Checked, CheckedAdd, Limb, UInt, Word, Wrapping, Zero}; use core::ops::{Add, AddAssign}; use subtle::CtOption; @@ -36,6 +36,14 @@ impl UInt { pub const fn wrapping_add(&self, rhs: &Self) -> Self { self.adc(rhs, Limb::ZERO).0 } + + /// Perform wrapping addition, returning the overflow bit as a `Word` that is either 0...0 or 1...1. + pub(crate) const fn conditional_wrapping_add(&self, rhs: &Self, choice: Word) -> (Self, Word) { + let actual_rhs = UInt::ct_select(UInt::ZERO, *rhs, choice); + let (sum, carry) = self.adc(&actual_rhs, Limb::ZERO); + + (sum, carry.0.wrapping_mul(Word::MAX)) + } } impl CheckedAdd<&UInt> for UInt { diff --git a/src/uint/bit_and.rs b/src/uint/bit_and.rs index cab89a42..e6eab65f 100644 --- a/src/uint/bit_and.rs +++ b/src/uint/bit_and.rs @@ -46,6 +46,7 @@ impl BitAnd for UInt { impl BitAnd<&UInt> for UInt { type Output = UInt; + #[allow(clippy::needless_borrow)] fn bitand(self, rhs: &UInt) -> UInt { (&self).bitand(rhs) } diff --git a/src/uint/bit_not.rs b/src/uint/bit_not.rs index 747d3b49..914774a8 100644 --- a/src/uint/bit_not.rs +++ b/src/uint/bit_not.rs @@ -23,6 +23,7 @@ impl UInt { impl Not for UInt { type Output = Self; + #[allow(clippy::needless_borrow)] fn not(self) -> ::Output { (&self).not() } diff --git a/src/uint/bit_or.rs b/src/uint/bit_or.rs index 4a01a834..8f6b84db 100644 --- a/src/uint/bit_or.rs +++ b/src/uint/bit_or.rs @@ -46,6 +46,7 @@ impl BitOr for UInt { impl BitOr<&UInt> for UInt { type Output = UInt; + #[allow(clippy::needless_borrow)] fn bitor(self, rhs: &UInt) -> UInt { (&self).bitor(rhs) } diff --git a/src/uint/bit_xor.rs b/src/uint/bit_xor.rs index 16d78ad3..0e886ca2 100644 --- a/src/uint/bit_xor.rs +++ b/src/uint/bit_xor.rs @@ -46,6 +46,7 @@ impl BitXor for UInt { impl BitXor<&UInt> for UInt { type Output = UInt; + #[allow(clippy::needless_borrow)] fn bitxor(self, rhs: &UInt) -> UInt { (&self).bitxor(rhs) } diff --git a/src/uint/cmp.rs b/src/uint/cmp.rs index 19046df9..676622ce 100644 --- a/src/uint/cmp.rs +++ b/src/uint/cmp.rs @@ -24,6 +24,14 @@ impl UInt { UInt { limbs } } + #[inline] + pub(crate) const fn ct_swap(a: UInt, b: UInt, c: Word) -> (Self, Self) { + let new_a = Self::ct_select(a, b, c); + let new_b = Self::ct_select(b, a, c); + + (new_a, new_b) + } + /// Returns all 1's if `self`!=0 or 0 if `self`==0. /// /// Const-friendly: we can't yet use `subtle` in `const fn` contexts. @@ -38,6 +46,10 @@ impl UInt { Limb::is_nonzero(Limb(b)) } + pub(crate) const fn ct_is_odd(&self) -> Word { + (self.limbs[0].0 & 1).wrapping_mul(Word::MAX) + } + /// Returns -1 if self < rhs /// 0 if self == rhs /// 1 if self > rhs diff --git a/src/uint/div.rs b/src/uint/div.rs index f7d9d6bf..828b9f50 100644 --- a/src/uint/div.rs +++ b/src/uint/div.rs @@ -68,6 +68,43 @@ impl UInt { (rem, (is_some & 1) as u8) } + /// Computes `self` % `rhs`, returns the remainder and + /// and 1 for is_some or 0 for is_none. The results can be wrapped in [`CtOption`]. + /// NOTE: Use only if you need to access const fn. Otherwise use `reduce` + /// This is variable only with respect to `rhs`. + /// + /// When used with a fixed `rhs`, this function is constant-time with respect + /// to `self`. + #[allow(dead_code)] + pub(crate) const fn ct_reduce_wide(lower_upper: (Self, Self), rhs: &Self) -> (Self, u8) { + let mb = rhs.bits_vartime(); + + // The number of bits to consider is two sets of limbs * BIT_SIZE - mb (modulus bitcount) + let mut bd = (2 * LIMBS * Limb::BIT_SIZE) - mb; + + // The wide integer to reduce, split into two halves + let (mut lower, mut upper) = lower_upper; + + // Factor of the modulus, split into two halves + let mut c = Self::shl_vartime_wide((*rhs, UInt::ZERO), bd); + + loop { + let (lower_sub, borrow) = lower.sbb(&c.0, Limb::ZERO); + let (upper_sub, borrow) = upper.sbb(&c.1, borrow); + + lower = Self::ct_select(lower_sub, lower, borrow.0); + upper = Self::ct_select(upper_sub, upper, borrow.0); + if bd == 0 { + break; + } + bd -= 1; + c = Self::shr_vartime_wide(c, 1); + } + + let is_some = Limb(mb as Word).is_nonzero(); + (lower, (is_some & 1) as u8) + } + /// Computes `self` % 2^k. Faster than reduce since its a power of 2. /// Limited to 2^16-1 since UInt doesn't support higher. pub const fn reduce2k(&self, k: usize) -> Self { @@ -466,6 +503,19 @@ mod tests { assert_eq!(r, U256::from(3u8)); } + #[test] + fn reduce_tests_wide_zero_padded() { + let (r, is_some) = U256::ct_reduce_wide((U256::from(10u8), U256::ZERO), &U256::from(2u8)); + assert_eq!(is_some, 1); + assert_eq!(r, U256::ZERO); + let (r, is_some) = U256::ct_reduce_wide((U256::from(10u8), U256::ZERO), &U256::from(3u8)); + assert_eq!(is_some, 1); + assert_eq!(r, U256::ONE); + let (r, is_some) = U256::ct_reduce_wide((U256::from(10u8), U256::ZERO), &U256::from(7u8)); + assert_eq!(is_some, 1); + assert_eq!(r, U256::from(3u8)); + } + #[test] fn reduce_max() { let mut a = U256::ZERO; diff --git a/src/uint/inv_mod.rs b/src/uint/inv_mod.rs index a1140856..911fbc4e 100644 --- a/src/uint/inv_mod.rs +++ b/src/uint/inv_mod.rs @@ -1,5 +1,7 @@ +use subtle::{Choice, CtOption}; + use super::UInt; -use crate::Limb; +use crate::{Limb, Word}; impl UInt { /// Computes 1/`self` mod 2^k as specified in Algorithm 4 from @@ -25,11 +27,72 @@ impl UInt { } x } + + /// Computes the multiplicative inverse of `self` mod `modulus`. In other words `self^-1 mod modulus`. Returns `(inverse, 1...1)` if an inverse exists, otherwise `(undefined, 0...0)`. The algorithm is the same as in GMP 6.2.1's `mpn_sec_invert`. + pub const fn inv_odd_mod(self, modulus: UInt) -> (Self, Word) { + debug_assert!(modulus.ct_is_odd() == Word::MAX); + + let mut a = self; + + let mut u = UInt::ONE; + let mut v = UInt::ZERO; + + let mut b = modulus; + + // TODO: This can be lower if `self` is known to be small. + let bit_size = 2 * LIMBS * 64; + + let mut m1hp = modulus; + let (m1hp_new, carry) = m1hp.shr_1(); + debug_assert!(carry == Word::MAX); + m1hp = m1hp_new.wrapping_add(&UInt::ONE); + + let mut i = 0; + while i < bit_size { + debug_assert!(b.ct_is_odd() == Word::MAX); + + let self_odd = a.ct_is_odd(); + + // Set `self -= b` if `self` is odd. + let (new_a, swap) = a.conditional_wrapping_sub(&b, self_odd); + // Set `b += self` if `swap` is true. + b = UInt::ct_select(b, b.wrapping_add(&new_a), swap); + // Negate `self` if `swap` is true. + a = new_a.conditional_wrapping_neg(swap); + + let (new_u, new_v) = UInt::ct_swap(u, v, swap); + let (new_u, cy) = new_u.conditional_wrapping_sub(&new_v, self_odd); + let (new_u, cyy) = new_u.conditional_wrapping_add(&modulus, cy); + debug_assert!(cy == cyy); + + let (new_a, overflow) = a.shr_1(); + debug_assert!(overflow == 0); + let (new_u, cy) = new_u.shr_1(); + let (new_u, cy) = new_u.conditional_wrapping_add(&m1hp, cy); + debug_assert!(cy == 0); + + a = new_a; + u = new_u; + v = new_v; + + i += 1; + } + + debug_assert!(a.ct_cmp(&UInt::ZERO) == 0); + + (v, b.ct_not_eq(&UInt::ONE) ^ Word::MAX) + } + + /// Computes the multiplicative inverse of `self` mod `modulus`. In other words `self^-1 mod modulus`. Returns `None` if the inverse does not exist. The algorithm is the same as in GMP 6.2.1's `mpn_sec_invert`. + pub fn inv_odd_mod_option(self, modulus: UInt) -> CtOption { + let (inverse, exists) = self.inv_odd_mod(modulus); + CtOption::new(inverse, Choice::from((exists == Word::MAX) as u8)) + } } #[cfg(test)] mod tests { - use crate::U256; + use crate::{U1024, U256, U64}; #[test] fn inv_mod2k() { @@ -59,4 +122,35 @@ mod tests { let a = v.inv_mod2k(256); assert_eq!(e, a); } + + #[test] + fn test_invert() { + let a = U1024::from_be_hex("000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"); + let m = U1024::from_be_hex("D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001AD198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767"); + + let res = a.inv_odd_mod_option(m); + + let expected = U1024::from_be_hex("B03623284B0EBABCABD5C5881893320281460C0A8E7BF4BFDCFFCBCCBF436A55D364235C8171E46C7D21AAD0680676E57274A8FDA6D12768EF961CACDD2DAE5788D93DA5EB8EDC391EE3726CDCF4613C539F7D23E8702200CB31B5ED5B06E5CA3E520968399B4017BF98A864FABA2B647EFC4998B56774D4F2CB026BC024A336"); + assert_eq!(res.unwrap(), expected); + } + + #[test] + fn test_invert_small() { + let a = U64::from(3u64); + let m = U64::from(13u64); + + let res = a.inv_odd_mod_option(m); + + assert_eq!(U64::from(9u64), res.unwrap()); + } + + #[test] + fn test_no_inverse_small() { + let a = U64::from(14u64); + let m = U64::from(49u64); + + let res = a.inv_odd_mod_option(m); + + assert!(res.is_none().unwrap_u8() == 1); + } } diff --git a/src/uint/modular/add.rs b/src/uint/modular/add.rs new file mode 100644 index 00000000..ef3e5a0e --- /dev/null +++ b/src/uint/modular/add.rs @@ -0,0 +1,14 @@ +use crate::UInt; + +pub trait AddResidue { + /// Computes the (reduced) sum of two residues. + fn add(&self, rhs: &Self) -> Self; +} + +pub(crate) const fn add_montgomery_form( + a: &UInt, + b: &UInt, + modulus: &UInt, +) -> UInt { + a.add_mod(b, modulus) +} diff --git a/src/uint/modular/constant_mod/const_add.rs b/src/uint/modular/constant_mod/const_add.rs new file mode 100644 index 00000000..288b2a76 --- /dev/null +++ b/src/uint/modular/constant_mod/const_add.rs @@ -0,0 +1,72 @@ +use core::ops::AddAssign; + +use crate::{ + modular::add::{add_montgomery_form, AddResidue}, + UInt, +}; + +use super::{Residue, ResidueParams}; + +impl, const LIMBS: usize> AddResidue for Residue { + fn add(&self, rhs: &Self) -> Self { + self.add(rhs) + } +} + +impl, const LIMBS: usize> Residue { + /// Adds two residues together. + pub const fn add(&self, rhs: &Self) -> Self { + Residue { + montgomery_form: add_montgomery_form( + &self.montgomery_form, + &rhs.montgomery_form, + &MOD::MODULUS, + ), + phantom: core::marker::PhantomData, + } + } +} + +impl, const LIMBS: usize> AddAssign<&UInt> + for Residue +{ + fn add_assign(&mut self, rhs: &UInt) { + *self += &Residue::new(*rhs); + } +} + +impl, const LIMBS: usize> AddAssign<&Self> for Residue { + fn add_assign(&mut self, rhs: &Self) { + *self = self.add(rhs); + } +} + +#[cfg(test)] +mod tests { + use crate::{ + const_residue, impl_modulus, modular::constant_mod::ResidueParams, traits::Encoding, U256, + }; + + impl_modulus!( + Modulus, + U256, + "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551" + ); + + #[test] + fn add_overflow() { + let x = + U256::from_be_hex("44acf6b7e36c1342c2c5897204fe09504e1e2efb1a900377dbc4e7a6a133ec56"); + let mut x_mod = const_residue!(x, Modulus); + + let y = + U256::from_be_hex("d5777c45019673125ad240f83094d4252d829516fac8601ed01979ec1ec1a251"); + + x_mod += &y; + + let expected = + U256::from_be_hex("1a2472fde50286541d97ca6a3592dd75beb9c9646e40c511b82496cfc3926956"); + + assert_eq!(expected, x_mod.retrieve()); + } +} diff --git a/src/uint/modular/constant_mod/const_inv.rs b/src/uint/modular/constant_mod/const_inv.rs new file mode 100644 index 00000000..8bc94360 --- /dev/null +++ b/src/uint/modular/constant_mod/const_inv.rs @@ -0,0 +1,72 @@ +use core::marker::PhantomData; + +use subtle::{Choice, CtOption}; + +use crate::{ + modular::inv::{inv_montgomery_form, InvResidue}, + Word, +}; + +use super::{Residue, ResidueParams}; + +impl, const LIMBS: usize> InvResidue for Residue { + fn inv(self) -> CtOption { + let (montgomery_form, error) = inv_montgomery_form( + self.montgomery_form, + MOD::MODULUS, + &MOD::R3, + MOD::MOD_NEG_INV, + ); + + let value = Self { + montgomery_form, + phantom: PhantomData, + }; + + CtOption::new(value, Choice::from((error == Word::MAX) as u8)) + } +} + +impl, const LIMBS: usize> Residue { + /// Computes the residue `self^-1` representing the multiplicative inverse of `self`. I.e. `self * self^-1 = 1`. Panics if `self` was not invertible. + pub const fn inv(self) -> Self { + let (montgomery_form, error) = inv_montgomery_form( + self.montgomery_form, + MOD::MODULUS, + &MOD::R3, + MOD::MOD_NEG_INV, + ); + + assert!(error == Word::MAX); + + Self { + montgomery_form, + phantom: PhantomData, + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + const_residue, impl_modulus, modular::constant_mod::ResidueParams, traits::Encoding, U256, + }; + + impl_modulus!( + Modulus, + U256, + "15477BCCEFE197328255BFA79A1217899016D927EF460F4FF404029D24FA4409" + ); + + #[test] + fn test_self_inverse() { + let x = + U256::from_be_hex("77117F1273373C26C700D076B3F780074D03339F56DD0EFB60E7F58441FD3685"); + let x_mod = const_residue!(x, Modulus); + + let inv = x_mod.inv(); + let res = &x_mod * &inv; + + assert_eq!(res.retrieve(), U256::ONE); + } +} diff --git a/src/uint/modular/constant_mod/const_mul.rs b/src/uint/modular/constant_mod/const_mul.rs new file mode 100644 index 00000000..def2d71a --- /dev/null +++ b/src/uint/modular/constant_mod/const_mul.rs @@ -0,0 +1,60 @@ +use core::{ + marker::PhantomData, + ops::{Mul, MulAssign}, +}; + +use crate::modular::{ + mul::{mul_montgomery_form, MulResidue}, + reduction::montgomery_reduction, +}; + +use super::{Residue, ResidueParams}; + +impl, const LIMBS: usize> MulResidue for Residue { + fn mul(&self, rhs: &Self) -> Self { + self.mul(rhs) + } + + fn square(&self) -> Self { + self.square() + } +} + +impl, const LIMBS: usize> Residue { + /// Computes the (reduced) product between two residues. + pub const fn mul(&self, rhs: &Self) -> Self { + Self { + montgomery_form: mul_montgomery_form( + &self.montgomery_form, + &rhs.montgomery_form, + MOD::MODULUS, + MOD::MOD_NEG_INV, + ), + phantom: PhantomData, + } + } + + /// Computes the (reduced) square of a residue. + pub const fn square(&self) -> Self { + let lo_hi = self.montgomery_form.square_wide(); + + Self { + montgomery_form: montgomery_reduction::(lo_hi, MOD::MODULUS, MOD::MOD_NEG_INV), + phantom: PhantomData, + } + } +} + +impl, const LIMBS: usize> MulAssign<&Self> for Residue { + fn mul_assign(&mut self, rhs: &Self) { + *self = self.mul(rhs) + } +} + +impl, const LIMBS: usize> Mul for &Residue { + type Output = Residue; + + fn mul(self, rhs: Self) -> Self::Output { + self.mul(rhs) + } +} diff --git a/src/uint/modular/constant_mod/const_pow.rs b/src/uint/modular/constant_mod/const_pow.rs new file mode 100644 index 00000000..321daa68 --- /dev/null +++ b/src/uint/modular/constant_mod/const_pow.rs @@ -0,0 +1,97 @@ +use crate::{ + modular::pow::{pow_montgomery_form, PowResidue}, + UInt, Word, +}; + +use super::{Residue, ResidueParams}; + +impl, const LIMBS: usize> PowResidue for Residue { + fn pow_specific(self, exponent: &UInt, exponent_bits: usize) -> Self { + self.pow_specific(exponent, exponent_bits) + } +} + +impl, const LIMBS: usize> Residue { + /// Performs modular exponentiation using Montgomery's ladder. + pub const fn pow(self, exponent: &UInt) -> Residue { + self.pow_specific(exponent, LIMBS * Word::BITS as usize) + } + + /// Performs modular exponentiation using Montgomery's ladder. `exponent_bits` represents the number of bits to take into account for the exponent. Note that this value is leaked in the time pattern. + pub const fn pow_specific( + self, + exponent: &UInt, + exponent_bits: usize, + ) -> Residue { + Self { + montgomery_form: pow_montgomery_form( + self.montgomery_form, + exponent, + exponent_bits, + MOD::MODULUS, + MOD::R, + MOD::MOD_NEG_INV, + ), + phantom: core::marker::PhantomData, + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + const_residue, impl_modulus, modular::constant_mod::ResidueParams, traits::Encoding, U256, + }; + + impl_modulus!( + Modulus, + U256, + "9CC24C5DF431A864188AB905AC751B727C9447A8E99E6366E1AD78A21E8D882B" + ); + + #[test] + fn test_powmod_small_base() { + let base = U256::from(105u64); + let base_mod = const_residue!(base, Modulus); + + let exponent = + U256::from_be_hex("77117F1273373C26C700D076B3F780074D03339F56DD0EFB60E7F58441FD3685"); + + let res = base_mod.pow(&exponent); + + let expected = + U256::from_be_hex("7B2CD7BDDD96C271E6F232F2F415BB03FE2A90BD6CCCEA5E94F1BFD064993766"); + assert_eq!(res.retrieve(), expected); + } + + #[test] + fn test_powmod_small_exponent() { + let base = + U256::from_be_hex("3435D18AA8313EBBE4D20002922225B53F75DC4453BB3EEC0378646F79B524A4"); + let base_mod = const_residue!(base, Modulus); + + let exponent = U256::from(105u64); + + let res = base_mod.pow(&exponent); + + let expected = + U256::from_be_hex("89E2A4E99F649A5AE2C18068148C355CA927B34A3245C938178ED00D6EF218AA"); + assert_eq!(res.retrieve(), expected); + } + + #[test] + fn test_powmod() { + let base = + U256::from_be_hex("3435D18AA8313EBBE4D20002922225B53F75DC4453BB3EEC0378646F79B524A4"); + let base_mod = const_residue!(base, Modulus); + + let exponent = + U256::from_be_hex("77117F1273373C26C700D076B3F780074D03339F56DD0EFB60E7F58441FD3685"); + + let res = base_mod.pow(&exponent); + + let expected = + U256::from_be_hex("3681BC0FEA2E5D394EB178155A127B0FD2EF405486D354251C385BDD51B9D421"); + assert_eq!(res.retrieve(), expected); + } +} diff --git a/src/uint/modular/constant_mod/macros.rs b/src/uint/modular/constant_mod/macros.rs new file mode 100644 index 00000000..8c23fb57 --- /dev/null +++ b/src/uint/modular/constant_mod/macros.rs @@ -0,0 +1,47 @@ +// TODO: Use `adt_const_params` once stabilized to make a `Residue` generic around a modulus rather than having to implement a ZST + trait +#[macro_export] +/// Implements a modulus with the given name, type, and value, in that specific order. Please `use crypto_bigint::traits::Encoding` to make this work. +/// For example, `impl_modulus!(MyModulus, U256, "73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001");` implements a 256-bit modulus named `MyModulus`. +macro_rules! impl_modulus { + ($name:ident, $uint_type:ty, $value:expr) => { + #[derive(Clone, Copy, PartialEq, Eq, Debug)] + pub struct $name {} + impl ResidueParams<{ nlimbs!(<$uint_type>::BIT_SIZE) }> for $name + where + $crate::UInt<{ nlimbs!(<$uint_type>::BIT_SIZE) }>: + $crate::traits::Concat>, + $crate::UInt: $crate::traits::Split, + { + const LIMBS: usize = { nlimbs!(<$uint_type>::BIT_SIZE) }; + const MODULUS: $crate::UInt<{ nlimbs!(<$uint_type>::BIT_SIZE) }> = + <$uint_type>::from_be_hex($value); + const R: $crate::UInt<{ nlimbs!(<$uint_type>::BIT_SIZE) }> = $crate::UInt::MAX + .ct_reduce(&Self::MODULUS) + .0 + .wrapping_add(&$crate::UInt::ONE); + const R2: $crate::UInt<{ nlimbs!(<$uint_type>::BIT_SIZE) }> = + $crate::UInt::ct_reduce_wide(Self::R.square_wide(), &Self::MODULUS).0; + const MOD_NEG_INV: $crate::Limb = $crate::Limb( + $crate::Word::MIN + .wrapping_sub(Self::MODULUS.inv_mod2k($crate::Word::BITS as usize).limbs[0].0), + ); + const R3: $crate::UInt<{ nlimbs!(<$uint_type>::BIT_SIZE) }> = + $crate::uint::modular::reduction::montgomery_reduction( + Self::R2.square_wide(), + Self::MODULUS, + Self::MOD_NEG_INV, + ); + } + }; +} + +#[macro_export] +/// Creates a `Residue` with the given value for a specific modulus. +/// For example, `residue!(U256::from(105u64), MyModulus);` creates a `Residue` for 105 mod `MyModulus`. +macro_rules! const_residue { + ($variable:ident, $modulus:ident) => { + $crate::uint::modular::constant_mod::Residue::<$modulus, { $modulus::LIMBS }>::new( + $variable, + ) + }; +} diff --git a/src/uint/modular/constant_mod/mod.rs b/src/uint/modular/constant_mod/mod.rs new file mode 100644 index 00000000..877ccc3e --- /dev/null +++ b/src/uint/modular/constant_mod/mod.rs @@ -0,0 +1,102 @@ +use core::marker::PhantomData; + +use subtle::{Choice, ConditionallySelectable}; + +use crate::{Limb, UInt}; + +use super::{reduction::montgomery_reduction, GenericResidue}; + +/// Additions between residues with a constant modulus +mod const_add; +/// Multiplicative inverses of residues with a constant modulus +mod const_inv; +/// Multiplications between residues with a constant modulus +mod const_mul; +/// Exponentiation of residues with a constant modulus +mod const_pow; + +#[macro_use] +/// Macros to remove the boilerplate code when dealing with constant moduli. +pub mod macros; + +/// The parameters to efficiently go to and from the Montgomery form for a given odd modulus. An easy way to generate these parameters is using the `impl_modulus!` macro. These parameters are constant, so they cannot be set at runtime. +/// +/// Unfortunately, `LIMBS` must be generic for now until const generics are stabilized. +pub trait ResidueParams: Copy { + /// Number of limbs required to encode a residue + const LIMBS: usize; + + /// The constant modulus + const MODULUS: UInt; + /// Parameter used in Montgomery reduction + const R: UInt; + /// R^2, used to move into Montgomery form + const R2: UInt; + /// R^3, used to perform a multiplicative inverse + const R3: UInt; + /// The lowest limbs of -(MODULUS^-1) mod R + // We only need the LSB because during reduction this value is multiplied modulo 2**64. + const MOD_NEG_INV: Limb; +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +/// A residue mod `MOD`, represented using `LIMBS` limbs. The modulus of this residue is constant, so it cannot be set at runtime. +pub struct Residue +where + MOD: ResidueParams, +{ + montgomery_form: UInt, + phantom: PhantomData, +} + +impl, const LIMBS: usize> Residue { + /// The representation of 1 mod `MOD`. + pub const ONE: Self = Self { + montgomery_form: MOD::R, + phantom: PhantomData, + }; + + /// Instantiates a new `Residue` that represents this `integer` mod `MOD`. + pub const fn new(integer: UInt) -> Self { + let mut modular_integer = Residue { + montgomery_form: integer, + phantom: PhantomData, + }; + + let product = integer.mul_wide(&MOD::R2); + modular_integer.montgomery_form = + montgomery_reduction::(product, MOD::MODULUS, MOD::MOD_NEG_INV); + + modular_integer + } + + /// Retrieves the integer currently encoded in this `Residue`, guaranteed to be reduced. + pub const fn retrieve(&self) -> UInt { + montgomery_reduction::( + (self.montgomery_form, UInt::ZERO), + MOD::MODULUS, + MOD::MOD_NEG_INV, + ) + } +} + +impl, const LIMBS: usize> GenericResidue for Residue { + fn retrieve(&self) -> UInt { + self.retrieve() + } +} + +impl + Copy, const LIMBS: usize> ConditionallySelectable + for Residue +{ + fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { + Residue { + montgomery_form: UInt::conditional_select( + &a.montgomery_form, + &b.montgomery_form, + choice, + ), + phantom: PhantomData, + } + } +} diff --git a/src/uint/modular/inv.rs b/src/uint/modular/inv.rs new file mode 100644 index 00000000..60096bd3 --- /dev/null +++ b/src/uint/modular/inv.rs @@ -0,0 +1,24 @@ +use subtle::CtOption; + +use crate::{modular::reduction::montgomery_reduction, Limb, UInt, Word}; + +pub trait InvResidue +where + Self: Sized, +{ + /// Computes the (reduced) multiplicative inverse of the residue. Returns CtOption, which is None if the residue was not invertible. + fn inv(self) -> CtOption; +} + +pub const fn inv_montgomery_form( + x: UInt, + modulus: UInt, + r3: &UInt, + mod_neg_inv: Limb, +) -> (UInt, Word) { + let (inverse, error) = x.inv_odd_mod(modulus); + ( + montgomery_reduction(inverse.mul_wide(r3), modulus, mod_neg_inv), + error, + ) +} diff --git a/src/uint/modular/mod.rs b/src/uint/modular/mod.rs new file mode 100644 index 00000000..23534149 --- /dev/null +++ b/src/uint/modular/mod.rs @@ -0,0 +1,161 @@ +use crate::UInt; + +use self::{add::AddResidue, inv::InvResidue, mul::MulResidue, pow::PowResidue}; + +mod reduction; + +/// Implements `Residue`s, supporting modular arithmetic with a constant modulus. +pub mod constant_mod; + +mod add; +mod inv; +mod mul; +mod pow; + +/// The `GenericResidue` trait provides a consistent API for dealing with residues with a constant modulus. +pub trait GenericResidue: + AddResidue + MulResidue + PowResidue + InvResidue +{ + /// Retrieves the integer currently encoded in this `Residue`, guaranteed to be reduced. + fn retrieve(&self) -> UInt; +} + +#[cfg(test)] +mod tests { + use crate::{ + const_residue, impl_modulus, + modular::{ + constant_mod::Residue, constant_mod::ResidueParams, reduction::montgomery_reduction, + }, + traits::Encoding, + UInt, U256, U64, + }; + + impl_modulus!( + Modulus1, + U256, + "73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001" + ); + + #[test] + fn test_montgomery_params() { + assert_eq!( + Modulus1::R, + U256::from_be_hex("1824b159acc5056f998c4fefecbc4ff55884b7fa0003480200000001fffffffe") + ); + assert_eq!( + Modulus1::R2, + U256::from_be_hex("0748d9d99f59ff1105d314967254398f2b6cedcb87925c23c999e990f3f29c6d") + ); + assert_eq!( + Modulus1::MOD_NEG_INV, + U64::from_be_hex("fffffffeffffffff").limbs[0] + ); + } + + impl_modulus!( + Modulus2, + U256, + "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551" + ); + + #[test] + fn test_reducing_r() { + // Divide the value R by R, which should equal 1 + assert_eq!( + montgomery_reduction::<{ Modulus2::LIMBS }>( + (Modulus2::R, UInt::ZERO), + Modulus2::MODULUS, + Modulus2::MOD_NEG_INV + ), + UInt::ONE + ); + } + + #[test] + fn test_reducing_r2() { + // Divide the value R^2 by R, which should equal R + assert_eq!( + montgomery_reduction::<{ Modulus2::LIMBS }>( + (Modulus2::R2, UInt::ZERO), + Modulus2::MODULUS, + Modulus2::MOD_NEG_INV + ), + Modulus2::R + ); + } + + #[test] + fn test_reducing_r2_wide() { + // Divide the value R^2 by R, which should equal R + let (hi, lo) = Modulus2::R.square().split(); + assert_eq!( + montgomery_reduction::<{ Modulus2::LIMBS }>( + (lo, hi), + Modulus2::MODULUS, + Modulus2::MOD_NEG_INV + ), + Modulus2::R + ); + } + + #[test] + fn test_reducing_xr_wide() { + // Reducing xR should return x + let x = + U256::from_be_hex("44acf6b7e36c1342c2c5897204fe09504e1e2efb1a900377dbc4e7a6a133ec56"); + let product = x.mul_wide(&Modulus2::R); + assert_eq!( + montgomery_reduction::<{ Modulus2::LIMBS }>( + product, + Modulus2::MODULUS, + Modulus2::MOD_NEG_INV + ), + x + ); + } + + #[test] + fn test_reducing_xr2_wide() { + // Reducing xR^2 should return xR + let x = + U256::from_be_hex("44acf6b7e36c1342c2c5897204fe09504e1e2efb1a900377dbc4e7a6a133ec56"); + let product = x.mul_wide(&Modulus2::R2); + + // Computing xR mod modulus without Montgomery reduction + let (lo, hi) = x.mul_wide(&Modulus2::R); + let c = hi.concat(&lo); + let red = c.reduce(&U256::ZERO.concat(&Modulus2::MODULUS)).unwrap(); + let (hi, lo) = red.split(); + assert_eq!(hi, UInt::ZERO); + + assert_eq!( + montgomery_reduction::<{ Modulus2::LIMBS }>( + product, + Modulus2::MODULUS, + Modulus2::MOD_NEG_INV + ), + lo + ); + } + + #[test] + fn test_new_retrieve() { + let x = + U256::from_be_hex("44acf6b7e36c1342c2c5897204fe09504e1e2efb1a900377dbc4e7a6a133ec56"); + let x_mod = Residue::::new(x); + + // Confirm that when creating a Modular and retrieving the value, that it equals the original + assert_eq!(x, x_mod.retrieve()); + } + + #[test] + fn test_residue_macro() { + let x = + U256::from_be_hex("44acf6b7e36c1342c2c5897204fe09504e1e2efb1a900377dbc4e7a6a133ec56"); + assert_eq!( + Residue::::new(x), + const_residue!(x, Modulus2) + ); + } +} diff --git a/src/uint/modular/mul.rs b/src/uint/modular/mul.rs new file mode 100644 index 00000000..03008f0e --- /dev/null +++ b/src/uint/modular/mul.rs @@ -0,0 +1,34 @@ +use crate::{Limb, UInt}; + +use super::reduction::montgomery_reduction; + +pub trait MulResidue +where + Self: Sized, +{ + /// Computes the (reduced) product of two residues. + fn mul(&self, rhs: &Self) -> Self; + + fn square(&self) -> Self { + self.mul(self) + } +} + +pub(crate) const fn mul_montgomery_form( + a: &UInt, + b: &UInt, + modulus: UInt, + mod_neg_inv: Limb, +) -> UInt { + let product = a.mul_wide(b); + montgomery_reduction::(product, modulus, mod_neg_inv) +} + +pub(crate) const fn square_montgomery_form( + a: &UInt, + modulus: UInt, + mod_neg_inv: Limb, +) -> UInt { + let product = a.square_wide(); + montgomery_reduction::(product, modulus, mod_neg_inv) +} diff --git a/src/uint/modular/pow.rs b/src/uint/modular/pow.rs new file mode 100644 index 00000000..cc3b67e2 --- /dev/null +++ b/src/uint/modular/pow.rs @@ -0,0 +1,52 @@ +use crate::{Limb, UInt, Word}; + +use super::mul::{mul_montgomery_form, square_montgomery_form}; + +pub trait PowResidue +where + Self: Sized, +{ + /// Computes the (reduced) exponentiation of a residue. + fn pow(self, exponent: &UInt) -> Self { + self.pow_specific(exponent, LIMBS * Word::BITS as usize) + } + + /// Computes the (reduced) exponentiation of a residue, here `exponent_bits` represents the number of bits to take into account for the exponent. Note that this value is leaked in the time pattern. + fn pow_specific(self, exponent: &UInt, exponent_bits: usize) -> Self; +} + +/// Performs modular exponentiation using Montgomery's ladder. `exponent_bits` represents the number of bits to take into account for the exponent. Note that this value is leaked in the time pattern. +pub const fn pow_montgomery_form( + x: UInt, + exponent: &UInt, + exponent_bits: usize, + modulus: UInt, + r: UInt, + mod_neg_inv: Limb, +) -> UInt { + let mut x1: UInt = r; + let mut x2: UInt = x; + + // Shift the exponent all the way to the left so the leftmost bit is the MSB of the `UInt` + let mut n: UInt = exponent.shl_vartime((LIMBS * Word::BITS as usize) - exponent_bits); + + let mut i = 0; + while i < exponent_bits { + // Peel off one bit at a time from the left side + let (next_n, overflow) = n.shl_1(); + n = next_n; + + let mut product: UInt = x1; + product = mul_montgomery_form(&product, &x2, modulus, mod_neg_inv); + + let mut square = UInt::ct_select(x1, x2, overflow); + square = square_montgomery_form(&square, modulus, mod_neg_inv); + + x1 = UInt::::ct_select(square, product, overflow); + x2 = UInt::::ct_select(product, square, overflow); + + i += 1; + } + + x1 +} diff --git a/src/uint/modular/reduction.rs b/src/uint/modular/reduction.rs new file mode 100644 index 00000000..da4895a2 --- /dev/null +++ b/src/uint/modular/reduction.rs @@ -0,0 +1,57 @@ +use crate::{Limb, UInt, WideWord, Word}; + +/// Algorithm 14.32 in Handbook of Applied Cryptography (https://cacr.uwaterloo.ca/hac/about/chap14.pdf) +pub(crate) const fn montgomery_reduction( + lower_upper: (UInt, UInt), + modulus: UInt, + mod_neg_inv: Limb, +) -> UInt { + let (mut lower, mut upper) = lower_upper; + + let mut meta_carry = 0; + + let mut i = 0; + while i < LIMBS { + let u = (lower.limbs[i].0.wrapping_mul(mod_neg_inv.0)) as WideWord; + + let new_limb = + (u * modulus.limbs[0].0 as WideWord).wrapping_add(lower.limbs[i].0 as WideWord); + let mut carry = new_limb >> Word::BITS; + + let mut j = 1; + while j < (LIMBS - i) { + let new_limb = (u * modulus.limbs[j].0 as WideWord) + .wrapping_add(lower.limbs[i + j].0 as WideWord) + .wrapping_add(carry); + carry = new_limb >> Word::BITS; + lower.limbs[i + j] = Limb(new_limb as Word); + + j += 1; + } + while j < LIMBS { + let new_limb = (u * modulus.limbs[j].0 as WideWord) + .wrapping_add(upper.limbs[i + j - LIMBS].0 as WideWord) + .wrapping_add(carry); + carry = new_limb >> Word::BITS; + upper.limbs[i + j - LIMBS] = Limb(new_limb as Word); + + j += 1; + } + + let new_sum = (upper.limbs[i].0 as WideWord) + .wrapping_add(carry) + .wrapping_add(meta_carry); + meta_carry = new_sum >> Word::BITS; + upper.limbs[i] = Limb(new_sum as Word); + + i += 1; + } + + // Division is simply taking the upper half of the limbs + // Final reduction (at this point, the value is at most 2 * modulus) + let must_reduce = (meta_carry as Word).saturating_mul(Word::MAX) + | ((upper.ct_cmp(&modulus) != -1) as Word).saturating_mul(Word::MAX); + upper = upper.wrapping_sub(&UInt::ct_select(UInt::ZERO, modulus, must_reduce)); + + upper +} diff --git a/src/uint/mul.rs b/src/uint/mul.rs index ecb32fd1..3600aa55 100644 --- a/src/uint/mul.rs +++ b/src/uint/mul.rs @@ -75,14 +75,19 @@ impl UInt { self.mul_wide(rhs).0 } - /// Square self, returning a "wide" result. + /// Square self, returning a concatenated "wide" result. pub fn square(&self) -> ::Output where Self: Concat, { - let (lo, hi) = self.mul_wide(self); + let (lo, hi) = self.square_wide(); hi.concat(&lo) } + + /// Square self, returning a "wide" result in two parts as (lo, hi). + pub const fn square_wide(&self) -> (Self, Self) { + self.mul_wide(self) + } } impl CheckedMul<&UInt> for UInt { diff --git a/src/uint/neg.rs b/src/uint/neg.rs new file mode 100644 index 00000000..0af1601c --- /dev/null +++ b/src/uint/neg.rs @@ -0,0 +1,22 @@ +use core::ops::Neg; + +use crate::{UInt, Word, Wrapping}; + +impl Neg for Wrapping> { + type Output = Self; + + fn neg(self) -> Self::Output { + let shifted = Wrapping(self.0.shl_vartime(1)); + self - shifted + } +} + +impl UInt { + /// Negates based on `choice` by wrapping the integer. + pub(crate) const fn conditional_wrapping_neg(self, choice: Word) -> UInt { + let (shifted, _) = self.shl_1(); + let negated_self = self.wrapping_sub(&shifted); + + UInt::ct_select(self, negated_self, choice) + } +} diff --git a/src/uint/shl.rs b/src/uint/shl.rs index 9d466913..1dbe0e79 100644 --- a/src/uint/shl.rs +++ b/src/uint/shl.rs @@ -1,9 +1,40 @@ //! [`UInt`] bitwise left shift operations. -use crate::{Limb, UInt, Word}; +use crate::{limb::HI_BIT, Limb, UInt, Word}; use core::ops::{Shl, ShlAssign}; impl UInt { + /// Computes `self << 1` in constant-time, returning the overflowing bit as a `Word` that is either 0...0 or 1...1. + pub(crate) const fn shl_1(&self) -> (Self, Word) { + let mut shifted_bits = [0; LIMBS]; + let mut i = 0; + while i < LIMBS { + shifted_bits[i] = self.limbs[i].0 << 1; + i += 1; + } + + let mut carry_bits = [0; LIMBS]; + let mut i = 0; + while i < LIMBS { + carry_bits[i] = self.limbs[i].0 >> HI_BIT; + i += 1; + } + + let mut limbs = [Limb(0); LIMBS]; + + limbs[0] = Limb(shifted_bits[0]); + let mut i = 1; + while i < LIMBS { + limbs[i] = Limb(shifted_bits[i] | carry_bits[i - 1]); + i += 1; + } + + ( + UInt::new(limbs), + carry_bits[LIMBS - 1].wrapping_mul(Word::MAX), + ) + } + /// Computes `self << shift`. /// /// NOTE: this operation is variable time with respect to `n` *ONLY*. @@ -36,6 +67,26 @@ impl UInt { Self { limbs } } + + /// Computes a left shift on a wide input as `(lo, hi)`. + /// + /// NOTE: this operation is variable time with respect to `n` *ONLY*. + /// + /// When used with a fixed `n`, this function is constant-time with respect + /// to `self`. + #[inline(always)] + pub const fn shl_vartime_wide(lower_upper: (Self, Self), n: usize) -> (Self, Self) { + let (lower, mut upper) = lower_upper; + let new_lower = lower.shl_vartime(n); + upper = upper.shl_vartime(n); + if n >= LIMBS * Limb::BIT_SIZE { + upper = upper.bitor(&lower.shl_vartime(n - LIMBS * Limb::BIT_SIZE)); + } else { + upper = upper.bitor(&lower.shr_vartime(LIMBS * Limb::BIT_SIZE - n)); + } + + (new_lower, upper) + } } impl Shl for UInt { @@ -74,7 +125,7 @@ impl ShlAssign for UInt { #[cfg(test)] mod tests { - use crate::U256; + use crate::{Limb, UInt, U128, U256}; const N: U256 = U256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141"); @@ -131,4 +182,28 @@ mod tests { fn shl64() { assert_eq!(N << 64, SIXTY_FOUR); } + + #[test] + fn shl_wide_1_1_128() { + assert_eq!( + UInt::shl_vartime_wide((U128::ONE, U128::ONE), 128), + (U128::ZERO, U128::ONE) + ); + } + + #[test] + fn shl_wide_max_0_1() { + assert_eq!( + UInt::shl_vartime_wide((U128::MAX, U128::ZERO), 1), + (U128::MAX.sbb(&U128::ONE, Limb::ZERO).0, U128::ONE) + ); + } + + #[test] + fn shl_wide_max_max_256() { + assert_eq!( + UInt::shl_vartime_wide((U128::MAX, U128::MAX), 256), + (U128::ZERO, U128::ZERO) + ); + } } diff --git a/src/uint/shr.rs b/src/uint/shr.rs index 54375ae7..f24dea46 100644 --- a/src/uint/shr.rs +++ b/src/uint/shr.rs @@ -1,10 +1,41 @@ //! [`UInt`] bitwise right shift operations. use super::UInt; -use crate::Limb; +use crate::{limb::HI_BIT, Limb, Word}; use core::ops::{Shr, ShrAssign}; impl UInt { + /// Computes `self >> 1` in constant-time, returning the overflowing bit as a `Word` that is either 0...0 or 1...1. + pub(crate) const fn shr_1(&self) -> (Self, Word) { + let mut shifted_bits = [0; LIMBS]; + let mut i = 0; + while i < LIMBS { + shifted_bits[i] = self.limbs[i].0 >> 1; + i += 1; + } + + let mut carry_bits = [0; LIMBS]; + let mut i = 0; + while i < LIMBS { + carry_bits[i] = self.limbs[i].0 << HI_BIT; + i += 1; + } + + let mut limbs = [Limb(0); LIMBS]; + + let mut i = 0; + while i < (LIMBS - 1) { + limbs[i] = Limb(shifted_bits[i] | carry_bits[i + 1]); + i += 1; + } + limbs[LIMBS - 1] = Limb(shifted_bits[LIMBS - 1]); + + ( + UInt::new(limbs), + (carry_bits[0] >> HI_BIT).wrapping_mul(Word::MAX), + ) + } + /// Computes `self >> n`. /// /// NOTE: this operation is variable time with respect to `n` *ONLY*. @@ -44,6 +75,26 @@ impl UInt { Self { limbs } } + + /// Computes a right shift on a wide input as `(lo, hi)`. + /// + /// NOTE: this operation is variable time with respect to `n` *ONLY*. + /// + /// When used with a fixed `n`, this function is constant-time with respect + /// to `self`. + #[inline(always)] + pub const fn shr_vartime_wide(lower_upper: (Self, Self), n: usize) -> (Self, Self) { + let (mut lower, upper) = lower_upper; + let new_upper = upper.shr_vartime(n); + lower = lower.shr_vartime(n); + if n >= LIMBS * Limb::BIT_SIZE { + lower = lower.bitor(&upper.shr_vartime(n - LIMBS * Limb::BIT_SIZE)); + } else { + lower = lower.bitor(&upper.shl_vartime(LIMBS * Limb::BIT_SIZE - n)); + } + + (lower, new_upper) + } } impl Shr for UInt { @@ -78,7 +129,7 @@ impl ShrAssign for UInt { #[cfg(test)] mod tests { - use crate::U256; + use crate::{UInt, U128, U256}; const N: U256 = U256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141"); @@ -90,4 +141,28 @@ mod tests { fn shr1() { assert_eq!(N >> 1, N_2); } + + #[test] + fn shr_wide_1_1_128() { + assert_eq!( + UInt::shr_vartime_wide((U128::ONE, U128::ONE), 128), + (U128::ONE, U128::ZERO) + ); + } + + #[test] + fn shr_wide_0_max_1() { + assert_eq!( + UInt::shr_vartime_wide((U128::ZERO, U128::MAX), 1), + (U128::ONE << 127, U128::MAX >> 1) + ); + } + + #[test] + fn shr_wide_max_max_256() { + assert_eq!( + UInt::shr_vartime_wide((U128::MAX, U128::MAX), 256), + (U128::ZERO, U128::ZERO) + ); + } } diff --git a/src/uint/sub.rs b/src/uint/sub.rs index 102f6b97..b0d1f49c 100644 --- a/src/uint/sub.rs +++ b/src/uint/sub.rs @@ -1,7 +1,7 @@ //! [`UInt`] addition operations. use super::UInt; -use crate::{Checked, CheckedSub, Limb, Wrapping, Zero}; +use crate::{Checked, CheckedSub, Limb, Word, Wrapping, Zero}; use core::ops::{Sub, SubAssign}; use subtle::CtOption; @@ -38,6 +38,15 @@ impl UInt { pub const fn wrapping_sub(&self, rhs: &Self) -> Self { self.sbb(rhs, Limb::ZERO).0 } + + /// Perform wrapping subtraction, returning the underflow bit as a `Word` that is either 0...0 or 1...1. + pub(crate) const fn conditional_wrapping_sub(&self, rhs: &Self, choice: Word) -> (Self, Word) { + let actual_rhs = UInt::ct_select(UInt::ZERO, *rhs, choice); + let (res, borrow) = self.sbb(&actual_rhs, Limb::ZERO); + + // Here we use a saturating multiplication to get the result to 0...0 or 1...1 + (res, borrow.0.saturating_mul(Word::MAX)) + } } impl CheckedSub<&UInt> for UInt {