diff --git a/src/uint.rs b/src/uint.rs index fd116fc4c..35a285f40 100644 --- a/src/uint.rs +++ b/src/uint.rs @@ -88,9 +88,9 @@ impl Uint { /// Total size of the represented integer in bits. pub const BITS: u32 = LIMBS as u32 * Limb::BITS; - /// Bit size of `BITS`. + /// `floor(log2(Self::BITS))`. // Note: assumes the type of `BITS` is `u32`. Any way to assert that? - pub(crate) const LOG2_BITS: u32 = u32::BITS - Self::BITS.leading_zeros(); + pub(crate) const LOG2_BITS: u32 = u32::BITS - Self::BITS.leading_zeros() - 1; /// Total size of the represented integer in bytes. pub const BYTES: usize = LIMBS * Limb::BYTES; diff --git a/src/uint/shl.rs b/src/uint/shl.rs index 97030b5a1..4509d61ef 100644 --- a/src/uint/shl.rs +++ b/src/uint/shl.rs @@ -82,7 +82,7 @@ impl Uint { let shift = shift % Self::BITS; let mut result = *self; let mut i = 0; - while i < Self::LOG2_BITS { + while i < Self::LOG2_BITS + 1 { let bit = CtChoice::from_u32_lsb((shift >> i) & 1); result = Uint::ct_select(&result, &result.shl_vartime(1 << i), bit); i += 1; diff --git a/src/uint/shr.rs b/src/uint/shr.rs index d49ad3afb..758714138 100644 --- a/src/uint/shr.rs +++ b/src/uint/shr.rs @@ -105,7 +105,7 @@ impl Uint { let shift = shift % Self::BITS; let mut result = *self; let mut i = 0; - while i < Self::LOG2_BITS { + while i < Self::LOG2_BITS + 1 { let bit = CtChoice::from_u32_lsb((shift >> i) & 1); result = Uint::ct_select(&result, &result.shr_vartime(1 << i), bit); i += 1; diff --git a/src/uint/sqrt.rs b/src/uint/sqrt.rs index 6e56e0066..929524de6 100644 --- a/src/uint/sqrt.rs +++ b/src/uint/sqrt.rs @@ -5,65 +5,71 @@ use subtle::{ConstantTimeEq, CtOption}; impl Uint { /// Computes √(`self`) in constant time. - /// Based on Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13 /// /// Callers can check if `self` is a square by squaring the result pub const fn sqrt(&self) -> Self { - let max_bits = (self.bits() + 1) >> 1; - let cap = Self::ONE.shl(max_bits); - let mut guess = cap; // ≥ √(`self`) - let mut xn = { - let q = self.wrapping_div(&guess); - let t = guess.wrapping_add(&q); - t.shr_vartime(1) - }; + // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13. + // + // See Hast, "Note on computation of integer square roots" + // for the proof of the sufficiency of the bound on iterations. + // https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf + + // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`. + let mut x = Self::ONE.shl((self.bits() + 1) >> 1); // ≥ √(`self`) // Repeat enough times to guarantee result has stabilized. - // See Hast, "Note on computation of integer square roots" for a proof of this bound. - // https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf let mut i = 0; - while i < Self::LOG2_BITS { - guess = xn; - xn = { - let (q, _, is_some) = self.const_div_rem(&guess); - let q = Self::ct_select(&Self::ZERO, &q, is_some); - let t = guess.wrapping_add(&q); - t.shr_vartime(1) - }; + let mut x_prev = x; // keep the previous iteration in case we need to roll back. + + // TODO (#378): the tests indicate that just `Self::LOG2_BITS` may be enough. + while i < Self::LOG2_BITS + 2 { + x_prev = x; + + // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)` + + let (q, _, is_some) = self.const_div_rem(&x); + + // A protection in case `self == 0`, which will make `x == 0` + let q = Self::ct_select(&Self::ZERO, &q, is_some); + + x = x.wrapping_add(&q).shr_vartime(1); i += 1; } - // at least one of `guess` and `xn` is now equal to √(`self`), so return the minimum - Self::ct_select(&guess, &xn, Uint::ct_gt(&guess, &xn)) + // At this point `x_prev == x_{n}` and `x == x_{n+1}` + // where `n == i - 1 == LOG2_BITS + 1 == floor(log2(BITS)) + 1`. + // Thus, according to Hast, `sqrt(self) = min(x_n, x_{n+1})`. + Self::ct_select(&x_prev, &x, Uint::ct_gt(&x_prev, &x)) } /// Computes √(`self`) - /// Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13 /// /// Callers can check if `self` is a square by squaring the result pub const fn sqrt_vartime(&self) -> Self { - let max_bits = (self.bits_vartime() + 1) >> 1; - let cap = Self::ONE.shl_vartime(max_bits); - let mut guess = cap; // ≥ √(`self`) - let mut xn = { - let q = self.wrapping_div_vartime(&guess); - let t = guess.wrapping_add(&q); - t.shr_vartime(1) - }; - // Note, xn <= guess at this point. - - // Repeat while guess decreases. - while guess.cmp_vartime(&xn).is_gt() && !xn.cmp_vartime(&Self::ZERO).is_eq() { - guess = xn; - xn = { - let q = self.wrapping_div_vartime(&guess); - let t = guess.wrapping_add(&q); - t.shr_vartime(1) - }; + // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13 + + // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`. + let mut x = Self::ONE.shl((self.bits() + 1) >> 1); // ≥ √(`self`) + + // Stop right away if `x` is zero to avoid divizion by zero. + while !x.cmp_vartime(&Self::ZERO).is_eq() { + // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)` + let q = self.wrapping_div_vartime(&x); + let t = x.wrapping_add(&q); + let next_x = t.shr_vartime(1); + + // If `next_x` is the same as `x` or greater, we reached convergence + // (`x` is guaranteed to either go down or oscillate between + // `sqrt(self)` and `sqrt(self) + 1`) + if !x.cmp_vartime(&next_x).is_gt() { + break; + } + + x = next_x; } if self.ct_is_nonzero().is_true_vartime() { - guess + x } else { Self::ZERO } @@ -121,16 +127,29 @@ mod tests { } assert_eq!(U256::MAX.sqrt(), half); + // Test edge cases that use up the maximum number of iterations. + + // `x = (r + 1)^2 - 583`, where `r` is the expected square root. assert_eq!( U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d").sqrt(), U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21") ); + assert_eq!( + U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d").sqrt_vartime(), + U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21") + ); + // `x = (r + 1)^2 - 205`, where `r` is the expected square root. assert_eq!( U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597") .sqrt(), U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075") ); + assert_eq!( + U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597") + .sqrt_vartime(), + U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075") + ); } #[test]