Skip to content

Commit b623239

Browse files
SparrowLiibluss
authored andcommitted
Add function broadcast_with
1 parent e3b73cc commit b623239

File tree

3 files changed

+66
-11
lines changed

3 files changed

+66
-11
lines changed

src/impl_methods.rs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ use rawpointer::PointerExt;
1414

1515
use crate::imp_prelude::*;
1616

17-
use crate::arraytraits;
17+
use crate::{arraytraits, BroadcastShape};
1818
use crate::dimension;
1919
use crate::dimension::IntoDimension;
2020
use crate::dimension::{
2121
abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last,
2222
offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes,
2323
};
24-
use crate::error::{self, ErrorKind, ShapeError};
24+
use crate::error::{self, ErrorKind, ShapeError, from_kind};
2525
use crate::math_cell::MathCell;
2626
use crate::itertools::zip;
2727
use crate::zip::Zip;
@@ -1766,6 +1766,36 @@ where
17661766
unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) }
17671767
}
17681768

1769+
/// Calculate the views of two ArrayBases after broadcasting each other, if possible.
1770+
///
1771+
/// Return `ShapeError` if their shapes can not be broadcast together.
1772+
///
1773+
/// ```
1774+
/// use ndarray::{arr1, arr2};
1775+
///
1776+
/// let a = arr2(&[[2], [3], [4]]);
1777+
/// let b = arr1(&[5, 6, 7]);
1778+
/// let (a1, b1) = a.broadcast_with(&b).unwrap();
1779+
/// assert_eq!(a1, arr2(&[[2, 2, 2], [3, 3, 3], [4, 4, 4]]));
1780+
/// assert_eq!(b1, arr2(&[[5, 6, 7], [5, 6, 7], [5, 6, 7]]));
1781+
/// ```
1782+
pub fn broadcast_with<'a, 'b, B, S2, E>(&'a self, other: &'b ArrayBase<S2, E>) ->
1783+
Result<(ArrayView<'a, A, <D as BroadcastShape<E>>::Output>, ArrayView<'b, B, <D as BroadcastShape<E>>::Output>), ShapeError>
1784+
where
1785+
S: Data<Elem=A>,
1786+
S2: Data<Elem=B>,
1787+
D: Dimension + BroadcastShape<E>,
1788+
E: Dimension,
1789+
{
1790+
let shape = self.dim.broadcast_shape(&other.dim)?;
1791+
if let Some(view1) = self.broadcast(shape.clone()) {
1792+
if let Some(view2) = other.broadcast(shape) {
1793+
return Ok((view1, view2))
1794+
}
1795+
}
1796+
return Err(from_kind(ErrorKind::IncompatibleShape));
1797+
}
1798+
17691799
/// Swap axes `ax` and `bx`.
17701800
///
17711801
/// This does not move any data, it just adjusts the array’s dimensions

src/impl_ops.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,7 @@ where
106106
out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth));
107107
out
108108
} else {
109-
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
110-
let lhs = self.broadcast(shape.clone()).unwrap();
111-
let rhs = rhs.broadcast(shape).unwrap();
109+
let (lhs, rhs) = self.broadcast_with(rhs).unwrap();
112110
Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth))
113111
}
114112
}
@@ -143,9 +141,7 @@ where
143141
out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth));
144142
out
145143
} else {
146-
let shape = rhs.dim.broadcast_shape(&self.dim).unwrap();
147-
let lhs = self.broadcast(shape.clone()).unwrap();
148-
let rhs = rhs.broadcast(shape).unwrap();
144+
let (rhs, lhs) = rhs.broadcast_with(self).unwrap();
149145
Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth))
150146
}
151147
}
@@ -171,9 +167,7 @@ where
171167
{
172168
type Output = Array<A, <D as BroadcastShape<E>>::Output>;
173169
fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Self::Output {
174-
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
175-
let lhs = self.broadcast(shape.clone()).unwrap();
176-
let rhs = rhs.broadcast(shape).unwrap();
170+
let (lhs, rhs) = self.broadcast_with(rhs).unwrap();
177171
Zip::from(&lhs).and(&rhs).map_collect(clone_opf(A::$mth))
178172
}
179173
}

tests/broadcast.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use ndarray::prelude::*;
2+
use ndarray::{ShapeError, ErrorKind, arr3};
23

34
#[test]
45
#[cfg(feature = "std")]
@@ -81,3 +82,33 @@ fn test_broadcast_1d() {
8182
println!("b2=\n{:?}", b2);
8283
assert_eq!(b0, b2);
8384
}
85+
86+
#[test]
87+
fn test_broadcast_with() {
88+
let a = arr2(&[[1., 2.], [3., 4.]]);
89+
let b = aview0(&1.);
90+
let (a1, b1) = a.broadcast_with(&b).unwrap();
91+
assert_eq!(a1, arr2(&[[1.0, 2.0], [3.0, 4.0]]));
92+
assert_eq!(b1, arr2(&[[1.0, 1.0], [1.0, 1.0]]));
93+
94+
let a = arr2(&[[2], [3], [4]]);
95+
let b = arr1(&[5, 6, 7]);
96+
let (a1, b1) = a.broadcast_with(&b).unwrap();
97+
assert_eq!(a1, arr2(&[[2, 2, 2], [3, 3, 3], [4, 4, 4]]));
98+
assert_eq!(b1, arr2(&[[5, 6, 7], [5, 6, 7], [5, 6, 7]]));
99+
100+
// Negative strides and non-contiguous memory
101+
let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
102+
let s = Array3::from_shape_vec((2, 3, 2).strides((1, 4, 2)), s.to_vec()).unwrap();
103+
let a = s.slice(s![..;-1,..;2,..]);
104+
let b = s.slice(s![..2, -1, ..]);
105+
let (a1, b1) = a.broadcast_with(&b).unwrap();
106+
assert_eq!(a1, arr3(&[[[2, 4], [10, 12]], [[1, 3], [9, 11]]]));
107+
assert_eq!(b1, arr3(&[[[9, 11], [10, 12]], [[9, 11], [10, 12]]]));
108+
109+
// ShapeError
110+
let a = arr2(&[[2, 2], [3, 3], [4, 4]]);
111+
let b = arr1(&[5, 6, 7]);
112+
let e = a.broadcast_with(&b);
113+
assert_eq!(e, Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)));
114+
}

0 commit comments

Comments
 (0)