@@ -310,11 +310,13 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True,
310310 self .bias = None
311311
312312 def forward (self , input ):
313- weight , bias = comfy .ops .cast_bias_weight (self , input )
313+ weight , bias , offload_stream = comfy .ops .cast_bias_weight (self , input , offloadable = True )
314314 if self .up is not None :
315- return torch .nn .functional .linear (input , weight + (torch .mm (self .up .flatten (start_dim = 1 ), self .down .flatten (start_dim = 1 ))).reshape (self .weight .shape ).type (input .dtype ), bias )
315+ x = torch .nn .functional .linear (input , weight + (torch .mm (self .up .flatten (start_dim = 1 ), self .down .flatten (start_dim = 1 ))).reshape (self .weight .shape ).type (input .dtype ), bias )
316316 else :
317- return torch .nn .functional .linear (input , weight , bias )
317+ x = torch .nn .functional .linear (input , weight , bias )
318+ comfy .ops .uncast_bias_weight (self , weight , bias , offload_stream )
319+ return x
318320
319321 class Conv2d (torch .nn .Module , comfy .ops .CastWeightBiasOp ):
320322 def __init__ (
@@ -350,12 +352,13 @@ def __init__(
350352
351353
352354 def forward (self , input ):
353- weight , bias = comfy .ops .cast_bias_weight (self , input )
355+ weight , bias , offload_stream = comfy .ops .cast_bias_weight (self , input , offloadable = True )
354356 if self .up is not None :
355- return torch .nn .functional .conv2d (input , weight + (torch .mm (self .up .flatten (start_dim = 1 ), self .down .flatten (start_dim = 1 ))).reshape (self .weight .shape ).type (input .dtype ), bias , self .stride , self .padding , self .dilation , self .groups )
357+ x = torch .nn .functional .conv2d (input , weight + (torch .mm (self .up .flatten (start_dim = 1 ), self .down .flatten (start_dim = 1 ))).reshape (self .weight .shape ).type (input .dtype ), bias , self .stride , self .padding , self .dilation , self .groups )
356358 else :
357- return torch .nn .functional .conv2d (input , weight , bias , self .stride , self .padding , self .dilation , self .groups )
358-
359+ x = torch .nn .functional .conv2d (input , weight , bias , self .stride , self .padding , self .dilation , self .groups )
360+ comfy .ops .uncast_bias_weight (self , weight , bias , offload_stream )
361+ return x
359362
360363class ControlLora (ControlNet ):
361364 def __init__ (self , control_weights , global_average_pooling = False , model_options = {}): #TODO? model_options
0 commit comments