@@ -70,8 +70,12 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
7070def cast_to_input (weight , input , non_blocking = False , copy = True ):
7171 return comfy .model_management .cast_to (weight , input .dtype , input .device , non_blocking = non_blocking , copy = copy )
7272
73+
7374@torch .compiler .disable ()
74- def cast_bias_weight (s , input = None , dtype = None , device = None , bias_dtype = None ):
75+ def cast_bias_weight (s , input = None , dtype = None , device = None , bias_dtype = None , offloadable = False ):
76+ # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
77+ # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
78+ # will add async-offload support to your cast and improve performance.
7579 if input is not None :
7680 if dtype is None :
7781 dtype = input .dtype
@@ -80,7 +84,11 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
8084 if device is None :
8185 device = input .device
8286
83- offload_stream = comfy .model_management .get_offload_stream (device )
87+ if offloadable :
88+ offload_stream = comfy .model_management .get_offload_stream (device )
89+ else :
90+ offload_stream = None
91+
8492 if offload_stream is not None :
8593 wf_context = offload_stream
8694 else :
@@ -105,7 +113,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
105113 weight = f (weight )
106114
107115 comfy .model_management .sync_stream (device , offload_stream )
108- return weight , bias
116+ if offloadable :
117+ return weight , bias , offload_stream
118+ else :
119+ #Legacy function signature
120+ return weight , bias
121+
122+
123+ def uncast_bias_weight (s , weight , bias , offload_stream ):
124+ if offload_stream is None :
125+ return
126+ if weight is not None :
127+ device = weight .device
128+ else :
129+ if bias is None :
130+ return
131+ device = bias .device
132+ offload_stream .wait_stream (comfy .model_management .current_stream (device ))
133+
109134
110135class CastWeightBiasOp :
111136 comfy_cast_weights = False
@@ -118,8 +143,10 @@ def reset_parameters(self):
118143 return None
119144
120145 def forward_comfy_cast_weights (self , input ):
121- weight , bias = cast_bias_weight (self , input )
122- return torch .nn .functional .linear (input , weight , bias )
146+ weight , bias , offload_stream = cast_bias_weight (self , input , offloadable = True )
147+ x = torch .nn .functional .linear (input , weight , bias )
148+ uncast_bias_weight (self , weight , bias , offload_stream )
149+ return x
123150
124151 def forward (self , * args , ** kwargs ):
125152 run_every_op ()
@@ -133,8 +160,10 @@ def reset_parameters(self):
133160 return None
134161
135162 def forward_comfy_cast_weights (self , input ):
136- weight , bias = cast_bias_weight (self , input )
137- return self ._conv_forward (input , weight , bias )
163+ weight , bias , offload_stream = cast_bias_weight (self , input , offloadable = True )
164+ x = self ._conv_forward (input , weight , bias )
165+ uncast_bias_weight (self , weight , bias , offload_stream )
166+ return x
138167
139168 def forward (self , * args , ** kwargs ):
140169 run_every_op ()
@@ -148,8 +177,10 @@ def reset_parameters(self):
148177 return None
149178
150179 def forward_comfy_cast_weights (self , input ):
151- weight , bias = cast_bias_weight (self , input )
152- return self ._conv_forward (input , weight , bias )
180+ weight , bias , offload_stream = cast_bias_weight (self , input , offloadable = True )
181+ x = self ._conv_forward (input , weight , bias )
182+ uncast_bias_weight (self , weight , bias , offload_stream )
183+ return x
153184
154185 def forward (self , * args , ** kwargs ):
155186 run_every_op ()
@@ -172,8 +203,10 @@ def _conv_forward(self, input, weight, bias, *args, **kwargs):
172203 return super ()._conv_forward (input , weight , bias , * args , ** kwargs )
173204
174205 def forward_comfy_cast_weights (self , input ):
175- weight , bias = cast_bias_weight (self , input )
176- return self ._conv_forward (input , weight , bias )
206+ weight , bias , offload_stream = cast_bias_weight (self , input , offloadable = True )
207+ x = self ._conv_forward (input , weight , bias )
208+ uncast_bias_weight (self , weight , bias , offload_stream )
209+ return x
177210
178211 def forward (self , * args , ** kwargs ):
179212 run_every_op ()
@@ -187,8 +220,10 @@ def reset_parameters(self):
187220 return None
188221
189222 def forward_comfy_cast_weights (self , input ):
190- weight , bias = cast_bias_weight (self , input )
191- return torch .nn .functional .group_norm (input , self .num_groups , weight , bias , self .eps )
223+ weight , bias , offload_stream = cast_bias_weight (self , input , offloadable = True )
224+ x = torch .nn .functional .group_norm (input , self .num_groups , weight , bias , self .eps )
225+ uncast_bias_weight (self , weight , bias , offload_stream )
226+ return x
192227
193228 def forward (self , * args , ** kwargs ):
194229 run_every_op ()
@@ -203,11 +238,14 @@ def reset_parameters(self):
203238
204239 def forward_comfy_cast_weights (self , input ):
205240 if self .weight is not None :
206- weight , bias = cast_bias_weight (self , input )
241+ weight , bias , offload_stream = cast_bias_weight (self , input , offloadable = True )
207242 else :
208243 weight = None
209244 bias = None
210- return torch .nn .functional .layer_norm (input , self .normalized_shape , weight , bias , self .eps )
245+ offload_stream = None
246+ x = torch .nn .functional .layer_norm (input , self .normalized_shape , weight , bias , self .eps )
247+ uncast_bias_weight (self , weight , bias , offload_stream )
248+ return x
211249
212250 def forward (self , * args , ** kwargs ):
213251 run_every_op ()
@@ -223,11 +261,15 @@ def reset_parameters(self):
223261
224262 def forward_comfy_cast_weights (self , input ):
225263 if self .weight is not None :
226- weight , bias = cast_bias_weight (self , input )
264+ weight , bias , offload_stream = cast_bias_weight (self , input , offloadable = True )
227265 else :
228266 weight = None
229- return comfy .rmsnorm .rms_norm (input , weight , self .eps ) # TODO: switch to commented out line when old torch is deprecated
230- # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
267+ bias = None
268+ offload_stream = None
269+ x = comfy .rmsnorm .rms_norm (input , weight , self .eps ) # TODO: switch to commented out line when old torch is deprecated
270+ # x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
271+ uncast_bias_weight (self , weight , bias , offload_stream )
272+ return x
231273
232274 def forward (self , * args , ** kwargs ):
233275 run_every_op ()
@@ -246,10 +288,12 @@ def forward_comfy_cast_weights(self, input, output_size=None):
246288 input , output_size , self .stride , self .padding , self .kernel_size ,
247289 num_spatial_dims , self .dilation )
248290
249- weight , bias = cast_bias_weight (self , input )
250- return torch .nn .functional .conv_transpose2d (
291+ weight , bias , offload_stream = cast_bias_weight (self , input , offloadable = True )
292+ x = torch .nn .functional .conv_transpose2d (
251293 input , weight , bias , self .stride , self .padding ,
252294 output_padding , self .groups , self .dilation )
295+ uncast_bias_weight (self , weight , bias , offload_stream )
296+ return x
253297
254298 def forward (self , * args , ** kwargs ):
255299 run_every_op ()
@@ -268,10 +312,12 @@ def forward_comfy_cast_weights(self, input, output_size=None):
268312 input , output_size , self .stride , self .padding , self .kernel_size ,
269313 num_spatial_dims , self .dilation )
270314
271- weight , bias = cast_bias_weight (self , input )
272- return torch .nn .functional .conv_transpose1d (
315+ weight , bias , offload_stream = cast_bias_weight (self , input , offloadable = True )
316+ x = torch .nn .functional .conv_transpose1d (
273317 input , weight , bias , self .stride , self .padding ,
274318 output_padding , self .groups , self .dilation )
319+ uncast_bias_weight (self , weight , bias , offload_stream )
320+ return x
275321
276322 def forward (self , * args , ** kwargs ):
277323 run_every_op ()
@@ -289,8 +335,11 @@ def forward_comfy_cast_weights(self, input, out_dtype=None):
289335 output_dtype = out_dtype
290336 if self .weight .dtype == torch .float16 or self .weight .dtype == torch .bfloat16 :
291337 out_dtype = None
292- weight , bias = cast_bias_weight (self , device = input .device , dtype = out_dtype )
293- return torch .nn .functional .embedding (input , weight , self .padding_idx , self .max_norm , self .norm_type , self .scale_grad_by_freq , self .sparse ).to (dtype = output_dtype )
338+ weight , bias , offload_stream = cast_bias_weight (self , device = input .device , dtype = out_dtype , offloadable = True )
339+ x = torch .nn .functional .embedding (input , weight , self .padding_idx , self .max_norm , self .norm_type , self .scale_grad_by_freq , self .sparse ).to (dtype = output_dtype )
340+ uncast_bias_weight (self , weight , bias , offload_stream )
341+ return x
342+
294343
295344 def forward (self , * args , ** kwargs ):
296345 run_every_op ()
@@ -361,7 +410,7 @@ def fp8_linear(self, input):
361410 input_dtype = input .dtype
362411
363412 if len (input .shape ) == 3 :
364- w , bias = cast_bias_weight (self , input , dtype = dtype , bias_dtype = input_dtype )
413+ w , bias , offload_stream = cast_bias_weight (self , input , dtype = dtype , bias_dtype = input_dtype , offloadable = True )
365414
366415 scale_weight = self .scale_weight
367416 scale_input = self .scale_input
@@ -382,6 +431,8 @@ def fp8_linear(self, input):
382431 quantized_input = QuantizedTensor .from_float (input .reshape (- 1 , input_shape [2 ]), TensorCoreFP8Layout , scale = scale_input , dtype = dtype )
383432 o = torch .nn .functional .linear (quantized_input , quantized_weight , bias )
384433
434+ uncast_bias_weight (self , w , bias , offload_stream )
435+
385436 if tensor_2d :
386437 return o .reshape (input_shape [0 ], - 1 )
387438 return o .reshape ((- 1 , input_shape [1 ], self .weight .shape [0 ]))
@@ -404,8 +455,10 @@ def forward_comfy_cast_weights(self, input):
404455 except Exception as e :
405456 logging .info ("Exception during fp8 op: {}" .format (e ))
406457
407- weight , bias = cast_bias_weight (self , input )
408- return torch .nn .functional .linear (input , weight , bias )
458+ weight , bias , offload_stream = cast_bias_weight (self , input , offloadable = True )
459+ x = torch .nn .functional .linear (input , weight , bias )
460+ uncast_bias_weight (self , weight , bias , offload_stream )
461+ return x
409462
410463def scaled_fp8_ops (fp8_matrix_mult = False , scale_input = False , override_dtype = None ):
411464 logging .info ("Using scaled fp8: fp8 matrix mult: {}, scale input: {}" .format (fp8_matrix_mult , scale_input ))
@@ -433,12 +486,14 @@ def forward_comfy_cast_weights(self, input):
433486 if out is not None :
434487 return out
435488
436- weight , bias = cast_bias_weight (self , input )
489+ weight , bias , offload_stream = cast_bias_weight (self , input , offloadable = True )
437490
438491 if weight .numel () < input .numel (): #TODO: optimize
439- return torch .nn .functional .linear (input , weight * self .scale_weight .to (device = weight .device , dtype = weight .dtype ), bias )
492+ x = torch .nn .functional .linear (input , weight * self .scale_weight .to (device = weight .device , dtype = weight .dtype ), bias )
440493 else :
441- return torch .nn .functional .linear (input * self .scale_weight .to (device = weight .device , dtype = weight .dtype ), weight , bias )
494+ x = torch .nn .functional .linear (input * self .scale_weight .to (device = weight .device , dtype = weight .dtype ), weight , bias )
495+ uncast_bias_weight (self , weight , bias , offload_stream )
496+ return x
442497
443498 def convert_weight (self , weight , inplace = False , ** kwargs ):
444499 if inplace :
@@ -577,8 +632,10 @@ def _forward(self, input, weight, bias):
577632 return torch .nn .functional .linear (input , weight , bias )
578633
579634 def forward_comfy_cast_weights (self , input ):
580- weight , bias = cast_bias_weight (self , input )
581- return self ._forward (input , weight , bias )
635+ weight , bias , offload_stream = cast_bias_weight (self , input , offloadable = True )
636+ x = self ._forward (input , weight , bias )
637+ uncast_bias_weight (self , weight , bias , offload_stream )
638+ return x
582639
583640 def forward (self , input , * args , ** kwargs ):
584641 run_every_op ()
0 commit comments