Skip to content

Commit 1671b00

Browse files
124C41padamreichold
authored andcommitted
Add PyArrayLike wrapper around PyReadonlyArray
Extracts a read-only reference if the correct NumPy array type is given. Tries to convert the input into the correct type using `numpy.asarray` otherwise.
1 parent b5af0ed commit 1671b00

File tree

4 files changed

+340
-0
lines changed

4 files changed

+340
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
- Increase MSRV to 1.56 released in October 2021 and available in Debain 12, RHEL 9 and Alpine 3.17 following the same change for PyO3. ([#378](https://github.com/PyO3/rust-numpy/pull/378))
55
- Add support for ASCII (`PyFixedString<N>`) and Unicode (`PyFixedUnicode<N>`) string arrays, i.e. dtypes `SN` and `UN` where `N` is the number of characters. ([#378](https://github.com/PyO3/rust-numpy/pull/378))
66
- Add support for the `bfloat16` dtype by extending the optional integration with the `half` crate. Note that the `bfloat16` dtype is not part of NumPy itself so that usage requires third-party packages like Tensorflow. ([#381](https://github.com/PyO3/rust-numpy/pull/381))
7+
- Add `PyArrayLike` type which extracts `PyReadonlyArray` if a NumPy array of the correct type is given and attempts a conversion using `numpy.asarray` otherwise. ([#383](https://github.com/PyO3/rust-numpy/pull/383))
78

89
- v0.19.0
910
- Add `PyUntypedArray` as an untyped base type for `PyArray` which can be used to inspect arguments before more targeted downcasts. This is accompanied by some methods like `dtype` and `shape` moving from `PyArray` to `PyUntypedArray`. They are still accessible though, as `PyArray` dereferences to `PyUntypedArray` via the `Deref` trait. ([#369](https://github.com/PyO3/rust-numpy/pull/369))

src/array_like.rs

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
use std::marker::PhantomData;
2+
use std::ops::Deref;
3+
4+
use ndarray::{Array1, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
5+
use pyo3::{intern, sync::GILOnceCell, types::PyDict, FromPyObject, Py, PyAny, PyResult};
6+
7+
use crate::sealed::Sealed;
8+
use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray};
9+
10+
pub trait Coerce: Sealed {
11+
const VAL: bool;
12+
}
13+
14+
/// Marker type to indicate that the element type received via [`PyArrayLike`] must match the specified type exactly.
15+
#[derive(Debug)]
16+
pub struct TypeMustMatch;
17+
18+
impl Sealed for TypeMustMatch {}
19+
20+
impl Coerce for TypeMustMatch {
21+
const VAL: bool = false;
22+
}
23+
24+
/// Marker type to indicate that the element type received via [`PyArrayLike`] can be cast to the specified type by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
25+
#[derive(Debug)]
26+
pub struct AllowTypeChange;
27+
28+
impl Sealed for AllowTypeChange {}
29+
30+
impl Coerce for AllowTypeChange {
31+
const VAL: bool = true;
32+
}
33+
34+
/// Receiver for arrays or array-like types.
35+
///
36+
/// When building API using NumPy in Python, it is common for functions to additionally accept any array-like type such as `list[float]` as arguments.
37+
/// `PyArrayLike` enables the same pattern in Rust extensions, i.e. by taking this type as the argument of a `#[pyfunction]`,
38+
/// one will always get access to a [`PyReadonlyArray`] that will either reference to the NumPy array originally passed into the function
39+
/// or a temporary one created by converting the input type into a NumPy array.
40+
///
41+
/// Depending on whether [`TypeMustMatch`] or [`AllowTypeChange`] is used for the `C` type parameter,
42+
/// the element type must either match the specific type `T` exactly or will be cast to it by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
43+
///
44+
/// # Example
45+
///
46+
/// `PyArrayLike1<'py, T, TypeMustMatch>` will enable you to receive both NumPy arrays and sequences
47+
///
48+
/// ```rust
49+
/// # use pyo3::prelude::*;
50+
/// use pyo3::py_run;
51+
/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch};
52+
///
53+
/// #[pyfunction]
54+
/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, f64, TypeMustMatch>) -> f64 {
55+
/// array.as_array().sum()
56+
/// }
57+
///
58+
/// Python::with_gil(|py| {
59+
/// let np = get_array_module(py).unwrap();
60+
/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
61+
///
62+
/// py_run!(py, np sum_up, r"assert sum_up(np.array([1., 2., 3.])) == 6.");
63+
/// py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6.");
64+
/// });
65+
/// ```
66+
///
67+
/// but it will not cast the element type if that is required
68+
///
69+
/// ```rust,should_panic
70+
/// use pyo3::prelude::*;
71+
/// use pyo3::py_run;
72+
/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch};
73+
///
74+
/// #[pyfunction]
75+
/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, TypeMustMatch>) -> i32 {
76+
/// array.as_array().sum()
77+
/// }
78+
///
79+
/// Python::with_gil(|py| {
80+
/// let np = get_array_module(py).unwrap();
81+
/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
82+
///
83+
/// py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6");
84+
/// });
85+
/// ```
86+
///
87+
/// whereas `PyArrayLike1<'py, T, AllowTypeChange>` will do even at the cost loosing precision
88+
///
89+
/// ```rust
90+
/// use pyo3::prelude::*;
91+
/// use pyo3::py_run;
92+
/// use numpy::{get_array_module, AllowTypeChange, PyArrayLike1};
93+
///
94+
/// #[pyfunction]
95+
/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, AllowTypeChange>) -> i32 {
96+
/// array.as_array().sum()
97+
/// }
98+
///
99+
/// Python::with_gil(|py| {
100+
/// let np = get_array_module(py).unwrap();
101+
/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
102+
///
103+
/// py_run!(py, np sum_up, r"assert sum_up((1.5, 2.5)) == 3");
104+
/// });
105+
/// ```
106+
#[derive(Debug)]
107+
#[repr(transparent)]
108+
pub struct PyArrayLike<'py, T, D, C = TypeMustMatch>(PyReadonlyArray<'py, T, D>, PhantomData<C>)
109+
where
110+
T: Element,
111+
D: Dimension,
112+
C: Coerce;
113+
114+
impl<'py, T, D, C> Deref for PyArrayLike<'py, T, D, C>
115+
where
116+
T: Element,
117+
D: Dimension,
118+
C: Coerce,
119+
{
120+
type Target = PyReadonlyArray<'py, T, D>;
121+
122+
fn deref(&self) -> &Self::Target {
123+
&self.0
124+
}
125+
}
126+
127+
impl<'py, T, D, C> FromPyObject<'py> for PyArrayLike<'py, T, D, C>
128+
where
129+
T: Element,
130+
D: Dimension,
131+
C: Coerce,
132+
Vec<T>: FromPyObject<'py>,
133+
{
134+
fn extract(ob: &'py PyAny) -> PyResult<Self> {
135+
if let Ok(array) = ob.downcast::<PyArray<T, D>>() {
136+
return Ok(Self(array.readonly(), PhantomData));
137+
}
138+
139+
let py = ob.py();
140+
141+
if matches!(D::NDIM, None | Some(1)) {
142+
if let Ok(vec) = ob.extract::<Vec<T>>() {
143+
let array = Array1::from(vec)
144+
.into_dimensionality()
145+
.expect("D being compatible to Ix1")
146+
.into_pyarray(py)
147+
.readonly();
148+
return Ok(Self(array, PhantomData));
149+
}
150+
}
151+
152+
static AS_ARRAY: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
153+
154+
let as_array = AS_ARRAY
155+
.get_or_try_init(py, || {
156+
get_array_module(py)?.getattr("asarray").map(Into::into)
157+
})?
158+
.as_ref(py);
159+
160+
let kwargs = if C::VAL {
161+
let kwargs = PyDict::new(py);
162+
kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?;
163+
Some(kwargs)
164+
} else {
165+
None
166+
};
167+
168+
let array = as_array.call((ob,), kwargs)?.extract()?;
169+
Ok(Self(array, PhantomData))
170+
}
171+
}
172+
173+
/// Receiver for zero-dimensional arrays or array-like types.
174+
pub type PyArrayLike0<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix0, C>;
175+
176+
/// Receiver for one-dimensional arrays or array-like types.
177+
pub type PyArrayLike1<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix1, C>;
178+
179+
/// Receiver for two-dimensional arrays or array-like types.
180+
pub type PyArrayLike2<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix2, C>;
181+
182+
/// Receiver for three-dimensional arrays or array-like types.
183+
pub type PyArrayLike3<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix3, C>;
184+
185+
/// Receiver for four-dimensional arrays or array-like types.
186+
pub type PyArrayLike4<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix4, C>;
187+
188+
/// Receiver for five-dimensional arrays or array-like types.
189+
pub type PyArrayLike5<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix5, C>;
190+
191+
/// Receiver for six-dimensional arrays or array-like types.
192+
pub type PyArrayLike6<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix6, C>;
193+
194+
/// Receiver for arrays or array-like types whose dimensionality is determined at runtime.
195+
pub type PyArrayLikeDyn<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, IxDyn, C>;

src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ as well as the [`PyReadonlyArray::try_as_matrix`] and [`PyReadwriteArray::try_as
7373
#![deny(missing_docs, missing_debug_implementations)]
7474

7575
pub mod array;
76+
mod array_like;
7677
pub mod borrow;
7778
pub mod convert;
7879
pub mod datetime;
@@ -94,6 +95,10 @@ pub use crate::array::{
9495
get_array_module, PyArray, PyArray0, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5,
9596
PyArray6, PyArrayDyn,
9697
};
98+
pub use crate::array_like::{
99+
AllowTypeChange, PyArrayLike, PyArrayLike0, PyArrayLike1, PyArrayLike2, PyArrayLike3,
100+
PyArrayLike4, PyArrayLike5, PyArrayLike6, PyArrayLikeDyn, TypeMustMatch,
101+
};
97102
pub use crate::borrow::{
98103
PyReadonlyArray, PyReadonlyArray0, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3,
99104
PyReadonlyArray4, PyReadonlyArray5, PyReadonlyArray6, PyReadonlyArrayDyn, PyReadwriteArray,

tests/array_like.rs

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
use ndarray::array;
2+
use numpy::{get_array_module, AllowTypeChange, PyArrayLike1, PyArrayLike2, PyArrayLikeDyn};
3+
use pyo3::{
4+
types::{IntoPyDict, PyDict},
5+
Python,
6+
};
7+
8+
fn get_np_locals<'py>(py: Python<'py>) -> &'py PyDict {
9+
[("np", get_array_module(py).unwrap())].into_py_dict(py)
10+
}
11+
12+
#[test]
13+
fn extract_reference() {
14+
Python::with_gil(|py| {
15+
let locals = get_np_locals(py);
16+
let py_array = py
17+
.eval(
18+
"np.array([[1,2],[3,4]], dtype='float64')",
19+
Some(locals),
20+
None,
21+
)
22+
.unwrap();
23+
let extracted_array = py_array.extract::<PyArrayLike2<'_, f64>>().unwrap();
24+
25+
assert_eq!(
26+
array![[1_f64, 2_f64], [3_f64, 4_f64]],
27+
extracted_array.as_array()
28+
);
29+
});
30+
}
31+
32+
#[test]
33+
fn convert_array_on_extract() {
34+
Python::with_gil(|py| {
35+
let locals = get_np_locals(py);
36+
let py_array = py
37+
.eval("np.array([[1,2],[3,4]], dtype='int32')", Some(locals), None)
38+
.unwrap();
39+
let extracted_array = py_array
40+
.extract::<PyArrayLike2<'_, f64, AllowTypeChange>>()
41+
.unwrap();
42+
43+
assert_eq!(
44+
array![[1_f64, 2_f64], [3_f64, 4_f64]],
45+
extracted_array.as_array()
46+
);
47+
});
48+
}
49+
50+
#[test]
51+
fn convert_list_on_extract() {
52+
Python::with_gil(|py| {
53+
let py_list = py.eval("[[1.0,2.0],[3.0,4.0]]", None, None).unwrap();
54+
let extracted_array = py_list.extract::<PyArrayLike2<'_, f64>>().unwrap();
55+
56+
assert_eq!(array![[1.0, 2.0], [3.0, 4.0]], extracted_array.as_array());
57+
});
58+
}
59+
60+
#[test]
61+
fn convert_array_in_list_on_extract() {
62+
Python::with_gil(|py| {
63+
let locals = get_np_locals(py);
64+
let py_array = py
65+
.eval("[np.array([1.0, 2.0]), [3.0, 4.0]]", Some(locals), None)
66+
.unwrap();
67+
let extracted_array = py_array.extract::<PyArrayLike2<'_, f64>>().unwrap();
68+
69+
assert_eq!(array![[1.0, 2.0], [3.0, 4.0]], extracted_array.as_array());
70+
});
71+
}
72+
73+
#[test]
74+
fn convert_list_on_extract_dyn() {
75+
Python::with_gil(|py| {
76+
let py_list = py
77+
.eval("[[[1,2],[3,4]],[[5,6],[7,8]]]", None, None)
78+
.unwrap();
79+
let extracted_array = py_list
80+
.extract::<PyArrayLikeDyn<'_, i64, AllowTypeChange>>()
81+
.unwrap();
82+
83+
assert_eq!(
84+
array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
85+
extracted_array.as_array()
86+
);
87+
});
88+
}
89+
90+
#[test]
91+
fn convert_1d_list_on_extract() {
92+
Python::with_gil(|py| {
93+
let py_list = py.eval("[1,2,3,4]", None, None).unwrap();
94+
let extracted_array_1d = py_list.extract::<PyArrayLike1<'_, u32>>().unwrap();
95+
let extracted_array_dyn = py_list.extract::<PyArrayLikeDyn<'_, f64>>().unwrap();
96+
97+
assert_eq!(array![1, 2, 3, 4], extracted_array_1d.as_array());
98+
assert_eq!(
99+
array![1_f64, 2_f64, 3_f64, 4_f64].into_dyn(),
100+
extracted_array_dyn.as_array()
101+
);
102+
});
103+
}
104+
105+
#[test]
106+
fn unsafe_cast_shall_fail() {
107+
Python::with_gil(|py| {
108+
let locals = get_np_locals(py);
109+
let py_list = py
110+
.eval(
111+
"np.array([1.1,2.2,3.3,4.4], dtype='float64')",
112+
Some(locals),
113+
None,
114+
)
115+
.unwrap();
116+
let extracted_array = py_list.extract::<PyArrayLike1<'_, i32>>();
117+
118+
assert!(extracted_array.is_err());
119+
});
120+
}
121+
122+
#[test]
123+
fn unsafe_cast_with_coerce_works() {
124+
Python::with_gil(|py| {
125+
let locals = get_np_locals(py);
126+
let py_list = py
127+
.eval(
128+
"np.array([1.1,2.2,3.3,4.4], dtype='float64')",
129+
Some(locals),
130+
None,
131+
)
132+
.unwrap();
133+
let extracted_array = py_list
134+
.extract::<PyArrayLike1<'_, i32, AllowTypeChange>>()
135+
.unwrap();
136+
137+
assert_eq!(array![1, 2, 3, 4], extracted_array.as_array());
138+
});
139+
}

0 commit comments

Comments
 (0)