Skip to content

Commit

Permalink
Merge pull request cupy#6217 from chainer-ci/bp-6170-v10-vectorize-la…
Browse files Browse the repository at this point in the history
…mbda

[backport] Support lambda function in `cupy.vectorize`
  • Loading branch information
kmaehashi authored Dec 8, 2021
2 parents 81dc54b + ce09e37 commit 8ab830d
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 33 deletions.
2 changes: 1 addition & 1 deletion cupy/_functional/vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __call__(self, *args):
f'{t.dtype} in{i}' for i, t in enumerate(in_types))
in_args = ', '.join([f'in{i}' for i in range(len(in_types))])
out_params, out_lval = self._parse_out_param(result.return_type)
body = '{} = {}({})'.format(out_lval, func.name, in_args)
body = '{} = {}({})'.format(out_lval, result.func_name, in_args)
# note: we don't worry about -D not working on ROCm here, because
# we unroll all headers for HIP and so thrust::tuple et al are all
# defined regardless if CUPY_JIT_MODE is defined or not
Expand Down
4 changes: 1 addition & 3 deletions cupyx/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,4 @@
from cupyx.jit._builtin_funcs import shfl_down_sync # NOQA
from cupyx.jit._builtin_funcs import shfl_xor_sync # NOQA

import inspect as _inspect

_getsource_func = _inspect.getsource # NOQA
_getsource_func = None # NOQA
79 changes: 58 additions & 21 deletions cupyx/jit/_compile.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ast
import collections
import inspect
import math
import numbers
import re
import sys
Expand Down Expand Up @@ -54,6 +55,59 @@ def new_func(node, *args, **kwargs):
return new_func


def _parse_function_object(func):
# Parses function object into ast.FunctionDef object.
if not callable(func):
raise ValueError('`func` must be a callable object.')

if func.__name__ != '<lambda>':
if jit._getsource_func is None:
lines = inspect.getsource(func).split('\n')
num_indent = len(lines[0]) - len(lines[0].lstrip())
source = '\n'.join([
line.replace(' ' * num_indent, '', 1) for line in lines])
else:
source = jit._getsource_func(func)
tree = ast.parse(source)
assert isinstance(tree, ast.Module)
assert len(tree.body) == 1
return tree.body[0], source

if jit._getsource_func is not None:
full_source = jit._getsource_func(func)
start_line, end_line = 0, math.inf
source = full_source
else:
try:
filename = inspect.getsourcefile(func)
except TypeError:
filename = None
if filename is None:
raise ValueError(f'JIT needs access to Python source for {func}'
'but could not be located')
with open(filename) as f:
full_source = f.read()
source, start_line = inspect.getsourcelines(func)
end_line = start_line + len(source)
source = ''.join(source)

tree = ast.parse(full_source)

nodes = [node for node in ast.walk(tree)
if isinstance(node, ast.Lambda)
and start_line <= node.lineno < end_line]
if len(nodes) > 1:
raise ValueError('Multiple callables are found near the'
f' definition of {func}, and JIT could not'
' identify the source code for it.')
node = nodes[0]
return ast.FunctionDef(
name='_lambda_kernel', args=node.args,
body=[ast.Return(node.body)],
decorator_list=[], returns=None, type_comment=None,
), source


def transpile(func, attributes, mode, in_types, ret_type):
"""Transpile the target function
Args:
Expand All @@ -63,32 +117,15 @@ def transpile(func, attributes, mode, in_types, ret_type):
in_types (list of _cuda_types.TypeBase): Types of the arguments.
ret_type (_cuda_types.TypeBase or None): Type of the return value.
"""

if not callable(func):
raise ValueError('`func` must be a callable object.')

if func.__name__ == '<lambda>':
raise NotImplementedError('Lambda function is not supported.')

attributes = ' '.join(attributes)
source = jit._getsource_func(func)
lines = source.split('\n')
num_indent = len(lines[0]) - len(lines[0].lstrip())
source = '\n'.join([
line.replace(' ' * num_indent, '', 1) for line in lines])

cvars = inspect.getclosurevars(func)
consts = dict(**cvars.globals, **cvars.nonlocals, **cvars.builtins)
tree = ast.parse(source)
assert isinstance(tree, ast.Module)
assert len(tree.body) == 1
attributes = ' '.join(attributes)
tree, source = _parse_function_object(func)
cuda_code, env = _transpile_function(
tree.body[0], attributes, mode, consts, in_types, ret_type,
source=source
)
tree, attributes, mode, consts, in_types, ret_type, source=source)
cuda_code = ''.join([code + '\n' for code in env.preambles]) + cuda_code
return Result(
func_name=func.__name__,
func_name=tree.name,
code=cuda_code,
return_type=env.ret_type,
)
Expand Down
16 changes: 16 additions & 0 deletions tests/cupy_tests/functional_tests/test_vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,3 +615,19 @@ def my_func(x1, x2, x3):
x2 = testing.shaped_random((20, 30), xp, dtype, seed=2)
x3 = testing.shaped_random((20, 30), xp, dtype, seed=3)
return f(x1, x2, x3)

@testing.numpy_cupy_array_equal()
def test_vectorize_lambda(self, xp):
f = xp.vectorize(lambda a, b, c: a + b * c)
x1 = testing.shaped_random((20, 30), xp, numpy.int64, seed=1)
x2 = testing.shaped_random((20, 30), xp, numpy.int64, seed=2)
x3 = testing.shaped_random((20, 30), xp, numpy.int64, seed=3)
return f(x1, x2, x3)

def test_vectorize_lambda_xfail(self):
functions = [lambda a, b: a + b, lambda a, b: a * b]
f = cupy.vectorize(functions[0])
x1 = testing.shaped_random((20, 30), cupy, numpy.int64, seed=1)
x2 = testing.shaped_random((20, 30), cupy, numpy.int64, seed=2)
with pytest.raises(ValueError, match='Multiple callables are found'):
return f(x1, x2)
9 changes: 1 addition & 8 deletions tests/cupyx_tests/jit_tests/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,15 +618,8 @@ def test_error_msg(self):
def f(x):
return unknown_var # NOQA

import re
mes = re.escape('''Unbound name: unknown_var
@jit.rawkernel()
def f(x):
> return unknown_var # NOQA
''')
x = cupy.zeros((10,), dtype=numpy.float32)
with pytest.raises(NameError, match=mes):
with pytest.raises(NameError, match='Unbound name: unknown_var'):
f((1,), (1,), (x,))

def test_laneid(self):
Expand Down

0 comments on commit 8ab830d

Please sign in to comment.