Skip to content

Commit 49f912a

Browse files
committed
refactoring intrinsics
1 parent e3fec07 commit 49f912a

File tree

5 files changed

+84
-16
lines changed

5 files changed

+84
-16
lines changed

luisa_lang/_builtin_decor.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ def make_builtin():
134134
return decorator
135135

136136

137+
138+
139+
137140
def builtin(s: str) -> Callable[[_F], _F]:
138141
def wrapper(func: _F) -> _F:
139142
setattr(func, "__luisa_builtin__", s)
@@ -207,7 +210,7 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.FunctionLike:
207210
params = [v[0] for v in func_sig.args]
208211
is_generic = len(func_sig_converted.generic_params) > 0
209212
# print(
210-
# f"func {func_name} is_generic: {is_generic} {func_sig_converted.generic_params}")
213+
# f"func {func_name} is_generic: {is_generic} {func_sig_converted.generic_params}")
211214
return hir.FunctionTemplate(func_name, params, parsing_func, is_generic)
212215

213216

@@ -303,7 +306,7 @@ def monomorphization_func(args: List[hir.Type | Any]) -> hir.Type:
303306
pass
304307
ctx.types[cls] = ir_ty
305308
if not is_generic:
306-
parse_methods({},ir_ty)
309+
parse_methods({}, ir_ty)
307310
return cls
308311

309312

luisa_lang/codegen/cpp.py

+6
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,12 @@ def impl() -> None:
312312
f"{ty} v{vid}{{ {','.join(self.gen_expr(e) for e in expr.args)} }};")
313313
case hir.TypeValue():
314314
pass
315+
case hir.Intrinsic():
316+
intrin_name = expr.name.replace('.', '_')
317+
args_s = ','.join(self.gen_value_or_ref(
318+
arg) for arg in expr.args)
319+
self.body.writeln(
320+
f"auto v{vid} = __intrin__{intrin_name}({args_s});")
315321
case _:
316322
raise NotImplementedError(
317323
f"unsupported expression: {expr}")

luisa_lang/hir.py

+30-3
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,21 @@ def is_concrete(self) -> bool:
131131
def __len__(self) -> int:
132132
return 1
133133

134+
class AnyType(Type):
135+
def size(self) -> int:
136+
raise RuntimeError("AnyType has no size")
137+
138+
def align(self) -> int:
139+
raise RuntimeError("AnyType has no align")
140+
141+
def __eq__(self, value: object) -> bool:
142+
return isinstance(value, AnyType)
143+
144+
def __hash__(self) -> int:
145+
return hash(AnyType)
146+
147+
def __str__(self) -> str:
148+
return "AnyType"
134149

135150
class UnitType(Type):
136151
def size(self) -> int:
@@ -807,7 +822,6 @@ def __init__(self, value: 'Value') -> None:
807822
class Value(TypedNode):
808823
pass
809824

810-
811825
class Unit(Value):
812826
def __init__(self) -> None:
813827
super().__init__(UnitType())
@@ -907,7 +921,10 @@ class TypeValue(Value):
907921
def __init__(self, ty: Type, span: Optional[Span] = None) -> None:
908922
super().__init__(TypeConstructorType(ty), span)
909923

910-
924+
def inner_type(self) -> Type:
925+
assert isinstance(self.type, TypeConstructorType)
926+
return self.type.inner
927+
911928
class Alloca(Ref):
912929
"""
913930
A temporary variable
@@ -917,6 +934,8 @@ def __init__(self, ty: Type, span: Optional[Span] = None) -> None:
917934
super().__init__(ty, span)
918935

919936

937+
938+
920939
# class Init(Value):
921940
# init_call: 'Call'
922941

@@ -931,6 +950,14 @@ def __init__(self, args: List[Value], type: Type, span: Optional[Span] = None) -
931950
super().__init__(type, span)
932951
self.args = args
933952

953+
class Intrinsic(Value):
954+
name: str
955+
args: List[Value]
956+
957+
def __init__(self, name: str, args: List[Value], type: Type, span: Optional[Span] = None) -> None:
958+
super().__init__(type, span)
959+
self.name = name
960+
self.args = args
934961

935962
class Call(Value):
936963
op: FunctionLike
@@ -1335,7 +1362,7 @@ def get_dsl_type(cls: type) -> Optional[Type]:
13351362

13361363

13371364
def is_type_compatible_to(ty: Type, target: Type) -> bool:
1338-
if ty == target:
1365+
if ty == target or isinstance(ty, AnyType):
13391366
return True
13401367
if isinstance(target, FloatType):
13411368
return isinstance(ty, GenericFloatType)

luisa_lang/lang_builtins.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@
2626
N = TypeVar("N", int, u32, u64)
2727

2828

29+
def intrinsic(name: str, ret_type: type[T], *args, **kwargs) -> T:
30+
raise NotImplementedError(
31+
"intrinsic functions should not be called in host-side Python code. "
32+
"Did you mistakenly called a DSL function?"
33+
)
34+
35+
2936
@builtin("dispatch_id")
3037
def dispatch_id() -> uint3:
3138
return _intrinsic_impl()
@@ -103,7 +110,10 @@ def comptime(a):
103110
return a
104111

105112

106-
parse.comptime = comptime
113+
parse._add_special_function("comptime", comptime)
114+
parse._add_special_function("intrinsic", intrinsic)
115+
parse._add_special_function("range", range)
116+
parse._add_special_function('reveal_type', typing.reveal_type)
107117

108118

109119
def static_assert(cond: Any, msg: str = ""):
@@ -160,6 +170,7 @@ def __setitem__(self, index: int | u32 | u64, value: T) -> None:
160170
def __len__(self) -> u32 | u64:
161171
return _intrinsic_impl()
162172

173+
163174
def __buffer_ty():
164175
t = hir.GenericParameter("T", "luisa_lang.lang")
165176
return hir.ParametricType(
@@ -171,6 +182,8 @@ def __buffer_ty():
171182
# # "Buffer", [hir.TypeParameter(_t, bound=[])], hir.OpaqueType("Buffer")
172183
# # )
173184
# )
185+
186+
174187
class Buffer(Generic[T]):
175188
def __getitem__(self, index: int | u32 | u64) -> T:
176189
return _intrinsic_impl()
@@ -216,4 +229,5 @@ def value(self, value: T) -> None:
216229
"dispatch_id",
217230
"thread_id",
218231
"block_id",
232+
"intrinsic",
219233
]

luisa_lang/parse.py

+28-10
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
from luisa_lang.hir import get_dsl_type, ComptimeValue
2525
import luisa_lang.classinfo as classinfo
2626

27-
comptime: Any = None
28-
2927

3028
def _implicit_typevar_name(v: str) -> str:
3129
return f"T#{v}"
@@ -156,11 +154,14 @@ def convert_func_signature(signature: classinfo.MethodType,
156154
return hir.FunctionSignature(type_parser.generic_params, params, return_type), type_parser
157155

158156

159-
SPECIAL_FUNCTIONS: Set[Callable[..., Any]] = {
160-
comptime,
161-
reveal_type,
162-
range
163-
}
157+
SPECIAL_FUNCTIONS_DICT: Dict[str, Callable[..., Any]] = {}
158+
SPECIAL_FUNCTIONS: Set[Callable[..., Any]] = set()
159+
160+
161+
def _add_special_function(name: str, f: Callable[..., Any]) -> None:
162+
SPECIAL_FUNCTIONS_DICT[name] = f
163+
SPECIAL_FUNCTIONS.add(f)
164+
164165

165166
NewVarHint = Literal[False, 'dsl', 'comptime']
166167

@@ -390,7 +391,8 @@ def parse_type_arg(expr: ast.expr) -> hir.Type:
390391
case _:
391392
type_args.append(parse_type_arg(expr.slice))
392393
# print(f"Type args: {type_args}")
393-
assert isinstance(value.type, hir.TypeConstructorType) and isinstance(value.type.inner, hir.ParametricType)
394+
assert isinstance(value.type, hir.TypeConstructorType) and isinstance(
395+
value.type.inner, hir.ParametricType)
394396
return hir.TypeValue(
395397
hir.BoundType(value.type.inner, type_args, value.type.inner.instantiate(type_args)))
396398

@@ -476,7 +478,23 @@ def parse_call_impl(self, span: hir.Span | None, f: hir.FunctionLike | hir.Funct
476478
raise NotImplementedError() # unreachable
477479

478480
def handle_special_functions(self, f: Callable[..., Any], expr: ast.Call) -> hir.Value | ComptimeValue:
479-
if f is comptime:
481+
if f is SPECIAL_FUNCTIONS_DICT['intrinsic']:
482+
def do() -> hir.Intrinsic:
483+
intrinsic_name = expr.args[0]
484+
if not isinstance(intrinsic_name, ast.Constant) or not isinstance(intrinsic_name.value, str):
485+
raise hir.ParsingError(
486+
expr, "intrinsic function expects a string literal as its first argument")
487+
args = [self.parse_expr(arg) for arg in expr.args[1:]]
488+
ret_type = args[0]
489+
if not isinstance(ret_type, hir.TypeValue):
490+
raise hir.ParsingError(
491+
expr, f"intrinsic function expects a type as its second argument but found {ret_type}")
492+
if any([not isinstance(arg, hir.Value) for arg in args[1:]]):
493+
raise hir.ParsingError(
494+
expr, "intrinsic function expects values as its arguments")
495+
return hir.Intrinsic(intrinsic_name.value, cast(List[hir.Value], args[1:]), ret_type.inner_type(), hir.Span.from_ast(expr))
496+
return do()
497+
elif f is SPECIAL_FUNCTIONS_DICT['comptime']:
480498
if len(expr.args) != 1:
481499
raise hir.ParsingError(
482500
expr, f"when used in expressions, lc.comptime function expects exactly one argument")
@@ -563,7 +581,7 @@ def collect_args() -> List[hir.Value | hir.Ref]:
563581
arg, hir.Span.from_ast(expr.args[i]))
564582
return cast(List[hir.Value | hir.Ref], args)
565583

566-
if isinstance(func.type, hir.TypeConstructorType):
584+
if isinstance(func.type, hir.TypeConstructorType):
567585
# TypeConstructorType is unique for each type
568586
# so if any value has this type, it must be referring to the same underlying type
569587
# even if it comes from a very complex expression, it's still fine

0 commit comments

Comments
 (0)