Skip to content

Commit 0d624e5

Browse files
committed
added comptime and reveal_type
1 parent d4c1044 commit 0d624e5

File tree

4 files changed

+177
-99
lines changed

4 files changed

+177
-99
lines changed

luisa_lang/hir.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,9 @@ class Ref(TypedNode):
730730
class Value(TypedNode):
731731
pass
732732

733+
class Unit(Value):
734+
def __init__(self) -> None:
735+
super().__init__(UnitType())
733736

734737
class SymbolicConstant(Value):
735738
generic: GenericParameter

luisa_lang/lang.py

Lines changed: 1 addition & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from luisa_lang.utils import get_full_name, unique_hash
2121
from luisa_lang.math_types import *
2222
from luisa_lang._builtin_decor import _builtin_type, _builtin, _intrinsic_impl
23+
from luisa_lang.lang_builtins import *
2324
import luisa_lang.hir as hir
2425
import luisa_lang.classinfo as classinfo
2526
import luisa_lang.parse as parse
@@ -224,80 +225,3 @@ def decorator(f):
224225
return decorator
225226

226227

227-
def type_of_opt(value: Any) -> Optional[hir.Type]:
228-
if isinstance(value, hir.Type):
229-
return value
230-
if isinstance(value, type):
231-
return hir.GlobalContext.get().types[value]
232-
return hir.GlobalContext.get().types.get(type(value))
233-
234-
235-
def typeof(value: Any) -> hir.Type:
236-
ty = type_of_opt(value)
237-
if ty is None:
238-
raise TypeError(f"Cannot determine type of {value}")
239-
return ty
240-
241-
242-
_t = hir.SymbolicType(hir.GenericParameter("_T", "luisa_lang.lang"))
243-
_n = hir.SymbolicConstant(hir.GenericParameter(
244-
"_N", "luisa_lang.lang")), typeof(u32)
245-
246-
247-
# @_builtin_type(
248-
# hir.ParametricType(
249-
# "Array", [hir.TypeParameter(_t, bound=[])], hir.ArrayType(_t, _n)
250-
# )
251-
# )
252-
class Array(Generic[_T, _N]):
253-
def __init__(self) -> None:
254-
return _intrinsic_impl()
255-
256-
def __getitem__(self, index: int | u32 | u64) -> _T:
257-
return _intrinsic_impl()
258-
259-
def __setitem__(self, index: int | u32 | u64, value: _T) -> None:
260-
return _intrinsic_impl()
261-
262-
def __len__(self) -> u32 | u64:
263-
return _intrinsic_impl()
264-
265-
266-
# @_builtin_type(
267-
# hir.ParametricType(
268-
# "Buffer", [hir.TypeParameter(_t, bound=[])], hir.OpaqueType("Buffer")
269-
# )
270-
# )
271-
class Buffer(Generic[_T]):
272-
def __getitem__(self, index: int | u32 | u64) -> _T:
273-
return _intrinsic_impl()
274-
275-
def __setitem__(self, index: int | u32 | u64, value: _T) -> None:
276-
return _intrinsic_impl()
277-
278-
def __len__(self) -> u32 | u64:
279-
return _intrinsic_impl()
280-
281-
282-
# @_builtin_type(
283-
# hir.ParametricType(
284-
# "Pointer", [hir.TypeParameter(_t, bound=[])], hir.PointerType(_t)
285-
# )
286-
# )
287-
class Pointer(Generic[_T]):
288-
def __getitem__(self, index: int | i32 | i64 | u32 | u64) -> _T:
289-
return _intrinsic_impl()
290-
291-
def __setitem__(self, index: int | i32 | i64 | u32 | u64, value: _T) -> None:
292-
return _intrinsic_impl()
293-
294-
@property
295-
def value(self) -> _T:
296-
return _intrinsic_impl()
297-
298-
@value.setter
299-
def value(self, value: _T) -> None:
300-
return _intrinsic_impl()
301-
302-
303-
# hir.GlobalContext.get().flush()

luisa_lang/lang_builtins.py

Lines changed: 95 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from typing import Any, Generic, List, Optional, Sequence, TypeVar, overload
12
from typing_extensions import TypeAliasType
23
from luisa_lang.lang import _builtin, _intrinsic_impl
3-
from luisa_lang.lang import *
4+
from luisa_lang.math_types import *
5+
import luisa_lang.hir as hir
46

57
_T = TypeVar("_T")
6-
8+
_N = TypeVar("_N")
79

810
@_builtin
911
def dispatch_id() -> uint3:
@@ -89,8 +91,98 @@ def unroll(range_: Sequence[int]) -> Sequence[int]:
8991

9092

9193
@_builtin
92-
def address_of(a: _T) -> Pointer[_T]:
94+
def address_of(a: _T) -> 'Pointer[_T]':
9395
return _intrinsic_impl()
9496

9597
# class StaticEval:
9698
#
99+
100+
101+
def type_of_opt(value: Any) -> Optional[hir.Type]:
102+
if isinstance(value, hir.Type):
103+
return value
104+
if isinstance(value, type):
105+
return hir.GlobalContext.get().types[value]
106+
return hir.GlobalContext.get().types.get(type(value))
107+
108+
109+
def typeof(value: Any) -> hir.Type:
110+
ty = type_of_opt(value)
111+
if ty is None:
112+
raise TypeError(f"Cannot determine type of {value}")
113+
return ty
114+
115+
116+
_t = hir.SymbolicType(hir.GenericParameter("_T", "luisa_lang.lang"))
117+
_n = hir.SymbolicConstant(hir.GenericParameter(
118+
"_N", "luisa_lang.lang")), typeof(u32)
119+
120+
121+
# @_builtin_type(
122+
# hir.ParametricType(
123+
# "Array", [hir.TypeParameter(_t, bound=[])], hir.ArrayType(_t, _n)
124+
# )
125+
# )
126+
class Array(Generic[_T, _N]):
127+
def __init__(self) -> None:
128+
return _intrinsic_impl()
129+
130+
def __getitem__(self, index: int | u32 | u64) -> _T:
131+
return _intrinsic_impl()
132+
133+
def __setitem__(self, index: int | u32 | u64, value: _T) -> None:
134+
return _intrinsic_impl()
135+
136+
def __len__(self) -> u32 | u64:
137+
return _intrinsic_impl()
138+
139+
140+
# @_builtin_type(
141+
# hir.ParametricType(
142+
# "Buffer", [hir.TypeParameter(_t, bound=[])], hir.OpaqueType("Buffer")
143+
# )
144+
# )
145+
class Buffer(Generic[_T]):
146+
def __getitem__(self, index: int | u32 | u64) -> _T:
147+
return _intrinsic_impl()
148+
149+
def __setitem__(self, index: int | u32 | u64, value: _T) -> None:
150+
return _intrinsic_impl()
151+
152+
def __len__(self) -> u32 | u64:
153+
return _intrinsic_impl()
154+
155+
156+
# @_builtin_type(
157+
# hir.ParametricType(
158+
# "Pointer", [hir.TypeParameter(_t, bound=[])], hir.PointerType(_t)
159+
# )
160+
# )
161+
class Pointer(Generic[_T]):
162+
def __getitem__(self, index: int | i32 | i64 | u32 | u64) -> _T:
163+
return _intrinsic_impl()
164+
165+
def __setitem__(self, index: int | i32 | i64 | u32 | u64, value: _T) -> None:
166+
return _intrinsic_impl()
167+
168+
@property
169+
def value(self) -> _T:
170+
return _intrinsic_impl()
171+
172+
@value.setter
173+
def value(self, value: _T) -> None:
174+
return _intrinsic_impl()
175+
176+
177+
__all__: List[str] = [
178+
'Pointer',
179+
'Buffer',
180+
'Array',
181+
'comptime',
182+
'device_log',
183+
'address_of',
184+
'unroll',
185+
'static_assert',
186+
'type_of_opt',
187+
'typeof',
188+
]

luisa_lang/parse.py

Lines changed: 78 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, overload
55
import typing
66
import luisa_lang
7+
from luisa_lang.lang_builtins import comptime
78
from luisa_lang.utils import get_typevar_constrains_and_bounds, report_error
89
import luisa_lang.hir as hir
910
import sys
@@ -114,6 +115,12 @@ def convert_func_signature(signature: classinfo.MethodType,
114115
return hir.FunctionSignature(type_parser.generic_params, params, return_type), type_parser
115116

116117

118+
SPECIAL_FUNCTIONS: Set[Callable[..., Any]] = {
119+
comptime,
120+
reveal_type,
121+
}
122+
123+
117124
class FuncParser:
118125
name: str
119126
func: object
@@ -138,7 +145,7 @@ def __init__(self, name: str,
138145
self.signature = signature
139146
self.globalns = globalns
140147
obj_ast, _obj_file = retrieve_ast_and_filename(func)
141-
print(ast.dump(obj_ast))
148+
# print(ast.dump(obj_ast))
142149
assert isinstance(obj_ast, ast.Module), f"{obj_ast} is not a module"
143150
if not isinstance(obj_ast.body[0], ast.FunctionDef):
144151
raise RuntimeError("Function definition expected.")
@@ -205,6 +212,18 @@ def parse_const(self, const: ast.Constant) -> hir.Value:
205212
report_error(
206213
const, f"unsupported constant type {type(value)}, wrap it in lc.comptime(...) if you intead to use it as a compile-time expression")
207214

215+
def convert_any_to_value(self, a: Any, span: hir.Span | None) -> hir.Value | ComptimeValue:
216+
if not isinstance(a, ComptimeValue):
217+
a = ComptimeValue(a, None)
218+
if a.value in SPECIAL_FUNCTIONS:
219+
return a
220+
if (converted := self.convert_constexpr(a, span)) is not None:
221+
return converted
222+
if is_valid_comptime_value_in_dsl_code(a.value):
223+
return a
224+
report_error(
225+
span, f"unsupported constant type {type(a.value)}, wrap it in lc.comptime(...) if you intead to use it as a compile-time expression")
226+
208227
def parse_name(self, name: ast.Name, maybe_new_var: bool) -> hir.Ref | hir.Value | ComptimeValue:
209228
span = hir.Span.from_ast(name)
210229
var = self.vars.get(name.id)
@@ -218,13 +237,9 @@ def parse_name(self, name: ast.Name, maybe_new_var: bool) -> hir.Ref | hir.Value
218237
# look up in global namespace
219238
if name.id in self.globalns:
220239
resolved = self.globalns[name.id]
240+
return self.convert_any_to_value(resolved, span)
221241
# assert isinstance(resolved, ComptimeValue), type(resolved)
222-
if not isinstance(resolved, ComptimeValue):
223-
resolved = ComptimeValue(resolved, None)
224-
if (converted := self.convert_constexpr(resolved, span)) is not None:
225-
return converted
226-
if is_valid_comptime_value_in_dsl_code(resolved.value):
227-
return resolved
242+
228243
report_error(name, f"unknown variable {name.id}")
229244

230245
def try_convert_comptime_value(self, value: ComptimeValue, span: hir.Span | None = None) -> hir.Value:
@@ -346,12 +361,49 @@ def parse_call_impl(self, span: hir.Span | None, f: hir.FunctionLike | hir.Funct
346361
return self.cur_bb().append(hir.Call(resolved_f, args, type=ty, span=span))
347362
raise NotImplementedError() # unreachable
348363

349-
def parse_call(self, expr: ast.Call) -> hir.Value:
364+
def handle_special_functions(self, f: Callable[..., Any], expr: ast.Call) -> hir.Value | ComptimeValue:
365+
match f:
366+
case _ if f == comptime:
367+
if len(expr.args) != 1:
368+
report_error(
369+
expr, f"when used in expressions, lc.comptime function expects exactly one argument")
370+
arg = expr.args[0]
371+
# print(ast.dump(arg))
372+
if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
373+
evaled = self.eval_expr(arg.value)
374+
else:
375+
evaled = self.eval_expr(arg)
376+
# print(evaled)
377+
v = self.convert_any_to_value(evaled, hir.Span.from_ast(expr))
378+
return v
379+
case _ if f == reveal_type:
380+
if len(expr.args) != 1:
381+
report_error(
382+
expr, f"lc.reveal_type expects exactly one argument")
383+
arg = expr.args[0]
384+
cur_bb = self.cur_bb()
385+
cur_bb_len = len(cur_bb.nodes)
386+
value = self.parse_expr(arg)
387+
assert cur_bb is self.cur_bb()
388+
del self.cur_bb().nodes[cur_bb_len:]
389+
unparsed_arg = ast.unparse(arg)
390+
if isinstance(value, ComptimeValue):
391+
print(
392+
f"Type of {unparsed_arg} is ComptimeValue({type(value.value)})")
393+
else:
394+
print(f"Type of {unparsed_arg} is {value.type}")
395+
return hir.Unit()
396+
case _:
397+
raise RuntimeError(f"Unsupported special function {f}")
398+
399+
def parse_call(self, expr: ast.Call) -> hir.Value | ComptimeValue:
350400
func = self.parse_expr(expr.func)
351401

352402
if isinstance(func, hir.Ref):
353403
report_error(expr, f"function expected")
354404
elif isinstance(func, ComptimeValue):
405+
if func.value in SPECIAL_FUNCTIONS:
406+
return self.handle_special_functions(func.value, expr)
355407
func = self.try_convert_comptime_value(
356408
func, hir.Span.from_ast(expr))
357409

@@ -471,9 +523,10 @@ def parse_expr(self, expr: ast.expr) -> hir.Value | ComptimeValue:
471523
case _:
472524
raise RuntimeError(f"Unsupported expression: {ast.dump(expr)}")
473525

474-
def eval_expr(self, tree: ast.Expression | ast.expr):
526+
def eval_expr(self, tree: str | ast.Expression | ast.expr):
475527
if isinstance(tree, ast.expr):
476528
tree = ast.Expression(tree)
529+
# print(tree)
477530
code_object = compile(tree, "<string>", "eval")
478531
localns = {}
479532
for name, v in self.vars.items():
@@ -531,18 +584,19 @@ def check_return_type(ty: hir.Type):
531584
report_error(
532585
stmt, f"expected {var.type}, got {value.type}")
533586
else:
587+
if not value.type.is_concrete():
588+
report_error(
589+
stmt, "only concrete type can be assigned, please annotate the variable with type hint")
534590
var.type = value.type
535591
self.cur_bb().append(hir.Assign(var, value, span))
536592
case ast.AnnAssign():
537593
var = self.parse_ref(stmt.target, maybe_new_var=True)
538-
if isinstance(var, hir.Value):
539-
report_error(stmt, f"value cannot be assigned")
540-
elif isinstance(var, hir.Ref):
541-
type_annotation = self.eval_expr(stmt.annotation)
542-
type_hint = classinfo.parse_type_hint(type_annotation)
543-
ty = self.parse_type(type_hint)
544-
assert ty
545-
var.type = ty
594+
595+
type_annotation = self.eval_expr(stmt.annotation)
596+
type_hint = classinfo.parse_type_hint(type_annotation)
597+
ty = self.parse_type(type_hint)
598+
assert ty
599+
var.type = ty
546600

547601
if stmt.value:
548602
value = self.parse_expr(stmt.value)
@@ -560,14 +614,19 @@ def check_return_type(ty: hir.Type):
560614
value = hir.Load(value)
561615
assert value.type
562616
assert ty
617+
if not var.type.is_concrete():
618+
report_error(
619+
stmt, "only concrete type can be assigned, please annotate the variable with concrete types")
563620
if not hir.is_type_compatible_to(value.type, ty):
564621
report_error(
565622
stmt, f"expected {ty}, got {value.type}")
623+
if not value.type.is_concrete():
624+
value.type = var.type
566625
self.cur_bb().append(hir.Assign(var, value, span))
567626
else:
568627
assert isinstance(var, hir.Var)
569-
case ast.Expression():
570-
self.parse_expr(stmt.body)
628+
case ast.Expr():
629+
self.parse_expr(stmt.value)
571630
case ast.Pass():
572631
return
573632
case _:

0 commit comments

Comments
 (0)