diff --git a/src/impl_methods.rs b/src/impl_methods.rs index ea9c9a0d5..0fd3a5a71 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -2538,6 +2538,72 @@ where unsafe { self.with_strides_dim(new_strides, new_dim) } } + /// Permute the axes in-place. + /// + /// This does not move any data, it just adjusts the array's dimensions + /// and strides. + /// + /// *i* in the *j*-th place in the axes sequence means `self`'s *i*-th axis + /// becomes `self`'s *j*-th axis + /// + /// **Panics** if any of the axes are out of bounds, if an axis is missing, + /// or if an axis is repeated more than once. + /// + /// # Example + /// ```rust + /// use ndarray::{arr2, Array3}; + /// + /// let mut a = arr2(&[[0, 1], [2, 3]]); + /// a.permute_axes([1, 0]); + /// assert_eq!(a, arr2(&[[0, 2], [1, 3]])); + /// + /// let mut b = Array3::::zeros((1, 2, 3)); + /// b.permute_axes([1, 0, 2]); + /// assert_eq!(b.shape(), &[2, 1, 3]); + /// ``` + #[track_caller] + pub fn permute_axes(&mut self, axes: T) + where T: IntoDimension + { + let axes = axes.into_dimension(); + // Ensure that each axis is used exactly once. + let mut usage_counts = D::zeros(self.ndim()); + for axis in axes.slice() { + usage_counts[*axis] += 1; + } + for count in usage_counts.slice() { + assert_eq!(*count, 1, "each axis must be listed exactly once"); + } + + let dim = self.layout.dim.slice_mut(); + let strides = self.layout.strides.slice_mut(); + let axes = axes.slice(); + + // The cycle detection is done using a bitmask to track visited positions. + // For example, axes from [0,1,2] to [2, 0, 1] + // For axis values [1, 0, 2]: + // 1 << 1 // 0b0001 << 1 = 0b0010 (decimal 2) + // 1 << 0 // 0b0001 << 0 = 0b0001 (decimal 1) + // 1 << 2 // 0b0001 << 2 = 0b0100 (decimal 4) + // + // Each axis gets its own unique bit position in the bitmask: + // - Axis 0: bit 0 (rightmost) + // - Axis 1: bit 1 + // - Axis 2: bit 2 + // + let mut visited = 0usize; + for (new_axis, &axis) in axes.iter().enumerate() { + if (visited & (1 << axis)) != 0 { + continue; + } + + dim.swap(axis, new_axis); + strides.swap(axis, new_axis); + + visited |= (1 << axis) | (1 << new_axis); + } + } + /// Transpose the array by reversing axes. /// /// Transposition reverses the order of the axes (dimensions and strides) @@ -2548,6 +2614,16 @@ where self.layout.strides.slice_mut().reverse(); self } + + /// Reverse the axes of the array in-place. + /// + /// This does not move any data, it just adjusts the array's dimensions + /// and strides. + pub fn reverse_axes(&mut self) + { + self.layout.dim.slice_mut().reverse(); + self.layout.strides.slice_mut().reverse(); + } } impl ArrayRef diff --git a/tests/array.rs b/tests/array.rs index f1426625c..3d6fa6715 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -2828,3 +2828,81 @@ fn test_slice_assign() *a.slice_mut(s![1..3]) += 1; assert_eq!(a, array![0, 2, 3, 3, 4]); } + +#[test] +fn reverse_axes() +{ + let mut a = arr2(&[[1, 2], [3, 4]]); + a.reverse_axes(); + assert_eq!(a, arr2(&[[1, 3], [2, 4]])); + + let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]); + a.reverse_axes(); + assert_eq!(a, arr2(&[[1, 4], [2, 5], [3, 6]])); + + let mut a = Array::from_iter(0..24) + .into_shape_with_order((2, 3, 4)) + .unwrap(); + let original = a.clone(); + a.reverse_axes(); + for ((i0, i1, i2), elem) in original.indexed_iter() { + assert_eq!(*elem, a[(i2, i1, i0)]); + } +} + +#[test] +fn permute_axes() +{ + let mut a = arr2(&[[1, 2], [3, 4]]); + a.permute_axes([1, 0]); + assert_eq!(a, arr2(&[[1, 3], [2, 4]])); + + let mut a = Array::from_iter(0..24) + .into_shape_with_order((2, 3, 4)) + .unwrap(); + let original = a.clone(); + a.permute_axes([2, 1, 0]); + for ((i0, i1, i2), elem) in original.indexed_iter() { + assert_eq!(*elem, a[(i2, i1, i0)]); + } + + let mut a = Array::from_iter(0..120) + .into_shape_with_order((2, 3, 4, 5)) + .unwrap(); + let original = a.clone(); + a.permute_axes([1, 0, 3, 2]); + for ((i0, i1, i2, i3), elem) in original.indexed_iter() { + assert_eq!(*elem, a[(i1, i0, i3, i2)]); + } +} + +#[should_panic] +#[test] +fn permute_axes_repeated_axis() +{ + let mut a = Array::from_iter(0..24) + .into_shape_with_order((2, 3, 4)) + .unwrap(); + a.permute_axes([1, 0, 1]); +} + +#[should_panic] +#[test] +fn permute_axes_missing_axis() +{ + let mut a = Array::from_iter(0..24) + .into_shape_with_order((2, 3, 4)) + .unwrap() + .into_dyn(); + a.permute_axes(&[2, 0][..]); +} + +#[should_panic] +#[test] +fn permute_axes_oob() +{ + let mut a = Array::from_iter(0..24) + .into_shape_with_order((2, 3, 4)) + .unwrap(); + a.permute_axes([1, 0, 3]); +}