@@ -137,10 +137,43 @@ impl<const LIMBS: usize> Uint<LIMBS> {
137
137
}
138
138
139
139
/// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
140
- /// Returns `(inverse, Word::MAX)` if an inverse exists, otherwise `(undefined, Word::ZERO)`.
140
+ /// Returns `(inverse, CtChoice::TRUE)` if an inverse exists,
141
+ /// otherwise `(undefined, CtChoice::FALSE)`.
141
142
pub const fn inv_odd_mod ( & self , modulus : & Self ) -> ( Self , CtChoice ) {
142
143
self . inv_odd_mod_bounded ( modulus, Uint :: < LIMBS > :: BITS , Uint :: < LIMBS > :: BITS )
143
144
}
145
+
146
+ /// Computes the multiplicative inverse of `self` mod `modulus`.
147
+ /// Returns `(inverse, CtChoice::TRUE)` if an inverse exists,
148
+ /// otherwise `(undefined, CtChoice::FALSE)`.
149
+ pub fn inv_mod ( & self , modulus : & Self ) -> ( Self , CtChoice ) {
150
+ // Decompose `modulus = s * 2^k` where `s` is odd
151
+ let k = modulus. trailing_zeros ( ) ;
152
+ let s = modulus. shr ( k) ;
153
+
154
+ // Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses.
155
+ // Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1`.
156
+ let ( a, a_is_some) = self . inv_odd_mod ( & s) ;
157
+ let b = self . inv_mod2k ( k) ;
158
+ // inverse modulo 2^k exists either if `k` is 0 or if `self` is odd.
159
+ let b_is_some = CtChoice :: from_usize_being_nonzero ( k)
160
+ . not ( )
161
+ . or ( self . ct_is_odd ( ) ) ;
162
+
163
+ // Restore from RNS:
164
+ // self^{-1} = a mod s = b mod 2^k
165
+ // => self^{-1} = a + s * ((b - a) * s^(-1) mod 2^k)
166
+ let m_odd_inv = s. inv_mod2k ( k) ; // `s` is odd, so this always exists
167
+
168
+ // This part is mod 2^k
169
+ let mask = ( Uint :: ONE << k) . wrapping_sub ( & Uint :: ONE ) ;
170
+ let t = ( b. wrapping_sub ( & a) . wrapping_mul ( & m_odd_inv) ) & mask;
171
+
172
+ // Will not overflow since `a <= s - 1`, `t <= 2^k - 1`,
173
+ // so `a + s * t <= s * 2^k - 1 == modulus - 1`.
174
+ let result = a. wrapping_add ( & s. wrapping_mul ( & t) ) ;
175
+ ( result, a_is_some. and ( b_is_some) )
176
+ }
144
177
}
145
178
146
179
#[ cfg( test) ]
@@ -165,7 +198,7 @@ mod tests {
165
198
}
166
199
167
200
#[ test]
168
- fn test_invert ( ) {
201
+ fn test_invert_odd ( ) {
169
202
let a = U1024 :: from_be_hex ( concat ! [
170
203
"000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E" ,
171
204
"347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8" ,
@@ -178,15 +211,45 @@ mod tests {
178
211
"D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C" ,
179
212
"558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767"
180
213
] ) ;
181
-
182
- let ( res, is_some) = a. inv_odd_mod ( & m) ;
183
-
184
214
let expected = U1024 :: from_be_hex ( concat ! [
185
215
"B03623284B0EBABCABD5C5881893320281460C0A8E7BF4BFDCFFCBCCBF436A55" ,
186
216
"D364235C8171E46C7D21AAD0680676E57274A8FDA6D12768EF961CACDD2DAE57" ,
187
217
"88D93DA5EB8EDC391EE3726CDCF4613C539F7D23E8702200CB31B5ED5B06E5CA" ,
188
218
"3E520968399B4017BF98A864FABA2B647EFC4998B56774D4F2CB026BC024A336"
189
219
] ) ;
220
+
221
+ let ( res, is_some) = a. inv_odd_mod ( & m) ;
222
+ assert ! ( is_some. is_true_vartime( ) ) ;
223
+ assert_eq ! ( res, expected) ;
224
+
225
+ // Even though it is less efficient, it still works
226
+ let ( res, is_some) = a. inv_mod ( & m) ;
227
+ assert ! ( is_some. is_true_vartime( ) ) ;
228
+ assert_eq ! ( res, expected) ;
229
+ }
230
+
231
+ #[ test]
232
+ fn test_invert_even ( ) {
233
+ let a = U1024 :: from_be_hex ( concat ! [
234
+ "000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E" ,
235
+ "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8" ,
236
+ "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8" ,
237
+ "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
238
+ ] ) ;
239
+ let m = U1024 :: from_be_hex ( concat ! [
240
+ "D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F" ,
241
+ "37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A" ,
242
+ "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C" ,
243
+ "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156000"
244
+ ] ) ;
245
+ let expected = U1024 :: from_be_hex ( concat ! [
246
+ "1EBF391306817E1BC610E213F4453AD70911CCBD59A901B2A468A4FC1D64F357" ,
247
+ "DBFC6381EC5635CAA664DF280028AF4651482C77A143DF38D6BFD4D64B6C0225" ,
248
+ "FC0E199B15A64966FB26D88A86AD144271F6BDCD3D63193AB2B3CC53B99F21A3" ,
249
+ "5B9BFAE5D43C6BC6E7A9856C71C7318C76530E9E5AE35882D5ABB02F1696874D" ,
250
+ ] ) ;
251
+
252
+ let ( res, is_some) = a. inv_mod ( & m) ;
190
253
assert ! ( is_some. is_true_vartime( ) ) ;
191
254
assert_eq ! ( res, expected) ;
192
255
}
0 commit comments