Skip to content

Commit c1adef1

Browse files
committed
feat(script): add switch node elimination
1 parent 6a473e0 commit c1adef1

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

script/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,10 @@ def _node_to_assign_st(self, node):
125125
# `value is str` doesn't work
126126
# TODO: BOOLEAN, not used in any node?
127127
if type(value) is str:
128-
args[name] = {'exp': astutil.to_str(value)}
128+
args[name] = {'exp': astutil.to_str(value), 'value': value}
129129
else:
130130
# int, float
131-
args[name] = {'exp': str(value)}
131+
args[name] = {'exp': str(value), 'value': value}
132132
if hasattr(v, 'inputs'):
133133
# If a node's output is not used, it is allowed to have dangling inputs, in which case the link is None.
134134
# TODO: This breaks the order and arg positions.
@@ -149,7 +149,8 @@ def _node_to_assign_st(self, node):
149149
'type': input.type,
150150
'move': output_links is None or len(output_links) == 1
151151
}
152-
args = self._keyword_args_to_positional(v.type, args)
152+
args_dict = args
153+
args = self._keyword_args_to_positional(v.type, args_dict)
153154

154155
args_of_any_type = [arg for arg in args if arg.get('type', None) == '*']
155156

@@ -209,6 +210,7 @@ def _node_to_assign_st(self, node):
209210

210211
c = passes.reroute_elimination(v, args, vars, c)
211212
c = passes.primitive_node_elimination(v, args, vars, c)
213+
c = passes.switch_node_elimination(v, args_dict, args, vars, c)
212214
return c
213215

214216
def _topological_generations_ordered_dfs(self, end_nodes: Union[list[int], None] = None):

script/passes/__init__.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import re
22

33
# TODO: Switch nodes
4-
# e.g. TomePatchModel, CRLoadLoRA, CLIPSetLastLayer
54
# How to prevent var rename?
65
# e.g. ModelMergeSimple, CRModelInputSwitch
76

@@ -19,7 +18,26 @@ def primitive_node_elimination(v, args, vars, c):
1918
assert new_c != c, c
2019
return new_c
2120

21+
SWITCH_NODES = {
22+
'CLIPSetLastLayer': [{'stop_at_clip_layer': -1}],
23+
'CR Apply ControlNet': [{'switch': 'Off'}, {'strength': 0}],
24+
'CR Load LoRA': [{'switch': 'Off'}, {'strength_model': 0, 'strength_clip': 0}],
25+
'TomePatchModel': [{'ratio': 0}],
26+
}
27+
def switch_node_elimination(v, args_dict: dict, args, vars, c):
28+
switch_inputs = SWITCH_NODES.get(v.type)
29+
if switch_inputs is None:
30+
return c
31+
# print('switch_node_elimination:', v.type, args_dict)
32+
for switch_input in switch_inputs:
33+
for k, v in switch_input.items():
34+
# Only widget values are considered
35+
if args_dict[k].get('value') == v:
36+
return ''
37+
return c
38+
2239
__all__ = [
2340
'reroute_elimination',
2441
'primitive_node_elimination',
42+
'switch_node_elimination',
2543
]

0 commit comments

Comments
 (0)