diff --git a/acc/components/ComponentBase.py b/acc/components/ComponentBase.py index 856e33f9..235c72f8 100644 --- a/acc/components/ComponentBase.py +++ b/acc/components/ComponentBase.py @@ -129,24 +129,30 @@ def register_propagate_method(cls, propagate): return new_propagate @classmethod - def _adjust_propagate_type(cls, propagate): + def _change_kernel_floattype(cls, kernel_func): + ''' + Changes the floattype of kernel_func based on the Component floattype + kernel_func should be a DeviceFunction object returned by cuda.jit + Returns a new DeviceFunction + ''' # disable float switching if in cudasim mode if config.ENABLE_CUDASIM: - return propagate + return kernel_func - if not isinstance(propagate, DeviceFunction): + if not isinstance(kernel_func, DeviceFunction): raise RuntimeError( - "invalid propagate function ({}, {}) registered, ".format( - propagate, type(propagate)) - + "does propagate have a signature defined?") + "invalid kernel function ({}, {}), ".format( + kernel_func, type(kernel_func)) + + "does the function have a signature defined?") - args = propagate.args + args = kernel_func.args # reconstruct the numba args with the correct floattype newargs = [] for arg in args: if isinstance(arg, Array) and isinstance(arg.dtype, Float): - newargs.append(arg.copy(dtype=getattr(numba, cls._floattype))) + newargs.append( + arg.copy(dtype=getattr(numba, cls._floattype))) elif isinstance(arg, Float): newargs.append(Float(name=cls._floattype)) else: @@ -156,19 +162,42 @@ def _adjust_propagate_type(cls, propagate): # DeviceFunction in Numba < 0.54.1 does not have a lineinfo property if int(numba.__version__.split(".")[1]) < 54: - new_propagate = DeviceFunction(pyfunc=propagate.py_func, - return_type=propagate.return_type, - args=newargs, - inline=propagate.inline, - debug=propagate.debug) + new_func = DeviceFunction(pyfunc=kernel_func.py_func, + return_type=kernel_func.return_type, + args=newargs, + inline=kernel_func.inline, + debug=kernel_func.debug) else: - new_propagate = DeviceFunction(pyfunc=propagate.py_func, - return_type=propagate.return_type, - args=newargs, - inline=propagate.inline, - debug=propagate.debug, - lineinfo=propagate.lineinfo) - #cls.print_kernel_info(new_propagate) + new_func = DeviceFunction(pyfunc=kernel_func.py_func, + return_type=kernel_func.return_type, + args=newargs, + inline=kernel_func.inline, + debug=kernel_func.debug, + lineinfo=kernel_func.lineinfo) + + return new_func + + @classmethod + def _adjust_propagate_type(cls, propagate): + # disable float switching if in cudasim mode + if config.ENABLE_CUDASIM: + return propagate + + if not isinstance(propagate, DeviceFunction): + raise RuntimeError( + "invalid propagate function ({}, {}) registered, ".format( + propagate, type(propagate)) + + "does propagate have a signature defined?") + + # adjust float types of any device function that propagate calls + for func in propagate.py_func.__globals__: + if isinstance(propagate.py_func.__globals__[func], DeviceFunction): + propagate.py_func.__globals__[func] = \ + cls._change_kernel_floattype( + propagate.py_func.__globals__[func]) + + new_propagate = cls._change_kernel_floattype(propagate) + return new_propagate @classmethod diff --git a/tests/components/test_ComponentBase.py b/tests/components/test_ComponentBase.py index 74ea5627..5fbc0b54 100644 --- a/tests/components/test_ComponentBase.py +++ b/tests/components/test_ComponentBase.py @@ -3,12 +3,23 @@ import pytest import numba -from numba import cuda, void, float32 +from numba import cuda, void, float32, float64 from numba.core.types import Array, Float from mcvine.acc.components.ComponentBase import ComponentBase +#from mcvine.acc.neutron import absorb from mcvine.acc import test +@cuda.jit(float64(float64), device=True) +def global_kernel(x): + return x + +@cuda.jit(float64(float64), device=True) +def nested_global_kernel(x): + y = 1.0 + global_kernel(x) + return y + + def test_no_propagate_raises(): with pytest.raises(TypeError): # check creating a component with no propagate method raises error @@ -80,3 +91,111 @@ def propagate(x, y): assert isinstance(args[1], Array) assert args[1].dtype == float32 + +@pytest.mark.skipif(not test.USE_CUDA, reason='No CUDA') +def test_propagate_global_function_changed(): + # check that propagate arguments are changed from float64 -> float32 + NB_FLOAT = getattr(numba, "float64") + class Component(ComponentBase): + def __init__(self, **kwargs): + return + + @cuda.jit(void(NB_FLOAT, NB_FLOAT[:]), device=True) + def propagate(x, y): + y[0] = global_kernel(x) + + component = Component() + Component.change_floattype("float32") + assert component.floattype == "float32" + + # check that the class wide attributes are changed + assert Component.get_floattype() == "float32" + assert Component.process_kernel is not None + args = Component.propagate.args + assert len(args) == 2 + + assert isinstance(args[0], Float) + assert args[0].bitwidth == 32 + assert isinstance(args[1], Array) + assert args[1].dtype == float32 + + # check that the global kernel function args are changed + args = global_kernel.args + assert len(args) == 1 + + assert isinstance(args[0], Float) + assert args[0].bitwidth == 32 + + +@pytest.mark.skipif(not test.USE_CUDA, reason='No CUDA') +def test_propagate_nested_global_function_changed(): + # check that propagate arguments are changed from float64 -> float32 + NB_FLOAT = getattr(numba, "float64") + class Component(ComponentBase): + def __init__(self, **kwargs): + return + + @cuda.jit(void(NB_FLOAT, NB_FLOAT[:]), device=True) + def propagate(x, y): + y[0] = nested_global_kernel(x) + + component = Component() + Component.change_floattype("float32") + assert component.floattype == "float32" + + # check that the class wide attributes are changed + assert Component.get_floattype() == "float32" + assert Component.process_kernel is not None + args = Component.propagate.args + assert len(args) == 2 + + assert isinstance(args[0], Float) + assert args[0].bitwidth == 32 + assert isinstance(args[1], Array) + assert args[1].dtype == float32 + + # check that the nested kernel function args are changed + args = nested_global_kernel.args + assert len(args) == 1 + + assert isinstance(args[0], Float) + assert args[0].bitwidth == 32 + + +@pytest.mark.skipif(not test.USE_CUDA, reason='No CUDA') +def test_propagate_local_function_changed(): + # check that propagate arguments are changed from float64 -> float32 + NB_FLOAT = getattr(numba, "float64") + @cuda.jit(NB_FLOAT(NB_FLOAT, NB_FLOAT), device=True) + def helper_kernel(x, y): + return x * y + + class Component(ComponentBase): + def __init__(self, **kwargs): + return + + @cuda.jit(void(NB_FLOAT, NB_FLOAT[:]), device=True) + def propagate(x, y): + y[0] = helper_kernel(x, x) + + component = Component() + Component.change_floattype("float32") + assert component.floattype == "float32" + + # check that the class wide attributes are changed + assert Component.get_floattype() == "float32" + assert Component.process_kernel is not None + args = Component.propagate.args + assert len(args) == 2 + + assert isinstance(args[0], Float) + assert args[0].bitwidth == 32 + assert isinstance(args[1], Array) + assert args[1].dtype == float32 + + # check that the local kernel function args are changed + args = helper_kernel.args + assert len(args) == 2 + for arg in args: + assert isinstance(arg, Float) + assert arg.bitwidth == 32