Skip to content

Commit 1c57cd1

Browse files
authored
Some sqrt() fixes (#379)
* Disambiguate Uint::LOG2_BITS * Simplify sqrt() loops and add some comments * Add edge case tests for vartime sqrt(), and comments to the tests * Add a TODO for `Uint::sqrt()`
1 parent 7f93018 commit 1c57cd1

File tree

4 files changed

+64
-45
lines changed

4 files changed

+64
-45
lines changed

src/uint.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ impl<const LIMBS: usize> Uint<LIMBS> {
8888
/// Total size of the represented integer in bits.
8989
pub const BITS: u32 = LIMBS as u32 * Limb::BITS;
9090

91-
/// Bit size of `BITS`.
91+
/// `floor(log2(Self::BITS))`.
9292
// Note: assumes the type of `BITS` is `u32`. Any way to assert that?
93-
pub(crate) const LOG2_BITS: u32 = u32::BITS - Self::BITS.leading_zeros();
93+
pub(crate) const LOG2_BITS: u32 = u32::BITS - Self::BITS.leading_zeros() - 1;
9494

9595
/// Total size of the represented integer in bytes.
9696
pub const BYTES: usize = LIMBS * Limb::BYTES;

src/uint/shl.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
8282
let shift = shift % Self::BITS;
8383
let mut result = *self;
8484
let mut i = 0;
85-
while i < Self::LOG2_BITS {
85+
while i < Self::LOG2_BITS + 1 {
8686
let bit = CtChoice::from_u32_lsb((shift >> i) & 1);
8787
result = Uint::ct_select(&result, &result.shl_vartime(1 << i), bit);
8888
i += 1;

src/uint/shr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
105105
let shift = shift % Self::BITS;
106106
let mut result = *self;
107107
let mut i = 0;
108-
while i < Self::LOG2_BITS {
108+
while i < Self::LOG2_BITS + 1 {
109109
let bit = CtChoice::from_u32_lsb((shift >> i) & 1);
110110
result = Uint::ct_select(&result, &result.shr_vartime(1 << i), bit);
111111
i += 1;

src/uint/sqrt.rs

Lines changed: 60 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,65 +5,71 @@ use subtle::{ConstantTimeEq, CtOption};
55

66
impl<const LIMBS: usize> Uint<LIMBS> {
77
/// Computes √(`self`) in constant time.
8-
/// Based on Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
98
///
109
/// Callers can check if `self` is a square by squaring the result
1110
pub const fn sqrt(&self) -> Self {
12-
let max_bits = (self.bits() + 1) >> 1;
13-
let cap = Self::ONE.shl(max_bits);
14-
let mut guess = cap; // ≥ √(`self`)
15-
let mut xn = {
16-
let q = self.wrapping_div(&guess);
17-
let t = guess.wrapping_add(&q);
18-
t.shr_vartime(1)
19-
};
11+
// Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13.
12+
//
13+
// See Hast, "Note on computation of integer square roots"
14+
// for the proof of the sufficiency of the bound on iterations.
15+
// https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf
16+
17+
// The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
18+
let mut x = Self::ONE.shl((self.bits() + 1) >> 1); // ≥ √(`self`)
2019

2120
// Repeat enough times to guarantee result has stabilized.
22-
// See Hast, "Note on computation of integer square roots" for a proof of this bound.
23-
// https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf
2421
let mut i = 0;
25-
while i < Self::LOG2_BITS {
26-
guess = xn;
27-
xn = {
28-
let (q, _, is_some) = self.const_div_rem(&guess);
29-
let q = Self::ct_select(&Self::ZERO, &q, is_some);
30-
let t = guess.wrapping_add(&q);
31-
t.shr_vartime(1)
32-
};
22+
let mut x_prev = x; // keep the previous iteration in case we need to roll back.
23+
24+
// TODO (#378): the tests indicate that just `Self::LOG2_BITS` may be enough.
25+
while i < Self::LOG2_BITS + 2 {
26+
x_prev = x;
27+
28+
// Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
29+
30+
let (q, _, is_some) = self.const_div_rem(&x);
31+
32+
// A protection in case `self == 0`, which will make `x == 0`
33+
let q = Self::ct_select(&Self::ZERO, &q, is_some);
34+
35+
x = x.wrapping_add(&q).shr_vartime(1);
3336
i += 1;
3437
}
3538

36-
// at least one of `guess` and `xn` is now equal to √(`self`), so return the minimum
37-
Self::ct_select(&guess, &xn, Uint::ct_gt(&guess, &xn))
39+
// At this point `x_prev == x_{n}` and `x == x_{n+1}`
40+
// where `n == i - 1 == LOG2_BITS + 1 == floor(log2(BITS)) + 1`.
41+
// Thus, according to Hast, `sqrt(self) = min(x_n, x_{n+1})`.
42+
Self::ct_select(&x_prev, &x, Uint::ct_gt(&x_prev, &x))
3843
}
3944

4045
/// Computes √(`self`)
41-
/// Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
4246
///
4347
/// Callers can check if `self` is a square by squaring the result
4448
pub const fn sqrt_vartime(&self) -> Self {
45-
let max_bits = (self.bits_vartime() + 1) >> 1;
46-
let cap = Self::ONE.shl_vartime(max_bits);
47-
let mut guess = cap; // ≥ √(`self`)
48-
let mut xn = {
49-
let q = self.wrapping_div_vartime(&guess);
50-
let t = guess.wrapping_add(&q);
51-
t.shr_vartime(1)
52-
};
53-
// Note, xn <= guess at this point.
54-
55-
// Repeat while guess decreases.
56-
while guess.cmp_vartime(&xn).is_gt() && !xn.cmp_vartime(&Self::ZERO).is_eq() {
57-
guess = xn;
58-
xn = {
59-
let q = self.wrapping_div_vartime(&guess);
60-
let t = guess.wrapping_add(&q);
61-
t.shr_vartime(1)
62-
};
49+
// Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
50+
51+
// The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
52+
let mut x = Self::ONE.shl((self.bits() + 1) >> 1); // ≥ √(`self`)
53+
54+
// Stop right away if `x` is zero to avoid divizion by zero.
55+
while !x.cmp_vartime(&Self::ZERO).is_eq() {
56+
// Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
57+
let q = self.wrapping_div_vartime(&x);
58+
let t = x.wrapping_add(&q);
59+
let next_x = t.shr_vartime(1);
60+
61+
// If `next_x` is the same as `x` or greater, we reached convergence
62+
// (`x` is guaranteed to either go down or oscillate between
63+
// `sqrt(self)` and `sqrt(self) + 1`)
64+
if !x.cmp_vartime(&next_x).is_gt() {
65+
break;
66+
}
67+
68+
x = next_x;
6369
}
6470

6571
if self.ct_is_nonzero().is_true_vartime() {
66-
guess
72+
x
6773
} else {
6874
Self::ZERO
6975
}
@@ -121,16 +127,29 @@ mod tests {
121127
}
122128
assert_eq!(U256::MAX.sqrt(), half);
123129

130+
// Test edge cases that use up the maximum number of iterations.
131+
132+
// `x = (r + 1)^2 - 583`, where `r` is the expected square root.
124133
assert_eq!(
125134
U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d").sqrt(),
126135
U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
127136
);
137+
assert_eq!(
138+
U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d").sqrt_vartime(),
139+
U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
140+
);
128141

142+
// `x = (r + 1)^2 - 205`, where `r` is the expected square root.
129143
assert_eq!(
130144
U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
131145
.sqrt(),
132146
U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
133147
);
148+
assert_eq!(
149+
U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
150+
.sqrt_vartime(),
151+
U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
152+
);
134153
}
135154

136155
#[test]

0 commit comments

Comments
 (0)