Skip to content

Constant-time square root and division #277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 76 additions & 7 deletions src/uint/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,42 @@ impl<const LIMBS: usize> Uint<LIMBS> {
(quo, rem)
}

/// Computes `self` / `rhs`, returns the quotient (q), remainder (r)
/// and the truthy value for is_some or the falsy value for is_none.
///
/// NOTE: Use only if you need to access const fn. Otherwise use [`Self::div_rem`] because
/// the value for is_some needs to be checked before using `q` and `r`.
///
/// This function is constant-time with respect to both `self` and `rhs`.
pub(crate) const fn const_div_rem(&self, rhs: &Self) -> (Self, Self, CtChoice) {
let mb = rhs.bits();
let mut rem = *self;
let mut quo = Self::ZERO;
let mut c = rhs.shl(Self::BITS - mb);

let mut i = Self::BITS;
let mut done = CtChoice::FALSE;
loop {
let (mut r, borrow) = rem.sbb(&c, Limb::ZERO);
rem = Self::ct_select(&r, &rem, CtChoice::from_mask(borrow.0).or(done));
r = quo.bitor(&Self::ONE);
quo = Self::ct_select(&r, &quo, CtChoice::from_mask(borrow.0).or(done));
if i == 0 {
break;
}
i -= 1;
// when `i < mb`, the computation is actually done, so we ensure `quo` and `rem`
// aren't modified further (but do the remaining iterations anyway to be constant-time)
done = Limb::ct_lt(Limb(i as Word), Limb(mb as Word));
c = c.shr_vartime(1);
quo = Self::ct_select(&quo.shl_vartime(1), &quo, done);
}

let is_some = Limb(mb as Word).ct_is_nonzero();
quo = Self::ct_select(&Self::ZERO, &quo, is_some);
(quo, rem, is_some)
}

/// Computes `self` / `rhs`, returns the quotient (q), remainder (r)
/// and the truthy value for is_some or the falsy value for is_none.
///
Expand All @@ -51,7 +87,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
///
/// When used with a fixed `rhs`, this function is constant-time with respect
/// to `self`.
pub(crate) const fn ct_div_rem(&self, rhs: &Self) -> (Self, Self, CtChoice) {
pub(crate) const fn const_div_rem_vartime(&self, rhs: &Self) -> (Self, Self, CtChoice) {
let mb = rhs.bits_vartime();
let mut bd = Self::BITS - mb;
let mut rem = *self;
Expand Down Expand Up @@ -168,7 +204,14 @@ impl<const LIMBS: usize> Uint<LIMBS> {
/// Computes self / rhs, returns the quotient, remainder.
pub fn div_rem(&self, rhs: &NonZero<Self>) -> (Self, Self) {
// Since `rhs` is nonzero, this should always hold.
let (q, r, _c) = self.ct_div_rem(rhs);
let (q, r, _c) = self.const_div_rem(rhs);
(q, r)
}

/// Computes self / rhs, returns the quotient, remainder. Constant-time only for fixed `rhs`.
pub fn div_rem_vartime(&self, rhs: &NonZero<Self>) -> (Self, Self) {
// Since `rhs` is nonzero, this should always hold.
let (q, r, _c) = self.const_div_rem_vartime(rhs);
(q, r)
}

Expand All @@ -185,7 +228,18 @@ impl<const LIMBS: usize> Uint<LIMBS> {
///
/// Panics if `rhs == 0`.
pub const fn wrapping_div(&self, rhs: &Self) -> Self {
let (q, _, c) = self.ct_div_rem(rhs);
let (q, _, c) = self.const_div_rem(rhs);
assert!(c.is_true_vartime(), "divide by zero");
q
}

/// Wrapped division is just normal division i.e. `self` / `rhs`
/// There’s no way wrapping could ever happen.
/// This function exists, so that all operations are accounted for in the wrapping operations.
///
/// Panics if `rhs == 0`. Constant-time only for fixed `rhs`.
pub const fn wrapping_div_vartime(&self, rhs: &Self) -> Self {
let (q, _, c) = self.const_div_rem_vartime(rhs);
assert!(c.is_true_vartime(), "divide by zero");
q
}
Expand Down Expand Up @@ -609,7 +663,11 @@ mod tests {
] {
let lhs = U256::from(*n);
let rhs = U256::from(*d);
let (q, r, is_some) = lhs.ct_div_rem(&rhs);
let (q, r, is_some) = lhs.const_div_rem(&rhs);
assert!(is_some.is_true_vartime());
assert_eq!(U256::from(*e), q);
assert_eq!(U256::from(*ee), r);
let (q, r, is_some) = lhs.const_div_rem_vartime(&rhs);
assert!(is_some.is_true_vartime());
assert_eq!(U256::from(*e), q);
assert_eq!(U256::from(*ee), r);
Expand All @@ -625,7 +683,10 @@ mod tests {
let den = U256::random(&mut rng).shr_vartime(128);
let n = num.checked_mul(&den);
if n.is_some().into() {
let (q, _, is_some) = n.unwrap().ct_div_rem(&den);
let (q, _, is_some) = n.unwrap().const_div_rem(&den);
assert!(is_some.is_true_vartime());
assert_eq!(q, num);
let (q, _, is_some) = n.unwrap().const_div_rem_vartime(&den);
assert!(is_some.is_true_vartime());
assert_eq!(q, num);
}
Expand All @@ -647,15 +708,23 @@ mod tests {

#[test]
fn div_zero() {
let (q, r, is_some) = U256::ONE.ct_div_rem(&U256::ZERO);
let (q, r, is_some) = U256::ONE.const_div_rem(&U256::ZERO);
assert!(!is_some.is_true_vartime());
assert_eq!(q, U256::ZERO);
assert_eq!(r, U256::ONE);
let (q, r, is_some) = U256::ONE.const_div_rem_vartime(&U256::ZERO);
assert!(!is_some.is_true_vartime());
assert_eq!(q, U256::ZERO);
assert_eq!(r, U256::ONE);
}

#[test]
fn div_one() {
let (q, r, is_some) = U256::from(10u8).ct_div_rem(&U256::ONE);
let (q, r, is_some) = U256::from(10u8).const_div_rem(&U256::ONE);
assert!(is_some.is_true_vartime());
assert_eq!(q, U256::from(10u8));
assert_eq!(r, U256::ZERO);
let (q, r, is_some) = U256::from(10u8).const_div_rem_vartime(&U256::ONE);
assert!(is_some.is_true_vartime());
assert_eq!(q, U256::from(10u8));
assert_eq!(r, U256::ZERO);
Expand Down
110 changes: 72 additions & 38 deletions src/uint/sqrt.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,39 @@
//! [`Uint`] square root operations.

use super::Uint;
use crate::{Limb, Word};
use subtle::{ConstantTimeEq, CtOption};

impl<const LIMBS: usize> Uint<LIMBS> {
/// See [`Self::sqrt_vartime`].
#[deprecated(
since = "0.5.3",
note = "This functionality will be moved to `sqrt_vartime` in a future release."
)]
/// Computes √(`self`) in constant time.
/// Based on Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
///
/// Callers can check if `self` is a square by squaring the result
pub const fn sqrt(&self) -> Self {
self.sqrt_vartime()
let max_bits = (self.bits() + 1) >> 1;
let cap = Self::ONE.shl(max_bits);
let mut guess = cap; // ≥ √(`self`)
let mut xn = {
let q = self.wrapping_div(&guess);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think wrapping_div() is currently constant-time (although it could be made so).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My mistake, I didn't see that wrapping_div is only constant-time for a fixed rhs. Actually, I think there might be a documentation issue here—I can't find anywhere in the public documentation that says this, only the documentation for ct_div_rem (which doesn't appear in the public docs since it's pub(crate)).

What do you think would be the better approach here: making a new function that's like ct_div_rem but constant-time with respect to both inputs, or modifying ct_div_rem to have this stronger constant-time guarantee and moving the "constant-time only for fixed rhs" behavior to a new function?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the name is misleading - it should have had a vartime suffix (and the whole div.rs is kind of a mess in terms of naming - see #268). So the proper way to proceed I think would be to rename the current one to _vartime, and implement a constant-time one in its place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears this is a blocker on merging this PR. I guess we can go ahead and flip over to the v0.6 series per #268 and try to land this PR afterward.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking on this: wrapping_div calls const_div_rem, which claims:

    /// This function is constant-time with respect to both `self` and `rhs`.
    pub(crate) const fn const_div_rem(&self, rhs: &Self) -> (Self, Self, CtChoice) {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which commit are you looking at? The current master has no const_div_rem(), and Uint::wrapping_div() uses ct_div_rem(), which is not constant-time in rhs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR changes wrapping_div to call const_div_rem, which is also added by this PR

let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};

// Repeat enough times to guarantee result has stabilized.
// See Hast, "Note on computation of integer square roots" for a proof of this bound.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When your paper has a more permanent link (e.g. on arxiv), please make a PR referencing it here

let mut i = 0;
while i < usize::BITS - Self::BITS.leading_zeros() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Self::LOG2_BITS could be used here

guess = xn;
xn = {
let (q, _, is_some) = self.const_div_rem(&guess);
let q = Self::ct_select(&Self::ZERO, &q, is_some);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed specifically to handle the case of self == 0?

let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};
i += 1;
}

// at least one of `guess` and `xn` is now equal to √(`self`), so return the minimum
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So at this point guess == x_n, and xn == x_{n+1}, where n = floor(log2(Self::BITS)). But in the paper it says that it should be n = floor(log2(Self::BITS)) + 1 - am I missing something?

Self::ct_select(&guess, &xn, Uint::ct_gt(&guess, &xn))
}

/// Computes √(`self`)
Expand All @@ -23,31 +45,17 @@ impl<const LIMBS: usize> Uint<LIMBS> {
let cap = Self::ONE.shl_vartime(max_bits);
let mut guess = cap; // ≥ √(`self`)
let mut xn = {
let q = self.wrapping_div(&guess);
let q = self.wrapping_div_vartime(&guess);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};

// If guess increased, the initial guess was low.
// Repeat until reverse course.
while Uint::ct_lt(&guess, &xn).is_true_vartime() {
// Sometimes an increase is too far, especially with large
// powers, and then takes a long time to walk back. The upper
// bound is based on bit size, so saturate on that.
let le = Limb::ct_le(Limb(xn.bits_vartime() as Word), Limb(max_bits as Word));
guess = Self::ct_select(&cap, &xn, le);
xn = {
let q = self.wrapping_div(&guess);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};
}
// Note, xn <= guess at this point.

// Repeat while guess decreases.
while Uint::ct_gt(&guess, &xn).is_true_vartime() && xn.ct_is_nonzero().is_true_vartime() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to use ct_gt() and ct_is_nonzero() here, those are constant-time

guess = xn;
xn = {
let q = self.wrapping_div(&guess);
let q = self.wrapping_div_vartime(&guess);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};
Expand All @@ -56,29 +64,26 @@ impl<const LIMBS: usize> Uint<LIMBS> {
Self::ct_select(&Self::ZERO, &guess, self.ct_is_nonzero())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly here, we don't need constant-timeness.

}

/// See [`Self::wrapping_sqrt_vartime`].
#[deprecated(
since = "0.5.3",
note = "This functionality will be moved to `wrapping_sqrt_vartime` in a future release."
)]
/// Wrapped sqrt is just normal √(`self`)
/// There’s no way wrapping could ever happen.
/// This function exists so that all operations are accounted for in the wrapping operations.
pub const fn wrapping_sqrt(&self) -> Self {
self.wrapping_sqrt_vartime()
self.sqrt()
}

/// Wrapped sqrt is just normal √(`self`)
/// There’s no way wrapping could ever happen.
/// This function exists, so that all operations are accounted for in the wrapping operations.
/// This function exists so that all operations are accounted for in the wrapping operations.
pub const fn wrapping_sqrt_vartime(&self) -> Self {
self.sqrt_vartime()
}

/// See [`Self::checked_sqrt_vartime`].
#[deprecated(
since = "0.5.3",
note = "This functionality will be moved to `checked_sqrt_vartime` in a future release."
)]
/// Perform checked sqrt, returning a [`CtOption`] which `is_some`
/// only if the √(`self`)² == self
pub fn checked_sqrt(&self) -> CtOption<Self> {
self.checked_sqrt_vartime()
let r = self.sqrt();
let s = r.wrapping_mul(&r);
CtOption::new(r, ConstantTimeEq::ct_eq(self, &s))
}

/// Perform checked sqrt, returning a [`CtOption`] which `is_some`
Expand All @@ -103,13 +108,24 @@ mod tests {

#[test]
fn edge() {
assert_eq!(U256::ZERO.sqrt(), U256::ZERO);
assert_eq!(U256::ONE.sqrt(), U256::ONE);
let mut half = U256::ZERO;
for i in 0..half.limbs.len() / 2 {
half.limbs[i] = Limb::MAX;
}
assert_eq!(U256::MAX.sqrt(), half);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An idea for the edge case test: a number that actually needs the maximum amount of iterations to converge. According to my tests, the ones before 10000 are 80, 99, 4224, 4355, 4488, 4623, 4760, 4899; but please check independently. (Also an interesting mathematical question - is there some rule for their distribution)

}

#[test]
fn edge_vartime() {
assert_eq!(U256::ZERO.sqrt_vartime(), U256::ZERO);
assert_eq!(U256::ONE.sqrt_vartime(), U256::ONE);
let mut half = U256::ZERO;
for i in 0..half.limbs.len() / 2 {
half.limbs[i] = Limb::MAX;
}
assert_eq!(U256::MAX.sqrt_vartime(), half,);
assert_eq!(U256::MAX.sqrt_vartime(), half);
}

#[test]
Expand All @@ -131,13 +147,28 @@ mod tests {
for (a, e) in &tests {
let l = U256::from(*a);
let r = U256::from(*e);
assert_eq!(l.sqrt(), r);
assert_eq!(l.sqrt_vartime(), r);
assert_eq!(l.checked_sqrt().is_some().unwrap_u8(), 1u8);
assert_eq!(l.checked_sqrt_vartime().is_some().unwrap_u8(), 1u8);
}
}

#[test]
fn nonsquares() {
assert_eq!(U256::from(2u8).sqrt(), U256::from(1u8));
assert_eq!(U256::from(2u8).checked_sqrt().is_some().unwrap_u8(), 0);
assert_eq!(U256::from(3u8).sqrt(), U256::from(1u8));
assert_eq!(U256::from(3u8).checked_sqrt().is_some().unwrap_u8(), 0);
assert_eq!(U256::from(5u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(6u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(7u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(8u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(10u8).sqrt(), U256::from(3u8));
}

#[test]
fn nonsquares_vartime() {
assert_eq!(U256::from(2u8).sqrt_vartime(), U256::from(1u8));
assert_eq!(
U256::from(2u8).checked_sqrt_vartime().is_some().unwrap_u8(),
Expand All @@ -163,14 +194,17 @@ mod tests {
let t = rng.next_u32() as u64;
let s = U256::from(t);
let s2 = s.checked_mul(&s).unwrap();
assert_eq!(s2.sqrt(), s);
assert_eq!(s2.sqrt_vartime(), s);
assert_eq!(s2.checked_sqrt().is_some().unwrap_u8(), 1);
assert_eq!(s2.checked_sqrt_vartime().is_some().unwrap_u8(), 1);
}

for _ in 0..50 {
let s = U256::random(&mut rng);
let mut s2 = U512::ZERO;
s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs);
assert_eq!(s.square().sqrt(), s2);
assert_eq!(s.square().sqrt_vartime(), s2);
}
}
Expand Down
10 changes: 6 additions & 4 deletions tests/proptests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,9 @@ proptest! {
if !b_bi.is_zero() {
let expected = to_uint(a_bi / b_bi);
let actual = a.wrapping_div(&b);

assert_eq!(expected, actual);
let actual_vartime = a.wrapping_div_vartime(&b);
assert_eq!(expected, actual_vartime);
}
}

Expand Down Expand Up @@ -254,9 +255,10 @@ proptest! {
fn wrapping_sqrt(a in uint()) {
let a_bi = to_biguint(&a);
let expected = to_uint(a_bi.sqrt());
let actual = a.wrapping_sqrt_vartime();

assert_eq!(expected, actual);
let actual_ct = a.wrapping_sqrt();
assert_eq!(expected, actual_ct);
let actual_vartime = a.wrapping_sqrt_vartime();
assert_eq!(expected, actual_vartime);
}

#[test]
Expand Down