Skip to content

Commit c966aea

Browse files
authored
Make BoxedUint::rem_vartime infallible using NonZero (#336)
Now that #335 has been merged allowing `NonZero<BoxedUint>` to be expressed, we can use it as the operand for remainder, eliminating the need to handle division by zero at the type system level. This dramatically simplifies the implementation of `BoxedResidueParams::new` since now all of the remainder calculations are infallible, and is more consistent with the other remainder functions.
1 parent baaba0c commit c966aea

File tree

11 files changed

+92
-152
lines changed

11 files changed

+92
-152
lines changed

src/boxed.rs

Lines changed: 0 additions & 3 deletions
This file was deleted.

src/modular/boxed_residue.rs

Lines changed: 29 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
mod mul;
55

66
use super::reduction::montgomery_reduction_boxed;
7-
use crate::{BoxedUint, Limb, Word};
8-
use subtle::{Choice, CtOption};
7+
use crate::{BoxedUint, Limb, NonZero, Word};
8+
use subtle::CtOption;
99

1010
#[cfg(feature = "zeroize")]
1111
use zeroize::Zeroize;
@@ -37,19 +37,23 @@ impl BoxedResidueParams {
3737
let bits_precision = modulus.bits_precision();
3838
let is_odd = modulus.is_odd();
3939

40-
let r = BoxedUint::ct_map(
41-
bits_precision,
42-
BoxedUint::max(bits_precision).rem_vartime(&modulus), // TODO(tarcieri): constant time
43-
|r| r.wrapping_add(&BoxedUint::one()),
44-
);
40+
// Use a surrogate value of `1` in case a modulus of `0` is passed.
41+
// This will be rejected by the `is_odd` check above, which will fail and return `None`.
42+
let modulus_nz = NonZero::new(BoxedUint::conditional_select(
43+
&modulus,
44+
&BoxedUint::one_with_precision(modulus.bits_precision()),
45+
modulus.is_zero(),
46+
))
47+
.expect("modulus ensured non-zero");
4548

46-
let r2 = BoxedUint::ct_map(
47-
bits_precision,
48-
BoxedUint::ct_and_then(bits_precision, r.clone(), |r| {
49-
r.square().rem_vartime(&modulus.widen(bits_precision * 2)) // TODO(tarcieri): constant time
50-
}),
51-
|r2| r2.shorten(bits_precision),
52-
);
49+
let r = BoxedUint::max(bits_precision)
50+
.rem_vartime(&modulus_nz)
51+
.wrapping_add(&BoxedUint::one());
52+
53+
let r2 = r
54+
.square()
55+
.rem_vartime(&modulus_nz.widen(bits_precision * 2)) // TODO(tarcieri): constant time
56+
.shorten(bits_precision);
5357

5458
// Since we are calculating the inverse modulo (Word::MAX+1),
5559
// we can take the modulo right away and calculate the inverse of the first limb only.
@@ -58,31 +62,17 @@ impl BoxedResidueParams {
5862
let mod_neg_inv =
5963
Limb(Word::MIN.wrapping_sub(modulus_lo.inv_mod2k(Word::BITS as usize).limbs[0].0));
6064

61-
let r3 = BoxedUint::ct_map(bits_precision, r2.clone(), |r2| {
62-
montgomery_reduction_boxed(&mut r2.square(), &modulus, mod_neg_inv)
63-
});
64-
65-
// Not quite constant time, but shouldn't be an issue in practice, hopefully.
66-
// The branching is just around constructing the return value.
67-
let r = Option::<BoxedUint>::from(r);
68-
let r2 = Option::<BoxedUint>::from(r2);
69-
let r3 = Option::<BoxedUint>::from(r3);
70-
71-
let params = r.and_then(|r| {
72-
r2.and_then(|r2| {
73-
r3.map(|r3| Self {
74-
modulus,
75-
r,
76-
r2,
77-
r3,
78-
mod_neg_inv,
79-
})
80-
})
81-
});
82-
83-
let is_some = Choice::from(params.is_some() as u8);
84-
let placeholder = Self::placeholder(bits_precision);
85-
CtOption::new(params.unwrap_or(placeholder), is_some & is_odd)
65+
let r3 = montgomery_reduction_boxed(&mut r2.square(), &modulus, mod_neg_inv);
66+
67+
let params = Self {
68+
modulus,
69+
r,
70+
r2,
71+
r3,
72+
mod_neg_inv,
73+
};
74+
75+
CtOption::new(params, is_odd)
8676
}
8777

8878
/// Modulus value.
@@ -94,19 +84,6 @@ impl BoxedResidueParams {
9484
pub fn bits_precision(&self) -> usize {
9585
self.modulus.bits_precision()
9686
}
97-
98-
/// Create a placeholder value with the given precision, used as a default for `CtOption`.
99-
fn placeholder(bits_precision: usize) -> Self {
100-
let zero = BoxedUint::zero_with_precision(bits_precision);
101-
102-
Self {
103-
modulus: zero.clone(),
104-
r: zero.clone(),
105-
r2: zero.clone(),
106-
r3: zero,
107-
mod_neg_inv: Limb::ZERO,
108-
}
109-
}
11087
}
11188

11289
/// A residue represented using heap-allocated limbs.

src/non_zero.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use serdect::serde::{
2222

2323
/// Wrapper type for non-zero integers.
2424
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)]
25-
pub struct NonZero<T: Zero>(T);
25+
pub struct NonZero<T: Zero>(pub(crate) T);
2626

2727
impl NonZero<Limb> {
2828
/// Creates a new non-zero limb in a const context.
@@ -49,6 +49,11 @@ where
4949
let is_zero = n.is_zero();
5050
CtOption::new(Self(n), !is_zero)
5151
}
52+
53+
/// Returns the inner value.
54+
pub fn get(self) -> T {
55+
self.0
56+
}
5257
}
5358

5459
impl<T> NonZero<T>

src/uint/add_mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ mod tests {
7777
];
7878

7979
for special in &moduli {
80-
let p = &NonZero::new(Uint::ZERO.wrapping_sub(&Uint::from_word(special.0)))
81-
.unwrap();
80+
let p =
81+
&NonZero::new(Uint::ZERO.wrapping_sub(&Uint::from(special.get()))).unwrap();
8282

8383
let minus_one = p.wrapping_sub(&Uint::ONE);
8484

src/uint/boxed.rs

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ mod bit_and;
66
mod bit_or;
77
mod bits;
88
mod cmp;
9-
mod ct;
109
mod div;
1110
pub(crate) mod encoding;
1211
mod inv_mod;
@@ -16,7 +15,7 @@ mod shr;
1615
mod sub;
1716
mod sub_mod;
1817

19-
use crate::{Limb, Uint, Word, Zero, U128, U64};
18+
use crate::{Limb, NonZero, Uint, Word, Zero, U128, U64};
2019
use alloc::{boxed::Box, vec, vec::Vec};
2120
use core::fmt;
2221
use subtle::{Choice, ConditionallySelectable};
@@ -32,7 +31,7 @@ use zeroize::Zeroize;
3231
/// Unlike many other heap-allocated big integer libraries, this type is not
3332
/// arbitrary precision and will wrap at its fixed-precision rather than
3433
/// automatically growing.
35-
#[derive(Clone, Default)]
34+
#[derive(Clone)]
3635
pub struct BoxedUint {
3736
/// Boxed slice containing limbs.
3837
///
@@ -43,7 +42,9 @@ pub struct BoxedUint {
4342
impl BoxedUint {
4443
/// Get the value `0` represented as succinctly as possible.
4544
pub fn zero() -> Self {
46-
Self::default()
45+
Self {
46+
limbs: vec![Limb::ZERO; 1].into(),
47+
}
4748
}
4849

4950
/// Get the value `0` with the given number of bits of precision.
@@ -247,6 +248,15 @@ impl BoxedUint {
247248
}
248249
}
249250

251+
impl NonZero<BoxedUint> {
252+
/// Widen this type's precision to the given number of bits.
253+
///
254+
/// See [`BoxedUint::widen`] for more information, including panic conditions.
255+
pub fn widen(&self, bits_precision: usize) -> Self {
256+
NonZero(self.0.widen(bits_precision))
257+
}
258+
}
259+
250260
impl AsRef<[Word]> for BoxedUint {
251261
fn as_ref(&self) -> &[Word] {
252262
self.as_words()
@@ -271,15 +281,9 @@ impl AsMut<[Limb]> for BoxedUint {
271281
}
272282
}
273283

274-
impl fmt::Debug for BoxedUint {
275-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276-
write!(f, "BoxedUint(0x{self:X})")
277-
}
278-
}
279-
280-
impl fmt::Display for BoxedUint {
281-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
282-
fmt::UpperHex::fmt(self, f)
284+
impl Default for BoxedUint {
285+
fn default() -> Self {
286+
Self::zero()
283287
}
284288
}
285289

@@ -355,6 +359,25 @@ impl Zero for BoxedUint {
355359
}
356360
}
357361

362+
#[cfg(feature = "zeroize")]
363+
impl Zeroize for BoxedUint {
364+
fn zeroize(&mut self) {
365+
self.limbs.zeroize();
366+
}
367+
}
368+
369+
impl fmt::Debug for BoxedUint {
370+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
371+
write!(f, "BoxedUint(0x{self:X})")
372+
}
373+
}
374+
375+
impl fmt::Display for BoxedUint {
376+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
377+
fmt::UpperHex::fmt(self, f)
378+
}
379+
}
380+
358381
impl fmt::LowerHex for BoxedUint {
359382
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
360383
if self.limbs.is_empty() {
@@ -381,13 +404,6 @@ impl fmt::UpperHex for BoxedUint {
381404
}
382405
}
383406

384-
#[cfg(feature = "zeroize")]
385-
impl Zeroize for BoxedUint {
386-
fn zeroize(&mut self) {
387-
self.limbs.zeroize();
388-
}
389-
}
390-
391407
#[cfg(test)]
392408
mod tests {
393409
use super::BoxedUint;

src/uint/boxed/ct.rs

Lines changed: 0 additions & 56 deletions
This file was deleted.

src/uint/boxed/div.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
//! [`BoxedUint`] division operations.
22
3-
use crate::{BoxedUint, Limb};
4-
use subtle::{ConstantTimeEq, CtOption};
3+
use crate::{BoxedUint, Limb, NonZero};
4+
use subtle::ConstantTimeEq;
55

66
impl BoxedUint {
77
/// Computes self % rhs, returns the remainder.
88
///
99
/// Variable-time with respect to `rhs`.
1010
///
11+
/// # Panics
12+
///
1113
/// Panics if `self` and `rhs` have different precisions.
12-
// TODO(tarcieri): make infallible by making `rhs` into `NonZero`; don't panic
13-
pub fn rem_vartime(&self, rhs: &Self) -> CtOption<Self> {
14+
// TODO(tarcieri): handle different precisions without panicking
15+
pub fn rem_vartime(&self, rhs: &NonZero<Self>) -> Self {
1416
debug_assert_eq!(self.nlimbs(), rhs.nlimbs());
1517
let mb = rhs.bits();
1618
let mut bd = self.bits_precision() - mb;
@@ -21,24 +23,22 @@ impl BoxedUint {
2123
let (r, borrow) = rem.sbb(&c, Limb::ZERO);
2224
rem = Self::conditional_select(&r, &rem, !borrow.ct_eq(&Limb::ZERO));
2325
if bd == 0 {
24-
break;
26+
break rem;
2527
}
2628
bd -= 1;
2729
c = c.shr_vartime(1);
2830
}
29-
30-
CtOption::new(rem, !(mb as u32).ct_eq(&0))
3131
}
3232
}
3333

3434
#[cfg(test)]
3535
mod tests {
36-
use super::BoxedUint;
36+
use super::{BoxedUint, NonZero};
3737

3838
#[test]
3939
fn rem_vartime() {
4040
let n = BoxedUint::from(0xFFEECCBBAA99887766u128);
41-
let p = BoxedUint::from(997u128);
42-
assert_eq!(BoxedUint::from(648u128), n.rem_vartime(&p).unwrap());
41+
let p = NonZero::new(BoxedUint::from(997u128)).unwrap();
42+
assert_eq!(BoxedUint::from(648u128), n.rem_vartime(&p));
4343
}
4444
}

src/uint/mul_mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ mod tests {
106106
];
107107

108108
for special in &moduli {
109-
let p = &NonZero::new(Uint::ZERO.wrapping_sub(&Uint::from_word(special.0)))
110-
.unwrap();
109+
let p =
110+
&NonZero::new(Uint::ZERO.wrapping_sub(&Uint::from(special.get()))).unwrap();
111111

112112
let minus_one = p.wrapping_sub(&Uint::ONE);
113113

src/uint/sub_mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ mod tests {
140140
];
141141

142142
for special in &moduli {
143-
let p = &NonZero::new(Uint::ZERO.wrapping_sub(&Uint::from_word(special.0)))
144-
.unwrap();
143+
let p =
144+
&NonZero::new(Uint::ZERO.wrapping_sub(&Uint::from(special.get()))).unwrap();
145145

146146
let minus_one = p.wrapping_sub(&Uint::ONE);
147147

0 commit comments

Comments
 (0)