Skip to content

Commit 4e956b3

Browse files
committed
__getitem__ works
1 parent 78733cb commit 4e956b3

File tree

8 files changed

+385
-204
lines changed

8 files changed

+385
-204
lines changed

luisa_lang/_builtin_decor.py

+72-18
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ class _ObjKind(Enum):
6868
KERNEL = auto()
6969

7070

71-
def _make_func_template(f: Callable[..., Any], func_name: str, func_sig: Optional[MethodType], func_globals: Dict[str, Any], foreign_type_var_ns: Dict[TypeVar, hir.Type | hir.ComptimeValue], props: hir.FuncProperties, self_type: Optional[hir.Type] = None):
71+
def _make_func_template(f: Callable[..., Any], func_name: str, func_sig: Optional[MethodType],
72+
func_globals: Dict[str, Any], foreign_type_var_ns: Dict[TypeVar, hir.Type | hir.ComptimeValue],
73+
props: hir.FuncProperties, self_type: Optional[hir.Type] = None):
7274
# parsing_ctx = _parse.ParsingContext(func_name, func_globals)
7375
# func_sig_parser = _parse.FuncParser(func_name, f, parsing_ctx, self_type)
7476
# func_sig = func_sig_parser.parsed_func
@@ -91,7 +93,8 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.Function:
9193
mapped_implicit_type_params: Dict[str,
9294
hir.Type] = dict()
9395
assert func_sig is not None
94-
type_parser = parse.TypeParser(func_name, func_globals, type_var_ns, self_type, 'instantiate')
96+
type_parser = parse.TypeParser(
97+
func_name, func_globals, type_var_ns, self_type, 'instantiate')
9598
for (tv, t) in func_sig.env.items():
9699
type_var_ns[tv] = unwrap(type_parser.parse_type(t))
97100
if is_generic:
@@ -115,7 +118,7 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.Function:
115118
mapped_type = mapping[gp]
116119
assert isinstance(mapped_type, hir.Type)
117120
mapped_implicit_type_params[name] = mapped_type
118-
121+
119122
func_sig_instantiated, _p = parse.convert_func_signature(
120123
func_sig, func_name, func_globals, type_var_ns, mapped_implicit_type_params, self_type, mode='instantiate')
121124
# print(func_name, func_sig)
@@ -124,10 +127,10 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.Function:
124127
assert not isinstance(
125128
func_sig_instantiated.return_type, hir.SymbolicType)
126129
func_parser = parse.FuncParser(
127-
func_name, f, func_sig_instantiated, func_globals, type_var_ns, self_type)
130+
func_name, f, func_sig_instantiated, func_globals, type_var_ns, self_type, props.returning_ref)
128131
ret = func_parser.parse_body()
129132
ret.inline_hint = props.inline
130-
ret.export = props.export
133+
ret.export = props.export
131134
return ret
132135
params = [v[0] for v in func_sig.args]
133136
is_generic = len(func_sig_converted.generic_params) > 0
@@ -162,17 +165,23 @@ def _dsl_func_impl(f: _TT, kind: _ObjKind, attrs: Dict[str, Any]) -> _TT:
162165
# return cast(_T, f)
163166

164167

165-
def _dsl_struct_impl(cls: type[_TT], attrs: Dict[str, Any], ir_ty_override: hir.Type | None = None) -> type[_TT]:
166-
ctx = hir.GlobalContext.get()
168+
_MakeTemplateFn = Callable[[List[hir.GenericParameter]], hir.Type]
169+
_InstantiateFn = Callable[[List[Any]], hir.Type]
170+
167171

172+
def _dsl_struct_impl(cls: type[_TT], attrs: Dict[str, Any], ir_ty_override: hir.Type | Tuple[_MakeTemplateFn, _InstantiateFn] | None = None, opqaue_override: str | None = None) -> type[_TT]:
173+
ctx = hir.GlobalContext.get()
168174
register_class(cls)
175+
assert not (ir_ty_override is not None and opqaue_override is not None)
169176
cls_info = class_typeinfo(cls)
170177
globalns = _get_cls_globalns(cls)
171178
globalns[cls.__name__] = cls
172179
type_var_to_generic_param: Dict[TypeVar, hir.GenericParameter] = {}
173180
for type_var in cls_info.type_vars:
174181
type_var_to_generic_param[type_var] = hir.GenericParameter(
175182
type_var.__name__, cls.__qualname__)
183+
generic_params = [type_var_to_generic_param[tv]
184+
for tv in cls_info.type_vars]
176185

177186
def parse_fields(tp: parse.TypeParser, self_ty: hir.Type):
178187
fields: List[Tuple[str, hir.Type]] = []
@@ -182,13 +191,14 @@ def parse_fields(tp: parse.TypeParser, self_ty: hir.Type):
182191
raise hir.TypeInferenceError(
183192
None, f"Cannot infer type for field {name} of {cls.__name__}")
184193
fields.append((name, field_ty))
185-
if isinstance(self_ty, hir.StructType):
186-
self_ty.fields = fields
187-
elif isinstance(self_ty, hir.BoundType):
188-
assert isinstance(self_ty.instantiated, hir.StructType)
189-
self_ty.instantiated.fields = fields
190-
else:
191-
raise NotImplementedError()
194+
if len(fields) > 0:
195+
if isinstance(self_ty, hir.StructType):
196+
self_ty.fields = fields
197+
elif isinstance(self_ty, hir.BoundType):
198+
assert isinstance(self_ty.instantiated, hir.StructType)
199+
self_ty.instantiated.fields = fields
200+
else:
201+
raise NotImplementedError()
192202

193203
def parse_methods(type_var_ns: Dict[TypeVar, hir.Type | Any], self_ty: hir.Type,):
194204
for name in cls_info.methods:
@@ -198,16 +208,24 @@ def parse_methods(type_var_ns: Dict[TypeVar, hir.Type | Any], self_ty: hir.Type,
198208
props = getattr(method_object, '__luisa_func_props__')
199209
else:
200210
props = hir.FuncProperties()
211+
if name == '__getitem__':
212+
props.returning_ref = True
201213
template = _make_func_template(
202214
method_object, get_full_name(method_object), cls_info.methods[name], globalns, type_var_ns, props, self_type=self_ty)
203215
if isinstance(self_ty, hir.BoundType):
204-
assert isinstance(self_ty.instantiated, hir.StructType)
216+
assert isinstance(self_ty.instantiated,
217+
(hir.StructType, hir.OpaqueType))
205218
self_ty.instantiated.methods[name] = template
206219
else:
207220
self_ty.methods[name] = template
208221
ir_ty: hir.Type
209222
if ir_ty_override is not None:
210-
ir_ty = ir_ty_override
223+
if isinstance(ir_ty_override, hir.Type):
224+
ir_ty = ir_ty_override
225+
else:
226+
ir_ty = ir_ty_override[0](generic_params)
227+
elif opqaue_override is not None:
228+
ir_ty = hir.OpaqueType(opqaue_override)
211229
else:
212230
ir_ty = hir.StructType(
213231
f'{cls.__name__}_{unique_hash(cls.__qualname__)}', cls.__qualname__, [])
@@ -226,8 +244,15 @@ def monomorphization_func(args: List[hir.Type | Any]) -> hir.Type:
226244
for i, arg in enumerate(args):
227245
type_var_ns[cls_info.type_vars[i]] = arg
228246
hash_s = unique_hash(f'{cls.__qualname__}_{args}')
229-
inner_ty = hir.StructType(
230-
f'{cls.__name__}_{hash_s}M', f'{cls.__qualname__}[{",".join([str(a) for a in args])}]', [])
247+
inner_ty: hir.Type
248+
if ir_ty_override is not None:
249+
assert isinstance(ir_ty_override, tuple)
250+
inner_ty = ir_ty_override[1](args)
251+
elif opqaue_override:
252+
inner_ty = hir.OpaqueType(opqaue_override, args[:])
253+
else:
254+
inner_ty = hir.StructType(
255+
f'{cls.__name__}_{hash_s}M', f'{cls.__qualname__}[{",".join([str(a) for a in args])}]', [])
231256
mono_self_ty = hir.BoundType(ir_ty, args, inner_ty)
232257
mono_type_parser = parse.TypeParser(
233258
cls.__qualname__, globalns, type_var_ns, mono_self_ty, 'instantiate')
@@ -253,6 +278,22 @@ def _dsl_decorator_impl(obj: _TT, kind: _ObjKind, attrs: Dict[str, Any]) -> _TT:
253278
raise NotImplementedError()
254279

255280

281+
def opaque(name: str) -> Callable[[type[_TT]], type[_TT]]:
282+
"""
283+
Mark a class as a DSL opaque type.
284+
285+
Example:
286+
```python
287+
@luisa.opaque("Buffer")
288+
class Buffer(Generic[T]):
289+
pass
290+
```
291+
"""
292+
def wrapper(cls: type[_TT]) -> type[_TT]:
293+
return _dsl_struct_impl(cls, {}, opqaue_override=name)
294+
return wrapper
295+
296+
256297
def struct(cls: type[_TT]) -> type[_TT]:
257298
"""
258299
Mark a class as a DSL struct.
@@ -277,6 +318,12 @@ def decorator(cls: type[_TT]) -> type[_TT]:
277318
return decorator
278319

279320

321+
def builtin_generic_type(make_template: _MakeTemplateFn, instantiate: _InstantiateFn) -> Callable[[type[_TT]], type[_TT]]:
322+
def decorator(cls: type[_TT]) -> type[_TT]:
323+
return typing.cast(type[_TT], _dsl_struct_impl(cls, {}, ir_ty_override=(make_template, instantiate)))
324+
return decorator
325+
326+
280327
_KernelType = TypeVar("_KernelType", bound=Callable[..., None])
281328

282329

@@ -310,6 +357,13 @@ def __init__(self, value: str):
310357
def _parse_func_kwargs(kwargs: Dict[str, Any]) -> hir.FuncProperties:
311358
props = hir.FuncProperties()
312359
props.byref = set()
360+
return_ = kwargs.get("return", None)
361+
if return_ is not None:
362+
if return_ == 'ref':
363+
props.returning_ref = True
364+
else:
365+
raise ValueError(
366+
f"invalid value for return: {return_}, expected 'ref'")
313367
inline = kwargs.get("inline", False)
314368
if isinstance(inline, bool):
315369
props.inline = inline

luisa_lang/codegen/cpp.py

+29-2
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,16 @@ def gen(self, ty: hir.Type) -> str:
3434
def gen_impl(self, ty: hir.Type) -> str:
3535
match ty:
3636
case hir.IntType(bits=bits, signed=signed):
37+
int_names = {
38+
'8':'byte',
39+
'16':'short',
40+
'32':'int',
41+
'64':'long',
42+
}
3743
if signed:
38-
return f"i{bits}"
44+
return f"lc_{int_names[str(bits)]}"
3945
else:
40-
return f"u{bits}"
46+
return f"lc_u{int_names[str(bits)]}"
4147
case hir.FloatType(bits=bits):
4248
match bits:
4349
case 16:
@@ -77,6 +83,15 @@ def do():
7783
return ''
7884
case hir.TypeConstructorType():
7985
return ''
86+
case hir.OpaqueType():
87+
def do():
88+
match ty.name:
89+
case 'Buffer':
90+
elem_ty = self.gen(ty.extra_args[0])
91+
return f'__builtin__Buffer<{elem_ty}>'
92+
case _:
93+
raise NotImplementedError(f"unsupported opaque type: {ty.name}")
94+
return do()
8095
case _:
8196
raise NotImplementedError(f"unsupported type: {ty}")
8297

@@ -167,6 +182,8 @@ def mangle_impl(self, obj: Union[hir.Type, hir.Function]) -> str:
167182
case hir.BoundType():
168183
assert obj.instantiated
169184
return self.mangle(obj.instantiated)
185+
case hir.OpaqueType():
186+
return obj.name
170187
case _:
171188
raise NotImplementedError(f"unsupported object: {obj}")
172189

@@ -263,6 +280,16 @@ def gen_ref(self, ref: hir.Ref) -> str:
263280
base = self.gen_ref(index.base)
264281
idx = self.gen_expr(index.index)
265282
return f"{base}[{idx}]"
283+
case hir.IntrinsicRef() as intrin:
284+
def do():
285+
intrin_name = intrin.name
286+
gened_args = [self.gen_value_or_ref(
287+
arg) for arg in intrin.args]
288+
if intrin_name == 'buffer_ref':
289+
return f"{gened_args[0]}[{gened_args[1]}]"
290+
else:
291+
raise RuntimeError(f"unsupported intrinsic reference: {intrin_name}")
292+
return do()
266293
case _:
267294
raise NotImplementedError(f"unsupported reference: {ref}")
268295

luisa_lang/codegen/cpp_lib.py

+1-1
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)