Skip to content

Commit 915507c

Browse files
committed
refactoring parser
1 parent 59100ff commit 915507c

File tree

11 files changed

+532
-1583
lines changed

11 files changed

+532
-1583
lines changed

luisa_lang/classinfo.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Literal,
99
Optional,
1010
Set,
11+
Tuple,
1112
TypeVar,
1213
Generic,
1314
Dict,
@@ -39,13 +40,15 @@ def __init__(self, types: List["VarType"]):
3940
def __repr__(self):
4041
return f"Union[{', '.join(map(repr, self.types))}]"
4142

43+
4244
class AnyType:
4345
def __repr__(self):
4446
return "Any"
45-
47+
4648
def __eq__(self, other):
4749
return isinstance(other, AnyType)
4850

51+
4952
class SelfType:
5053
def __repr__(self):
5154
return "Self"
@@ -69,13 +72,13 @@ def subst_type(ty: VarType, env: Dict[TypeVar, VarType]) -> VarType:
6972

7073
class MethodType:
7174
type_vars: List[TypeVar]
72-
args: List[VarType]
75+
args: List[Tuple[str, VarType]]
7376
return_type: VarType
7477
env: Dict[TypeVar, VarType]
7578
is_static: bool
7679

7780
def __init__(
78-
self, type_vars: List[TypeVar], args: List[VarType], return_type: VarType, env: Optional[Dict[TypeVar, VarType]] = None, is_static: bool = False
81+
self, type_vars: List[TypeVar], args: List[Tuple[str, VarType]], return_type: VarType, env: Optional[Dict[TypeVar, VarType]] = None, is_static: bool = False
7982
):
8083
self.type_vars = type_vars
8184
self.args = args
@@ -88,7 +91,7 @@ def __repr__(self):
8891
return f"[{', '.join(map(repr, self.type_vars))}]({', '.join(map(repr, self.args))}) -> {self.return_type}"
8992

9093
def substitute(self, env: Dict[TypeVar, VarType]) -> "MethodType":
91-
return MethodType([], [subst_type(arg, env) for arg in self.args], subst_type(self.return_type, env), env)
94+
return MethodType([], [(arg[0], subst_type(arg[1], env)) for arg in self.args], subst_type(self.return_type, env), env)
9295

9396

9497
class ClassType:
@@ -229,15 +232,15 @@ def parse_func_signature(func: object, globalns: Dict[str, Any], foreign_type_va
229232
assert inspect.isfunction(func)
230233
signature = inspect.signature(func)
231234
method_type_hints = typing.get_type_hints(func, globalns)
232-
param_types: List[VarType] = []
235+
param_types: List[Tuple[str, VarType]] = []
233236
type_vars = get_type_vars(func)
234237
for param in signature.parameters.values():
235238
if param.name == "self":
236239
assert self_type is not None
237-
param_types.append(self_type)
240+
param_types.append((param.name, self_type))
238241
else:
239-
param_types.append(parse_type_hint(
240-
method_type_hints[param.name]))
242+
param_types.append((param.name, parse_type_hint(
243+
method_type_hints[param.name])))
241244
if "return" in method_type_hints:
242245
return_type = parse_type_hint(method_type_hints.get("return"))
243246
else:

luisa_lang/codegen/cpp.py

+24-14
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
from luisa_lang.codegen import CodeGen, ScratchBuffer
55
from typing import Any, Callable, Dict, Set, Tuple, Union
66

7-
from luisa_lang.hir.defs import GlobalContext
87
from luisa_lang.hir import get_dsl_func
9-
from luisa_lang.hir.infer import run_inference_on_function
108

119

1210
class TypeCodeGenCache:
@@ -150,11 +148,11 @@ def gen_function(self, func: hir.Function | Callable[..., Any]) -> str:
150148
assert dsl_func is not None
151149
assert not dsl_func.is_generic, f"Generic functions should be resolved before codegen: {func}"
152150
func_tmp = dsl_func.resolve([])
153-
assert isinstance(func_tmp, hir.Function), f"Expected function, got {func_tmp}"
151+
assert isinstance(
152+
func_tmp, hir.Function), f"Expected function, got {func_tmp}"
154153
func = func_tmp
155154
if id(func) in self.func_cache:
156155
return self.func_cache[id(func)][1]
157-
run_inference_on_function(func)
158156
func_code_gen = FuncCodeGen(self, func)
159157
name = func_code_gen.name
160158
self.func_cache[id(func)] = (func, name)
@@ -207,17 +205,31 @@ def gen_ref(self, ref: hir.Ref) -> str:
207205
case _:
208206
raise NotImplementedError(f"unsupported reference: {ref}")
209207

208+
def gen_func(self, f: hir.FunctionLike) -> str:
209+
if isinstance(f, hir.Function):
210+
return self.base.gen_function(f)
211+
elif isinstance(f, hir.BuiltinFunction):
212+
return self.base.mangling.mangle(f)
213+
else:
214+
raise NotImplementedError(f"unsupported constant")
215+
216+
def gen_value_or_ref(self, value: hir.Value | hir.Ref) -> str:
217+
match value:
218+
case hir.Value() as value:
219+
return self.gen_expr(value)
220+
case hir.Ref() as ref:
221+
return self.gen_ref(ref)
222+
case _:
223+
raise NotImplementedError(
224+
f"unsupported value or reference: {value}")
225+
210226
def gen_expr(self, expr: hir.Value) -> str:
211227
match expr:
212228
case hir.Load() as load:
213229
return self.gen_ref(load.ref)
214230
case hir.Call() as call:
215-
assert call.resolved, f"unresolved call: {call}"
216-
kind = call.kind
217-
assert kind == hir.CallOpKind.FUNC and isinstance(
218-
call.op, hir.Value)
219-
op = self.gen_expr(call.op)
220-
return f"{op}({','.join(self.gen_expr(arg) for arg in call.args)})"
231+
op = self.gen_func(call.op)
232+
return f"{op}({','.join(self.gen_value_or_ref(arg) for arg in call.args)})"
221233
case hir.Constant() as constant:
222234
value = constant.value
223235
if isinstance(value, int):
@@ -228,10 +240,8 @@ def gen_expr(self, expr: hir.Value) -> str:
228240
return "true" if value else "false"
229241
elif isinstance(value, str):
230242
return f'"{value}"'
231-
elif isinstance(value, hir.Function):
232-
return self.base.gen_function(value)
233-
elif isinstance(value, hir.BuiltinFunction):
234-
return self.base.mangling.mangle(value)
243+
elif isinstance(value, hir.Function) or isinstance(value, hir.BuiltinFunction):
244+
return self.gen_func(value)
235245
else:
236246
raise NotImplementedError(
237247
f"unsupported constant: {constant}")

0 commit comments

Comments
 (0)