Skip to content

Commit b51cfb6

Browse files
committed
while loop and multi assignment
1 parent da68f83 commit b51cfb6

File tree

5 files changed

+254
-74
lines changed

5 files changed

+254
-74
lines changed

luisa_lang/codegen/cpp.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@ def gen_impl(self, ty: hir.Type) -> str:
4444
return name
4545
case hir.UnitType():
4646
return 'void'
47+
case hir.TupleType():
48+
def do():
49+
elements = [self.gen(e) for e in ty.elements]
50+
name = f'Tuple_{unique_hash("".join(elements))}'
51+
self.impl.writeln(f'struct {name} {{')
52+
for i, element in enumerate(elements):
53+
self.impl.writeln(f' {element} _{i};')
54+
self.impl.writeln('};')
55+
return name
56+
return do()
4757
case _:
4858
raise NotImplementedError(f"unsupported type: {ty}")
4959

@@ -129,6 +139,9 @@ def mangle_impl(self, obj: Union[hir.Type, hir.FunctionLike]) -> str:
129139
return f"__builtin_{name}"
130140
case hir.StructType(name=name):
131141
return name
142+
case hir.TupleType():
143+
elements = [self.mangle(e) for e in obj.elements]
144+
return f"T{unique_hash(''.join(elements))}"
132145
case _:
133146
raise NotImplementedError(f"unsupported object: {obj}")
134147

@@ -275,6 +288,10 @@ def impl() -> None:
275288
else:
276289
raise NotImplementedError(
277290
f"unsupported constant: {constant}")
291+
case hir.AggregateInit():
292+
assert expr.type
293+
ty = self.base.type_cache.gen(expr.type)
294+
self.body.writeln(f"{ty} v{vid}{{}};")
278295
case _:
279296
raise NotImplementedError(
280297
f"unsupported expression: {expr}")
@@ -310,12 +327,12 @@ def gen_node(self, node: hir.Node):
310327
vid = self.new_vid()
311328
self.body.write(f"auto loop{vid}_prepare = [&]()->bool {{")
312329
self.body.indent += 1
313-
self.gen_bb(loop.prepare)
330+
self.gen_bb(loop.prepare)
314331
if loop.cond:
315332
self.body.writeln(f"return {self.gen_expr(loop.cond)};")
316333
else:
317334
self.body.writeln("return true;")
318-
self.body.indent -=1
335+
self.body.indent -= 1
319336
self.body.writeln("};")
320337
self.body.writeln(f"auto loop{vid}_body = [&]() {{")
321338
self.body.indent += 1
@@ -354,7 +371,7 @@ def gen_locals(self):
354371
continue
355372
assert (
356373
local.type
357-
), f"Local variable {local.name} contains unresolved type, please resolve it via TypeInferencer"
374+
), f"Local variable `{local.name}` contains unresolved type"
358375
self.body.writeln(
359376
f"{self.base.type_cache.gen(local.type)} {local.name}{{}};"
360377
)

luisa_lang/hir.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
Tuple,
1313
Dict,
1414
Union,
15-
cast,
1615
)
1716
import typing
1817
from typing_extensions import override
@@ -128,6 +127,9 @@ def method(self, name: str) -> Optional[FunctionLike | FunctionTemplate]:
128127

129128
def is_concrete(self) -> bool:
130129
return True
130+
131+
def __len__(self) -> int:
132+
return 1
131133

132134

133135
class UnitType(Type):
@@ -337,7 +339,9 @@ def member(self, field: Any) -> Optional['Type']:
337339
return self.element
338340
return Type.member(self, field)
339341

340-
342+
def __len__(self) -> int:
343+
return self.count
344+
341345
class ArrayType(Type):
342346
element: Type
343347
count: Union[int, "SymbolicConstant"]
@@ -789,7 +793,7 @@ class Index(Value):
789793
index: Value
790794

791795
def __init__(self, base: Value, index: Value, type: Type, span: Optional[Span]) -> None:
792-
super().__init__(None, span)
796+
super().__init__(type, span)
793797
self.base = base
794798
self.index = index
795799

@@ -857,6 +861,12 @@ def __init__(self, ty: Type, span: Optional[Span] = None) -> None:
857861
# super().__init__(ty, span)
858862
# self.init_call = init_call
859863

864+
class AggregateInit(Value):
865+
args: List[Value]
866+
867+
def __init__(self, args: List[Value], type: Type, span: Optional[Span] = None) -> None:
868+
super().__init__(type, span)
869+
self.args = args
860870

861871
class Call(Value):
862872
op: FunctionLike

luisa_lang/lang.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from luisa_lang.classinfo import VarType, GenericInstance, UnionType, _get_cls_globalns, register_class, class_typeinfo
22
from enum import Enum, auto
33
from typing_extensions import TypeAliasType
4+
import typing
45
from typing import (
56
Callable,
67
Dict,
@@ -14,7 +15,6 @@
1415
Union,
1516
Generic,
1617
Literal,
17-
cast,
1818
overload,
1919
Any,
2020
)
@@ -109,7 +109,7 @@ def _dsl_func_impl(f: _T, kind: _ObjKind, attrs: Dict[str, Any]) -> _T:
109109
template = _make_func_template(f, func_name, func_globals)
110110
ctx.functions[f] = template
111111
setattr(f, "__luisa_func__", template)
112-
return cast(_T, f)
112+
return typing.cast(_T, f)
113113
else:
114114
raise NotImplementedError()
115115
# return cast(_T, f)
@@ -150,7 +150,7 @@ def get_ir_type(var_ty: VarType) -> hir.Type:
150150
def _dsl_decorator_impl(obj: _T, kind: _ObjKind, attrs: Dict[str, Any]) -> _T:
151151
if kind == _ObjKind.STRUCT:
152152
assert isinstance(obj, type), f"{obj} is not a type"
153-
return cast(_T, _dsl_struct_impl(obj, attrs))
153+
return typing.cast(_T, _dsl_struct_impl(obj, attrs))
154154
elif kind == _ObjKind.FUNC or kind == _ObjKind.KERNEL:
155155
return _dsl_func_impl(obj, kind, attrs)
156156
raise NotImplementedError()

luisa_lang/lang_builtins.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def block_id() -> uint3:
2323

2424

2525
@_builtin
26-
def convert(target: type[_T], value: Any) -> _T:
26+
def cast(target: type[_T], value: Any) -> _T:
2727
"""
2828
Attempt to convert the value to the target type.
2929
"""
@@ -185,4 +185,7 @@ def value(self, value: _T) -> None:
185185
'static_assert',
186186
'type_of_opt',
187187
'typeof',
188+
"dispatch_id",
189+
"thread_id",
190+
"block_id",
188191
]

0 commit comments

Comments
 (0)