Skip to content

Commit 59100ff

Browse files
committed
rewriting parser
1 parent 519841c commit 59100ff

12 files changed

+458
-113
lines changed

luisa_lang/_builtin_decor.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import typing
44
from luisa_lang import hir
55
import inspect
6-
from luisa_lang._utils import get_full_name, get_union_args
7-
from luisa_lang._classinfo import register_class, _class_typeinfo, MethodType, _get_cls_globalns
6+
from luisa_lang.utils import get_full_name, get_union_args
7+
from luisa_lang.classinfo import register_class, class_typeinfo, MethodType, _get_cls_globalns
88
import functools
99

1010
_T = TypeVar("_T", bound=type)
@@ -18,7 +18,7 @@ def decorator(cls: _T) -> _T:
1818
ctx.types[cls] = ty
1919

2020
register_class(cls)
21-
cls_info = _class_typeinfo(cls)
21+
cls_info = class_typeinfo(cls)
2222
globalns = _get_cls_globalns(cls)
2323

2424
def make_type_rule(

luisa_lang/_classinfo.py renamed to luisa_lang/classinfo.py

+114-42
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(self, origin: type, args: List["VarType"]):
2929
def __repr__(self):
3030
return f"{self.origin}[{', '.join(map(repr, self.args))}]"
3131

32+
3233
class UnionType:
3334
types: List["VarType"]
3435

@@ -38,7 +39,23 @@ def __init__(self, types: List["VarType"]):
3839
def __repr__(self):
3940
return f"Union[{', '.join(map(repr, self.types))}]"
4041

41-
VarType = Union[TypeVar, Type, GenericInstance, UnionType]
42+
class AnyType:
43+
def __repr__(self):
44+
return "Any"
45+
46+
def __eq__(self, other):
47+
return isinstance(other, AnyType)
48+
49+
class SelfType:
50+
def __repr__(self):
51+
return "Self"
52+
53+
def __eq__(self, other):
54+
return isinstance(other, SelfType)
55+
56+
57+
VarType = Union[TypeVar, Type, GenericInstance, UnionType, SelfType, AnyType]
58+
4259

4360
def subst_type(ty: VarType, env: Dict[TypeVar, VarType]) -> VarType:
4461
match ty:
@@ -49,24 +66,27 @@ def subst_type(ty: VarType, env: Dict[TypeVar, VarType]) -> VarType:
4966
case _:
5067
return ty
5168

69+
5270
class MethodType:
5371
type_vars: List[TypeVar]
5472
args: List[VarType]
5573
return_type: VarType
5674
env: Dict[TypeVar, VarType]
75+
is_static: bool
5776

5877
def __init__(
59-
self, type_vars: List[TypeVar], args: List[VarType], return_type: VarType, env: Optional[Dict[TypeVar, VarType]] = None
78+
self, type_vars: List[TypeVar], args: List[VarType], return_type: VarType, env: Optional[Dict[TypeVar, VarType]] = None, is_static: bool = False
6079
):
6180
self.type_vars = type_vars
6281
self.args = args
6382
self.return_type = return_type
6483
self.env = env or {}
84+
self.is_static = is_static
6585

6686
def __repr__(self):
6787
# [a, b, c](x: T, y: U) -> V
6888
return f"[{', '.join(map(repr, self.type_vars))}]({', '.join(map(repr, self.args))}) -> {self.return_type}"
69-
89+
7090
def substitute(self, env: Dict[TypeVar, VarType]) -> "MethodType":
7191
return MethodType([], [subst_type(arg, env) for arg in self.args], subst_type(self.return_type, env), env)
7292

@@ -108,24 +128,27 @@ def instantiate(self, type_args: List[VarType]) -> "ClassType":
108128
)
109129
env = dict(zip(self.type_vars, type_args))
110130
return ClassType(
111-
[], {name: subst_type(ty, env) for name, ty in self.fields.items()}, {name: method.substitute(env) for name, method in self.methods.items()}
131+
[], {name: subst_type(ty, env) for name, ty in self.fields.items()}, {
132+
name: method.substitute(env) for name, method in self.methods.items()}
112133
)
113134

135+
114136
_CLS_TYPE_INFO: Dict[type, ClassType] = {}
115137

116138

117-
def _class_typeinfo(cls: type) -> ClassType:
139+
def class_typeinfo(cls: type) -> ClassType:
118140
if cls in _CLS_TYPE_INFO:
119141
return _CLS_TYPE_INFO[cls]
120142
raise RuntimeError(f"Class {cls} is not registered.")
121143

122144

123-
124145
def _is_class_registered(cls: type) -> bool:
125146
return cls in _CLS_TYPE_INFO
126147

148+
127149
_BUILTIN_ANNOTATION_BASES = set([typing.Generic, typing.Protocol, object])
128150

151+
129152
def _get_base_classinfo(cls: type, globalns) -> List[tuple[str, ClassType]]:
130153
if not hasattr(cls, "__orig_bases__"):
131154
return []
@@ -140,23 +163,99 @@ def _get_base_classinfo(cls: type, globalns) -> List[tuple[str, ClassType]]:
140163
)
141164
for arg in base.__args__:
142165
if isinstance(arg, typing.ForwardRef):
143-
arg: type = typing._eval_type(arg, globalns, globalns) #type: ignore
166+
arg: type = typing._eval_type( # type: ignore
167+
arg, globalns, globalns) # type: ignore
144168
base_params.append(arg)
145169
if base_orig in _BUILTIN_ANNOTATION_BASES:
146170
pass
147171
else:
148-
base_info = _class_typeinfo(base_orig)
149-
info.append((base.__name__, base_info.instantiate(base_params)))
172+
base_info = class_typeinfo(base_orig)
173+
info.append(
174+
(base.__name__, base_info.instantiate(base_params)))
150175
else:
151176
if _is_class_registered(base):
152-
info.append((base.__name__, _class_typeinfo(base)))
177+
info.append((base.__name__, class_typeinfo(base)))
153178
return info
154179

180+
155181
def _get_cls_globalns(cls: type) -> Dict[str, Any]:
156182
module = inspect.getmodule(cls)
157183
assert module is not None
158184
return module.__dict__
159185

186+
187+
def parse_type_hint(hint: Any) -> VarType:
188+
if hint is None:
189+
return NoneType
190+
if isinstance(hint, TypeVar):
191+
return hint
192+
origin = typing.get_origin(hint)
193+
if origin:
194+
if isinstance(origin, type):
195+
# assert isinstance(origin, type), f"origin must be a type but got {origin}"
196+
args = list(typing.get_args(hint))
197+
return GenericInstance(origin, [parse_type_hint(arg) for arg in args])
198+
elif origin is Union:
199+
return UnionType([parse_type_hint(arg) for arg in typing.get_args(hint)])
200+
else:
201+
raise RuntimeError(f"Unsupported origin type: {origin}")
202+
if isinstance(hint, type):
203+
return hint
204+
if hint == typing.Self:
205+
return SelfType()
206+
raise RuntimeError(f"Unsupported type hint: {hint}")
207+
208+
209+
def extract_type_vars_from_hint(hint: typing.Any) -> List[TypeVar]:
210+
if isinstance(hint, TypeVar):
211+
return [hint]
212+
if hasattr(hint, "__args__"): # Handle custom generic types like Foo[T]
213+
type_vars = []
214+
for arg in hint.__args__:
215+
type_vars.extend(extract_type_vars_from_hint(arg))
216+
return type_vars
217+
return []
218+
219+
220+
def get_type_vars(func: typing.Callable) -> List[TypeVar]:
221+
type_hints = typing.get_type_hints(func)
222+
type_vars = []
223+
for hint in type_hints.values():
224+
type_vars.extend(extract_type_vars_from_hint(hint))
225+
return list(set(type_vars)) # Return unique type vars
226+
227+
228+
def parse_func_signature(func: object, globalns: Dict[str, Any], foreign_type_vars: List[TypeVar], self_type: Optional[VarType] = None, is_static: bool = False) -> MethodType:
229+
assert inspect.isfunction(func)
230+
signature = inspect.signature(func)
231+
method_type_hints = typing.get_type_hints(func, globalns)
232+
param_types: List[VarType] = []
233+
type_vars = get_type_vars(func)
234+
for param in signature.parameters.values():
235+
if param.name == "self":
236+
assert self_type is not None
237+
param_types.append(self_type)
238+
else:
239+
param_types.append(parse_type_hint(
240+
method_type_hints[param.name]))
241+
if "return" in method_type_hints:
242+
return_type = parse_type_hint(method_type_hints.get("return"))
243+
else:
244+
return_type = AnyType()
245+
# remove foreign type vars from type_vars
246+
type_vars = [tv for tv in type_vars if tv not in foreign_type_vars]
247+
return MethodType(type_vars, param_types, return_type, is_static=is_static)
248+
249+
250+
def is_static(cls: type, method_name: str) -> bool:
251+
method = getattr(cls, method_name, None)
252+
if method is None:
253+
return False
254+
# Using inspect to retrieve the method directly from the class
255+
method = cls.__dict__.get(method_name, None)
256+
return isinstance(method, staticmethod)
257+
258+
160259
def register_class(cls: type) -> None:
161260
cls_qualname = cls.__qualname__
162261
globalns = _get_cls_globalns(cls)
@@ -202,47 +301,20 @@ def register_class(cls: type) -> None:
202301
continue
203302
local_methods.add(name)
204303

205-
def parse_type_hint(hint: Any) -> VarType:
206-
207-
if hint is None:
208-
return NoneType
209-
if isinstance(hint, TypeVar):
210-
return hint
211-
origin = typing.get_origin(hint)
212-
if origin:
213-
if isinstance(origin, type):
214-
# assert isinstance(origin, type), f"origin must be a type but got {origin}"
215-
args = list(typing.get_args(hint))
216-
return GenericInstance(origin, [parse_type_hint(arg) for arg in args])
217-
elif origin is Union:
218-
return UnionType([parse_type_hint(arg) for arg in typing.get_args(hint)])
219-
else:
220-
raise RuntimeError(f"Unsupported origin type: {origin}")
221-
if isinstance(hint, type):
222-
return hint
223-
raise RuntimeError(f"Unsupported type hint: {hint}")
224-
225304
cls_ty = ClassType([], {}, {})
226305
for _base_name, base_info in base_infos:
227306
cls_ty.fields.update(base_info.fields)
228307
cls_ty.methods.update(base_info.methods)
308+
229309
if type_vars:
230310
for tv in type_vars:
231311
cls_ty.type_vars.append(tv)
312+
self_ty: VarType = SelfType()
232313
for name, member in inspect.getmembers(cls):
233314
if name in local_methods:
234-
assert inspect.isfunction(member)
235-
signature = inspect.signature(member)
236-
method_type_hints = typing.get_type_hints(member)
237-
param_types: List[VarType] = []
238-
for param in signature.parameters.values():
239-
if param.name == "self":
240-
param_types.append(cls)
241-
else:
242-
param_types.append(parse_type_hint(
243-
method_type_hints[param.name]))
244-
return_type = parse_type_hint(method_type_hints.get("return"))
245-
cls_ty.methods[name] = MethodType([], param_types, return_type)
315+
# print(f'Found local method: {name} in {cls}')
316+
cls_ty.methods[name] = parse_func_signature(
317+
member, globalns, cls_ty.type_vars, self_ty, is_static=is_static(cls, name))
246318
for name in local_fields:
247319
cls_ty.fields[name] = parse_type_hint(type_hints[name])
248320
_CLS_TYPE_INFO[cls] = cls_ty

luisa_lang/codegen/cpp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from functools import cache
22
from luisa_lang import hir
3-
from luisa_lang._utils import unique_hash, unwrap
3+
from luisa_lang.utils import unique_hash, unwrap
44
from luisa_lang.codegen import CodeGen, ScratchBuffer
55
from typing import Any, Callable, Dict, Set, Tuple, Union
66

luisa_lang/hir/defs.py

+5-42
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
cast,
1616
)
1717
from typing_extensions import override
18-
from luisa_lang._utils import Span
18+
from luisa_lang.utils import Span
1919
from abc import ABC, abstractmethod
2020

2121
PATH_PREFIX = "luisa_lang"
@@ -135,7 +135,7 @@ def __eq__(self, value: object) -> bool:
135135

136136
def __hash__(self) -> int:
137137
return hash(UnitType)
138-
138+
139139
def __str__(self) -> str:
140140
return "NoneType"
141141

@@ -642,7 +642,7 @@ def __init__(self, func_like: FunctionLike | FunctionTemplate) -> None:
642642
self.func_like = func_like
643643

644644
def __eq__(self, value: object) -> bool:
645-
return isinstance(value, FunctionType) and id(value.func_like) == id(self.func_like)
645+
return isinstance(value, FunctionType) and value.func_like is self.func_like
646646

647647
def __hash__(self) -> int:
648648
return hash((FunctionType, id(self.func_like)))
@@ -863,8 +863,8 @@ class CallOpKind(Enum):
863863
class Constant(Value):
864864
value: Any
865865

866-
def __init__(self, value: Any, span: Optional[Span] = None) -> None:
867-
super().__init__(None, span)
866+
def __init__(self, value: Any, type: Type | None = None, span: Optional[Span] = None) -> None:
867+
super().__init__(type, span)
868868
self.value = value
869869

870870
def __eq__(self, value: object) -> bool:
@@ -1224,45 +1224,14 @@ def match_func_template_args(sig: Function, args: FunctionTemplateResolvingArgs)
12241224
template_args.append((param.name, param.type))
12251225
matching_args = [arg[1] for arg in args]
12261226
return match_template_args(template_args, matching_args)
1227-
# K = TypeVar("K")
1228-
# V = TypeVar("V")
1229-
1230-
1231-
# class Env(Generic[K, V]):
1232-
# _map: Dict[K, V]
1233-
# _parent: Optional["Env[K, V]"]
1234-
1235-
# def __init__(self) -> None:
1236-
# self._map = {}
1237-
# self._parent = None
12381227

1239-
# def fork(self) -> "Env[K, V]":
1240-
# env = Env[K, V]()
1241-
# env._parent = self
1242-
# return env
1243-
1244-
# def lookup(self, key: K) -> Optional[V]:
1245-
# res = self._map.get(key)
1246-
# if res is not None:
1247-
# return res
1248-
# if self._parent is not None:
1249-
# return self._parent.lookup(key)
1250-
# return None
1251-
1252-
# def bind(self, key: K, value: V) -> None:
1253-
# self._map[key] = value
1254-
1255-
1256-
# Item = Union[Type, Function, BuiltinFunction]
1257-
# ItemEnv = Env[Path, Item]
12581228

12591229
_global_context: Optional["GlobalContext"] = None
12601230

12611231

12621232
class GlobalContext:
12631233
types: Dict[type, Type]
12641234
functions: Dict[Callable[..., Any], FunctionTemplate]
1265-
# deferred: List[Callable[[], None]]
12661235

12671236
@staticmethod
12681237
def get() -> "GlobalContext":
@@ -1280,12 +1249,6 @@ def __init__(self) -> None:
12801249
bool: BoolType(),
12811250
}
12821251
self.functions = {}
1283-
# self.deferred = []
1284-
1285-
# def flush(self) -> None:
1286-
# for fn in self.deferred:
1287-
# fn()
1288-
# self.deferred = []
12891252

12901253

12911254
class FuncMetadata:

luisa_lang/hir/infer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast
22
from luisa_lang import hir
3-
from luisa_lang._utils import report_error
3+
from luisa_lang.utils import report_error
44
from luisa_lang.hir.defs import is_type_compatible_to
55
import traceback
66

luisa_lang/hir/template.py

-2
This file was deleted.

0 commit comments

Comments
 (0)