Skip to content

Commit 78733cb

Browse files
committed
various fix
1 parent baf2020 commit 78733cb

File tree

6 files changed

+424
-306
lines changed

6 files changed

+424
-306
lines changed

luisa_lang/_builtin_decor.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def _make_func_template(f: Callable[..., Any], func_name: str, func_sig: Optiona
8585
assert isinstance(p, hir.SymbolicType)
8686
implicit_generic_params.add(p.param)
8787

88-
def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.FunctionLike:
88+
def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.Function:
8989
type_var_ns: Dict[TypeVar, hir.Type |
9090
hir.ComptimeValue] = foreign_type_var_ns.copy()
9191
mapped_implicit_type_params: Dict[str,
@@ -127,7 +127,7 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.FunctionLike:
127127
func_name, f, func_sig_instantiated, func_globals, type_var_ns, self_type)
128128
ret = func_parser.parse_body()
129129
ret.inline_hint = props.inline
130-
ret.export = props.export
130+
ret.export = props.export
131131
return ret
132132
params = [v[0] for v in func_sig.args]
133133
is_generic = len(func_sig_converted.generic_params) > 0
@@ -271,9 +271,9 @@ def volume(self) -> float:
271271
return _dsl_decorator_impl(cls, _ObjKind.STRUCT, {})
272272

273273

274-
def builtin_type(ty: hir.Type, *args, **kwargs) -> Callable[[type[_TT]], _TT]:
275-
def decorator(cls: type[_TT]) -> _TT:
276-
return typing.cast(_TT, _dsl_struct_impl(cls, {}, ir_ty_override=ty))
274+
def builtin_type(ty: hir.Type, *args, **kwargs) -> Callable[[type[_TT]], type[_TT]]:
275+
def decorator(cls: type[_TT]) -> type[_TT]:
276+
return typing.cast(type[_TT], _dsl_struct_impl(cls, {}, ir_ty_override=ty))
277277
return decorator
278278

279279

luisa_lang/codegen/cpp.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -120,20 +120,20 @@ def mangle_name(name: str) -> str:
120120

121121

122122
class Mangling:
123-
cache: Dict[hir.Type | hir.FunctionLike, str]
123+
cache: Dict[hir.Type | hir.Function, str]
124124

125125
def __init__(self) -> None:
126126
self.cache = {}
127127

128-
def mangle(self, obj: Union[hir.Type, hir.FunctionLike]) -> str:
128+
def mangle(self, obj: Union[hir.Type, hir.Function]) -> str:
129129
if obj in self.cache:
130130
return self.cache[obj]
131131
else:
132132
res = self.mangle_impl(obj)
133133
self.cache[obj] = res
134134
return res
135135

136-
def mangle_impl(self, obj: Union[hir.Type, hir.FunctionLike]) -> str:
136+
def mangle_impl(self, obj: Union[hir.Type, hir.Function]) -> str:
137137

138138
match obj:
139139
case hir.UnitType():
@@ -266,7 +266,7 @@ def gen_ref(self, ref: hir.Ref) -> str:
266266
case _:
267267
raise NotImplementedError(f"unsupported reference: {ref}")
268268

269-
def gen_func(self, f: hir.FunctionLike) -> str:
269+
def gen_func(self, f: hir.Function) -> str:
270270
if isinstance(f, hir.Function):
271271
return self.base.gen_function(f)
272272
else:
@@ -346,7 +346,12 @@ def do():
346346
comps = intrin_name.split('.')
347347
gened_args = [self.gen_value_or_ref(
348348
arg) for arg in intrin.args]
349-
if comps[0] == 'cmp':
349+
if comps[0] == 'init':
350+
assert expr.type
351+
ty = self.base.type_cache.gen(expr.type)
352+
self.body.writeln(
353+
f"{ty} v{vid}{{ {','.join(gened_args)} }};")
354+
elif comps[0] == 'cmp':
350355
cmp_dict = {
351356
'__eq__': '==',
352357
'__ne__': '!=',

luisa_lang/hir.py

+117-19
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,10 @@
2222

2323
PATH_PREFIX = "luisa_lang"
2424

25-
FunctionLike = Union["Function"]
26-
2725

2826
# @dataclass
2927
# class FunctionTemplateResolveResult:
30-
# func: Optional[FunctionLike]
28+
# func: Optional[Function]
3129
# matched: bool
3230

3331

@@ -40,17 +38,18 @@
4038

4139

4240
FunctionTemplateResolvingFunc = Callable[[
43-
FunctionTemplateResolvingArgs], Union[FunctionLike, 'TemplateMatchingError']]
41+
FunctionTemplateResolvingArgs], Union['Function', 'TemplateMatchingError']]
42+
4443

4544
class FuncProperties:
46-
inline: bool | Literal["always"]
45+
inline: bool | Literal["never", "always"]
4746
export: bool
4847
byref: Set[str]
4948

5049
def __init__(self):
5150
self.inline = False
5251
self.export = False
53-
self.byref = set()
52+
self.byref = set()
5453

5554

5655
class FunctionTemplate:
@@ -63,7 +62,7 @@ class FunctionTemplate:
6362
"""
6463
parsing_func: FunctionTemplateResolvingFunc
6564
__resolved: Dict[Tuple[Tuple[str,
66-
Union['Type', Any]], ...], FunctionLike]
65+
Union['Type', Any]], ...], "Function"]
6766
is_generic: bool
6867
name: str
6968
params: List[str]
@@ -78,7 +77,7 @@ def __init__(self, name: str, params: List[str], parsing_func: FunctionTemplateR
7877
self.name = name
7978
self.props = None
8079

81-
def resolve(self, args: FunctionTemplateResolvingArgs | None) -> Union[FunctionLike, 'TemplateMatchingError']:
80+
def resolve(self, args: FunctionTemplateResolvingArgs | None) -> Union["Function", 'TemplateMatchingError']:
8281
args = args or []
8382
if not self.is_generic:
8483
key = tuple(args)
@@ -101,7 +100,7 @@ class DynamicIndex:
101100

102101

103102
class Type(ABC):
104-
methods: Dict[str, Union[FunctionLike]]
103+
methods: Dict[str, Union["Function", FunctionTemplate]]
105104
is_builtin: bool
106105

107106
def __init__(self):
@@ -132,7 +131,7 @@ def member(self, field: Any) -> Optional['Type']:
132131
return FunctionType(m, None)
133132
return None
134133

135-
def method(self, name: str) -> Optional[FunctionLike | FunctionTemplate]:
134+
def method(self, name: str) -> Optional[Union["Function", FunctionTemplate]]:
136135
m = self.methods.get(name)
137136
if m:
138137
return m
@@ -738,7 +737,7 @@ def member(self, field) -> Optional['Type']:
738737
raise RuntimeError("member access on uninstantiated BoundType")
739738

740739
@override
741-
def method(self, name) -> Optional[FunctionLike | FunctionTemplate]:
740+
def method(self, name) -> Optional[Union["Function", FunctionTemplate]]:
742741
if self.instantiated is not None:
743742
return self.instantiated.method(name)
744743
else:
@@ -766,10 +765,10 @@ def __hash__(self) -> int:
766765

767766

768767
class FunctionType(Type):
769-
func_like: FunctionLike | FunctionTemplate
768+
func_like: Union["Function", FunctionTemplate]
770769
bound_object: Optional['Ref']
771770

772-
def __init__(self, func_like: FunctionLike | FunctionTemplate, bound_object: Optional['Ref']) -> None:
771+
def __init__(self, func_like: Union["Function", FunctionTemplate], bound_object: Optional['Ref']) -> None:
773772
super().__init__()
774773
self.func_like = func_like
775774
self.bound_object = bound_object
@@ -950,6 +949,8 @@ def __eq__(self, value: object) -> bool:
950949

951950
def __hash__(self) -> int:
952951
return hash(self.value)
952+
953+
953954
class TypeValue(Value):
954955
def __init__(self, ty: Type, span: Optional[Span] = None) -> None:
955956
super().__init__(TypeConstructorType(ty), span)
@@ -958,10 +959,12 @@ def inner_type(self) -> Type:
958959
assert isinstance(self.type, TypeConstructorType)
959960
return self.type.inner
960961

962+
961963
class FunctionValue(Value):
962-
def __init__(self, ty:FunctionType, span: Optional[Span] = None) -> None:
964+
def __init__(self, ty: FunctionType, span: Optional[Span] = None) -> None:
963965
super().__init__(ty, span)
964966

967+
965968
class Alloca(Ref):
966969
"""
967970
A temporary variable
@@ -1003,14 +1006,14 @@ def __repr__(self) -> str:
10031006

10041007

10051008
class Call(Value):
1006-
op: FunctionLike
1009+
op: "Function"
10071010
"""After type inference, op should be a Value."""
10081011

10091012
args: List[Value | Ref]
10101013

10111014
def __init__(
10121015
self,
1013-
op: FunctionLike,
1016+
op: "Function",
10141017
args: List[Value | Ref],
10151018
type: Type,
10161019
span: Optional[Span] = None,
@@ -1077,7 +1080,7 @@ class Assign(Node):
10771080
value: Value
10781081

10791082
def __init__(self, ref: Ref, value: Value, span: Optional[Span] = None) -> None:
1080-
assert not isinstance(value.type, (FunctionType, TypeConstructorType))
1083+
assert not isinstance(value.type, (FunctionType, TypeConstructorType))
10811084
super().__init__(span)
10821085
self.ref = ref
10831086
self.value = value
@@ -1206,7 +1209,7 @@ class Function:
12061209
locals: List[Var]
12071210
complete: bool
12081211
is_method: bool
1209-
inline_hint: Literal[True, 'always', 'never'] | None
1212+
inline_hint: bool | Literal['always', 'never']
12101213

12111214
def __init__(
12121215
self,
@@ -1223,7 +1226,7 @@ def __init__(
12231226
self.locals = []
12241227
self.complete = False
12251228
self.is_method = is_method
1226-
self.inline_hint = None
1229+
self.inline_hint = False
12271230

12281231

12291232
def match_template_args(
@@ -1408,3 +1411,98 @@ def is_type_compatible_to(ty: Type, target: Type) -> bool:
14081411
if isinstance(target, IntType):
14091412
return isinstance(ty, GenericIntType)
14101413
return False
1414+
1415+
1416+
class FunctionInliner:
1417+
mapping: Dict[Ref | Value, Ref | Value]
1418+
ret: Value | None
1419+
1420+
def __init__(self, func: Function, args: List[Value | Ref], body: BasicBlock, span: Optional[Span] = None) -> None:
1421+
self.mapping = {}
1422+
for param, arg in zip(func.params, args):
1423+
self.mapping[param] = arg
1424+
assert func.body
1425+
self.do_inline(func.body, body)
1426+
1427+
def do_inline(self, func_body: BasicBlock, body: BasicBlock) -> None:
1428+
for node in func_body.nodes:
1429+
assert node not in self.mapping
1430+
1431+
match node:
1432+
case Var():
1433+
assert node.type
1434+
assert node.semantic == ParameterSemantic.BYVAL
1435+
self.mapping[node] = Alloca(node.type, node.span)
1436+
case Load():
1437+
mapped_var = self.mapping[node.ref]
1438+
assert isinstance(mapped_var, Ref)
1439+
body.append(Load(mapped_var))
1440+
case Index():
1441+
base = self.mapping.get(node.base)
1442+
assert isinstance(base, Value)
1443+
index = self.mapping.get(node.index)
1444+
assert isinstance(index, Value)
1445+
assert node.type
1446+
self.mapping[node] = body.append(
1447+
Index(base, index, node.type, node.span))
1448+
case IndexRef():
1449+
base = self.mapping.get(node.base)
1450+
index = self.mapping.get(node.index)
1451+
assert isinstance(base, Ref)
1452+
assert isinstance(index, Value)
1453+
assert node.type
1454+
self.mapping[node] = body.append(IndexRef(
1455+
base, index, node.type, node.span))
1456+
case Member():
1457+
base = self.mapping.get(node.base)
1458+
assert isinstance(base, Value)
1459+
assert node.type
1460+
self.mapping[node] = body.append(Member(
1461+
base, node.field, node.type, node.span))
1462+
case MemberRef():
1463+
base = self.mapping.get(node.base)
1464+
assert isinstance(base, Ref)
1465+
assert node.type
1466+
self.mapping[node] = body.append(MemberRef(
1467+
base, node.field, node.type, node.span))
1468+
case Call() as call:
1469+
def do():
1470+
args: List[Ref | Value] = []
1471+
for arg in call.args:
1472+
mapped_arg = self.mapping.get(arg)
1473+
if mapped_arg is None:
1474+
raise ParsingError(node, "unable to inline call")
1475+
args.append(mapped_arg)
1476+
assert call.type
1477+
self.mapping[call] = body.append(
1478+
Call(call.op, args, call.type, node.span))
1479+
do()
1480+
case Intrinsic() as intrin:
1481+
def do():
1482+
args: List[Ref | Value] = []
1483+
for arg in intrin.args:
1484+
mapped_arg = self.mapping.get(arg)
1485+
if mapped_arg is None:
1486+
raise ParsingError(
1487+
node, "unable to inline intrinsic")
1488+
args.append(mapped_arg)
1489+
assert intrin.type
1490+
self.mapping[intrin] = body.append(
1491+
Intrinsic(intrin.name, args, intrin.type, node.span))
1492+
do()
1493+
case Return():
1494+
if self.ret is not None:
1495+
raise ParsingError(node, "multiple return statement")
1496+
assert node.value is not None
1497+
mapped_value = self.mapping.get(node.value)
1498+
if mapped_value is None or isinstance(mapped_value, Ref):
1499+
raise ParsingError(node, "unable to inline return")
1500+
self.ret = mapped_value
1501+
case _:
1502+
raise ParsingError(node, "invalid node for inlining")
1503+
1504+
@staticmethod
1505+
def inline(func: Function, args: List[Value | Ref], body: BasicBlock, span: Optional[Span] = None) -> Value:
1506+
inliner = FunctionInliner(func, args, body, span)
1507+
assert inliner.ret
1508+
return inliner.ret

0 commit comments

Comments
 (0)