Skip to content

Commit 8d29a32

Browse files
committed
fix(nyz): fix multi-machine gpu id bug (#875)
1 parent f78aed1 commit 8d29a32

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

ding/worker/learner/learner_hook.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from easydict import EasyDict
77

88
import ding
9-
from ding.utils import allreduce, read_file, save_file, get_rank
9+
from ding.utils import allreduce, read_file, save_file
1010

1111

1212
class Hook(ABC):
@@ -287,6 +287,7 @@ def should_reduce(key):
287287
# The "noreduce_" prefix is used in the unizero_multitask ddp pipeline
288288
# to indicate data that should not be reduced.
289289
return not key.startswith("noreduce_")
290+
cuda_device = torch.cuda.current_device()
290291

291292
if isinstance(data, dict):
292293
new_data = {}
@@ -303,15 +304,15 @@ def should_reduce(key):
303304
if ding.enable_linklink:
304305
allreduce(new_data)
305306
else:
306-
new_data = new_data.to(get_rank())
307+
new_data = new_data.to(cuda_device)
307308
allreduce(new_data)
308309
new_data = new_data.cpu()
309310
elif isinstance(data, numbers.Integral) or isinstance(data, numbers.Real):
310311
new_data = torch.scalar_tensor(data).reshape([1])
311312
if ding.enable_linklink:
312313
allreduce(new_data)
313314
else:
314-
new_data = new_data.to(get_rank())
315+
new_data = new_data.to(cuda_device)
315316
allreduce(new_data)
316317
new_data = new_data.cpu()
317318
new_data = new_data.item()

0 commit comments

Comments
 (0)