11import inspect
2- from functools import wraps
2+ from functools import wraps , partial
33
44from mlir .dialects .func import FuncOp , ReturnOp , CallOp
55from mlir .ir import (
88 StringAttr ,
99 TypeAttr ,
1010 FlatSymbolRefAttr ,
11+ Type ,
1112)
1213
1314from mlir_utils .dialects .util import (
1415 get_result_or_results ,
1516 make_maybe_no_args_decorator ,
17+ maybe_cast ,
1618)
1719
1820
19- @make_maybe_no_args_decorator
20- def func (sym_visibility = None , arg_attrs = None , res_attrs = None , loc = None , ip = None ):
21+ def func_base (
22+ FuncOp ,
23+ ReturnOp ,
24+ CallOp ,
25+ sym_visibility = None ,
26+ arg_attrs = None ,
27+ res_attrs = None ,
28+ loc = None ,
29+ ip = None ,
30+ ):
2131 ip = ip or InsertionPoint .current
2232
33+ # if this is set to true then wrapper below won't emit a call op
34+ # it is set below by a def emit fn that is attached to the body_builder
35+ # wrapper; thus you can call wrapped_fn.emit() (i.e., without an operands)
36+ # and the func will be emitted.
37+ _emit = False
38+
2339 def builder_wrapper (body_builder ):
2440 @wraps (body_builder )
2541 def wrapper (* call_args ):
2642 sig = inspect .signature (body_builder )
2743 implicit_return = sig .return_annotation is inspect ._empty
28- input_types = [a .type for a in call_args ]
44+ input_types = [p .annotation for p in sig .parameters .values ()]
45+ if not (
46+ len (input_types ) == len (sig .parameters )
47+ and all (isinstance (t , Type ) for t in input_types )
48+ ):
49+ input_types = [a .type for a in call_args ]
2950 function_type = TypeAttr .get (
3051 FunctionType .get (
3152 inputs = input_types ,
@@ -34,7 +55,7 @@ def wrapper(*call_args):
3455 )
3556 # FuncOp is extended but we do really want the base
3657 func_name = body_builder .__name__
37- func_op = FuncOp . __base__ (
58+ func_op = FuncOp (
3859 func_name ,
3960 function_type ,
4061 sym_visibility = StringAttr .get (str (sym_visibility ))
@@ -45,7 +66,7 @@ def wrapper(*call_args):
4566 loc = loc ,
4667 ip = ip ,
4768 )
48- func_op .regions [0 ].blocks .append (* [ a . type for a in call_args ] )
69+ func_op .regions [0 ].blocks .append (* input_types )
4970 with InsertionPoint (func_op .regions [0 ].blocks [0 ]):
5071 results = get_result_or_results (
5172 body_builder (* func_op .regions [0 ].blocks [0 ].arguments )
@@ -63,14 +84,27 @@ def wrapper(*call_args):
6384 function_type = FunctionType .get (inputs = input_types , results = return_types )
6485 func_op .attributes ["function_type" ] = TypeAttr .get (function_type )
6586
66- call_op = CallOp (
67- [r .type for r in results ], FlatSymbolRefAttr .get (func_name ), call_args
68- )
69- if results is None :
70- return func_op
71- return get_result_or_results (call_op )
87+ if _emit :
88+ return maybe_cast (get_result_or_results (func_op ))
89+ else :
90+ call_op = CallOp (
91+ [r .type for r in results ],
92+ FlatSymbolRefAttr .get (func_name ),
93+ call_args ,
94+ )
95+ return maybe_cast (get_result_or_results (call_op ))
96+
97+ def emit ():
98+ nonlocal _emit
99+ _emit = True
100+ wrapper ()
72101
73- # wrapper.op = op
102+ wrapper .emit = emit
74103 return wrapper
75104
76105 return builder_wrapper
106+
107+
108+ func = make_maybe_no_args_decorator (
109+ partial (func_base , FuncOp = FuncOp .__base__ , ReturnOp = ReturnOp , CallOp = CallOp )
110+ )
0 commit comments