Skip to content

Commit d92e550

Browse files
committed
fixed loop codegen
1 parent 87e7e52 commit d92e550

File tree

4 files changed

+140
-37
lines changed

4 files changed

+140
-37
lines changed

README.md

+54-2
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@ A new Python DSL frontend for LuisaCompute. Will be integrated into LuisaCompute
55
## Content
66
- [Introduction](#introduction)
77
- [Basics](#basic-syntax)
8+
- [Difference from Python](#difference-from-python)
89
- [Types](#types)
10+
- [Value & Reference Semantics](#value--reference-semantics)
911
- [Functions](#functions)
1012
- [User-defined Structs](#user-defined-structs)
13+
- [Control Flow](#control-flow)
1114
- [Advanced Usage](#advanced-syntax)
1215
- [Generics](#generics)
1316
- [Metaprogramming](#metaprogramming)
@@ -20,11 +23,16 @@ A new Python DSL frontend for LuisaCompute. Will be integrated into LuisaCompute
2023
import luisa_lang as lc
2124
```
2225
## Basic Syntax
26+
### Difference from Python
27+
There are some notable differences between luisa_lang and Python:
28+
- Variables have value semantics by default. Use `inout` to indicate that an argument that is passed by reference.
29+
- Generic functions and structs are implemented via monomorphization (a.k.a instantiation) at compile time rather than via type erasure.
30+
- Overloading subscript operator and attribute access is different from Python. Only `__getitem__` and `__getattr__` are needed, which returns a local reference.
31+
2332
### Types
2433
```python
2534
```
2635

27-
2836
### Functions
2937
Functions are defined using the `@lc.func` decorator. The function body can contain any valid LuisaCompute code. You can also include normal Python code that will be executed at DSL comile time using `lc.comptime()`. (See [Metaprogramming](#metaprogramming) for more details)
3038

@@ -37,13 +45,57 @@ def add(a: lc.float, b: lc.float) -> lc.float:
3745

3846
```
3947

40-
LuisaCompute uses value semantics, which means that all types are passed by value. You can use `inout` to indicate that a variable can be modified in place.
48+
49+
### Value & Reference Semantics
50+
Variables have value semantics by default. This means that when you assign a variable to another, a copy is made.
51+
```python
52+
a = lc.float3(1.0, 2.0, 3.0)
53+
b = a
54+
a.x = 2.0
55+
lc.print(f'{a.x} {b.x}') # prints 2.0 1.0
56+
```
57+
58+
You can use `inout` to indicate that a variable is passed as a *local reference*. Assigning to an `inout` variable will update the original variable.
4159
```python
4260
@luisa.func(a=inout, b=inout)
4361
def swap(a: int, b: int):
4462
a, b = b, a
63+
64+
a = lc.float3(1.0, 2.0, 3.0)
65+
b = lc.float3(4.0, 5.0, 6.0)
66+
swap(a.x, b.x)
67+
lc.print(f'{a.x} {b.x}') # prints 4.0 1.0
4568
```
4669

70+
When overloading subscript operator or attribute access, you actually return a local reference to the object.
71+
72+
#### Local References
73+
Local references are like pointers in C++. However, they cannot escape the expression boundary. This means that you cannot store a local reference in a variable and use it later. While you can return a local reference from a function, it must be returned from a uniform path. That is you cannot return different local references based on a condition.
74+
75+
76+
```python
77+
@lc.struct
78+
class InfiniteArray:
79+
def __getitem__(self, index: int) -> int:
80+
return self.data[index] # returns a local reference
81+
82+
# this method will be ignored by the compiler. but you can still put it here for linting
83+
def __setitem__(self, index: int, value: int):
84+
pass
85+
86+
# Not allowed, non-uniform return
87+
def __getitem__(self, index: int) -> int:
88+
if index == 0:
89+
return self.data[0]
90+
else:
91+
return self.data[1]
92+
93+
```
94+
95+
96+
97+
98+
4799
### User-defined Structs
48100
```python
49101
@lc.struct

luisa_lang/codegen/cpp.py

+36-17
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ def mangle_impl(self, obj: Union[hir.Type, hir.FunctionLike]) -> str:
133133
case hir.Function(name=name, params=params, return_type=ret):
134134
assert ret
135135
name = mangle_name(name)
136+
params = list(filter(lambda p: not isinstance(
137+
p.type, (hir.FunctionType)), params))
136138
return f'{name}_' + unique_hash(f"F{name}_{self.mangle(ret)}{''.join(self.mangle(unwrap(p.type)) for p in params)}")
137139
case hir.BuiltinFunction(name=name):
138140
name = map_builtin_to_cpp_func(name)
@@ -203,7 +205,8 @@ def __init__(self, base: CppCodeGen, func: hir.Function) -> None:
203205
self.base = base
204206
self.name = base.mangling.mangle(func)
205207
self.func = func
206-
params = ",".join(self.gen_var(p) for p in func.params)
208+
params = ",".join(self.gen_var(
209+
p) for p in func.params if not isinstance(p.type, hir.FunctionType))
207210
assert func.return_type
208211
self.signature = f'extern "C" auto {self.name}({params}) -> {base.type_cache.gen(func.return_type)}'
209212
self.body = ScratchBuffer()
@@ -250,6 +253,8 @@ def gen_value_or_ref(self, value: hir.Value | hir.Ref) -> str:
250253
f"unsupported value or reference: {value}")
251254

252255
def gen_expr(self, expr: hir.Value) -> str:
256+
if expr.type and isinstance(expr.type, hir.FunctionType):
257+
return ''
253258
if expr in self.node_map:
254259
return self.node_map[expr]
255260
vid = self.new_vid()
@@ -269,8 +274,10 @@ def impl() -> None:
269274
f"const auto v{vid} = {base}.{member.field};")
270275
case hir.Call() as call:
271276
op = self.gen_func(call.op)
277+
args_s = ','.join(self.gen_value_or_ref(
278+
arg) for arg in call.args if not isinstance(arg.type, hir.FunctionType))
272279
self.body.writeln(
273-
f"auto v{vid} ={op}({','.join(self.gen_value_or_ref(arg) for arg in call.args)});")
280+
f"auto v{vid} ={op}({args_s});")
274281
case hir.Constant() as constant:
275282
value = constant.value
276283
if isinstance(value, int):
@@ -302,6 +309,7 @@ def impl() -> None:
302309
return f'v{vid}'
303310

304311
def gen_node(self, node: hir.Node):
312+
305313
match node:
306314
case hir.Return() as ret:
307315
if ret.value:
@@ -324,31 +332,42 @@ def gen_node(self, node: hir.Node):
324332
self.gen_bb(if_stmt.else_body)
325333
self.body.indent -= 1
326334
self.gen_bb(if_stmt.merge)
335+
case hir.Break():
336+
self.body.writeln("__loop_break = true; break;")
337+
case hir.Continue():
338+
self.body.writeln("break;")
327339
case hir.Loop() as loop:
328-
vid = self.new_vid()
329-
self.body.write(f"auto loop{vid}_prepare = [&]()->bool {{")
340+
"""
341+
while(true) {
342+
bool loop_break = false;
343+
prepare();
344+
if (!cond()) break;
345+
do {
346+
// break => { loop_break = true; break; }
347+
// continue => { break; }
348+
} while(false);
349+
if (loop_break) break;
350+
update();
351+
}
352+
353+
"""
354+
self.body.writeln("while(true) {")
330355
self.body.indent += 1
356+
self.body.writeln("bool __loop_break = false;")
331357
self.gen_bb(loop.prepare)
332358
if loop.cond:
333-
self.body.writeln(f"return {self.gen_expr(loop.cond)};")
334-
else:
335-
self.body.writeln("return true;")
336-
self.body.indent -= 1
337-
self.body.writeln("};")
338-
self.body.writeln(f"auto loop{vid}_body = [&]() {{")
359+
cond = self.gen_expr(loop.cond)
360+
self.body.writeln(f"if (!{cond}) break;")
361+
self.body.writeln("do {")
339362
self.body.indent += 1
340363
self.gen_bb(loop.body)
341364
self.body.indent -= 1
342-
self.body.writeln("};")
343-
self.body.writeln(f"auto loop{vid}_update = [&]() {{")
344-
self.body.indent += 1
365+
self.body.writeln("} while(false);")
366+
self.body.writeln("if (__loop_break) break;")
345367
if loop.update:
346368
self.gen_bb(loop.update)
347369
self.body.indent -= 1
348-
self.body.writeln("};")
349-
self.body.writeln(
350-
f"for(;loop{vid}_prepare();loop{vid}_update());")
351-
self.gen_bb(loop.merge)
370+
self.body.writeln("}")
352371
case hir.Alloca() as alloca:
353372
vid = self.new_vid()
354373
assert alloca.type

luisa_lang/hir.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def method(self, name: str) -> Optional[FunctionLike | FunctionTemplate]:
127127

128128
def is_concrete(self) -> bool:
129129
return True
130-
130+
131131
def __len__(self) -> int:
132132
return 1
133133

@@ -341,7 +341,8 @@ def member(self, field: Any) -> Optional['Type']:
341341

342342
def __len__(self) -> int:
343343
return self.count
344-
344+
345+
345346
class ArrayType(Type):
346347
element: Type
347348
count: Union[int, "SymbolicConstant"]
@@ -868,6 +869,7 @@ def __init__(self, args: List[Value], type: Type, span: Optional[Span] = None) -
868869
super().__init__(type, span)
869870
self.args = args
870871

872+
871873
class Call(Value):
872874
op: FunctionLike
873875
"""After type inference, op should be a Value."""
@@ -988,17 +990,17 @@ def __init__(
988990

989991

990992
class Break(Terminator):
991-
target: Loop
993+
target: Loop | None
992994

993-
def __init__(self, target: Loop, span: Optional[Span] = None) -> None:
995+
def __init__(self, target: Loop | None, span: Optional[Span] = None) -> None:
994996
super().__init__(span)
995997
self.target = target
996998

997999

9981000
class Continue(Terminator):
999-
target: Loop
1001+
target: Loop | None
10001002

1001-
def __init__(self, target: Loop, span: Optional[Span] = None) -> None:
1003+
def __init__(self, target: Loop | None, span: Optional[Span] = None) -> None:
10021004
super().__init__(span)
10031005
self.target = target
10041006

@@ -1057,7 +1059,7 @@ def update(self, value: Any) -> None:
10571059
self.update_func(value)
10581060
else:
10591061
raise RuntimeError("unable to update comptime value")
1060-
1062+
10611063
def __str__(self) -> str:
10621064
return f"ComptimeValue({self.value})"
10631065

luisa_lang/parse.py

+41-11
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ class FuncParser:
151151
type_var_ns: Dict[typing.TypeVar, hir.Type | ComptimeValue]
152152
bb_stack: List[hir.BasicBlock]
153153
type_parser: TypeParser
154+
break_and_continues: List[hir.Break | hir.Continue] | None
154155

155156
def __init__(self, name: str,
156157
func: object,
@@ -173,7 +174,7 @@ def __init__(self, name: str,
173174
self.parsed_func = hir.Function(name, [], None)
174175
self.type_var_ns = type_var_ns
175176
self.bb_stack = []
176-
177+
self.break_and_continues = None
177178
self.parsed_func.params = signature.params
178179
for p in self.parsed_func.params:
179180
self.vars[p.name] = p
@@ -262,11 +263,12 @@ def parse_name(self, name: ast.Name, new_var_hint: NewVarHint) -> hir.Ref | hir.
262263
if name.id in self.globalns:
263264
resolved = self.globalns[name.id]
264265
return self.convert_any_to_value(resolved, span)
265-
elif name.id in __builtins__: # type: ignore
266+
elif name.id in __builtins__: # type: ignore
266267
resolved = __builtins__[name.id] # type: ignore
267268
return self.convert_any_to_value(resolved, span)
268269
elif new_var_hint == 'comptime':
269270
self.globalns[name.id] = None
271+
270272
def update_fn(value: Any) -> None:
271273
self.globalns[name.id] = value
272274
return ComptimeValue(None, update_fn)
@@ -379,7 +381,9 @@ def parse_call_impl(self, span: hir.Span | None, f: hir.FunctionLike | hir.Funct
379381
span,
380382
f"Expected {len(template_params)} arguments, got {len(args)}")
381383
for i, (param, arg) in enumerate(zip(template_params, args)):
382-
assert arg.type is not None
384+
if arg.type is None:
385+
raise hir.TypeInferenceError(
386+
span, f"failed to infer type of argument {i}")
383387
template_resolve_args.append((param, arg.type))
384388
resolved_f = f.resolve(template_resolve_args)
385389
if isinstance(resolved_f, hir.TemplateMatchingError):
@@ -467,6 +471,7 @@ def handle_range() -> hir.Value | ComptimeValue:
467471
args[i] = self.try_convert_comptime_value(
468472
arg, hir.Span.from_ast(expr.args[i]))
469473
converted_args = cast(List[hir.Value], args)
474+
470475
def make_int(i: int) -> hir.Value:
471476
return hir.Constant(i, type=hir.GenericIntType())
472477
if len(args) == 1:
@@ -516,10 +521,12 @@ def collect_args() -> List[hir.Value | hir.Ref]:
516521
raise hir.ParsingError(expr, call.message)
517522
assert isinstance(call, hir.Call)
518523
return self.cur_bb().append(hir.Load(tmp))
519-
520-
if not isinstance(func, hir.Constant) or not isinstance(func.value, (hir.Function, hir.BuiltinFunction, hir.FunctionTemplate)):
524+
if func.type is not None and isinstance(func.type, hir.FunctionType):
525+
func_like = func.type.func_like
526+
elif not isinstance(func, hir.Constant) or not isinstance(func.value, (hir.Function, hir.BuiltinFunction, hir.FunctionTemplate)):
521527
raise hir.ParsingError(expr, f"function expected")
522-
func_like = func.value
528+
else:
529+
func_like = func.value
523530
ret = self.parse_call_impl(
524531
hir.Span.from_ast(expr), func_like, collect_args())
525532
if isinstance(ret, hir.TemplateMatchingError):
@@ -791,13 +798,19 @@ def parse_stmt(self, stmt: ast.stmt) -> None:
791798
stmt, "while loop condition must not be a comptime value")
792799
body = hir.BasicBlock(span)
793800
self.bb_stack.append(body)
801+
old_break_and_continues = self.break_and_continues
802+
self.break_and_continues = []
794803
for s in stmt.body:
795804
self.parse_stmt(s)
805+
break_and_continues = self.break_and_continues
806+
self.break_and_continues = old_break_and_continues
796807
body = self.bb_stack.pop()
797808
update = hir.BasicBlock(span)
798809
merge = hir.BasicBlock(span)
799-
pred_bb.append(
800-
hir.Loop(prepare, cond, body, update, merge, span))
810+
loop_node = hir.Loop(prepare, cond, body, update, merge, span)
811+
pred_bb.append(loop_node)
812+
for bc in break_and_continues:
813+
bc.target = loop_node
801814
self.bb_stack.append(merge)
802815
case ast.For():
803816
iter_val = self.parse_expr(stmt.iter)
@@ -828,12 +841,16 @@ def parse_stmt(self, stmt: ast.stmt) -> None:
828841
self.bb_stack.pop()
829842
body = hir.BasicBlock(span)
830843
self.bb_stack.append(body)
844+
old_break_and_continues = self.break_and_continues
845+
self.break_and_continues = []
831846
for s in stmt.body:
832847
self.parse_stmt(s)
833848
body = self.bb_stack.pop()
849+
break_and_continues = self.break_and_continues
850+
self.break_and_continues = old_break_and_continues
834851
update = hir.BasicBlock(span)
835852
self.bb_stack.append(update)
836-
inc =loop_range.step
853+
inc = loop_range.step
837854
int_add = loop_var.type.method("__add__")
838855
assert int_add is not None
839856
add = self.parse_call_impl(
@@ -842,9 +859,22 @@ def parse_stmt(self, stmt: ast.stmt) -> None:
842859
self.cur_bb().append(hir.Assign(loop_var, add))
843860
self.bb_stack.pop()
844861
merge = hir.BasicBlock(span)
845-
pred_bb.append(
846-
hir.Loop(prepare, cmp_result, body, update, merge, span))
862+
loop_node = hir.Loop(prepare, cmp_result,
863+
body, update, merge, span)
864+
pred_bb.append(loop_node)
865+
for bc in break_and_continues:
866+
bc.target = loop_node
847867
self.bb_stack.append(merge)
868+
case ast.Break():
869+
if self.break_and_continues is None:
870+
raise hir.ParsingError(
871+
stmt, "break statement must be inside a loop")
872+
self.cur_bb().append(hir.Break(None, span))
873+
case ast.Continue():
874+
if self.break_and_continues is None:
875+
raise hir.ParsingError(
876+
stmt, "continue statement must be inside a loop")
877+
self.cur_bb().append(hir.Continue(None, span))
848878
case ast.Return():
849879
def check_return_type(ty: hir.Type) -> None:
850880
assert self.parsed_func

0 commit comments

Comments
 (0)