diff --git a/src/uint/div.rs b/src/uint/div.rs index 7f5cda73..fb8395f1 100644 --- a/src/uint/div.rs +++ b/src/uint/div.rs @@ -41,6 +41,42 @@ impl Uint { (quo, rem) } + /// Computes `self` / `rhs`, returns the quotient (q), remainder (r) + /// and the truthy value for is_some or the falsy value for is_none. + /// + /// NOTE: Use only if you need to access const fn. Otherwise use [`Self::div_rem`] because + /// the value for is_some needs to be checked before using `q` and `r`. + /// + /// This function is constant-time with respect to both `self` and `rhs`. + pub(crate) const fn const_div_rem(&self, rhs: &Self) -> (Self, Self, CtChoice) { + let mb = rhs.bits(); + let mut rem = *self; + let mut quo = Self::ZERO; + let mut c = rhs.shl(Self::BITS - mb); + + let mut i = Self::BITS; + let mut done = CtChoice::FALSE; + loop { + let (mut r, borrow) = rem.sbb(&c, Limb::ZERO); + rem = Self::ct_select(&r, &rem, CtChoice::from_mask(borrow.0).or(done)); + r = quo.bitor(&Self::ONE); + quo = Self::ct_select(&r, &quo, CtChoice::from_mask(borrow.0).or(done)); + if i == 0 { + break; + } + i -= 1; + // when `i < mb`, the computation is actually done, so we ensure `quo` and `rem` + // aren't modified further (but do the remaining iterations anyway to be constant-time) + done = Limb::ct_lt(Limb(i as Word), Limb(mb as Word)); + c = c.shr_vartime(1); + quo = Self::ct_select(&quo.shl_vartime(1), &quo, done); + } + + let is_some = Limb(mb as Word).ct_is_nonzero(); + quo = Self::ct_select(&Self::ZERO, &quo, is_some); + (quo, rem, is_some) + } + /// Computes `self` / `rhs`, returns the quotient (q), remainder (r) /// and the truthy value for is_some or the falsy value for is_none. /// @@ -51,7 +87,7 @@ impl Uint { /// /// When used with a fixed `rhs`, this function is constant-time with respect /// to `self`. - pub(crate) const fn ct_div_rem(&self, rhs: &Self) -> (Self, Self, CtChoice) { + pub(crate) const fn const_div_rem_vartime(&self, rhs: &Self) -> (Self, Self, CtChoice) { let mb = rhs.bits_vartime(); let mut bd = Self::BITS - mb; let mut rem = *self; @@ -168,7 +204,14 @@ impl Uint { /// Computes self / rhs, returns the quotient, remainder. pub fn div_rem(&self, rhs: &NonZero) -> (Self, Self) { // Since `rhs` is nonzero, this should always hold. - let (q, r, _c) = self.ct_div_rem(rhs); + let (q, r, _c) = self.const_div_rem(rhs); + (q, r) + } + + /// Computes self / rhs, returns the quotient, remainder. Constant-time only for fixed `rhs`. + pub fn div_rem_vartime(&self, rhs: &NonZero) -> (Self, Self) { + // Since `rhs` is nonzero, this should always hold. + let (q, r, _c) = self.const_div_rem_vartime(rhs); (q, r) } @@ -185,7 +228,18 @@ impl Uint { /// /// Panics if `rhs == 0`. pub const fn wrapping_div(&self, rhs: &Self) -> Self { - let (q, _, c) = self.ct_div_rem(rhs); + let (q, _, c) = self.const_div_rem(rhs); + assert!(c.is_true_vartime(), "divide by zero"); + q + } + + /// Wrapped division is just normal division i.e. `self` / `rhs` + /// There’s no way wrapping could ever happen. + /// This function exists, so that all operations are accounted for in the wrapping operations. + /// + /// Panics if `rhs == 0`. Constant-time only for fixed `rhs`. + pub const fn wrapping_div_vartime(&self, rhs: &Self) -> Self { + let (q, _, c) = self.const_div_rem_vartime(rhs); assert!(c.is_true_vartime(), "divide by zero"); q } @@ -609,7 +663,11 @@ mod tests { ] { let lhs = U256::from(*n); let rhs = U256::from(*d); - let (q, r, is_some) = lhs.ct_div_rem(&rhs); + let (q, r, is_some) = lhs.const_div_rem(&rhs); + assert!(is_some.is_true_vartime()); + assert_eq!(U256::from(*e), q); + assert_eq!(U256::from(*ee), r); + let (q, r, is_some) = lhs.const_div_rem_vartime(&rhs); assert!(is_some.is_true_vartime()); assert_eq!(U256::from(*e), q); assert_eq!(U256::from(*ee), r); @@ -625,7 +683,10 @@ mod tests { let den = U256::random(&mut rng).shr_vartime(128); let n = num.checked_mul(&den); if n.is_some().into() { - let (q, _, is_some) = n.unwrap().ct_div_rem(&den); + let (q, _, is_some) = n.unwrap().const_div_rem(&den); + assert!(is_some.is_true_vartime()); + assert_eq!(q, num); + let (q, _, is_some) = n.unwrap().const_div_rem_vartime(&den); assert!(is_some.is_true_vartime()); assert_eq!(q, num); } @@ -647,7 +708,11 @@ mod tests { #[test] fn div_zero() { - let (q, r, is_some) = U256::ONE.ct_div_rem(&U256::ZERO); + let (q, r, is_some) = U256::ONE.const_div_rem(&U256::ZERO); + assert!(!is_some.is_true_vartime()); + assert_eq!(q, U256::ZERO); + assert_eq!(r, U256::ONE); + let (q, r, is_some) = U256::ONE.const_div_rem_vartime(&U256::ZERO); assert!(!is_some.is_true_vartime()); assert_eq!(q, U256::ZERO); assert_eq!(r, U256::ONE); @@ -655,7 +720,11 @@ mod tests { #[test] fn div_one() { - let (q, r, is_some) = U256::from(10u8).ct_div_rem(&U256::ONE); + let (q, r, is_some) = U256::from(10u8).const_div_rem(&U256::ONE); + assert!(is_some.is_true_vartime()); + assert_eq!(q, U256::from(10u8)); + assert_eq!(r, U256::ZERO); + let (q, r, is_some) = U256::from(10u8).const_div_rem_vartime(&U256::ONE); assert!(is_some.is_true_vartime()); assert_eq!(q, U256::from(10u8)); assert_eq!(r, U256::ZERO); diff --git a/src/uint/sqrt.rs b/src/uint/sqrt.rs index 5c96afb1..00c56181 100644 --- a/src/uint/sqrt.rs +++ b/src/uint/sqrt.rs @@ -1,17 +1,39 @@ //! [`Uint`] square root operations. use super::Uint; -use crate::{Limb, Word}; use subtle::{ConstantTimeEq, CtOption}; impl Uint { - /// See [`Self::sqrt_vartime`]. - #[deprecated( - since = "0.5.3", - note = "This functionality will be moved to `sqrt_vartime` in a future release." - )] + /// Computes √(`self`) in constant time. + /// Based on Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13 + /// + /// Callers can check if `self` is a square by squaring the result pub const fn sqrt(&self) -> Self { - self.sqrt_vartime() + let max_bits = (self.bits() + 1) >> 1; + let cap = Self::ONE.shl(max_bits); + let mut guess = cap; // ≥ √(`self`) + let mut xn = { + let q = self.wrapping_div(&guess); + let t = guess.wrapping_add(&q); + t.shr_vartime(1) + }; + + // Repeat enough times to guarantee result has stabilized. + // See Hast, "Note on computation of integer square roots" for a proof of this bound. + let mut i = 0; + while i < usize::BITS - Self::BITS.leading_zeros() { + guess = xn; + xn = { + let (q, _, is_some) = self.const_div_rem(&guess); + let q = Self::ct_select(&Self::ZERO, &q, is_some); + let t = guess.wrapping_add(&q); + t.shr_vartime(1) + }; + i += 1; + } + + // at least one of `guess` and `xn` is now equal to √(`self`), so return the minimum + Self::ct_select(&guess, &xn, Uint::ct_gt(&guess, &xn)) } /// Computes √(`self`) @@ -23,31 +45,17 @@ impl Uint { let cap = Self::ONE.shl_vartime(max_bits); let mut guess = cap; // ≥ √(`self`) let mut xn = { - let q = self.wrapping_div(&guess); + let q = self.wrapping_div_vartime(&guess); let t = guess.wrapping_add(&q); t.shr_vartime(1) }; - - // If guess increased, the initial guess was low. - // Repeat until reverse course. - while Uint::ct_lt(&guess, &xn).is_true_vartime() { - // Sometimes an increase is too far, especially with large - // powers, and then takes a long time to walk back. The upper - // bound is based on bit size, so saturate on that. - let le = Limb::ct_le(Limb(xn.bits_vartime() as Word), Limb(max_bits as Word)); - guess = Self::ct_select(&cap, &xn, le); - xn = { - let q = self.wrapping_div(&guess); - let t = guess.wrapping_add(&q); - t.shr_vartime(1) - }; - } + // Note, xn <= guess at this point. // Repeat while guess decreases. while Uint::ct_gt(&guess, &xn).is_true_vartime() && xn.ct_is_nonzero().is_true_vartime() { guess = xn; xn = { - let q = self.wrapping_div(&guess); + let q = self.wrapping_div_vartime(&guess); let t = guess.wrapping_add(&q); t.shr_vartime(1) }; @@ -56,29 +64,26 @@ impl Uint { Self::ct_select(&Self::ZERO, &guess, self.ct_is_nonzero()) } - /// See [`Self::wrapping_sqrt_vartime`]. - #[deprecated( - since = "0.5.3", - note = "This functionality will be moved to `wrapping_sqrt_vartime` in a future release." - )] + /// Wrapped sqrt is just normal √(`self`) + /// There’s no way wrapping could ever happen. + /// This function exists so that all operations are accounted for in the wrapping operations. pub const fn wrapping_sqrt(&self) -> Self { - self.wrapping_sqrt_vartime() + self.sqrt() } /// Wrapped sqrt is just normal √(`self`) /// There’s no way wrapping could ever happen. - /// This function exists, so that all operations are accounted for in the wrapping operations. + /// This function exists so that all operations are accounted for in the wrapping operations. pub const fn wrapping_sqrt_vartime(&self) -> Self { self.sqrt_vartime() } - /// See [`Self::checked_sqrt_vartime`]. - #[deprecated( - since = "0.5.3", - note = "This functionality will be moved to `checked_sqrt_vartime` in a future release." - )] + /// Perform checked sqrt, returning a [`CtOption`] which `is_some` + /// only if the √(`self`)² == self pub fn checked_sqrt(&self) -> CtOption { - self.checked_sqrt_vartime() + let r = self.sqrt(); + let s = r.wrapping_mul(&r); + CtOption::new(r, ConstantTimeEq::ct_eq(self, &s)) } /// Perform checked sqrt, returning a [`CtOption`] which `is_some` @@ -103,13 +108,24 @@ mod tests { #[test] fn edge() { + assert_eq!(U256::ZERO.sqrt(), U256::ZERO); + assert_eq!(U256::ONE.sqrt(), U256::ONE); + let mut half = U256::ZERO; + for i in 0..half.limbs.len() / 2 { + half.limbs[i] = Limb::MAX; + } + assert_eq!(U256::MAX.sqrt(), half); + } + + #[test] + fn edge_vartime() { assert_eq!(U256::ZERO.sqrt_vartime(), U256::ZERO); assert_eq!(U256::ONE.sqrt_vartime(), U256::ONE); let mut half = U256::ZERO; for i in 0..half.limbs.len() / 2 { half.limbs[i] = Limb::MAX; } - assert_eq!(U256::MAX.sqrt_vartime(), half,); + assert_eq!(U256::MAX.sqrt_vartime(), half); } #[test] @@ -131,13 +147,28 @@ mod tests { for (a, e) in &tests { let l = U256::from(*a); let r = U256::from(*e); + assert_eq!(l.sqrt(), r); assert_eq!(l.sqrt_vartime(), r); + assert_eq!(l.checked_sqrt().is_some().unwrap_u8(), 1u8); assert_eq!(l.checked_sqrt_vartime().is_some().unwrap_u8(), 1u8); } } #[test] fn nonsquares() { + assert_eq!(U256::from(2u8).sqrt(), U256::from(1u8)); + assert_eq!(U256::from(2u8).checked_sqrt().is_some().unwrap_u8(), 0); + assert_eq!(U256::from(3u8).sqrt(), U256::from(1u8)); + assert_eq!(U256::from(3u8).checked_sqrt().is_some().unwrap_u8(), 0); + assert_eq!(U256::from(5u8).sqrt(), U256::from(2u8)); + assert_eq!(U256::from(6u8).sqrt(), U256::from(2u8)); + assert_eq!(U256::from(7u8).sqrt(), U256::from(2u8)); + assert_eq!(U256::from(8u8).sqrt(), U256::from(2u8)); + assert_eq!(U256::from(10u8).sqrt(), U256::from(3u8)); + } + + #[test] + fn nonsquares_vartime() { assert_eq!(U256::from(2u8).sqrt_vartime(), U256::from(1u8)); assert_eq!( U256::from(2u8).checked_sqrt_vartime().is_some().unwrap_u8(), @@ -163,7 +194,9 @@ mod tests { let t = rng.next_u32() as u64; let s = U256::from(t); let s2 = s.checked_mul(&s).unwrap(); + assert_eq!(s2.sqrt(), s); assert_eq!(s2.sqrt_vartime(), s); + assert_eq!(s2.checked_sqrt().is_some().unwrap_u8(), 1); assert_eq!(s2.checked_sqrt_vartime().is_some().unwrap_u8(), 1); } @@ -171,6 +204,7 @@ mod tests { let s = U256::random(&mut rng); let mut s2 = U512::ZERO; s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs); + assert_eq!(s.square().sqrt(), s2); assert_eq!(s.square().sqrt_vartime(), s2); } } diff --git a/tests/proptests.rs b/tests/proptests.rs index bad14bc1..695f0b1e 100644 --- a/tests/proptests.rs +++ b/tests/proptests.rs @@ -170,8 +170,9 @@ proptest! { if !b_bi.is_zero() { let expected = to_uint(a_bi / b_bi); let actual = a.wrapping_div(&b); - assert_eq!(expected, actual); + let actual_vartime = a.wrapping_div_vartime(&b); + assert_eq!(expected, actual_vartime); } } @@ -254,9 +255,10 @@ proptest! { fn wrapping_sqrt(a in uint()) { let a_bi = to_biguint(&a); let expected = to_uint(a_bi.sqrt()); - let actual = a.wrapping_sqrt_vartime(); - - assert_eq!(expected, actual); + let actual_ct = a.wrapping_sqrt(); + assert_eq!(expected, actual_ct); + let actual_vartime = a.wrapping_sqrt_vartime(); + assert_eq!(expected, actual_vartime); } #[test]