Skip to content

Commit 7733b7c

Browse files
authored
Merge pull request #588 from jturner314/generalize-op-types
Allow ops on arrays with elems of different types
2 parents 638ac16 + 9b51170 commit 7733b7c

File tree

4 files changed

+34
-26
lines changed

4 files changed

+34
-26
lines changed

src/arraytraits.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,12 @@ impl<S, D, I> IndexMut<I> for ArrayBase<S, D>
8181

8282
/// Return `true` if the array shapes and all elements of `self` and
8383
/// `rhs` are equal. Return `false` otherwise.
84-
impl<S, S2, D> PartialEq<ArrayBase<S2, D>> for ArrayBase<S, D>
85-
where D: Dimension,
86-
S: Data,
87-
S2: Data<Elem = S::Elem>,
88-
S::Elem: PartialEq
84+
impl<A, B, S, S2, D> PartialEq<ArrayBase<S2, D>> for ArrayBase<S, D>
85+
where
86+
A: PartialEq<B>,
87+
S: Data<Elem = A>,
88+
S2: Data<Elem = B>,
89+
D: Dimension,
8990
{
9091
fn eq(&self, rhs: &ArrayBase<S2, D>) -> bool {
9192
if self.shape() != rhs.shape() {

src/impl_ops.rs

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,14 @@ macro_rules! impl_binary_op(
6060
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
6161
///
6262
/// **Panics** if broadcasting isn’t possible.
63-
impl<A, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
64-
where A: Clone + $trt<A, Output=A>,
65-
S: DataOwned<Elem=A> + DataMut,
66-
S2: Data<Elem=A>,
67-
D: Dimension,
68-
E: Dimension,
63+
impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
64+
where
65+
A: Clone + $trt<B, Output=A>,
66+
B: Clone,
67+
S: DataOwned<Elem=A> + DataMut,
68+
S2: Data<Elem=B>,
69+
D: Dimension,
70+
E: Dimension,
6971
{
7072
type Output = ArrayBase<S, D>;
7173
fn $mth(self, rhs: ArrayBase<S2, E>) -> ArrayBase<S, D>
@@ -82,12 +84,14 @@ impl<A, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
8284
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
8385
///
8486
/// **Panics** if broadcasting isn’t possible.
85-
impl<'a, A, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
86-
where A: Clone + $trt<A, Output=A>,
87-
S: DataOwned<Elem=A> + DataMut,
88-
S2: Data<Elem=A>,
89-
D: Dimension,
90-
E: Dimension,
87+
impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
88+
where
89+
A: Clone + $trt<B, Output=A>,
90+
B: Clone,
91+
S: DataOwned<Elem=A> + DataMut,
92+
S2: Data<Elem=B>,
93+
D: Dimension,
94+
E: Dimension,
9195
{
9296
type Output = ArrayBase<S, D>;
9397
fn $mth(mut self, rhs: &ArrayBase<S2, E>) -> ArrayBase<S, D>
@@ -107,12 +111,14 @@ impl<'a, A, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
107111
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
108112
///
109113
/// **Panics** if broadcasting isn’t possible.
110-
impl<'a, A, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for &'a ArrayBase<S, D>
111-
where A: Clone + $trt<A, Output=A>,
112-
S: Data<Elem=A>,
113-
S2: Data<Elem=A>,
114-
D: Dimension,
115-
E: Dimension,
114+
impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for &'a ArrayBase<S, D>
115+
where
116+
A: Clone + $trt<B, Output=A>,
117+
B: Clone,
118+
S: Data<Elem=A>,
119+
S2: Data<Elem=B>,
120+
D: Dimension,
121+
E: Dimension,
116122
{
117123
type Output = Array<A, D>;
118124
fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Array<A, D> {

src/numeric_util.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,9 @@ pub fn unrolled_dot<A>(xs: &[A], ys: &[A]) -> A
9696
/// Compute pairwise equality
9797
///
9898
/// `xs` and `ys` must be the same length
99-
pub fn unrolled_eq<A>(xs: &[A], ys: &[A]) -> bool
100-
where A: PartialEq
99+
pub fn unrolled_eq<A, B>(xs: &[A], ys: &[B]) -> bool
100+
where
101+
A: PartialEq<B>,
101102
{
102103
debug_assert_eq!(xs.len(), ys.len());
103104
// eightfold unrolled for performance (this is not done by llvm automatically)

tests/iterators.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ fn inner_iter() {
131131

132132
#[test]
133133
fn inner_iter_corner_cases() {
134-
let a0 = ArcArray::zeros(());
134+
let a0 = ArcArray::<i32, _>::zeros(());
135135
assert_equal(a0.genrows(), vec![aview1(&[0])]);
136136

137137
let a2 = ArcArray::<i32, _>::zeros((0, 3));

0 commit comments

Comments
 (0)