@@ -289,35 +289,23 @@ const fn mul_encrypt(a: u32, b: u32) -> u32 {
289
289
}
290
290
291
291
// Multiplication with multiplicative inverse, modulo (2^32)
292
- fn mul_decrypt ( a : u32 , b : u32 ) -> u32 {
293
- a. wrapping_mul ( mul_inverse ( b | 1 ) )
292
+ #[ inline( always) ]
293
+ const fn mul_decrypt ( a : u32 , b : u32 ) -> u32 {
294
+ a. wrapping_mul ( mod_inverse ( b | 1 ) )
294
295
}
295
296
296
- // Computes the multiplicative inverse of @word, modulo (2^32).
297
- // Original MIPS R5900 coding converted to C, and now to Rust.
298
- fn mul_inverse ( word : u32 ) -> u32 {
299
- if word == 1 {
300
- return 1 ;
301
- }
302
- let mut a2 = 0u32 . wrapping_sub ( word) % word;
303
- if a2 == 0 {
304
- return 1 ;
305
- }
306
- let mut t1 = 1u32 ;
307
- let mut a3 = word;
308
- let mut a0 = 0u32 . wrapping_sub ( 0xffff_ffff / word) ;
309
- while a2 != 0 {
310
- let mut v0 = a3 / a2;
311
- let v1 = a3 % a2;
312
- let a1 = a2;
313
- a3 = a1;
314
- let a1 = a0;
315
- a2 = v1;
316
- v0 = v0. wrapping_mul ( a1) ;
317
- a0 = t1. wrapping_sub ( v0) ;
318
- t1 = a1;
319
- }
320
- t1
297
+ // Computes the multiplicative inverse of x modulo (2^32). x must be odd!
298
+ // The code is based on Newton's method as explained in this blog post:
299
+ // https://lemire.me/blog/2017/09/18/computing-the-inverse-of-odd-integers/
300
+ const fn mod_inverse ( x : u32 ) -> u32 {
301
+ let mut y = x;
302
+ // Call this recurrence formula 4 times for 32-bit values:
303
+ // f(y) = y * (2 - y * x) modulo 2^32
304
+ y = y. wrapping_mul ( 2u32 . wrapping_sub ( y. wrapping_mul ( x) ) ) ;
305
+ y = y. wrapping_mul ( 2u32 . wrapping_sub ( y. wrapping_mul ( x) ) ) ;
306
+ y = y. wrapping_mul ( 2u32 . wrapping_sub ( y. wrapping_mul ( x) ) ) ;
307
+ y = y. wrapping_mul ( 2u32 . wrapping_sub ( y. wrapping_mul ( x) ) ) ;
308
+ y
321
309
}
322
310
323
311
// RSA encryption/decryption
@@ -507,7 +495,7 @@ mod tests {
507
495
}
508
496
509
497
#[ test]
510
- fn test_mul_inverse ( ) {
498
+ fn test_mod_inverse ( ) {
511
499
let tests = vec ! [
512
500
( 0x0d31_3243 , 0x6c7b_2a6b ) ,
513
501
( 0x0efd_8231 , 0xd4c0_96d1 ) ,
@@ -517,9 +505,12 @@ mod tests {
517
505
( 0x9ab2_af6d , 0x1043_b265 ) ,
518
506
( 0xa686_d3b7 , 0x57ed_7a07 ) ,
519
507
( 0xec35_a92f , 0xd274_3dcf ) ,
508
+ ( 0x0000_0000 , 0x0000_0000 ) , // Technically, 0 has no inverse
509
+ ( 0x0000_0001 , 0x0000_0001 ) ,
510
+ ( 0xffff_ffff , 0xffff_ffff ) ,
520
511
] ;
521
512
for t in tests. iter ( ) {
522
- assert_eq ! ( t. 1 , mul_inverse ( t. 0 ) ) ;
513
+ assert_eq ! ( t. 1 , mod_inverse ( t. 0 ) ) ;
523
514
}
524
515
}
525
516
0 commit comments