@@ -61,6 +61,8 @@ def __init__(
6161 operator_blocklist : Optional [Set [OpKey ]] = None ,
6262 operator_allowlist : Optional [Set [OpKey ]] = None ,
6363 fusable_subgraphs : Optional [List [InternalMatch ]] = None ,
64+ nn_module_blocklist : Optional [Set [str ]] = None ,
65+ nn_module_allowlist : Optional [Set [str ]] = None ,
6466 ) -> None :
6567 super ().__init__ ()
6668 self .texture_limits : utils .ImageExtents = texture_limits
@@ -78,6 +80,9 @@ def __init__(
7880 for match in self .fusable_subgraphs :
7981 self .fusable_nodes .update (match .nodes_map .values ())
8082
83+ self .nn_module_blocklist = nn_module_blocklist
84+ self .nn_module_allowlist = nn_module_allowlist
85+
8186 def op_node_is_compatible ( # noqa: C901: Function is too complex
8287 self , node : torch .fx .Node , features : Optional [OpFeatures ] = None
8388 ) -> Tuple [bool , str ]:
@@ -213,10 +218,26 @@ def is_node_supported(
213218 r = self ._is_node_supported (node )
214219 return r
215220
216- def _is_node_supported (self , node : torch .fx .Node ) -> bool :
217- # Check if this node is part of a fusable subgraph
218- if node .op == "call_function" and node in self .fusable_nodes :
219- return True
221+ def _is_node_supported (self , node : torch .fx .Node ) -> bool : # noqa: C901
222+ if node .op == "call_function" :
223+ # Apply nn module allowlist and blocklist
224+ if self .nn_module_allowlist is not None :
225+ if not utils .node_comes_from_any_nn_module_in_set (
226+ node , self .nn_module_allowlist
227+ ):
228+ self .log_skip (node , "source nn.Module is not in allowlist" )
229+ return False
230+
231+ if self .nn_module_blocklist is not None :
232+ if utils .node_comes_from_any_nn_module_in_set (
233+ node , self .nn_module_blocklist
234+ ):
235+ self .log_skip (node , "source nn.Module is in blocklist" )
236+ return False
237+
238+ # Check if this node is part of a fusable subgraph
239+ if node in self .fusable_nodes :
240+ return True
220241
221242 target = node .target
222243 if node .target == torch .ops .higher_order .auto_functionalized :
@@ -311,6 +332,8 @@ def __init__(
311332 compile_options : Optional [Dict [str , Any ]] = None ,
312333 operator_blocklist : Optional [List [OpKey ]] = None ,
313334 operator_allowlist : Optional [List [OpKey ]] = None ,
335+ nn_module_blocklist : Optional [List [str ]] = None ,
336+ nn_module_allowlist : Optional [List [str ]] = None ,
314337 ) -> None :
315338 self .options : Dict [str , Any ] = {}
316339 if compile_options is not None :
@@ -331,6 +354,20 @@ def __init__(
331354 assert self .operator_allowlist is not None
332355 self .operator_allowlist .add (entry )
333356
357+ self .nn_module_blocklist : Optional [Set [str ]] = None
358+ if nn_module_blocklist is not None :
359+ self .nn_module_blocklist = set ()
360+ for entry in nn_module_blocklist or []:
361+ assert self .nn_module_blocklist is not None
362+ self .nn_module_blocklist .add (entry )
363+
364+ self .nn_module_allowlist : Optional [Set [str ]] = None
365+ if nn_module_allowlist is not None :
366+ self .nn_module_allowlist = set ()
367+ for entry in nn_module_allowlist :
368+ assert self .nn_module_allowlist is not None
369+ self .nn_module_allowlist .add (entry )
370+
334371 def ops_to_not_decompose (
335372 self , ep : ExportedProgram
336373 ) -> Tuple [List [torch ._ops .OpOverload ], Optional [Callable [[torch .fx .Node ], bool ]]]:
@@ -362,6 +399,8 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
362399 operator_blocklist = self .operator_blocklist ,
363400 operator_allowlist = self .operator_allowlist ,
364401 fusable_subgraphs = fusable_subgraphs ,
402+ nn_module_blocklist = self .nn_module_blocklist ,
403+ nn_module_allowlist = self .nn_module_allowlist ,
365404 ),
366405 allows_single_node_partition = True ,
367406 )
0 commit comments