diff --git a/benches/int.rs b/benches/int.rs index e3c17ffc6..d9f9eaba9 100644 --- a/benches/int.rs +++ b/benches/int.rs @@ -1,12 +1,10 @@ use std::ops::Div; use criterion::{BatchSize, Criterion, black_box, criterion_group, criterion_main}; -use num_traits::WrappingSub; +use crypto_bigint::{I128, I256, I512, I1024, I2048, I4096, NonZero, Random}; use rand_chacha::ChaChaRng; use rand_core::SeedableRng; -use crypto_bigint::{I128, I256, I512, I1024, I2048, I4096, NonZero, Random}; - fn bench_mul(c: &mut Criterion) { let mut rng = ChaChaRng::from_os_rng(); let mut group = c.benchmark_group("wrapping ops"); diff --git a/benches/uint.rs b/benches/uint.rs index 7bbc9e172..19b1853df 100644 --- a/benches/uint.rs +++ b/benches/uint.rs @@ -1,7 +1,10 @@ -use criterion::{BatchSize, Criterion, black_box, criterion_group, criterion_main}; +use criterion::measurement::WallTime; +use criterion::{ + BatchSize, BenchmarkGroup, BenchmarkId, Criterion, black_box, criterion_group, criterion_main, +}; use crypto_bigint::{ - Limb, NonZero, Odd, Random, RandomBits, RandomMod, Reciprocal, U128, U256, U512, U1024, U2048, - U4096, Uint, + Gcd, Limb, NonZero, Odd, Random, RandomBits, RandomMod, Reciprocal, U128, U256, U512, U1024, + U2048, U4096, Uint, }; use rand_chacha::ChaCha8Rng; use rand_core::{RngCore, SeedableRng}; @@ -325,33 +328,88 @@ fn bench_division(c: &mut Criterion) { group.finish(); } -fn bench_gcd(c: &mut Criterion) { - let mut rng = make_rng(); - let mut group = c.benchmark_group("greatest common divisor"); - - group.bench_function("gcd, U256", |b| { +fn gcd_bench(g: &mut BenchmarkGroup, rng: &mut impl RngCore) +where + Uint: Gcd>, +{ + g.bench_function(BenchmarkId::new("gcd", LIMBS), |b| { b.iter_batched( - || { - let f = U256::random(&mut rng); - let g = U256::random(&mut rng); - (f, g) - }, + || (Uint::::random(rng), Uint::::random(rng)), |(f, g)| black_box(f.gcd(&g)), BatchSize::SmallInput, ) }); + g.bench_function(BenchmarkId::new("bingcd", LIMBS), |b| { + b.iter_batched( + || (Uint::::random(rng), Uint::::random(rng)), + |(f, g)| black_box(Uint::bingcd(&f, &g)), + BatchSize::SmallInput, + ) + }); + g.bench_function(BenchmarkId::new("bingcd (classic)", LIMBS), |b| { + b.iter_batched( + || (Odd::>::random(rng), Uint::::random(rng)), + |(f, g)| black_box(f.classic_bingcd(&g)), + BatchSize::SmallInput, + ) + }); + g.bench_function(BenchmarkId::new("bingcd (optimized)", LIMBS), |b| { + b.iter_batched( + || (Odd::>::random(rng), Uint::::random(rng)), + |(f, g)| black_box(f.optimized_bingcd(&g)), + BatchSize::SmallInput, + ) + }); +} - group.bench_function("gcd_vartime, U256", |b| { +fn bench_gcd(c: &mut Criterion) { + let mut rng = make_rng(); + let mut group = c.benchmark_group("greatest common divisor"); + + gcd_bench::<1>(&mut group, &mut rng); + gcd_bench::<2>(&mut group, &mut rng); + gcd_bench::<3>(&mut group, &mut rng); + gcd_bench::<4>(&mut group, &mut rng); + gcd_bench::<5>(&mut group, &mut rng); + gcd_bench::<6>(&mut group, &mut rng); + gcd_bench::<7>(&mut group, &mut rng); + gcd_bench::<8>(&mut group, &mut rng); + gcd_bench::<16>(&mut group, &mut rng); + gcd_bench::<32>(&mut group, &mut rng); + gcd_bench::<64>(&mut group, &mut rng); + gcd_bench::<128>(&mut group, &mut rng); + gcd_bench::<256>(&mut group, &mut rng); + + group.finish(); +} + +fn xgcd_bench(g: &mut BenchmarkGroup, rng: &mut impl RngCore) { + g.bench_function(BenchmarkId::new("binxgcd", LIMBS), |b| { b.iter_batched( - || { - let f = Odd::::random(&mut rng); - let g = U256::random(&mut rng); - (f, g) - }, - |(f, g)| black_box(f.gcd_vartime(&g)), + || (Uint::::random(rng), Uint::::random(rng)), + |(f, g)| black_box(f.binxgcd(&g)), BatchSize::SmallInput, ) }); +} + +fn bench_xgcd(c: &mut Criterion) { + let mut rng = make_rng(); + let mut group = c.benchmark_group("greatest common divisor"); + + xgcd_bench::<1>(&mut group, &mut rng); + xgcd_bench::<2>(&mut group, &mut rng); + xgcd_bench::<3>(&mut group, &mut rng); + xgcd_bench::<4>(&mut group, &mut rng); + xgcd_bench::<5>(&mut group, &mut rng); + xgcd_bench::<6>(&mut group, &mut rng); + xgcd_bench::<7>(&mut group, &mut rng); + xgcd_bench::<8>(&mut group, &mut rng); + xgcd_bench::<16>(&mut group, &mut rng); + xgcd_bench::<32>(&mut group, &mut rng); + xgcd_bench::<64>(&mut group, &mut rng); + xgcd_bench::<128>(&mut group, &mut rng); + xgcd_bench::<256>(&mut group, &mut rng); group.finish(); } @@ -518,6 +576,7 @@ criterion_group!( bench_mul, bench_division, bench_gcd, + bench_xgcd, bench_shl, bench_shr, bench_inv_mod, diff --git a/src/const_choice.rs b/src/const_choice.rs index d35d7282b..e3f7a7843 100644 --- a/src/const_choice.rs +++ b/src/const_choice.rs @@ -64,6 +64,13 @@ impl ConstChoice { Self(value.wrapping_neg() as Word) } + #[inline] + pub(crate) const fn from_u8_lsb(value: u8) -> Self { + debug_assert!(value == 0 || value == 1); + #[allow(trivial_numeric_casts)] + Self((value as Word).wrapping_neg()) + } + #[inline] pub(crate) const fn from_u32_lsb(value: u32) -> Self { debug_assert!(value == 0 || value == 1); @@ -78,6 +85,12 @@ impl ConstChoice { Self((value as Word).wrapping_neg()) } + /// Returns the truthy value if `value != 0`, and the falsy value otherwise. + #[inline] + pub(crate) const fn from_u8_nonzero(value: u8) -> Self { + Self::from_u8_lsb((value | value.wrapping_neg()) >> (u8::BITS - 1)) + } + /// Returns the truthy value if `value != 0`, and the falsy value otherwise. #[inline] pub(crate) const fn from_u32_nonzero(value: u32) -> Self { @@ -174,6 +187,14 @@ impl ConstChoice { Self::from_u64_lt(y, x) } + /// Returns the truthy value if `x = y`, and the falsy value otherwise. + #[inline] + pub(crate) const fn from_i8_eq(x: i8, y: i8) -> Self { + let x = x as u8; + let y = y as u8; + Self::from_u8_nonzero(x ^ y).not() + } + #[inline] pub(crate) const fn not(&self) -> Self { Self(!self.0) @@ -413,6 +434,20 @@ impl ConstCtOption<(Uint, Uint)> { } } +impl ConstCtOption<(Uint, ConstChoice)> { + /// Returns the contained value, consuming the `self` value. + /// + /// # Panics + /// + /// Panics if the value is none with a custom panic message provided by + /// `msg`. + #[inline] + pub const fn expect(self, msg: &str) -> (Uint, ConstChoice) { + assert!(self.is_some.is_true_vartime(), "{}", msg); + self.value + } +} + impl ConstCtOption>> { /// Returns the contained value, consuming the `self` value. /// @@ -461,6 +496,34 @@ impl ConstCtOption> { } } +impl ConstCtOption>> { + /// Returns the contained value, consuming the `self` value. + /// + /// # Panics + /// + /// Panics if the value is none with a custom panic message provided by + /// `msg`. + #[inline] + pub const fn expect(self, msg: &str) -> NonZero> { + assert!(self.is_some.is_true_vartime(), "{}", msg); + self.value + } +} + +impl ConstCtOption>> { + /// Returns the contained value, consuming the `self` value. + /// + /// # Panics + /// + /// Panics if the value is none with a custom panic message provided by + /// `msg`. + #[inline] + pub const fn expect(self, msg: &str) -> Odd> { + assert!(self.is_some.is_true_vartime(), "{}", msg); + self.value + } +} + impl ConstCtOption> { /// Returns the contained value, consuming the `self` value. /// @@ -497,6 +560,19 @@ mod tests { use super::ConstChoice; + #[test] + fn from_u8_nonzero() { + assert_eq!(ConstChoice::from_u8_nonzero(0), ConstChoice::FALSE); + assert_eq!(ConstChoice::from_u8_nonzero(1), ConstChoice::TRUE); + assert_eq!(ConstChoice::from_u8_nonzero(123), ConstChoice::TRUE); + } + + #[test] + fn from_u8_lsb() { + assert_eq!(ConstChoice::from_u8_lsb(0), ConstChoice::FALSE); + assert_eq!(ConstChoice::from_u8_lsb(1), ConstChoice::TRUE); + } + #[test] fn from_u64_lsb() { assert_eq!(ConstChoice::from_u64_lsb(0), ConstChoice::FALSE); @@ -524,6 +600,19 @@ mod tests { assert_eq!(ConstChoice::from_wide_word_le(6, 5), ConstChoice::FALSE); } + #[test] + fn from_i8_eq() { + assert_eq!(ConstChoice::from_i8_eq(-1, -1), ConstChoice::TRUE); + assert_eq!(ConstChoice::from_i8_eq(-1, 0), ConstChoice::FALSE); + assert_eq!(ConstChoice::from_i8_eq(-1, 1), ConstChoice::FALSE); + assert_eq!(ConstChoice::from_i8_eq(0, -1), ConstChoice::FALSE); + assert_eq!(ConstChoice::from_i8_eq(0, 0), ConstChoice::TRUE); + assert_eq!(ConstChoice::from_i8_eq(0, 1), ConstChoice::FALSE); + assert_eq!(ConstChoice::from_i8_eq(1, -1), ConstChoice::FALSE); + assert_eq!(ConstChoice::from_i8_eq(1, 0), ConstChoice::FALSE); + assert_eq!(ConstChoice::from_i8_eq(1, 1), ConstChoice::TRUE); + } + #[test] fn select_u32() { let a: u32 = 1; diff --git a/src/int.rs b/src/int.rs index 553529dbb..577cc794d 100644 --- a/src/int.rs +++ b/src/int.rs @@ -12,6 +12,7 @@ use crate::Encoding; use crate::{Bounded, ConstChoice, ConstCtOption, Constants, Limb, NonZero, Odd, Uint, Word}; mod add; +mod bingcd; mod bit_and; mod bit_not; mod bit_or; @@ -33,6 +34,8 @@ mod sign; mod sub; pub(crate) mod types; +pub use bingcd::{IntBinxgcdOutput, NonZeroIntBinxgcdOutput, OddIntBinxgcdOutput}; + #[cfg(feature = "rand_core")] mod rand; diff --git a/src/int/bingcd.rs b/src/int/bingcd.rs new file mode 100644 index 000000000..44c334478 --- /dev/null +++ b/src/int/bingcd.rs @@ -0,0 +1,517 @@ +//! This module implements (a constant variant of) the Optimized Extended Binary GCD algorithm, +//! which is described by Pornin in "Optimized Binary GCD for Modular Inversion". +//! Ref: + +use crate::modular::bingcd::OddUintBinxgcdOutput; +use crate::modular::bingcd::tools::const_min; +use crate::{ConstChoice, Int, NonZero, Odd, Uint}; + +#[derive(Debug)] +pub struct BaseIntBinxgcdOutput { + pub gcd: T, + pub x: Int, + pub y: Int, + pub lhs_on_gcd: Int, + pub rhs_on_gcd: Int, +} + +/// Output of the Binary XGCD algorithm applied to two [Int]s. +pub type IntBinxgcdOutput = BaseIntBinxgcdOutput, LIMBS>; + +/// Output of the Binary XGCD algorithm applied to two [`NonZero>`]s. +pub type NonZeroIntBinxgcdOutput = + BaseIntBinxgcdOutput>, LIMBS>; + +/// Output of the Binary XGCD algorithm applied to two [`Odd>`]s. +pub type OddIntBinxgcdOutput = BaseIntBinxgcdOutput>, LIMBS>; + +impl BaseIntBinxgcdOutput { + /// Return the quotients `lhs.gcd` and `rhs/gcd`. + pub fn quotients(&self) -> (Int, Int) { + (self.lhs_on_gcd, self.rhs_on_gcd) + } + + /// Provide mutable access to the quotients `lhs.gcd` and `rhs/gcd`. + pub fn quotients_as_mut(&mut self) -> (&mut Int, &mut Int) { + (&mut self.lhs_on_gcd, &mut self.rhs_on_gcd) + } + + /// Return the Bézout coefficients `x` and `y` s.t. `lhs * x + rhs * y = gcd`. + pub fn bezout_coefficients(&self) -> (Int, Int) { + (self.x, self.y) + } + + /// Provide mutable access to the Bézout coefficients. + pub fn bezout_coefficients_as_mut(&mut self) -> (&mut Int, &mut Int) { + (&mut self.x, &mut self.y) + } +} + +impl Int { + /// Compute the gcd of `self` and `rhs` leveraging the Binary GCD algorithm. + pub fn bingcd(&self, rhs: &Self) -> Uint { + self.abs().bingcd(&rhs.abs()) + } + + /// Executes the Binary Extended GCD algorithm. + /// + /// Given `(self, rhs)`, computes `(g, x, y)`, s.t. `self * x + rhs * y = g = gcd(self, rhs)`. + pub fn binxgcd(&self, rhs: &Self) -> IntBinxgcdOutput { + // Make sure `self` and `rhs` are nonzero. + let self_is_zero = self.is_nonzero().not(); + let self_nz = Int::select(self, &Int::ONE, self_is_zero) + .to_nz() + .expect("self is non zero by construction"); + let rhs_is_zero = rhs.is_nonzero().not(); + let rhs_nz = Int::select(rhs, &Int::ONE, rhs_is_zero) + .to_nz() + .expect("rhs is non zero by construction"); + + let NonZeroIntBinxgcdOutput { + gcd, + mut x, + mut y, + mut lhs_on_gcd, + mut rhs_on_gcd, + } = self_nz.binxgcd(&rhs_nz); + + // Correct the gcd in case self and/or rhs was zero + let mut gcd = *gcd.as_ref(); + gcd = Uint::select(&gcd, &rhs.abs(), self_is_zero); + gcd = Uint::select(&gcd, &self.abs(), rhs_is_zero); + + // Correct the Bézout coefficients in case self and/or rhs was zero. + let signum_self = Int::new_from_abs_sign(Uint::ONE, self.is_negative()).expect("+/- 1"); + let signum_rhs = Int::new_from_abs_sign(Uint::ONE, rhs.is_negative()).expect("+/- 1"); + x = Int::select(&x, &Int::ZERO, self_is_zero); + y = Int::select(&y, &signum_rhs, self_is_zero); + x = Int::select(&x, &signum_self, rhs_is_zero); + y = Int::select(&y, &Int::ZERO, rhs_is_zero); + + // Correct the quotients in case self and/or rhs was zero. + lhs_on_gcd = Int::select(&lhs_on_gcd, &signum_self, rhs_is_zero); + lhs_on_gcd = Int::select(&lhs_on_gcd, &Int::ZERO, self_is_zero); + rhs_on_gcd = Int::select(&rhs_on_gcd, &signum_rhs, self_is_zero); + rhs_on_gcd = Int::select(&rhs_on_gcd, &Int::ZERO, rhs_is_zero); + + IntBinxgcdOutput { + gcd, + x, + y, + lhs_on_gcd, + rhs_on_gcd, + } + } +} + +impl NonZero> { + /// Compute the gcd of `self` and `rhs` leveraging the Binary GCD algorithm. + pub fn bingcd(&self, rhs: &Self) -> NonZero> { + self.abs().bingcd(&rhs.as_ref().abs()) + } + + /// Execute the Binary Extended GCD algorithm. + /// + /// Given `(self, rhs)`, computes `(g, x, y)` s.t. `self * x + rhs * y = g = gcd(self, rhs)`. + pub fn binxgcd(&self, rhs: &Self) -> NonZeroIntBinxgcdOutput { + let (mut lhs, mut rhs) = (*self.as_ref(), *rhs.as_ref()); + + // Leverage the property that gcd(2^k * a, 2^k *b) = 2^k * gcd(a, b) + let i = lhs.0.trailing_zeros(); + let j = rhs.0.trailing_zeros(); + let k = const_min(i, j); + lhs = lhs.shr(k); + rhs = rhs.shr(k); + + // Note: at this point, either lhs or rhs is odd (or both). + // Swap to make sure lhs is odd. + let swap = ConstChoice::from_u32_lt(j, i); + Int::conditional_swap(&mut lhs, &mut rhs, swap); + let lhs = lhs.to_odd().expect("odd by construction"); + + let rhs = rhs.to_nz().expect("non-zero by construction"); + let OddIntBinxgcdOutput { + gcd, + mut x, + mut y, + mut lhs_on_gcd, + mut rhs_on_gcd, + } = lhs.binxgcd(&rhs); + + // Account for the parameter swap + Int::conditional_swap(&mut x, &mut y, swap); + Int::conditional_swap(&mut lhs_on_gcd, &mut rhs_on_gcd, swap); + + // Reintroduce the factor 2^k to the gcd. + let gcd = gcd + .as_ref() + .shl(k) + .to_nz() + .expect("is non-zero by construction"); + + NonZeroIntBinxgcdOutput { + gcd, + x, + y, + lhs_on_gcd, + rhs_on_gcd, + } + } +} + +impl Odd> { + /// Compute the gcd of `self` and `rhs` leveraging the Binary GCD algorithm. + pub fn bingcd(&self, rhs: &Self) -> Odd> { + self.abs().bingcd(&rhs.as_ref().abs()) + } + + /// Execute the Binary Extended GCD algorithm. + /// + /// Given `(self, rhs)`, computes `(g, x, y)` s.t. `self * x + rhs * y = g = gcd(self, rhs)`. + pub fn binxgcd(&self, rhs: &NonZero>) -> OddIntBinxgcdOutput { + let (abs_lhs, sgn_lhs) = self.abs_sign(); + let (abs_rhs, sgn_rhs) = rhs.abs_sign(); + + let OddUintBinxgcdOutput { + gcd, + mut x, + mut y, + lhs_on_gcd: abs_lhs_on_gcd, + rhs_on_gcd: abs_rhs_on_gcd, + } = abs_lhs.binxgcd_nz(&abs_rhs); + + x = x.wrapping_neg_if(sgn_lhs); + y = y.wrapping_neg_if(sgn_rhs); + let lhs_on_gcd = Int::new_from_abs_sign(abs_lhs_on_gcd, sgn_lhs).expect("no overflow"); + let rhs_on_gcd = Int::new_from_abs_sign(abs_rhs_on_gcd, sgn_rhs).expect("no overflow"); + + OddIntBinxgcdOutput { + gcd, + x, + y, + lhs_on_gcd, + rhs_on_gcd, + } + } +} + +#[cfg(all(test, not(miri)))] +mod test { + use crate::int::bingcd::{IntBinxgcdOutput, NonZeroIntBinxgcdOutput, OddIntBinxgcdOutput}; + use crate::{ConcatMixed, Int, Uint}; + use num_traits::Zero; + + #[cfg(feature = "rand_core")] + use rand_chacha::ChaChaRng; + #[cfg(feature = "rand_core")] + use rand_core::SeedableRng; + + impl From> for IntBinxgcdOutput { + fn from(value: NonZeroIntBinxgcdOutput) -> Self { + let NonZeroIntBinxgcdOutput { + gcd, + x, + y, + lhs_on_gcd, + rhs_on_gcd, + } = value; + IntBinxgcdOutput { + gcd: *gcd.as_ref(), + x, + y, + lhs_on_gcd, + rhs_on_gcd, + } + } + } + + impl From> for IntBinxgcdOutput { + fn from(value: OddIntBinxgcdOutput) -> Self { + let OddIntBinxgcdOutput { + gcd, + x, + y, + lhs_on_gcd, + rhs_on_gcd, + } = value; + IntBinxgcdOutput { + gcd: *gcd.as_ref(), + x, + y, + lhs_on_gcd, + rhs_on_gcd, + } + } + } + + #[cfg(feature = "rand_core")] + pub(crate) fn make_rng() -> ChaChaRng { + ChaChaRng::from_seed([0; 32]) + } + + fn binxgcd_test( + lhs: Int, + rhs: Int, + output: IntBinxgcdOutput, + ) where + Uint: ConcatMixed, MixedOutput = Uint>, + { + let gcd = lhs.bingcd(&rhs); + assert_eq!(gcd, output.gcd); + + // Test quotients + let (lhs_on_gcd, rhs_on_gcd) = output.quotients(); + if gcd.is_zero() { + assert_eq!(lhs_on_gcd, Int::ZERO); + assert_eq!(rhs_on_gcd, Int::ZERO); + } else { + assert_eq!(lhs_on_gcd, lhs.div_uint(&gcd.to_nz().unwrap())); + assert_eq!(rhs_on_gcd, rhs.div_uint(&gcd.to_nz().unwrap())); + } + + // Test the Bezout coefficients on minimality + let (x, y) = output.bezout_coefficients(); + assert!(x.abs() <= rhs_on_gcd.abs() || rhs_on_gcd.is_zero()); + assert!(y.abs() <= lhs_on_gcd.abs() || lhs_on_gcd.is_zero()); + if lhs.abs() != rhs.abs() { + assert!(x.abs() <= rhs_on_gcd.abs().shr(1) || rhs_on_gcd.is_zero()); + assert!(y.abs() <= lhs_on_gcd.abs().shr(1) || lhs_on_gcd.is_zero()); + } + + // Test the Bezout coefficients for correctness + assert_eq!( + x.widening_mul(&lhs).wrapping_add(&y.widening_mul(&rhs)), + gcd.resize().as_int() + ); + } + + mod test_int_binxgcd { + use crate::int::bingcd::test::binxgcd_test; + use crate::{ + ConcatMixed, Gcd, Int, U64, U128, U192, U256, U384, U512, U768, U1024, U2048, U4096, + U8192, Uint, + }; + + #[cfg(feature = "rand_core")] + use crate::Random; + #[cfg(feature = "rand_core")] + use crate::int::bingcd::test::make_rng; + + fn int_binxgcd_test( + lhs: Int, + rhs: Int, + ) where + Uint: ConcatMixed, MixedOutput = Uint>, + Int: Gcd>, + { + binxgcd_test(lhs, rhs, lhs.binxgcd(&rhs)) + } + + #[cfg(feature = "rand_core")] + fn int_binxgcd_randomized_tests(iterations: u32) + where + Uint: ConcatMixed, MixedOutput = Uint>, + Int: Gcd>, + { + let mut rng = make_rng(); + for _ in 0..iterations { + let x = Int::random(&mut rng); + let y = Int::random(&mut rng); + int_binxgcd_test(x, y); + } + } + + fn int_binxgcd_tests() + where + Uint: ConcatMixed, MixedOutput = Uint>, + Int: Gcd>, + { + int_binxgcd_test(Int::MIN, Int::MIN); + int_binxgcd_test(Int::MIN, Int::MINUS_ONE); + int_binxgcd_test(Int::MIN, Int::ZERO); + int_binxgcd_test(Int::MIN, Int::ONE); + int_binxgcd_test(Int::MIN, Int::MAX); + int_binxgcd_test(Int::ONE, Int::MIN); + int_binxgcd_test(Int::ONE, Int::MINUS_ONE); + int_binxgcd_test(Int::ONE, Int::ZERO); + int_binxgcd_test(Int::ONE, Int::ONE); + int_binxgcd_test(Int::ONE, Int::MAX); + int_binxgcd_test(Int::ZERO, Int::MIN); + int_binxgcd_test(Int::ZERO, Int::MINUS_ONE); + int_binxgcd_test(Int::ZERO, Int::ZERO); + int_binxgcd_test(Int::ZERO, Int::ONE); + int_binxgcd_test(Int::ZERO, Int::MAX); + int_binxgcd_test(Int::ONE, Int::MIN); + int_binxgcd_test(Int::ONE, Int::MINUS_ONE); + int_binxgcd_test(Int::ONE, Int::ZERO); + int_binxgcd_test(Int::ONE, Int::ONE); + int_binxgcd_test(Int::ONE, Int::MAX); + int_binxgcd_test(Int::MAX, Int::MIN); + int_binxgcd_test(Int::MAX, Int::MINUS_ONE); + int_binxgcd_test(Int::MAX, Int::ZERO); + int_binxgcd_test(Int::MAX, Int::ONE); + int_binxgcd_test(Int::MAX, Int::MAX); + + #[cfg(feature = "rand_core")] + int_binxgcd_randomized_tests(100); + } + + #[test] + fn test_int_binxgcd() { + int_binxgcd_tests::<{ U64::LIMBS }, { U128::LIMBS }>(); + int_binxgcd_tests::<{ U128::LIMBS }, { U256::LIMBS }>(); + int_binxgcd_tests::<{ U192::LIMBS }, { U384::LIMBS }>(); + int_binxgcd_tests::<{ U256::LIMBS }, { U512::LIMBS }>(); + int_binxgcd_tests::<{ U384::LIMBS }, { U768::LIMBS }>(); + int_binxgcd_tests::<{ U512::LIMBS }, { U1024::LIMBS }>(); + int_binxgcd_tests::<{ U1024::LIMBS }, { U2048::LIMBS }>(); + int_binxgcd_tests::<{ U2048::LIMBS }, { U4096::LIMBS }>(); + int_binxgcd_tests::<{ U4096::LIMBS }, { U8192::LIMBS }>(); + } + } + + mod test_nonzero_int_binxgcd { + use crate::int::bingcd::test::binxgcd_test; + use crate::{ + ConcatMixed, Gcd, Int, U64, U128, U192, U256, U384, U512, U768, U1024, U2048, U4096, + U8192, Uint, + }; + + #[cfg(feature = "rand_core")] + use crate::{Random, int::bingcd::test::make_rng}; + + fn nz_int_binxgcd_test( + lhs: Int, + rhs: Int, + ) where + Uint: ConcatMixed, MixedOutput = Uint>, + Int: Gcd>, + { + let output = lhs.to_nz().unwrap().binxgcd(&rhs.to_nz().unwrap()); + binxgcd_test(lhs, rhs, output.into()); + } + + #[cfg(feature = "rand_core")] + fn nz_int_binxgcd_randomized_tests(iterations: u32) + where + Uint: ConcatMixed, MixedOutput = Uint>, + Int: Gcd>, + { + let mut rng = make_rng(); + for _ in 0..iterations { + let x = Uint::random(&mut rng).as_int(); + let y = Uint::random(&mut rng).as_int(); + nz_int_binxgcd_test(x, y); + } + } + + fn nz_int_binxgcd_tests() + where + Uint: ConcatMixed, MixedOutput = Uint>, + Int: Gcd>, + { + nz_int_binxgcd_test(Int::MIN, Int::MIN); + nz_int_binxgcd_test(Int::MIN, Int::MINUS_ONE); + nz_int_binxgcd_test(Int::MIN, Int::ONE); + nz_int_binxgcd_test(Int::MIN, Int::MAX); + nz_int_binxgcd_test(Int::MINUS_ONE, Int::MIN); + nz_int_binxgcd_test(Int::MINUS_ONE, Int::MINUS_ONE); + nz_int_binxgcd_test(Int::MINUS_ONE, Int::ONE); + nz_int_binxgcd_test(Int::MINUS_ONE, Int::MAX); + nz_int_binxgcd_test(Int::ONE, Int::MIN); + nz_int_binxgcd_test(Int::ONE, Int::MINUS_ONE); + nz_int_binxgcd_test(Int::ONE, Int::ONE); + nz_int_binxgcd_test(Int::ONE, Int::MAX); + nz_int_binxgcd_test(Int::MAX, Int::MIN); + nz_int_binxgcd_test(Int::MAX, Int::MINUS_ONE); + nz_int_binxgcd_test(Int::MAX, Int::ONE); + nz_int_binxgcd_test(Int::MAX, Int::MAX); + + #[cfg(feature = "rand_core")] + nz_int_binxgcd_randomized_tests(100); + } + + #[test] + fn test_nz_int_binxgcd() { + nz_int_binxgcd_tests::<{ U64::LIMBS }, { U128::LIMBS }>(); + nz_int_binxgcd_tests::<{ U128::LIMBS }, { U256::LIMBS }>(); + nz_int_binxgcd_tests::<{ U192::LIMBS }, { U384::LIMBS }>(); + nz_int_binxgcd_tests::<{ U256::LIMBS }, { U512::LIMBS }>(); + nz_int_binxgcd_tests::<{ U384::LIMBS }, { U768::LIMBS }>(); + nz_int_binxgcd_tests::<{ U512::LIMBS }, { U1024::LIMBS }>(); + nz_int_binxgcd_tests::<{ U1024::LIMBS }, { U2048::LIMBS }>(); + nz_int_binxgcd_tests::<{ U2048::LIMBS }, { U4096::LIMBS }>(); + nz_int_binxgcd_tests::<{ U4096::LIMBS }, { U8192::LIMBS }>(); + } + } + + mod test_odd_int_binxgcd { + use crate::int::bingcd::test::binxgcd_test; + use crate::{ + ConcatMixed, Int, U64, U128, U192, U256, U384, U512, U768, U1024, U2048, U4096, U8192, + Uint, + }; + + #[cfg(feature = "rand_core")] + use crate::{Random, int::bingcd::test::make_rng}; + + fn odd_int_binxgcd_test( + lhs: Int, + rhs: Int, + ) where + Uint: ConcatMixed, MixedOutput = Uint>, + { + let output = lhs.to_odd().unwrap().binxgcd(&rhs.to_nz().unwrap()); + binxgcd_test(lhs, rhs, output.into()); + } + + #[cfg(feature = "rand_core")] + fn odd_int_binxgcd_randomized_tests( + iterations: u32, + ) where + Uint: ConcatMixed, MixedOutput = Uint>, + { + let mut rng = make_rng(); + for _ in 0..iterations { + let x = Int::::random(&mut rng).bitor(&Int::ONE); + let y = Int::::random(&mut rng); + odd_int_binxgcd_test(x, y); + } + } + + fn odd_int_binxgcd_tests() + where + Uint: ConcatMixed, MixedOutput = Uint>, + { + let neg_max = Int::MAX.wrapping_neg(); + odd_int_binxgcd_test(neg_max, neg_max); + odd_int_binxgcd_test(neg_max, Int::MINUS_ONE); + odd_int_binxgcd_test(neg_max, Int::ONE); + odd_int_binxgcd_test(neg_max, Int::MAX); + odd_int_binxgcd_test(Int::ONE, neg_max); + odd_int_binxgcd_test(Int::ONE, Int::MINUS_ONE); + odd_int_binxgcd_test(Int::ONE, Int::ONE); + odd_int_binxgcd_test(Int::ONE, Int::MAX); + odd_int_binxgcd_test(Int::MAX, neg_max); + odd_int_binxgcd_test(Int::MAX, Int::MINUS_ONE); + odd_int_binxgcd_test(Int::MAX, Int::ONE); + odd_int_binxgcd_test(Int::MAX, Int::MAX); + + #[cfg(feature = "rand_core")] + odd_int_binxgcd_randomized_tests(100); + } + + #[test] + fn test_odd_int_binxgcd() { + odd_int_binxgcd_tests::<{ U64::LIMBS }, { U128::LIMBS }>(); + odd_int_binxgcd_tests::<{ U128::LIMBS }, { U256::LIMBS }>(); + odd_int_binxgcd_tests::<{ U192::LIMBS }, { U384::LIMBS }>(); + odd_int_binxgcd_tests::<{ U256::LIMBS }, { U512::LIMBS }>(); + odd_int_binxgcd_tests::<{ U384::LIMBS }, { U768::LIMBS }>(); + odd_int_binxgcd_tests::<{ U512::LIMBS }, { U1024::LIMBS }>(); + odd_int_binxgcd_tests::<{ U1024::LIMBS }, { U2048::LIMBS }>(); + odd_int_binxgcd_tests::<{ U2048::LIMBS }, { U4096::LIMBS }>(); + odd_int_binxgcd_tests::<{ U4096::LIMBS }, { U8192::LIMBS }>(); + } + } +} diff --git a/src/int/cmp.rs b/src/int/cmp.rs index 9f4d3a6e8..ce22989ce 100644 --- a/src/int/cmp.rs +++ b/src/int/cmp.rs @@ -15,6 +15,12 @@ impl Int { Self(Uint::select(&a.0, &b.0, c)) } + /// Swap `a` and `b` if `c` is truthy, otherwise, do nothing. + #[inline] + pub(crate) const fn conditional_swap(a: &mut Self, b: &mut Self, c: ConstChoice) { + Uint::conditional_swap(&mut a.0, &mut b.0, c); + } + /// Returns the truthy value if `self`!=0 or the falsy value otherwise. #[inline] pub(crate) const fn is_nonzero(&self) -> ConstChoice { diff --git a/src/int/gcd.rs b/src/int/gcd.rs index 0df124573..e6eff35ed 100644 --- a/src/int/gcd.rs +++ b/src/int/gcd.rs @@ -37,7 +37,7 @@ where #[cfg(test)] mod tests { - use crate::{Gcd, I256, U256}; + use crate::{Gcd, I64, I256, U256}; #[test] fn gcd_always_positive() { @@ -60,4 +60,10 @@ mod tests { assert_eq!(U256::from(61u32), f.gcd(&g)); assert_eq!(U256::from(61u32), f.wrapping_neg().gcd(&g)); } + + #[test] + fn gcd() { + assert_eq!(I64::MIN.gcd(&I64::ZERO), I64::MIN.abs()); + assert_eq!(I64::ZERO.gcd(&I64::MIN), I64::MIN.abs()); + } } diff --git a/src/int/mul.rs b/src/int/mul.rs index d978bd2d7..6d1628d5f 100644 --- a/src/int/mul.rs +++ b/src/int/mul.rs @@ -4,7 +4,7 @@ use core::ops::{Mul, MulAssign}; use subtle::CtOption; -use crate::{Checked, CheckedMul, ConcatMixed, ConstChoice, ConstCtOption, Int, Uint, Zero}; +use crate::{Checked, CheckedMul, ConcatMixed, ConstChoice, ConstCtOption, Int, Uint}; impl Int { /// Compute "wide" multiplication as a 3-tuple `(lo, hi, negate)`. @@ -49,6 +49,23 @@ impl Int { // always fits Int::from_bits(product_abs.wrapping_neg_if(product_sign)) } + + /// Multiply `self` with `rhs`, returning a [ConstCtOption] that `is_some` only if the result + /// fits in an `Int`. + pub(crate) const fn const_checked_mul( + &self, + rhs: &Int, + ) -> ConstCtOption> { + let (lo, hi, is_negative) = self.split_mul(rhs); + Self::new_from_abs_sign(lo, is_negative).and_choice(hi.is_nonzero().not()) + } + + /// Multiply `self` with `rhs`, returning a [ConstCtOption] that `is_some` only if the result + /// fits in an `Int`. + pub const fn wrapping_mul(&self, rhs: &Int) -> Int { + let (lo, _, is_negative) = self.split_mul(rhs); + Self(lo.wrapping_neg_if(is_negative)) + } } /// Squaring operations. @@ -80,9 +97,7 @@ impl Int { impl CheckedMul> for Int { #[inline] fn checked_mul(&self, rhs: &Int) -> CtOption { - let (lo, hi, is_negative) = self.split_mul(rhs); - let val = Self::new_from_abs_sign(lo, is_negative); - CtOption::from(val).and_then(|int| CtOption::new(int, hi.is_zero())) + Self::const_checked_mul(self, rhs).into() } } @@ -114,7 +129,7 @@ impl Mul<&Int> for &Int; fn mul(self, rhs: &Int) -> Self::Output { - self.checked_mul(rhs) + self.const_checked_mul(rhs) .expect("attempted to multiply with overflow") } } diff --git a/src/int/resize.rs b/src/int/resize.rs index 4aed28991..1dc9cb4d8 100644 --- a/src/int/resize.rs +++ b/src/int/resize.rs @@ -18,7 +18,6 @@ impl Int { #[cfg(test)] mod tests { - use num_traits::WrappingSub; use crate::{I128, I256}; diff --git a/src/int/sign.rs b/src/int/sign.rs index e8bd3ed0d..11b9bde5b 100644 --- a/src/int/sign.rs +++ b/src/int/sign.rs @@ -1,4 +1,4 @@ -use crate::{ConstChoice, ConstCtOption, Int, Uint, Word}; +use crate::{ConstChoice, ConstCtOption, Int, Odd, Uint, Word}; use num_traits::ConstZero; impl Int { @@ -49,6 +49,20 @@ impl Int { } } +impl Odd> { + /// The sign and magnitude of this [`Odd`]. + pub const fn abs_sign(&self) -> (Odd>, ConstChoice) { + let (abs, sgn) = Int::abs_sign(self.as_ref()); + let odd_abs = abs.to_odd().expect("abs value of an odd number is odd"); + (odd_abs, sgn) + } + + /// The magnitude of this [`Odd`]. + pub const fn abs(&self) -> Odd> { + self.abs_sign().0 + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/int/sub.rs b/src/int/sub.rs index 623ec149d..9afe69705 100644 --- a/src/int/sub.rs +++ b/src/int/sub.rs @@ -3,12 +3,14 @@ use core::ops::{Sub, SubAssign}; use num_traits::WrappingSub; -use subtle::{Choice, ConstantTimeEq, CtOption}; +use subtle::CtOption; -use crate::{Checked, CheckedSub, Int, Wrapping}; +use crate::{Checked, CheckedSub, ConstChoice, ConstCtOption, Int, Wrapping}; -impl CheckedSub for Int { - fn checked_sub(&self, rhs: &Self) -> CtOption { +impl Int { + /// Perform subtraction, returning the result along with a [ConstChoice] which `is_true` + /// only if the operation underflowed. + pub const fn underflowing_sub(&self, rhs: &Self) -> (Self, ConstChoice) { // Step 1. subtract operands let res = Self(self.0.wrapping_sub(&rhs.0)); @@ -18,12 +20,26 @@ impl CheckedSub for Int { // - underflow occurs if and only if the result and the lhs have opposing signs. // // We can thus express the overflow flag as: (self.msb != rhs.msb) & (self.msb != res.msb) - let self_msb: Choice = self.is_negative().into(); - let underflow = - self_msb.ct_ne(&rhs.is_negative().into()) & self_msb.ct_ne(&res.is_negative().into()); + let self_msb = self.is_negative(); + let underflow = self_msb + .ne(rhs.is_negative()) + .and(self_msb.ne(res.is_negative())); // Step 3. Construct result - CtOption::new(res, !underflow) + (res, underflow) + } + + /// Perform wrapping subtraction, discarding underflow and wrapping around the boundary of the + /// type. + pub const fn wrapping_sub(&self, rhs: &Self) -> Self { + self.underflowing_sub(rhs).0 + } +} + +impl CheckedSub for Int { + fn checked_sub(&self, rhs: &Self) -> CtOption { + let (res, underflow) = Self::underflowing_sub(self, rhs); + ConstCtOption::new(res, underflow.not()).into() } } @@ -77,8 +93,6 @@ impl WrappingSub for Int { #[cfg(test)] #[allow(clippy::init_numbered_fields)] mod tests { - use num_traits::WrappingSub; - use crate::{CheckedSub, I128, Int, U128}; #[test] diff --git a/src/lib.rs b/src/lib.rs index ae44a1670..6ec5c209a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -187,6 +187,7 @@ pub use crate::{ int::types::*, int::*, limb::{Limb, WideWord, Word}, + modular::bingcd::{NonZeroUintBinxgcdOutput, OddUintBinxgcdOutput, UintBinxgcdOutput}, non_zero::NonZero, odd::Odd, traits::*, diff --git a/src/modular.rs b/src/modular.rs index 988aa5fdc..c7182bfbf 100644 --- a/src/modular.rs +++ b/src/modular.rs @@ -22,6 +22,7 @@ mod monty_form; mod reduction; mod add; +pub(crate) mod bingcd; mod div_by_2; mod mul; mod pow; diff --git a/src/modular/bingcd.rs b/src/modular/bingcd.rs new file mode 100644 index 000000000..16523e83d --- /dev/null +++ b/src/modular/bingcd.rs @@ -0,0 +1,11 @@ +//! This module implements (a constant variant of) the Optimized Extended Binary GCD algorithm, +//! which is described by Pornin as Algorithm 2 in "Optimized Binary GCD for Modular Inversion". +//! Ref: + +mod extension; +mod gcd; +mod matrix; +pub(crate) mod tools; +mod xgcd; + +pub use xgcd::{NonZeroUintBinxgcdOutput, OddUintBinxgcdOutput, UintBinxgcdOutput}; diff --git a/src/modular/bingcd/extension.rs b/src/modular/bingcd/extension.rs new file mode 100644 index 000000000..e2bf4f4e9 --- /dev/null +++ b/src/modular/bingcd/extension.rs @@ -0,0 +1,259 @@ +use crate::{ConstChoice, ConstCtOption, Int, Limb, Uint}; + +pub(crate) struct ExtendedUint( + Uint, + Uint, +); + +impl ExtendedUint {} + +impl ExtendedUint { + /// Construct an [ExtendedUint] from the product of a [Uint] and an [Uint]. + /// + /// Assumes the top bit of the product is not set. + #[inline] + pub const fn from_product(lhs: Uint, rhs: Uint) -> Self { + let (lo, hi) = lhs.split_mul(&rhs); + ExtendedUint(lo, hi) + } + + /// Wrapping multiply `self` with `rhs` + pub fn wrapping_mul(&self, rhs: &Uint) -> Self { + let (lo, hi) = self.0.split_mul(&rhs); + let hi = self + .1 + .wrapping_mul(&rhs) + .wrapping_add(&hi.resize::()); + Self(lo, hi) + } + + /// Interpret `self` as an [ExtendedInt] + #[inline] + pub const fn as_extended_int(&self) -> ExtendedInt { + ExtendedInt(self.0, self.1) + } + + /// Whether this form is `Self::ZERO`. + #[inline] + pub const fn is_zero(&self) -> ConstChoice { + self.0.is_nonzero().not().and(self.1.is_nonzero().not()) + } + + /// Drop the extension. + #[inline] + pub const fn checked_drop_extension(&self) -> ConstCtOption> { + ConstCtOption::new(self.0, self.1.is_nonzero().not()) + } + + /// Construction the binary negation of `self`, i.e., map `self` to `!self + 1`. + /// + /// Note: maps `0` to itself. + #[inline] + pub const fn wrapping_neg(&self) -> Self { + let (lhs, carry) = self.0.carrying_neg(); + let mut rhs = self.1.not(); + rhs = Uint::select(&rhs, &rhs.wrapping_add(&Uint::ONE), carry); + Self(lhs, rhs) + } + + /// Negate `self` if `negate` is truthy. Otherwise returns `self`. + #[inline] + pub const fn wrapping_neg_if(&self, negate: ConstChoice) -> Self { + let neg = self.wrapping_neg(); + Self( + Uint::select(&self.0, &neg.0, negate), + Uint::select(&self.1, &neg.1, negate), + ) + } + + /// Shift `self` right by `shift` bits. + /// + /// Assumes `shift <= Uint::::BITS`. + #[inline] + pub const fn shr(&self, shift: u32) -> Self { + debug_assert!(shift <= Uint::::BITS); + + let shift_is_zero = ConstChoice::from_u32_eq(shift, 0); + let left_shift = shift_is_zero.select_u32(Uint::::BITS - shift, 0); + + let hi = self.1.shr(shift); + // TODO: replace with carrying_shl + let carry = Uint::select(&self.1, &Uint::ZERO, shift_is_zero).shl(left_shift); + let mut lo = self.0.shr(shift); + + // Apply carry + let limb_diff = LIMBS.wrapping_sub(EXTRA) as u32; + // safe to vartime; shr_vartime is variable in the value of shift only. Since this shift + // is a public constant, the constant time property of this algorithm is not impacted. + let carry = carry.resize::().shl_vartime(limb_diff * Limb::BITS); + lo = lo.bitxor(&carry); + + Self(lo, hi) + } +} + +#[derive(Debug, PartialEq, Clone, Copy)] +pub(crate) struct ExtendedInt( + Uint, + Uint, +); + +impl ExtendedInt {} + +impl ExtendedInt { + pub(super) const ZERO: Self = Self(Uint::ZERO, Uint::ZERO); + pub(super) const ONE: Self = Self(Uint::ONE, Uint::ZERO); + + /// Construct an [ExtendedInt] from the product of a [Uint] and an [Int]. + /// + /// Assumes the top bit of the product is not set. + #[inline] + pub const fn from_product(lhs: Uint, rhs: Uint) -> Self { + ExtendedUint::from_product(lhs, rhs).as_extended_int() + } + + /// Wrapping multiply `self` with `rhs`, which is passed as a + pub(crate) fn wrapping_mul( + &self, + rhs: (&Uint, &ConstChoice), + ) -> Self { + let (abs_self, self_is_negative) = self.abs_sign(); + let (abs_rhs, rhs_is_negative) = rhs; + let mut abs_val = abs_self.wrapping_mul(abs_rhs); + + // Make sure the top bit of `abs_val` is not set + abs_val.1 = abs_val.1.bitand((!Int::::SIGN_MASK).as_uint()); + + let val_is_negative = self_is_negative.xor(*rhs_is_negative); + abs_val.wrapping_neg_if(val_is_negative).as_extended_int() + } + + /// Interpret this as an [ExtendedUint]. + #[inline] + pub const fn as_extended_uint(&self) -> ExtendedUint { + ExtendedUint(self.0, self.1) + } + + /// Return the negation of `self` if `negate` is truthy. Otherwise, return `self`. + #[inline] + pub const fn wrapping_neg_if(&self, negate: ConstChoice) -> Self { + self.as_extended_uint() + .wrapping_neg_if(negate) + .as_extended_int() + } + + #[inline] + pub(crate) fn wrapping_add(&self, rhs: &Self) -> Self { + let (lo, carry) = self.0.adc(&rhs.0, Limb::ZERO); + let (hi, _) = self.1.adc(&rhs.1, carry); + Self(lo, hi) + } + + /// Compute `self - rhs`, wrapping any underflow. + #[inline] + pub const fn wrapping_sub(&self, rhs: &Self) -> Self { + let (lo, borrow) = self.0.sbb(&rhs.0, Limb::ZERO); + let (hi, _) = self.1.sbb(&rhs.1, borrow); + Self(lo, hi) + } + + /// Returns self without the extension. + #[inline] + pub const fn wrapping_drop_extension(&self) -> (Uint, ConstChoice) { + let (abs, sgn) = self.abs_sign(); + (abs.0, sgn) + } + + /// Decompose `self` into is absolute value and signum. + #[inline] + pub const fn abs_sign(&self) -> (ExtendedUint, ConstChoice) { + let is_negative = self.1.as_int().is_negative(); + ( + self.wrapping_neg_if(is_negative).as_extended_uint(), + is_negative, + ) + } + + /// Divide self by `2^k`, rounding towards zero. + #[inline] + pub const fn div_2k(&self, k: u32) -> Self { + let (abs, sgn) = self.abs_sign(); + abs.shr(k).wrapping_neg_if(sgn).as_extended_int() + } +} + +#[cfg(test)] +mod tests { + use crate::modular::bingcd::extension::{ExtendedInt, ExtendedUint}; + use crate::{ConstChoice, U64, U128, Uint}; + + impl ExtendedInt { + /// Construct an [ExtendedInt] from the product of a [Uint] and an [Int]. + /// + /// Assumes the top bit of the product is not set. + #[inline] + pub const fn from_i64(val: i64) -> Self { + let abs_val = val.unsigned_abs(); + let is_negative = ConstChoice::from_u64_gt(abs_val, 0x7FFFFFFF); + ExtendedUint::from_product(Uint::from_u64(abs_val), Uint::ZERO) + .wrapping_neg_if(is_negative) + .as_extended_int() + } + } + + const A: ExtendedUint<{ U64::LIMBS }, { U64::LIMBS }> = ExtendedUint::from_product( + U64::from_u64(68146184546341u64), + U64::from_u64(873817114763u64), + ); + const B: ExtendedUint<{ U64::LIMBS }, { U64::LIMBS }> = ExtendedUint::from_product( + U64::from_u64(7772181434148543u64), + U64::from_u64(6665138352u64), + ); + + impl ExtendedUint { + /// Decompose `self` into the bottom and top limbs. + #[inline] + fn as_elements(&self) -> (Uint, Uint) { + (self.0, self.1) + } + } + + #[test] + fn test_from_product() { + assert_eq!( + A.as_elements(), + (U64::from(13454091406951429143u64), U64::from(3228065u64)) + ); + assert_eq!( + B.as_elements(), + (U64::from(1338820589698724688u64), U64::from(2808228u64)) + ); + } + + #[test] + fn test_wrapping_sub() { + assert_eq!( + A.as_extended_int() + .wrapping_sub(&B.as_extended_int()) + .as_extended_uint() + .as_elements(), + (U64::from(12115270817252704455u64), U64::from(419837u64)) + ) + } + + #[test] + fn test_wrapping_mul() { + let a = ExtendedInt( + U128::from_be_hex("39F1B23EBAB019658E5A4C15C3FBC4D5"), + U64::from_be_hex("5BF731833CE465C7"), + ); + let b = (&U64::MAX, &ConstChoice::TRUE); + let res = a.wrapping_mul(b); + + let target = ExtendedInt( + Uint::from_be_hex("AB976628F6B454908E5A4C15C3FBC4D5"), + Uint::from_be_hex("A2057F4482344C61"), + ); + assert_eq!(res, target); + } +} diff --git a/src/modular/bingcd/gcd.rs b/src/modular/bingcd/gcd.rs new file mode 100644 index 000000000..814a8172d --- /dev/null +++ b/src/modular/bingcd/gcd.rs @@ -0,0 +1,244 @@ +use crate::modular::bingcd::tools::const_max; +use crate::{ConstChoice, Odd, U64, U128, Uint}; + +impl Odd> { + /// The minimal number of iterations required to ensure the Binary GCD algorithm terminates and + /// returns the proper value. + const MINIMAL_BINGCD_ITERATIONS: u32 = 2 * Self::BITS - 1; + + /// Computes `gcd(self, rhs)`, leveraging (a constant time implementation of) the classic + /// Binary GCD algorithm. + /// + /// Note: this algorithm is efficient for [Uint]s with relatively few `LIMBS`. + /// + /// Ref: Pornin, Optimized Binary GCD for Modular Inversion, Algorithm 1. + /// + #[inline] + pub fn classic_bingcd(&self, rhs: &Uint) -> Self { + // (self, rhs) corresponds to (m, y) in the Algorithm 1 notation. + let (mut a, mut b) = (*rhs, *self.as_ref()); + let mut j = 0; + while j < Self::MINIMAL_BINGCD_ITERATIONS { + Self::bingcd_step(&mut a, &mut b); + j += 1; + } + + b.to_odd() + .expect("gcd of an odd value with something else is always odd") + } + + /// Binary GCD update step. + /// + /// This is a condensed, constant time execution of the following algorithm: + /// ```text + /// if a mod 2 == 1 + /// if a < b + /// (a, b) ← (b, a) + /// a ← a - b + /// a ← a/2 + /// ``` + /// + /// Note: assumes `b` to be odd. Might yield an incorrect result if this is not the case. + /// + /// Ref: Pornin, Algorithm 1, L3-9, . + #[inline] + fn bingcd_step(a: &mut Uint, b: &mut Uint) { + let a_odd = a.is_odd(); + let a_lt_b = Uint::lt(a, b); + Uint::conditional_swap(a, b, a_odd.and(a_lt_b)); + *a = a + .wrapping_sub(&Uint::select(&Uint::ZERO, b, a_odd)) + .shr_vartime(1); + } + + /// Computes `gcd(self, rhs)`, leveraging the optimized Binary GCD algorithm. + /// + /// Note: this algorithm becomes more efficient than the classical algorithm for [Uint]s with + /// relatively many `LIMBS`. A best-effort threshold is presented in [Self::bingcd]. + /// + /// Note: the full algorithm has an additional parameter; this function selects the best-effort + /// value for this parameter. You might be able to further tune your performance by calling the + /// [Self::optimized_bingcd_] function directly. + /// + /// Ref: Pornin, Optimized Binary GCD for Modular Inversion, Algorithm 2. + /// + #[inline] + pub fn optimized_bingcd(&self, rhs: &Uint) -> Self { + self.optimized_bingcd_::<{ U64::BITS }, { U64::LIMBS }, { U128::LIMBS }>(rhs) + } + + /// Computes `gcd(self, rhs)`, leveraging the optimized Binary GCD algorithm. + /// + /// Ref: Pornin, Optimized Binary GCD for Modular Inversion, Algorithm 2. + /// + /// + /// In summary, the optimized algorithm does not operate on `self` and `rhs` directly, but + /// instead of condensed summaries that fit in few registers. Based on these summaries, an + /// update matrix is constructed by which `self` and `rhs` are updated in larger steps. + /// + /// This function is generic over the following three values: + /// - `K`: the number of bits used when summarizing `self` and `rhs` for the inner loop. The + /// `K+1` top bits and `K-1` least significant bits are selected. It is recommended to keep + /// `K` close to a (multiple of) the number of bits that fit in a single register. + /// - `LIMBS_K`: should be chosen as the minimum number s.t. `Uint::::BITS ≥ K`, + /// - `LIMBS_2K`: should be chosen as the minimum number s.t. `Uint::::BITS ≥ 2K`. + #[inline] + pub fn optimized_bingcd_( + &self, + rhs: &Uint, + ) -> Self { + let (mut a, mut b) = (*self.as_ref(), *rhs); + + let mut i = 0; + while i < Self::MINIMAL_BINGCD_ITERATIONS.div_ceil(K - 1) { + i += 1; + + // Construct a_ and b_ as the summary of a and b, respectively. + let n = const_max(2 * K, const_max(a.bits(), b.bits())); + let a_ = a.compact::(n); + let b_ = b.compact::(n); + + // Compute the K-1 iteration update matrix from a_ and b_ + // Safe to vartime; function executes in time variable in `iterations` only, which is + // a public constant K-1 here. + let (.., matrix) = a_ + .to_odd() + .expect("a_ is always odd") + .partial_binxgcd_vartime::(&b_, K - 1, ConstChoice::FALSE); + + // Update `a` and `b` using the update matrix + let (updated_a, updated_b) = matrix.extended_apply_to((a, b)); + (a, _) = updated_a.wrapping_drop_extension(); + (b, _) = updated_b.wrapping_drop_extension(); + } + + a.to_odd() + .expect("gcd of an odd value with something else is always odd") + } +} + +#[cfg(all(test, feature = "rand_core"))] +mod tests { + use rand_chacha::ChaChaRng; + use rand_core::SeedableRng; + + fn make_rng() -> ChaChaRng { + ChaChaRng::from_seed([0; 32]) + } + + mod test_classic_bingcd { + use crate::modular::bingcd::gcd::tests::make_rng; + use crate::{ + Gcd, Int, Random, U64, U128, U192, U256, U384, U512, U1024, U2048, U4096, Uint, + }; + + fn classic_bingcd_test(lhs: Uint, rhs: Uint) + where + Uint: Gcd>, + { + let gcd = lhs.gcd(&rhs); + let bingcd = lhs.to_odd().unwrap().classic_bingcd(&rhs); + assert_eq!(gcd, bingcd); + } + + fn classic_bingcd_tests() + where + Uint: Gcd>, + { + // Edge cases + classic_bingcd_test(Uint::ONE, Uint::ZERO); + classic_bingcd_test(Uint::ONE, Uint::ONE); + classic_bingcd_test(Uint::ONE, Int::MAX.abs()); + classic_bingcd_test(Uint::ONE, Int::MIN.abs()); + classic_bingcd_test(Uint::ONE, Uint::MAX); + classic_bingcd_test(Int::MAX.abs(), Uint::ZERO); + classic_bingcd_test(Int::MAX.abs(), Uint::ONE); + classic_bingcd_test(Int::MAX.abs(), Int::MAX.abs()); + classic_bingcd_test(Int::MAX.abs(), Int::MIN.abs()); + classic_bingcd_test(Int::MAX.abs(), Uint::MAX); + classic_bingcd_test(Uint::MAX, Uint::ZERO); + classic_bingcd_test(Uint::MAX, Uint::ONE); + classic_bingcd_test(Uint::MAX, Int::MAX.abs()); + classic_bingcd_test(Uint::MAX, Int::MIN.abs()); + classic_bingcd_test(Uint::MAX, Uint::MAX); + + // Randomized test cases + let mut rng = make_rng(); + for _ in 0..1000 { + let x = Uint::::random(&mut rng).bitor(&Uint::ONE); + let y = Uint::::random(&mut rng); + classic_bingcd_test(x, y); + } + } + + #[test] + fn test_classic_bingcd() { + classic_bingcd_tests::<{ U64::LIMBS }>(); + classic_bingcd_tests::<{ U128::LIMBS }>(); + classic_bingcd_tests::<{ U192::LIMBS }>(); + classic_bingcd_tests::<{ U256::LIMBS }>(); + classic_bingcd_tests::<{ U384::LIMBS }>(); + classic_bingcd_tests::<{ U512::LIMBS }>(); + classic_bingcd_tests::<{ U1024::LIMBS }>(); + classic_bingcd_tests::<{ U2048::LIMBS }>(); + classic_bingcd_tests::<{ U4096::LIMBS }>(); + } + } + + mod test_optimized_bingcd { + use crate::modular::bingcd::gcd::tests::make_rng; + use crate::{Gcd, Int, Random, U128, U192, U256, U384, U512, U1024, U2048, U4096, Uint}; + + fn optimized_bingcd_test(lhs: Uint, rhs: Uint) + where + Uint: Gcd>, + { + let gcd = lhs.gcd(&rhs); + let bingcd = lhs.to_odd().unwrap().optimized_bingcd(&rhs); + assert_eq!(gcd, bingcd); + } + + fn optimized_bingcd_tests() + where + Uint: Gcd>, + { + // Edge cases + optimized_bingcd_test(Uint::ONE, Uint::ZERO); + optimized_bingcd_test(Uint::ONE, Uint::ONE); + optimized_bingcd_test(Uint::ONE, Int::MAX.abs()); + optimized_bingcd_test(Uint::ONE, Int::MIN.abs()); + optimized_bingcd_test(Uint::ONE, Uint::MAX); + optimized_bingcd_test(Int::MAX.abs(), Uint::ZERO); + optimized_bingcd_test(Int::MAX.abs(), Uint::ONE); + optimized_bingcd_test(Int::MAX.abs(), Int::MAX.abs()); + optimized_bingcd_test(Int::MAX.abs(), Int::MIN.abs()); + optimized_bingcd_test(Int::MAX.abs(), Uint::MAX); + optimized_bingcd_test(Uint::MAX, Uint::ZERO); + optimized_bingcd_test(Uint::MAX, Uint::ONE); + optimized_bingcd_test(Uint::MAX, Int::MAX.abs()); + optimized_bingcd_test(Uint::MAX, Int::MIN.abs()); + optimized_bingcd_test(Uint::MAX, Uint::MAX); + + // Randomized testing + let mut rng = make_rng(); + for _ in 0..1000 { + let x = Uint::::random(&mut rng).bitor(&Uint::ONE); + let y = Uint::::random(&mut rng); + optimized_bingcd_test(x, y); + } + } + + #[test] + fn test_optimized_bingcd() { + // Not applicable for U64 + optimized_bingcd_tests::<{ U128::LIMBS }>(); + optimized_bingcd_tests::<{ U192::LIMBS }>(); + optimized_bingcd_tests::<{ U256::LIMBS }>(); + optimized_bingcd_tests::<{ U384::LIMBS }>(); + optimized_bingcd_tests::<{ U512::LIMBS }>(); + optimized_bingcd_tests::<{ U1024::LIMBS }>(); + optimized_bingcd_tests::<{ U2048::LIMBS }>(); + optimized_bingcd_tests::<{ U4096::LIMBS }>(); + } + } +} diff --git a/src/modular/bingcd/matrix.rs b/src/modular/bingcd/matrix.rs new file mode 100644 index 000000000..d4f0a536a --- /dev/null +++ b/src/modular/bingcd/matrix.rs @@ -0,0 +1,602 @@ +use crate::modular::bingcd::extension::ExtendedInt; +use crate::{ConstChoice, Uint}; + +type Vector = (T, T); + +/// [`Int`] with an extra limb. +type ExtraLimbInt = ExtendedInt; + +/// Matrix used by the Binary XGCD algorithm to represent the state. +/// +/// ### Representation +/// The internal state represents the matrix +/// ```text +/// [ m00 m01 ] +/// [ m10 m11 ] / 2^k +/// ``` +/// with `k_upper_bound ≥ k`. +#[derive(Debug, Clone, Copy, PartialEq)] +pub(super) struct StateMatrix { + m00: ExtraLimbInt, + m01: ExtraLimbInt, + m10: ExtraLimbInt, + m11: ExtraLimbInt, + k: u32, + k_upper_bound: u32, +} + +impl StateMatrix { + /// The unit matrix + pub(super) const UNIT: Self = Self::new( + ExtraLimbInt::ONE, + ExtraLimbInt::ZERO, + ExtraLimbInt::ZERO, + ExtraLimbInt::ONE, + 0, + 0, + ); + + /// Construct a new [`StateMatrix`]. + pub(super) const fn new( + m00: ExtraLimbInt, + m01: ExtraLimbInt, + m10: ExtraLimbInt, + m11: ExtraLimbInt, + k: u32, + k_upper_bound: u32, + ) -> Self { + Self { + m00, + m01, + m10, + m11, + k, + k_upper_bound, + } + } + + /// Negate the top row of this matrix if `negate`; otherwise do nothing. + pub(super) const fn conditional_negate_top_row(&mut self, negate: ConstChoice) { + self.m00 = self.m00.wrapping_neg_if(negate); + self.m01 = self.m01.wrapping_neg_if(negate); + } + + /// Negate the bottom row of this matrix if `negate`; otherwise do nothing. + pub(super) const fn conditional_negate_bottom_row(&mut self, negate: ConstChoice) { + self.m10 = self.m10.wrapping_neg_if(negate); + self.m11 = self.m11.wrapping_neg_if(negate); + } + + pub(super) const fn to_update_matrix(&self) -> UpdateMatrix { + let (abs_m00, m00_is_negative) = self.m00.abs_sign(); + let (abs_m01, m01_is_negative) = self.m01.abs_sign(); + let (abs_m10, m10_is_negative) = self.m10.abs_sign(); + let (abs_m11, m11_is_negative) = self.m11.abs_sign(); + + // Construct the pattern. + let m00_is_zero = abs_m00.is_zero(); + let m01_is_zero = abs_m01.is_zero(); + let pattern_vote_1 = m00_is_zero.not().and(m00_is_negative.not()); + let pattern_vote_2 = m01_is_zero.not().and(m01_is_negative); + + let m00_and_m01_are_zero = m00_is_zero.and(m01_is_zero); + let m10_is_zero = abs_m10.is_zero(); + let m11_is_zero = abs_m11.is_zero(); + let pattern_vote_3 = m00_and_m01_are_zero.and(m10_is_zero.not().and(m10_is_negative)); + let pattern_vote_4 = m00_and_m01_are_zero.and(m11_is_zero.not().and(m11_is_negative.not())); + let pattern = pattern_vote_1 + .or(pattern_vote_2) + .or(pattern_vote_3) + .or(pattern_vote_4); + + UpdateMatrix::new( + abs_m00.checked_drop_extension().expect("m00 fits"), + abs_m01.checked_drop_extension().expect("m01 fits"), + abs_m10.checked_drop_extension().expect("m10 fits"), + abs_m11.checked_drop_extension().expect("m11 fits"), + pattern, + self.k, + self.k_upper_bound, + ) + } +} + +/// Matrix used to compute the Extended GCD using the Binary Extended GCD algorithm. +/// +/// The internal state represents the matrix +/// ```text +/// true false +/// [ m00 -m01 ] [ -m00 m01 ] +/// [ -m10 m11 ] / 2^k or [ m10 -m11 ] / 2^k +/// ``` +/// depending on whether `pattern` is respectively truthy or not. +/// +/// Since some of the operations conditionally increase `k`, this struct furthermore keeps track of +/// `k_upper_bound`; an upper bound on the value of `k`. +#[derive(Debug, Clone, Copy, PartialEq)] +pub(crate) struct UpdateMatrix { + m00: Uint, + m01: Uint, + m10: Uint, + m11: Uint, + pattern: ConstChoice, + k: u32, + k_upper_bound: u32, +} + +impl UpdateMatrix { + /// The unit matrix. + pub(crate) const UNIT: Self = Self::new( + Uint::ONE, + Uint::ZERO, + Uint::ZERO, + Uint::ONE, + ConstChoice::TRUE, + 0, + 0, + ); + + /// Construct the matrix representing the subtraction of one vector element from the other. + /// Subtracts the top element from the bottom if `top_from_bottom` is truthy, and the one + /// subtracting the bottom element from the top otherwise. + /// + /// In other words, returns one of the following matrices, given `top_from_bottom` + /// ```text + /// true false + /// [ 1 0 ] [ 1 -1 ] + /// [ -1 1 ] or [ 0 1 ] + /// ``` + pub(crate) fn get_subtraction_matrix(top_from_bottom: ConstChoice, k_upper_bound: u32) -> Self { + let (mut m01, mut m10) = (Uint::ONE, Uint::ZERO); + Uint::conditional_swap(&mut m01, &mut m10, top_from_bottom); + Self::new( + Uint::ONE, + m01, + m10, + Uint::ONE, + ConstChoice::TRUE, + 0, + k_upper_bound, + ) + } + + pub(crate) const fn new( + m00: Uint, + m01: Uint, + m10: Uint, + m11: Uint, + pattern: ConstChoice, + k: u32, + k_upper_bound: u32, + ) -> Self { + Self { + m00, + m01, + m10, + m11, + pattern, + k, + k_upper_bound, + } + } + + pub(crate) fn as_elements( + &self, + ) -> ( + &Uint, + &Uint, + &Uint, + &Uint, + ConstChoice, + u32, + u32, + ) { + ( + &self.m00, + &self.m01, + &self.m10, + &self.m11, + self.pattern, + self.k, + self.k_upper_bound, + ) + } + + pub(crate) fn as_elements_mut( + &mut self, + ) -> ( + &mut Uint, + &mut Uint, + &mut Uint, + &mut Uint, + &mut ConstChoice, + &mut u32, + &mut u32, + ) { + ( + &mut self.m00, + &mut self.m01, + &mut self.m10, + &mut self.m11, + &mut self.pattern, + &mut self.k, + &mut self.k_upper_bound, + ) + } + + /// Return `b` if `c` is truthy, otherwise return `a`. + #[inline] + pub(crate) fn select(a: &Self, b: &Self, c: ConstChoice) -> Self { + Self { + m00: Uint::select(&a.m00, &b.m00, c), + m01: Uint::select(&a.m01, &b.m01, c), + m10: Uint::select(&a.m10, &b.m10, c), + m11: Uint::select(&a.m11, &b.m11, c), + k: c.select_u32(a.k, b.k), + k_upper_bound: c.select_u32(a.k_upper_bound, b.k_upper_bound), + pattern: b.pattern.and(c).or(a.pattern.and(c.not())), + } + } + + /// Apply this matrix to a vector of [Uint]s, returning the result as a vector of + /// [ExtendedInt]s. + #[inline] + pub(crate) fn extended_apply_to( + &self, + vec: Vector>, + ) -> Vector> { + let (a, b) = vec; + let m00a = ExtendedInt::from_product(a, self.m00); + let m10a = ExtendedInt::from_product(a, self.m10); + let m01b = ExtendedInt::from_product(b, self.m01); + let m11b = ExtendedInt::from_product(b, self.m11); + ( + m00a.wrapping_sub(&m01b) + .div_2k(self.k) + .wrapping_neg_if(self.pattern.not()), + m11b.wrapping_sub(&m10a) + .div_2k(self.k) + .wrapping_neg_if(self.pattern.not()), + ) + } + + /// Wrapping apply this matrix to `rhs`. Return the result in `RHS_LIMBS`. + #[inline] + pub(super) fn mul_right( + &self, + rhs: &StateMatrix, + ) -> StateMatrix { + let a0 = rhs.m00.wrapping_mul((&self.m00, &self.pattern.not())); + let a1 = rhs.m10.wrapping_mul((&self.m01, &self.pattern)); + let a = a0.wrapping_add(&a1); + + let b0 = rhs.m01.wrapping_mul((&self.m00, &self.pattern.not())); + let b1 = rhs.m11.wrapping_mul((&self.m01, &self.pattern)); + let b = b0.wrapping_add(&b1); + + let c0 = rhs.m00.wrapping_mul((&self.m10, &self.pattern)); + let c1 = rhs.m10.wrapping_mul((&self.m11, &self.pattern.not())); + let c = c0.wrapping_add(&c1); + + let d0 = rhs.m01.wrapping_mul((&self.m10, &self.pattern)); + let d1 = rhs.m11.wrapping_mul((&self.m11, &self.pattern.not())); + let d = d0.wrapping_add(&d1); + + StateMatrix::new( + a, + b, + c, + d, + self.k + rhs.k, + self.k_upper_bound + rhs.k_upper_bound, + ) + } + + /// Swap the rows of this matrix if `swap` is truthy. Otherwise, do nothing. + #[inline] + pub(crate) fn conditional_swap_rows(&mut self, swap: ConstChoice) { + Uint::conditional_swap(&mut self.m00, &mut self.m10, swap); + Uint::conditional_swap(&mut self.m01, &mut self.m11, swap); + self.pattern = self.pattern.xor(swap); + } + + /// Swap the rows of this matrix. + #[inline] + pub(crate) fn swap_rows(&mut self) { + self.conditional_swap_rows(ConstChoice::TRUE); + } + + /// Subtract the bottom row from the top if `subtract` is truthy. Otherwise, do nothing. + #[inline] + pub(crate) fn conditional_subtract_bottom_row_from_top(&mut self, subtract: ConstChoice) { + // Note: because the signs of the internal representation are stored in `pattern`, + // subtracting one row from another involves _adding_ these rows instead. + self.m00 = Uint::select(&self.m00, &self.m00.wrapping_add(&self.m10), subtract); + self.m01 = Uint::select(&self.m01, &self.m01.wrapping_add(&self.m11), subtract); + } + + /// Subtract the right column from the left if `subtract` is truthy. Otherwise, do nothing. + #[inline] + pub(crate) fn conditional_subtract_right_column_from_left(&mut self, subtract: ConstChoice) { + // Note: because the signs of the internal representation are stored in `pattern`, + // subtracting one column from another involves _adding_ these columns instead. + self.m00 = Uint::select(&self.m00, &self.m00.wrapping_add(&self.m01), subtract); + self.m10 = Uint::select(&self.m10, &self.m10.wrapping_add(&self.m11), subtract); + } + + /// If `add` is truthy, add the right column to the left. Otherwise, do nothing. + #[inline] + pub(crate) fn conditional_add_right_column_to_left(&mut self, add: ConstChoice) { + // Note: because the signs of the internal representation are stored in `pattern`, + // subtracting one column from another involves _adding_ these columns instead. + self.m00 = Uint::select(&self.m00, &self.m01.wrapping_sub(&self.m00), add); + self.m10 = Uint::select(&self.m10, &self.m11.wrapping_sub(&self.m10), add); + } + + /// Double the bottom row of this matrix if `double` is truthy. Otherwise, do nothing. + #[inline] + pub(crate) fn conditional_double_bottom_row(&mut self, double: ConstChoice) { + // safe to vartime; shr_vartime is variable in the value of shift only. Since this shift + // is a public constant, the constant time property of this algorithm is not impacted. + self.m10 = Uint::select(&self.m10, &self.m10.shl_vartime(1), double); + self.m11 = Uint::select(&self.m11, &self.m11.shl_vartime(1), double); + self.k = double.select_u32(self.k, self.k + 1); + self.k_upper_bound += 1; + } + + /// Negate the elements in this matrix if `negate` is truthy. Otherwise, do nothing. + #[inline] + pub(crate) fn conditional_negate(&mut self, negate: ConstChoice) { + self.pattern = self.pattern.xor(negate); + } +} + +#[cfg(test)] +mod tests { + use crate::modular::bingcd::matrix::{ExtraLimbInt, StateMatrix, UpdateMatrix}; + use crate::{ConstChoice, U64, U256, Uint}; + + const X: UpdateMatrix<{ U256::LIMBS }> = UpdateMatrix::new( + U256::from_u64(1u64), + U256::from_u64(7u64), + U256::from_u64(23u64), + U256::from_u64(53u64), + ConstChoice::TRUE, + 1, + 2, + ); + + const Y: StateMatrix<{ U256::LIMBS }> = StateMatrix::new( + ExtraLimbInt::from_i64(1i64), + ExtraLimbInt::from_i64(-7i64), + ExtraLimbInt::from_i64(-23i64), + ExtraLimbInt::from_i64(53i64), + 1, + 2, + ); + + #[test] + fn test_wrapping_apply_to() { + let a = U64::from_be_hex("CA048AFA63CD6A1F"); + let b = U64::from_be_hex("AE693BF7BE8E5566"); + let matrix = UpdateMatrix { + m00: U64::from_be_hex("0000000000000120"), + m01: U64::from_be_hex("00000000000000D0"), + m10: U64::from_be_hex("0000000000000136"), + m11: U64::from_be_hex("00000000000002A7"), + pattern: ConstChoice::TRUE, + k: 17, + k_upper_bound: 17, + }; + + let (a_, b_) = matrix.extended_apply_to((a, b)); + assert_eq!( + a_.wrapping_drop_extension().0, + Uint::from_be_hex("002AC7CDD032B9B9") + ); + assert_eq!( + b_.wrapping_drop_extension().0, + Uint::from_be_hex("006CFBCEE172C863") + ); + } + + #[test] + fn test_swap() { + let mut y = X; + y.swap_rows(); + assert_eq!( + y, + UpdateMatrix::new( + Uint::from(23u32), + Uint::from(53u32), + Uint::from(1u32), + Uint::from(7u32), + ConstChoice::FALSE, + 1, + 2 + ) + ); + } + + #[test] + fn test_conditional_swap() { + let mut y = X; + y.conditional_swap_rows(ConstChoice::FALSE); + assert_eq!(y, X); + y.conditional_swap_rows(ConstChoice::TRUE); + assert_eq!( + y, + UpdateMatrix::new( + Uint::from(23u32), + Uint::from(53u32), + Uint::from(1u32), + Uint::from(7u32), + ConstChoice::FALSE, + 1, + 2 + ) + ); + } + + #[test] + fn test_conditional_add_right_column_to_left() { + let mut y = X; + y.conditional_add_right_column_to_left(ConstChoice::FALSE); + assert_eq!(y, X); + y.conditional_add_right_column_to_left(ConstChoice::TRUE); + assert_eq!( + y, + UpdateMatrix::new( + Uint::from(6u32), + Uint::from(7u32), + Uint::from(30u32), + Uint::from(53u32), + ConstChoice::TRUE, + 1, + 2 + ) + ); + } + + #[test] + fn test_conditional_subtract_bottom_row_from_top() { + let mut y = X; + y.conditional_subtract_bottom_row_from_top(ConstChoice::FALSE); + assert_eq!(y, X); + y.conditional_subtract_bottom_row_from_top(ConstChoice::TRUE); + assert_eq!( + y, + UpdateMatrix::new( + Uint::from(24u32), + Uint::from(60u32), + Uint::from(23u32), + Uint::from(53u32), + ConstChoice::TRUE, + 1, + 2 + ) + ); + } + + #[test] + fn test_conditional_subtract_right_column_from_left() { + let mut y = X; + y.conditional_subtract_right_column_from_left(ConstChoice::FALSE); + assert_eq!(y, X); + y.conditional_subtract_right_column_from_left(ConstChoice::TRUE); + assert_eq!( + y, + UpdateMatrix::new( + Uint::from(8u32), + Uint::from(7u32), + Uint::from(76u32), + Uint::from(53u32), + ConstChoice::TRUE, + 1, + 2 + ) + ); + } + + #[test] + fn test_conditional_double() { + let mut y = X; + y.conditional_double_bottom_row(ConstChoice::FALSE); + assert_eq!( + y, + UpdateMatrix::new( + Uint::from(1u32), + Uint::from(7u32), + Uint::from(23u32), + Uint::from(53u32), + ConstChoice::TRUE, + 1, + 3 + ) + ); + y.conditional_double_bottom_row(ConstChoice::TRUE); + assert_eq!( + y, + UpdateMatrix::new( + Uint::from(1u32), + Uint::from(7u32), + Uint::from(46u32), + Uint::from(106u32), + ConstChoice::TRUE, + 2, + 4 + ) + ); + } + + #[test] + fn test_mul_right() { + let res = X.mul_right(&Y); + assert_eq!( + res, + StateMatrix::new( + ExtraLimbInt::<{ U256::LIMBS }>::from_i64(162i64), + ExtraLimbInt::<{ U256::LIMBS }>::from_i64(-378i64), + ExtraLimbInt::<{ U256::LIMBS }>::from_i64(-1242i64), + ExtraLimbInt::<{ U256::LIMBS }>::from_i64(2970i64), + 2, + 4 + ) + ) + } + + #[test] + fn test_select() { + let x = UpdateMatrix::new( + U64::from_u64(0), + U64::from_u64(1), + U64::from_u64(2), + U64::from_u64(3), + ConstChoice::FALSE, + 4, + 5, + ); + let y = UpdateMatrix::new( + U64::from_u64(6), + U64::from_u64(7), + U64::from_u64(8), + U64::from_u64(9), + ConstChoice::TRUE, + 11, + 12, + ); + assert_eq!(UpdateMatrix::select(&x, &y, ConstChoice::FALSE), x); + assert_eq!(UpdateMatrix::select(&x, &y, ConstChoice::TRUE), y); + } + + #[test] + fn test_get_subtraction_matrix() { + let x = UpdateMatrix::get_subtraction_matrix(ConstChoice::TRUE, 35); + assert_eq!( + x, + UpdateMatrix::new( + U64::ONE, + U64::ZERO, + U64::ONE, + U64::ONE, + ConstChoice::TRUE, + 0, + 35 + ) + ); + + let x = UpdateMatrix::get_subtraction_matrix(ConstChoice::FALSE, 63); + assert_eq!( + x, + UpdateMatrix::new( + U64::ONE, + U64::ONE, + U64::ZERO, + U64::ONE, + ConstChoice::TRUE, + 0, + 63 + ) + ); + } +} diff --git a/src/modular/bingcd/tools.rs b/src/modular/bingcd/tools.rs new file mode 100644 index 000000000..a3898c0d8 --- /dev/null +++ b/src/modular/bingcd/tools.rs @@ -0,0 +1,268 @@ +use crate::{ConstChoice, Limb, Odd, Uint}; + +/// `const` equivalent of `u32::max(a, b)`. +pub(crate) fn const_max(a: u32, b: u32) -> u32 { + ConstChoice::from_u32_lt(a, b).select_u32(a, b) +} + +/// `const` equivalent of `u32::min(a, b)`. +pub(crate) fn const_min(a: u32, b: u32) -> u32 { + ConstChoice::from_u32_lt(a, b).select_u32(b, a) +} + +impl Limb { + /// Compute `self / 2^k mod q`. Returns the result, as well as a factor `f` such that `2^k` + /// divides `self + q * f`. + /// + /// Executes in time variable in `k_bound`. This value should be chosen as an inclusive + /// upperbound to the value of `k`. + fn bounded_div2k_mod_q(mut self, k: u32, k_bound: u32, one_half_mod_q: Self) -> (Self, Self) { + let mut factor = Limb::ZERO; + let mut i = 0; + while i < k_bound { + let execute = ConstChoice::from_u32_lt(i, k); + + let (shifted, carry) = self.shr1(); + self = Self::select(self, shifted, execute); + + let overflow = ConstChoice::from_word_msb(carry.0); + let add_back_q = overflow.and(execute); + self = self.wrapping_add(Self::select(Self::ZERO, one_half_mod_q, add_back_q)); + factor = factor.bitxor(Self::select(Self::ZERO, Self::ONE.shl(i), add_back_q)); + i += 1; + } + + (self, factor) + } +} + +impl Uint { + /// Compute `self / 2^k mod q`. + /// + /// Executes in time variable in `k_bound`. This value should be chosen as an inclusive + /// upperbound to the value of `k`. + #[inline] + pub(crate) fn bounded_div_2k_mod_q(self, k: u32, k_bound: u32, q: &Odd) -> Self { + // 1 / 2 mod q + // = (q + 1) / 2 mod q + // = (q - 1) / 2 + 1 mod q + // = floor(q / 2) + 1 mod q, since q is odd. + let one_half_mod_q = q.as_ref().shr1().wrapping_add(&Uint::ONE); + + // invariant: x = self / 2^e mod q. + let (mut x, mut e) = (self, 0); + + let max_round_iters = Limb::BITS - 1; + let rounds = k_bound.div_ceil(max_round_iters); + + let mut r = 0; + while r < rounds { + let f_bound = const_min(k_bound - r * max_round_iters, max_round_iters); + let f = const_min(k - e, f_bound); + let (_, s) = x.limbs[0].bounded_div2k_mod_q(f, f_bound, one_half_mod_q.limbs[0]); + + // Compute (x * qs) / 2^f + // Note that 2^f divides x + qs by construction + let (x_qs_lo, mut x_qs_hi) = x.mac_limb(q.as_ref(), s, Limb::ZERO); + x_qs_hi = x_qs_hi.shl((Limb::BITS - f) % Limb::BITS); + let (mut x_qs_div_2f, _) = x_qs_lo.shr_limb(f); + x_qs_div_2f.limbs[LIMBS - 1] = x_qs_div_2f.limbs[LIMBS - 1].bitxor(x_qs_hi); + + (x, e) = (x_qs_div_2f, e + f); + + r += 1; + } + + x + } + + /// Computes `self + (b * c) + carry`, returning the result along with the new carry. + #[inline] + fn mac_limb(mut self, b: &Self, c: Limb, mut carry: Limb) -> (Self, Limb) { + let mut i = 0; + while i < LIMBS { + (self.limbs[i], carry) = self.limbs[i].mac(b.limbs[i], c, carry); + i += 1; + } + (self, carry) + } + + /// Construct a [Uint] containing the bits in `self` in the range `[idx, idx + length)`. + /// + /// Assumes `length ≤ Uint::::BITS` and `idx + length ≤ Self::BITS`. + /// + /// Executes in time variable in `length` only. + #[inline(always)] + pub(crate) fn section_vartime_length( + &self, + idx: u32, + length: u32, + ) -> Uint { + debug_assert!(length <= Uint::::BITS); + debug_assert!(idx + length <= Self::BITS); + + let mask = Uint::ONE.shl_vartime(length).wrapping_sub(&Uint::ONE); + self.shr(idx).resize::().bitand(&mask) + } + + /// Construct a [Uint] containing the bits in `self` in the range `[idx, idx + length)`. + /// + /// Assumes `length ≤ Uint::::BITS` and `idx + length ≤ Self::BITS`. + /// + /// Executes in time variable in `idx` and `length`. + #[inline(always)] + pub(crate) fn section_vartime( + &self, + idx: u32, + length: u32, + ) -> Uint { + debug_assert!(length <= Uint::::BITS); + debug_assert!(idx + length <= Self::BITS); + + let mask = Uint::ONE.shl_vartime(length).wrapping_sub(&Uint::ONE); + self.shr_vartime(idx) + .resize::() + .bitand(&mask) + } + + /// Compact `self` to a form containing the concatenation of its bit ranges `[0, K-1)` + /// and `[n-K-1, n)`. + /// + /// Assumes `K ≤ Uint::::BITS`, `n ≤ Self::BITS` and `n ≥ 2K`. + #[inline(always)] + pub(crate) fn compact( + &self, + n: u32, + ) -> Uint { + debug_assert!(K <= Uint::::BITS); + debug_assert!(n <= Self::BITS); + debug_assert!(n >= 2 * K); + + // safe to vartime; this function is vartime in length only, which is a public constant + let hi = self.section_vartime_length(n - K - 1, K + 1); + // safe to vartime; this function is vartime in idx and length only, which are both public + // constants + let lo = self.section_vartime(0, K - 1); + // safe to vartime; shl_vartime is variable in the value of shift only. Since this shift + // is a public constant, the constant time property of this algorithm is not impacted. + hi.shl_vartime(K - 1).bitxor(&lo) + } +} + +#[cfg(test)] +mod tests { + use crate::{Limb, U128, Uint}; + + #[test] + fn test_bounded_div2k_mod_q() { + let x = Limb::MAX.wrapping_sub(Limb::from(15u32)); + let q = Limb::from(55u32); + let half_mod_q = q.shr1().0.wrapping_add(Limb::ONE); + + // Do nothing + let k = 0; + let k_bound = 3; + let (res, factor) = x.bounded_div2k_mod_q(k, k_bound, half_mod_q); + assert_eq!(res, x); + assert_eq!(factor, Limb::ZERO); + + // Divide by 2^4 without requiring the addition of q + let k = 4; + let k_bound = 4; + let (res, factor) = x.bounded_div2k_mod_q(k, k_bound, half_mod_q); + assert_eq!(res, x.shr(4)); + assert_eq!(factor, Limb::ZERO); + + // Divide by 2^5, requiring a single addition of q * 2^4 + let k = 5; + let k_bound = 5; + let (res, factor) = x.bounded_div2k_mod_q(k, k_bound, half_mod_q); + assert_eq!(res, x.shr(5).wrapping_add(half_mod_q)); + assert_eq!(factor, Limb::ONE.shl(4)); + + // Execute at most k_bound iterations + let k = 5; + let k_bound = 4; + let (res, factor) = x.bounded_div2k_mod_q(k, k_bound, half_mod_q); + assert_eq!(res, x.shr(4)); + assert_eq!(factor, Limb::ZERO); + } + + #[test] + fn test_mac_limb() { + // Do nothing + let x = U128::from_be_hex("ABCDEF98765432100123456789FEDCBA"); + let q = U128::MAX; + let f = Limb::ZERO; + let (res, carry) = x.mac_limb(&q, f, Limb::ZERO); + assert_eq!(res, x); + assert_eq!(carry, Limb::ZERO); + + // f = 1 + let x = U128::from_be_hex("ABCDEF98765432100123456789FEDCBA"); + let q = U128::MAX; + let f = Limb::ONE; + let (res, carry) = x.mac_limb(&q, f, Limb::ZERO); + assert_eq!(res, x.wrapping_add(&q)); + assert_eq!(carry, Limb::ONE); + + // f = max + let x = U128::from_be_hex("ABCDEF98765432100123456789FEDCBA"); + let q = U128::MAX; + let f = Limb::MAX; + let (res, mac_carry) = x.mac_limb(&q, f, Limb::ZERO); + let (qf_lo, qf_hi) = q.split_mul(&Uint::new([f; 1])); + let (lo, carry) = qf_lo.adc(&x, Limb::ZERO); + let (hi, carry) = qf_hi.adc(&Uint::ZERO, carry); + assert_eq!(res, lo); + assert_eq!(mac_carry, hi.limbs[0]); + assert_eq!(carry, Limb::ZERO) + } + + #[test] + fn test_new_div2k_mod_q() { + // Do nothing + let q = U128::from(3u64).to_odd().unwrap(); + let res = U128::ONE.shl_vartime(64).bounded_div_2k_mod_q(0, 0, &q); + assert_eq!(res, U128::ONE.shl_vartime(64)); + + // Simply shift out 5 factors + let q = U128::from(3u64).to_odd().unwrap(); + let res = U128::ONE.shl_vartime(64).bounded_div_2k_mod_q(5, 5, &q); + assert_eq!(res, U128::ONE.shl_vartime(59)); + + // Add in one factor of q + let q = U128::from(3u64).to_odd().unwrap(); + let res = U128::ONE.bounded_div_2k_mod_q(1, 1, &q); + assert_eq!(res, U128::from(2u64)); + + // Add in many factors of q + let q = U128::from(3u64).to_odd().unwrap(); + let res = U128::from(8u64).bounded_div_2k_mod_q(17, 17, &q); + assert_eq!(res, U128::ONE); + + // Larger q + let q = U128::from(2864434311u64).to_odd().unwrap(); + let res = U128::from(8u64).bounded_div_2k_mod_q(17, 17, &q); + assert_eq!(res, U128::from(303681787u64)); + + // Shift greater than Limb::BITS + let q = U128::from_be_hex("0000AAAABBBB33330000AAAABBBB3333") + .to_odd() + .unwrap(); + let res = U128::MAX.bounded_div_2k_mod_q(71, 71, &q); + assert_eq!(res, U128::from_be_hex("00002D6F169DBBF300002D6F169DBBF3")); + + // Have k_bound restrict the number of shifts to 0 + let res = U128::MAX.bounded_div_2k_mod_q(71, 0, &q); + assert_eq!(res, U128::MAX); + + // Have k_bound < k + let res = U128::MAX.bounded_div_2k_mod_q(71, 30, &q); + assert_eq!(res, U128::from_be_hex("000071EEB6013E76000071EEB6013E76")); + + // Have k_bound >> k + let res = U128::MAX.bounded_div_2k_mod_q(30, 127, &q); + assert_eq!(res, U128::from_be_hex("000071EEB6013E76000071EEB6013E76")); + } +} diff --git a/src/modular/bingcd/xgcd.rs b/src/modular/bingcd/xgcd.rs new file mode 100644 index 000000000..29215129d --- /dev/null +++ b/src/modular/bingcd/xgcd.rs @@ -0,0 +1,939 @@ +use crate::modular::bingcd::matrix::{StateMatrix, UpdateMatrix}; +use crate::modular::bingcd::tools::const_max; +use crate::{ConstChoice, Int, NonZero, Odd, U64, U128, Uint}; + +/// Container for the raw output of the Binary XGCD algorithm. +pub(crate) struct RawOddUintBinxgcdOutput { + gcd: Odd>, + matrix: UpdateMatrix, +} + +impl RawOddUintBinxgcdOutput { + /// Process raw output, constructing an [UintBinxgcdOutput] object. + pub(crate) fn process(&mut self) -> OddUintBinxgcdOutput { + self.remove_matrix_factors(); + let (x, y) = self.bezout_coefficients(); + let (lhs_on_gcd, rhs_on_gcd) = self.quotients(); + OddUintBinxgcdOutput { + gcd: self.gcd, + x, + y, + lhs_on_gcd, + rhs_on_gcd, + } + } + + /// Divide `self.matrix` by `2^self.matrix.k`, i.e., remove the excess doublings from + /// `self.matrix`. + /// + /// The performed divisions are modulo `lhs` and `rhs` to maintain the correctness of the XGCD + /// state. + /// + /// This operation is 'fast' since it only applies the division to the top row of the matrix. + /// This is allowed since it is assumed that `self.matrix * (lhs, rhs) = (gcd, 0)`; dividing + /// the bottom row of the matrix by a constant has no impact since its inner-product with the + /// input vector is zero. + fn remove_matrix_factors(&mut self) { + let (lhs_div_gcd, rhs_div_gcd) = self.quotients(); + let (x, y, .., k, k_upper_bound) = self.matrix.as_elements_mut(); + if *k_upper_bound > 0 { + *x = x.bounded_div_2k_mod_q( + *k, + *k_upper_bound, + &rhs_div_gcd.to_odd().expect("odd by construction"), + ); + *y = y.bounded_div_2k_mod_q( + *k, + *k_upper_bound, + &lhs_div_gcd.to_odd().expect("odd by construction"), + ); + *k = 0; + *k_upper_bound = 0; + } + } + + /// Obtain the bezout coefficients `(x, y)` such that `lhs * x + rhs * y = gcd`. + fn bezout_coefficients(&self) -> (Int, Int) { + let (m00, m01, m10, m11, pattern, ..) = self.matrix.as_elements(); + let m10_sub_m00 = m10.wrapping_sub(m00); + let m11_sub_m01 = m11.wrapping_sub(m01); + let apply = Uint::lte(&m10_sub_m00, m00).and(Uint::lte(&m11_sub_m01, m01)); + let m00 = Uint::select(m00, &m10_sub_m00, apply) + .wrapping_neg_if(apply.xor(pattern.not())) + .as_int(); + let m01 = Uint::select(m01, &m11_sub_m01, apply) + .wrapping_neg_if(apply.xor(pattern)) + .as_int(); + (m00, m01) + } + + /// Obtain the quotients `lhs/gcd` and `rhs/gcd` from `matrix`. + fn quotients(&self) -> (Uint, Uint) { + let (.., rhs_div_gcd, lhs_div_gcd, _, _, _) = self.matrix.as_elements(); + (*lhs_div_gcd, *rhs_div_gcd) + } +} + +/// Output of the Binary XGCD algorithm applied to two [Uint]s. +pub type UintBinxgcdOutput = BaseUintBinxgcdOutput, LIMBS>; + +/// Output of the Binary XGCD algorithm applied to two [`NonZero>`]s. +pub type NonZeroUintBinxgcdOutput = + BaseUintBinxgcdOutput>, LIMBS>; + +/// Output of the Binary XGCD algorithm applied to two [`Odd>`]s. +pub type OddUintBinxgcdOutput = BaseUintBinxgcdOutput>, LIMBS>; + +/// Container for the processed output of the Binary XGCD algorithm. +#[derive(Debug)] +pub struct BaseUintBinxgcdOutput { + pub gcd: T, + pub x: Int, + pub y: Int, + pub lhs_on_gcd: Uint, + pub rhs_on_gcd: Uint, +} + +impl BaseUintBinxgcdOutput { + /// Borrow the elements in this struct. + pub fn to_components(&self) -> (T, Int, Int, Uint, Uint) { + (self.gcd, self.x, self.y, self.lhs_on_gcd, self.rhs_on_gcd) + } + + /// Mutably borrow the elements in this struct. + pub fn as_components_mut( + &mut self, + ) -> ( + &mut T, + &mut Int, + &mut Int, + &mut Uint, + &mut Uint, + ) { + ( + &mut self.gcd, + &mut self.x, + &mut self.y, + &mut self.lhs_on_gcd, + &mut self.rhs_on_gcd, + ) + } + + /// The greatest common divisor stored in this object. + pub fn gcd(&self) -> T { + self.gcd + } + + /// Obtain a copy of the Bézout coefficients. + pub fn bezout_coefficients(&self) -> (Int, Int) { + (self.x, self.y) + } + + /// Obtain a copy of the quotients `lhs/gcd` and `rhs/gcd`. + pub fn quotients(&self) -> (Uint, Uint) { + (self.lhs_on_gcd, self.rhs_on_gcd) + } +} + +/// Number of bits used by [Odd::>::optimized_binxgcd] to represent a "compact" [Uint]. +const SUMMARY_BITS: u32 = U64::BITS; + +/// Number of limbs used to represent [Self::SUMMARY_BITS]. +const SUMMARY_LIMBS: usize = U64::LIMBS; + +/// Twice the number of limbs used to represent [Self::SUMMARY_BITS], i.e., two times +/// [Self::SUMMARY_LIMBS]. +const DOUBLE_SUMMARY_LIMBS: usize = U128::LIMBS; + +impl Odd> { + /// The minimal number of binary GCD iterations required to guarantee successful completion. + const MIN_BINGCD_ITERATIONS: u32 = 2 * Self::BITS - 1; + + /// Given `(self, rhs)`, computes `(g, x, y)` s.t. `self * x + rhs * y = g = gcd(self, rhs)`, + /// leveraging the Binary Extended GCD algorithm. + pub(crate) fn binxgcd_nz(&self, rhs: &NonZero>) -> OddUintBinxgcdOutput { + let (lhs_, rhs_) = (self.as_ref(), rhs.as_ref()); + + // The `binxgcd` subroutine requires `rhs` needs to be odd. We leverage the equality + // gcd(lhs, rhs) = gcd(lhs, |lhs-rhs|) to deal with the case that `rhs` is even. + let rhs_gt_lhs = Uint::gt(rhs_, lhs_); + let rhs_is_even = rhs_.is_odd().not(); + let abs_lhs_sub_rhs = Uint::select( + &lhs_.wrapping_sub(rhs_), + &rhs_.wrapping_sub(lhs_), + rhs_gt_lhs, + ); + let rhs_ = Uint::select(rhs.as_ref(), &abs_lhs_sub_rhs, rhs_is_even) + .to_odd() + .expect("rhs is odd by construction"); + + let mut output = self.binxgcd_(&rhs_); + output.remove_matrix_factors(); + + // Modify the output to negate the transformation applied to the input. + let matrix = &mut output.matrix; + let case_one = rhs_is_even.and(rhs_gt_lhs); + matrix.conditional_subtract_right_column_from_left(case_one); + let case_two = rhs_is_even.and(rhs_gt_lhs.not()); + matrix.conditional_add_right_column_to_left(case_two); + matrix.conditional_negate(case_two); + + output.process() + } + + /// Given `(self, rhs)`, computes `(g, x, y)` s.t. `self * x + rhs * y = g = gcd(self, rhs)`, + /// leveraging the Binary Extended GCD algorithm. + /// + /// This function switches between the "classic" and "optimized" algorithm at a best-effort + /// threshold. When using [Uint]s with `LIMBS` close to the threshold, it may be useful to + /// manually test whether the classic or optimized algorithm is faster for your machine. + pub(crate) fn binxgcd_(&self, rhs: &Self) -> RawOddUintBinxgcdOutput { + if LIMBS < 4 { + self.classic_binxgcd(rhs) + } else { + self.optimized_binxgcd(rhs) + } + } + + /// Execute the classic Binary Extended GCD algorithm. + /// + /// Given `(self, rhs)`, computes `(g, x, y)` s.t. `self * x + rhs * y = g = gcd(self, rhs)`. + /// + /// Ref: Pornin, Optimized Binary GCD for Modular Inversion, Algorithm 1. + /// . + pub(crate) fn classic_binxgcd(&self, rhs: &Self) -> RawOddUintBinxgcdOutput { + let (gcd, _, matrix) = self.partial_binxgcd_vartime::( + rhs.as_ref(), + Self::MIN_BINGCD_ITERATIONS, + ConstChoice::TRUE, + ); + + RawOddUintBinxgcdOutput { gcd, matrix } + } + + /// Given `(self, rhs)`, computes `(g, x, y)` s.t. `self * x + rhs * y = g = gcd(self, rhs)`, + /// leveraging the Binary Extended GCD algorithm. + /// + /// **Warning**: `self` and `rhs` must be contained in an [U128] or larger. + /// + /// Note: this algorithm becomes more efficient than the classical algorithm for [Uint]s with + /// relatively many `LIMBS`. A best-effort threshold is presented in [Self::binxgcd_]. + /// + /// Note: the full algorithm has an additional parameter; this function selects the best-effort + /// value for this parameter. You might be able to further tune your performance by calling the + /// [Self::optimized_bingcd_] function directly. + /// + /// Ref: Pornin, Optimized Binary GCD for Modular Inversion, Algorithm 2. + /// . + pub(crate) fn optimized_binxgcd(&self, rhs: &Self) -> RawOddUintBinxgcdOutput { + assert!(Self::BITS >= U128::BITS); + self.optimized_binxgcd_::(rhs) + } + + /// Given `(self, rhs)`, computes `(g, x, y)`, s.t. `self * x + rhs * y = g = gcd(self, rhs)`, + /// leveraging the optimized Binary Extended GCD algorithm. + /// + /// Ref: Pornin, Optimized Binary GCD for Modular Inversion, Algorithm 2. + /// + /// + /// In summary, the optimized algorithm does not operate on `self` and `rhs` directly, but + /// instead of condensed summaries that fit in few registers. Based on these summaries, an + /// update matrix is constructed by which `self` and `rhs` are updated in larger steps. + /// + /// This function is generic over the following three values: + /// - `K`: the number of bits used when summarizing `self` and `rhs` for the inner loop. The + /// `K+1` top bits and `K-1` least significant bits are selected. It is recommended to keep + /// `K` close to a (multiple of) the number of bits that fit in a single register. + /// - `LIMBS_K`: should be chosen as the minimum number s.t. `Uint::::BITS ≥ K`, + /// - `LIMBS_2K`: should be chosen as the minimum number s.t. `Uint::::BITS ≥ 2K`. + pub(crate) fn optimized_binxgcd_( + &self, + rhs: &Self, + ) -> RawOddUintBinxgcdOutput { + let (mut a, mut b) = (*self.as_ref(), *rhs.as_ref()); + let mut state = StateMatrix::UNIT; + + let (mut a_is_negative, mut b_is_negative); + let mut i = 0; + while i < Self::MIN_BINGCD_ITERATIONS.div_ceil(K - 1) { + // Loop invariants: + // i) each iteration of this loop, `a.bits() + b.bits()` shrinks by at least K-1, + // until `b = 0`. + // ii) `a` is odd. + i += 1; + + // Construct compact_a and compact_b as the summary of a and b, respectively. + // TODO: deal with the case that a fits entirely in compact + let b_bits = b.bits(); + let n = const_max(2 * K, const_max(a.bits(), b_bits)); + let compact_a = a.compact::(n); + let compact_b = b.compact::(n); + let b_eq_compact_b = + ConstChoice::from_u32_le(b_bits, K - 1).or(ConstChoice::from_u32_eq(n, 2 * K)); + + // Compute the K-1 iteration update matrix from a_ and b_ + let (.., update_matrix) = compact_a + .to_odd() + .expect("a is always odd") + .partial_binxgcd_vartime::(&compact_b, K - 1, b_eq_compact_b); + + // Update `a` and `b` using the update matrix + let (updated_a, updated_b) = update_matrix.extended_apply_to((a, b)); + (a, a_is_negative) = updated_a.wrapping_drop_extension(); + (b, b_is_negative) = updated_b.wrapping_drop_extension(); + + state = update_matrix.mul_right(&state); + + state.conditional_negate_top_row(a_is_negative); + state.conditional_negate_bottom_row(b_is_negative); + } + + let gcd = a + .to_odd() + .expect("gcd of an odd value with something else is always odd"); + + let matrix = state.to_update_matrix(); + RawOddUintBinxgcdOutput { gcd, matrix } + } + + /// Executes the optimized Binary GCD inner loop. + /// + /// Ref: Pornin, Optimized Binary GCD for Modular Inversion, Algorithm 2. + /// . + /// + /// The function outputs the reduced values `(a, b)` for the input values `(self, rhs)` as well + /// as the matrix that yields the former two when multiplied with the latter two. + /// + /// Additionally, the number doublings that were executed is returned. By construction, each + /// element in `M` lies in the interval `(-2^doublings, 2^doublings]`. + /// + /// Note: this implementation deviates slightly from the paper, in that it can be instructed to + /// "run in place" (i.e., execute iterations that do nothing) once `a` becomes zero. + /// This is done by passing a truthy `halt_at_zero`. + /// + /// The function executes in time variable in `iterations`. + #[inline] + pub(crate) fn partial_binxgcd_vartime( + &self, + rhs: &Uint, + iterations: u32, + halt_at_zero: ConstChoice, + ) -> (Self, Uint, UpdateMatrix) { + let (mut a, mut b) = (*self.as_ref(), *rhs); + // This matrix corresponds with (f0, g0, f1, g1) in the paper. + let mut matrix = UpdateMatrix::UNIT; + + // Compute the update matrix. + // Note: to be consistent with the paper, the `binxgcd_step` algorithm requires the second + // argument to be odd. Here, we have `a` odd, so we have to swap a and b before and after + // calling the subroutine. The columns of the matrix have to be swapped accordingly. + Uint::swap(&mut a, &mut b); + matrix.swap_rows(); + + let mut j = 0; + while j < iterations { + Self::binxgcd_step(&mut a, &mut b, &mut matrix, halt_at_zero); + j += 1; + } + + // Undo swap + Uint::swap(&mut a, &mut b); + matrix.swap_rows(); + + let a = a.to_odd().expect("a is always odd"); + (a, b, matrix) + } + + /// Binary XGCD update step. + /// + /// This is a condensed, constant time execution of the following algorithm: + /// ```text + /// if a mod 2 == 1 + /// if a < b + /// (a, b) ← (b, a) + /// (f0, g0, f1, g1) ← (f1, g1, f0, g0) + /// a ← a - b + /// (f0, g0) ← (f0 - f1, g0 - g1) + /// if a > 0 + /// a ← a/2 + /// (f1, g1) ← (2f1, 2g1) + /// ``` + /// where `matrix` represents + /// ```text + /// (f0 g0) + /// (f1 g1). + /// ``` + /// + /// Note: this algorithm assumes `b` to be an odd integer. The algorithm will likely not yield + /// the correct result when this is not the case. + /// + /// Ref: Pornin, Algorithm 2, L8-17, . + #[inline] + fn binxgcd_step( + a: &mut Uint, + b: &mut Uint, + matrix: &mut UpdateMatrix, + halt_at_zero: ConstChoice, + ) { + let a_odd = a.is_odd(); + let a_lt_b = Uint::lt(a, b); + + // swap if a odd and a < b + let swap = a_odd.and(a_lt_b); + Uint::conditional_swap(a, b, swap); + matrix.conditional_swap_rows(swap); + + // subtract b from a when a is odd + *a = a.wrapping_sub(&Uint::select(&Uint::ZERO, b, a_odd)); + matrix.conditional_subtract_bottom_row_from_top(a_odd); + + // Div a by 2. + let double = a.is_nonzero().or(halt_at_zero.not()); + // safe to vartime; shr_vartime is variable in the value of shift only. Since this shift + // is a public constant, the constant time property of this algorithm is not impacted. + *a = a.shr_vartime(1); + + // Double the bottom row of the matrix when a was ≠ 0 and when not halting. + matrix.conditional_double_bottom_row(double); + } +} + +#[cfg(all(test, not(miri)))] +mod tests { + use crate::modular::bingcd::xgcd::OddUintBinxgcdOutput; + use crate::{ConcatMixed, Gcd, Uint}; + use core::ops::Div; + use num_traits::Zero; + + #[cfg(feature = "rand_core")] + use rand_chacha::ChaChaRng; + #[cfg(feature = "rand_core")] + use rand_core::SeedableRng; + + mod test_extract_quotients { + use crate::modular::bingcd::matrix::UpdateMatrix; + use crate::modular::bingcd::xgcd::RawOddUintBinxgcdOutput; + use crate::{ConstChoice, U64, Uint}; + + fn raw_binxgcdoutput_setup( + matrix: UpdateMatrix, + ) -> RawOddUintBinxgcdOutput { + RawOddUintBinxgcdOutput { + gcd: Uint::::ONE.to_odd().unwrap(), + matrix, + } + } + + #[test] + fn test_extract_quotients_unit() { + let output = raw_binxgcdoutput_setup(UpdateMatrix::<{ U64::LIMBS }>::UNIT); + let (lhs_on_gcd, rhs_on_gcd) = output.quotients(); + assert_eq!(lhs_on_gcd, Uint::ONE); + assert_eq!(rhs_on_gcd, Uint::ZERO); + } + + #[test] + fn test_extract_quotients_basic() { + let output = raw_binxgcdoutput_setup(UpdateMatrix::<{ U64::LIMBS }>::new( + Uint::ZERO, + Uint::ZERO, + Uint::from(5u32), + Uint::from(7u32), + ConstChoice::FALSE, + 0, + 0, + )); + let (lhs_on_gcd, rhs_on_gcd) = output.quotients(); + assert_eq!(lhs_on_gcd, Uint::from(7u32)); + assert_eq!(rhs_on_gcd, Uint::from(5u32)); + + let output = raw_binxgcdoutput_setup(UpdateMatrix::<{ U64::LIMBS }>::new( + Uint::ZERO, + Uint::ZERO, + Uint::from(7u32), + Uint::from(5u32), + ConstChoice::TRUE, + 0, + 0, + )); + let (lhs_on_gcd, rhs_on_gcd) = output.quotients(); + assert_eq!(lhs_on_gcd, Uint::from(5u32)); + assert_eq!(rhs_on_gcd, Uint::from(7u32)); + } + } + + mod test_derive_bezout_coefficients { + use crate::modular::bingcd::matrix::UpdateMatrix; + use crate::modular::bingcd::xgcd::RawOddUintBinxgcdOutput; + use crate::{ConstChoice, Int, U64, Uint}; + + #[test] + fn test_derive_bezout_coefficients_unit() { + let mut output = RawOddUintBinxgcdOutput { + gcd: Uint::ONE.to_odd().unwrap(), + matrix: UpdateMatrix::<{ U64::LIMBS }>::UNIT, + }; + output.remove_matrix_factors(); + let (x, y) = output.bezout_coefficients(); + assert_eq!(x, Int::ONE); + assert_eq!(y, Int::ZERO); + } + + #[test] + fn test_derive_bezout_coefficients_basic() { + let mut output = RawOddUintBinxgcdOutput { + gcd: Uint::ONE.to_odd().unwrap(), + matrix: UpdateMatrix::new( + U64::from(2u32), + U64::from(3u32), + U64::from(4u32), + U64::from(5u32), + ConstChoice::TRUE, + 0, + 0, + ), + }; + output.remove_matrix_factors(); + let (x, y) = output.bezout_coefficients(); + assert_eq!(x, Int::from(-2i32)); + assert_eq!(y, Int::from(2i32)); + + let mut output = RawOddUintBinxgcdOutput { + gcd: Uint::ONE.to_odd().unwrap(), + matrix: UpdateMatrix::new( + U64::from(2u32), + U64::from(3u32), + U64::from(3u32), + U64::from(5u32), + ConstChoice::FALSE, + 0, + 1, + ), + }; + output.remove_matrix_factors(); + let (x, y) = output.bezout_coefficients(); + assert_eq!(x, Int::from(1i32)); + assert_eq!(y, Int::from(-2i32)); + } + + #[test] + fn test_derive_bezout_coefficients_removes_doublings_easy() { + let mut output = RawOddUintBinxgcdOutput { + gcd: Uint::ONE.to_odd().unwrap(), + matrix: UpdateMatrix::new( + U64::from(2u32), + U64::from(6u32), + U64::from(3u32), + U64::from(5u32), + ConstChoice::TRUE, + 1, + 1, + ), + }; + output.remove_matrix_factors(); + let (x, y) = output.bezout_coefficients(); + assert_eq!(x, Int::ONE); + assert_eq!(y, Int::from(-3i32)); + + let mut output = RawOddUintBinxgcdOutput { + gcd: Uint::ONE.to_odd().unwrap(), + matrix: UpdateMatrix::new( + U64::from(120u32), + U64::from(64u32), + U64::from(7u32), + U64::from(5u32), + ConstChoice::FALSE, + 5, + 6, + ), + }; + output.remove_matrix_factors(); + let (x, y) = output.bezout_coefficients(); + assert_eq!(x, Int::from(-9i32)); + assert_eq!(y, Int::from(2i32)); + } + + #[test] + fn test_derive_bezout_coefficients_removes_doublings_for_odd_numbers() { + let mut output = RawOddUintBinxgcdOutput { + gcd: Uint::ONE.to_odd().unwrap(), + matrix: UpdateMatrix::new( + U64::from(2u32), + U64::from(6u32), + U64::from(7u32), + U64::from(5u32), + ConstChoice::FALSE, + 3, + 7, + ), + }; + output.remove_matrix_factors(); + let (x, y) = output.bezout_coefficients(); + assert_eq!(x, Int::from(-2i32)); + assert_eq!(y, Int::from(2i32)); + } + } + + mod test_partial_binxgcd { + use crate::modular::bingcd::matrix::UpdateMatrix; + use crate::{ConstChoice, Odd, U64}; + + const A: Odd = U64::from_be_hex("CA048AFA63CD6A1F").to_odd().expect("odd"); + const B: U64 = U64::from_be_hex("AE693BF7BE8E5566"); + + #[test] + fn test_partial_binxgcd() { + let (.., matrix) = + A.partial_binxgcd_vartime::<{ U64::LIMBS }>(&B, 5, ConstChoice::TRUE); + let (.., k, _) = matrix.as_elements(); + assert_eq!(k, 5); + assert_eq!( + matrix, + UpdateMatrix::new( + U64::from(8u64), + U64::from(4u64), + U64::from(2u64), + U64::from(5u64), + ConstChoice::TRUE, + 5, + 5 + ) + ); + } + + #[test] + fn test_partial_binxgcd_constructs_correct_matrix() { + let target_a = U64::from_be_hex("1CB3FB3FA1218FDB").to_odd().unwrap(); + let target_b = U64::from_be_hex("0EA028AF0F8966B6"); + + let (new_a, new_b, matrix) = + A.partial_binxgcd_vartime::<{ U64::LIMBS }>(&B, 5, ConstChoice::TRUE); + + assert_eq!(new_a, target_a); + assert_eq!(new_b, target_b); + + let (computed_a, computed_b) = matrix.extended_apply_to((A.get(), B)); + let computed_a = computed_a.wrapping_drop_extension().0; + let computed_b = computed_b.wrapping_drop_extension().0; + + assert_eq!(computed_a, target_a); + assert_eq!(computed_b, target_b); + } + + const SMALL_A: Odd = U64::from_be_hex("0000000003CD6A1F").to_odd().expect("odd"); + const SMALL_B: U64 = U64::from_be_hex("000000000E8E5566"); + + #[test] + fn test_partial_binxgcd_halts() { + let (gcd, _, matrix) = + SMALL_A.partial_binxgcd_vartime::<{ U64::LIMBS }>(&SMALL_B, 60, ConstChoice::TRUE); + let (.., k, k_upper_bound) = matrix.as_elements(); + assert_eq!(k, 35); + assert_eq!(k_upper_bound, 60); + assert_eq!(gcd.get(), SMALL_A.gcd(&SMALL_B)); + } + + #[test] + fn test_partial_binxgcd_does_not_halt() { + let (gcd, .., matrix) = + SMALL_A.partial_binxgcd_vartime::<{ U64::LIMBS }>(&SMALL_B, 60, ConstChoice::FALSE); + let (.., k, k_upper_bound) = matrix.as_elements(); + assert_eq!(k, 60); + assert_eq!(k_upper_bound, 60); + assert_eq!(gcd.get(), SMALL_A.gcd(&SMALL_B)); + } + } + + /// Helper function to effectively test xgcd. + fn test_xgcd( + lhs: Uint, + rhs: Uint, + output: OddUintBinxgcdOutput, + ) where + Uint: + Gcd> + ConcatMixed, MixedOutput = Uint>, + { + // Test the gcd + assert_eq!(lhs.gcd(&rhs), output.gcd, "{} {}", lhs, rhs); + + // Test the quotients + assert_eq!(output.lhs_on_gcd, lhs.div(output.gcd.as_nz_ref())); + assert_eq!(output.rhs_on_gcd, rhs.div(output.gcd.as_nz_ref())); + + // Test the Bezout coefficients for correctness + let (x, y) = output.bezout_coefficients(); + assert_eq!( + x.widening_mul_uint(&lhs) + y.widening_mul_uint(&rhs), + output.gcd.resize().as_int(), + ); + + // Test the Bezout coefficients for minimality + assert!(x.abs() <= rhs.div(output.gcd.as_nz_ref())); + assert!(y.abs() <= lhs.div(output.gcd.as_nz_ref())); + if lhs != rhs { + assert!(x.abs() <= output.rhs_on_gcd.shr(1) || output.rhs_on_gcd.is_zero()); + assert!(y.abs() <= output.lhs_on_gcd.shr(1) || output.lhs_on_gcd.is_zero()); + } + } + + #[cfg(feature = "rand_core")] + fn make_rng() -> ChaChaRng { + ChaChaRng::from_seed([0; 32]) + } + + mod test_binxgcd_nz { + use crate::modular::bingcd::xgcd::tests::test_xgcd; + use crate::{ + ConcatMixed, Gcd, Int, U64, U128, U192, U256, U384, U512, U768, U1024, U2048, U4096, + U8192, Uint, + }; + + #[cfg(feature = "rand_core")] + use super::make_rng; + #[cfg(feature = "rand_core")] + use crate::Random; + + fn binxgcd_nz_test( + lhs: Uint, + rhs: Uint, + ) where + Uint: + Gcd> + ConcatMixed, MixedOutput = Uint>, + { + let output = lhs.to_odd().unwrap().binxgcd_nz(&rhs.to_nz().unwrap()); + test_xgcd(lhs, rhs, output); + } + + #[cfg(feature = "rand_core")] + fn binxgcd_nz_randomized_tests(iterations: u32) + where + Uint: + Gcd> + ConcatMixed, MixedOutput = Uint>, + { + let mut rng = make_rng(); + for _ in 0..iterations { + let x = Uint::random(&mut rng).bitor(&Uint::ONE); + let y = Uint::random(&mut rng).saturating_add(&Uint::ONE); + binxgcd_nz_test(x, y); + } + } + + fn binxgcd_nz_tests() + where + Uint: + Gcd> + ConcatMixed, MixedOutput = Uint>, + { + // Edge cases + let odd_upper_bound = *Int::MAX.as_uint(); + let even_upper_bound = Int::MIN.abs(); + binxgcd_nz_test(Uint::ONE, Uint::ONE); + binxgcd_nz_test(Uint::ONE, odd_upper_bound); + binxgcd_nz_test(Uint::ONE, even_upper_bound); + binxgcd_nz_test(odd_upper_bound, Uint::ONE); + binxgcd_nz_test(odd_upper_bound, odd_upper_bound); + binxgcd_nz_test(odd_upper_bound, even_upper_bound); + + #[cfg(feature = "rand_core")] + binxgcd_nz_randomized_tests(100); + } + + #[test] + fn test_binxgcd_nz() { + binxgcd_nz_tests::<{ U64::LIMBS }, { U128::LIMBS }>(); + binxgcd_nz_tests::<{ U128::LIMBS }, { U256::LIMBS }>(); + binxgcd_nz_tests::<{ U192::LIMBS }, { U384::LIMBS }>(); + binxgcd_nz_tests::<{ U256::LIMBS }, { U512::LIMBS }>(); + binxgcd_nz_tests::<{ U384::LIMBS }, { U768::LIMBS }>(); + binxgcd_nz_tests::<{ U512::LIMBS }, { U1024::LIMBS }>(); + binxgcd_nz_tests::<{ U1024::LIMBS }, { U2048::LIMBS }>(); + binxgcd_nz_tests::<{ U2048::LIMBS }, { U4096::LIMBS }>(); + binxgcd_nz_tests::<{ U4096::LIMBS }, { U8192::LIMBS }>(); + } + } + + mod test_classic_binxgcd { + use crate::modular::bingcd::xgcd::tests::test_xgcd; + use crate::{ + ConcatMixed, Gcd, Int, U64, U128, U192, U256, U384, U512, U768, U1024, U2048, U4096, + U8192, Uint, + }; + + #[cfg(feature = "rand_core")] + use super::make_rng; + #[cfg(feature = "rand_core")] + use crate::Random; + + fn classic_binxgcd_test( + lhs: Uint, + rhs: Uint, + ) where + Uint: + Gcd> + ConcatMixed, MixedOutput = Uint>, + { + let mut output = lhs + .to_odd() + .unwrap() + .classic_binxgcd(&rhs.to_odd().unwrap()); + test_xgcd(lhs, rhs, output.process()); + } + + #[cfg(feature = "rand_core")] + fn classic_binxgcd_randomized_tests( + iterations: u32, + ) where + Uint: + Gcd> + ConcatMixed, MixedOutput = Uint>, + { + let mut rng = make_rng(); + for _ in 0..iterations { + let x = Uint::::random(&mut rng).bitor(&Uint::ONE); + let y = Uint::::random(&mut rng).bitor(&Uint::ONE); + classic_binxgcd_test(x, y); + } + } + + fn classic_binxgcd_tests() + where + Uint: + Gcd> + ConcatMixed, MixedOutput = Uint>, + { + // Edge cases + let upper_bound = *Int::MAX.as_uint(); + classic_binxgcd_test(Uint::ONE, Uint::ONE); + classic_binxgcd_test(Uint::ONE, upper_bound); + classic_binxgcd_test(upper_bound, Uint::ONE); + classic_binxgcd_test(upper_bound, upper_bound); + + #[cfg(feature = "rand_core")] + classic_binxgcd_randomized_tests(100); + } + + #[test] + fn test_classic_binxgcd() { + classic_binxgcd_tests::<{ U64::LIMBS }, { U128::LIMBS }>(); + classic_binxgcd_tests::<{ U128::LIMBS }, { U256::LIMBS }>(); + classic_binxgcd_tests::<{ U192::LIMBS }, { U384::LIMBS }>(); + classic_binxgcd_tests::<{ U256::LIMBS }, { U512::LIMBS }>(); + classic_binxgcd_tests::<{ U384::LIMBS }, { U768::LIMBS }>(); + classic_binxgcd_tests::<{ U512::LIMBS }, { U1024::LIMBS }>(); + classic_binxgcd_tests::<{ U1024::LIMBS }, { U2048::LIMBS }>(); + classic_binxgcd_tests::<{ U2048::LIMBS }, { U4096::LIMBS }>(); + classic_binxgcd_tests::<{ U4096::LIMBS }, { U8192::LIMBS }>(); + } + } + + mod test_optimized_binxgcd { + #[cfg(feature = "rand_core")] + use super::make_rng; + #[cfg(feature = "rand_core")] + use crate::Random; + + use crate::modular::bingcd::xgcd::tests::test_xgcd; + use crate::modular::bingcd::xgcd::{DOUBLE_SUMMARY_LIMBS, SUMMARY_BITS, SUMMARY_LIMBS}; + use crate::{ + ConcatMixed, Gcd, Int, U64, U128, U192, U256, U384, U512, U768, U1024, U2048, U4096, + U8192, Uint, + }; + + fn optimized_binxgcd_test( + lhs: Uint, + rhs: Uint, + ) where + Uint: + Gcd> + ConcatMixed, MixedOutput = Uint>, + { + let mut output = lhs + .to_odd() + .unwrap() + .optimized_binxgcd(&rhs.to_odd().unwrap()); + test_xgcd(lhs, rhs, output.process()); + } + + #[cfg(feature = "rand_core")] + fn optimized_binxgcd_randomized_tests( + iterations: u32, + ) where + Uint: + Gcd> + ConcatMixed, MixedOutput = Uint>, + { + let mut rng = make_rng(); + for _ in 0..iterations { + let x = Uint::::random(&mut rng).bitor(&Uint::ONE); + let y = Uint::::random(&mut rng).bitor(&Uint::ONE); + optimized_binxgcd_test(x, y); + } + } + + fn optimized_binxgcd_tests() + where + Uint: + Gcd> + ConcatMixed, MixedOutput = Uint>, + { + // Edge cases + let upper_bound = *Int::MAX.as_uint(); + optimized_binxgcd_test(Uint::ONE, Uint::ONE); + optimized_binxgcd_test(Uint::ONE, upper_bound); + optimized_binxgcd_test(upper_bound, Uint::ONE); + optimized_binxgcd_test(upper_bound, upper_bound); + + #[cfg(feature = "rand_core")] + optimized_binxgcd_randomized_tests(100); + } + + #[test] + fn test_optimized_binxgcd_edge_cases() { + // If one of these tests fails, you have probably tweaked the SUMMARY_BITS, + // SUMMARY_LIMBS or DOUBLE_SUMMARY_LIMBS settings. Please make sure to update these + // tests accordingly. + assert_eq!(SUMMARY_BITS, 64); + assert_eq!(SUMMARY_LIMBS, U64::LIMBS); + assert_eq!(DOUBLE_SUMMARY_LIMBS, U128::LIMBS); + + // Case #1: a > b but a.compact() < b.compact() + let a = U256::from_be_hex( + "1234567890ABCDEF80000000000000000000000000000000BEDCBA0987654321", + ); + let b = U256::from_be_hex( + "1234567890ABCDEF800000000000000000000000000000007EDCBA0987654321", + ); + assert!(a > b); + assert!( + a.compact::(U256::BITS) + < b.compact::(U256::BITS) + ); + optimized_binxgcd_test(a, b); + + // Case #2: a < b but a.compact() > b.compact() + optimized_binxgcd_test(b, a); + + // Case #3: a > b but a.compact() = b.compact() + let a = U256::from_be_hex( + "1234567890ABCDEF80000000000000000000000000000000FEDCBA0987654321", + ); + let b = U256::from_be_hex( + "1234567890ABCDEF800000000000000000000000000000007EDCBA0987654321", + ); + assert!(a > b); + assert_eq!( + a.compact::(U256::BITS), + b.compact::(U256::BITS) + ); + optimized_binxgcd_test(a, b); + + // Case #4: a < b but a.compact() = b.compact() + optimized_binxgcd_test(b, a); + } + + #[test] + fn test_optimized_binxgcd() { + optimized_binxgcd_tests::<{ U128::LIMBS }, { U256::LIMBS }>(); + optimized_binxgcd_tests::<{ U192::LIMBS }, { U384::LIMBS }>(); + optimized_binxgcd_tests::<{ U256::LIMBS }, { U512::LIMBS }>(); + optimized_binxgcd_tests::<{ U384::LIMBS }, { U768::LIMBS }>(); + optimized_binxgcd_tests::<{ U512::LIMBS }, { U1024::LIMBS }>(); + optimized_binxgcd_tests::<{ U1024::LIMBS }, { U2048::LIMBS }>(); + optimized_binxgcd_tests::<{ U2048::LIMBS }, { U4096::LIMBS }>(); + optimized_binxgcd_tests::<{ U4096::LIMBS }, { U8192::LIMBS }>(); + } + } +} diff --git a/src/non_zero.rs b/src/non_zero.rs index e76b5b970..3bc484bf3 100644 --- a/src/non_zero.rs +++ b/src/non_zero.rs @@ -177,6 +177,11 @@ impl NonZero> { // Note: a NonZero always has a non-zero magnitude, so it is safe to unwrap. (NonZero::>::new_unwrap(abs), sign) } + + /// Convert a [`NonZero`] to its [`NonZero`] magnitude. + pub const fn abs(&self) -> NonZero> { + self.abs_sign().0 + } } #[cfg(feature = "hybrid-array")] diff --git a/src/odd.rs b/src/odd.rs index b442a6326..6db562c4c 100644 --- a/src/odd.rs +++ b/src/odd.rs @@ -8,7 +8,7 @@ use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; use crate::BoxedUint; #[cfg(feature = "rand_core")] -use crate::{Random, rand_core::TryRngCore}; +use crate::{Int, Random, rand_core::TryRngCore}; #[cfg(all(feature = "alloc", feature = "rand_core"))] use crate::RandomBits; @@ -58,6 +58,9 @@ impl Odd { } impl Odd> { + /// Total size of the represented integer in bits. + pub const BITS: u32 = Uint::::BITS; + /// Create a new [`Odd>`] from the provided big endian hex string. /// /// Panics if the hex is malformed or not zero-padded accordingly for the size, or if the value is even. @@ -160,6 +163,14 @@ impl Random for Odd> { } } +#[cfg(feature = "rand_core")] +impl Random for Odd> { + /// Generate a random `Odd>`. + fn try_random(rng: &mut R) -> Result { + Odd::>::try_random(rng).map(|r| Odd(r.as_int())) + } +} + #[cfg(all(feature = "alloc", feature = "rand_core"))] impl Odd { /// Generate a random `Odd>`. diff --git a/src/uint.rs b/src/uint.rs index 417056c95..92fd423c9 100644 --- a/src/uint.rs +++ b/src/uint.rs @@ -24,6 +24,7 @@ mod macros; mod add; mod add_mod; +mod bingcd; mod bit_and; mod bit_not; mod bit_or; diff --git a/src/uint/bingcd.rs b/src/uint/bingcd.rs new file mode 100644 index 000000000..9494486d3 --- /dev/null +++ b/src/uint/bingcd.rs @@ -0,0 +1,351 @@ +//! This module implements (a constant variant of) the Optimized Extended Binary GCD algorithm, +//! which is described by Pornin as Algorithm 2 in "Optimized Binary GCD for Modular Inversion". +//! Ref: + +use crate::modular::bingcd::tools::const_min; +use crate::modular::bingcd::{NonZeroUintBinxgcdOutput, OddUintBinxgcdOutput, UintBinxgcdOutput}; +use crate::{ConstChoice, Int, NonZero, Odd, Uint}; + +impl Uint { + /// Compute the greatest common divisor of `self` and `rhs`. + pub fn bingcd(&self, rhs: &Self) -> Self { + let self_is_zero = self.is_nonzero().not(); + let self_nz = Uint::select(self, &Uint::ONE, self_is_zero) + .to_nz() + .expect("self is non zero by construction"); + Uint::select(self_nz.bingcd(rhs).as_ref(), rhs, self_is_zero) + } + + /// Executes the Binary Extended GCD algorithm. + /// + /// Given `(self, rhs)`, computes `(g, x, y)`, s.t. `self * x + rhs * y = g = gcd(self, rhs)`. + pub fn binxgcd(&self, rhs: &Self) -> UintBinxgcdOutput { + // Make sure `self` and `rhs` are nonzero. + let self_is_zero = self.is_nonzero().not(); + let self_nz = Uint::select(self, &Uint::ONE, self_is_zero) + .to_nz() + .expect("self is non zero by construction"); + let rhs_is_zero = rhs.is_nonzero().not(); + let rhs_nz = Uint::select(rhs, &Uint::ONE, rhs_is_zero) + .to_nz() + .expect("rhs is non zero by construction"); + + let NonZeroUintBinxgcdOutput { + gcd, + mut x, + mut y, + mut lhs_on_gcd, + mut rhs_on_gcd, + } = self_nz.binxgcd(&rhs_nz); + + // Correct the gcd in case self and/or rhs was zero + let mut gcd = *gcd.as_ref(); + gcd = Uint::select(&gcd, rhs, self_is_zero); + gcd = Uint::select(&gcd, self, rhs_is_zero); + + // Correct the Bézout coefficients in case self and/or rhs was zero. + x = Int::select(&x, &Int::ZERO, self_is_zero); + y = Int::select(&y, &Int::ONE, self_is_zero); + x = Int::select(&x, &Int::ONE, rhs_is_zero); + y = Int::select(&y, &Int::ZERO, rhs_is_zero); + + // Correct the quotients in case self and/or rhs was zero. + lhs_on_gcd = Uint::select(&lhs_on_gcd, &Uint::ZERO, self_is_zero); + rhs_on_gcd = Uint::select(&rhs_on_gcd, &Uint::ONE, self_is_zero); + lhs_on_gcd = Uint::select(&lhs_on_gcd, &Uint::ONE, rhs_is_zero); + rhs_on_gcd = Uint::select(&rhs_on_gcd, &Uint::ZERO, rhs_is_zero); + + UintBinxgcdOutput { + gcd, + x, + y, + lhs_on_gcd, + rhs_on_gcd, + } + } +} + +impl NonZero> { + /// Compute the greatest common divisor of `self` and `rhs`. + pub fn bingcd(&self, rhs: &Uint) -> Self { + let val = self.as_ref(); + // Leverage two GCD identity rules to make self odd. + // 1) gcd(2a, 2b) = 2 * gcd(a, b) + // 2) gcd(a, 2b) = gcd(a, b) if a is odd. + let i = val.trailing_zeros(); + let j = rhs.trailing_zeros(); + let k = const_min(i, j); + + val.shr(i) + .to_odd() + .expect("val.shr(i) is odd by construction") + .bingcd(rhs) + .as_ref() + .shl(k) + .to_nz() + .expect("gcd of non-zero element with another element is non-zero") + } + + /// Execute the Binary Extended GCD algorithm. + /// + /// Given `(self, rhs)`, computes `(g, x, y)` s.t. `self * x + rhs * y = g = gcd(self, rhs)`. + pub fn binxgcd(&self, rhs: &Self) -> NonZeroUintBinxgcdOutput { + let (mut lhs, mut rhs) = (*self.as_ref(), *rhs.as_ref()); + + // Leverage the property that gcd(2^k * a, 2^k *b) = 2^k * gcd(a, b) + let i = lhs.trailing_zeros(); + let j = rhs.trailing_zeros(); + let k = const_min(i, j); + lhs = lhs.shr(k); + rhs = rhs.shr(k); + + // Note: at this point, either lhs or rhs is odd (or both). + // Swap to make sure lhs is odd. + let swap = ConstChoice::from_u32_lt(j, i); + Uint::conditional_swap(&mut lhs, &mut rhs, swap); + let lhs = lhs.to_odd().expect("odd by construction"); + + let rhs = rhs.to_nz().expect("non-zero by construction"); + let OddUintBinxgcdOutput { + gcd, + mut x, + mut y, + mut lhs_on_gcd, + mut rhs_on_gcd, + } = lhs.binxgcd_nz(&rhs); + + let gcd = gcd + .as_ref() + .shl(k) + .to_nz() + .expect("is non-zero by construction"); + Int::conditional_swap(&mut x, &mut y, swap); + Uint::conditional_swap(&mut lhs_on_gcd, &mut rhs_on_gcd, swap); + + NonZeroUintBinxgcdOutput { + gcd, + x, + y, + lhs_on_gcd, + rhs_on_gcd, + } + } +} + +impl Odd> { + /// Compute the greatest common divisor of `self` and `rhs` using the Binary GCD algorithm. + /// + /// This function switches between the "classic" and "optimized" algorithm at a best-effort + /// threshold. When using [Uint]s with `LIMBS` close to the threshold, it may be useful to + /// manually test whether the classic or optimized algorithm is faster for your machine. + #[inline(always)] + pub fn bingcd(&self, rhs: &Uint) -> Self { + if LIMBS < 6 { + self.classic_bingcd(rhs) + } else { + self.optimized_bingcd(rhs) + } + } + + /// Execute the Binary Extended GCD algorithm. + /// + /// Given `(self, rhs)`, computes `(g, x, y)` s.t. `self * x + rhs * y = g = gcd(self, rhs)`. + pub fn binxgcd(&self, rhs: &Self) -> OddUintBinxgcdOutput { + self.binxgcd_(rhs).process() + } +} + +#[cfg(all(test, not(miri)))] +mod tests { + mod bincgd_test { + #[cfg(feature = "rand_core")] + use crate::Random; + #[cfg(feature = "rand_core")] + use rand_chacha::ChaChaRng; + #[cfg(feature = "rand_core")] + use rand_core::SeedableRng; + + use crate::{Gcd, Int, U64, U128, U256, U512, U1024, U2048, U4096, U8192, U16384, Uint}; + + fn bingcd_test(lhs: Uint, rhs: Uint) + where + Uint: Gcd>, + { + let gcd = lhs.gcd(&rhs); + let bingcd = lhs.bingcd(&rhs); + assert_eq!(gcd, bingcd); + } + + #[cfg(feature = "rand_core")] + fn bingcd_randomized_tests(iterations: u32) + where + Uint: Gcd>, + { + let mut rng = ChaChaRng::from_seed([0; 32]); + for _ in 0..iterations { + let x = Uint::::random(&mut rng); + let y = Uint::::random(&mut rng); + bingcd_test(x, y); + } + } + + fn bingcd_tests() + where + Uint: Gcd>, + { + // Edge cases + let min = Int::MIN.abs(); + bingcd_test(Uint::ZERO, Uint::ZERO); + bingcd_test(Uint::ZERO, Uint::ONE); + bingcd_test(Uint::ZERO, min); + bingcd_test(Uint::ZERO, Uint::MAX); + bingcd_test(Uint::ONE, Uint::ZERO); + bingcd_test(Uint::ONE, Uint::ONE); + bingcd_test(Uint::ONE, min); + bingcd_test(Uint::ONE, Uint::MAX); + bingcd_test(min, Uint::ZERO); + bingcd_test(min, Uint::ONE); + bingcd_test(min, Int::MIN.abs()); + bingcd_test(min, Uint::MAX); + bingcd_test(Uint::MAX, Uint::ZERO); + bingcd_test(Uint::MAX, Uint::ONE); + bingcd_test(Uint::ONE, min); + bingcd_test(Uint::MAX, Uint::MAX); + + #[cfg(feature = "rand_core")] + bingcd_randomized_tests(100); + } + + #[test] + fn test_bingcd() { + bingcd_tests::<{ U64::LIMBS }>(); + bingcd_tests::<{ U128::LIMBS }>(); + bingcd_tests::<{ U256::LIMBS }>(); + bingcd_tests::<{ U512::LIMBS }>(); + bingcd_tests::<{ U1024::LIMBS }>(); + bingcd_tests::<{ U2048::LIMBS }>(); + bingcd_tests::<{ U4096::LIMBS }>(); + bingcd_tests::<{ U8192::LIMBS }>(); + bingcd_tests::<{ U16384::LIMBS }>(); + } + } + + mod binxgcd_test { + use core::ops::Div; + + #[cfg(feature = "rand_core")] + use crate::Random; + #[cfg(feature = "rand_core")] + use rand_chacha::ChaChaRng; + #[cfg(feature = "rand_core")] + use rand_core::SeedableRng; + + use crate::{ + Concat, Gcd, Int, U64, U128, U256, U512, U1024, U2048, U4096, U8192, U16384, Uint, + }; + + fn binxgcd_test(lhs: Uint, rhs: Uint) + where + Uint: Gcd> + Concat>, + { + let output = lhs.binxgcd(&rhs); + + assert_eq!(output.gcd, lhs.gcd(&rhs)); + + if output.gcd > Uint::ZERO { + assert_eq!(output.lhs_on_gcd, lhs.div(output.gcd.to_nz().unwrap())); + assert_eq!(output.rhs_on_gcd, rhs.div(output.gcd.to_nz().unwrap())); + } + + let (x, y) = output.bezout_coefficients(); + assert_eq!( + x.widening_mul_uint(&lhs) + y.widening_mul_uint(&rhs), + output.gcd.resize().as_int() + ); + } + + #[cfg(feature = "rand_core")] + fn binxgcd_randomized_tests(iterations: u32) + where + Uint: Gcd> + Concat>, + { + let mut rng = ChaChaRng::from_seed([0; 32]); + for _ in 0..iterations { + let x = Uint::::random(&mut rng); + let y = Uint::::random(&mut rng); + binxgcd_test(x, y); + } + } + + fn binxgcd_tests() + where + Uint: Gcd> + Concat>, + { + // Edge cases + let min = Int::MIN.abs(); + binxgcd_test(Uint::ZERO, Uint::ZERO); + binxgcd_test(Uint::ZERO, Uint::ONE); + binxgcd_test(Uint::ZERO, min); + binxgcd_test(Uint::ZERO, Uint::MAX); + binxgcd_test(Uint::ONE, Uint::ZERO); + binxgcd_test(Uint::ONE, Uint::ONE); + binxgcd_test(Uint::ONE, min); + binxgcd_test(Uint::ONE, Uint::MAX); + binxgcd_test(min, Uint::ZERO); + binxgcd_test(min, Uint::ONE); + binxgcd_test(min, Int::MIN.abs()); + binxgcd_test(min, Uint::MAX); + binxgcd_test(Uint::MAX, Uint::ZERO); + binxgcd_test(Uint::MAX, Uint::ONE); + binxgcd_test(Uint::ONE, min); + binxgcd_test(Uint::MAX, Uint::MAX); + + #[cfg(feature = "rand_core")] + binxgcd_randomized_tests(100); + } + + #[test] + fn test_binxgcd() { + binxgcd_tests::<{ U64::LIMBS }, { U128::LIMBS }>(); + binxgcd_tests::<{ U128::LIMBS }, { U256::LIMBS }>(); + binxgcd_tests::<{ U256::LIMBS }, { U512::LIMBS }>(); + binxgcd_tests::<{ U512::LIMBS }, { U1024::LIMBS }>(); + binxgcd_tests::<{ U1024::LIMBS }, { U2048::LIMBS }>(); + binxgcd_tests::<{ U2048::LIMBS }, { U4096::LIMBS }>(); + binxgcd_tests::<{ U4096::LIMBS }, { U8192::LIMBS }>(); + binxgcd_tests::<{ U8192::LIMBS }, { U16384::LIMBS }>(); + } + + #[test] + fn test_binxgcd_regression_tests() { + // Sent in by @kayabaNerve (https://github.com/RustCrypto/crypto-bigint/pull/761#issuecomment-2771564732) + let a = U256::from_be_hex( + "000000000000000000000000000000000000001B5DFB3BA1D549DFAF611B8D4C", + ); + let b = U256::from_be_hex( + "000000000000345EAEDFA8CA03C1F0F5B578A787FE2D23B82A807F178B37FD8E", + ); + binxgcd_test(a, b); + + // Sent in by @kayabaNerve (https://github.com/RustCrypto/crypto-bigint/pull/761#issuecomment-2771581512) + let a = U256::from_be_hex( + "000000000000000000000000000000000000001A0DEEF6F3AC2566149D925044", + ); + let b = U256::from_be_hex( + "000000000000072B69C9DD0AA15F135675EA9C5180CF8FF0A59298CFC92E87FA", + ); + binxgcd_test(a, b); + + // Sent in by @kayabaNerve (https://github.com/RustCrypto/crypto-bigint/pull/761#issuecomment-2782912608) + let a = U512::from_be_hex(concat![ + "7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364142", + "4EB38E6AC0E34DE2F34BFAF22DE683E1F4B92847B6871C780488D797042229E1" + ]); + let b = U512::from_be_hex(concat![ + "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD755DB9CD5E9140777FA4BD19A06C8283", + "9D671CD581C69BC5E697F5E45BCD07C52EC373A8BDC598B4493F50A1380E1281" + ]); + binxgcd_test(a, b); + } + } +} diff --git a/src/uint/cmp.rs b/src/uint/cmp.rs index eb8a109e6..cd19b9e6c 100644 --- a/src/uint/cmp.rs +++ b/src/uint/cmp.rs @@ -25,6 +25,18 @@ impl Uint { Uint { limbs } } + /// Swap `a` and `b` if `c` is truthy, otherwise, do nothing. + #[inline] + pub(crate) const fn conditional_swap(a: &mut Self, b: &mut Self, c: ConstChoice) { + (*a, *b) = (Self::select(a, b, c), Self::select(b, a, c)); + } + + /// Swap `a` and `b`. + #[inline] + pub(crate) const fn swap(a: &mut Self, b: &mut Self) { + Self::conditional_swap(a, b, ConstChoice::TRUE) + } + /// Returns the truthy value if `self`!=0 or the falsy value otherwise. #[inline] pub(crate) const fn is_nonzero(&self) -> ConstChoice { diff --git a/src/uint/shr.rs b/src/uint/shr.rs index df6db1f7e..ae34b40f0 100644 --- a/src/uint/shr.rs +++ b/src/uint/shr.rs @@ -157,6 +157,49 @@ impl Uint { (ret, ConstChoice::from_word_lsb(carry.0 >> Limb::HI_BIT)) } + + /// Computes `self >> shift` where `0 <= shift < Limb::BITS`, + /// returning the result and the carry. + #[inline(always)] + pub(crate) const fn shr_limb(&self, shift: u32) -> (Self, Limb) { + assert!(shift < Limb::BITS); + let nz = ConstChoice::from_u32_nonzero(shift); + let shift = nz.select_u32(1, shift); + let (res, carry) = self.shr_limb_nonzero(shift); + ( + Uint::select(self, &res, nz), + Limb::select(Limb::ZERO, carry, nz), + ) + } + + /// Computes `self >> shift` where `0 < shift < Limb::BITS`, returning the result and the carry. + /// + /// Note: this operation should not be used in situations where `shift == 0`; the compiler can + /// sometimes sniff this case out and optimize it away, possibly leading to variable time + /// behaviour. + #[inline(always)] + pub(crate) const fn shr_limb_nonzero(&self, shift: u32) -> (Self, Limb) { + assert!(0 < shift); + assert!(shift < Limb::BITS); + + let mut limbs = [Limb::ZERO; LIMBS]; + + let rshift = shift; + let lshift = Limb::BITS - shift; + + let mut carry = Limb::ZERO; + let mut i = LIMBS; + while i > 0 { + i -= 1; + + let limb = self.limbs[i].shr(rshift); + let new_carry = self.limbs[i].shl(lshift); + limbs[i] = limb.bitor(carry); + carry = new_carry; + } + + (Uint::::new(limbs), carry) + } } macro_rules! impl_shr { @@ -208,7 +251,7 @@ impl ShrVartime for Uint { #[cfg(test)] mod tests { - use crate::{U128, U256, Uint}; + use crate::{Limb, U128, U256, Uint}; const N: U256 = U256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141"); @@ -258,4 +301,41 @@ mod tests { .is_true_vartime() ); } + + #[test] + #[should_panic] + fn shr_limb_shift_too_large() { + let _ = U128::ONE.shr_limb(Limb::BITS); + } + + #[test] + #[should_panic] + fn shr_limb_nz_panics_at_zero_shift() { + let _ = U128::ONE.shr_limb_nonzero(0); + } + + #[test] + fn shr_limb() { + let val = U128::from_be_hex("876543210FEDCBA90123456FEDCBA987"); + + // Shift by zero + let (res, carry) = val.shr_limb(0); + assert_eq!(res, val); + assert_eq!(carry, Limb::ZERO); + + // Shift by one + let (res, carry) = val.shr_limb(1); + assert_eq!(res, val.shr_vartime(1)); + assert_eq!(carry, val.limbs[0].shl(Limb::BITS - 1)); + + // Shift by any + let (res, carry) = val.shr_limb(13); + assert_eq!(res, val.shr_vartime(13)); + assert_eq!(carry, val.limbs[0].shl(Limb::BITS - 13)); + + // Shift by max + let (res, carry) = val.shr_limb(Limb::BITS - 1); + assert_eq!(res, val.shr_vartime(Limb::BITS - 1)); + assert_eq!(carry, val.limbs[0].shl(1)); + } }