2424import comfy .rmsnorm
2525import contextlib
2626
27+ def run_every_op ():
28+ comfy .model_management .throw_exception_if_processing_interrupted ()
2729
2830def scaled_dot_product_attention (q , k , v , * args , ** kwargs ):
2931 return torch .nn .functional .scaled_dot_product_attention (q , k , v , * args , ** kwargs )
@@ -109,6 +111,7 @@ def forward_comfy_cast_weights(self, input):
109111 return torch .nn .functional .linear (input , weight , bias )
110112
111113 def forward (self , * args , ** kwargs ):
114+ run_every_op ()
112115 if self .comfy_cast_weights or len (self .weight_function ) > 0 or len (self .bias_function ) > 0 :
113116 return self .forward_comfy_cast_weights (* args , ** kwargs )
114117 else :
@@ -123,6 +126,7 @@ def forward_comfy_cast_weights(self, input):
123126 return self ._conv_forward (input , weight , bias )
124127
125128 def forward (self , * args , ** kwargs ):
129+ run_every_op ()
126130 if self .comfy_cast_weights or len (self .weight_function ) > 0 or len (self .bias_function ) > 0 :
127131 return self .forward_comfy_cast_weights (* args , ** kwargs )
128132 else :
@@ -137,6 +141,7 @@ def forward_comfy_cast_weights(self, input):
137141 return self ._conv_forward (input , weight , bias )
138142
139143 def forward (self , * args , ** kwargs ):
144+ run_every_op ()
140145 if self .comfy_cast_weights or len (self .weight_function ) > 0 or len (self .bias_function ) > 0 :
141146 return self .forward_comfy_cast_weights (* args , ** kwargs )
142147 else :
@@ -151,6 +156,7 @@ def forward_comfy_cast_weights(self, input):
151156 return self ._conv_forward (input , weight , bias )
152157
153158 def forward (self , * args , ** kwargs ):
159+ run_every_op ()
154160 if self .comfy_cast_weights or len (self .weight_function ) > 0 or len (self .bias_function ) > 0 :
155161 return self .forward_comfy_cast_weights (* args , ** kwargs )
156162 else :
@@ -165,6 +171,7 @@ def forward_comfy_cast_weights(self, input):
165171 return torch .nn .functional .group_norm (input , self .num_groups , weight , bias , self .eps )
166172
167173 def forward (self , * args , ** kwargs ):
174+ run_every_op ()
168175 if self .comfy_cast_weights or len (self .weight_function ) > 0 or len (self .bias_function ) > 0 :
169176 return self .forward_comfy_cast_weights (* args , ** kwargs )
170177 else :
@@ -183,6 +190,7 @@ def forward_comfy_cast_weights(self, input):
183190 return torch .nn .functional .layer_norm (input , self .normalized_shape , weight , bias , self .eps )
184191
185192 def forward (self , * args , ** kwargs ):
193+ run_every_op ()
186194 if self .comfy_cast_weights or len (self .weight_function ) > 0 or len (self .bias_function ) > 0 :
187195 return self .forward_comfy_cast_weights (* args , ** kwargs )
188196 else :
@@ -202,6 +210,7 @@ def forward_comfy_cast_weights(self, input):
202210 # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
203211
204212 def forward (self , * args , ** kwargs ):
213+ run_every_op ()
205214 if self .comfy_cast_weights or len (self .weight_function ) > 0 or len (self .bias_function ) > 0 :
206215 return self .forward_comfy_cast_weights (* args , ** kwargs )
207216 else :
@@ -223,6 +232,7 @@ def forward_comfy_cast_weights(self, input, output_size=None):
223232 output_padding , self .groups , self .dilation )
224233
225234 def forward (self , * args , ** kwargs ):
235+ run_every_op ()
226236 if self .comfy_cast_weights or len (self .weight_function ) > 0 or len (self .bias_function ) > 0 :
227237 return self .forward_comfy_cast_weights (* args , ** kwargs )
228238 else :
@@ -244,6 +254,7 @@ def forward_comfy_cast_weights(self, input, output_size=None):
244254 output_padding , self .groups , self .dilation )
245255
246256 def forward (self , * args , ** kwargs ):
257+ run_every_op ()
247258 if self .comfy_cast_weights or len (self .weight_function ) > 0 or len (self .bias_function ) > 0 :
248259 return self .forward_comfy_cast_weights (* args , ** kwargs )
249260 else :
@@ -262,6 +273,7 @@ def forward_comfy_cast_weights(self, input, out_dtype=None):
262273 return torch .nn .functional .embedding (input , weight , self .padding_idx , self .max_norm , self .norm_type , self .scale_grad_by_freq , self .sparse ).to (dtype = output_dtype )
263274
264275 def forward (self , * args , ** kwargs ):
276+ run_every_op ()
265277 if self .comfy_cast_weights or len (self .weight_function ) > 0 or len (self .bias_function ) > 0 :
266278 return self .forward_comfy_cast_weights (* args , ** kwargs )
267279 else :
0 commit comments