diff --git a/cupyx/jit/_compile.py b/cupyx/jit/_compile.py index ad4732f4a65..9197a652ecc 100644 --- a/cupyx/jit/_compile.py +++ b/cupyx/jit/_compile.py @@ -67,7 +67,7 @@ def reraise(self, pycode): def transpile_function_wrapper(func): - def new_func(node, *args, **kwargs): + def new_func(node: ast.AST, *args, **kwargs): try: return func(node, *args, **kwargs) except _JitCompileError: @@ -528,6 +528,11 @@ def _transpile_stmt( 'Nested functions are not supported currently.') if isinstance(stmt, ast.Return): value = _transpile_expr(stmt.value, env) + + if isinstance(value, Constant) and value.obj is None: + # `return None` or `return` without value + return ['return;'] + value = Data.init(value, env) t = value.ctype if env.ret_type is None: @@ -686,7 +691,7 @@ def _transpile_expr(expr: ast.expr, env: Environment) -> _internal_types.Expr: def _transpile_expr_internal( - expr: ast.expr, + expr: Optional[ast.expr], env: Environment, ) -> _internal_types.Expr: if isinstance(expr, ast.BoolOp): @@ -789,8 +794,11 @@ def _transpile_expr_internal( raise TypeError(f"Invalid function call '{func.__name__}'.") + if expr is None: + return Constant(None) if isinstance(expr, ast.Constant): return Constant(expr.value) + if isinstance(expr, ast.Subscript): array = _transpile_expr(expr.value, env) index = _transpile_expr(expr.slice, env) diff --git a/tests/cupyx_tests/jit_tests/test_raw.py b/tests/cupyx_tests/jit_tests/test_raw.py index c26666cc109..76b1c92e332 100644 --- a/tests/cupyx_tests/jit_tests/test_raw.py +++ b/tests/cupyx_tests/jit_tests/test_raw.py @@ -794,3 +794,27 @@ def f(x): x = cupy.zeros((30), dtype=dtype) f((5,), (6,), (x,)) testing.assert_array_equal(x, numpy.full_like(x, cupy.nan)) + + def test_return_empty(self): + @jit.rawkernel() + def f(x, y): + tid = jit.threadIdx.x + jit.blockDim.x * jit.blockIdx.x + y[tid] = x[tid] + return + + x = testing.shaped_random((30,), dtype=numpy.int32, seed=0) + y = testing.shaped_random((30,), dtype=numpy.int32, seed=1) + f((5,), (6,), (x, y)) + assert bool((x == y).all()) + + def test_return_none(self): + @jit.rawkernel() + def f(x, y): + tid = jit.threadIdx.x + jit.blockDim.x * jit.blockIdx.x + y[tid] = x[tid] + return None + + x = testing.shaped_random((30,), dtype=numpy.int32, seed=0) + y = testing.shaped_random((30,), dtype=numpy.int32, seed=1) + f((5,), (6,), (x, y)) + assert bool((x == y).all())