Skip to content

Commit cf4eb1d

Browse files
committed
Mark the current inv_mod2k() as vartime and add an actual constant-time variant
1 parent d5e8906 commit cf4eb1d

File tree

6 files changed

+138
-20
lines changed

6 files changed

+138
-20
lines changed

src/ct_choice.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@ impl CtChoice {
2929
Self(value.wrapping_neg())
3030
}
3131

32+
/// Returns the truthy value if `value != 0`, and the falsy value otherwise.
33+
pub(crate) const fn from_usize_being_nonzero(value: usize) -> Self {
34+
const HI_BIT: u32 = usize::BITS - 1;
35+
Self::from_lsb(((value | value.wrapping_neg()) >> HI_BIT) as Word)
36+
}
37+
38+
/// Returns the truthy value if `x == y`, and the falsy value otherwise.
39+
pub(crate) const fn from_usize_equality(x: usize, y: usize) -> Self {
40+
Self::from_usize_being_nonzero(x.wrapping_sub(y)).not()
41+
}
42+
3243
/// Returns the truthy value if `x < y`, and the falsy value otherwise.
3344
pub(crate) const fn from_usize_lt(x: usize, y: usize) -> Self {
3445
let bit = (((!x) & y) | (((!x) | y) & (x.wrapping_sub(y)))) >> (usize::BITS - 1);

src/uint/bits.rs

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
6969
/// Get the value of the bit at position `index`, as a truthy or falsy `CtChoice`.
7070
/// Returns the falsy value for indices out of range.
7171
pub const fn bit(&self, index: usize) -> CtChoice {
72-
let limb_num = Limb((index / Limb::BITS) as Word);
72+
let limb_num = index / Limb::BITS;
7373
let index_in_limb = index % Limb::BITS;
7474
let index_mask = 1 << index_in_limb;
7575

@@ -79,18 +79,36 @@ impl<const LIMBS: usize> Uint<LIMBS> {
7979
let mut i = 0;
8080
while i < LIMBS {
8181
let bit = limbs[i] & index_mask;
82-
let is_right_limb = Limb::ct_eq(limb_num, Limb(i as Word));
82+
let is_right_limb = CtChoice::from_usize_equality(i, limb_num);
8383
result |= is_right_limb.if_true(bit);
8484
i += 1;
8585
}
8686

8787
CtChoice::from_lsb(result >> index_in_limb)
8888
}
89+
90+
/// Sets the bit at `index` to 0 or 1 depending on the value of `bit_value`.
91+
pub(crate) const fn set_bit(self, index: usize, bit_value: CtChoice) -> Self {
92+
let mut result = self;
93+
let limb_num = index / Limb::BITS;
94+
let index_in_limb = index % Limb::BITS;
95+
let index_mask = 1 << index_in_limb;
96+
97+
let mut i = 0;
98+
while i < LIMBS {
99+
let is_right_limb = CtChoice::from_usize_equality(i, limb_num);
100+
let old_limb = result.limbs[i].0;
101+
let new_limb = bit_value.select(old_limb & !index_mask, old_limb | index_mask);
102+
result.limbs[i] = Limb(is_right_limb.select(old_limb, new_limb));
103+
i += 1;
104+
}
105+
result
106+
}
89107
}
90108

91109
#[cfg(test)]
92110
mod tests {
93-
use crate::U256;
111+
use crate::{CtChoice, U256};
94112

95113
fn uint_with_bits_at(positions: &[usize]) -> U256 {
96114
let mut result = U256::ZERO;
@@ -159,4 +177,31 @@ mod tests {
159177
let u = U256::ZERO;
160178
assert_eq!(u.trailing_zeros() as u32, 256);
161179
}
180+
181+
#[test]
182+
fn set_bit() {
183+
let u = uint_with_bits_at(&[16, 79, 150]);
184+
assert_eq!(
185+
u.set_bit(127, CtChoice::TRUE),
186+
uint_with_bits_at(&[16, 79, 127, 150])
187+
);
188+
189+
let u = uint_with_bits_at(&[16, 79, 150]);
190+
assert_eq!(
191+
u.set_bit(150, CtChoice::TRUE),
192+
uint_with_bits_at(&[16, 79, 150])
193+
);
194+
195+
let u = uint_with_bits_at(&[16, 79, 150]);
196+
assert_eq!(
197+
u.set_bit(127, CtChoice::FALSE),
198+
uint_with_bits_at(&[16, 79, 150])
199+
);
200+
201+
let u = uint_with_bits_at(&[16, 79, 150]);
202+
assert_eq!(
203+
u.set_bit(150, CtChoice::FALSE),
204+
uint_with_bits_at(&[16, 79])
205+
);
206+
}
162207
}

src/uint/inv_mod.rs

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,68 @@
11
use super::Uint;
2-
use crate::{CtChoice, Limb};
2+
use crate::CtChoice;
33

44
impl<const LIMBS: usize> Uint<LIMBS> {
5-
/// Computes 1/`self` mod 2^k as specified in Algorithm 4 from
6-
/// A Secure Algorithm for Inversion Modulo 2k by
7-
/// Sadiel de la Fe and Carles Ferrer. See
8-
/// <https://www.mdpi.com/2410-387X/2/3/23>.
5+
/// Computes 1/`self` mod `2^k`.
6+
/// This method is constant-time w.r.t. `self` but not `k`.
97
///
108
/// Conditions: `self` < 2^k and `self` must be odd
11-
pub const fn inv_mod2k(&self, k: usize) -> Self {
12-
let mut x = Self::ZERO;
13-
let mut b = Self::ONE;
9+
pub const fn inv_mod2k_vartime(&self, k: usize) -> Self {
10+
// Using the Algorithm 3 from "A Secure Algorithm for Inversion Modulo 2k"
11+
// by Sadiel de la Fe and Carles Ferrer.
12+
// See <https://www.mdpi.com/2410-387X/2/3/23>.
13+
14+
// Note that we are not using Alrgorithm 4, since we have a different approach
15+
// of enforcing constant-timeness w.r.t. `self`.
16+
17+
let mut x = Self::ZERO; // keeps `x` during iterations
18+
let mut b = Self::ONE; // keeps `b_i` during iterations
1419
let mut i = 0;
1520

1621
while i < k {
17-
let mut x_i = Self::ZERO;
18-
let j = b.limbs[0].0 & 1;
19-
x_i.limbs[0] = Limb(j);
20-
x = x.bitor(&x_i.shl_vartime(i));
22+
// X_i = b_i mod 2
23+
let x_i = b.limbs[0].0 & 1;
24+
let x_i_choice = CtChoice::from_lsb(x_i);
25+
// b_{i+1} = (b_i - a * X_i) / 2
26+
b = Self::ct_select(&b, &b.wrapping_sub(self), x_i_choice).shr_vartime(1);
27+
// Store the X_i bit in the result (x = x | (1 << X_i))
28+
x = x.bitor(&Uint::from_word(x_i).shl_vartime(i));
2129

22-
let t = b.wrapping_sub(self);
23-
b = Self::ct_select(&b, &t, CtChoice::from_lsb(j)).shr_vartime(1);
2430
i += 1;
2531
}
32+
33+
x
34+
}
35+
36+
/// Computes 1/`self` mod `2^k`.
37+
///
38+
/// Conditions: `self` < 2^k and `self` must be odd
39+
pub const fn inv_mod2k(&self, k: usize) -> Self {
40+
// This is the same algorithm as in `inv_mod2k_vartime()`,
41+
// but made constant-time w.r.t `k` as well.
42+
43+
let mut x = Self::ZERO; // keeps `x` during iterations
44+
let mut b = Self::ONE; // keeps `b_i` during iterations
45+
let mut i = 0;
46+
47+
while i < Self::BITS {
48+
// Only iterations for i = 0..k need to change `x`,
49+
// the rest are dummy ones performed for the sake of constant-timeness.
50+
let within_range = CtChoice::from_usize_lt(i, k);
51+
52+
// X_i = b_i mod 2
53+
let x_i = b.limbs[0].0 & 1;
54+
let x_i_choice = CtChoice::from_lsb(x_i);
55+
// b_{i+1} = (b_i - a * X_i) / 2
56+
b = Self::ct_select(&b, &b.wrapping_sub(self), x_i_choice).shr_vartime(1);
57+
58+
// Store the X_i bit in the result (x = x | (1 << X_i))
59+
// Don't change the result in dummy iterations.
60+
let x_i_choice = x_i_choice.and(within_range);
61+
x = x.set_bit(i, x_i_choice);
62+
63+
i += 1;
64+
}
65+
2666
x
2767
}
2868

src/uint/modular/constant_mod/macros.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ macro_rules! impl_modulus {
3232
const MOD_NEG_INV: $crate::Limb = $crate::Limb(
3333
$crate::Word::MIN.wrapping_sub(
3434
Self::MODULUS
35-
.inv_mod2k($crate::Word::BITS as usize)
35+
.inv_mod2k_vartime($crate::Word::BITS as usize)
3636
.as_limbs()[0]
3737
.0,
3838
),

src/uint/modular/runtime_mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ impl<const LIMBS: usize> DynResidueParams<LIMBS> {
4747
// Since we are calculating the inverse modulo (Word::MAX+1),
4848
// we can take the modulo right away and calculate the inverse of the first limb only.
4949
let modulus_lo = Uint::<1>::from_words([modulus.limbs[0].0]);
50-
let mod_neg_inv =
51-
Limb(Word::MIN.wrapping_sub(modulus_lo.inv_mod2k(Word::BITS as usize).limbs[0].0));
50+
let mod_neg_inv = Limb(
51+
Word::MIN.wrapping_sub(modulus_lo.inv_mod2k_vartime(Word::BITS as usize).limbs[0].0),
52+
);
5253

5354
let r3 = montgomery_reduction(&r2.square_wide(), modulus, mod_neg_inv);
5455

tests/proptests.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,27 @@ proptest! {
212212
}
213213
}
214214

215+
#[test]
216+
fn inv_mod2k(a in uint(), k in any::<usize>()) {
217+
let a = a | U256::ONE; // make odd
218+
let k = k % (U256::BITS + 1);
219+
let a_bi = to_biguint(&a);
220+
let m_bi = BigUint::one() << k;
221+
222+
let actual = a.inv_mod2k(k);
223+
let actual_vartime = a.inv_mod2k_vartime(k);
224+
assert_eq!(actual, actual_vartime);
225+
226+
if k == 0 {
227+
assert_eq!(actual, U256::ZERO);
228+
}
229+
else {
230+
let inv_bi = to_biguint(&actual);
231+
let res = (inv_bi * a_bi) % m_bi;
232+
assert_eq!(res, BigUint::one());
233+
}
234+
}
235+
215236
#[test]
216237
fn wrapping_sqrt(a in uint()) {
217238
let a_bi = to_biguint(&a);

0 commit comments

Comments
 (0)