Skip to content

Commit 64add42

Browse files
tarcieriHastD
andauthored
Constant-time square root and division (#376)
Based on PR #277. The constant-time square root algorithm is described here: https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf Co-authored-by: Daniel Hast <[email protected]>
1 parent 7d174f0 commit 64add42

File tree

3 files changed

+175
-53
lines changed

3 files changed

+175
-53
lines changed

src/uint/div.rs

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! [`Uint`] division operations.
22
33
use super::div_limb::{div_rem_limb_with_reciprocal, Reciprocal};
4-
use crate::{CheckedDiv, CtChoice, Limb, NonZero, Uint, Wrapping};
4+
use crate::{CheckedDiv, CtChoice, Limb, NonZero, Uint, Word, Wrapping};
55
use core::ops::{Div, DivAssign, Rem, RemAssign};
66
use subtle::CtOption;
77

@@ -41,6 +41,43 @@ impl<const LIMBS: usize> Uint<LIMBS> {
4141
(quo, rem)
4242
}
4343

44+
/// Computes `self` / `rhs`, returns the quotient (q), remainder (r)
45+
/// and the truthy value for is_some or the falsy value for is_none.
46+
///
47+
/// NOTE: Use only if you need to access const fn. Otherwise use [`Self::div_rem`] because
48+
/// the value for is_some needs to be checked before using `q` and `r`.
49+
///
50+
/// This function is constant-time with respect to both `self` and `rhs`.
51+
#[allow(trivial_numeric_casts)]
52+
pub(crate) const fn const_div_rem(&self, rhs: &Self) -> (Self, Self, CtChoice) {
53+
let mb = rhs.bits();
54+
let mut rem = *self;
55+
let mut quo = Self::ZERO;
56+
let mut c = rhs.shl(Self::BITS - mb);
57+
58+
let mut i = Self::BITS;
59+
let mut done = CtChoice::FALSE;
60+
loop {
61+
let (mut r, borrow) = rem.sbb(&c, Limb::ZERO);
62+
rem = Self::ct_select(&r, &rem, CtChoice::from_word_mask(borrow.0).or(done));
63+
r = quo.bitor(&Self::ONE);
64+
quo = Self::ct_select(&r, &quo, CtChoice::from_word_mask(borrow.0).or(done));
65+
if i == 0 {
66+
break;
67+
}
68+
i -= 1;
69+
// when `i < mb`, the computation is actually done, so we ensure `quo` and `rem`
70+
// aren't modified further (but do the remaining iterations anyway to be constant-time)
71+
done = CtChoice::from_word_lt(i as Word, mb as Word);
72+
c = c.shr_vartime(1);
73+
quo = Self::ct_select(&quo.shl_vartime(1), &quo, done);
74+
}
75+
76+
let is_some = Limb(mb as Word).ct_is_nonzero();
77+
quo = Self::ct_select(&Self::ZERO, &quo, is_some);
78+
(quo, rem, is_some)
79+
}
80+
4481
/// Computes `self` / `rhs`, returns the quotient (q), remainder (r)
4582
/// and the truthy value for is_some or the falsy value for is_none.
4683
///
@@ -51,7 +88,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
5188
///
5289
/// When used with a fixed `rhs`, this function is constant-time with respect
5390
/// to `self`.
54-
pub(crate) const fn ct_div_rem(&self, rhs: &Self) -> (Self, Self, CtChoice) {
91+
pub(crate) const fn const_div_rem_vartime(&self, rhs: &Self) -> (Self, Self, CtChoice) {
5592
let mb = rhs.bits_vartime();
5693
let mut bd = Self::BITS - mb;
5794
let mut rem = *self;
@@ -169,7 +206,14 @@ impl<const LIMBS: usize> Uint<LIMBS> {
169206
/// Computes self / rhs, returns the quotient, remainder.
170207
pub fn div_rem(&self, rhs: &NonZero<Self>) -> (Self, Self) {
171208
// Since `rhs` is nonzero, this should always hold.
172-
let (q, r, _c) = self.ct_div_rem(rhs);
209+
let (q, r, _c) = self.const_div_rem(rhs);
210+
(q, r)
211+
}
212+
213+
/// Computes self / rhs, returns the quotient, remainder. Constant-time only for fixed `rhs`.
214+
pub fn div_rem_vartime(&self, rhs: &NonZero<Self>) -> (Self, Self) {
215+
// Since `rhs` is nonzero, this should always hold.
216+
let (q, r, _c) = self.const_div_rem_vartime(rhs);
173217
(q, r)
174218
}
175219

@@ -186,7 +230,18 @@ impl<const LIMBS: usize> Uint<LIMBS> {
186230
///
187231
/// Panics if `rhs == 0`.
188232
pub const fn wrapping_div(&self, rhs: &Self) -> Self {
189-
let (q, _, c) = self.ct_div_rem(rhs);
233+
let (q, _, c) = self.const_div_rem(rhs);
234+
assert!(c.is_true_vartime(), "divide by zero");
235+
q
236+
}
237+
238+
/// Wrapped division is just normal division i.e. `self` / `rhs`
239+
/// There’s no way wrapping could ever happen.
240+
/// This function exists, so that all operations are accounted for in the wrapping operations.
241+
///
242+
/// Panics if `rhs == 0`. Constant-time only for fixed `rhs`.
243+
pub const fn wrapping_div_vartime(&self, rhs: &Self) -> Self {
244+
let (q, _, c) = self.const_div_rem_vartime(rhs);
190245
assert!(c.is_true_vartime(), "divide by zero");
191246
q
192247
}
@@ -625,7 +680,11 @@ mod tests {
625680
] {
626681
let lhs = U256::from(*n);
627682
let rhs = U256::from(*d);
628-
let (q, r, is_some) = lhs.ct_div_rem(&rhs);
683+
let (q, r, is_some) = lhs.const_div_rem(&rhs);
684+
assert!(is_some.is_true_vartime());
685+
assert_eq!(U256::from(*e), q);
686+
assert_eq!(U256::from(*ee), r);
687+
let (q, r, is_some) = lhs.const_div_rem_vartime(&rhs);
629688
assert!(is_some.is_true_vartime());
630689
assert_eq!(U256::from(*e), q);
631690
assert_eq!(U256::from(*ee), r);
@@ -641,7 +700,10 @@ mod tests {
641700
let den = U256::random(&mut rng).shr_vartime(128);
642701
let n = num.checked_mul(&den);
643702
if n.is_some().into() {
644-
let (q, _, is_some) = n.unwrap().ct_div_rem(&den);
703+
let (q, _, is_some) = n.unwrap().const_div_rem(&den);
704+
assert!(is_some.is_true_vartime());
705+
assert_eq!(q, num);
706+
let (q, _, is_some) = n.unwrap().const_div_rem_vartime(&den);
645707
assert!(is_some.is_true_vartime());
646708
assert_eq!(q, num);
647709
}
@@ -663,15 +725,23 @@ mod tests {
663725

664726
#[test]
665727
fn div_zero() {
666-
let (q, r, is_some) = U256::ONE.ct_div_rem(&U256::ZERO);
728+
let (q, r, is_some) = U256::ONE.const_div_rem(&U256::ZERO);
729+
assert!(!is_some.is_true_vartime());
730+
assert_eq!(q, U256::ZERO);
731+
assert_eq!(r, U256::ONE);
732+
let (q, r, is_some) = U256::ONE.const_div_rem_vartime(&U256::ZERO);
667733
assert!(!is_some.is_true_vartime());
668734
assert_eq!(q, U256::ZERO);
669735
assert_eq!(r, U256::ONE);
670736
}
671737

672738
#[test]
673739
fn div_one() {
674-
let (q, r, is_some) = U256::from(10u8).ct_div_rem(&U256::ONE);
740+
let (q, r, is_some) = U256::from(10u8).const_div_rem(&U256::ONE);
741+
assert!(is_some.is_true_vartime());
742+
assert_eq!(q, U256::from(10u8));
743+
assert_eq!(r, U256::ZERO);
744+
let (q, r, is_some) = U256::from(10u8).const_div_rem_vartime(&U256::ONE);
675745
assert!(is_some.is_true_vartime());
676746
assert_eq!(q, U256::from(10u8));
677747
assert_eq!(r, U256::ZERO);

src/uint/sqrt.rs

Lines changed: 91 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,40 @@
11
//! [`Uint`] square root operations.
22
33
use super::Uint;
4-
use crate::CtChoice;
54
use subtle::{ConstantTimeEq, CtOption};
65

76
impl<const LIMBS: usize> Uint<LIMBS> {
8-
/// See [`Self::sqrt_vartime`].
9-
#[deprecated(
10-
since = "0.5.3",
11-
note = "This functionality will be moved to `sqrt_vartime` in a future release."
12-
)]
7+
/// Computes √(`self`) in constant time.
8+
/// Based on Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
9+
///
10+
/// Callers can check if `self` is a square by squaring the result
1311
pub const fn sqrt(&self) -> Self {
14-
self.sqrt_vartime()
12+
let max_bits = (self.bits() + 1) >> 1;
13+
let cap = Self::ONE.shl(max_bits);
14+
let mut guess = cap; // ≥ √(`self`)
15+
let mut xn = {
16+
let q = self.wrapping_div(&guess);
17+
let t = guess.wrapping_add(&q);
18+
t.shr_vartime(1)
19+
};
20+
21+
// Repeat enough times to guarantee result has stabilized.
22+
// See Hast, "Note on computation of integer square roots" for a proof of this bound.
23+
// https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf
24+
let mut i = 0;
25+
while i < Self::LOG2_BITS {
26+
guess = xn;
27+
xn = {
28+
let (q, _, is_some) = self.const_div_rem(&guess);
29+
let q = Self::ct_select(&Self::ZERO, &q, is_some);
30+
let t = guess.wrapping_add(&q);
31+
t.shr_vartime(1)
32+
};
33+
i += 1;
34+
}
35+
36+
// at least one of `guess` and `xn` is now equal to √(`self`), so return the minimum
37+
Self::ct_select(&guess, &xn, Uint::ct_gt(&guess, &xn))
1538
}
1639

1740
/// Computes √(`self`)
@@ -23,62 +46,49 @@ impl<const LIMBS: usize> Uint<LIMBS> {
2346
let cap = Self::ONE.shl_vartime(max_bits);
2447
let mut guess = cap; // ≥ √(`self`)
2548
let mut xn = {
26-
let q = self.wrapping_div(&guess);
49+
let q = self.wrapping_div_vartime(&guess);
2750
let t = guess.wrapping_add(&q);
2851
t.shr_vartime(1)
2952
};
30-
31-
// If guess increased, the initial guess was low.
32-
// Repeat until reverse course.
33-
while Uint::ct_lt(&guess, &xn).is_true_vartime() {
34-
// Sometimes an increase is too far, especially with large
35-
// powers, and then takes a long time to walk back. The upper
36-
// bound is based on bit size, so saturate on that.
37-
let le = CtChoice::from_u32_le(xn.bits_vartime(), max_bits);
38-
guess = Self::ct_select(&cap, &xn, le);
39-
xn = {
40-
let q = self.wrapping_div(&guess);
41-
let t = guess.wrapping_add(&q);
42-
t.shr_vartime(1)
43-
};
44-
}
53+
// Note, xn <= guess at this point.
4554

4655
// Repeat while guess decreases.
47-
while Uint::ct_gt(&guess, &xn).is_true_vartime() && xn.ct_is_nonzero().is_true_vartime() {
56+
while guess.cmp_vartime(&xn).is_gt() && !xn.cmp_vartime(&Self::ZERO).is_eq() {
4857
guess = xn;
4958
xn = {
50-
let q = self.wrapping_div(&guess);
59+
let q = self.wrapping_div_vartime(&guess);
5160
let t = guess.wrapping_add(&q);
5261
t.shr_vartime(1)
5362
};
5463
}
5564

56-
Self::ct_select(&Self::ZERO, &guess, self.ct_is_nonzero())
65+
if self.ct_is_nonzero().is_true_vartime() {
66+
guess
67+
} else {
68+
Self::ZERO
69+
}
5770
}
5871

59-
/// See [`Self::wrapping_sqrt_vartime`].
60-
#[deprecated(
61-
since = "0.5.3",
62-
note = "This functionality will be moved to `wrapping_sqrt_vartime` in a future release."
63-
)]
72+
/// Wrapped sqrt is just normal √(`self`)
73+
/// There’s no way wrapping could ever happen.
74+
/// This function exists so that all operations are accounted for in the wrapping operations.
6475
pub const fn wrapping_sqrt(&self) -> Self {
65-
self.wrapping_sqrt_vartime()
76+
self.sqrt()
6677
}
6778

6879
/// Wrapped sqrt is just normal √(`self`)
6980
/// There’s no way wrapping could ever happen.
70-
/// This function exists, so that all operations are accounted for in the wrapping operations.
81+
/// This function exists so that all operations are accounted for in the wrapping operations.
7182
pub const fn wrapping_sqrt_vartime(&self) -> Self {
7283
self.sqrt_vartime()
7384
}
7485

75-
/// See [`Self::checked_sqrt_vartime`].
76-
#[deprecated(
77-
since = "0.5.3",
78-
note = "This functionality will be moved to `checked_sqrt_vartime` in a future release."
79-
)]
86+
/// Perform checked sqrt, returning a [`CtOption`] which `is_some`
87+
/// only if the √(`self`)² == self
8088
pub fn checked_sqrt(&self) -> CtOption<Self> {
81-
self.checked_sqrt_vartime()
89+
let r = self.sqrt();
90+
let s = r.wrapping_mul(&r);
91+
CtOption::new(r, ConstantTimeEq::ct_eq(self, &s))
8292
}
8393

8494
/// Perform checked sqrt, returning a [`CtOption`] which `is_some`
@@ -92,7 +102,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
92102

93103
#[cfg(test)]
94104
mod tests {
95-
use crate::{Limb, U256};
105+
use crate::{Limb, U192, U256};
96106

97107
#[cfg(feature = "rand")]
98108
use {
@@ -103,13 +113,35 @@ mod tests {
103113

104114
#[test]
105115
fn edge() {
116+
assert_eq!(U256::ZERO.sqrt(), U256::ZERO);
117+
assert_eq!(U256::ONE.sqrt(), U256::ONE);
118+
let mut half = U256::ZERO;
119+
for i in 0..half.limbs.len() / 2 {
120+
half.limbs[i] = Limb::MAX;
121+
}
122+
assert_eq!(U256::MAX.sqrt(), half);
123+
124+
assert_eq!(
125+
U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d").sqrt(),
126+
U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
127+
);
128+
129+
assert_eq!(
130+
U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
131+
.sqrt(),
132+
U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
133+
);
134+
}
135+
136+
#[test]
137+
fn edge_vartime() {
106138
assert_eq!(U256::ZERO.sqrt_vartime(), U256::ZERO);
107139
assert_eq!(U256::ONE.sqrt_vartime(), U256::ONE);
108140
let mut half = U256::ZERO;
109141
for i in 0..half.limbs.len() / 2 {
110142
half.limbs[i] = Limb::MAX;
111143
}
112-
assert_eq!(U256::MAX.sqrt_vartime(), half,);
144+
assert_eq!(U256::MAX.sqrt_vartime(), half);
113145
}
114146

115147
#[test]
@@ -131,13 +163,28 @@ mod tests {
131163
for (a, e) in &tests {
132164
let l = U256::from(*a);
133165
let r = U256::from(*e);
166+
assert_eq!(l.sqrt(), r);
134167
assert_eq!(l.sqrt_vartime(), r);
168+
assert_eq!(l.checked_sqrt().is_some().unwrap_u8(), 1u8);
135169
assert_eq!(l.checked_sqrt_vartime().is_some().unwrap_u8(), 1u8);
136170
}
137171
}
138172

139173
#[test]
140174
fn nonsquares() {
175+
assert_eq!(U256::from(2u8).sqrt(), U256::from(1u8));
176+
assert_eq!(U256::from(2u8).checked_sqrt().is_some().unwrap_u8(), 0);
177+
assert_eq!(U256::from(3u8).sqrt(), U256::from(1u8));
178+
assert_eq!(U256::from(3u8).checked_sqrt().is_some().unwrap_u8(), 0);
179+
assert_eq!(U256::from(5u8).sqrt(), U256::from(2u8));
180+
assert_eq!(U256::from(6u8).sqrt(), U256::from(2u8));
181+
assert_eq!(U256::from(7u8).sqrt(), U256::from(2u8));
182+
assert_eq!(U256::from(8u8).sqrt(), U256::from(2u8));
183+
assert_eq!(U256::from(10u8).sqrt(), U256::from(3u8));
184+
}
185+
186+
#[test]
187+
fn nonsquares_vartime() {
141188
assert_eq!(U256::from(2u8).sqrt_vartime(), U256::from(1u8));
142189
assert_eq!(
143190
U256::from(2u8).checked_sqrt_vartime().is_some().unwrap_u8(),
@@ -163,14 +210,17 @@ mod tests {
163210
let t = rng.next_u32() as u64;
164211
let s = U256::from(t);
165212
let s2 = s.checked_mul(&s).unwrap();
213+
assert_eq!(s2.sqrt(), s);
166214
assert_eq!(s2.sqrt_vartime(), s);
215+
assert_eq!(s2.checked_sqrt().is_some().unwrap_u8(), 1);
167216
assert_eq!(s2.checked_sqrt_vartime().is_some().unwrap_u8(), 1);
168217
}
169218

170219
for _ in 0..50 {
171220
let s = U256::random(&mut rng);
172221
let mut s2 = U512::ZERO;
173222
s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs);
223+
assert_eq!(s.square().sqrt(), s2);
174224
assert_eq!(s.square().sqrt_vartime(), s2);
175225
}
176226
}

tests/uint_proptests.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,9 @@ proptest! {
195195
if !b_bi.is_zero() {
196196
let expected = to_uint(a_bi / b_bi);
197197
let actual = a.wrapping_div(&b);
198-
199198
assert_eq!(expected, actual);
199+
let actual_vartime = a.wrapping_div_vartime(&b);
200+
assert_eq!(expected, actual_vartime);
200201
}
201202
}
202203

@@ -279,9 +280,10 @@ proptest! {
279280
fn wrapping_sqrt(a in uint()) {
280281
let a_bi = to_biguint(&a);
281282
let expected = to_uint(a_bi.sqrt());
282-
let actual = a.wrapping_sqrt_vartime();
283-
284-
assert_eq!(expected, actual);
283+
let actual_ct = a.wrapping_sqrt();
284+
assert_eq!(expected, actual_ct);
285+
let actual_vartime = a.wrapping_sqrt_vartime();
286+
assert_eq!(expected, actual_vartime);
285287
}
286288

287289
#[test]

0 commit comments

Comments
 (0)