1
1
use super :: Uint ;
2
- use crate :: { CtChoice , Limb } ;
2
+ use crate :: CtChoice ;
3
3
4
4
impl < const LIMBS : usize > Uint < LIMBS > {
5
- /// Computes 1/`self` mod 2^k as specified in Algorithm 4 from
6
- /// A Secure Algorithm for Inversion Modulo 2k by
7
- /// Sadiel de la Fe and Carles Ferrer. See
8
- /// <https://www.mdpi.com/2410-387X/2/3/23>.
5
+ /// Computes 1/`self` mod `2^k`.
6
+ /// This method is constant-time w.r.t. `self` but not `k`.
9
7
///
10
8
/// Conditions: `self` < 2^k and `self` must be odd
11
- pub const fn inv_mod2k ( & self , k : usize ) -> Self {
12
- let mut x = Self :: ZERO ;
13
- let mut b = Self :: ONE ;
9
+ pub const fn inv_mod2k_vartime ( & self , k : usize ) -> Self {
10
+ // Using the Algorithm 3 from "A Secure Algorithm for Inversion Modulo 2k"
11
+ // by Sadiel de la Fe and Carles Ferrer.
12
+ // See <https://www.mdpi.com/2410-387X/2/3/23>.
13
+
14
+ // Note that we are not using Alrgorithm 4, since we have a different approach
15
+ // of enforcing constant-timeness w.r.t. `self`.
16
+
17
+ let mut x = Self :: ZERO ; // keeps `x` during iterations
18
+ let mut b = Self :: ONE ; // keeps `b_i` during iterations
14
19
let mut i = 0 ;
15
20
16
21
while i < k {
17
- let mut x_i = Self :: ZERO ;
18
- let j = b. limbs [ 0 ] . 0 & 1 ;
19
- x_i. limbs [ 0 ] = Limb ( j) ;
20
- x = x. bitor ( & x_i. shl_vartime ( i) ) ;
22
+ // X_i = b_i mod 2
23
+ let x_i = b. limbs [ 0 ] . 0 & 1 ;
24
+ let x_i_choice = CtChoice :: from_lsb ( x_i) ;
25
+ // b_{i+1} = (b_i - a * X_i) / 2
26
+ b = Self :: ct_select ( & b, & b. wrapping_sub ( self ) , x_i_choice) . shr_vartime ( 1 ) ;
27
+ // Store the X_i bit in the result (x = x | (1 << X_i))
28
+ x = x. bitor ( & Uint :: from_word ( x_i) . shl_vartime ( i) ) ;
29
+
30
+ i += 1 ;
31
+ }
32
+
33
+ x
34
+ }
35
+
36
+ /// Computes 1/`self` mod `2^k`.
37
+ ///
38
+ /// Conditions: `self` < 2^k and `self` must be odd
39
+ pub const fn inv_mod2k ( & self , k : usize ) -> Self {
40
+ // This is the same algorithm as in `inv_mod2k_vartime()`,
41
+ // but made constant-time w.r.t `k` as well.
42
+
43
+ let mut x = Self :: ZERO ; // keeps `x` during iterations
44
+ let mut b = Self :: ONE ; // keeps `b_i` during iterations
45
+ let mut i = 0 ;
46
+
47
+ while i < Self :: BITS {
48
+ // Only iterations for i = 0..k need to change `x`,
49
+ // the rest are dummy ones performed for the sake of constant-timeness.
50
+ let within_range = CtChoice :: from_usize_lt ( i, k) ;
51
+
52
+ // X_i = b_i mod 2
53
+ let x_i = b. limbs [ 0 ] . 0 & 1 ;
54
+ let x_i_choice = CtChoice :: from_lsb ( x_i) ;
55
+ // b_{i+1} = (b_i - a * X_i) / 2
56
+ b = Self :: ct_select ( & b, & b. wrapping_sub ( self ) , x_i_choice) . shr_vartime ( 1 ) ;
57
+
58
+ // Store the X_i bit in the result (x = x | (1 << X_i))
59
+ // Don't change the result in dummy iterations.
60
+ let x_i_choice = x_i_choice. and ( within_range) ;
61
+ x = x. set_bit ( i, x_i_choice) ;
21
62
22
- let t = b. wrapping_sub ( self ) ;
23
- b = Self :: ct_select ( & b, & t, CtChoice :: from_lsb ( j) ) . shr_vartime ( 1 ) ;
24
63
i += 1 ;
25
64
}
65
+
26
66
x
27
67
}
28
68
@@ -97,10 +137,45 @@ impl<const LIMBS: usize> Uint<LIMBS> {
97
137
}
98
138
99
139
/// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
100
- /// Returns `(inverse, Word::MAX)` if an inverse exists, otherwise `(undefined, Word::ZERO)`.
140
+ /// Returns `(inverse, CtChoice::TRUE)` if an inverse exists,
141
+ /// otherwise `(undefined, CtChoice::FALSE)`.
101
142
pub const fn inv_odd_mod ( & self , modulus : & Self ) -> ( Self , CtChoice ) {
102
143
self . inv_odd_mod_bounded ( modulus, Uint :: < LIMBS > :: BITS , Uint :: < LIMBS > :: BITS )
103
144
}
145
+
146
+ /// Computes the multiplicative inverse of `self` mod `modulus`.
147
+ /// Returns `(inverse, CtChoice::TRUE)` if an inverse exists,
148
+ /// otherwise `(undefined, CtChoice::FALSE)`.
149
+ pub fn inv_mod ( & self , modulus : & Self ) -> ( Self , CtChoice ) {
150
+ // Decompose `modulus = s * 2^k` where `s` is odd
151
+ let k = modulus. trailing_zeros ( ) ;
152
+ let s = modulus. shr ( k) ;
153
+
154
+ // Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses.
155
+ // Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1`
156
+ let ( a, a_is_some) = self . inv_odd_mod ( & s) ;
157
+ let b = self . inv_mod2k ( k) ;
158
+ // inverse modulo 2^k exists either if `k` is 0 or if `self` is odd.
159
+ let b_is_some = CtChoice :: from_usize_being_nonzero ( k)
160
+ . not ( )
161
+ . or ( self . ct_is_odd ( ) ) ;
162
+
163
+ // Restore from RNS:
164
+ // self^{-1} = a mod s = b mod 2^k
165
+ // => self^{-1} = a + s * ((b - a) * s^(-1) mod 2^k)
166
+ // (essentially one step of the Garner's algorithm for recovery from RNS).
167
+
168
+ let m_odd_inv = s. inv_mod2k ( k) ; // `s` is odd, so this always exists
169
+
170
+ // This part is mod 2^k
171
+ let mask = ( Uint :: ONE << k) . wrapping_sub ( & Uint :: ONE ) ;
172
+ let t = ( b. wrapping_sub ( & a) . wrapping_mul ( & m_odd_inv) ) & mask;
173
+
174
+ // Will not overflow since `a <= s - 1`, `t <= 2^k - 1`,
175
+ // so `a + s * t <= s * 2^k - 1 == modulus - 1`.
176
+ let result = a. wrapping_add ( & s. wrapping_mul ( & t) ) ;
177
+ ( result, a_is_some. and ( b_is_some) )
178
+ }
104
179
}
105
180
106
181
#[ cfg( test) ]
@@ -125,7 +200,7 @@ mod tests {
125
200
}
126
201
127
202
#[ test]
128
- fn test_invert ( ) {
203
+ fn test_invert_odd ( ) {
129
204
let a = U1024 :: from_be_hex ( concat ! [
130
205
"000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E" ,
131
206
"347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8" ,
@@ -138,15 +213,45 @@ mod tests {
138
213
"D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C" ,
139
214
"558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767"
140
215
] ) ;
141
-
142
- let ( res, is_some) = a. inv_odd_mod ( & m) ;
143
-
144
216
let expected = U1024 :: from_be_hex ( concat ! [
145
217
"B03623284B0EBABCABD5C5881893320281460C0A8E7BF4BFDCFFCBCCBF436A55" ,
146
218
"D364235C8171E46C7D21AAD0680676E57274A8FDA6D12768EF961CACDD2DAE57" ,
147
219
"88D93DA5EB8EDC391EE3726CDCF4613C539F7D23E8702200CB31B5ED5B06E5CA" ,
148
220
"3E520968399B4017BF98A864FABA2B647EFC4998B56774D4F2CB026BC024A336"
149
221
] ) ;
222
+
223
+ let ( res, is_some) = a. inv_odd_mod ( & m) ;
224
+ assert ! ( is_some. is_true_vartime( ) ) ;
225
+ assert_eq ! ( res, expected) ;
226
+
227
+ // Even though it is less efficient, it still works
228
+ let ( res, is_some) = a. inv_mod ( & m) ;
229
+ assert ! ( is_some. is_true_vartime( ) ) ;
230
+ assert_eq ! ( res, expected) ;
231
+ }
232
+
233
+ #[ test]
234
+ fn test_invert_even ( ) {
235
+ let a = U1024 :: from_be_hex ( concat ! [
236
+ "000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E" ,
237
+ "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8" ,
238
+ "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8" ,
239
+ "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
240
+ ] ) ;
241
+ let m = U1024 :: from_be_hex ( concat ! [
242
+ "D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F" ,
243
+ "37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A" ,
244
+ "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C" ,
245
+ "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156000"
246
+ ] ) ;
247
+ let expected = U1024 :: from_be_hex ( concat ! [
248
+ "1EBF391306817E1BC610E213F4453AD70911CCBD59A901B2A468A4FC1D64F357" ,
249
+ "DBFC6381EC5635CAA664DF280028AF4651482C77A143DF38D6BFD4D64B6C0225" ,
250
+ "FC0E199B15A64966FB26D88A86AD144271F6BDCD3D63193AB2B3CC53B99F21A3" ,
251
+ "5B9BFAE5D43C6BC6E7A9856C71C7318C76530E9E5AE35882D5ABB02F1696874D" ,
252
+ ] ) ;
253
+
254
+ let ( res, is_some) = a. inv_mod ( & m) ;
150
255
assert ! ( is_some. is_true_vartime( ) ) ;
151
256
assert_eq ! ( res, expected) ;
152
257
}
0 commit comments