1
1
//! [`Uint`] square root operations.
2
2
3
3
use super :: Uint ;
4
- use crate :: CtChoice ;
5
4
use subtle:: { ConstantTimeEq , CtOption } ;
6
5
7
6
impl < const LIMBS : usize > Uint < LIMBS > {
8
- /// See [`Self::sqrt_vartime`].
9
- #[ deprecated(
10
- since = "0.5.3" ,
11
- note = "This functionality will be moved to `sqrt_vartime` in a future release."
12
- ) ]
7
+ /// Computes √(`self`) in constant time.
8
+ /// Based on Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
9
+ ///
10
+ /// Callers can check if `self` is a square by squaring the result
13
11
pub const fn sqrt ( & self ) -> Self {
14
- self . sqrt_vartime ( )
12
+ let max_bits = ( self . bits ( ) + 1 ) >> 1 ;
13
+ let cap = Self :: ONE . shl ( max_bits) ;
14
+ let mut guess = cap; // ≥ √(`self`)
15
+ let mut xn = {
16
+ let q = self . wrapping_div ( & guess) ;
17
+ let t = guess. wrapping_add ( & q) ;
18
+ t. shr_vartime ( 1 )
19
+ } ;
20
+
21
+ // Repeat enough times to guarantee result has stabilized.
22
+ // See Hast, "Note on computation of integer square roots" for a proof of this bound.
23
+ // https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf
24
+ let mut i = 0 ;
25
+ while i < Self :: LOG2_BITS {
26
+ guess = xn;
27
+ xn = {
28
+ let ( q, _, is_some) = self . const_div_rem ( & guess) ;
29
+ let q = Self :: ct_select ( & Self :: ZERO , & q, is_some) ;
30
+ let t = guess. wrapping_add ( & q) ;
31
+ t. shr_vartime ( 1 )
32
+ } ;
33
+ i += 1 ;
34
+ }
35
+
36
+ // at least one of `guess` and `xn` is now equal to √(`self`), so return the minimum
37
+ Self :: ct_select ( & guess, & xn, Uint :: ct_gt ( & guess, & xn) )
15
38
}
16
39
17
40
/// Computes √(`self`)
@@ -23,62 +46,49 @@ impl<const LIMBS: usize> Uint<LIMBS> {
23
46
let cap = Self :: ONE . shl_vartime ( max_bits) ;
24
47
let mut guess = cap; // ≥ √(`self`)
25
48
let mut xn = {
26
- let q = self . wrapping_div ( & guess) ;
49
+ let q = self . wrapping_div_vartime ( & guess) ;
27
50
let t = guess. wrapping_add ( & q) ;
28
51
t. shr_vartime ( 1 )
29
52
} ;
30
-
31
- // If guess increased, the initial guess was low.
32
- // Repeat until reverse course.
33
- while Uint :: ct_lt ( & guess, & xn) . is_true_vartime ( ) {
34
- // Sometimes an increase is too far, especially with large
35
- // powers, and then takes a long time to walk back. The upper
36
- // bound is based on bit size, so saturate on that.
37
- let le = CtChoice :: from_u32_le ( xn. bits_vartime ( ) , max_bits) ;
38
- guess = Self :: ct_select ( & cap, & xn, le) ;
39
- xn = {
40
- let q = self . wrapping_div ( & guess) ;
41
- let t = guess. wrapping_add ( & q) ;
42
- t. shr_vartime ( 1 )
43
- } ;
44
- }
53
+ // Note, xn <= guess at this point.
45
54
46
55
// Repeat while guess decreases.
47
- while Uint :: ct_gt ( & guess , & xn) . is_true_vartime ( ) && xn. ct_is_nonzero ( ) . is_true_vartime ( ) {
56
+ while guess . cmp_vartime ( & xn) . is_gt ( ) && ! xn. cmp_vartime ( & Self :: ZERO ) . is_eq ( ) {
48
57
guess = xn;
49
58
xn = {
50
- let q = self . wrapping_div ( & guess) ;
59
+ let q = self . wrapping_div_vartime ( & guess) ;
51
60
let t = guess. wrapping_add ( & q) ;
52
61
t. shr_vartime ( 1 )
53
62
} ;
54
63
}
55
64
56
- Self :: ct_select ( & Self :: ZERO , & guess, self . ct_is_nonzero ( ) )
65
+ if self . ct_is_nonzero ( ) . is_true_vartime ( ) {
66
+ guess
67
+ } else {
68
+ Self :: ZERO
69
+ }
57
70
}
58
71
59
- /// See [`Self::wrapping_sqrt_vartime`].
60
- #[ deprecated(
61
- since = "0.5.3" ,
62
- note = "This functionality will be moved to `wrapping_sqrt_vartime` in a future release."
63
- ) ]
72
+ /// Wrapped sqrt is just normal √(`self`)
73
+ /// There’s no way wrapping could ever happen.
74
+ /// This function exists so that all operations are accounted for in the wrapping operations.
64
75
pub const fn wrapping_sqrt ( & self ) -> Self {
65
- self . wrapping_sqrt_vartime ( )
76
+ self . sqrt ( )
66
77
}
67
78
68
79
/// Wrapped sqrt is just normal √(`self`)
69
80
/// There’s no way wrapping could ever happen.
70
- /// This function exists, so that all operations are accounted for in the wrapping operations.
81
+ /// This function exists so that all operations are accounted for in the wrapping operations.
71
82
pub const fn wrapping_sqrt_vartime ( & self ) -> Self {
72
83
self . sqrt_vartime ( )
73
84
}
74
85
75
- /// See [`Self::checked_sqrt_vartime`].
76
- #[ deprecated(
77
- since = "0.5.3" ,
78
- note = "This functionality will be moved to `checked_sqrt_vartime` in a future release."
79
- ) ]
86
+ /// Perform checked sqrt, returning a [`CtOption`] which `is_some`
87
+ /// only if the √(`self`)² == self
80
88
pub fn checked_sqrt ( & self ) -> CtOption < Self > {
81
- self . checked_sqrt_vartime ( )
89
+ let r = self . sqrt ( ) ;
90
+ let s = r. wrapping_mul ( & r) ;
91
+ CtOption :: new ( r, ConstantTimeEq :: ct_eq ( self , & s) )
82
92
}
83
93
84
94
/// Perform checked sqrt, returning a [`CtOption`] which `is_some`
@@ -92,7 +102,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
92
102
93
103
#[ cfg( test) ]
94
104
mod tests {
95
- use crate :: { Limb , U256 } ;
105
+ use crate :: { Limb , U192 , U256 } ;
96
106
97
107
#[ cfg( feature = "rand" ) ]
98
108
use {
@@ -103,13 +113,35 @@ mod tests {
103
113
104
114
#[ test]
105
115
fn edge ( ) {
116
+ assert_eq ! ( U256 :: ZERO . sqrt( ) , U256 :: ZERO ) ;
117
+ assert_eq ! ( U256 :: ONE . sqrt( ) , U256 :: ONE ) ;
118
+ let mut half = U256 :: ZERO ;
119
+ for i in 0 ..half. limbs . len ( ) / 2 {
120
+ half. limbs [ i] = Limb :: MAX ;
121
+ }
122
+ assert_eq ! ( U256 :: MAX . sqrt( ) , half) ;
123
+
124
+ assert_eq ! (
125
+ U192 :: from_be_hex( "055fa39422bd9f281762946e056535badbf8a6864d45fa3d" ) . sqrt( ) ,
126
+ U192 :: from_be_hex( "0000000000000000000000002516f0832a538b2d98869e21" )
127
+ ) ;
128
+
129
+ assert_eq ! (
130
+ U256 :: from_be_hex( "4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597" )
131
+ . sqrt( ) ,
132
+ U256 :: from_be_hex( "000000000000000000000000000000008b3956339e8315cff66eb6107b610075" )
133
+ ) ;
134
+ }
135
+
136
+ #[ test]
137
+ fn edge_vartime ( ) {
106
138
assert_eq ! ( U256 :: ZERO . sqrt_vartime( ) , U256 :: ZERO ) ;
107
139
assert_eq ! ( U256 :: ONE . sqrt_vartime( ) , U256 :: ONE ) ;
108
140
let mut half = U256 :: ZERO ;
109
141
for i in 0 ..half. limbs . len ( ) / 2 {
110
142
half. limbs [ i] = Limb :: MAX ;
111
143
}
112
- assert_eq ! ( U256 :: MAX . sqrt_vartime( ) , half, ) ;
144
+ assert_eq ! ( U256 :: MAX . sqrt_vartime( ) , half) ;
113
145
}
114
146
115
147
#[ test]
@@ -131,13 +163,28 @@ mod tests {
131
163
for ( a, e) in & tests {
132
164
let l = U256 :: from ( * a) ;
133
165
let r = U256 :: from ( * e) ;
166
+ assert_eq ! ( l. sqrt( ) , r) ;
134
167
assert_eq ! ( l. sqrt_vartime( ) , r) ;
168
+ assert_eq ! ( l. checked_sqrt( ) . is_some( ) . unwrap_u8( ) , 1u8 ) ;
135
169
assert_eq ! ( l. checked_sqrt_vartime( ) . is_some( ) . unwrap_u8( ) , 1u8 ) ;
136
170
}
137
171
}
138
172
139
173
#[ test]
140
174
fn nonsquares ( ) {
175
+ assert_eq ! ( U256 :: from( 2u8 ) . sqrt( ) , U256 :: from( 1u8 ) ) ;
176
+ assert_eq ! ( U256 :: from( 2u8 ) . checked_sqrt( ) . is_some( ) . unwrap_u8( ) , 0 ) ;
177
+ assert_eq ! ( U256 :: from( 3u8 ) . sqrt( ) , U256 :: from( 1u8 ) ) ;
178
+ assert_eq ! ( U256 :: from( 3u8 ) . checked_sqrt( ) . is_some( ) . unwrap_u8( ) , 0 ) ;
179
+ assert_eq ! ( U256 :: from( 5u8 ) . sqrt( ) , U256 :: from( 2u8 ) ) ;
180
+ assert_eq ! ( U256 :: from( 6u8 ) . sqrt( ) , U256 :: from( 2u8 ) ) ;
181
+ assert_eq ! ( U256 :: from( 7u8 ) . sqrt( ) , U256 :: from( 2u8 ) ) ;
182
+ assert_eq ! ( U256 :: from( 8u8 ) . sqrt( ) , U256 :: from( 2u8 ) ) ;
183
+ assert_eq ! ( U256 :: from( 10u8 ) . sqrt( ) , U256 :: from( 3u8 ) ) ;
184
+ }
185
+
186
+ #[ test]
187
+ fn nonsquares_vartime ( ) {
141
188
assert_eq ! ( U256 :: from( 2u8 ) . sqrt_vartime( ) , U256 :: from( 1u8 ) ) ;
142
189
assert_eq ! (
143
190
U256 :: from( 2u8 ) . checked_sqrt_vartime( ) . is_some( ) . unwrap_u8( ) ,
@@ -163,14 +210,17 @@ mod tests {
163
210
let t = rng. next_u32 ( ) as u64 ;
164
211
let s = U256 :: from ( t) ;
165
212
let s2 = s. checked_mul ( & s) . unwrap ( ) ;
213
+ assert_eq ! ( s2. sqrt( ) , s) ;
166
214
assert_eq ! ( s2. sqrt_vartime( ) , s) ;
215
+ assert_eq ! ( s2. checked_sqrt( ) . is_some( ) . unwrap_u8( ) , 1 ) ;
167
216
assert_eq ! ( s2. checked_sqrt_vartime( ) . is_some( ) . unwrap_u8( ) , 1 ) ;
168
217
}
169
218
170
219
for _ in 0 ..50 {
171
220
let s = U256 :: random ( & mut rng) ;
172
221
let mut s2 = U512 :: ZERO ;
173
222
s2. limbs [ ..s. limbs . len ( ) ] . copy_from_slice ( & s. limbs ) ;
223
+ assert_eq ! ( s. square( ) . sqrt( ) , s2) ;
174
224
assert_eq ! ( s. square( ) . sqrt_vartime( ) , s2) ;
175
225
}
176
226
}
0 commit comments