Skip to content

Commit d256160

Browse files
committed
MultiFab: to_numpy/cupy
Add numpy & cupy helpers for MultiFab.
1 parent 596f0e7 commit d256160

File tree

5 files changed

+39
-16
lines changed

5 files changed

+39
-16
lines changed

src/amrex/MultiFab.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"""
88

99

10-
def mf_to_numpy(self, copy=False, order="F"):
10+
def mf_to_numpy(amr, self, copy=False, order="F"):
1111
"""
1212
Provide a Numpy view into a MultiFab.
1313
@@ -29,13 +29,24 @@ def mf_to_numpy(self, copy=False, order="F"):
2929
3030
Returns
3131
-------
32-
list of np.array
32+
list of numpy.array
3333
A list of numpy n-dimensional arrays, for each local block in the
3434
MultiFab.
3535
"""
36+
mf = self
37+
if copy:
38+
mf = amr.MultiFab(
39+
self.box_array(),
40+
self.dm(),
41+
self.n_comp(),
42+
self.n_grow_vect(),
43+
amr.MFInfo().set_arena(amr.The_Pinned_Arena()),
44+
)
45+
amr.dtoh_memcpy(mf, self)
46+
3647
views = []
37-
for mfi in self:
38-
views.append(self.array(mfi).to_numpy(copy, order))
48+
for mfi in mf:
49+
views.append(mf.array(mfi).to_numpy(copy=False, order=order))
3950

4051
return views
4152

@@ -80,15 +91,11 @@ def mf_to_cupy(self, copy=False, order="F"):
8091

8192
def register_MultiFab_extension(amr):
8293
"""MultiFab helper methods"""
83-
import inspect
84-
import sys
85-
86-
# register member functions for every MultiFab* type
87-
for _, MultiFab_type in inspect.getmembers(
88-
sys.modules[amr.__name__],
89-
lambda member: inspect.isclass(member)
90-
and member.__module__ == amr.__name__
91-
and member.__name__.startswith("MultiFab"),
92-
):
93-
MultiFab_type.to_numpy = mf_to_numpy
94-
MultiFab_type.to_cupy = mf_to_cupy
94+
95+
# register member functions for the MultiFab type
96+
amr.MultiFab.to_numpy = lambda self, copy=False, order="F": mf_to_numpy(
97+
amr, self, copy, order
98+
)
99+
amr.MultiFab.to_numpy.__doc__ = mf_to_numpy.__doc__
100+
101+
amr.MultiFab.to_cupy = mf_to_cupy

src/amrex/space1d/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ def Print(*args, **kwargs):
4646

4747
from ..Array4 import register_Array4_extension
4848
from ..ArrayOfStructs import register_AoS_extension
49+
from ..MultiFab import register_MultiFab_extension
4950
from ..PODVector import register_PODVector_extension
5051
from ..StructOfArrays import register_SoA_extension
5152

5253
register_Array4_extension(amrex_1d_pybind)
54+
register_MultiFab_extension(amrex_1d_pybind)
5355
register_PODVector_extension(amrex_1d_pybind)
5456
register_SoA_extension(amrex_1d_pybind)
5557
register_AoS_extension(amrex_1d_pybind)

src/amrex/space2d/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ def Print(*args, **kwargs):
4646

4747
from ..Array4 import register_Array4_extension
4848
from ..ArrayOfStructs import register_AoS_extension
49+
from ..MultiFab import register_MultiFab_extension
4950
from ..PODVector import register_PODVector_extension
5051
from ..StructOfArrays import register_SoA_extension
5152

5253
register_Array4_extension(amrex_2d_pybind)
54+
register_MultiFab_extension(amrex_2d_pybind)
5355
register_PODVector_extension(amrex_2d_pybind)
5456
register_SoA_extension(amrex_2d_pybind)
5557
register_AoS_extension(amrex_2d_pybind)

src/amrex/space3d/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ def Print(*args, **kwargs):
4646

4747
from ..Array4 import register_Array4_extension
4848
from ..ArrayOfStructs import register_AoS_extension
49+
from ..MultiFab import register_MultiFab_extension
4950
from ..PODVector import register_PODVector_extension
5051
from ..StructOfArrays import register_SoA_extension
5152

5253
register_Array4_extension(amrex_3d_pybind)
54+
register_MultiFab_extension(amrex_3d_pybind)
5355
register_PODVector_extension(amrex_3d_pybind)
5456
register_SoA_extension(amrex_3d_pybind)
5557
register_AoS_extension(amrex_3d_pybind)

tests/test_multifab.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,3 +350,13 @@ def test_mfab_dtoh_copy(make_mfab_device):
350350
device_max = mfab_device.max(0)
351351
assert device_min == device_max
352352
assert device_max == 11.0
353+
354+
# numpy bindings (w/ copy)
355+
local_boxes_host = mfab_device.to_numpy(copy=True)
356+
assert max([np.max(box) for box in local_boxes_host]) == device_max
357+
358+
# cupy bindings (w/o copy)
359+
import cupy as cp
360+
361+
local_boxes_device = mfab_device.to_cupy()
362+
assert max([cp.max(box) for box in local_boxes_device]) == device_max

0 commit comments

Comments
 (0)