Skip to content

Commit

Permalink
Merge pull request cupy#8989 from asi1024/empty-return
Browse files Browse the repository at this point in the history
JIT: Support empty return
  • Loading branch information
asi1024 authored and chainer-ci committed Feb 28, 2025
1 parent 4a84f28 commit ec21e44
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
12 changes: 10 additions & 2 deletions cupyx/jit/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def reraise(self, pycode):


def transpile_function_wrapper(func):
def new_func(node, *args, **kwargs):
def new_func(node: ast.AST, *args, **kwargs):
try:
return func(node, *args, **kwargs)
except _JitCompileError:
Expand Down Expand Up @@ -528,6 +528,11 @@ def _transpile_stmt(
'Nested functions are not supported currently.')
if isinstance(stmt, ast.Return):
value = _transpile_expr(stmt.value, env)

if isinstance(value, Constant) and value.obj is None:
# `return None` or `return` without value
return ['return;']

value = Data.init(value, env)
t = value.ctype
if env.ret_type is None:
Expand Down Expand Up @@ -686,7 +691,7 @@ def _transpile_expr(expr: ast.expr, env: Environment) -> _internal_types.Expr:


def _transpile_expr_internal(
expr: ast.expr,
expr: Optional[ast.expr],
env: Environment,
) -> _internal_types.Expr:
if isinstance(expr, ast.BoolOp):
Expand Down Expand Up @@ -789,8 +794,11 @@ def _transpile_expr_internal(

raise TypeError(f"Invalid function call '{func.__name__}'.")

if expr is None:
return Constant(None)
if isinstance(expr, ast.Constant):
return Constant(expr.value)

if isinstance(expr, ast.Subscript):
array = _transpile_expr(expr.value, env)
index = _transpile_expr(expr.slice, env)
Expand Down
24 changes: 24 additions & 0 deletions tests/cupyx_tests/jit_tests/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,3 +794,27 @@ def f(x):
x = cupy.zeros((30), dtype=dtype)
f((5,), (6,), (x,))
testing.assert_array_equal(x, numpy.full_like(x, cupy.nan))

def test_return_empty(self):
@jit.rawkernel()
def f(x, y):
tid = jit.threadIdx.x + jit.blockDim.x * jit.blockIdx.x
y[tid] = x[tid]
return

x = testing.shaped_random((30,), dtype=numpy.int32, seed=0)
y = testing.shaped_random((30,), dtype=numpy.int32, seed=1)
f((5,), (6,), (x, y))
assert bool((x == y).all())

def test_return_none(self):
@jit.rawkernel()
def f(x, y):
tid = jit.threadIdx.x + jit.blockDim.x * jit.blockIdx.x
y[tid] = x[tid]
return None

x = testing.shaped_random((30,), dtype=numpy.int32, seed=0)
y = testing.shaped_random((30,), dtype=numpy.int32, seed=1)
f((5,), (6,), (x, y))
assert bool((x == y).all())

0 comments on commit ec21e44

Please sign in to comment.