Skip to content

Commit 1a13201

Browse files
committed
Add inv_mod() that supports any moduli
1 parent cf4eb1d commit 1a13201

File tree

4 files changed

+144
-6
lines changed

4 files changed

+144
-6
lines changed

benches/bench.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,59 @@ fn bench_shifts<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
153153
});
154154
}
155155

156+
fn bench_inv_mod<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
157+
group.bench_function("inv_odd_mod, U256", |b| {
158+
b.iter_batched(
159+
|| {
160+
let m = U256::random(&mut OsRng) | U256::ONE;
161+
loop {
162+
let x = U256::random(&mut OsRng);
163+
let (_, is_some) = x.inv_odd_mod(&m);
164+
if is_some.into() {
165+
break (x, m);
166+
}
167+
}
168+
},
169+
|(x, m)| x.inv_odd_mod(&m),
170+
BatchSize::SmallInput,
171+
)
172+
});
173+
174+
group.bench_function("inv_mod, U256, odd modulus", |b| {
175+
b.iter_batched(
176+
|| {
177+
let m = U256::random(&mut OsRng) | U256::ONE;
178+
loop {
179+
let x = U256::random(&mut OsRng);
180+
let (_, is_some) = x.inv_odd_mod(&m);
181+
if is_some.into() {
182+
break (x, m);
183+
}
184+
}
185+
},
186+
|(x, m)| x.inv_mod(&m),
187+
BatchSize::SmallInput,
188+
)
189+
});
190+
191+
group.bench_function("inv_mod, U256", |b| {
192+
b.iter_batched(
193+
|| {
194+
let m = U256::random(&mut OsRng);
195+
loop {
196+
let x = U256::random(&mut OsRng);
197+
let (_, is_some) = x.inv_mod(&m);
198+
if is_some.into() {
199+
break (x, m);
200+
}
201+
}
202+
},
203+
|(x, m)| x.inv_mod(&m),
204+
BatchSize::SmallInput,
205+
)
206+
});
207+
}
208+
156209
fn bench_wrapping_ops(c: &mut Criterion) {
157210
let mut group = c.benchmark_group("wrapping ops");
158211
bench_division(&mut group);
@@ -169,6 +222,7 @@ fn bench_montgomery(c: &mut Criterion) {
169222
fn bench_modular_ops(c: &mut Criterion) {
170223
let mut group = c.benchmark_group("modular ops");
171224
bench_shifts(&mut group);
225+
bench_inv_mod(&mut group);
172226
group.finish();
173227
}
174228

src/ct_choice.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ impl CtChoice {
5050
Self(!self.0)
5151
}
5252

53+
pub(crate) const fn or(&self, other: Self) -> Self {
54+
Self(self.0 | other.0)
55+
}
56+
5357
pub(crate) const fn and(&self, other: Self) -> Self {
5458
Self(self.0 & other.0)
5559
}

src/uint/inv_mod.rs

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,43 @@ impl<const LIMBS: usize> Uint<LIMBS> {
137137
}
138138

139139
/// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
140-
/// Returns `(inverse, Word::MAX)` if an inverse exists, otherwise `(undefined, Word::ZERO)`.
140+
/// Returns `(inverse, CtChoice::TRUE)` if an inverse exists,
141+
/// otherwise `(undefined, CtChoice::FALSE)`.
141142
pub const fn inv_odd_mod(&self, modulus: &Self) -> (Self, CtChoice) {
142143
self.inv_odd_mod_bounded(modulus, Uint::<LIMBS>::BITS, Uint::<LIMBS>::BITS)
143144
}
145+
146+
/// Computes the multiplicative inverse of `self` mod `modulus`.
147+
/// Returns `(inverse, CtChoice::TRUE)` if an inverse exists,
148+
/// otherwise `(undefined, CtChoice::FALSE)`.
149+
pub fn inv_mod(&self, modulus: &Self) -> (Self, CtChoice) {
150+
// Decompose `modulus = s * 2^k` where `s` is odd
151+
let k = modulus.trailing_zeros();
152+
let s = modulus.shr(k);
153+
154+
// Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses.
155+
// Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1`.
156+
let (a, a_is_some) = self.inv_odd_mod(&s);
157+
let b = self.inv_mod2k(k);
158+
// inverse modulo 2^k exists either if `k` is 0 or if `self` is odd.
159+
let b_is_some = CtChoice::from_usize_being_nonzero(k)
160+
.not()
161+
.or(self.ct_is_odd());
162+
163+
// Restore from RNS:
164+
// self^{-1} = a mod s = b mod 2^k
165+
// => self^{-1} = a + s * ((b - a) * s^(-1) mod 2^k)
166+
let m_odd_inv = s.inv_mod2k(k); // `s` is odd, so this always exists
167+
168+
// This part is mod 2^k
169+
let mask = (Uint::ONE << k).wrapping_sub(&Uint::ONE);
170+
let t = (b.wrapping_sub(&a).wrapping_mul(&m_odd_inv)) & mask;
171+
172+
// Will not overflow since `a <= s - 1`, `t <= 2^k - 1`,
173+
// so `a + s * t <= s * 2^k - 1 == modulus - 1`.
174+
let result = a.wrapping_add(&s.wrapping_mul(&t));
175+
(result, a_is_some.and(b_is_some))
176+
}
144177
}
145178

146179
#[cfg(test)]
@@ -165,7 +198,7 @@ mod tests {
165198
}
166199

167200
#[test]
168-
fn test_invert() {
201+
fn test_invert_odd() {
169202
let a = U1024::from_be_hex(concat![
170203
"000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
171204
"347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
@@ -178,15 +211,45 @@ mod tests {
178211
"D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
179212
"558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767"
180213
]);
181-
182-
let (res, is_some) = a.inv_odd_mod(&m);
183-
184214
let expected = U1024::from_be_hex(concat![
185215
"B03623284B0EBABCABD5C5881893320281460C0A8E7BF4BFDCFFCBCCBF436A55",
186216
"D364235C8171E46C7D21AAD0680676E57274A8FDA6D12768EF961CACDD2DAE57",
187217
"88D93DA5EB8EDC391EE3726CDCF4613C539F7D23E8702200CB31B5ED5B06E5CA",
188218
"3E520968399B4017BF98A864FABA2B647EFC4998B56774D4F2CB026BC024A336"
189219
]);
220+
221+
let (res, is_some) = a.inv_odd_mod(&m);
222+
assert!(is_some.is_true_vartime());
223+
assert_eq!(res, expected);
224+
225+
// Even though it is less efficient, it still works
226+
let (res, is_some) = a.inv_mod(&m);
227+
assert!(is_some.is_true_vartime());
228+
assert_eq!(res, expected);
229+
}
230+
231+
#[test]
232+
fn test_invert_even() {
233+
let a = U1024::from_be_hex(concat![
234+
"000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
235+
"347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
236+
"BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
237+
"382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
238+
]);
239+
let m = U1024::from_be_hex(concat![
240+
"D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F",
241+
"37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A",
242+
"D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
243+
"558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156000"
244+
]);
245+
let expected = U1024::from_be_hex(concat![
246+
"1EBF391306817E1BC610E213F4453AD70911CCBD59A901B2A468A4FC1D64F357",
247+
"DBFC6381EC5635CAA664DF280028AF4651482C77A143DF38D6BFD4D64B6C0225",
248+
"FC0E199B15A64966FB26D88A86AD144271F6BDCD3D63193AB2B3CC53B99F21A3",
249+
"5B9BFAE5D43C6BC6E7A9856C71C7318C76530E9E5AE35882D5ABB02F1696874D",
250+
]);
251+
252+
let (res, is_some) = a.inv_mod(&m);
190253
assert!(is_some.is_true_vartime());
191254
assert_eq!(res, expected);
192255
}

tests/proptests.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
33
use crypto_bigint::{
44
modular::runtime_mod::{DynResidue, DynResidueParams},
5-
Encoding, Limb, NonZero, Word, U256,
5+
CtChoice, Encoding, Limb, NonZero, Word, U256,
66
};
77
use num_bigint::BigUint;
88
use num_integer::Integer;
@@ -233,6 +233,23 @@ proptest! {
233233
}
234234
}
235235

236+
#[test]
237+
fn inv_mod(a in uint(), b in uint()) {
238+
let a_bi = to_biguint(&a);
239+
let b_bi = to_biguint(&b);
240+
241+
let expected_is_some = if a_bi.gcd(&b_bi) == BigUint::one() { CtChoice::TRUE } else { CtChoice::FALSE };
242+
let (actual, actual_is_some) = a.inv_mod(&b);
243+
244+
assert_eq!(bool::from(expected_is_some), bool::from(actual_is_some));
245+
246+
if actual_is_some.into() {
247+
let inv_bi = to_biguint(&actual);
248+
let res = (inv_bi * a_bi) % b_bi;
249+
assert_eq!(res, BigUint::one());
250+
}
251+
}
252+
236253
#[test]
237254
fn wrapping_sqrt(a in uint()) {
238255
let a_bi = to_biguint(&a);

0 commit comments

Comments
 (0)