Skip to content

Commit 1958ae1

Browse files
committed
Update sample
1 parent 4a9fbe3 commit 1958ae1

File tree

2 files changed

+39
-43
lines changed

2 files changed

+39
-43
lines changed

example/extensions/Cargo.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,8 @@ crate-type = ["cdylib"]
99

1010
[dependencies]
1111
numpy = { path = "../.." }
12-
cpython = "0.1"
1312
ndarray = "0.10"
13+
14+
[dependencies.pyo3]
15+
version = "*"
16+
features = ["extension-module"]

example/extensions/src/lib.rs

+35-42
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,41 @@
1+
#![feature(proc_macro, proc_macro_path_invoc, specialization)]
12

2-
#[macro_use]
3-
extern crate pyo3;
4-
extern crate numpy;
53
extern crate ndarray;
4+
extern crate numpy;
5+
extern crate pyo3;
66

7-
use numpy::*;
87
use ndarray::*;
9-
use pyo3::{PyResult, Python, PyObject};
10-
11-
/* Pure rust-ndarray functions */
12-
13-
// immutable example
14-
fn axpy(a: f64, x: ArrayViewD<f64>, y: ArrayViewD<f64>) -> ArrayD<f64> {
15-
a * &x + &y
16-
}
17-
18-
// mutable example (no return)
19-
fn mult(a: f64, mut x: ArrayViewMutD<f64>) {
20-
x *= a;
21-
}
22-
23-
/* rust-pyo3 wrappers (to be exposed) */
24-
25-
// wrapper of `axpy`
26-
fn axpy_py(py: Python, a: f64, x: PyArray, y: PyArray) -> PyResult<PyArray> {
27-
let np = PyArrayModule::import(py)?;
28-
let x = x.as_array().into_pyresult(py, "x must be f64 array")?;
29-
let y = y.as_array().into_pyresult(py, "y must be f64 array")?;
30-
Ok(axpy(a, x, y).into_pyarray(py, &np))
31-
}
32-
33-
// wrapper of `mult`
34-
fn mult_py(py: Python, a: f64, x: PyArray) -> PyResult<PyObject> {
35-
let x = x.as_array_mut().into_pyresult(py, "x must be f64 array")?;
36-
mult(a, x);
37-
Ok(py.None()) // Python function must returns
38-
}
8+
use numpy::*;
9+
use pyo3::{py, PyModule, PyObject, PyResult, Python};
10+
11+
#[py::modinit(rust_ext)]
12+
fn init_module(py: Python, m: &PyModule) -> PyResult<()> {
13+
// immutable example
14+
fn axpy(a: f64, x: ArrayViewD<f64>, y: ArrayViewD<f64>) -> ArrayD<f64> {
15+
a * &x + &y
16+
}
17+
18+
// mutable example (no return)
19+
fn mult(a: f64, mut x: ArrayViewMutD<f64>) {
20+
x *= a;
21+
}
22+
23+
// wrapper of `axpy`
24+
#[pyfn(m, "axpy")]
25+
fn axpy_py(py: Python, a: f64, x: PyArray, y: PyArray) -> PyResult<PyArray> {
26+
let np = PyArrayModule::import(py)?;
27+
let x = x.as_array().into_pyresult(py, "x must be f64 array")?;
28+
let y = y.as_array().into_pyresult(py, "y must be f64 array")?;
29+
Ok(axpy(a, x, y).into_pyarray(py, &np))
30+
}
31+
32+
// wrapper of `mult`
33+
#[pyfn(m, "mult")]
34+
fn mult_py(py: Python, a: f64, x: PyArray) -> PyResult<PyObject> {
35+
let x = x.as_array_mut().into_pyresult(py, "x must be f64 array")?;
36+
mult(a, x);
37+
Ok(py.None()) // Python function must returns
38+
}
3939

40-
/* Define module "_rust_ext" */
41-
py_module_initializer!(_rust_ext, init_rust_ext, PyInit__rust_ext, |py, m| {
42-
m.add(py, "__doc__", "Rust extension for NumPy")?;
43-
m.add(py,
44-
"axpy",
45-
py_fn!(py, axpy_py(a: f64, x: PyArray, y: PyArray)))?;
46-
m.add(py, "mult", py_fn!(py, mult_py(a: f64, x: PyArray)))?;
4740
Ok(())
48-
});
41+
}

0 commit comments

Comments
 (0)