22
22
23
23
PATH_PREFIX = "luisa_lang"
24
24
25
- FunctionLike = Union ["Function" ]
26
-
27
25
28
26
# @dataclass
29
27
# class FunctionTemplateResolveResult:
30
- # func: Optional[FunctionLike ]
28
+ # func: Optional[Function ]
31
29
# matched: bool
32
30
33
31
40
38
41
39
42
40
FunctionTemplateResolvingFunc = Callable [[
43
- FunctionTemplateResolvingArgs ], Union [FunctionLike , 'TemplateMatchingError' ]]
41
+ FunctionTemplateResolvingArgs ], Union ['Function' , 'TemplateMatchingError' ]]
42
+
44
43
45
44
class FuncProperties :
46
- inline : bool | Literal ["always" ]
45
+ inline : bool | Literal ["never" , " always" ]
47
46
export : bool
48
47
byref : Set [str ]
49
48
50
49
def __init__ (self ):
51
50
self .inline = False
52
51
self .export = False
53
- self .byref = set ()
52
+ self .byref = set ()
54
53
55
54
56
55
class FunctionTemplate :
@@ -63,7 +62,7 @@ class FunctionTemplate:
63
62
"""
64
63
parsing_func : FunctionTemplateResolvingFunc
65
64
__resolved : Dict [Tuple [Tuple [str ,
66
- Union ['Type' , Any ]], ...], FunctionLike ]
65
+ Union ['Type' , Any ]], ...], "Function" ]
67
66
is_generic : bool
68
67
name : str
69
68
params : List [str ]
@@ -78,7 +77,7 @@ def __init__(self, name: str, params: List[str], parsing_func: FunctionTemplateR
78
77
self .name = name
79
78
self .props = None
80
79
81
- def resolve (self , args : FunctionTemplateResolvingArgs | None ) -> Union [FunctionLike , 'TemplateMatchingError' ]:
80
+ def resolve (self , args : FunctionTemplateResolvingArgs | None ) -> Union ["Function" , 'TemplateMatchingError' ]:
82
81
args = args or []
83
82
if not self .is_generic :
84
83
key = tuple (args )
@@ -101,7 +100,7 @@ class DynamicIndex:
101
100
102
101
103
102
class Type (ABC ):
104
- methods : Dict [str , Union [FunctionLike ]]
103
+ methods : Dict [str , Union ["Function" , FunctionTemplate ]]
105
104
is_builtin : bool
106
105
107
106
def __init__ (self ):
@@ -132,7 +131,7 @@ def member(self, field: Any) -> Optional['Type']:
132
131
return FunctionType (m , None )
133
132
return None
134
133
135
- def method (self , name : str ) -> Optional [FunctionLike | FunctionTemplate ]:
134
+ def method (self , name : str ) -> Optional [Union [ "Function" , FunctionTemplate ] ]:
136
135
m = self .methods .get (name )
137
136
if m :
138
137
return m
@@ -738,7 +737,7 @@ def member(self, field) -> Optional['Type']:
738
737
raise RuntimeError ("member access on uninstantiated BoundType" )
739
738
740
739
@override
741
- def method (self , name ) -> Optional [FunctionLike | FunctionTemplate ]:
740
+ def method (self , name ) -> Optional [Union [ "Function" , FunctionTemplate ] ]:
742
741
if self .instantiated is not None :
743
742
return self .instantiated .method (name )
744
743
else :
@@ -766,10 +765,10 @@ def __hash__(self) -> int:
766
765
767
766
768
767
class FunctionType (Type ):
769
- func_like : FunctionLike | FunctionTemplate
768
+ func_like : Union [ "Function" , FunctionTemplate ]
770
769
bound_object : Optional ['Ref' ]
771
770
772
- def __init__ (self , func_like : FunctionLike | FunctionTemplate , bound_object : Optional ['Ref' ]) -> None :
771
+ def __init__ (self , func_like : Union [ "Function" , FunctionTemplate ] , bound_object : Optional ['Ref' ]) -> None :
773
772
super ().__init__ ()
774
773
self .func_like = func_like
775
774
self .bound_object = bound_object
@@ -950,6 +949,8 @@ def __eq__(self, value: object) -> bool:
950
949
951
950
def __hash__ (self ) -> int :
952
951
return hash (self .value )
952
+
953
+
953
954
class TypeValue (Value ):
954
955
def __init__ (self , ty : Type , span : Optional [Span ] = None ) -> None :
955
956
super ().__init__ (TypeConstructorType (ty ), span )
@@ -958,10 +959,12 @@ def inner_type(self) -> Type:
958
959
assert isinstance (self .type , TypeConstructorType )
959
960
return self .type .inner
960
961
962
+
961
963
class FunctionValue (Value ):
962
- def __init__ (self , ty :FunctionType , span : Optional [Span ] = None ) -> None :
964
+ def __init__ (self , ty : FunctionType , span : Optional [Span ] = None ) -> None :
963
965
super ().__init__ (ty , span )
964
966
967
+
965
968
class Alloca (Ref ):
966
969
"""
967
970
A temporary variable
@@ -1003,14 +1006,14 @@ def __repr__(self) -> str:
1003
1006
1004
1007
1005
1008
class Call (Value ):
1006
- op : FunctionLike
1009
+ op : "Function"
1007
1010
"""After type inference, op should be a Value."""
1008
1011
1009
1012
args : List [Value | Ref ]
1010
1013
1011
1014
def __init__ (
1012
1015
self ,
1013
- op : FunctionLike ,
1016
+ op : "Function" ,
1014
1017
args : List [Value | Ref ],
1015
1018
type : Type ,
1016
1019
span : Optional [Span ] = None ,
@@ -1077,7 +1080,7 @@ class Assign(Node):
1077
1080
value : Value
1078
1081
1079
1082
def __init__ (self , ref : Ref , value : Value , span : Optional [Span ] = None ) -> None :
1080
- assert not isinstance (value .type , (FunctionType , TypeConstructorType ))
1083
+ assert not isinstance (value .type , (FunctionType , TypeConstructorType ))
1081
1084
super ().__init__ (span )
1082
1085
self .ref = ref
1083
1086
self .value = value
@@ -1206,7 +1209,7 @@ class Function:
1206
1209
locals : List [Var ]
1207
1210
complete : bool
1208
1211
is_method : bool
1209
- inline_hint : Literal [True , 'always' , 'never' ] | None
1212
+ inline_hint : bool | Literal ['always' , 'never' ]
1210
1213
1211
1214
def __init__ (
1212
1215
self ,
@@ -1223,7 +1226,7 @@ def __init__(
1223
1226
self .locals = []
1224
1227
self .complete = False
1225
1228
self .is_method = is_method
1226
- self .inline_hint = None
1229
+ self .inline_hint = False
1227
1230
1228
1231
1229
1232
def match_template_args (
@@ -1408,3 +1411,98 @@ def is_type_compatible_to(ty: Type, target: Type) -> bool:
1408
1411
if isinstance (target , IntType ):
1409
1412
return isinstance (ty , GenericIntType )
1410
1413
return False
1414
+
1415
+
1416
+ class FunctionInliner :
1417
+ mapping : Dict [Ref | Value , Ref | Value ]
1418
+ ret : Value | None
1419
+
1420
+ def __init__ (self , func : Function , args : List [Value | Ref ], body : BasicBlock , span : Optional [Span ] = None ) -> None :
1421
+ self .mapping = {}
1422
+ for param , arg in zip (func .params , args ):
1423
+ self .mapping [param ] = arg
1424
+ assert func .body
1425
+ self .do_inline (func .body , body )
1426
+
1427
+ def do_inline (self , func_body : BasicBlock , body : BasicBlock ) -> None :
1428
+ for node in func_body .nodes :
1429
+ assert node not in self .mapping
1430
+
1431
+ match node :
1432
+ case Var ():
1433
+ assert node .type
1434
+ assert node .semantic == ParameterSemantic .BYVAL
1435
+ self .mapping [node ] = Alloca (node .type , node .span )
1436
+ case Load ():
1437
+ mapped_var = self .mapping [node .ref ]
1438
+ assert isinstance (mapped_var , Ref )
1439
+ body .append (Load (mapped_var ))
1440
+ case Index ():
1441
+ base = self .mapping .get (node .base )
1442
+ assert isinstance (base , Value )
1443
+ index = self .mapping .get (node .index )
1444
+ assert isinstance (index , Value )
1445
+ assert node .type
1446
+ self .mapping [node ] = body .append (
1447
+ Index (base , index , node .type , node .span ))
1448
+ case IndexRef ():
1449
+ base = self .mapping .get (node .base )
1450
+ index = self .mapping .get (node .index )
1451
+ assert isinstance (base , Ref )
1452
+ assert isinstance (index , Value )
1453
+ assert node .type
1454
+ self .mapping [node ] = body .append (IndexRef (
1455
+ base , index , node .type , node .span ))
1456
+ case Member ():
1457
+ base = self .mapping .get (node .base )
1458
+ assert isinstance (base , Value )
1459
+ assert node .type
1460
+ self .mapping [node ] = body .append (Member (
1461
+ base , node .field , node .type , node .span ))
1462
+ case MemberRef ():
1463
+ base = self .mapping .get (node .base )
1464
+ assert isinstance (base , Ref )
1465
+ assert node .type
1466
+ self .mapping [node ] = body .append (MemberRef (
1467
+ base , node .field , node .type , node .span ))
1468
+ case Call () as call :
1469
+ def do ():
1470
+ args : List [Ref | Value ] = []
1471
+ for arg in call .args :
1472
+ mapped_arg = self .mapping .get (arg )
1473
+ if mapped_arg is None :
1474
+ raise ParsingError (node , "unable to inline call" )
1475
+ args .append (mapped_arg )
1476
+ assert call .type
1477
+ self .mapping [call ] = body .append (
1478
+ Call (call .op , args , call .type , node .span ))
1479
+ do ()
1480
+ case Intrinsic () as intrin :
1481
+ def do ():
1482
+ args : List [Ref | Value ] = []
1483
+ for arg in intrin .args :
1484
+ mapped_arg = self .mapping .get (arg )
1485
+ if mapped_arg is None :
1486
+ raise ParsingError (
1487
+ node , "unable to inline intrinsic" )
1488
+ args .append (mapped_arg )
1489
+ assert intrin .type
1490
+ self .mapping [intrin ] = body .append (
1491
+ Intrinsic (intrin .name , args , intrin .type , node .span ))
1492
+ do ()
1493
+ case Return ():
1494
+ if self .ret is not None :
1495
+ raise ParsingError (node , "multiple return statement" )
1496
+ assert node .value is not None
1497
+ mapped_value = self .mapping .get (node .value )
1498
+ if mapped_value is None or isinstance (mapped_value , Ref ):
1499
+ raise ParsingError (node , "unable to inline return" )
1500
+ self .ret = mapped_value
1501
+ case _:
1502
+ raise ParsingError (node , "invalid node for inlining" )
1503
+
1504
+ @staticmethod
1505
+ def inline (func : Function , args : List [Value | Ref ], body : BasicBlock , span : Optional [Span ] = None ) -> Value :
1506
+ inliner = FunctionInliner (func , args , body , span )
1507
+ assert inliner .ret
1508
+ return inliner .ret
0 commit comments