4
4
from typing import Any , Callable , Dict , List , Optional , Tuple , Union , overload
5
5
import typing
6
6
import luisa_lang
7
+ from luisa_lang .lang_builtins import comptime
7
8
from luisa_lang .utils import get_typevar_constrains_and_bounds , report_error
8
9
import luisa_lang .hir as hir
9
10
import sys
@@ -114,6 +115,12 @@ def convert_func_signature(signature: classinfo.MethodType,
114
115
return hir .FunctionSignature (type_parser .generic_params , params , return_type ), type_parser
115
116
116
117
118
+ SPECIAL_FUNCTIONS : Set [Callable [..., Any ]] = {
119
+ comptime ,
120
+ reveal_type ,
121
+ }
122
+
123
+
117
124
class FuncParser :
118
125
name : str
119
126
func : object
@@ -138,7 +145,7 @@ def __init__(self, name: str,
138
145
self .signature = signature
139
146
self .globalns = globalns
140
147
obj_ast , _obj_file = retrieve_ast_and_filename (func )
141
- print (ast .dump (obj_ast ))
148
+ # print(ast.dump(obj_ast))
142
149
assert isinstance (obj_ast , ast .Module ), f"{ obj_ast } is not a module"
143
150
if not isinstance (obj_ast .body [0 ], ast .FunctionDef ):
144
151
raise RuntimeError ("Function definition expected." )
@@ -205,6 +212,18 @@ def parse_const(self, const: ast.Constant) -> hir.Value:
205
212
report_error (
206
213
const , f"unsupported constant type { type (value )} , wrap it in lc.comptime(...) if you intead to use it as a compile-time expression" )
207
214
215
+ def convert_any_to_value (self , a : Any , span : hir .Span | None ) -> hir .Value | ComptimeValue :
216
+ if not isinstance (a , ComptimeValue ):
217
+ a = ComptimeValue (a , None )
218
+ if a .value in SPECIAL_FUNCTIONS :
219
+ return a
220
+ if (converted := self .convert_constexpr (a , span )) is not None :
221
+ return converted
222
+ if is_valid_comptime_value_in_dsl_code (a .value ):
223
+ return a
224
+ report_error (
225
+ span , f"unsupported constant type { type (a .value )} , wrap it in lc.comptime(...) if you intead to use it as a compile-time expression" )
226
+
208
227
def parse_name (self , name : ast .Name , maybe_new_var : bool ) -> hir .Ref | hir .Value | ComptimeValue :
209
228
span = hir .Span .from_ast (name )
210
229
var = self .vars .get (name .id )
@@ -218,13 +237,9 @@ def parse_name(self, name: ast.Name, maybe_new_var: bool) -> hir.Ref | hir.Value
218
237
# look up in global namespace
219
238
if name .id in self .globalns :
220
239
resolved = self .globalns [name .id ]
240
+ return self .convert_any_to_value (resolved , span )
221
241
# assert isinstance(resolved, ComptimeValue), type(resolved)
222
- if not isinstance (resolved , ComptimeValue ):
223
- resolved = ComptimeValue (resolved , None )
224
- if (converted := self .convert_constexpr (resolved , span )) is not None :
225
- return converted
226
- if is_valid_comptime_value_in_dsl_code (resolved .value ):
227
- return resolved
242
+
228
243
report_error (name , f"unknown variable { name .id } " )
229
244
230
245
def try_convert_comptime_value (self , value : ComptimeValue , span : hir .Span | None = None ) -> hir .Value :
@@ -346,12 +361,49 @@ def parse_call_impl(self, span: hir.Span | None, f: hir.FunctionLike | hir.Funct
346
361
return self .cur_bb ().append (hir .Call (resolved_f , args , type = ty , span = span ))
347
362
raise NotImplementedError () # unreachable
348
363
349
- def parse_call (self , expr : ast .Call ) -> hir .Value :
364
+ def handle_special_functions (self , f : Callable [..., Any ], expr : ast .Call ) -> hir .Value | ComptimeValue :
365
+ match f :
366
+ case _ if f == comptime :
367
+ if len (expr .args ) != 1 :
368
+ report_error (
369
+ expr , f"when used in expressions, lc.comptime function expects exactly one argument" )
370
+ arg = expr .args [0 ]
371
+ # print(ast.dump(arg))
372
+ if isinstance (arg , ast .Constant ) and isinstance (arg .value , str ):
373
+ evaled = self .eval_expr (arg .value )
374
+ else :
375
+ evaled = self .eval_expr (arg )
376
+ # print(evaled)
377
+ v = self .convert_any_to_value (evaled , hir .Span .from_ast (expr ))
378
+ return v
379
+ case _ if f == reveal_type :
380
+ if len (expr .args ) != 1 :
381
+ report_error (
382
+ expr , f"lc.reveal_type expects exactly one argument" )
383
+ arg = expr .args [0 ]
384
+ cur_bb = self .cur_bb ()
385
+ cur_bb_len = len (cur_bb .nodes )
386
+ value = self .parse_expr (arg )
387
+ assert cur_bb is self .cur_bb ()
388
+ del self .cur_bb ().nodes [cur_bb_len :]
389
+ unparsed_arg = ast .unparse (arg )
390
+ if isinstance (value , ComptimeValue ):
391
+ print (
392
+ f"Type of { unparsed_arg } is ComptimeValue({ type (value .value )} )" )
393
+ else :
394
+ print (f"Type of { unparsed_arg } is { value .type } " )
395
+ return hir .Unit ()
396
+ case _:
397
+ raise RuntimeError (f"Unsupported special function { f } " )
398
+
399
+ def parse_call (self , expr : ast .Call ) -> hir .Value | ComptimeValue :
350
400
func = self .parse_expr (expr .func )
351
401
352
402
if isinstance (func , hir .Ref ):
353
403
report_error (expr , f"function expected" )
354
404
elif isinstance (func , ComptimeValue ):
405
+ if func .value in SPECIAL_FUNCTIONS :
406
+ return self .handle_special_functions (func .value , expr )
355
407
func = self .try_convert_comptime_value (
356
408
func , hir .Span .from_ast (expr ))
357
409
@@ -471,9 +523,10 @@ def parse_expr(self, expr: ast.expr) -> hir.Value | ComptimeValue:
471
523
case _:
472
524
raise RuntimeError (f"Unsupported expression: { ast .dump (expr )} " )
473
525
474
- def eval_expr (self , tree : ast .Expression | ast .expr ):
526
+ def eval_expr (self , tree : str | ast .Expression | ast .expr ):
475
527
if isinstance (tree , ast .expr ):
476
528
tree = ast .Expression (tree )
529
+ # print(tree)
477
530
code_object = compile (tree , "<string>" , "eval" )
478
531
localns = {}
479
532
for name , v in self .vars .items ():
@@ -531,18 +584,19 @@ def check_return_type(ty: hir.Type):
531
584
report_error (
532
585
stmt , f"expected { var .type } , got { value .type } " )
533
586
else :
587
+ if not value .type .is_concrete ():
588
+ report_error (
589
+ stmt , "only concrete type can be assigned, please annotate the variable with type hint" )
534
590
var .type = value .type
535
591
self .cur_bb ().append (hir .Assign (var , value , span ))
536
592
case ast .AnnAssign ():
537
593
var = self .parse_ref (stmt .target , maybe_new_var = True )
538
- if isinstance (var , hir .Value ):
539
- report_error (stmt , f"value cannot be assigned" )
540
- elif isinstance (var , hir .Ref ):
541
- type_annotation = self .eval_expr (stmt .annotation )
542
- type_hint = classinfo .parse_type_hint (type_annotation )
543
- ty = self .parse_type (type_hint )
544
- assert ty
545
- var .type = ty
594
+
595
+ type_annotation = self .eval_expr (stmt .annotation )
596
+ type_hint = classinfo .parse_type_hint (type_annotation )
597
+ ty = self .parse_type (type_hint )
598
+ assert ty
599
+ var .type = ty
546
600
547
601
if stmt .value :
548
602
value = self .parse_expr (stmt .value )
@@ -560,14 +614,19 @@ def check_return_type(ty: hir.Type):
560
614
value = hir .Load (value )
561
615
assert value .type
562
616
assert ty
617
+ if not var .type .is_concrete ():
618
+ report_error (
619
+ stmt , "only concrete type can be assigned, please annotate the variable with concrete types" )
563
620
if not hir .is_type_compatible_to (value .type , ty ):
564
621
report_error (
565
622
stmt , f"expected { ty } , got { value .type } " )
623
+ if not value .type .is_concrete ():
624
+ value .type = var .type
566
625
self .cur_bb ().append (hir .Assign (var , value , span ))
567
626
else :
568
627
assert isinstance (var , hir .Var )
569
- case ast .Expression ():
570
- self .parse_expr (stmt .body )
628
+ case ast .Expr ():
629
+ self .parse_expr (stmt .value )
571
630
case ast .Pass ():
572
631
return
573
632
case _:
0 commit comments