Skip to content

Commit 3c2e2e6

Browse files
committed
ops: sync the offload stream with the consumption of w&b
This sync is nessacary as pytorch will queue cuda async frees on the same stream as created to tensor. In the case of async offload, this will be on the offload stream. Weights and biases can go out of scope in python which then triggers the pytorch garbage collector to queue the free operation on the offload stream possible before the compute stream has used the weight. This causes a use after free on weight data leading to total corruption of some workflows. So sync the offload stream with the compute stream after the weight has been used so the free has to wait for the weight to be used. The cast_bias_weight is extended in a backwards compatible way with the new behaviour opt-in on a defaulted parameter. This handles custom node packs calling cast_bias_weight and defeatures async-offload for them (as they do not handle the race). The pattern is now: cast_bias_weight(... , offloadable=True) #This might be offloaded thing(weight, bias, ...) uncast_bias_weight(...)
1 parent 4005cc7 commit 3c2e2e6

File tree

1 file changed

+85
-30
lines changed

1 file changed

+85
-30
lines changed

comfy/ops.py

Lines changed: 85 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,12 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
7070
def cast_to_input(weight, input, non_blocking=False, copy=True):
7171
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
7272

73+
7374
@torch.compiler.disable()
74-
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
75+
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
76+
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
77+
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
78+
# will add async-offload support to your cast and improve performance.
7579
if input is not None:
7680
if dtype is None:
7781
dtype = input.dtype
@@ -80,7 +84,11 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
8084
if device is None:
8185
device = input.device
8286

83-
offload_stream = comfy.model_management.get_offload_stream(device)
87+
if offloadable:
88+
offload_stream = comfy.model_management.get_offload_stream(device)
89+
else:
90+
offload_stream = None
91+
8492
if offload_stream is not None:
8593
wf_context = offload_stream
8694
else:
@@ -105,7 +113,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
105113
weight = f(weight)
106114

107115
comfy.model_management.sync_stream(device, offload_stream)
108-
return weight, bias
116+
if offloadable:
117+
return weight, bias, offload_stream
118+
else:
119+
#Legacy function signature
120+
return weight, bias
121+
122+
123+
def uncast_bias_weight(s, weight, bias, offload_stream):
124+
if offload_stream is None:
125+
return
126+
if weight is not None:
127+
device = weight.device
128+
else:
129+
if bias is None:
130+
return
131+
device = bias.device
132+
offload_stream.wait_stream(comfy.model_management.current_stream(device))
133+
109134

110135
class CastWeightBiasOp:
111136
comfy_cast_weights = False
@@ -118,8 +143,10 @@ def reset_parameters(self):
118143
return None
119144

120145
def forward_comfy_cast_weights(self, input):
121-
weight, bias = cast_bias_weight(self, input)
122-
return torch.nn.functional.linear(input, weight, bias)
146+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
147+
x = torch.nn.functional.linear(input, weight, bias)
148+
uncast_bias_weight(self, weight, bias, offload_stream)
149+
return x
123150

124151
def forward(self, *args, **kwargs):
125152
run_every_op()
@@ -133,8 +160,10 @@ def reset_parameters(self):
133160
return None
134161

135162
def forward_comfy_cast_weights(self, input):
136-
weight, bias = cast_bias_weight(self, input)
137-
return self._conv_forward(input, weight, bias)
163+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
164+
x = self._conv_forward(input, weight, bias)
165+
uncast_bias_weight(self, weight, bias, offload_stream)
166+
return x
138167

139168
def forward(self, *args, **kwargs):
140169
run_every_op()
@@ -148,8 +177,10 @@ def reset_parameters(self):
148177
return None
149178

150179
def forward_comfy_cast_weights(self, input):
151-
weight, bias = cast_bias_weight(self, input)
152-
return self._conv_forward(input, weight, bias)
180+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
181+
x = self._conv_forward(input, weight, bias)
182+
uncast_bias_weight(self, weight, bias, offload_stream)
183+
return x
153184

154185
def forward(self, *args, **kwargs):
155186
run_every_op()
@@ -172,8 +203,10 @@ def _conv_forward(self, input, weight, bias, *args, **kwargs):
172203
return super()._conv_forward(input, weight, bias, *args, **kwargs)
173204

174205
def forward_comfy_cast_weights(self, input):
175-
weight, bias = cast_bias_weight(self, input)
176-
return self._conv_forward(input, weight, bias)
206+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
207+
x = self._conv_forward(input, weight, bias)
208+
uncast_bias_weight(self, weight, bias, offload_stream)
209+
return x
177210

178211
def forward(self, *args, **kwargs):
179212
run_every_op()
@@ -187,8 +220,10 @@ def reset_parameters(self):
187220
return None
188221

189222
def forward_comfy_cast_weights(self, input):
190-
weight, bias = cast_bias_weight(self, input)
191-
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
223+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
224+
x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
225+
uncast_bias_weight(self, weight, bias, offload_stream)
226+
return x
192227

193228
def forward(self, *args, **kwargs):
194229
run_every_op()
@@ -203,11 +238,14 @@ def reset_parameters(self):
203238

204239
def forward_comfy_cast_weights(self, input):
205240
if self.weight is not None:
206-
weight, bias = cast_bias_weight(self, input)
241+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
207242
else:
208243
weight = None
209244
bias = None
210-
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
245+
offload_stream = None
246+
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
247+
uncast_bias_weight(self, weight, bias, offload_stream)
248+
return x
211249

212250
def forward(self, *args, **kwargs):
213251
run_every_op()
@@ -223,11 +261,15 @@ def reset_parameters(self):
223261

224262
def forward_comfy_cast_weights(self, input):
225263
if self.weight is not None:
226-
weight, bias = cast_bias_weight(self, input)
264+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
227265
else:
228266
weight = None
229-
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
230-
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
267+
bias = None
268+
offload_stream = None
269+
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
270+
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
271+
uncast_bias_weight(self, weight, bias, offload_stream)
272+
return x
231273

232274
def forward(self, *args, **kwargs):
233275
run_every_op()
@@ -246,10 +288,12 @@ def forward_comfy_cast_weights(self, input, output_size=None):
246288
input, output_size, self.stride, self.padding, self.kernel_size,
247289
num_spatial_dims, self.dilation)
248290

249-
weight, bias = cast_bias_weight(self, input)
250-
return torch.nn.functional.conv_transpose2d(
291+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
292+
x = torch.nn.functional.conv_transpose2d(
251293
input, weight, bias, self.stride, self.padding,
252294
output_padding, self.groups, self.dilation)
295+
uncast_bias_weight(self, weight, bias, offload_stream)
296+
return x
253297

254298
def forward(self, *args, **kwargs):
255299
run_every_op()
@@ -268,10 +312,12 @@ def forward_comfy_cast_weights(self, input, output_size=None):
268312
input, output_size, self.stride, self.padding, self.kernel_size,
269313
num_spatial_dims, self.dilation)
270314

271-
weight, bias = cast_bias_weight(self, input)
272-
return torch.nn.functional.conv_transpose1d(
315+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
316+
x = torch.nn.functional.conv_transpose1d(
273317
input, weight, bias, self.stride, self.padding,
274318
output_padding, self.groups, self.dilation)
319+
uncast_bias_weight(self, weight, bias, offload_stream)
320+
return x
275321

276322
def forward(self, *args, **kwargs):
277323
run_every_op()
@@ -289,8 +335,11 @@ def forward_comfy_cast_weights(self, input, out_dtype=None):
289335
output_dtype = out_dtype
290336
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
291337
out_dtype = None
292-
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
293-
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
338+
weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
339+
x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
340+
uncast_bias_weight(self, weight, bias, offload_stream)
341+
return x
342+
294343

295344
def forward(self, *args, **kwargs):
296345
run_every_op()
@@ -356,7 +405,7 @@ def fp8_linear(self, input):
356405
input_shape = input.shape
357406
input_dtype = input.dtype
358407
if len(input.shape) == 3:
359-
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
408+
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
360409
w = w.t()
361410

362411
scale_weight = self.scale_weight
@@ -379,6 +428,8 @@ def fp8_linear(self, input):
379428
else:
380429
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
381430

431+
uncast_bias_weight(self, w, bias, offload_stream)
432+
382433
if isinstance(o, tuple):
383434
o = o[0]
384435

@@ -405,8 +456,10 @@ def forward_comfy_cast_weights(self, input):
405456
except Exception as e:
406457
logging.info("Exception during fp8 op: {}".format(e))
407458

408-
weight, bias = cast_bias_weight(self, input)
409-
return torch.nn.functional.linear(input, weight, bias)
459+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
460+
x = torch.nn.functional.linear(input, weight, bias)
461+
uncast_bias_weight(self, weight, bias, offload_stream)
462+
return x
410463

411464
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
412465
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
@@ -434,12 +487,14 @@ def forward_comfy_cast_weights(self, input):
434487
if out is not None:
435488
return out
436489

437-
weight, bias = cast_bias_weight(self, input)
490+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
438491

439492
if weight.numel() < input.numel(): #TODO: optimize
440-
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
493+
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
441494
else:
442-
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
495+
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
496+
uncast_bias_weight(self, weight, bias, offload_stream)
497+
return x
443498

444499
def convert_weight(self, weight, inplace=False, **kwargs):
445500
if inplace:

0 commit comments

Comments
 (0)