@@ -68,7 +68,9 @@ class _ObjKind(Enum):
68
68
KERNEL = auto ()
69
69
70
70
71
- def _make_func_template (f : Callable [..., Any ], func_name : str , func_sig : Optional [MethodType ], func_globals : Dict [str , Any ], foreign_type_var_ns : Dict [TypeVar , hir .Type | hir .ComptimeValue ], props : hir .FuncProperties , self_type : Optional [hir .Type ] = None ):
71
+ def _make_func_template (f : Callable [..., Any ], func_name : str , func_sig : Optional [MethodType ],
72
+ func_globals : Dict [str , Any ], foreign_type_var_ns : Dict [TypeVar , hir .Type | hir .ComptimeValue ],
73
+ props : hir .FuncProperties , self_type : Optional [hir .Type ] = None ):
72
74
# parsing_ctx = _parse.ParsingContext(func_name, func_globals)
73
75
# func_sig_parser = _parse.FuncParser(func_name, f, parsing_ctx, self_type)
74
76
# func_sig = func_sig_parser.parsed_func
@@ -91,7 +93,8 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.Function:
91
93
mapped_implicit_type_params : Dict [str ,
92
94
hir .Type ] = dict ()
93
95
assert func_sig is not None
94
- type_parser = parse .TypeParser (func_name , func_globals , type_var_ns , self_type , 'instantiate' )
96
+ type_parser = parse .TypeParser (
97
+ func_name , func_globals , type_var_ns , self_type , 'instantiate' )
95
98
for (tv , t ) in func_sig .env .items ():
96
99
type_var_ns [tv ] = unwrap (type_parser .parse_type (t ))
97
100
if is_generic :
@@ -115,7 +118,7 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.Function:
115
118
mapped_type = mapping [gp ]
116
119
assert isinstance (mapped_type , hir .Type )
117
120
mapped_implicit_type_params [name ] = mapped_type
118
-
121
+
119
122
func_sig_instantiated , _p = parse .convert_func_signature (
120
123
func_sig , func_name , func_globals , type_var_ns , mapped_implicit_type_params , self_type , mode = 'instantiate' )
121
124
# print(func_name, func_sig)
@@ -124,10 +127,10 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.Function:
124
127
assert not isinstance (
125
128
func_sig_instantiated .return_type , hir .SymbolicType )
126
129
func_parser = parse .FuncParser (
127
- func_name , f , func_sig_instantiated , func_globals , type_var_ns , self_type )
130
+ func_name , f , func_sig_instantiated , func_globals , type_var_ns , self_type , props . returning_ref )
128
131
ret = func_parser .parse_body ()
129
132
ret .inline_hint = props .inline
130
- ret .export = props .export
133
+ ret .export = props .export
131
134
return ret
132
135
params = [v [0 ] for v in func_sig .args ]
133
136
is_generic = len (func_sig_converted .generic_params ) > 0
@@ -162,17 +165,23 @@ def _dsl_func_impl(f: _TT, kind: _ObjKind, attrs: Dict[str, Any]) -> _TT:
162
165
# return cast(_T, f)
163
166
164
167
165
- def _dsl_struct_impl (cls : type [_TT ], attrs : Dict [str , Any ], ir_ty_override : hir .Type | None = None ) -> type [_TT ]:
166
- ctx = hir .GlobalContext .get ()
168
+ _MakeTemplateFn = Callable [[List [hir .GenericParameter ]], hir .Type ]
169
+ _InstantiateFn = Callable [[List [Any ]], hir .Type ]
170
+
167
171
172
+ def _dsl_struct_impl (cls : type [_TT ], attrs : Dict [str , Any ], ir_ty_override : hir .Type | Tuple [_MakeTemplateFn , _InstantiateFn ] | None = None , opqaue_override : str | None = None ) -> type [_TT ]:
173
+ ctx = hir .GlobalContext .get ()
168
174
register_class (cls )
175
+ assert not (ir_ty_override is not None and opqaue_override is not None )
169
176
cls_info = class_typeinfo (cls )
170
177
globalns = _get_cls_globalns (cls )
171
178
globalns [cls .__name__ ] = cls
172
179
type_var_to_generic_param : Dict [TypeVar , hir .GenericParameter ] = {}
173
180
for type_var in cls_info .type_vars :
174
181
type_var_to_generic_param [type_var ] = hir .GenericParameter (
175
182
type_var .__name__ , cls .__qualname__ )
183
+ generic_params = [type_var_to_generic_param [tv ]
184
+ for tv in cls_info .type_vars ]
176
185
177
186
def parse_fields (tp : parse .TypeParser , self_ty : hir .Type ):
178
187
fields : List [Tuple [str , hir .Type ]] = []
@@ -182,13 +191,14 @@ def parse_fields(tp: parse.TypeParser, self_ty: hir.Type):
182
191
raise hir .TypeInferenceError (
183
192
None , f"Cannot infer type for field { name } of { cls .__name__ } " )
184
193
fields .append ((name , field_ty ))
185
- if isinstance (self_ty , hir .StructType ):
186
- self_ty .fields = fields
187
- elif isinstance (self_ty , hir .BoundType ):
188
- assert isinstance (self_ty .instantiated , hir .StructType )
189
- self_ty .instantiated .fields = fields
190
- else :
191
- raise NotImplementedError ()
194
+ if len (fields ) > 0 :
195
+ if isinstance (self_ty , hir .StructType ):
196
+ self_ty .fields = fields
197
+ elif isinstance (self_ty , hir .BoundType ):
198
+ assert isinstance (self_ty .instantiated , hir .StructType )
199
+ self_ty .instantiated .fields = fields
200
+ else :
201
+ raise NotImplementedError ()
192
202
193
203
def parse_methods (type_var_ns : Dict [TypeVar , hir .Type | Any ], self_ty : hir .Type ,):
194
204
for name in cls_info .methods :
@@ -198,16 +208,24 @@ def parse_methods(type_var_ns: Dict[TypeVar, hir.Type | Any], self_ty: hir.Type,
198
208
props = getattr (method_object , '__luisa_func_props__' )
199
209
else :
200
210
props = hir .FuncProperties ()
211
+ if name == '__getitem__' :
212
+ props .returning_ref = True
201
213
template = _make_func_template (
202
214
method_object , get_full_name (method_object ), cls_info .methods [name ], globalns , type_var_ns , props , self_type = self_ty )
203
215
if isinstance (self_ty , hir .BoundType ):
204
- assert isinstance (self_ty .instantiated , hir .StructType )
216
+ assert isinstance (self_ty .instantiated ,
217
+ (hir .StructType , hir .OpaqueType ))
205
218
self_ty .instantiated .methods [name ] = template
206
219
else :
207
220
self_ty .methods [name ] = template
208
221
ir_ty : hir .Type
209
222
if ir_ty_override is not None :
210
- ir_ty = ir_ty_override
223
+ if isinstance (ir_ty_override , hir .Type ):
224
+ ir_ty = ir_ty_override
225
+ else :
226
+ ir_ty = ir_ty_override [0 ](generic_params )
227
+ elif opqaue_override is not None :
228
+ ir_ty = hir .OpaqueType (opqaue_override )
211
229
else :
212
230
ir_ty = hir .StructType (
213
231
f'{ cls .__name__ } _{ unique_hash (cls .__qualname__ )} ' , cls .__qualname__ , [])
@@ -226,8 +244,15 @@ def monomorphization_func(args: List[hir.Type | Any]) -> hir.Type:
226
244
for i , arg in enumerate (args ):
227
245
type_var_ns [cls_info .type_vars [i ]] = arg
228
246
hash_s = unique_hash (f'{ cls .__qualname__ } _{ args } ' )
229
- inner_ty = hir .StructType (
230
- f'{ cls .__name__ } _{ hash_s } M' , f'{ cls .__qualname__ } [{ "," .join ([str (a ) for a in args ])} ]' , [])
247
+ inner_ty : hir .Type
248
+ if ir_ty_override is not None :
249
+ assert isinstance (ir_ty_override , tuple )
250
+ inner_ty = ir_ty_override [1 ](args )
251
+ elif opqaue_override :
252
+ inner_ty = hir .OpaqueType (opqaue_override , args [:])
253
+ else :
254
+ inner_ty = hir .StructType (
255
+ f'{ cls .__name__ } _{ hash_s } M' , f'{ cls .__qualname__ } [{ "," .join ([str (a ) for a in args ])} ]' , [])
231
256
mono_self_ty = hir .BoundType (ir_ty , args , inner_ty )
232
257
mono_type_parser = parse .TypeParser (
233
258
cls .__qualname__ , globalns , type_var_ns , mono_self_ty , 'instantiate' )
@@ -253,6 +278,22 @@ def _dsl_decorator_impl(obj: _TT, kind: _ObjKind, attrs: Dict[str, Any]) -> _TT:
253
278
raise NotImplementedError ()
254
279
255
280
281
+ def opaque (name : str ) -> Callable [[type [_TT ]], type [_TT ]]:
282
+ """
283
+ Mark a class as a DSL opaque type.
284
+
285
+ Example:
286
+ ```python
287
+ @luisa.opaque("Buffer")
288
+ class Buffer(Generic[T]):
289
+ pass
290
+ ```
291
+ """
292
+ def wrapper (cls : type [_TT ]) -> type [_TT ]:
293
+ return _dsl_struct_impl (cls , {}, opqaue_override = name )
294
+ return wrapper
295
+
296
+
256
297
def struct (cls : type [_TT ]) -> type [_TT ]:
257
298
"""
258
299
Mark a class as a DSL struct.
@@ -277,6 +318,12 @@ def decorator(cls: type[_TT]) -> type[_TT]:
277
318
return decorator
278
319
279
320
321
+ def builtin_generic_type (make_template : _MakeTemplateFn , instantiate : _InstantiateFn ) -> Callable [[type [_TT ]], type [_TT ]]:
322
+ def decorator (cls : type [_TT ]) -> type [_TT ]:
323
+ return typing .cast (type [_TT ], _dsl_struct_impl (cls , {}, ir_ty_override = (make_template , instantiate )))
324
+ return decorator
325
+
326
+
280
327
_KernelType = TypeVar ("_KernelType" , bound = Callable [..., None ])
281
328
282
329
@@ -310,6 +357,13 @@ def __init__(self, value: str):
310
357
def _parse_func_kwargs (kwargs : Dict [str , Any ]) -> hir .FuncProperties :
311
358
props = hir .FuncProperties ()
312
359
props .byref = set ()
360
+ return_ = kwargs .get ("return" , None )
361
+ if return_ is not None :
362
+ if return_ == 'ref' :
363
+ props .returning_ref = True
364
+ else :
365
+ raise ValueError (
366
+ f"invalid value for return: { return_ } , expected 'ref'" )
313
367
inline = kwargs .get ("inline" , False )
314
368
if isinstance (inline , bool ):
315
369
props .inline = inline
0 commit comments