Skip to content

Commit e6a21a5

Browse files
committed
added parsing for generic aliases
1 parent fea3ff3 commit e6a21a5

File tree

6 files changed

+95
-34
lines changed

6 files changed

+95
-34
lines changed

luisa_lang/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
1+
import sys
2+
3+
# check if is python 3.12 or higher
4+
if sys.version_info < (3, 12):
5+
raise Exception("luisa_lang requires Python 3.12 or higher")
6+
17
from luisa_lang.lang import *
28
from luisa_lang.lang_builtins import *

luisa_lang/_builtin_decor.py

-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.Function:
142142

143143

144144
def _dsl_func_impl(f: _TT, kind: _ObjKind, attrs: Dict[str, Any]) -> _TT:
145-
import sourceinspect
146145
assert inspect.isfunction(f), f"{f} is not a function"
147146
# print(hir.GlobalContext.get)
148147
ctx = hir.GlobalContext.get()

luisa_lang/classinfo.py

+63-15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import inspect
2-
from types import NoneType
2+
from types import GenericAlias, NoneType
33
import types
44
import typing
55
from typing import (
@@ -10,21 +10,23 @@
1010
Optional,
1111
Set,
1212
Tuple,
13+
TypeAliasType,
1314
TypeVar,
1415
Generic,
1516
Dict,
1617
Type,
1718
Union,
19+
cast,
1820
)
1921
import functools
2022
from dataclasses import dataclass
2123

2224

2325
class GenericInstance:
24-
origin: type
26+
origin: 'VarType'
2527
args: List["VarType"]
2628

27-
def __init__(self, origin: type, args: List["VarType"]):
29+
def __init__(self, origin: 'VarType', args: List["VarType"]):
2830
self.origin = origin
2931
self.args = args
3032

@@ -41,6 +43,9 @@ def __init__(self, types: List["VarType"]):
4143
def __repr__(self):
4244
return f"Union[{', '.join(map(repr, self.types))}]"
4345

46+
def substitute(self, env: Dict[TypeVar, 'VarType']) -> "UnionType":
47+
return UnionType([subst_type(ty, env) for ty in self.types])
48+
4449

4550
class AnyType:
4651
def __repr__(self):
@@ -56,7 +61,8 @@ def __repr__(self):
5661

5762
def __eq__(self, other):
5863
return isinstance(other, SelfType)
59-
64+
65+
6066
class LiteralType:
6167
value: Any
6268

@@ -70,7 +76,23 @@ def __eq__(self, other):
7076
return isinstance(other, LiteralType) and self.value == other.value
7177

7278

73-
VarType = Union[TypeVar, Type, GenericInstance, UnionType, SelfType, AnyType, LiteralType]
79+
class AnnotatedType:
80+
origin: 'VarType'
81+
annotations: List[Any]
82+
83+
def __init__(self, origin: 'VarType', annotations: List[Any]):
84+
self.origin = origin
85+
self.annotations = annotations
86+
87+
def __repr__(self):
88+
return f"Annotated[{self.origin}, {self.annotations}]"
89+
90+
def substitute(self, env: Dict[TypeVar, 'VarType']) -> "AnnotatedType":
91+
return AnnotatedType(subst_type(self.origin, env), self.annotations)
92+
93+
94+
type VarType = Union[TypeVar, Type, GenericInstance,
95+
UnionType, SelfType, AnyType, LiteralType, AnnotatedType]
7496

7597

7698
def subst_type(ty: VarType, env: Dict[TypeVar, VarType]) -> VarType:
@@ -79,6 +101,8 @@ def subst_type(ty: VarType, env: Dict[TypeVar, VarType]) -> VarType:
79101
return env.get(ty, ty)
80102
case GenericInstance(origin=origin, args=args):
81103
return GenericInstance(origin, [subst_type(arg, env) for arg in args])
104+
case MethodType() | UnionType() | AnnotatedType():
105+
return ty.substitute(env)
82106
case _:
83107
return ty
84108

@@ -140,7 +164,8 @@ def __repr__(self):
140164
def instantiate(self, type_args: List[VarType]) -> "ClassType":
141165
if len(type_args) != len(self.type_vars):
142166
raise RuntimeError(
143-
f"Expected {len(self.type_vars)} type arguments but got {len(type_args)}"
167+
f"Expected {len(self.type_vars)}" +
168+
f"type arguments but got {len(type_args)}"
144169
)
145170
env = dict(zip(self.type_vars, type_args))
146171
return ClassType(
@@ -172,7 +197,8 @@ def _get_base_classinfo(cls: type, globalns) -> List[tuple[str, ClassType]]:
172197
for base in cls.__orig_bases__:
173198
if hasattr(base, "__origin__"):
174199
base_params = []
175-
base_orig = base.__origin__
200+
base_orig: Any = base.__origin__
201+
176202
if not _is_class_registered(base_orig) and base_orig not in _BUILTIN_ANNOTATION_BASES:
177203
raise RuntimeError(
178204
f"Base class {base_orig} of {cls} is not registered."
@@ -185,7 +211,8 @@ def _get_base_classinfo(cls: type, globalns) -> List[tuple[str, ClassType]]:
185211
if base_orig in _BUILTIN_ANNOTATION_BASES:
186212
pass
187213
else:
188-
base_info = class_typeinfo(base_orig)
214+
assert isinstance(base_orig, type)
215+
base_info = class_typeinfo(cast(type, base_orig))
189216
info.append(
190217
(base.__name__, base_info.instantiate(base_params)))
191218
else:
@@ -210,19 +237,40 @@ def parse_type_hint(hint: Any) -> VarType:
210237
return UnionType([parse_type_hint(arg) for arg in hint.__args__])
211238
if hint is typing.Any:
212239
return AnyType()
240+
if isinstance(hint, TypeAliasType):
241+
return parse_type_hint(hint.__value__)
242+
213243
origin = typing.get_origin(hint)
214244
if origin:
215-
if isinstance(origin, type):
216-
# assert isinstance(origin, type), f"origin must be a type but got {origin}"
217-
args = list(typing.get_args(hint))
218-
return GenericInstance(origin, [parse_type_hint(arg) for arg in args])
245+
if origin is typing.Annotated:
246+
annotate_args = typing.get_args(hint)
247+
return AnnotatedType(parse_type_hint(annotate_args[0]), list(annotate_args[1:]))
219248
elif origin is Union:
220249
return UnionType([parse_type_hint(arg) for arg in typing.get_args(hint)])
221250
elif origin is Literal:
222251
return LiteralType(typing.get_args(hint)[0])
252+
elif isinstance(origin, TypeAliasType):
253+
def do() -> VarType:
254+
assert isinstance(hint, GenericAlias)
255+
args = list(typing.get_args(hint))
256+
assert len(args) == len(origin.__parameters__), f"Expected {
257+
len(origin.__parameters__)} type arguments but got {len(args)}"
258+
true_origin = origin.__value__
259+
parametric_args = origin.__parameters__
260+
parsed_args = [parse_type_hint(arg) for arg in args]
261+
env = dict(zip(parametric_args, parsed_args))
262+
parsed_origin = parse_type_hint(true_origin)
263+
return subst_type(parsed_origin, env)
264+
return do()
265+
elif isinstance(origin, type):
266+
# assert isinstance(origin, type), f"origin must be a type but got {origin}"
267+
args = list(typing.get_args(hint))
268+
return GenericInstance(origin, [parse_type_hint(arg) for arg in args])
269+
223270
else:
224-
raise RuntimeError(f"Unsupported origin type: {origin}")
225-
271+
raise RuntimeError(f"Unsupported origin type: {
272+
origin}, {type(origin), type(hint)}")
273+
226274
if isinstance(hint, type):
227275
return hint
228276
if hint == typing.Self:
@@ -242,7 +290,7 @@ def extract_type_vars_from_hint(hint: typing.Any) -> List[TypeVar]:
242290

243291

244292
def get_type_vars(func: typing.Callable) -> List[TypeVar]:
245-
type_hints = typing.get_type_hints(func)
293+
type_hints = typing.get_type_hints(func, include_extras=True)
246294
type_vars = []
247295
for hint in type_hints.values():
248296
type_vars.extend(extract_type_vars_from_hint(hint))

luisa_lang/codegen/cpp.py

+22-16
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ def gen_impl(self, ty: hir.Type) -> str:
3535
match ty:
3636
case hir.IntType(bits=bits, signed=signed):
3737
int_names = {
38-
'8':'byte',
39-
'16':'short',
40-
'32':'int',
41-
'64':'long',
38+
'8': 'byte',
39+
'16': 'short',
40+
'32': 'int',
41+
'64': 'long',
4242
}
4343
if signed:
4444
return f"lc_{int_names[str(bits)]}"
@@ -85,11 +85,13 @@ def do():
8585
return self.gen(ty.instantiated)
8686
case hir.FunctionType():
8787
name = f'func_{unique_hash(ty.func_like.name)}_t'
88-
self.impl.writeln(f'struct {name} {{}}; // function type of {ty.func_like.name}')
88+
self.impl.writeln(
89+
f'struct {name} {{}}; // function type of {ty.func_like.name}')
8990
return name
9091
case hir.TypeConstructorType():
9192
name = f'type_{unique_hash(self.gen(ty.inner))}_t'
92-
self.impl.writeln(f'struct {name} {{}}; // type constructor of {ty.inner}')
93+
self.impl.writeln(
94+
f'struct {name} {{}}; // type constructor of {ty.inner}')
9395
return name
9496
case hir.OpaqueType():
9597
def do():
@@ -98,7 +100,8 @@ def do():
98100
elem_ty = self.gen(ty.extra_args[0])
99101
return f'__builtin__Buffer<{elem_ty}>'
100102
case _:
101-
raise NotImplementedError(f"unsupported opaque type: {ty.name}")
103+
raise NotImplementedError(
104+
f"unsupported opaque type: {ty.name}")
102105
return do()
103106
case hir.GenericIntType():
104107
return 'int'
@@ -225,7 +228,8 @@ def gen_function(self, func: hir.Function | Callable[..., Any]) -> str:
225228
if callable(func):
226229
dsl_func = get_dsl_func(func)
227230
assert dsl_func is not None
228-
assert not dsl_func.is_generic, f"Generic functions should be resolved before codegen: {func}"
231+
assert not dsl_func.is_generic, f"Generic functions should be resolved before codegen: {
232+
func}"
229233
func_tmp = dsl_func.resolve([])
230234
assert isinstance(
231235
func_tmp, hir.Function), f"Expected function, got {func_tmp}"
@@ -268,8 +272,9 @@ def __init__(self, base: CppCodeGen, func: hir.Function) -> None:
268272
params = ",".join(self.gen_var(
269273
p) for p in func.params)
270274
assert func.return_type
271-
272-
self.signature = f'auto {self.name}({params}) -> {base.type_cache.gen(func.return_type)}'
275+
276+
self.signature = f'auto {
277+
self.name}({params}) -> {base.type_cache.gen(func.return_type)}'
273278
if func.export:
274279
self.signature = f'extern "C" {self.signature}'
275280
if func.inline_hint == True:
@@ -304,14 +309,15 @@ def gen_ref(self, ref: hir.Ref) -> str:
304309
def do():
305310
intrin_name = intrin.name
306311
gened_args = [self.gen_value_or_ref(
307-
arg) for arg in intrin.args]
312+
arg) for arg in intrin.args]
308313
match intrin_name:
309314
case 'buffer.ref' | 'array.ref':
310315
return f"{gened_args[0]}[{gened_args[1]}]"
311316
case 'buffer.size' | 'array.size':
312317
return f"{gened_args[0]}.size"
313318
case _:
314-
raise RuntimeError(f"unsupported intrinsic reference: {intrin_name}")
319+
raise RuntimeError(
320+
f"unsupported intrinsic reference: {intrin_name}")
315321
return do()
316322
case _:
317323
raise NotImplementedError(f"unsupported reference: {ref}")
@@ -338,7 +344,7 @@ def gen_node_checked(self, node: hir.Node) -> str:
338344
if isinstance(node, hir.TypedNode) and isinstance(node.type, (hir.TypeConstructorType, hir.FunctionType)):
339345
assert node.type
340346
return f'{self.base.type_cache.gen(node.type)}{{}}'
341-
347+
342348
return self.node_map[node]
343349

344350
def gen_expr(self, expr: hir.Value) -> str:
@@ -440,7 +446,7 @@ def do():
440446
'__sub__': '-',
441447
'__mul__': '*',
442448
'__truediv__': '/',
443-
'__floordiv__': '/', # TODO: fix floordiv
449+
'__floordiv__': '/', # TODO: fix floordiv
444450
'__mod__': '%',
445451
'__pow__': '**',
446452
'__and__': '&',
@@ -460,7 +466,7 @@ def do():
460466
'__isub__': '-=',
461467
'__imul__': '*=',
462468
'__itruediv__': '/=',
463-
'__ifloordiv__': '/=', # TODO: fix floordiv
469+
'__ifloordiv__': '/=', # TODO: fix floordiv
464470
'__imod__': '%=',
465471
'__ipow__': '**=',
466472
'__iand__': '&=',
@@ -489,7 +495,7 @@ def do():
489495
args_s = ','.join(gened_args)
490496
self.body.writeln(
491497
f"auto v{vid} = __intrin__{intrin_name}({args_s});")
492-
498+
493499
do()
494500
case _:
495501
raise NotImplementedError(

luisa_lang/lang_builtins.py

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
N = TypeVar("N")
2828

2929

30+
3031
@func
3132
def dispatch_id() -> uint3:
3233
return intrinsic("dispatch_id", uint3)

luisa_lang/parse.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def handle_type_t():
153153

154154
def convert_func_signature(signature: classinfo.MethodType,
155155
ctx_name: str,
156-
props:hir.FuncProperties,
156+
props: hir.FuncProperties,
157157
globalns: Dict[str, Any],
158158
type_var_ns: Dict[typing.TypeVar, hir.Type],
159159
implicit_type_params: Dict[str, hir.Type],
@@ -194,7 +194,8 @@ def convert_func_signature(signature: classinfo.MethodType,
194194
params.append(
195195
Var(arg[0], implicit_type_params[arg[0]], span=None, semantic=semantic))
196196
return_type = type_parser.parse_type_ext(signature.return_type)
197-
assert return_type is not None, f"failed to parse return type {signature.return_type}"
197+
assert return_type is not None, f"failed to parse return type {
198+
signature.return_type}"
198199
if isinstance(return_type, hir.AnyBound):
199200
return_type = None
200201
elif isinstance(return_type, hir.TypeBound):

0 commit comments

Comments
 (0)