Skip to content

Commit d4c1044

Browse files
committed
generics works
1 parent 87b0988 commit d4c1044

File tree

8 files changed

+391
-213
lines changed

8 files changed

+391
-213
lines changed

README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ import luisa_lang as lc
2626

2727

2828
### Functions
29-
Functions are defined using the `@lc.func` decorator. The function body can contain any valid LuisaCompute code. You can also include normal Python code that will be executed at DSL comile time using `lc.constexpr()`. (See [Metaprogramming](#metaprogramming) for more details)
29+
Functions are defined using the `@lc.func` decorator. The function body can contain any valid LuisaCompute code. You can also include normal Python code that will be executed at DSL comile time using `lc.comptime()`. (See [Metaprogramming](#metaprogramming) for more details)
3030

3131
```python
3232
@lc.func
3333
def add(a: lc.float, b: lc.float) -> lc.float:
34-
with lc.constexpr():
34+
with lc.comptime():
3535
print('compiliing add function')
3636
return a + b
3737

@@ -69,8 +69,8 @@ luisa_lang provides a metaprogramming feature similar to C++ that allows users t
6969
# Compile time reflection
7070
@lc.func
7171
def get_x_or_zero(x: Any):
72-
t = lc.constexpr(type(x))
73-
if lc.constexpr(hasattr(t, 'x')):
72+
t = lc.comptime(type(x))
73+
if lc.comptime(hasattr(t, 'x')):
7474
return x.x
7575
else:
7676
return 0.0
@@ -86,7 +86,7 @@ def apply_func(f: F, x: T):
8686
# Generate code at compile time
8787
@lc.func
8888
def call_n_times(f: F):
89-
with lc.constexpr():
89+
with lc.comptime():
9090
n = input('how many times to call?')
9191
for i in range(n):
9292
# lc.embed_code(expr) will generate add expr to the DSL code
@@ -95,12 +95,13 @@ def call_n_times(f: F):
9595
lc.embed_code('apply_func(f, i)')
9696

9797
# Hint a parameter is constexpr
98-
@lc.func(n=lc.constexpr) # without this, n will be treated as a runtime variable and result in an error
98+
@lc.func(n=lc.comptime) # without this, n will be treated as a runtime variable and result in an error
9999
def pow(x: lc.float, n: int) -> lc.float:
100100
p = 1.0
101-
with lc.constexpr():
101+
with lc.comptime():
102102
for _ in range(n):
103103
lc.embed_code('p *= x')
104+
return p
104105
```
105106
### Limitation & Caveats
106107
- Lambda and nested function do not support updating nonlocal variables.

luisa_lang/_builtin_decor.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,31 @@ def decorator(cls: _T) -> _T:
2323

2424
def make_type_rule(
2525
name: str, method: MethodType
26-
) -> Callable[[List[hir.Type]], hir.Type]:
26+
) -> hir.BuiltinTypeRule:
2727

2828
# # print(f'{cls_name}.{name}', signature)
2929
member = getattr(cls, name)
3030
signature = inspect.signature(member, globals=globalns)
3131
type_hints = typing.get_type_hints(member, globalns=globalns)
3232
parameters = signature.parameters
3333
return_type = method.return_type
34+
semantics: List[hir.ParameterSemantic] = []
3435
if not isinstance(return_type, type):
3536
raise hir.TypeInferenceError(None,
3637
f"Valid return type annotation required for {cls_name}.{name}"
3738
)
39+
parameters_list = list(parameters.values())
40+
for i, arg in enumerate(args):
41+
param = parameters_list[i]
42+
if param.name == "self":
43+
# self is always passed by reference
44+
semantics.append(hir.ParameterSemantic.BYREF)
45+
else:
46+
# other parameters are passed by value
47+
semantics.append(hir.ParameterSemantic.BYVAL)
3848

3949
def type_rule(args: List[hir.Type]) -> hir.Type:
4050

41-
parameters_list = list(parameters.values())
4251
if len(args) > len(parameters_list):
4352
raise hir.TypeInferenceError(None,
4453
f"Too many arguments for {cls_name}.{name} expected at most {len(parameters_list)} but got {len(args)}"
@@ -55,6 +64,7 @@ def type_rule(args: List[hir.Type]) -> hir.Type:
5564
raise hir.TypeInferenceError(None,
5665
f"Expected {cls_name}.{name} to be called with an instance of {cls_name} but got {arg}"
5766
)
67+
5868
continue
5969
if param_ty is None:
6070
raise hir.TypeInferenceError(None,
@@ -90,14 +100,14 @@ def check(anno_tys: List[type | Any]):
90100
else:
91101
return hir.UnitType()
92102

93-
return type_rule
103+
return hir.BuiltinTypeRule(type_rule, semantics)
94104

95105
def make_builtin():
96106
for name, member in cls_info.methods.items():
97107
type_rule = make_type_rule(name, member)
98108
builtin = hir.BuiltinFunction(
99109
f"{cls_name}.{name}",
100-
hir.TypeRule.from_fn(type_rule),
110+
type_rule,
101111
)
102112
ty.methods[name] = builtin
103113

luisa_lang/classinfo.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -228,19 +228,20 @@ def get_type_vars(func: typing.Callable) -> List[TypeVar]:
228228
return list(set(type_vars)) # Return unique type vars
229229

230230

231-
def parse_func_signature(func: object, globalns: Dict[str, Any], foreign_type_vars: List[TypeVar], self_type: Optional[VarType] = None, is_static: bool = False) -> MethodType:
231+
def parse_func_signature(func: object, globalns: Dict[str, Any], foreign_type_vars: List[TypeVar], is_static: bool = False) -> MethodType:
232232
assert inspect.isfunction(func)
233233
signature = inspect.signature(func)
234234
method_type_hints = typing.get_type_hints(func, globalns)
235235
param_types: List[Tuple[str, VarType]] = []
236236
type_vars = get_type_vars(func)
237237
for param in signature.parameters.values():
238238
if param.name == "self":
239-
assert self_type is not None
240-
param_types.append((param.name, self_type))
241-
else:
239+
param_types.append((param.name, SelfType()))
240+
elif param.name in method_type_hints:
242241
param_types.append((param.name, parse_type_hint(
243242
method_type_hints[param.name])))
243+
else:
244+
param_types.append((param.name, AnyType()))
244245
if "return" in method_type_hints:
245246
return_type = parse_type_hint(method_type_hints.get("return"))
246247
else:
@@ -312,12 +313,11 @@ def register_class(cls: type) -> None:
312313
if type_vars:
313314
for tv in type_vars:
314315
cls_ty.type_vars.append(tv)
315-
self_ty: VarType = SelfType()
316316
for name, member in inspect.getmembers(cls):
317317
if name in local_methods:
318318
# print(f'Found local method: {name} in {cls}')
319319
cls_ty.methods[name] = parse_func_signature(
320-
member, globalns, cls_ty.type_vars, self_ty, is_static=is_static(cls, name))
320+
member, globalns, cls_ty.type_vars, is_static=is_static(cls, name))
321321
for name in local_fields:
322322
cls_ty.fields[name] = parse_type_hint(type_hints[name])
323323
_CLS_TYPE_INFO[cls] = cls_ty

luisa_lang/codegen/cpp.py

Lines changed: 76 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def gen_impl(self, ty: hir.Type) -> str:
4242
self.impl.writeln(f' {self.gen(field[1])} {field[0]};')
4343
self.impl.writeln('};')
4444
return name
45+
case hir.UnitType():
46+
return 'void'
4547
case _:
4648
raise NotImplementedError(f"unsupported type: {ty}")
4749

@@ -101,6 +103,8 @@ def mangle(self, obj: Union[hir.Type, hir.FunctionLike]) -> str:
101103
def mangle_impl(self, obj: Union[hir.Type, hir.FunctionLike]) -> str:
102104

103105
match obj:
106+
case hir.UnitType():
107+
return 'u'
104108
case hir.IntType(bits=bits, signed=signed):
105109
if signed:
106110
return f"i{bits}"
@@ -177,10 +181,10 @@ class FuncCodeGen:
177181
def gen_var(self, var: hir.Var) -> str:
178182
assert var.type
179183
ty = self.base.type_cache.gen(var.type)
180-
if var.byval:
184+
if var.semantic == hir.ParameterSemantic.BYVAL:
181185
return f"{ty} {var.name}"
182186
else:
183-
return f"{ty}& {var.name}"
187+
return f"{ty} & {var.name}"
184188

185189
def __init__(self, base: CppCodeGen, func: hir.Function) -> None:
186190
self.base = base
@@ -199,17 +203,18 @@ def new_vid(self) -> int:
199203
return self.vid_cnt
200204

201205
def gen_ref(self, ref: hir.Ref) -> str:
206+
if ref in self.node_map:
207+
return self.node_map[ref]
202208
match ref:
203209
case hir.Var() as var:
204210
return var.name
205-
case hir.Member() as member:
206-
if isinstance(member.base, hir.Ref):
207-
base = self.gen_ref(member.base)
208-
else:
209-
base = self.gen_expr(member.base)
211+
case hir.MemberRef() as member:
212+
base = self.gen_ref(member.base)
210213
return f"{base}.{member.field}"
211-
case hir.ValueRef() as value_ref:
212-
return self.gen_expr(value_ref.value)
214+
case hir.IndexRef() as index:
215+
base = self.gen_ref(index.base)
216+
idx = self.gen_expr(index.index)
217+
return f"{base}[{idx}]"
213218
case _:
214219
raise NotImplementedError(f"unsupported reference: {ref}")
215220

@@ -232,62 +237,51 @@ def gen_value_or_ref(self, value: hir.Value | hir.Ref) -> str:
232237
f"unsupported value or reference: {value}")
233238

234239
def gen_expr(self, expr: hir.Value) -> str:
235-
match expr:
236-
case hir.Load() as load:
237-
return self.gen_ref(load.ref)
238-
case hir.Call() as call:
239-
op = self.gen_func(call.op)
240-
return f"{op}({','.join(self.gen_value_or_ref(arg) for arg in call.args)})"
241-
case hir.Constant() as constant:
242-
value = constant.value
243-
if isinstance(value, int):
244-
return str(value)
245-
elif isinstance(value, float):
246-
return str(value)
247-
elif isinstance(value, bool):
248-
return "true" if value else "false"
249-
elif isinstance(value, str):
250-
return f'"{value}"'
251-
elif isinstance(value, hir.Function) or isinstance(value, hir.BuiltinFunction):
252-
return self.gen_func(value)
253-
else:
240+
if expr in self.node_map:
241+
return self.node_map[expr]
242+
vid = self.new_vid()
243+
244+
def impl() -> None:
245+
match expr:
246+
case hir.Load() as load:
247+
self.body.writeln(
248+
f"const auto &v{vid} = {self.gen_ref(load.ref)};")
249+
case hir.Index() as index:
250+
base = self.gen_expr(index.base)
251+
idx = self.gen_expr(index.index)
252+
self.body.writeln(f"const auto v{vid} = {base}[{idx}];")
253+
case hir.Member() as member:
254+
base = self.gen_expr(member.base)
255+
self.body.writeln(
256+
f"const auto v{vid} = {base}.{member.field};")
257+
case hir.Call() as call:
258+
op = self.gen_func(call.op)
259+
self.body.writeln(
260+
f"auto v{vid} ={op}({','.join(self.gen_value_or_ref(arg) for arg in call.args)});")
261+
case hir.Constant() as constant:
262+
value = constant.value
263+
if isinstance(value, int):
264+
self.body.writeln(f"const auto v{vid} = {value};")
265+
elif isinstance(value, float):
266+
self.body.writeln(f"const auto v{vid} = {value};")
267+
elif isinstance(value, bool):
268+
s = "true" if value else "false"
269+
self.body.writeln(f"const auto v{vid} = {s};")
270+
elif isinstance(value, str):
271+
self.body.writeln(f"const auto v{vid} = \"{value}\";")
272+
elif isinstance(value, hir.Function) or isinstance(value, hir.BuiltinFunction):
273+
name = self.gen_func(value)
274+
self.body.writeln(f"auto&& v{vid} = {name};")
275+
else:
276+
raise NotImplementedError(
277+
f"unsupported constant: {constant}")
278+
case _:
254279
raise NotImplementedError(
255-
f"unsupported constant: {constant}")
256-
case hir.Init() as init:
257-
return f"([&]() {{ {self.gen_expr(init.value)}; }})()"
258-
case _:
259-
raise NotImplementedError(f"unsupported expression: {expr}")
260-
261-
# def gen_stmt(self, stmt: hir.Stmt):
262-
# match stmt:
263-
# case hir.Return() as ret:
264-
# if ret.value:
265-
# self.body.writeln(f"return {self.gen_expr(ret.value)};")
266-
# else:
267-
# self.body.writeln("return;")
268-
# case hir.Assign() as assign:
269-
# ref = self.gen_ref(assign.ref)
270-
# value = self.gen_expr(assign.value)
271-
# self.body.writeln(f"{ref} = {value};")
272-
# case hir.If() as if_stmt:
273-
# cond = self.gen_expr(if_stmt.cond)
274-
# self.body.writeln(f"if ({cond}) {{")
275-
# self.body.indent += 1
276-
# for stmt in if_stmt.then_body:
277-
# self.gen_stmt(stmt)
278-
# self.body.indent -= 1
279-
# self.body.writeln("}")
280-
# if if_stmt.else_body:
281-
# self.body.writeln("else {")
282-
# self.body.indent += 1
283-
# for stmt in if_stmt.else_body:
284-
# self.gen_stmt(stmt)
285-
# self.body.indent -= 1
286-
# self.body.writeln("}")
287-
# case hir.VarDecl() as var_decl:
288-
# pass
289-
# case _:
290-
# raise NotImplementedError(f"unsupported statement: {stmt}")
280+
f"unsupported expression: {expr}")
281+
282+
impl()
283+
self.node_map[expr] = f'v{vid}'
284+
return f'v{vid}'
291285

292286
def gen_node(self, node: hir.Node):
293287
match node:
@@ -303,43 +297,55 @@ def gen_node(self, node: hir.Node):
303297
case hir.If() as if_stmt:
304298
cond = self.gen_expr(if_stmt.cond)
305299
self.body.writeln(f"if ({cond})")
300+
self.body.indent += 1
306301
self.gen_bb(if_stmt.then_body)
302+
self.body.indent -= 1
307303
if if_stmt.else_body:
308304
self.body.writeln("else")
305+
self.body.indent += 1
309306
self.gen_bb(if_stmt.else_body)
307+
self.body.indent -= 1
310308
self.gen_bb(if_stmt.merge)
311309
case hir.Loop() as loop:
312310
vid = self.new_vid()
313311
self.body.write(f"auto loop{vid}_prepare = [&]()->bool {{")
314-
self.gen_bb(loop.prepare)
312+
self.body.indent += 1
313+
self.gen_bb(loop.prepare)
315314
if loop.cond:
316315
self.body.writeln(f"return {self.gen_expr(loop.cond)};")
317316
else:
318317
self.body.writeln("return true;")
318+
self.body.indent -=1
319319
self.body.writeln("};")
320320
self.body.writeln(f"auto loop{vid}_body = [&]() {{")
321+
self.body.indent += 1
321322
self.gen_bb(loop.body)
323+
self.body.indent -= 1
322324
self.body.writeln("};")
323325
self.body.writeln(f"auto loop{vid}_update = [&]() {{")
326+
self.body.indent += 1
324327
if loop.update:
325328
self.gen_bb(loop.update)
329+
self.body.indent -= 1
326330
self.body.writeln("};")
327331
self.body.writeln(
328332
f"for(;loop{vid}_prepare();loop{vid}_update());")
329333
self.gen_bb(loop.merge)
330334
case hir.Alloca() as alloca:
331-
pass
332-
case hir.Call() as call:
333-
self.gen_expr(call)
335+
vid = self.new_vid()
336+
assert alloca.type
337+
ty = self.base.type_cache.gen(alloca.type)
338+
self.body.writeln(f"{ty} v{vid}{{}};")
339+
self.node_map[alloca] = f"v{vid}"
340+
case hir.Call() | hir.Constant() | hir.Load() | hir.Index() | hir.Member():
341+
self.gen_expr(node)
334342
case hir.Member() | hir.Index():
335343
pass
336344

337345
def gen_bb(self, bb: hir.BasicBlock):
338346
self.body.writeln(f"{{ // BasicBlock Begin {bb.span}")
339-
self.body.indent += 1
340347
for node in bb.nodes:
341348
self.gen_node(node)
342-
self.body.indent -= 1
343349
self.body.writeln(f"}} // BasicBlock End {bb.span}")
344350

345351
def gen_locals(self):

0 commit comments

Comments
 (0)