Skip to content

Commit e0320a1

Browse files
authored
Modular inversion improvements (#263)
1 parent ccfd366 commit e0320a1

File tree

7 files changed

+285
-27
lines changed

7 files changed

+285
-27
lines changed

benches/bench.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,59 @@ fn bench_shifts<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
153153
});
154154
}
155155

156+
fn bench_inv_mod<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
157+
group.bench_function("inv_odd_mod, U256", |b| {
158+
b.iter_batched(
159+
|| {
160+
let m = U256::random(&mut OsRng) | U256::ONE;
161+
loop {
162+
let x = U256::random(&mut OsRng);
163+
let (_, is_some) = x.inv_odd_mod(&m);
164+
if is_some.into() {
165+
break (x, m);
166+
}
167+
}
168+
},
169+
|(x, m)| x.inv_odd_mod(&m),
170+
BatchSize::SmallInput,
171+
)
172+
});
173+
174+
group.bench_function("inv_mod, U256, odd modulus", |b| {
175+
b.iter_batched(
176+
|| {
177+
let m = U256::random(&mut OsRng) | U256::ONE;
178+
loop {
179+
let x = U256::random(&mut OsRng);
180+
let (_, is_some) = x.inv_odd_mod(&m);
181+
if is_some.into() {
182+
break (x, m);
183+
}
184+
}
185+
},
186+
|(x, m)| x.inv_mod(&m),
187+
BatchSize::SmallInput,
188+
)
189+
});
190+
191+
group.bench_function("inv_mod, U256", |b| {
192+
b.iter_batched(
193+
|| {
194+
let m = U256::random(&mut OsRng);
195+
loop {
196+
let x = U256::random(&mut OsRng);
197+
let (_, is_some) = x.inv_mod(&m);
198+
if is_some.into() {
199+
break (x, m);
200+
}
201+
}
202+
},
203+
|(x, m)| x.inv_mod(&m),
204+
BatchSize::SmallInput,
205+
)
206+
});
207+
}
208+
156209
fn bench_wrapping_ops(c: &mut Criterion) {
157210
let mut group = c.benchmark_group("wrapping ops");
158211
bench_division(&mut group);
@@ -169,6 +222,7 @@ fn bench_montgomery(c: &mut Criterion) {
169222
fn bench_modular_ops(c: &mut Criterion) {
170223
let mut group = c.benchmark_group("modular ops");
171224
bench_shifts(&mut group);
225+
bench_inv_mod(&mut group);
172226
group.finish();
173227
}
174228

src/ct_choice.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@ impl CtChoice {
2929
Self(value.wrapping_neg())
3030
}
3131

32+
/// Returns the truthy value if `value != 0`, and the falsy value otherwise.
33+
pub(crate) const fn from_usize_being_nonzero(value: usize) -> Self {
34+
const HI_BIT: u32 = usize::BITS - 1;
35+
Self::from_lsb(((value | value.wrapping_neg()) >> HI_BIT) as Word)
36+
}
37+
38+
/// Returns the truthy value if `x == y`, and the falsy value otherwise.
39+
pub(crate) const fn from_usize_equality(x: usize, y: usize) -> Self {
40+
Self::from_usize_being_nonzero(x.wrapping_sub(y)).not()
41+
}
42+
3243
/// Returns the truthy value if `x < y`, and the falsy value otherwise.
3344
pub(crate) const fn from_usize_lt(x: usize, y: usize) -> Self {
3445
let bit = (((!x) & y) | (((!x) | y) & (x.wrapping_sub(y)))) >> (usize::BITS - 1);
@@ -39,6 +50,10 @@ impl CtChoice {
3950
Self(!self.0)
4051
}
4152

53+
pub(crate) const fn or(&self, other: Self) -> Self {
54+
Self(self.0 | other.0)
55+
}
56+
4257
pub(crate) const fn and(&self, other: Self) -> Self {
4358
Self(self.0 & other.0)
4459
}

src/uint/bits.rs

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
6969
/// Get the value of the bit at position `index`, as a truthy or falsy `CtChoice`.
7070
/// Returns the falsy value for indices out of range.
7171
pub const fn bit(&self, index: usize) -> CtChoice {
72-
let limb_num = Limb((index / Limb::BITS) as Word);
72+
let limb_num = index / Limb::BITS;
7373
let index_in_limb = index % Limb::BITS;
7474
let index_mask = 1 << index_in_limb;
7575

@@ -79,18 +79,36 @@ impl<const LIMBS: usize> Uint<LIMBS> {
7979
let mut i = 0;
8080
while i < LIMBS {
8181
let bit = limbs[i] & index_mask;
82-
let is_right_limb = Limb::ct_eq(limb_num, Limb(i as Word));
82+
let is_right_limb = CtChoice::from_usize_equality(i, limb_num);
8383
result |= is_right_limb.if_true(bit);
8484
i += 1;
8585
}
8686

8787
CtChoice::from_lsb(result >> index_in_limb)
8888
}
89+
90+
/// Sets the bit at `index` to 0 or 1 depending on the value of `bit_value`.
91+
pub(crate) const fn set_bit(self, index: usize, bit_value: CtChoice) -> Self {
92+
let mut result = self;
93+
let limb_num = index / Limb::BITS;
94+
let index_in_limb = index % Limb::BITS;
95+
let index_mask = 1 << index_in_limb;
96+
97+
let mut i = 0;
98+
while i < LIMBS {
99+
let is_right_limb = CtChoice::from_usize_equality(i, limb_num);
100+
let old_limb = result.limbs[i].0;
101+
let new_limb = bit_value.select(old_limb & !index_mask, old_limb | index_mask);
102+
result.limbs[i] = Limb(is_right_limb.select(old_limb, new_limb));
103+
i += 1;
104+
}
105+
result
106+
}
89107
}
90108

91109
#[cfg(test)]
92110
mod tests {
93-
use crate::U256;
111+
use crate::{CtChoice, U256};
94112

95113
fn uint_with_bits_at(positions: &[usize]) -> U256 {
96114
let mut result = U256::ZERO;
@@ -159,4 +177,31 @@ mod tests {
159177
let u = U256::ZERO;
160178
assert_eq!(u.trailing_zeros() as u32, 256);
161179
}
180+
181+
#[test]
182+
fn set_bit() {
183+
let u = uint_with_bits_at(&[16, 79, 150]);
184+
assert_eq!(
185+
u.set_bit(127, CtChoice::TRUE),
186+
uint_with_bits_at(&[16, 79, 127, 150])
187+
);
188+
189+
let u = uint_with_bits_at(&[16, 79, 150]);
190+
assert_eq!(
191+
u.set_bit(150, CtChoice::TRUE),
192+
uint_with_bits_at(&[16, 79, 150])
193+
);
194+
195+
let u = uint_with_bits_at(&[16, 79, 150]);
196+
assert_eq!(
197+
u.set_bit(127, CtChoice::FALSE),
198+
uint_with_bits_at(&[16, 79, 150])
199+
);
200+
201+
let u = uint_with_bits_at(&[16, 79, 150]);
202+
assert_eq!(
203+
u.set_bit(150, CtChoice::FALSE),
204+
uint_with_bits_at(&[16, 79])
205+
);
206+
}
162207
}

src/uint/inv_mod.rs

Lines changed: 124 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,68 @@
11
use super::Uint;
2-
use crate::{CtChoice, Limb};
2+
use crate::CtChoice;
33

44
impl<const LIMBS: usize> Uint<LIMBS> {
5-
/// Computes 1/`self` mod 2^k as specified in Algorithm 4 from
6-
/// A Secure Algorithm for Inversion Modulo 2k by
7-
/// Sadiel de la Fe and Carles Ferrer. See
8-
/// <https://www.mdpi.com/2410-387X/2/3/23>.
5+
/// Computes 1/`self` mod `2^k`.
6+
/// This method is constant-time w.r.t. `self` but not `k`.
97
///
108
/// Conditions: `self` < 2^k and `self` must be odd
11-
pub const fn inv_mod2k(&self, k: usize) -> Self {
12-
let mut x = Self::ZERO;
13-
let mut b = Self::ONE;
9+
pub const fn inv_mod2k_vartime(&self, k: usize) -> Self {
10+
// Using the Algorithm 3 from "A Secure Algorithm for Inversion Modulo 2k"
11+
// by Sadiel de la Fe and Carles Ferrer.
12+
// See <https://www.mdpi.com/2410-387X/2/3/23>.
13+
14+
// Note that we are not using Alrgorithm 4, since we have a different approach
15+
// of enforcing constant-timeness w.r.t. `self`.
16+
17+
let mut x = Self::ZERO; // keeps `x` during iterations
18+
let mut b = Self::ONE; // keeps `b_i` during iterations
1419
let mut i = 0;
1520

1621
while i < k {
17-
let mut x_i = Self::ZERO;
18-
let j = b.limbs[0].0 & 1;
19-
x_i.limbs[0] = Limb(j);
20-
x = x.bitor(&x_i.shl_vartime(i));
22+
// X_i = b_i mod 2
23+
let x_i = b.limbs[0].0 & 1;
24+
let x_i_choice = CtChoice::from_lsb(x_i);
25+
// b_{i+1} = (b_i - a * X_i) / 2
26+
b = Self::ct_select(&b, &b.wrapping_sub(self), x_i_choice).shr_vartime(1);
27+
// Store the X_i bit in the result (x = x | (1 << X_i))
28+
x = x.bitor(&Uint::from_word(x_i).shl_vartime(i));
29+
30+
i += 1;
31+
}
32+
33+
x
34+
}
35+
36+
/// Computes 1/`self` mod `2^k`.
37+
///
38+
/// Conditions: `self` < 2^k and `self` must be odd
39+
pub const fn inv_mod2k(&self, k: usize) -> Self {
40+
// This is the same algorithm as in `inv_mod2k_vartime()`,
41+
// but made constant-time w.r.t `k` as well.
42+
43+
let mut x = Self::ZERO; // keeps `x` during iterations
44+
let mut b = Self::ONE; // keeps `b_i` during iterations
45+
let mut i = 0;
46+
47+
while i < Self::BITS {
48+
// Only iterations for i = 0..k need to change `x`,
49+
// the rest are dummy ones performed for the sake of constant-timeness.
50+
let within_range = CtChoice::from_usize_lt(i, k);
51+
52+
// X_i = b_i mod 2
53+
let x_i = b.limbs[0].0 & 1;
54+
let x_i_choice = CtChoice::from_lsb(x_i);
55+
// b_{i+1} = (b_i - a * X_i) / 2
56+
b = Self::ct_select(&b, &b.wrapping_sub(self), x_i_choice).shr_vartime(1);
57+
58+
// Store the X_i bit in the result (x = x | (1 << X_i))
59+
// Don't change the result in dummy iterations.
60+
let x_i_choice = x_i_choice.and(within_range);
61+
x = x.set_bit(i, x_i_choice);
2162

22-
let t = b.wrapping_sub(self);
23-
b = Self::ct_select(&b, &t, CtChoice::from_lsb(j)).shr_vartime(1);
2463
i += 1;
2564
}
65+
2666
x
2767
}
2868

@@ -97,10 +137,45 @@ impl<const LIMBS: usize> Uint<LIMBS> {
97137
}
98138

99139
/// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
100-
/// Returns `(inverse, Word::MAX)` if an inverse exists, otherwise `(undefined, Word::ZERO)`.
140+
/// Returns `(inverse, CtChoice::TRUE)` if an inverse exists,
141+
/// otherwise `(undefined, CtChoice::FALSE)`.
101142
pub const fn inv_odd_mod(&self, modulus: &Self) -> (Self, CtChoice) {
102143
self.inv_odd_mod_bounded(modulus, Uint::<LIMBS>::BITS, Uint::<LIMBS>::BITS)
103144
}
145+
146+
/// Computes the multiplicative inverse of `self` mod `modulus`.
147+
/// Returns `(inverse, CtChoice::TRUE)` if an inverse exists,
148+
/// otherwise `(undefined, CtChoice::FALSE)`.
149+
pub fn inv_mod(&self, modulus: &Self) -> (Self, CtChoice) {
150+
// Decompose `modulus = s * 2^k` where `s` is odd
151+
let k = modulus.trailing_zeros();
152+
let s = modulus.shr(k);
153+
154+
// Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses.
155+
// Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1`
156+
let (a, a_is_some) = self.inv_odd_mod(&s);
157+
let b = self.inv_mod2k(k);
158+
// inverse modulo 2^k exists either if `k` is 0 or if `self` is odd.
159+
let b_is_some = CtChoice::from_usize_being_nonzero(k)
160+
.not()
161+
.or(self.ct_is_odd());
162+
163+
// Restore from RNS:
164+
// self^{-1} = a mod s = b mod 2^k
165+
// => self^{-1} = a + s * ((b - a) * s^(-1) mod 2^k)
166+
// (essentially one step of the Garner's algorithm for recovery from RNS).
167+
168+
let m_odd_inv = s.inv_mod2k(k); // `s` is odd, so this always exists
169+
170+
// This part is mod 2^k
171+
let mask = (Uint::ONE << k).wrapping_sub(&Uint::ONE);
172+
let t = (b.wrapping_sub(&a).wrapping_mul(&m_odd_inv)) & mask;
173+
174+
// Will not overflow since `a <= s - 1`, `t <= 2^k - 1`,
175+
// so `a + s * t <= s * 2^k - 1 == modulus - 1`.
176+
let result = a.wrapping_add(&s.wrapping_mul(&t));
177+
(result, a_is_some.and(b_is_some))
178+
}
104179
}
105180

106181
#[cfg(test)]
@@ -125,7 +200,7 @@ mod tests {
125200
}
126201

127202
#[test]
128-
fn test_invert() {
203+
fn test_invert_odd() {
129204
let a = U1024::from_be_hex(concat![
130205
"000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
131206
"347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
@@ -138,15 +213,45 @@ mod tests {
138213
"D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
139214
"558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767"
140215
]);
141-
142-
let (res, is_some) = a.inv_odd_mod(&m);
143-
144216
let expected = U1024::from_be_hex(concat![
145217
"B03623284B0EBABCABD5C5881893320281460C0A8E7BF4BFDCFFCBCCBF436A55",
146218
"D364235C8171E46C7D21AAD0680676E57274A8FDA6D12768EF961CACDD2DAE57",
147219
"88D93DA5EB8EDC391EE3726CDCF4613C539F7D23E8702200CB31B5ED5B06E5CA",
148220
"3E520968399B4017BF98A864FABA2B647EFC4998B56774D4F2CB026BC024A336"
149221
]);
222+
223+
let (res, is_some) = a.inv_odd_mod(&m);
224+
assert!(is_some.is_true_vartime());
225+
assert_eq!(res, expected);
226+
227+
// Even though it is less efficient, it still works
228+
let (res, is_some) = a.inv_mod(&m);
229+
assert!(is_some.is_true_vartime());
230+
assert_eq!(res, expected);
231+
}
232+
233+
#[test]
234+
fn test_invert_even() {
235+
let a = U1024::from_be_hex(concat![
236+
"000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
237+
"347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
238+
"BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
239+
"382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
240+
]);
241+
let m = U1024::from_be_hex(concat![
242+
"D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F",
243+
"37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A",
244+
"D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
245+
"558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156000"
246+
]);
247+
let expected = U1024::from_be_hex(concat![
248+
"1EBF391306817E1BC610E213F4453AD70911CCBD59A901B2A468A4FC1D64F357",
249+
"DBFC6381EC5635CAA664DF280028AF4651482C77A143DF38D6BFD4D64B6C0225",
250+
"FC0E199B15A64966FB26D88A86AD144271F6BDCD3D63193AB2B3CC53B99F21A3",
251+
"5B9BFAE5D43C6BC6E7A9856C71C7318C76530E9E5AE35882D5ABB02F1696874D",
252+
]);
253+
254+
let (res, is_some) = a.inv_mod(&m);
150255
assert!(is_some.is_true_vartime());
151256
assert_eq!(res, expected);
152257
}

src/uint/modular/constant_mod/macros.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ macro_rules! impl_modulus {
3232
const MOD_NEG_INV: $crate::Limb = $crate::Limb(
3333
$crate::Word::MIN.wrapping_sub(
3434
Self::MODULUS
35-
.inv_mod2k($crate::Word::BITS as usize)
35+
.inv_mod2k_vartime($crate::Word::BITS as usize)
3636
.as_limbs()[0]
3737
.0,
3838
),

src/uint/modular/runtime_mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ impl<const LIMBS: usize> DynResidueParams<LIMBS> {
4747
// Since we are calculating the inverse modulo (Word::MAX+1),
4848
// we can take the modulo right away and calculate the inverse of the first limb only.
4949
let modulus_lo = Uint::<1>::from_words([modulus.limbs[0].0]);
50-
let mod_neg_inv =
51-
Limb(Word::MIN.wrapping_sub(modulus_lo.inv_mod2k(Word::BITS as usize).limbs[0].0));
50+
let mod_neg_inv = Limb(
51+
Word::MIN.wrapping_sub(modulus_lo.inv_mod2k_vartime(Word::BITS as usize).limbs[0].0),
52+
);
5253

5354
let r3 = montgomery_reduction(&r2.square_wide(), modulus, mod_neg_inv);
5455

0 commit comments

Comments
 (0)