Skip to content

Constant-time square root and division #376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 78 additions & 8 deletions src/uint/div.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! [`Uint`] division operations.

use super::div_limb::{div_rem_limb_with_reciprocal, Reciprocal};
use crate::{CheckedDiv, CtChoice, Limb, NonZero, Uint, Wrapping};
use crate::{CheckedDiv, CtChoice, Limb, NonZero, Uint, Word, Wrapping};
use core::ops::{Div, DivAssign, Rem, RemAssign};
use subtle::CtOption;

Expand Down Expand Up @@ -41,6 +41,43 @@ impl<const LIMBS: usize> Uint<LIMBS> {
(quo, rem)
}

/// Computes `self` / `rhs`, returns the quotient (q), remainder (r)
/// and the truthy value for is_some or the falsy value for is_none.
///
/// NOTE: Use only if you need to access const fn. Otherwise use [`Self::div_rem`] because
/// the value for is_some needs to be checked before using `q` and `r`.
///
/// This function is constant-time with respect to both `self` and `rhs`.
#[allow(trivial_numeric_casts)]
pub(crate) const fn const_div_rem(&self, rhs: &Self) -> (Self, Self, CtChoice) {
let mb = rhs.bits();
let mut rem = *self;
let mut quo = Self::ZERO;
let mut c = rhs.shl(Self::BITS - mb);

let mut i = Self::BITS;
let mut done = CtChoice::FALSE;
loop {
let (mut r, borrow) = rem.sbb(&c, Limb::ZERO);
rem = Self::ct_select(&r, &rem, CtChoice::from_word_mask(borrow.0).or(done));
r = quo.bitor(&Self::ONE);
quo = Self::ct_select(&r, &quo, CtChoice::from_word_mask(borrow.0).or(done));
if i == 0 {
break;
}
i -= 1;
// when `i < mb`, the computation is actually done, so we ensure `quo` and `rem`
// aren't modified further (but do the remaining iterations anyway to be constant-time)
done = CtChoice::from_word_lt(i as Word, mb as Word);
c = c.shr_vartime(1);
quo = Self::ct_select(&quo.shl_vartime(1), &quo, done);
}

let is_some = Limb(mb as Word).ct_is_nonzero();
quo = Self::ct_select(&Self::ZERO, &quo, is_some);
(quo, rem, is_some)
}

/// Computes `self` / `rhs`, returns the quotient (q), remainder (r)
/// and the truthy value for is_some or the falsy value for is_none.
///
Expand All @@ -51,7 +88,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
///
/// When used with a fixed `rhs`, this function is constant-time with respect
/// to `self`.
pub(crate) const fn ct_div_rem(&self, rhs: &Self) -> (Self, Self, CtChoice) {
pub(crate) const fn const_div_rem_vartime(&self, rhs: &Self) -> (Self, Self, CtChoice) {
let mb = rhs.bits_vartime();
let mut bd = Self::BITS - mb;
let mut rem = *self;
Expand Down Expand Up @@ -169,7 +206,14 @@ impl<const LIMBS: usize> Uint<LIMBS> {
/// Computes self / rhs, returns the quotient, remainder.
pub fn div_rem(&self, rhs: &NonZero<Self>) -> (Self, Self) {
// Since `rhs` is nonzero, this should always hold.
let (q, r, _c) = self.ct_div_rem(rhs);
let (q, r, _c) = self.const_div_rem(rhs);
(q, r)
}

/// Computes self / rhs, returns the quotient, remainder. Constant-time only for fixed `rhs`.
pub fn div_rem_vartime(&self, rhs: &NonZero<Self>) -> (Self, Self) {
// Since `rhs` is nonzero, this should always hold.
let (q, r, _c) = self.const_div_rem_vartime(rhs);
(q, r)
}

Expand All @@ -186,7 +230,18 @@ impl<const LIMBS: usize> Uint<LIMBS> {
///
/// Panics if `rhs == 0`.
pub const fn wrapping_div(&self, rhs: &Self) -> Self {
let (q, _, c) = self.ct_div_rem(rhs);
let (q, _, c) = self.const_div_rem(rhs);
assert!(c.is_true_vartime(), "divide by zero");
q
}

/// Wrapped division is just normal division i.e. `self` / `rhs`
/// There’s no way wrapping could ever happen.
/// This function exists, so that all operations are accounted for in the wrapping operations.
///
/// Panics if `rhs == 0`. Constant-time only for fixed `rhs`.
pub const fn wrapping_div_vartime(&self, rhs: &Self) -> Self {
let (q, _, c) = self.const_div_rem_vartime(rhs);
assert!(c.is_true_vartime(), "divide by zero");
q
}
Expand Down Expand Up @@ -625,7 +680,11 @@ mod tests {
] {
let lhs = U256::from(*n);
let rhs = U256::from(*d);
let (q, r, is_some) = lhs.ct_div_rem(&rhs);
let (q, r, is_some) = lhs.const_div_rem(&rhs);
assert!(is_some.is_true_vartime());
assert_eq!(U256::from(*e), q);
assert_eq!(U256::from(*ee), r);
let (q, r, is_some) = lhs.const_div_rem_vartime(&rhs);
assert!(is_some.is_true_vartime());
assert_eq!(U256::from(*e), q);
assert_eq!(U256::from(*ee), r);
Expand All @@ -641,7 +700,10 @@ mod tests {
let den = U256::random(&mut rng).shr_vartime(128);
let n = num.checked_mul(&den);
if n.is_some().into() {
let (q, _, is_some) = n.unwrap().ct_div_rem(&den);
let (q, _, is_some) = n.unwrap().const_div_rem(&den);
assert!(is_some.is_true_vartime());
assert_eq!(q, num);
let (q, _, is_some) = n.unwrap().const_div_rem_vartime(&den);
assert!(is_some.is_true_vartime());
assert_eq!(q, num);
}
Expand All @@ -663,15 +725,23 @@ mod tests {

#[test]
fn div_zero() {
let (q, r, is_some) = U256::ONE.ct_div_rem(&U256::ZERO);
let (q, r, is_some) = U256::ONE.const_div_rem(&U256::ZERO);
assert!(!is_some.is_true_vartime());
assert_eq!(q, U256::ZERO);
assert_eq!(r, U256::ONE);
let (q, r, is_some) = U256::ONE.const_div_rem_vartime(&U256::ZERO);
assert!(!is_some.is_true_vartime());
assert_eq!(q, U256::ZERO);
assert_eq!(r, U256::ONE);
}

#[test]
fn div_one() {
let (q, r, is_some) = U256::from(10u8).ct_div_rem(&U256::ONE);
let (q, r, is_some) = U256::from(10u8).const_div_rem(&U256::ONE);
assert!(is_some.is_true_vartime());
assert_eq!(q, U256::from(10u8));
assert_eq!(r, U256::ZERO);
let (q, r, is_some) = U256::from(10u8).const_div_rem_vartime(&U256::ONE);
assert!(is_some.is_true_vartime());
assert_eq!(q, U256::from(10u8));
assert_eq!(r, U256::ZERO);
Expand Down
132 changes: 91 additions & 41 deletions src/uint/sqrt.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,40 @@
//! [`Uint`] square root operations.

use super::Uint;
use crate::CtChoice;
use subtle::{ConstantTimeEq, CtOption};

impl<const LIMBS: usize> Uint<LIMBS> {
/// See [`Self::sqrt_vartime`].
#[deprecated(
since = "0.5.3",
note = "This functionality will be moved to `sqrt_vartime` in a future release."
)]
/// Computes √(`self`) in constant time.
/// Based on Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
///
/// Callers can check if `self` is a square by squaring the result
pub const fn sqrt(&self) -> Self {
self.sqrt_vartime()
let max_bits = (self.bits() + 1) >> 1;
let cap = Self::ONE.shl(max_bits);
let mut guess = cap; // ≥ √(`self`)
let mut xn = {
let q = self.wrapping_div(&guess);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};

// Repeat enough times to guarantee result has stabilized.
// See Hast, "Note on computation of integer square roots" for a proof of this bound.
// https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf
let mut i = 0;
while i < Self::LOG2_BITS {
guess = xn;
xn = {
let (q, _, is_some) = self.const_div_rem(&guess);
let q = Self::ct_select(&Self::ZERO, &q, is_some);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};
i += 1;
}

// at least one of `guess` and `xn` is now equal to √(`self`), so return the minimum
Self::ct_select(&guess, &xn, Uint::ct_gt(&guess, &xn))
}

/// Computes √(`self`)
Expand All @@ -23,62 +46,49 @@ impl<const LIMBS: usize> Uint<LIMBS> {
let cap = Self::ONE.shl_vartime(max_bits);
let mut guess = cap; // ≥ √(`self`)
let mut xn = {
let q = self.wrapping_div(&guess);
let q = self.wrapping_div_vartime(&guess);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};

// If guess increased, the initial guess was low.
// Repeat until reverse course.
while Uint::ct_lt(&guess, &xn).is_true_vartime() {
// Sometimes an increase is too far, especially with large
// powers, and then takes a long time to walk back. The upper
// bound is based on bit size, so saturate on that.
let le = CtChoice::from_u32_le(xn.bits_vartime(), max_bits);
guess = Self::ct_select(&cap, &xn, le);
xn = {
let q = self.wrapping_div(&guess);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};
}
// Note, xn <= guess at this point.

// Repeat while guess decreases.
while Uint::ct_gt(&guess, &xn).is_true_vartime() && xn.ct_is_nonzero().is_true_vartime() {
while guess.cmp_vartime(&xn).is_gt() && !xn.cmp_vartime(&Self::ZERO).is_eq() {
guess = xn;
xn = {
let q = self.wrapping_div(&guess);
let q = self.wrapping_div_vartime(&guess);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};
}

Self::ct_select(&Self::ZERO, &guess, self.ct_is_nonzero())
if self.ct_is_nonzero().is_true_vartime() {
guess
} else {
Self::ZERO
}
}

/// See [`Self::wrapping_sqrt_vartime`].
#[deprecated(
since = "0.5.3",
note = "This functionality will be moved to `wrapping_sqrt_vartime` in a future release."
)]
/// Wrapped sqrt is just normal √(`self`)
/// There’s no way wrapping could ever happen.
/// This function exists so that all operations are accounted for in the wrapping operations.
pub const fn wrapping_sqrt(&self) -> Self {
self.wrapping_sqrt_vartime()
self.sqrt()
}

/// Wrapped sqrt is just normal √(`self`)
/// There’s no way wrapping could ever happen.
/// This function exists, so that all operations are accounted for in the wrapping operations.
/// This function exists so that all operations are accounted for in the wrapping operations.
pub const fn wrapping_sqrt_vartime(&self) -> Self {
self.sqrt_vartime()
}

/// See [`Self::checked_sqrt_vartime`].
#[deprecated(
since = "0.5.3",
note = "This functionality will be moved to `checked_sqrt_vartime` in a future release."
)]
/// Perform checked sqrt, returning a [`CtOption`] which `is_some`
/// only if the √(`self`)² == self
pub fn checked_sqrt(&self) -> CtOption<Self> {
self.checked_sqrt_vartime()
let r = self.sqrt();
let s = r.wrapping_mul(&r);
CtOption::new(r, ConstantTimeEq::ct_eq(self, &s))
}

/// Perform checked sqrt, returning a [`CtOption`] which `is_some`
Expand All @@ -92,7 +102,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {

#[cfg(test)]
mod tests {
use crate::{Limb, U256};
use crate::{Limb, U192, U256};

#[cfg(feature = "rand")]
use {
Expand All @@ -103,13 +113,35 @@ mod tests {

#[test]
fn edge() {
assert_eq!(U256::ZERO.sqrt(), U256::ZERO);
assert_eq!(U256::ONE.sqrt(), U256::ONE);
let mut half = U256::ZERO;
for i in 0..half.limbs.len() / 2 {
half.limbs[i] = Limb::MAX;
}
assert_eq!(U256::MAX.sqrt(), half);

assert_eq!(
U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d").sqrt(),
U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
);

assert_eq!(
U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
.sqrt(),
U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
);
}

#[test]
fn edge_vartime() {
assert_eq!(U256::ZERO.sqrt_vartime(), U256::ZERO);
assert_eq!(U256::ONE.sqrt_vartime(), U256::ONE);
let mut half = U256::ZERO;
for i in 0..half.limbs.len() / 2 {
half.limbs[i] = Limb::MAX;
}
assert_eq!(U256::MAX.sqrt_vartime(), half,);
assert_eq!(U256::MAX.sqrt_vartime(), half);
}

#[test]
Expand All @@ -131,13 +163,28 @@ mod tests {
for (a, e) in &tests {
let l = U256::from(*a);
let r = U256::from(*e);
assert_eq!(l.sqrt(), r);
assert_eq!(l.sqrt_vartime(), r);
assert_eq!(l.checked_sqrt().is_some().unwrap_u8(), 1u8);
assert_eq!(l.checked_sqrt_vartime().is_some().unwrap_u8(), 1u8);
}
}

#[test]
fn nonsquares() {
assert_eq!(U256::from(2u8).sqrt(), U256::from(1u8));
assert_eq!(U256::from(2u8).checked_sqrt().is_some().unwrap_u8(), 0);
assert_eq!(U256::from(3u8).sqrt(), U256::from(1u8));
assert_eq!(U256::from(3u8).checked_sqrt().is_some().unwrap_u8(), 0);
assert_eq!(U256::from(5u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(6u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(7u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(8u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(10u8).sqrt(), U256::from(3u8));
}

#[test]
fn nonsquares_vartime() {
assert_eq!(U256::from(2u8).sqrt_vartime(), U256::from(1u8));
assert_eq!(
U256::from(2u8).checked_sqrt_vartime().is_some().unwrap_u8(),
Expand All @@ -163,14 +210,17 @@ mod tests {
let t = rng.next_u32() as u64;
let s = U256::from(t);
let s2 = s.checked_mul(&s).unwrap();
assert_eq!(s2.sqrt(), s);
assert_eq!(s2.sqrt_vartime(), s);
assert_eq!(s2.checked_sqrt().is_some().unwrap_u8(), 1);
assert_eq!(s2.checked_sqrt_vartime().is_some().unwrap_u8(), 1);
}

for _ in 0..50 {
let s = U256::random(&mut rng);
let mut s2 = U512::ZERO;
s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs);
assert_eq!(s.square().sqrt(), s2);
assert_eq!(s.square().sqrt_vartime(), s2);
}
}
Expand Down
10 changes: 6 additions & 4 deletions tests/uint_proptests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,9 @@ proptest! {
if !b_bi.is_zero() {
let expected = to_uint(a_bi / b_bi);
let actual = a.wrapping_div(&b);

assert_eq!(expected, actual);
let actual_vartime = a.wrapping_div_vartime(&b);
assert_eq!(expected, actual_vartime);
}
}

Expand Down Expand Up @@ -279,9 +280,10 @@ proptest! {
fn wrapping_sqrt(a in uint()) {
let a_bi = to_biguint(&a);
let expected = to_uint(a_bi.sqrt());
let actual = a.wrapping_sqrt_vartime();

assert_eq!(expected, actual);
let actual_ct = a.wrapping_sqrt();
assert_eq!(expected, actual_ct);
let actual_vartime = a.wrapping_sqrt_vartime();
assert_eq!(expected, actual_vartime);
}

#[test]
Expand Down