@@ -5,65 +5,71 @@ use subtle::{ConstantTimeEq, CtOption};
5
5
6
6
impl < const LIMBS : usize > Uint < LIMBS > {
7
7
/// Computes √(`self`) in constant time.
8
- /// Based on Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
9
8
///
10
9
/// Callers can check if `self` is a square by squaring the result
11
10
pub const fn sqrt ( & self ) -> Self {
12
- let max_bits = ( self . bits ( ) + 1 ) >> 1 ;
13
- let cap = Self :: ONE . shl ( max_bits ) ;
14
- let mut guess = cap ; // ≥ √(`self`)
15
- let mut xn = {
16
- let q = self . wrapping_div ( & guess ) ;
17
- let t = guess . wrapping_add ( & q ) ;
18
- t . shr_vartime ( 1 )
19
- } ;
11
+ // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13.
12
+ //
13
+ // See Hast, "Note on computation of integer square roots"
14
+ // for the proof of the sufficiency of the bound on iterations.
15
+ // https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf
16
+
17
+ // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
18
+ let mut x = Self :: ONE . shl ( ( self . bits ( ) + 1 ) >> 1 ) ; // ≥ √(`self`)
20
19
21
20
// Repeat enough times to guarantee result has stabilized.
22
- // See Hast, "Note on computation of integer square roots" for a proof of this bound.
23
- // https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf
24
21
let mut i = 0 ;
25
- while i < Self :: LOG2_BITS {
26
- guess = xn;
27
- xn = {
28
- let ( q, _, is_some) = self . const_div_rem ( & guess) ;
29
- let q = Self :: ct_select ( & Self :: ZERO , & q, is_some) ;
30
- let t = guess. wrapping_add ( & q) ;
31
- t. shr_vartime ( 1 )
32
- } ;
22
+ let mut x_prev = x; // keep the previous iteration in case we need to roll back.
23
+
24
+ // TODO (#378): the tests indicate that just `Self::LOG2_BITS` may be enough.
25
+ while i < Self :: LOG2_BITS + 2 {
26
+ x_prev = x;
27
+
28
+ // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
29
+
30
+ let ( q, _, is_some) = self . const_div_rem ( & x) ;
31
+
32
+ // A protection in case `self == 0`, which will make `x == 0`
33
+ let q = Self :: ct_select ( & Self :: ZERO , & q, is_some) ;
34
+
35
+ x = x. wrapping_add ( & q) . shr_vartime ( 1 ) ;
33
36
i += 1 ;
34
37
}
35
38
36
- // at least one of `guess` and `xn` is now equal to √(`self`), so return the minimum
37
- Self :: ct_select ( & guess, & xn, Uint :: ct_gt ( & guess, & xn) )
39
+ // At this point `x_prev == x_{n}` and `x == x_{n+1}`
40
+ // where `n == i - 1 == LOG2_BITS + 1 == floor(log2(BITS)) + 1`.
41
+ // Thus, according to Hast, `sqrt(self) = min(x_n, x_{n+1})`.
42
+ Self :: ct_select ( & x_prev, & x, Uint :: ct_gt ( & x_prev, & x) )
38
43
}
39
44
40
45
/// Computes √(`self`)
41
- /// Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
42
46
///
43
47
/// Callers can check if `self` is a square by squaring the result
44
48
pub const fn sqrt_vartime ( & self ) -> Self {
45
- let max_bits = ( self . bits_vartime ( ) + 1 ) >> 1 ;
46
- let cap = Self :: ONE . shl_vartime ( max_bits) ;
47
- let mut guess = cap; // ≥ √(`self`)
48
- let mut xn = {
49
- let q = self . wrapping_div_vartime ( & guess) ;
50
- let t = guess. wrapping_add ( & q) ;
51
- t. shr_vartime ( 1 )
52
- } ;
53
- // Note, xn <= guess at this point.
54
-
55
- // Repeat while guess decreases.
56
- while guess. cmp_vartime ( & xn) . is_gt ( ) && !xn. cmp_vartime ( & Self :: ZERO ) . is_eq ( ) {
57
- guess = xn;
58
- xn = {
59
- let q = self . wrapping_div_vartime ( & guess) ;
60
- let t = guess. wrapping_add ( & q) ;
61
- t. shr_vartime ( 1 )
62
- } ;
49
+ // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
50
+
51
+ // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
52
+ let mut x = Self :: ONE . shl ( ( self . bits ( ) + 1 ) >> 1 ) ; // ≥ √(`self`)
53
+
54
+ // Stop right away if `x` is zero to avoid divizion by zero.
55
+ while !x. cmp_vartime ( & Self :: ZERO ) . is_eq ( ) {
56
+ // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
57
+ let q = self . wrapping_div_vartime ( & x) ;
58
+ let t = x. wrapping_add ( & q) ;
59
+ let next_x = t. shr_vartime ( 1 ) ;
60
+
61
+ // If `next_x` is the same as `x` or greater, we reached convergence
62
+ // (`x` is guaranteed to either go down or oscillate between
63
+ // `sqrt(self)` and `sqrt(self) + 1`)
64
+ if !x. cmp_vartime ( & next_x) . is_gt ( ) {
65
+ break ;
66
+ }
67
+
68
+ x = next_x;
63
69
}
64
70
65
71
if self . ct_is_nonzero ( ) . is_true_vartime ( ) {
66
- guess
72
+ x
67
73
} else {
68
74
Self :: ZERO
69
75
}
@@ -121,16 +127,29 @@ mod tests {
121
127
}
122
128
assert_eq ! ( U256 :: MAX . sqrt( ) , half) ;
123
129
130
+ // Test edge cases that use up the maximum number of iterations.
131
+
132
+ // `x = (r + 1)^2 - 583`, where `r` is the expected square root.
124
133
assert_eq ! (
125
134
U192 :: from_be_hex( "055fa39422bd9f281762946e056535badbf8a6864d45fa3d" ) . sqrt( ) ,
126
135
U192 :: from_be_hex( "0000000000000000000000002516f0832a538b2d98869e21" )
127
136
) ;
137
+ assert_eq ! (
138
+ U192 :: from_be_hex( "055fa39422bd9f281762946e056535badbf8a6864d45fa3d" ) . sqrt_vartime( ) ,
139
+ U192 :: from_be_hex( "0000000000000000000000002516f0832a538b2d98869e21" )
140
+ ) ;
128
141
142
+ // `x = (r + 1)^2 - 205`, where `r` is the expected square root.
129
143
assert_eq ! (
130
144
U256 :: from_be_hex( "4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597" )
131
145
. sqrt( ) ,
132
146
U256 :: from_be_hex( "000000000000000000000000000000008b3956339e8315cff66eb6107b610075" )
133
147
) ;
148
+ assert_eq ! (
149
+ U256 :: from_be_hex( "4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597" )
150
+ . sqrt_vartime( ) ,
151
+ U256 :: from_be_hex( "000000000000000000000000000000008b3956339e8315cff66eb6107b610075" )
152
+ ) ;
134
153
}
135
154
136
155
#[ test]
0 commit comments