Skip to content

Commit 4239ddd

Browse files
bors[bot]cuviper
andauthored
Merge #142
142: Implement shift with broad RHS types r=cuviper a=cuviper The primitive integers all support shift operators with the right-hand side as any other primitive integer, by value or by reference. Now `BigInt` and `BigUint` support this too. This is a minor breaking change, because it can cause type inference failures if a user has an ambiguous integer for a shift count, whereas before it could only be `usize`. Co-authored-by: Josh Stone <[email protected]>
2 parents c076bb8 + e827fd5 commit 4239ddd

File tree

7 files changed

+244
-149
lines changed

7 files changed

+244
-149
lines changed

benches/bigint.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,10 @@ fn rand_131072(b: &mut Bencher) {
293293

294294
#[bench]
295295
fn shl(b: &mut Bencher) {
296-
let n = BigUint::one() << 1000;
296+
let n = BigUint::one() << 1000u32;
297+
let mut m = n.clone();
297298
b.iter(|| {
298-
let mut m = n.clone();
299+
m.clone_from(&n);
299300
for i in 0..50 {
300301
m <<= i;
301302
}
@@ -304,9 +305,10 @@ fn shl(b: &mut Bencher) {
304305

305306
#[bench]
306307
fn shr(b: &mut Bencher) {
307-
let n = BigUint::one() << 2000;
308+
let n = BigUint::one() << 2000u32;
309+
let mut m = n.clone();
308310
b.iter(|| {
309-
let mut m = n.clone();
311+
m.clone_from(&n);
310312
for i in 0..50 {
311313
m >>= i;
312314
}

src/algorithms.rs

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use core::cmp;
33
use core::cmp::Ordering::{self, Equal, Greater, Less};
44
use core::iter::repeat;
55
use core::mem;
6-
use num_traits::{One, Zero};
6+
use num_traits::{One, PrimInt, Zero};
77

88
use crate::biguint::biguint_from_vec;
99
use crate::biguint::BigUint;
@@ -720,34 +720,46 @@ fn div_rem_core(mut a: BigUint, b: &BigUint) -> (BigUint, BigUint) {
720720

721721
/// Find last set bit
722722
/// fls(0) == 0, fls(u32::MAX) == 32
723-
pub(crate) fn fls<T: num_traits::PrimInt>(v: T) -> usize {
723+
pub(crate) fn fls<T: PrimInt>(v: T) -> usize {
724724
mem::size_of::<T>() * 8 - v.leading_zeros() as usize
725725
}
726726

727-
pub(crate) fn ilog2<T: num_traits::PrimInt>(v: T) -> usize {
727+
pub(crate) fn ilog2<T: PrimInt>(v: T) -> usize {
728728
fls(v) - 1
729729
}
730730

731731
#[inline]
732-
pub(crate) fn biguint_shl(n: Cow<'_, BigUint>, bits: usize) -> BigUint {
733-
let n_unit = bits / big_digit::BITS;
734-
let mut data = match n_unit {
732+
pub(crate) fn biguint_shl<T: PrimInt>(n: Cow<'_, BigUint>, shift: T) -> BigUint {
733+
if shift < T::zero() {
734+
panic!("attempt to shift left with negative");
735+
}
736+
if n.is_zero() {
737+
return n.into_owned();
738+
}
739+
let bits = T::from(big_digit::BITS).unwrap();
740+
let digits = (shift / bits).to_usize().expect("capacity overflow");
741+
let shift = (shift % bits).to_u8().unwrap();
742+
biguint_shl2(n, digits, shift)
743+
}
744+
745+
fn biguint_shl2(n: Cow<'_, BigUint>, digits: usize, shift: u8) -> BigUint {
746+
let mut data = match digits {
735747
0 => n.into_owned().data,
736748
_ => {
737-
let len = n_unit + n.data.len() + 1;
749+
let len = digits.saturating_add(n.data.len() + 1);
738750
let mut data = Vec::with_capacity(len);
739-
data.extend(repeat(0).take(n_unit));
740-
data.extend(n.data.iter().cloned());
751+
data.extend(repeat(0).take(digits));
752+
data.extend(n.data.iter());
741753
data
742754
}
743755
};
744756

745-
let n_bits = bits % big_digit::BITS;
746-
if n_bits > 0 {
757+
if shift > 0 {
747758
let mut carry = 0;
748-
for elem in data[n_unit..].iter_mut() {
749-
let new_carry = *elem >> (big_digit::BITS - n_bits);
750-
*elem = (*elem << n_bits) | carry;
759+
let carry_shift = big_digit::BITS as u8 - shift;
760+
for elem in data[digits..].iter_mut() {
761+
let new_carry = *elem >> carry_shift;
762+
*elem = (*elem << shift) | carry;
751763
carry = new_carry;
752764
}
753765
if carry != 0 {
@@ -759,25 +771,39 @@ pub(crate) fn biguint_shl(n: Cow<'_, BigUint>, bits: usize) -> BigUint {
759771
}
760772

761773
#[inline]
762-
pub(crate) fn biguint_shr(n: Cow<'_, BigUint>, bits: usize) -> BigUint {
763-
let n_unit = bits / big_digit::BITS;
764-
if n_unit >= n.data.len() {
765-
return Zero::zero();
774+
pub(crate) fn biguint_shr<T: PrimInt>(n: Cow<'_, BigUint>, shift: T) -> BigUint {
775+
if shift < T::zero() {
776+
panic!("attempt to shift right with negative");
777+
}
778+
if n.is_zero() {
779+
return n.into_owned();
780+
}
781+
let bits = T::from(big_digit::BITS).unwrap();
782+
let digits = (shift / bits).to_usize().unwrap_or(core::usize::MAX);
783+
let shift = (shift % bits).to_u8().unwrap();
784+
biguint_shr2(n, digits, shift)
785+
}
786+
787+
fn biguint_shr2(n: Cow<'_, BigUint>, digits: usize, shift: u8) -> BigUint {
788+
if digits >= n.data.len() {
789+
let mut n = n.into_owned();
790+
n.set_zero();
791+
return n;
766792
}
767793
let mut data = match n {
768-
Cow::Borrowed(n) => n.data[n_unit..].to_vec(),
794+
Cow::Borrowed(n) => n.data[digits..].to_vec(),
769795
Cow::Owned(mut n) => {
770-
n.data.drain(..n_unit);
796+
n.data.drain(..digits);
771797
n.data
772798
}
773799
};
774800

775-
let n_bits = bits % big_digit::BITS;
776-
if n_bits > 0 {
801+
if shift > 0 {
777802
let mut borrow = 0;
803+
let borrow_shift = big_digit::BITS as u8 - shift;
778804
for elem in data.iter_mut().rev() {
779-
let new_borrow = *elem << (big_digit::BITS - n_bits);
780-
*elem = (*elem >> n_bits) | borrow;
805+
let new_borrow = *elem << borrow_shift;
806+
*elem = (*elem >> shift) | borrow;
781807
borrow = new_borrow;
782808
}
783809
}

src/bigint.rs

Lines changed: 90 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use serde;
2525

2626
use num_integer::{Integer, Roots};
2727
use num_traits::{
28-
CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive, Num, One, Pow, Signed,
28+
CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive, Num, One, Pow, PrimInt, Signed,
2929
ToPrimitive, Zero,
3030
};
3131

@@ -779,69 +779,105 @@ impl Num for BigInt {
779779
}
780780
}
781781

782-
impl Shl<usize> for BigInt {
783-
type Output = BigInt;
782+
macro_rules! impl_shift {
783+
(@ref $Shx:ident :: $shx:ident, $ShxAssign:ident :: $shx_assign:ident, $rhs:ty) => {
784+
impl<'b> $Shx<&'b $rhs> for BigInt {
785+
type Output = BigInt;
784786

785-
#[inline]
786-
fn shl(mut self, rhs: usize) -> BigInt {
787-
self <<= rhs;
788-
self
789-
}
790-
}
787+
#[inline]
788+
fn $shx(self, rhs: &'b $rhs) -> BigInt {
789+
$Shx::$shx(self, *rhs)
790+
}
791+
}
792+
impl<'a, 'b> $Shx<&'b $rhs> for &'a BigInt {
793+
type Output = BigInt;
791794

792-
impl<'a> Shl<usize> for &'a BigInt {
793-
type Output = BigInt;
795+
#[inline]
796+
fn $shx(self, rhs: &'b $rhs) -> BigInt {
797+
$Shx::$shx(self, *rhs)
798+
}
799+
}
800+
impl<'b> $ShxAssign<&'b $rhs> for BigInt {
801+
#[inline]
802+
fn $shx_assign(&mut self, rhs: &'b $rhs) {
803+
$ShxAssign::$shx_assign(self, *rhs);
804+
}
805+
}
806+
};
807+
($($rhs:ty),+) => {$(
808+
impl Shl<$rhs> for BigInt {
809+
type Output = BigInt;
794810

795-
#[inline]
796-
fn shl(self, rhs: usize) -> BigInt {
797-
BigInt::from_biguint(self.sign, &self.data << rhs)
798-
}
799-
}
811+
#[inline]
812+
fn shl(self, rhs: $rhs) -> BigInt {
813+
BigInt::from_biguint(self.sign, self.data << rhs)
814+
}
815+
}
816+
impl<'a> Shl<$rhs> for &'a BigInt {
817+
type Output = BigInt;
800818

801-
impl ShlAssign<usize> for BigInt {
802-
#[inline]
803-
fn shl_assign(&mut self, rhs: usize) {
804-
self.data <<= rhs;
805-
}
806-
}
819+
#[inline]
820+
fn shl(self, rhs: $rhs) -> BigInt {
821+
BigInt::from_biguint(self.sign, &self.data << rhs)
822+
}
823+
}
824+
impl ShlAssign<$rhs> for BigInt {
825+
#[inline]
826+
fn shl_assign(&mut self, rhs: $rhs) {
827+
self.data <<= rhs
828+
}
829+
}
830+
impl_shift! { @ref Shl::shl, ShlAssign::shl_assign, $rhs }
807831

808-
// Negative values need a rounding adjustment if there are any ones in the
809-
// bits that are getting shifted out.
810-
fn shr_round_down(i: &BigInt, rhs: usize) -> bool {
811-
i.is_negative() && i.trailing_zeros().map(|n| n < rhs).unwrap_or(false)
812-
}
832+
impl Shr<$rhs> for BigInt {
833+
type Output = BigInt;
813834

814-
impl Shr<usize> for BigInt {
815-
type Output = BigInt;
835+
#[inline]
836+
fn shr(self, rhs: $rhs) -> BigInt {
837+
let round_down = shr_round_down(&self, rhs);
838+
let data = self.data >> rhs;
839+
let data = if round_down { data + 1u8 } else { data };
840+
BigInt::from_biguint(self.sign, data)
841+
}
842+
}
843+
impl<'a> Shr<$rhs> for &'a BigInt {
844+
type Output = BigInt;
816845

817-
#[inline]
818-
fn shr(mut self, rhs: usize) -> BigInt {
819-
self >>= rhs;
820-
self
821-
}
846+
#[inline]
847+
fn shr(self, rhs: $rhs) -> BigInt {
848+
let round_down = shr_round_down(self, rhs);
849+
let data = &self.data >> rhs;
850+
let data = if round_down { data + 1u8 } else { data };
851+
BigInt::from_biguint(self.sign, data)
852+
}
853+
}
854+
impl ShrAssign<$rhs> for BigInt {
855+
#[inline]
856+
fn shr_assign(&mut self, rhs: $rhs) {
857+
let round_down = shr_round_down(self, rhs);
858+
self.data >>= rhs;
859+
if round_down {
860+
self.data += 1u8;
861+
} else if self.data.is_zero() {
862+
self.sign = NoSign;
863+
}
864+
}
865+
}
866+
impl_shift! { @ref Shr::shr, ShrAssign::shr_assign, $rhs }
867+
)*};
822868
}
823869

824-
impl<'a> Shr<usize> for &'a BigInt {
825-
type Output = BigInt;
826-
827-
#[inline]
828-
fn shr(self, rhs: usize) -> BigInt {
829-
let round_down = shr_round_down(self, rhs);
830-
let data = &self.data >> rhs;
831-
BigInt::from_biguint(self.sign, if round_down { data + 1u8 } else { data })
832-
}
833-
}
870+
impl_shift! { u8, u16, u32, u64, u128, usize }
871+
impl_shift! { i8, i16, i32, i64, i128, isize }
834872

835-
impl ShrAssign<usize> for BigInt {
836-
#[inline]
837-
fn shr_assign(&mut self, rhs: usize) {
838-
let round_down = shr_round_down(self, rhs);
839-
self.data >>= rhs;
840-
if round_down {
841-
self.data += 1u8;
842-
} else if self.data.is_zero() {
843-
self.sign = NoSign;
844-
}
873+
// Negative values need a rounding adjustment if there are any ones in the
874+
// bits that are getting shifted out.
875+
fn shr_round_down<T: PrimInt>(i: &BigInt, shift: T) -> bool {
876+
if i.is_negative() {
877+
let zeros = i.trailing_zeros().expect("negative values are non-zero");
878+
shift > T::zero() && shift.to_usize().map(|shift| zeros < shift).unwrap_or(true)
879+
} else {
880+
false
845881
}
846882
}
847883

0 commit comments

Comments
 (0)