4
4
from luisa_lang .codegen import CodeGen , ScratchBuffer
5
5
from typing import Any , Callable , Dict , Set , Tuple , Union
6
6
7
- from luisa_lang .hir .defs import GlobalContext
8
7
from luisa_lang .hir import get_dsl_func
9
- from luisa_lang .hir .infer import run_inference_on_function
10
8
11
9
12
10
class TypeCodeGenCache :
@@ -150,11 +148,11 @@ def gen_function(self, func: hir.Function | Callable[..., Any]) -> str:
150
148
assert dsl_func is not None
151
149
assert not dsl_func .is_generic , f"Generic functions should be resolved before codegen: { func } "
152
150
func_tmp = dsl_func .resolve ([])
153
- assert isinstance (func_tmp , hir .Function ), f"Expected function, got { func_tmp } "
151
+ assert isinstance (
152
+ func_tmp , hir .Function ), f"Expected function, got { func_tmp } "
154
153
func = func_tmp
155
154
if id (func ) in self .func_cache :
156
155
return self .func_cache [id (func )][1 ]
157
- run_inference_on_function (func )
158
156
func_code_gen = FuncCodeGen (self , func )
159
157
name = func_code_gen .name
160
158
self .func_cache [id (func )] = (func , name )
@@ -207,17 +205,31 @@ def gen_ref(self, ref: hir.Ref) -> str:
207
205
case _:
208
206
raise NotImplementedError (f"unsupported reference: { ref } " )
209
207
208
+ def gen_func (self , f : hir .FunctionLike ) -> str :
209
+ if isinstance (f , hir .Function ):
210
+ return self .base .gen_function (f )
211
+ elif isinstance (f , hir .BuiltinFunction ):
212
+ return self .base .mangling .mangle (f )
213
+ else :
214
+ raise NotImplementedError (f"unsupported constant" )
215
+
216
+ def gen_value_or_ref (self , value : hir .Value | hir .Ref ) -> str :
217
+ match value :
218
+ case hir .Value () as value :
219
+ return self .gen_expr (value )
220
+ case hir .Ref () as ref :
221
+ return self .gen_ref (ref )
222
+ case _:
223
+ raise NotImplementedError (
224
+ f"unsupported value or reference: { value } " )
225
+
210
226
def gen_expr (self , expr : hir .Value ) -> str :
211
227
match expr :
212
228
case hir .Load () as load :
213
229
return self .gen_ref (load .ref )
214
230
case hir .Call () as call :
215
- assert call .resolved , f"unresolved call: { call } "
216
- kind = call .kind
217
- assert kind == hir .CallOpKind .FUNC and isinstance (
218
- call .op , hir .Value )
219
- op = self .gen_expr (call .op )
220
- return f"{ op } ({ ',' .join (self .gen_expr (arg ) for arg in call .args )} )"
231
+ op = self .gen_func (call .op )
232
+ return f"{ op } ({ ',' .join (self .gen_value_or_ref (arg ) for arg in call .args )} )"
221
233
case hir .Constant () as constant :
222
234
value = constant .value
223
235
if isinstance (value , int ):
@@ -228,10 +240,8 @@ def gen_expr(self, expr: hir.Value) -> str:
228
240
return "true" if value else "false"
229
241
elif isinstance (value , str ):
230
242
return f'"{ value } "'
231
- elif isinstance (value , hir .Function ):
232
- return self .base .gen_function (value )
233
- elif isinstance (value , hir .BuiltinFunction ):
234
- return self .base .mangling .mangle (value )
243
+ elif isinstance (value , hir .Function ) or isinstance (value , hir .BuiltinFunction ):
244
+ return self .gen_func (value )
235
245
else :
236
246
raise NotImplementedError (
237
247
f"unsupported constant: { constant } " )
0 commit comments