Skip to content

Commit 6a6bec1

Browse files
committed
controlnet: adopt new cast_bias_weight synchronization scheme
This is nessacary for safe async weight offloading.
1 parent 3c2e2e6 commit 6a6bec1

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

comfy/controlnet.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -310,11 +310,13 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True,
310310
self.bias = None
311311

312312
def forward(self, input):
313-
weight, bias = comfy.ops.cast_bias_weight(self, input)
313+
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
314314
if self.up is not None:
315-
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
315+
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
316316
else:
317-
return torch.nn.functional.linear(input, weight, bias)
317+
x = torch.nn.functional.linear(input, weight, bias)
318+
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
319+
return x
318320

319321
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
320322
def __init__(
@@ -350,12 +352,13 @@ def __init__(
350352

351353

352354
def forward(self, input):
353-
weight, bias = comfy.ops.cast_bias_weight(self, input)
355+
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
354356
if self.up is not None:
355-
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
357+
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
356358
else:
357-
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
358-
359+
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
360+
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
361+
return x
359362

360363
class ControlLora(ControlNet):
361364
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options

0 commit comments

Comments
 (0)