Skip to content

Commit ab7ab5b

Browse files
authored
Fix Race condition in --async-offload that can cause corruption (#10501)
* mm: factor out the current stream getter Make this a reusable function. * ops: sync the offload stream with the consumption of w&b This sync is nessacary as pytorch will queue cuda async frees on the same stream as created to tensor. In the case of async offload, this will be on the offload stream. Weights and biases can go out of scope in python which then triggers the pytorch garbage collector to queue the free operation on the offload stream possible before the compute stream has used the weight. This causes a use after free on weight data leading to total corruption of some workflows. So sync the offload stream with the compute stream after the weight has been used so the free has to wait for the weight to be used. The cast_bias_weight is extended in a backwards compatible way with the new behaviour opt-in on a defaulted parameter. This handles custom node packs calling cast_bias_weight and defeatures async-offload for them (as they do not handle the race). The pattern is now: cast_bias_weight(... , offloadable=True) #This might be offloaded thing(weight, bias, ...) uncast_bias_weight(...) * controlnet: adopt new cast_bias_weight synchronization scheme This is nessacary for safe async weight offloading. * mm: sync the last stream in the queue, not the next Currently this peeks ahead to sync the next stream in the queue of streams with the compute stream. This doesnt allow a lot of parallelization, as then end result is you can only get one weight load ahead regardless of how many streams you have. Rotate the loop logic here to synchronize the end of the queue before returning the next stream. This allows weights to be loaded ahead of the compute streams position.
1 parent ec4fc2a commit ab7ab5b

File tree

3 files changed

+114
-52
lines changed

3 files changed

+114
-52
lines changed

comfy/controlnet.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -310,11 +310,13 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True,
310310
self.bias = None
311311

312312
def forward(self, input):
313-
weight, bias = comfy.ops.cast_bias_weight(self, input)
313+
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
314314
if self.up is not None:
315-
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
315+
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
316316
else:
317-
return torch.nn.functional.linear(input, weight, bias)
317+
x = torch.nn.functional.linear(input, weight, bias)
318+
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
319+
return x
318320

319321
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
320322
def __init__(
@@ -350,12 +352,13 @@ def __init__(
350352

351353

352354
def forward(self, input):
353-
weight, bias = comfy.ops.cast_bias_weight(self, input)
355+
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
354356
if self.up is not None:
355-
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
357+
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
356358
else:
357-
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
358-
359+
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
360+
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
361+
return x
359362

360363
class ControlLora(ControlNet):
361364
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options

comfy/model_management.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,6 +1013,16 @@ def force_channels_last():
10131013
NUM_STREAMS = 2
10141014
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
10151015

1016+
def current_stream(device):
1017+
if device is None:
1018+
return None
1019+
if is_device_cuda(device):
1020+
return torch.cuda.current_stream()
1021+
elif is_device_xpu(device):
1022+
return torch.xpu.current_stream()
1023+
else:
1024+
return None
1025+
10161026
stream_counters = {}
10171027
def get_offload_stream(device):
10181028
stream_counter = stream_counters.get(device, 0)
@@ -1021,21 +1031,17 @@ def get_offload_stream(device):
10211031

10221032
if device in STREAMS:
10231033
ss = STREAMS[device]
1024-
s = ss[stream_counter]
1034+
#Sync the oldest stream in the queue with the current
1035+
ss[stream_counter].wait_stream(current_stream(device))
10251036
stream_counter = (stream_counter + 1) % len(ss)
1026-
if is_device_cuda(device):
1027-
ss[stream_counter].wait_stream(torch.cuda.current_stream())
1028-
elif is_device_xpu(device):
1029-
ss[stream_counter].wait_stream(torch.xpu.current_stream())
10301037
stream_counters[device] = stream_counter
1031-
return s
1038+
return ss[stream_counter]
10321039
elif is_device_cuda(device):
10331040
ss = []
10341041
for k in range(NUM_STREAMS):
10351042
ss.append(torch.cuda.Stream(device=device, priority=0))
10361043
STREAMS[device] = ss
10371044
s = ss[stream_counter]
1038-
stream_counter = (stream_counter + 1) % len(ss)
10391045
stream_counters[device] = stream_counter
10401046
return s
10411047
elif is_device_xpu(device):
@@ -1044,18 +1050,14 @@ def get_offload_stream(device):
10441050
ss.append(torch.xpu.Stream(device=device, priority=0))
10451051
STREAMS[device] = ss
10461052
s = ss[stream_counter]
1047-
stream_counter = (stream_counter + 1) % len(ss)
10481053
stream_counters[device] = stream_counter
10491054
return s
10501055
return None
10511056

10521057
def sync_stream(device, stream):
1053-
if stream is None:
1058+
if stream is None or current_stream(device) is None:
10541059
return
1055-
if is_device_cuda(device):
1056-
torch.cuda.current_stream().wait_stream(stream)
1057-
elif is_device_xpu(device):
1058-
torch.xpu.current_stream().wait_stream(stream)
1060+
current_stream(device).wait_stream(stream)
10591061

10601062
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
10611063
if device is None or weight.device == device:

comfy/ops.py

Lines changed: 89 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,12 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
7070
def cast_to_input(weight, input, non_blocking=False, copy=True):
7171
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
7272

73+
7374
@torch.compiler.disable()
74-
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
75+
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
76+
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
77+
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
78+
# will add async-offload support to your cast and improve performance.
7579
if input is not None:
7680
if dtype is None:
7781
dtype = input.dtype
@@ -80,7 +84,11 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
8084
if device is None:
8185
device = input.device
8286

83-
offload_stream = comfy.model_management.get_offload_stream(device)
87+
if offloadable:
88+
offload_stream = comfy.model_management.get_offload_stream(device)
89+
else:
90+
offload_stream = None
91+
8492
if offload_stream is not None:
8593
wf_context = offload_stream
8694
else:
@@ -105,7 +113,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
105113
weight = f(weight)
106114

107115
comfy.model_management.sync_stream(device, offload_stream)
108-
return weight, bias
116+
if offloadable:
117+
return weight, bias, offload_stream
118+
else:
119+
#Legacy function signature
120+
return weight, bias
121+
122+
123+
def uncast_bias_weight(s, weight, bias, offload_stream):
124+
if offload_stream is None:
125+
return
126+
if weight is not None:
127+
device = weight.device
128+
else:
129+
if bias is None:
130+
return
131+
device = bias.device
132+
offload_stream.wait_stream(comfy.model_management.current_stream(device))
133+
109134

110135
class CastWeightBiasOp:
111136
comfy_cast_weights = False
@@ -118,8 +143,10 @@ def reset_parameters(self):
118143
return None
119144

120145
def forward_comfy_cast_weights(self, input):
121-
weight, bias = cast_bias_weight(self, input)
122-
return torch.nn.functional.linear(input, weight, bias)
146+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
147+
x = torch.nn.functional.linear(input, weight, bias)
148+
uncast_bias_weight(self, weight, bias, offload_stream)
149+
return x
123150

124151
def forward(self, *args, **kwargs):
125152
run_every_op()
@@ -133,8 +160,10 @@ def reset_parameters(self):
133160
return None
134161

135162
def forward_comfy_cast_weights(self, input):
136-
weight, bias = cast_bias_weight(self, input)
137-
return self._conv_forward(input, weight, bias)
163+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
164+
x = self._conv_forward(input, weight, bias)
165+
uncast_bias_weight(self, weight, bias, offload_stream)
166+
return x
138167

139168
def forward(self, *args, **kwargs):
140169
run_every_op()
@@ -148,8 +177,10 @@ def reset_parameters(self):
148177
return None
149178

150179
def forward_comfy_cast_weights(self, input):
151-
weight, bias = cast_bias_weight(self, input)
152-
return self._conv_forward(input, weight, bias)
180+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
181+
x = self._conv_forward(input, weight, bias)
182+
uncast_bias_weight(self, weight, bias, offload_stream)
183+
return x
153184

154185
def forward(self, *args, **kwargs):
155186
run_every_op()
@@ -172,8 +203,10 @@ def _conv_forward(self, input, weight, bias, *args, **kwargs):
172203
return super()._conv_forward(input, weight, bias, *args, **kwargs)
173204

174205
def forward_comfy_cast_weights(self, input):
175-
weight, bias = cast_bias_weight(self, input)
176-
return self._conv_forward(input, weight, bias)
206+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
207+
x = self._conv_forward(input, weight, bias)
208+
uncast_bias_weight(self, weight, bias, offload_stream)
209+
return x
177210

178211
def forward(self, *args, **kwargs):
179212
run_every_op()
@@ -187,8 +220,10 @@ def reset_parameters(self):
187220
return None
188221

189222
def forward_comfy_cast_weights(self, input):
190-
weight, bias = cast_bias_weight(self, input)
191-
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
223+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
224+
x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
225+
uncast_bias_weight(self, weight, bias, offload_stream)
226+
return x
192227

193228
def forward(self, *args, **kwargs):
194229
run_every_op()
@@ -203,11 +238,14 @@ def reset_parameters(self):
203238

204239
def forward_comfy_cast_weights(self, input):
205240
if self.weight is not None:
206-
weight, bias = cast_bias_weight(self, input)
241+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
207242
else:
208243
weight = None
209244
bias = None
210-
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
245+
offload_stream = None
246+
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
247+
uncast_bias_weight(self, weight, bias, offload_stream)
248+
return x
211249

212250
def forward(self, *args, **kwargs):
213251
run_every_op()
@@ -223,11 +261,15 @@ def reset_parameters(self):
223261

224262
def forward_comfy_cast_weights(self, input):
225263
if self.weight is not None:
226-
weight, bias = cast_bias_weight(self, input)
264+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
227265
else:
228266
weight = None
229-
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
230-
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
267+
bias = None
268+
offload_stream = None
269+
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
270+
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
271+
uncast_bias_weight(self, weight, bias, offload_stream)
272+
return x
231273

232274
def forward(self, *args, **kwargs):
233275
run_every_op()
@@ -246,10 +288,12 @@ def forward_comfy_cast_weights(self, input, output_size=None):
246288
input, output_size, self.stride, self.padding, self.kernel_size,
247289
num_spatial_dims, self.dilation)
248290

249-
weight, bias = cast_bias_weight(self, input)
250-
return torch.nn.functional.conv_transpose2d(
291+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
292+
x = torch.nn.functional.conv_transpose2d(
251293
input, weight, bias, self.stride, self.padding,
252294
output_padding, self.groups, self.dilation)
295+
uncast_bias_weight(self, weight, bias, offload_stream)
296+
return x
253297

254298
def forward(self, *args, **kwargs):
255299
run_every_op()
@@ -268,10 +312,12 @@ def forward_comfy_cast_weights(self, input, output_size=None):
268312
input, output_size, self.stride, self.padding, self.kernel_size,
269313
num_spatial_dims, self.dilation)
270314

271-
weight, bias = cast_bias_weight(self, input)
272-
return torch.nn.functional.conv_transpose1d(
315+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
316+
x = torch.nn.functional.conv_transpose1d(
273317
input, weight, bias, self.stride, self.padding,
274318
output_padding, self.groups, self.dilation)
319+
uncast_bias_weight(self, weight, bias, offload_stream)
320+
return x
275321

276322
def forward(self, *args, **kwargs):
277323
run_every_op()
@@ -289,8 +335,11 @@ def forward_comfy_cast_weights(self, input, out_dtype=None):
289335
output_dtype = out_dtype
290336
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
291337
out_dtype = None
292-
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
293-
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
338+
weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
339+
x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
340+
uncast_bias_weight(self, weight, bias, offload_stream)
341+
return x
342+
294343

295344
def forward(self, *args, **kwargs):
296345
run_every_op()
@@ -361,7 +410,7 @@ def fp8_linear(self, input):
361410
input_dtype = input.dtype
362411

363412
if len(input.shape) == 3:
364-
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
413+
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
365414

366415
scale_weight = self.scale_weight
367416
scale_input = self.scale_input
@@ -382,6 +431,8 @@ def fp8_linear(self, input):
382431
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
383432
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
384433

434+
uncast_bias_weight(self, w, bias, offload_stream)
435+
385436
if tensor_2d:
386437
return o.reshape(input_shape[0], -1)
387438
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
@@ -404,8 +455,10 @@ def forward_comfy_cast_weights(self, input):
404455
except Exception as e:
405456
logging.info("Exception during fp8 op: {}".format(e))
406457

407-
weight, bias = cast_bias_weight(self, input)
408-
return torch.nn.functional.linear(input, weight, bias)
458+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
459+
x = torch.nn.functional.linear(input, weight, bias)
460+
uncast_bias_weight(self, weight, bias, offload_stream)
461+
return x
409462

410463
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
411464
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
@@ -433,12 +486,14 @@ def forward_comfy_cast_weights(self, input):
433486
if out is not None:
434487
return out
435488

436-
weight, bias = cast_bias_weight(self, input)
489+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
437490

438491
if weight.numel() < input.numel(): #TODO: optimize
439-
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
492+
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
440493
else:
441-
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
494+
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
495+
uncast_bias_weight(self, weight, bias, offload_stream)
496+
return x
442497

443498
def convert_weight(self, weight, inplace=False, **kwargs):
444499
if inplace:
@@ -577,8 +632,10 @@ def _forward(self, input, weight, bias):
577632
return torch.nn.functional.linear(input, weight, bias)
578633

579634
def forward_comfy_cast_weights(self, input):
580-
weight, bias = cast_bias_weight(self, input)
581-
return self._forward(input, weight, bias)
635+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
636+
x = self._forward(input, weight, bias)
637+
uncast_bias_weight(self, weight, bias, offload_stream)
638+
return x
582639

583640
def forward(self, input, *args, **kwargs):
584641
run_every_op()

0 commit comments

Comments
 (0)