Skip to content

Commit 9aa7cb8

Browse files
authored
MultiFab: to_numpy/cupy (#192)
* MultiFab: to_numpy/cupy Add numpy & cupy helpers for MultiFab. * Update Stub Files --------- Co-authored-by: ax3l <[email protected]>
1 parent 596f0e7 commit 9aa7cb8

File tree

11 files changed

+228
-16
lines changed

11 files changed

+228
-16
lines changed

src/amrex/MultiFab.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"""
88

99

10-
def mf_to_numpy(self, copy=False, order="F"):
10+
def mf_to_numpy(amr, self, copy=False, order="F"):
1111
"""
1212
Provide a Numpy view into a MultiFab.
1313
@@ -29,13 +29,24 @@ def mf_to_numpy(self, copy=False, order="F"):
2929
3030
Returns
3131
-------
32-
list of np.array
32+
list of numpy.array
3333
A list of numpy n-dimensional arrays, for each local block in the
3434
MultiFab.
3535
"""
36+
mf = self
37+
if copy:
38+
mf = amr.MultiFab(
39+
self.box_array(),
40+
self.dm(),
41+
self.n_comp(),
42+
self.n_grow_vect(),
43+
amr.MFInfo().set_arena(amr.The_Pinned_Arena()),
44+
)
45+
amr.dtoh_memcpy(mf, self)
46+
3647
views = []
37-
for mfi in self:
38-
views.append(self.array(mfi).to_numpy(copy, order))
48+
for mfi in mf:
49+
views.append(mf.array(mfi).to_numpy(copy=False, order=order))
3950

4051
return views
4152

@@ -80,15 +91,11 @@ def mf_to_cupy(self, copy=False, order="F"):
8091

8192
def register_MultiFab_extension(amr):
8293
"""MultiFab helper methods"""
83-
import inspect
84-
import sys
85-
86-
# register member functions for every MultiFab* type
87-
for _, MultiFab_type in inspect.getmembers(
88-
sys.modules[amr.__name__],
89-
lambda member: inspect.isclass(member)
90-
and member.__module__ == amr.__name__
91-
and member.__name__.startswith("MultiFab"),
92-
):
93-
MultiFab_type.to_numpy = mf_to_numpy
94-
MultiFab_type.to_cupy = mf_to_cupy
94+
95+
# register member functions for the MultiFab type
96+
amr.MultiFab.to_numpy = lambda self, copy=False, order="F": mf_to_numpy(
97+
amr, self, copy, order
98+
)
99+
amr.MultiFab.to_numpy.__doc__ = mf_to_numpy.__doc__
100+
101+
amr.MultiFab.to_cupy = mf_to_cupy

src/amrex/space1d/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ def Print(*args, **kwargs):
4646

4747
from ..Array4 import register_Array4_extension
4848
from ..ArrayOfStructs import register_AoS_extension
49+
from ..MultiFab import register_MultiFab_extension
4950
from ..PODVector import register_PODVector_extension
5051
from ..StructOfArrays import register_SoA_extension
5152

5253
register_Array4_extension(amrex_1d_pybind)
54+
register_MultiFab_extension(amrex_1d_pybind)
5355
register_PODVector_extension(amrex_1d_pybind)
5456
register_SoA_extension(amrex_1d_pybind)
5557
register_AoS_extension(amrex_1d_pybind)

src/amrex/space1d/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import os as os
3838

3939
from amrex.Array4 import register_Array4_extension
4040
from amrex.ArrayOfStructs import register_AoS_extension
41+
from amrex.MultiFab import register_MultiFab_extension
4142
from amrex.PODVector import register_PODVector_extension
4243
from amrex.StructOfArrays import register_SoA_extension
4344
from amrex.space1d.amrex_1d_pybind import (
@@ -461,6 +462,7 @@ __all__ = [
461462
"refine",
462463
"register_AoS_extension",
463464
"register_Array4_extension",
465+
"register_MultiFab_extension",
464466
"register_PODVector_extension",
465467
"register_SoA_extension",
466468
"size",

src/amrex/space1d/amrex_1d_pybind/__init__.pyi

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4043,6 +4043,67 @@ class MultiFab(FabArray_FArrayBox):
40434043
"""
40444044
Same as sum with local=false, but for non-cell-centered data, thisskips non-unique points that are owned by multiple boxes.
40454045
"""
4046+
def to_cupy(self, copy=False, order="F"):
4047+
"""
4048+
4049+
Provide a Cupy view into a MultiFab.
4050+
4051+
Note on the order of indices:
4052+
By default, this is as in AMReX in Fortran contiguous order, indexing as
4053+
x,y,z. This has performance implications for use in external libraries such
4054+
as cupy.
4055+
The order="C" option will index as z,y,x and perform better with cupy.
4056+
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074
4057+
4058+
Parameters
4059+
----------
4060+
self : amrex.MultiFab
4061+
A MultiFab class in pyAMReX
4062+
copy : bool, optional
4063+
Copy the data if true, otherwise create a view (default).
4064+
order : string, optional
4065+
F order (default) or C. C is faster with external libraries.
4066+
4067+
Returns
4068+
-------
4069+
list of cupy.array
4070+
A list of cupy n-dimensional arrays, for each local block in the
4071+
MultiFab.
4072+
4073+
Raises
4074+
------
4075+
ImportError
4076+
Raises an exception if cupy is not installed
4077+
4078+
"""
4079+
def to_numpy(self, copy=False, order="F"):
4080+
"""
4081+
4082+
Provide a Numpy view into a MultiFab.
4083+
4084+
Note on the order of indices:
4085+
By default, this is as in AMReX in Fortran contiguous order, indexing as
4086+
x,y,z. This has performance implications for use in external libraries such
4087+
as cupy.
4088+
The order="C" option will index as z,y,x and perform better with cupy.
4089+
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074
4090+
4091+
Parameters
4092+
----------
4093+
self : amrex.MultiFab
4094+
A MultiFab class in pyAMReX
4095+
copy : bool, optional
4096+
Copy the data if true, otherwise create a view (default).
4097+
order : string, optional
4098+
F order (default) or C. C is faster with external libraries.
4099+
4100+
Returns
4101+
-------
4102+
list of numpy.array
4103+
A list of numpy n-dimensional arrays, for each local block in the
4104+
MultiFab.
4105+
4106+
"""
40464107
def weighted_sync(self, arg0: MultiFab, arg1: Periodicity) -> None: ...
40474108

40484109
class PIdx:

src/amrex/space2d/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ def Print(*args, **kwargs):
4646

4747
from ..Array4 import register_Array4_extension
4848
from ..ArrayOfStructs import register_AoS_extension
49+
from ..MultiFab import register_MultiFab_extension
4950
from ..PODVector import register_PODVector_extension
5051
from ..StructOfArrays import register_SoA_extension
5152

5253
register_Array4_extension(amrex_2d_pybind)
54+
register_MultiFab_extension(amrex_2d_pybind)
5355
register_PODVector_extension(amrex_2d_pybind)
5456
register_SoA_extension(amrex_2d_pybind)
5557
register_AoS_extension(amrex_2d_pybind)

src/amrex/space2d/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import os as os
3838

3939
from amrex.Array4 import register_Array4_extension
4040
from amrex.ArrayOfStructs import register_AoS_extension
41+
from amrex.MultiFab import register_MultiFab_extension
4142
from amrex.PODVector import register_PODVector_extension
4243
from amrex.StructOfArrays import register_SoA_extension
4344
from amrex.space2d.amrex_2d_pybind import (
@@ -461,6 +462,7 @@ __all__ = [
461462
"refine",
462463
"register_AoS_extension",
463464
"register_Array4_extension",
465+
"register_MultiFab_extension",
464466
"register_PODVector_extension",
465467
"register_SoA_extension",
466468
"size",

src/amrex/space2d/amrex_2d_pybind/__init__.pyi

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4049,6 +4049,67 @@ class MultiFab(FabArray_FArrayBox):
40494049
"""
40504050
Same as sum with local=false, but for non-cell-centered data, thisskips non-unique points that are owned by multiple boxes.
40514051
"""
4052+
def to_cupy(self, copy=False, order="F"):
4053+
"""
4054+
4055+
Provide a Cupy view into a MultiFab.
4056+
4057+
Note on the order of indices:
4058+
By default, this is as in AMReX in Fortran contiguous order, indexing as
4059+
x,y,z. This has performance implications for use in external libraries such
4060+
as cupy.
4061+
The order="C" option will index as z,y,x and perform better with cupy.
4062+
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074
4063+
4064+
Parameters
4065+
----------
4066+
self : amrex.MultiFab
4067+
A MultiFab class in pyAMReX
4068+
copy : bool, optional
4069+
Copy the data if true, otherwise create a view (default).
4070+
order : string, optional
4071+
F order (default) or C. C is faster with external libraries.
4072+
4073+
Returns
4074+
-------
4075+
list of cupy.array
4076+
A list of cupy n-dimensional arrays, for each local block in the
4077+
MultiFab.
4078+
4079+
Raises
4080+
------
4081+
ImportError
4082+
Raises an exception if cupy is not installed
4083+
4084+
"""
4085+
def to_numpy(self, copy=False, order="F"):
4086+
"""
4087+
4088+
Provide a Numpy view into a MultiFab.
4089+
4090+
Note on the order of indices:
4091+
By default, this is as in AMReX in Fortran contiguous order, indexing as
4092+
x,y,z. This has performance implications for use in external libraries such
4093+
as cupy.
4094+
The order="C" option will index as z,y,x and perform better with cupy.
4095+
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074
4096+
4097+
Parameters
4098+
----------
4099+
self : amrex.MultiFab
4100+
A MultiFab class in pyAMReX
4101+
copy : bool, optional
4102+
Copy the data if true, otherwise create a view (default).
4103+
order : string, optional
4104+
F order (default) or C. C is faster with external libraries.
4105+
4106+
Returns
4107+
-------
4108+
list of numpy.array
4109+
A list of numpy n-dimensional arrays, for each local block in the
4110+
MultiFab.
4111+
4112+
"""
40524113
def weighted_sync(self, arg0: MultiFab, arg1: Periodicity) -> None: ...
40534114

40544115
class PIdx:

src/amrex/space3d/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ def Print(*args, **kwargs):
4646

4747
from ..Array4 import register_Array4_extension
4848
from ..ArrayOfStructs import register_AoS_extension
49+
from ..MultiFab import register_MultiFab_extension
4950
from ..PODVector import register_PODVector_extension
5051
from ..StructOfArrays import register_SoA_extension
5152

5253
register_Array4_extension(amrex_3d_pybind)
54+
register_MultiFab_extension(amrex_3d_pybind)
5355
register_PODVector_extension(amrex_3d_pybind)
5456
register_SoA_extension(amrex_3d_pybind)
5557
register_AoS_extension(amrex_3d_pybind)

src/amrex/space3d/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import os as os
3838

3939
from amrex.Array4 import register_Array4_extension
4040
from amrex.ArrayOfStructs import register_AoS_extension
41+
from amrex.MultiFab import register_MultiFab_extension
4142
from amrex.PODVector import register_PODVector_extension
4243
from amrex.StructOfArrays import register_SoA_extension
4344
from amrex.space3d.amrex_3d_pybind import (
@@ -461,6 +462,7 @@ __all__ = [
461462
"refine",
462463
"register_AoS_extension",
463464
"register_Array4_extension",
465+
"register_MultiFab_extension",
464466
"register_PODVector_extension",
465467
"register_SoA_extension",
466468
"size",

src/amrex/space3d/amrex_3d_pybind/__init__.pyi

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4052,6 +4052,67 @@ class MultiFab(FabArray_FArrayBox):
40524052
"""
40534053
Same as sum with local=false, but for non-cell-centered data, thisskips non-unique points that are owned by multiple boxes.
40544054
"""
4055+
def to_cupy(self, copy=False, order="F"):
4056+
"""
4057+
4058+
Provide a Cupy view into a MultiFab.
4059+
4060+
Note on the order of indices:
4061+
By default, this is as in AMReX in Fortran contiguous order, indexing as
4062+
x,y,z. This has performance implications for use in external libraries such
4063+
as cupy.
4064+
The order="C" option will index as z,y,x and perform better with cupy.
4065+
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074
4066+
4067+
Parameters
4068+
----------
4069+
self : amrex.MultiFab
4070+
A MultiFab class in pyAMReX
4071+
copy : bool, optional
4072+
Copy the data if true, otherwise create a view (default).
4073+
order : string, optional
4074+
F order (default) or C. C is faster with external libraries.
4075+
4076+
Returns
4077+
-------
4078+
list of cupy.array
4079+
A list of cupy n-dimensional arrays, for each local block in the
4080+
MultiFab.
4081+
4082+
Raises
4083+
------
4084+
ImportError
4085+
Raises an exception if cupy is not installed
4086+
4087+
"""
4088+
def to_numpy(self, copy=False, order="F"):
4089+
"""
4090+
4091+
Provide a Numpy view into a MultiFab.
4092+
4093+
Note on the order of indices:
4094+
By default, this is as in AMReX in Fortran contiguous order, indexing as
4095+
x,y,z. This has performance implications for use in external libraries such
4096+
as cupy.
4097+
The order="C" option will index as z,y,x and perform better with cupy.
4098+
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074
4099+
4100+
Parameters
4101+
----------
4102+
self : amrex.MultiFab
4103+
A MultiFab class in pyAMReX
4104+
copy : bool, optional
4105+
Copy the data if true, otherwise create a view (default).
4106+
order : string, optional
4107+
F order (default) or C. C is faster with external libraries.
4108+
4109+
Returns
4110+
-------
4111+
list of numpy.array
4112+
A list of numpy n-dimensional arrays, for each local block in the
4113+
MultiFab.
4114+
4115+
"""
40554116
def weighted_sync(self, arg0: MultiFab, arg1: Periodicity) -> None: ...
40564117

40574118
class PIdx:

0 commit comments

Comments
 (0)