@@ -171,6 +171,8 @@ class FuncCodeGen:
171
171
signature : str
172
172
func : hir .Function
173
173
params : Set [str ]
174
+ node_map : Dict [hir .Node , str ]
175
+ vid_cnt : int
174
176
175
177
def gen_var (self , var : hir .Var ) -> str :
176
178
assert var .type
@@ -189,6 +191,12 @@ def __init__(self, base: CppCodeGen, func: hir.Function) -> None:
189
191
self .signature = f'extern "C" auto { self .name } ({ params } ) -> { base .type_cache .gen (func .return_type )} '
190
192
self .body = ScratchBuffer ()
191
193
self .params = set (p .name for p in func .params )
194
+ self .node_map = {}
195
+ self .vid_cnt = 0
196
+
197
+ def new_vid (self ) -> int :
198
+ self .vid_cnt += 1
199
+ return self .vid_cnt
192
200
193
201
def gen_ref (self , ref : hir .Ref ) -> str :
194
202
match ref :
@@ -245,11 +253,44 @@ def gen_expr(self, expr: hir.Value) -> str:
245
253
else :
246
254
raise NotImplementedError (
247
255
f"unsupported constant: { constant } " )
256
+ case hir .Init () as init :
257
+ return f"([&]() {{ { self .gen_expr (init .value )} ; }})()"
248
258
case _:
249
259
raise NotImplementedError (f"unsupported expression: { expr } " )
250
260
251
- def gen_stmt (self , stmt : hir .Stmt ):
252
- match stmt :
261
+ # def gen_stmt(self, stmt: hir.Stmt):
262
+ # match stmt:
263
+ # case hir.Return() as ret:
264
+ # if ret.value:
265
+ # self.body.writeln(f"return {self.gen_expr(ret.value)};")
266
+ # else:
267
+ # self.body.writeln("return;")
268
+ # case hir.Assign() as assign:
269
+ # ref = self.gen_ref(assign.ref)
270
+ # value = self.gen_expr(assign.value)
271
+ # self.body.writeln(f"{ref} = {value};")
272
+ # case hir.If() as if_stmt:
273
+ # cond = self.gen_expr(if_stmt.cond)
274
+ # self.body.writeln(f"if ({cond}) {{")
275
+ # self.body.indent += 1
276
+ # for stmt in if_stmt.then_body:
277
+ # self.gen_stmt(stmt)
278
+ # self.body.indent -= 1
279
+ # self.body.writeln("}")
280
+ # if if_stmt.else_body:
281
+ # self.body.writeln("else {")
282
+ # self.body.indent += 1
283
+ # for stmt in if_stmt.else_body:
284
+ # self.gen_stmt(stmt)
285
+ # self.body.indent -= 1
286
+ # self.body.writeln("}")
287
+ # case hir.VarDecl() as var_decl:
288
+ # pass
289
+ # case _:
290
+ # raise NotImplementedError(f"unsupported statement: {stmt}")
291
+
292
+ def gen_node (self , node : hir .Node ):
293
+ match node :
253
294
case hir .Return () as ret :
254
295
if ret .value :
255
296
self .body .writeln (f"return { self .gen_expr (ret .value )} ;" )
@@ -261,23 +302,45 @@ def gen_stmt(self, stmt: hir.Stmt):
261
302
self .body .writeln (f"{ ref } = { value } ;" )
262
303
case hir .If () as if_stmt :
263
304
cond = self .gen_expr (if_stmt .cond )
264
- self .body .writeln (f"if ({ cond } ) {{" )
265
- self .body .indent += 1
266
- for stmt in if_stmt .then_body :
267
- self .gen_stmt (stmt )
268
- self .body .indent -= 1
269
- self .body .writeln ("}" )
305
+ self .body .writeln (f"if ({ cond } )" )
306
+ self .gen_bb (if_stmt .then_body )
270
307
if if_stmt .else_body :
271
- self .body .writeln ("else {" )
272
- self .body .indent += 1
273
- for stmt in if_stmt .else_body :
274
- self .gen_stmt (stmt )
275
- self .body .indent -= 1
276
- self .body .writeln ("}" )
277
- case hir .VarDecl () as var_decl :
308
+ self .body .writeln ("else" )
309
+ self .gen_bb (if_stmt .else_body )
310
+ self .gen_bb (if_stmt .merge )
311
+ case hir .Loop () as loop :
312
+ vid = self .new_vid ()
313
+ self .body .write (f"auto loop{ vid } _prepare = [&]()->bool {{" )
314
+ self .gen_bb (loop .prepare )
315
+ if loop .cond :
316
+ self .body .writeln (f"return { self .gen_expr (loop .cond )} ;" )
317
+ else :
318
+ self .body .writeln ("return true;" )
319
+ self .body .writeln ("};" )
320
+ self .body .writeln (f"auto loop{ vid } _body = [&]() {{" )
321
+ self .gen_bb (loop .body )
322
+ self .body .writeln ("};" )
323
+ self .body .writeln (f"auto loop{ vid } _update = [&]() {{" )
324
+ if loop .update :
325
+ self .gen_bb (loop .update )
326
+ self .body .writeln ("};" )
327
+ self .body .writeln (
328
+ f"for(;loop{ vid } _prepare();loop{ vid } _update());" )
329
+ self .gen_bb (loop .merge )
330
+ case hir .Alloca () as alloca :
278
331
pass
279
- case _:
280
- raise NotImplementedError (f"unsupported statement: { stmt } " )
332
+ case hir .Call () as call :
333
+ self .gen_expr (call )
334
+ case hir .Member () | hir .Index ():
335
+ pass
336
+
337
+ def gen_bb (self , bb : hir .BasicBlock ):
338
+ self .body .writeln (f"{{ // BasicBlock Begin { bb .span } " )
339
+ self .body .indent += 1
340
+ for node in bb .nodes :
341
+ self .gen_node (node )
342
+ self .body .indent -= 1
343
+ self .body .writeln (f"}} // BasicBlock End { bb .span } " )
281
344
282
345
def gen_locals (self ):
283
346
for local in self .func .locals :
@@ -294,7 +357,7 @@ def gen(self) -> None:
294
357
self .body .writeln (f"{ self .signature } {{" )
295
358
self .body .indent += 1
296
359
self .gen_locals ()
297
- for stmt in self .func .body :
298
- self .gen_stmt ( stmt )
360
+ if self .func .body :
361
+ self .gen_bb ( self . func . body )
299
362
self .body .indent -= 1
300
363
self .body .writeln ("}" )
0 commit comments