Skip to content

Commit 0e83814

Browse files
committed
allowing type[T] to be passed to functions
1 parent 4e956b3 commit 0e83814

File tree

7 files changed

+321
-185
lines changed

7 files changed

+321
-185
lines changed

luisa_lang/_builtin_decor.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class _ObjKind(Enum):
6969

7070

7171
def _make_func_template(f: Callable[..., Any], func_name: str, func_sig: Optional[MethodType],
72-
func_globals: Dict[str, Any], foreign_type_var_ns: Dict[TypeVar, hir.Type | hir.ComptimeValue],
72+
func_globals: Dict[str, Any], foreign_type_var_ns: Dict[TypeVar, hir.Type],
7373
props: hir.FuncProperties, self_type: Optional[hir.Type] = None):
7474
# parsing_ctx = _parse.ParsingContext(func_name, func_globals)
7575
# func_sig_parser = _parse.FuncParser(func_name, f, parsing_ctx, self_type)
@@ -88,8 +88,7 @@ def _make_func_template(f: Callable[..., Any], func_name: str, func_sig: Optiona
8888
implicit_generic_params.add(p.param)
8989

9090
def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.Function:
91-
type_var_ns: Dict[TypeVar, hir.Type |
92-
hir.ComptimeValue] = foreign_type_var_ns.copy()
91+
type_var_ns: Dict[TypeVar, hir.Type] = foreign_type_var_ns.copy()
9392
mapped_implicit_type_params: Dict[str,
9493
hir.Type] = dict()
9594
assert func_sig is not None
@@ -166,7 +165,7 @@ def _dsl_func_impl(f: _TT, kind: _ObjKind, attrs: Dict[str, Any]) -> _TT:
166165

167166

168167
_MakeTemplateFn = Callable[[List[hir.GenericParameter]], hir.Type]
169-
_InstantiateFn = Callable[[List[Any]], hir.Type]
168+
_InstantiateFn = Callable[[List[hir.Type]], hir.Type]
170169

171170

172171
def _dsl_struct_impl(cls: type[_TT], attrs: Dict[str, Any], ir_ty_override: hir.Type | Tuple[_MakeTemplateFn, _InstantiateFn] | None = None, opqaue_override: str | None = None) -> type[_TT]:
@@ -202,6 +201,8 @@ def parse_fields(tp: parse.TypeParser, self_ty: hir.Type):
202201

203202
def parse_methods(type_var_ns: Dict[TypeVar, hir.Type | Any], self_ty: hir.Type,):
204203
for name in cls_info.methods:
204+
if name == '__setitem__': # __setitem__ is ignored deliberately
205+
continue
205206
method_object = getattr(cls, name)
206207
props: hir.FuncProperties
207208
if hasattr(method_object, '__luisa_func_props__'):
@@ -214,7 +215,7 @@ def parse_methods(type_var_ns: Dict[TypeVar, hir.Type | Any], self_ty: hir.Type,
214215
method_object, get_full_name(method_object), cls_info.methods[name], globalns, type_var_ns, props, self_type=self_ty)
215216
if isinstance(self_ty, hir.BoundType):
216217
assert isinstance(self_ty.instantiated,
217-
(hir.StructType, hir.OpaqueType))
218+
(hir.ArrayType, hir.StructType, hir.OpaqueType))
218219
self_ty.instantiated.methods[name] = template
219220
else:
220221
self_ty.methods[name] = template
@@ -235,7 +236,7 @@ def parse_methods(type_var_ns: Dict[TypeVar, hir.Type | Any], self_ty: hir.Type,
235236
parse_fields(type_parser, ir_ty)
236237
is_generic = len(cls_info.type_vars) > 0
237238
if is_generic:
238-
def monomorphization_func(args: List[hir.Type | Any]) -> hir.Type:
239+
def monomorphization_func(args: List[hir.Type]) -> hir.Type:
239240
assert isinstance(ir_ty, hir.ParametricType)
240241
type_var_ns = {}
241242
if len(args) != len(cls_info.type_vars):

luisa_lang/classinfo.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,21 @@ def __repr__(self):
5656

5757
def __eq__(self, other):
5858
return isinstance(other, SelfType)
59+
60+
class LiteralType:
61+
value: Any
5962

63+
def __init__(self, value: Any):
64+
self.value = value
6065

61-
VarType = Union[TypeVar, Type, GenericInstance, UnionType, SelfType, AnyType]
66+
def __repr__(self):
67+
return f"Literal[{self.value}]"
68+
69+
def __eq__(self, other):
70+
return isinstance(other, LiteralType) and self.value == other.value
71+
72+
73+
VarType = Union[TypeVar, Type, GenericInstance, UnionType, SelfType, AnyType, LiteralType]
6274

6375

6476
def subst_type(ty: VarType, env: Dict[TypeVar, VarType]) -> VarType:
@@ -204,6 +216,8 @@ def parse_type_hint(hint: Any) -> VarType:
204216
return GenericInstance(origin, [parse_type_hint(arg) for arg in args])
205217
elif origin is Union:
206218
return UnionType([parse_type_hint(arg) for arg in typing.get_args(hint)])
219+
elif origin is Literal:
220+
return LiteralType(typing.get_args(hint)[0])
207221
else:
208222
raise RuntimeError(f"Unsupported origin type: {origin}")
209223

luisa_lang/codegen/cpp.py

+53-34
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,12 @@ def gen_impl(self, ty: hir.Type) -> str:
5656
raise RuntimeError("invalid float type")
5757
case hir.BoolType():
5858
return "lc_bool"
59+
case hir.PointerType(element=element):
60+
return f"lc_ptr<{self.gen(element)}>"
5961
case hir.VectorType(element=element, count=count):
6062
return f"{self.gen(element)}{count}>"
63+
case hir.ArrayType(element=element, count=count):
64+
return f"lc_array<{self.gen(element)}, {count}>"
6165
case hir.StructType(name=name, fields=fields):
6266
self.impl.writeln(f'struct {name} {{')
6367
for field in fields:
@@ -80,9 +84,13 @@ def do():
8084
assert ty.instantiated
8185
return self.gen(ty.instantiated)
8286
case hir.FunctionType():
83-
return ''
87+
name = f'func_{unique_hash(ty.func_like.name)}_t'
88+
self.impl.writeln(f'struct {name} {{}}; // function type of {ty.func_like.name}')
89+
return name
8490
case hir.TypeConstructorType():
85-
return ''
91+
name = f'type_{unique_hash(self.gen(ty.inner))}_t'
92+
self.impl.writeln(f'struct {name} {{}}; // type constructor of {ty.inner}')
93+
return name
8694
case hir.OpaqueType():
8795
def do():
8896
match ty.name:
@@ -171,8 +179,8 @@ def mangle_impl(self, obj: Union[hir.Type, hir.Function]) -> str:
171179
case hir.Function(name=name, params=params, return_type=ret):
172180
assert ret
173181
name = mangle_name(name)
174-
params = list(filter(lambda p: not isinstance(
175-
p.type, (hir.FunctionType)), params))
182+
# params = list(filter(lambda p: not isinstance(
183+
# p.type, (hir.FunctionType)), params))
176184
return f'{name}_' + unique_hash(f"F{name}_{self.mangle(ret)}{''.join(self.mangle(unwrap(p.type)) for p in params)}")
177185
case hir.StructType(name=name):
178186
return name
@@ -184,6 +192,10 @@ def mangle_impl(self, obj: Union[hir.Type, hir.Function]) -> str:
184192
return self.mangle(obj.instantiated)
185193
case hir.OpaqueType():
186194
return obj.name
195+
case hir.TypeConstructorType():
196+
return self.mangle(obj.inner)
197+
case hir.FunctionType():
198+
return f'func_{unique_hash(obj.func_like.name)}'
187199
case _:
188200
raise NotImplementedError(f"unsupported object: {obj}")
189201

@@ -246,7 +258,7 @@ def __init__(self, base: CppCodeGen, func: hir.Function) -> None:
246258
self.name = base.mangling.mangle(func)
247259
self.func = func
248260
params = ",".join(self.gen_var(
249-
p) for p in func.params if not isinstance(p.type, hir.FunctionType))
261+
p) for p in func.params)
250262
assert func.return_type
251263

252264
self.signature = f'auto {self.name}({params}) -> {base.type_cache.gen(func.return_type)}'
@@ -285,10 +297,13 @@ def do():
285297
intrin_name = intrin.name
286298
gened_args = [self.gen_value_or_ref(
287299
arg) for arg in intrin.args]
288-
if intrin_name == 'buffer_ref':
289-
return f"{gened_args[0]}[{gened_args[1]}]"
290-
else:
291-
raise RuntimeError(f"unsupported intrinsic reference: {intrin_name}")
300+
match intrin_name:
301+
case 'buffer.ref' | 'array.ref':
302+
return f"{gened_args[0]}[{gened_args[1]}]"
303+
case 'buffer.size' | 'array.size':
304+
return f"{gened_args[0]}.size"
305+
case _:
306+
raise RuntimeError(f"unsupported intrinsic reference: {intrin_name}")
292307
return do()
293308
case _:
294309
raise NotImplementedError(f"unsupported reference: {ref}")
@@ -312,17 +327,40 @@ def gen_value_or_ref(self, value: hir.Value | hir.Ref) -> str:
312327
def gen_node_checked(self, node: hir.Node) -> str:
313328
if isinstance(node, hir.Constant):
314329
return self.gen_expr(node)
330+
if isinstance(node, hir.TypedNode) and isinstance(node.type, (hir.TypeConstructorType, hir.FunctionType)):
331+
assert node.type
332+
return f'{self.base.type_cache.gen(node.type)}{{}}'
333+
315334
return self.node_map[node]
316335

317336
def gen_expr(self, expr: hir.Value) -> str:
318-
if expr.type and isinstance(expr.type, hir.FunctionType):
319-
return ''
337+
# if expr.type and isinstance(expr.type, hir.FunctionType):
338+
# return ''
339+
if isinstance(expr, hir.Constant):
340+
value = expr.value
341+
if isinstance(value, int):
342+
return f"{value}"
343+
elif isinstance(value, float):
344+
return f"{value}f"
345+
elif isinstance(value, bool):
346+
return "true" if value else "false"
347+
elif isinstance(value, str):
348+
return f"\"{value}\""
349+
elif isinstance(value, hir.Function):
350+
return self.gen_func(value)
351+
else:
352+
raise NotImplementedError(
353+
f"unsupported constant: {expr}")
320354
if expr in self.node_map:
321355
return self.node_map[expr]
322356
vid = self.new_vid()
323357

324358
def impl() -> None:
325359
match expr:
360+
case hir.TypeValue() as type_value:
361+
assert type_value.type
362+
self.base.type_cache.gen(type_value.type)
363+
return
326364
case hir.Load() as load:
327365
self.body.writeln(
328366
f"const auto &v{vid} = {self.gen_ref(load.ref)}; // load")
@@ -337,36 +375,17 @@ def impl() -> None:
337375
case hir.Call() as call:
338376
op = self.gen_func(call.op)
339377
args_s = ','.join(self.gen_value_or_ref(
340-
arg) for arg in call.args if not isinstance(arg.type, hir.FunctionType))
378+
arg) for arg in call.args)
341379
if call.type != hir.UnitType():
342380
self.body.writeln(
343381
f"auto v{vid} = {op}({args_s});")
344382
else:
345383
self.body.writeln(f"{op}({args_s});")
346-
case hir.Constant() as constant:
347-
value = constant.value
348-
if isinstance(value, int):
349-
self.body.writeln(f"const auto v{vid} = {value};")
350-
elif isinstance(value, float):
351-
self.body.writeln(f"const auto v{vid} = {value};")
352-
elif isinstance(value, bool):
353-
s = "true" if value else "false"
354-
self.body.writeln(f"const auto v{vid} = {s};")
355-
elif isinstance(value, str):
356-
self.body.writeln(f"const auto v{vid} = \"{value}\";")
357-
elif isinstance(value, hir.Function):
358-
name = self.gen_func(value)
359-
self.body.writeln(f"auto&& v{vid} = {name};")
360-
else:
361-
raise NotImplementedError(
362-
f"unsupported constant: {constant}")
363384
case hir.AggregateInit():
364385
assert expr.type
365386
ty = self.base.type_cache.gen(expr.type)
366387
self.body.writeln(
367388
f"{ty} v{vid}{{ {','.join(self.gen_expr(e) for e in expr.args)} }};")
368-
case hir.TypeValue():
369-
pass
370389
case hir.Intrinsic() as intrin:
371390
def do():
372391
intrin_name = intrin.name
@@ -544,7 +563,7 @@ def gen_node(self, node: hir.Node) -> Optional[hir.BasicBlock]:
544563
ty = self.base.type_cache.gen(alloca.type)
545564
self.body.writeln(f"{ty} v{vid}{{}};")
546565
self.node_map[alloca] = f"v{vid}"
547-
case hir.AggregateInit() | hir.Intrinsic() | hir.Call() | hir.Constant() | hir.Load() | hir.Index() | hir.Member():
566+
case hir.AggregateInit() | hir.Intrinsic() | hir.Call() | hir.Constant() | hir.Load() | hir.Index() | hir.Member() | hir.TypeValue() | hir.FunctionValue():
548567
self.gen_expr(node)
549568
case hir.Member() | hir.Index():
550569
pass
@@ -570,8 +589,8 @@ def gen_locals(self):
570589
for local in self.func.locals:
571590
if local.name in self.params:
572591
continue
573-
if isinstance(local.type, (hir.FunctionType, hir.TypeConstructorType)):
574-
continue
592+
# if isinstance(local.type, (hir.FunctionType, hir.TypeConstructorType)):
593+
# continue
575594
assert (
576595
local.type
577596
), f"Local variable `{local.name}` contains unresolved type"

0 commit comments

Comments
 (0)