diff --git a/src/array_like.rs b/src/array_like.rs index 1c7a0b36b..554c960d5 100644 --- a/src/array_like.rs +++ b/src/array_like.rs @@ -149,7 +149,7 @@ where let py = ob.py(); - if matches!(D::NDIM, None | Some(1)) { + if matches!(D::NDIM, Some(1)) { if let Ok(vec) = ob.extract::>() { let array = Array1::from(vec) .into_dimensionality() diff --git a/tests/array_like.rs b/tests/array_like.rs index aa185b174..4bae2b9e0 100644 --- a/tests/array_like.rs +++ b/tests/array_like.rs @@ -1,5 +1,8 @@ use ndarray::array; -use numpy::{get_array_module, AllowTypeChange, PyArrayLike1, PyArrayLike2, PyArrayLikeDyn}; +use numpy::{ + get_array_module, AllowTypeChange, PyArrayLike1, PyArrayLike2, PyArrayLikeDyn, + PyUntypedArrayMethods, +}; use pyo3::{ ffi::c_str, types::{IntoPyDict, PyAnyMethods, PyDict}, @@ -105,7 +108,9 @@ fn convert_1d_list_on_extract() { Python::with_gil(|py| { let py_list = py.eval(c_str!("[1,2,3,4]"), None, None).unwrap(); let extracted_array_1d = py_list.extract::>().unwrap(); - let extracted_array_dyn = py_list.extract::>().unwrap(); + let extracted_array_dyn = py_list + .extract::>() + .unwrap(); assert_eq!(array![1, 2, 3, 4], extracted_array_1d.as_array()); assert_eq!( @@ -115,6 +120,25 @@ fn convert_1d_list_on_extract() { }); } +#[test] +fn preserve_trailing_singleton_dims() { + Python::with_gil(|py| { + let locals = get_np_locals(py); + let py_array = py + .eval( + c_str!("np.array([[1], [2], [3]], dtype='int32')"), + Some(&locals), + None, + ) + .unwrap(); + let extracted_array = py_array + .extract::>() + .unwrap(); + + assert_eq!(extracted_array.shape(), &[3, 1]); + }) +} + #[test] fn unsafe_cast_shall_fail() { Python::with_gil(|py| {