Skip to content

Some sqrt() fixes #379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ impl<const LIMBS: usize> Uint<LIMBS> {
/// Total size of the represented integer in bits.
pub const BITS: u32 = LIMBS as u32 * Limb::BITS;

/// Bit size of `BITS`.
/// `floor(log2(Self::BITS))`.
// Note: assumes the type of `BITS` is `u32`. Any way to assert that?
pub(crate) const LOG2_BITS: u32 = u32::BITS - Self::BITS.leading_zeros();
pub(crate) const LOG2_BITS: u32 = u32::BITS - Self::BITS.leading_zeros() - 1;

/// Total size of the represented integer in bytes.
pub const BYTES: usize = LIMBS * Limb::BYTES;
Expand Down
2 changes: 1 addition & 1 deletion src/uint/shl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
let shift = shift % Self::BITS;
let mut result = *self;
let mut i = 0;
while i < Self::LOG2_BITS {
while i < Self::LOG2_BITS + 1 {
let bit = CtChoice::from_u32_lsb((shift >> i) & 1);
result = Uint::ct_select(&result, &result.shl_vartime(1 << i), bit);
i += 1;
Expand Down
2 changes: 1 addition & 1 deletion src/uint/shr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
let shift = shift % Self::BITS;
let mut result = *self;
let mut i = 0;
while i < Self::LOG2_BITS {
while i < Self::LOG2_BITS + 1 {
let bit = CtChoice::from_u32_lsb((shift >> i) & 1);
result = Uint::ct_select(&result, &result.shr_vartime(1 << i), bit);
i += 1;
Expand Down
101 changes: 60 additions & 41 deletions src/uint/sqrt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,65 +5,71 @@ use subtle::{ConstantTimeEq, CtOption};

impl<const LIMBS: usize> Uint<LIMBS> {
/// Computes √(`self`) in constant time.
/// Based on Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
///
/// Callers can check if `self` is a square by squaring the result
pub const fn sqrt(&self) -> Self {
let max_bits = (self.bits() + 1) >> 1;
let cap = Self::ONE.shl(max_bits);
let mut guess = cap; // ≥ √(`self`)
let mut xn = {
let q = self.wrapping_div(&guess);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};
// Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13.
//
// See Hast, "Note on computation of integer square roots"
// for the proof of the sufficiency of the bound on iterations.
// https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf

// The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
let mut x = Self::ONE.shl((self.bits() + 1) >> 1); // ≥ √(`self`)

// Repeat enough times to guarantee result has stabilized.
// See Hast, "Note on computation of integer square roots" for a proof of this bound.
// https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf
let mut i = 0;
while i < Self::LOG2_BITS {
guess = xn;
xn = {
let (q, _, is_some) = self.const_div_rem(&guess);
let q = Self::ct_select(&Self::ZERO, &q, is_some);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};
let mut x_prev = x; // keep the previous iteration in case we need to roll back.

// TODO (#378): the tests indicate that just `Self::LOG2_BITS` may be enough.
while i < Self::LOG2_BITS + 2 {
x_prev = x;

// Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`

let (q, _, is_some) = self.const_div_rem(&x);

// A protection in case `self == 0`, which will make `x == 0`
let q = Self::ct_select(&Self::ZERO, &q, is_some);

x = x.wrapping_add(&q).shr_vartime(1);
i += 1;
}

// at least one of `guess` and `xn` is now equal to √(`self`), so return the minimum
Self::ct_select(&guess, &xn, Uint::ct_gt(&guess, &xn))
// At this point `x_prev == x_{n}` and `x == x_{n+1}`
// where `n == i - 1 == LOG2_BITS + 1 == floor(log2(BITS)) + 1`.
// Thus, according to Hast, `sqrt(self) = min(x_n, x_{n+1})`.
Self::ct_select(&x_prev, &x, Uint::ct_gt(&x_prev, &x))
}

/// Computes √(`self`)
/// Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
///
/// Callers can check if `self` is a square by squaring the result
pub const fn sqrt_vartime(&self) -> Self {
let max_bits = (self.bits_vartime() + 1) >> 1;
let cap = Self::ONE.shl_vartime(max_bits);
let mut guess = cap; // ≥ √(`self`)
let mut xn = {
let q = self.wrapping_div_vartime(&guess);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};
// Note, xn <= guess at this point.

// Repeat while guess decreases.
while guess.cmp_vartime(&xn).is_gt() && !xn.cmp_vartime(&Self::ZERO).is_eq() {
guess = xn;
xn = {
let q = self.wrapping_div_vartime(&guess);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};
// Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13

// The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
let mut x = Self::ONE.shl((self.bits() + 1) >> 1); // ≥ √(`self`)

// Stop right away if `x` is zero to avoid divizion by zero.
while !x.cmp_vartime(&Self::ZERO).is_eq() {
// Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
let q = self.wrapping_div_vartime(&x);
let t = x.wrapping_add(&q);
let next_x = t.shr_vartime(1);

// If `next_x` is the same as `x` or greater, we reached convergence
// (`x` is guaranteed to either go down or oscillate between
// `sqrt(self)` and `sqrt(self) + 1`)
if !x.cmp_vartime(&next_x).is_gt() {
break;
}

x = next_x;
}

if self.ct_is_nonzero().is_true_vartime() {
guess
x
} else {
Self::ZERO
}
Expand Down Expand Up @@ -121,16 +127,29 @@ mod tests {
}
assert_eq!(U256::MAX.sqrt(), half);

// Test edge cases that use up the maximum number of iterations.

// `x = (r + 1)^2 - 583`, where `r` is the expected square root.
assert_eq!(
U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d").sqrt(),
U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
);
assert_eq!(
U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d").sqrt_vartime(),
U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
);

// `x = (r + 1)^2 - 205`, where `r` is the expected square root.
assert_eq!(
U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
.sqrt(),
U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
);
assert_eq!(
U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
.sqrt_vartime(),
U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
);
}

#[test]
Expand Down