Skip to content

Commit 385ea44

Browse files
committed
fix passing arguments by ref
1 parent 0e83814 commit 385ea44

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

luisa_lang/_builtin_decor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _make_func_template(f: Callable[..., Any], func_name: str, func_sig: Optiona
8080
func_sig = classinfo.parse_func_signature(f, func_globals, [])
8181

8282
func_sig_converted, sig_parser = parse.convert_func_signature(
83-
func_sig, func_name, func_globals, foreign_type_var_ns, {}, self_type)
83+
func_sig, func_name, props, func_globals, foreign_type_var_ns, {}, self_type)
8484
implicit_type_params = sig_parser.implicit_type_params
8585
implicit_generic_params: Set[hir.GenericParameter] = set()
8686
for p in implicit_type_params.values():
@@ -119,7 +119,7 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.Function:
119119
mapped_implicit_type_params[name] = mapped_type
120120

121121
func_sig_instantiated, _p = parse.convert_func_signature(
122-
func_sig, func_name, func_globals, type_var_ns, mapped_implicit_type_params, self_type, mode='instantiate')
122+
func_sig, func_name, props, func_globals, type_var_ns, mapped_implicit_type_params, self_type, mode='instantiate')
123123
# print(func_name, func_sig)
124124
assert len(
125125
func_sig_instantiated.generic_params) == 0, f"generic params should be resolved but found {func_sig_instantiated.generic_params}"

luisa_lang/lang_builtins.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,11 @@ def device_assert(cond: bool, msg: str = "") -> typing.NoReturn:
121121
raise NotImplementedError(
122122
"device_assert should not be called in host-side Python code. ")
123123

124+
124125
def sizeof(t: type[T]) -> u64:
125-
raise NotImplementedError("sizeof should not be called in host-side Python code. ")
126+
raise NotImplementedError(
127+
"sizeof should not be called in host-side Python code. ")
128+
126129

127130
@overload
128131
def range(n: T) -> List[T]: ...
@@ -208,11 +211,11 @@ class Array(Generic[T, N]):
208211
def __init__(self) -> None:
209212
self = intrinsic("init.array", Array[T, N])
210213

211-
def __getitem__(self, index: int | u32 | u64) -> T:
214+
def __getitem__(self, index: int | i32 | u32 | i64 | u64) -> T:
212215
return intrinsic("array.ref", T, byref(self), index) # type: ignore
213216

214-
def __setitem__(self, index: int | u32 | u64, value: T) -> None:
215-
pass
217+
def __setitem__(self, index: int | i32 | u32 | i64 | u64, value: T | int | float) -> None:
218+
"""value: T | int | float annotation is to make mypy happy. this function is ignored by the compiler"""
216219

217220
def __len__(self) -> u64:
218221
return intrinsic("array.size", u64, self) # type: ignore
@@ -233,10 +236,11 @@ def __len__(self) -> u64:
233236

234237
@opaque("Buffer")
235238
class Buffer(Generic[T]):
236-
def __getitem__(self, index: int | u32 | u64) -> T:
239+
def __getitem__(self, index: int | i32 | u32 | i64 | u64) -> T:
237240
return intrinsic("buffer.ref", T, self, index) # type: ignore
238241

239-
def __setitem__(self, index: int | u32 | u64, value: T) -> None:
242+
def __setitem__(self, index: int | i32 | u32 | i64 | u64, value: T | int | float) -> None:
243+
"""value: T | int | float annotation is to make mypy happy. this function is ignored by the compiler"""
240244
pass
241245

242246
def __len__(self) -> u64:

luisa_lang/parse.py

+3
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def handle_type_t():
153153

154154
def convert_func_signature(signature: classinfo.MethodType,
155155
ctx_name: str,
156+
props:hir.FuncProperties,
156157
globalns: Dict[str, Any],
157158
type_var_ns: Dict[typing.TypeVar, hir.Type],
158159
implicit_type_params: Dict[str, hir.Type],
@@ -173,6 +174,8 @@ def convert_func_signature(signature: classinfo.MethodType,
173174
assert self_type is not None
174175
param_type = self_type
175176
semantic = hir.ParameterSemantic.BYREF
177+
if arg[0] in props.byref:
178+
semantic = hir.ParameterSemantic.BYREF
176179
if param_type is None:
177180
raise RuntimeError(
178181
f"Unable to parse type of parameter {arg[0]}: {arg[1]}")

0 commit comments

Comments
 (0)