Skip to content

Commit a50c8c2

Browse files
committed
mm: sync the last stream in the queue, not the next
Currently this peeks ahead to sync the next stream in the queue of streams with the compute stream. This doesnt allow a lot of parallelization, as then end result is you can only get one weight load ahead regardless of how many streams you have. Rotate the loop logic here to synchronize the end of the queue before returning the next stream. This allows weights to be loaded ahead of the compute streams position.
1 parent 6a6bec1 commit a50c8c2

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

comfy/model_management.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,18 +1031,17 @@ def get_offload_stream(device):
10311031

10321032
if device in STREAMS:
10331033
ss = STREAMS[device]
1034-
s = ss[stream_counter]
1035-
stream_counter = (stream_counter + 1) % len(ss)
1034+
#Sync the oldest stream in the queue with the current
10361035
ss[stream_counter].wait_stream(current_stream(device))
1036+
stream_counter = (stream_counter + 1) % len(ss)
10371037
stream_counters[device] = stream_counter
1038-
return s
1038+
return ss[stream_counter]
10391039
elif is_device_cuda(device):
10401040
ss = []
10411041
for k in range(NUM_STREAMS):
10421042
ss.append(torch.cuda.Stream(device=device, priority=0))
10431043
STREAMS[device] = ss
10441044
s = ss[stream_counter]
1045-
stream_counter = (stream_counter + 1) % len(ss)
10461045
stream_counters[device] = stream_counter
10471046
return s
10481047
elif is_device_xpu(device):
@@ -1051,7 +1050,6 @@ def get_offload_stream(device):
10511050
ss.append(torch.xpu.Stream(device=device, priority=0))
10521051
STREAMS[device] = ss
10531052
s = ss[stream_counter]
1054-
stream_counter = (stream_counter + 1) % len(ss)
10551053
stream_counters[device] = stream_counter
10561054
return s
10571055
return None

0 commit comments

Comments
 (0)