|
| 1 | +#![feature(proc_macro, proc_macro_path_invoc, specialization)] |
1 | 2 |
|
2 |
| -#[macro_use] |
3 |
| -extern crate pyo3; |
4 |
| -extern crate numpy; |
5 | 3 | extern crate ndarray;
|
| 4 | +extern crate numpy; |
| 5 | +extern crate pyo3; |
6 | 6 |
|
7 |
| -use numpy::*; |
8 | 7 | use ndarray::*;
|
9 |
| -use pyo3::{PyResult, Python, PyObject}; |
10 |
| - |
11 |
| -/* Pure rust-ndarray functions */ |
12 |
| - |
13 |
| -// immutable example |
14 |
| -fn axpy(a: f64, x: ArrayViewD<f64>, y: ArrayViewD<f64>) -> ArrayD<f64> { |
15 |
| - a * &x + &y |
16 |
| -} |
17 |
| - |
18 |
| -// mutable example (no return) |
19 |
| -fn mult(a: f64, mut x: ArrayViewMutD<f64>) { |
20 |
| - x *= a; |
21 |
| -} |
22 |
| - |
23 |
| -/* rust-pyo3 wrappers (to be exposed) */ |
24 |
| - |
25 |
| -// wrapper of `axpy` |
26 |
| -fn axpy_py(py: Python, a: f64, x: PyArray, y: PyArray) -> PyResult<PyArray> { |
27 |
| - let np = PyArrayModule::import(py)?; |
28 |
| - let x = x.as_array().into_pyresult(py, "x must be f64 array")?; |
29 |
| - let y = y.as_array().into_pyresult(py, "y must be f64 array")?; |
30 |
| - Ok(axpy(a, x, y).into_pyarray(py, &np)) |
31 |
| -} |
32 |
| - |
33 |
| -// wrapper of `mult` |
34 |
| -fn mult_py(py: Python, a: f64, x: PyArray) -> PyResult<PyObject> { |
35 |
| - let x = x.as_array_mut().into_pyresult(py, "x must be f64 array")?; |
36 |
| - mult(a, x); |
37 |
| - Ok(py.None()) // Python function must returns |
38 |
| -} |
| 8 | +use numpy::*; |
| 9 | +use pyo3::{py, PyModule, PyObject, PyResult, Python}; |
| 10 | + |
| 11 | +#[py::modinit(rust_ext)] |
| 12 | +fn init_module(py: Python, m: &PyModule) -> PyResult<()> { |
| 13 | + // immutable example |
| 14 | + fn axpy(a: f64, x: ArrayViewD<f64>, y: ArrayViewD<f64>) -> ArrayD<f64> { |
| 15 | + a * &x + &y |
| 16 | + } |
| 17 | + |
| 18 | + // mutable example (no return) |
| 19 | + fn mult(a: f64, mut x: ArrayViewMutD<f64>) { |
| 20 | + x *= a; |
| 21 | + } |
| 22 | + |
| 23 | + // wrapper of `axpy` |
| 24 | + #[pyfn(m, "axpy")] |
| 25 | + fn axpy_py(py: Python, a: f64, x: PyArray, y: PyArray) -> PyResult<PyArray> { |
| 26 | + let np = PyArrayModule::import(py)?; |
| 27 | + let x = x.as_array().into_pyresult(py, "x must be f64 array")?; |
| 28 | + let y = y.as_array().into_pyresult(py, "y must be f64 array")?; |
| 29 | + Ok(axpy(a, x, y).into_pyarray(py, &np)) |
| 30 | + } |
| 31 | + |
| 32 | + // wrapper of `mult` |
| 33 | + #[pyfn(m, "mult")] |
| 34 | + fn mult_py(py: Python, a: f64, x: PyArray) -> PyResult<PyObject> { |
| 35 | + let x = x.as_array_mut().into_pyresult(py, "x must be f64 array")?; |
| 36 | + mult(a, x); |
| 37 | + Ok(py.None()) // Python function must returns |
| 38 | + } |
39 | 39 |
|
40 |
| -/* Define module "_rust_ext" */ |
41 |
| -py_module_initializer!(_rust_ext, init_rust_ext, PyInit__rust_ext, |py, m| { |
42 |
| - m.add(py, "__doc__", "Rust extension for NumPy")?; |
43 |
| - m.add(py, |
44 |
| - "axpy", |
45 |
| - py_fn!(py, axpy_py(a: f64, x: PyArray, y: PyArray)))?; |
46 |
| - m.add(py, "mult", py_fn!(py, mult_py(a: f64, x: PyArray)))?; |
47 | 40 | Ok(())
|
48 |
| -}); |
| 41 | +} |
0 commit comments