Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions src/amrex/MultiFab.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""


def mf_to_numpy(self, copy=False, order="F"):
def mf_to_numpy(amr, self, copy=False, order="F"):
"""
Provide a Numpy view into a MultiFab.

Expand All @@ -29,13 +29,24 @@ def mf_to_numpy(self, copy=False, order="F"):

Returns
-------
list of np.array
list of numpy.array
A list of numpy n-dimensional arrays, for each local block in the
MultiFab.
"""
mf = self
if copy:
mf = amr.MultiFab(
self.box_array(),
self.dm(),
self.n_comp(),
self.n_grow_vect(),
amr.MFInfo().set_arena(amr.The_Pinned_Arena()),
)
amr.dtoh_memcpy(mf, self)

views = []
for mfi in self:
views.append(self.array(mfi).to_numpy(copy, order))
for mfi in mf:
views.append(mf.array(mfi).to_numpy(copy=False, order=order))

return views

Expand Down Expand Up @@ -80,15 +91,11 @@ def mf_to_cupy(self, copy=False, order="F"):

def register_MultiFab_extension(amr):
"""MultiFab helper methods"""
import inspect
import sys

# register member functions for every MultiFab* type
for _, MultiFab_type in inspect.getmembers(
sys.modules[amr.__name__],
lambda member: inspect.isclass(member)
and member.__module__ == amr.__name__
and member.__name__.startswith("MultiFab"),
):
MultiFab_type.to_numpy = mf_to_numpy
MultiFab_type.to_cupy = mf_to_cupy

# register member functions for the MultiFab type
amr.MultiFab.to_numpy = lambda self, copy=False, order="F": mf_to_numpy(
amr, self, copy, order
)
amr.MultiFab.to_numpy.__doc__ = mf_to_numpy.__doc__

amr.MultiFab.to_cupy = mf_to_cupy
2 changes: 2 additions & 0 deletions src/amrex/space1d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ def Print(*args, **kwargs):

from ..Array4 import register_Array4_extension
from ..ArrayOfStructs import register_AoS_extension
from ..MultiFab import register_MultiFab_extension
from ..PODVector import register_PODVector_extension
from ..StructOfArrays import register_SoA_extension

register_Array4_extension(amrex_1d_pybind)
register_MultiFab_extension(amrex_1d_pybind)
register_PODVector_extension(amrex_1d_pybind)
register_SoA_extension(amrex_1d_pybind)
register_AoS_extension(amrex_1d_pybind)
2 changes: 2 additions & 0 deletions src/amrex/space1d/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import os as os

from amrex.Array4 import register_Array4_extension
from amrex.ArrayOfStructs import register_AoS_extension
from amrex.MultiFab import register_MultiFab_extension
from amrex.PODVector import register_PODVector_extension
from amrex.StructOfArrays import register_SoA_extension
from amrex.space1d.amrex_1d_pybind import (
Expand Down Expand Up @@ -461,6 +462,7 @@ __all__ = [
"refine",
"register_AoS_extension",
"register_Array4_extension",
"register_MultiFab_extension",
"register_PODVector_extension",
"register_SoA_extension",
"size",
Expand Down
61 changes: 61 additions & 0 deletions src/amrex/space1d/amrex_1d_pybind/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4043,6 +4043,67 @@ class MultiFab(FabArray_FArrayBox):
"""
Same as sum with local=false, but for non-cell-centered data, thisskips non-unique points that are owned by multiple boxes.
"""
def to_cupy(self, copy=False, order="F"):
"""

Provide a Cupy view into a MultiFab.

Note on the order of indices:
By default, this is as in AMReX in Fortran contiguous order, indexing as
x,y,z. This has performance implications for use in external libraries such
as cupy.
The order="C" option will index as z,y,x and perform better with cupy.
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074

Parameters
----------
self : amrex.MultiFab
A MultiFab class in pyAMReX
copy : bool, optional
Copy the data if true, otherwise create a view (default).
order : string, optional
F order (default) or C. C is faster with external libraries.

Returns
-------
list of cupy.array
A list of cupy n-dimensional arrays, for each local block in the
MultiFab.

Raises
------
ImportError
Raises an exception if cupy is not installed

"""
def to_numpy(self, copy=False, order="F"):
"""

Provide a Numpy view into a MultiFab.

Note on the order of indices:
By default, this is as in AMReX in Fortran contiguous order, indexing as
x,y,z. This has performance implications for use in external libraries such
as cupy.
The order="C" option will index as z,y,x and perform better with cupy.
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074

Parameters
----------
self : amrex.MultiFab
A MultiFab class in pyAMReX
copy : bool, optional
Copy the data if true, otherwise create a view (default).
order : string, optional
F order (default) or C. C is faster with external libraries.

Returns
-------
list of numpy.array
A list of numpy n-dimensional arrays, for each local block in the
MultiFab.

"""
def weighted_sync(self, arg0: MultiFab, arg1: Periodicity) -> None: ...

class PIdx:
Expand Down
2 changes: 2 additions & 0 deletions src/amrex/space2d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ def Print(*args, **kwargs):

from ..Array4 import register_Array4_extension
from ..ArrayOfStructs import register_AoS_extension
from ..MultiFab import register_MultiFab_extension
from ..PODVector import register_PODVector_extension
from ..StructOfArrays import register_SoA_extension

register_Array4_extension(amrex_2d_pybind)
register_MultiFab_extension(amrex_2d_pybind)
register_PODVector_extension(amrex_2d_pybind)
register_SoA_extension(amrex_2d_pybind)
register_AoS_extension(amrex_2d_pybind)
2 changes: 2 additions & 0 deletions src/amrex/space2d/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import os as os

from amrex.Array4 import register_Array4_extension
from amrex.ArrayOfStructs import register_AoS_extension
from amrex.MultiFab import register_MultiFab_extension
from amrex.PODVector import register_PODVector_extension
from amrex.StructOfArrays import register_SoA_extension
from amrex.space2d.amrex_2d_pybind import (
Expand Down Expand Up @@ -461,6 +462,7 @@ __all__ = [
"refine",
"register_AoS_extension",
"register_Array4_extension",
"register_MultiFab_extension",
"register_PODVector_extension",
"register_SoA_extension",
"size",
Expand Down
61 changes: 61 additions & 0 deletions src/amrex/space2d/amrex_2d_pybind/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4049,6 +4049,67 @@ class MultiFab(FabArray_FArrayBox):
"""
Same as sum with local=false, but for non-cell-centered data, thisskips non-unique points that are owned by multiple boxes.
"""
def to_cupy(self, copy=False, order="F"):
"""

Provide a Cupy view into a MultiFab.

Note on the order of indices:
By default, this is as in AMReX in Fortran contiguous order, indexing as
x,y,z. This has performance implications for use in external libraries such
as cupy.
The order="C" option will index as z,y,x and perform better with cupy.
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074

Parameters
----------
self : amrex.MultiFab
A MultiFab class in pyAMReX
copy : bool, optional
Copy the data if true, otherwise create a view (default).
order : string, optional
F order (default) or C. C is faster with external libraries.

Returns
-------
list of cupy.array
A list of cupy n-dimensional arrays, for each local block in the
MultiFab.

Raises
------
ImportError
Raises an exception if cupy is not installed

"""
def to_numpy(self, copy=False, order="F"):
"""

Provide a Numpy view into a MultiFab.

Note on the order of indices:
By default, this is as in AMReX in Fortran contiguous order, indexing as
x,y,z. This has performance implications for use in external libraries such
as cupy.
The order="C" option will index as z,y,x and perform better with cupy.
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074

Parameters
----------
self : amrex.MultiFab
A MultiFab class in pyAMReX
copy : bool, optional
Copy the data if true, otherwise create a view (default).
order : string, optional
F order (default) or C. C is faster with external libraries.

Returns
-------
list of numpy.array
A list of numpy n-dimensional arrays, for each local block in the
MultiFab.

"""
def weighted_sync(self, arg0: MultiFab, arg1: Periodicity) -> None: ...

class PIdx:
Expand Down
2 changes: 2 additions & 0 deletions src/amrex/space3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ def Print(*args, **kwargs):

from ..Array4 import register_Array4_extension
from ..ArrayOfStructs import register_AoS_extension
from ..MultiFab import register_MultiFab_extension
from ..PODVector import register_PODVector_extension
from ..StructOfArrays import register_SoA_extension

register_Array4_extension(amrex_3d_pybind)
register_MultiFab_extension(amrex_3d_pybind)
register_PODVector_extension(amrex_3d_pybind)
register_SoA_extension(amrex_3d_pybind)
register_AoS_extension(amrex_3d_pybind)
2 changes: 2 additions & 0 deletions src/amrex/space3d/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import os as os

from amrex.Array4 import register_Array4_extension
from amrex.ArrayOfStructs import register_AoS_extension
from amrex.MultiFab import register_MultiFab_extension
from amrex.PODVector import register_PODVector_extension
from amrex.StructOfArrays import register_SoA_extension
from amrex.space3d.amrex_3d_pybind import (
Expand Down Expand Up @@ -461,6 +462,7 @@ __all__ = [
"refine",
"register_AoS_extension",
"register_Array4_extension",
"register_MultiFab_extension",
"register_PODVector_extension",
"register_SoA_extension",
"size",
Expand Down
61 changes: 61 additions & 0 deletions src/amrex/space3d/amrex_3d_pybind/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4052,6 +4052,67 @@ class MultiFab(FabArray_FArrayBox):
"""
Same as sum with local=false, but for non-cell-centered data, thisskips non-unique points that are owned by multiple boxes.
"""
def to_cupy(self, copy=False, order="F"):
"""

Provide a Cupy view into a MultiFab.

Note on the order of indices:
By default, this is as in AMReX in Fortran contiguous order, indexing as
x,y,z. This has performance implications for use in external libraries such
as cupy.
The order="C" option will index as z,y,x and perform better with cupy.
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074

Parameters
----------
self : amrex.MultiFab
A MultiFab class in pyAMReX
copy : bool, optional
Copy the data if true, otherwise create a view (default).
order : string, optional
F order (default) or C. C is faster with external libraries.

Returns
-------
list of cupy.array
A list of cupy n-dimensional arrays, for each local block in the
MultiFab.

Raises
------
ImportError
Raises an exception if cupy is not installed

"""
def to_numpy(self, copy=False, order="F"):
"""

Provide a Numpy view into a MultiFab.

Note on the order of indices:
By default, this is as in AMReX in Fortran contiguous order, indexing as
x,y,z. This has performance implications for use in external libraries such
as cupy.
The order="C" option will index as z,y,x and perform better with cupy.
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074

Parameters
----------
self : amrex.MultiFab
A MultiFab class in pyAMReX
copy : bool, optional
Copy the data if true, otherwise create a view (default).
order : string, optional
F order (default) or C. C is faster with external libraries.

Returns
-------
list of numpy.array
A list of numpy n-dimensional arrays, for each local block in the
MultiFab.

"""
def weighted_sync(self, arg0: MultiFab, arg1: Periodicity) -> None: ...

class PIdx:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_multifab.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,13 @@ def test_mfab_dtoh_copy(make_mfab_device):
device_max = mfab_device.max(0)
assert device_min == device_max
assert device_max == 11.0

# numpy bindings (w/ copy)
local_boxes_host = mfab_device.to_numpy(copy=True)
assert max([np.max(box) for box in local_boxes_host]) == device_max

# cupy bindings (w/o copy)
import cupy as cp

local_boxes_device = mfab_device.to_cupy()
assert max([cp.max(box) for box in local_boxes_device]) == device_max