Skip to content

Commit 2dc42f6

Browse files
committed
refactoring hir
1 parent 915507c commit 2dc42f6

File tree

5 files changed

+344
-117
lines changed

5 files changed

+344
-117
lines changed

luisa_lang/_builtin_decor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ def make_type_rule(
3939
def type_rule(args: List[hir.Type]) -> hir.Type:
4040

4141
parameters_list = list(parameters.values())
42-
if name == '__init__':
43-
parameters_list = parameters_list[1:]
4442
if len(args) > len(parameters_list):
4543
raise hir.TypeInferenceError(None,
4644
f"Too many arguments for {cls_name}.{name} expected at most {len(parameters_list)} but got {len(args)}"

luisa_lang/codegen/cpp.py

Lines changed: 82 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ class FuncCodeGen:
171171
signature: str
172172
func: hir.Function
173173
params: Set[str]
174+
node_map: Dict[hir.Node, str]
175+
vid_cnt: int
174176

175177
def gen_var(self, var: hir.Var) -> str:
176178
assert var.type
@@ -189,6 +191,12 @@ def __init__(self, base: CppCodeGen, func: hir.Function) -> None:
189191
self.signature = f'extern "C" auto {self.name}({params}) -> {base.type_cache.gen(func.return_type)}'
190192
self.body = ScratchBuffer()
191193
self.params = set(p.name for p in func.params)
194+
self.node_map = {}
195+
self.vid_cnt = 0
196+
197+
def new_vid(self) -> int:
198+
self.vid_cnt += 1
199+
return self.vid_cnt
192200

193201
def gen_ref(self, ref: hir.Ref) -> str:
194202
match ref:
@@ -245,11 +253,44 @@ def gen_expr(self, expr: hir.Value) -> str:
245253
else:
246254
raise NotImplementedError(
247255
f"unsupported constant: {constant}")
256+
case hir.Init() as init:
257+
return f"([&]() {{ {self.gen_expr(init.value)}; }})()"
248258
case _:
249259
raise NotImplementedError(f"unsupported expression: {expr}")
250260

251-
def gen_stmt(self, stmt: hir.Stmt):
252-
match stmt:
261+
# def gen_stmt(self, stmt: hir.Stmt):
262+
# match stmt:
263+
# case hir.Return() as ret:
264+
# if ret.value:
265+
# self.body.writeln(f"return {self.gen_expr(ret.value)};")
266+
# else:
267+
# self.body.writeln("return;")
268+
# case hir.Assign() as assign:
269+
# ref = self.gen_ref(assign.ref)
270+
# value = self.gen_expr(assign.value)
271+
# self.body.writeln(f"{ref} = {value};")
272+
# case hir.If() as if_stmt:
273+
# cond = self.gen_expr(if_stmt.cond)
274+
# self.body.writeln(f"if ({cond}) {{")
275+
# self.body.indent += 1
276+
# for stmt in if_stmt.then_body:
277+
# self.gen_stmt(stmt)
278+
# self.body.indent -= 1
279+
# self.body.writeln("}")
280+
# if if_stmt.else_body:
281+
# self.body.writeln("else {")
282+
# self.body.indent += 1
283+
# for stmt in if_stmt.else_body:
284+
# self.gen_stmt(stmt)
285+
# self.body.indent -= 1
286+
# self.body.writeln("}")
287+
# case hir.VarDecl() as var_decl:
288+
# pass
289+
# case _:
290+
# raise NotImplementedError(f"unsupported statement: {stmt}")
291+
292+
def gen_node(self, node: hir.Node):
293+
match node:
253294
case hir.Return() as ret:
254295
if ret.value:
255296
self.body.writeln(f"return {self.gen_expr(ret.value)};")
@@ -261,23 +302,45 @@ def gen_stmt(self, stmt: hir.Stmt):
261302
self.body.writeln(f"{ref} = {value};")
262303
case hir.If() as if_stmt:
263304
cond = self.gen_expr(if_stmt.cond)
264-
self.body.writeln(f"if ({cond}) {{")
265-
self.body.indent += 1
266-
for stmt in if_stmt.then_body:
267-
self.gen_stmt(stmt)
268-
self.body.indent -= 1
269-
self.body.writeln("}")
305+
self.body.writeln(f"if ({cond})")
306+
self.gen_bb(if_stmt.then_body)
270307
if if_stmt.else_body:
271-
self.body.writeln("else {")
272-
self.body.indent += 1
273-
for stmt in if_stmt.else_body:
274-
self.gen_stmt(stmt)
275-
self.body.indent -= 1
276-
self.body.writeln("}")
277-
case hir.VarDecl() as var_decl:
308+
self.body.writeln("else")
309+
self.gen_bb(if_stmt.else_body)
310+
self.gen_bb(if_stmt.merge)
311+
case hir.Loop() as loop:
312+
vid = self.new_vid()
313+
self.body.write(f"auto loop{vid}_prepare = [&]()->bool {{")
314+
self.gen_bb(loop.prepare)
315+
if loop.cond:
316+
self.body.writeln(f"return {self.gen_expr(loop.cond)};")
317+
else:
318+
self.body.writeln("return true;")
319+
self.body.writeln("};")
320+
self.body.writeln(f"auto loop{vid}_body = [&]() {{")
321+
self.gen_bb(loop.body)
322+
self.body.writeln("};")
323+
self.body.writeln(f"auto loop{vid}_update = [&]() {{")
324+
if loop.update:
325+
self.gen_bb(loop.update)
326+
self.body.writeln("};")
327+
self.body.writeln(
328+
f"for(;loop{vid}_prepare();loop{vid}_update());")
329+
self.gen_bb(loop.merge)
330+
case hir.Alloca() as alloca:
278331
pass
279-
case _:
280-
raise NotImplementedError(f"unsupported statement: {stmt}")
332+
case hir.Call() as call:
333+
self.gen_expr(call)
334+
case hir.Member() | hir.Index():
335+
pass
336+
337+
def gen_bb(self, bb: hir.BasicBlock):
338+
self.body.writeln(f"{{ // BasicBlock Begin {bb.span}")
339+
self.body.indent += 1
340+
for node in bb.nodes:
341+
self.gen_node(node)
342+
self.body.indent -= 1
343+
self.body.writeln(f"}} // BasicBlock End {bb.span}")
281344

282345
def gen_locals(self):
283346
for local in self.func.locals:
@@ -294,7 +357,7 @@ def gen(self) -> None:
294357
self.body.writeln(f"{self.signature} {{")
295358
self.body.indent += 1
296359
self.gen_locals()
297-
for stmt in self.func.body:
298-
self.gen_stmt(stmt)
360+
if self.func.body:
361+
self.gen_bb(self.func.body)
299362
self.body.indent -= 1
300363
self.body.writeln("}")

0 commit comments

Comments
 (0)