Skip to content

Commit a8aac58

Browse files
authored
Merge pull request #325 from PyO3/repr-transparent
Use repr(transparent) to enforce layout compatibility with PyAny
2 parents 8b2679a + e1e5385 commit a8aac58

File tree

5 files changed

+19
-29
lines changed

5 files changed

+19
-29
lines changed

src/array.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ use ndarray::{
1616
};
1717
use num_traits::AsPrimitive;
1818
use pyo3::{
19-
ffi, pyobject_native_type_named, type_object, types::PyModule, AsPyPointer, FromPyObject,
20-
IntoPy, Py, PyAny, PyClassInitializer, PyDowncastError, PyErr, PyNativeType, PyObject,
21-
PyResult, PyTypeInfo, Python, ToPyObject,
19+
ffi, pyobject_native_type_named, types::PyModule, AsPyPointer, FromPyObject, IntoPy, Py, PyAny,
20+
PyClassInitializer, PyDowncastError, PyErr, PyNativeType, PyObject, PyResult, PyTypeInfo,
21+
Python, ToPyObject,
2222
};
2323

2424
use crate::borrow::{PyReadonlyArray, PyReadwriteArray};
@@ -95,6 +95,7 @@ use crate::slice_container::PySliceContainer;
9595
///
9696
/// [ndarray]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html
9797
/// [pyo3-memory]: https://pyo3.rs/main/memory.html
98+
#[repr(transparent)]
9899
pub struct PyArray<T, D>(PyAny, PhantomData<T>, PhantomData<D>);
99100

100101
/// Zero-dimensional array.
@@ -119,10 +120,6 @@ pub fn get_array_module(py: Python<'_>) -> PyResult<&PyModule> {
119120
PyModule::import(py, npyffi::array::MOD_NAME)
120121
}
121122

122-
unsafe impl<T, D> type_object::PyLayout<PyArray<T, D>> for npyffi::PyArrayObject {}
123-
124-
impl<T, D> type_object::PySizedLayout<PyArray<T, D>> for npyffi::PyArrayObject {}
125-
126123
unsafe impl<T: Element, D: Dimension> PyTypeInfo for PyArray<T, D> {
127124
type AsRefTarget = Self;
128125

src/dtype.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ pub use num_complex::{Complex32, Complex64};
4545
/// ```
4646
///
4747
/// [dtype]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.html
48+
#[repr(transparent)]
4849
pub struct PyArrayDescr(PyAny);
4950

5051
pyobject_native_type_named!(PyArrayDescr);
@@ -61,12 +62,7 @@ unsafe impl PyTypeInfo for PyArrayDescr {
6162
}
6263

6364
fn is_type_of(ob: &PyAny) -> bool {
64-
unsafe {
65-
ffi::PyObject_TypeCheck(
66-
ob.as_ptr(),
67-
PY_ARRAY_API.get_type_object(ob.py(), NpyTypes::PyArrayDescr_Type),
68-
) > 0
69-
}
65+
unsafe { ffi::PyObject_TypeCheck(ob.as_ptr(), Self::type_object_raw(ob.py())) > 0 }
7066
}
7167
}
7268

src/npyffi/array.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -328,14 +328,15 @@ impl PyArrayAPI {
328328
impl_api![303; PyArray_SetWritebackIfCopyBase(arr: *mut PyArrayObject, base: *mut PyArrayObject) -> c_int];
329329
}
330330

331-
// Define type objects that belongs to Numpy API
331+
// Define type objects associated with the NumPy API
332332
macro_rules! impl_array_type {
333333
($(($offset:expr, $tname:ident)),*) => {
334-
/// All type objects of numpy API.
334+
/// All type objects exported by the NumPy API.
335335
#[allow(non_camel_case_types)]
336336
pub enum NpyTypes { $($tname),* }
337+
337338
impl PyArrayAPI {
338-
/// Get the pointer of the type object that `self` refers.
339+
/// Get a pointer of the type object assocaited with `ty`.
339340
pub unsafe fn get_type_object(&self, py: Python, ty: NpyTypes) -> *mut PyTypeObject {
340341
match ty {
341342
$( NpyTypes::$tname => *(self.get(py, $offset)) as _ ),*
@@ -401,11 +402,11 @@ pub unsafe fn PyArray_CheckExact(py: Python, op: *mut PyObject) -> c_int {
401402

402403
#[cfg(test)]
403404
mod tests {
404-
use super::PY_ARRAY_API;
405+
use super::*;
405406

406407
#[test]
407408
fn call_api() {
408-
pyo3::Python::with_gil(|py| unsafe {
409+
Python::with_gil(|py| unsafe {
409410
assert_eq!(
410411
PY_ARRAY_API.PyArray_MultiplyIntList(py, [1, 2, 3].as_mut_ptr(), 3),
411412
6

src/npyffi/mod.rs

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,24 @@ fn get_numpy_api(_py: Python, module: &str, capsule: &str) -> *const *const c_vo
1818
let module = CString::new(module).unwrap();
1919
let capsule = CString::new(capsule).unwrap();
2020
unsafe {
21-
let numpy = ffi::PyImport_ImportModule(module.as_ptr());
22-
assert!(!numpy.is_null(), "Failed to import numpy module");
23-
let capsule = ffi::PyObject_GetAttrString(numpy as _, capsule.as_ptr());
24-
assert!(!capsule.is_null(), "Failed to get numpy capsule API");
21+
let module = ffi::PyImport_ImportModule(module.as_ptr());
22+
assert!(!module.is_null(), "Failed to import NumPy module");
23+
let capsule = ffi::PyObject_GetAttrString(module as _, capsule.as_ptr());
24+
assert!(!capsule.is_null(), "Failed to get NumPy API capsule");
2525
ffi::PyCapsule_GetPointer(capsule, null_mut()) as _
2626
}
2727
}
2828

29-
// Define Array&UFunc APIs
29+
// Implements wrappers for NumPy's Array and UFunc API
3030
macro_rules! impl_api {
31-
[$offset: expr; $fname: ident ( $($arg: ident : $t: ty),* ) $( -> $ret: ty )* ] => {
31+
[$offset: expr; $fname: ident ( $($arg: ident : $t: ty),* $(,)?) $( -> $ret: ty )* ] => {
3232
#[allow(non_snake_case)]
3333
pub unsafe fn $fname(&self, py: Python, $($arg : $t), *) $( -> $ret )* {
3434
let fptr = self.get(py, $offset)
3535
as *const extern fn ($($arg : $t), *) $( -> $ret )*;
3636
(*fptr)($($arg), *)
3737
}
3838
};
39-
// To allow fn a(b: type,) -> ret
40-
[$offset: expr; $fname: ident ( $($arg: ident : $t:ty,)* ) $( -> $ret: ty )* ] => {
41-
impl_api![$offset; $fname( $($arg: $t),*) $( -> $ret )*];
42-
}
4339
}
4440

4541
pub mod array;

src/npyffi/ufunc.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ const MOD_NAME: &str = "numpy.core.umath";
1212
const CAPSULE_NAME: &str = "_UFUNC_API";
1313

1414
/// A global variable which stores a ['capsule'](https://docs.python.org/3/c-api/capsule.html)
15-
/// pointer to [Numpy UFunc API](https://numpy.org/doc/stable/reference/c-api/array.html).
15+
/// pointer to [Numpy UFunc API](https://numpy.org/doc/stable/reference/c-api/ufunc.html).
1616
pub static PY_UFUNC_API: PyUFuncAPI = PyUFuncAPI::new();
1717

1818
pub struct PyUFuncAPI {

0 commit comments

Comments
 (0)