Skip to content

Commit 217688a

Browse files
authored
Fix for multinode (#65)
* fix for multinode * one more fix --------- Co-authored-by: Dmitry Razdoburdin <>
1 parent c25ed19 commit 217688a

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

plugin/sycl/device_manager.cc

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,25 @@ ::sycl::device DeviceManager::GetDevice(const DeviceOrd& device_spec) const {
2020
(collective::IsDistributed());
2121
if (not_use_default_selector) {
2222
DeviceRegister& device_register = GetDevicesRegister();
23-
const int device_idx =
24-
collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal;
2523
if (device_spec.IsSyclDefault()) {
2624
auto& devices = device_register.devices;
25+
const int device_idx = collective::IsDistributed()
26+
? collective::GetRank() % devices.size()
27+
: device_spec.ordinal;
2728
CHECK_LT(device_idx, devices.size());
2829
return devices[device_idx];
2930
} else if (device_spec.IsSyclCPU()) {
3031
auto& cpu_devices = device_register.cpu_devices;
32+
const int device_idx = collective::IsDistributed()
33+
? collective::GetRank() % cpu_devices.size()
34+
: device_spec.ordinal;
3135
CHECK_LT(device_idx, cpu_devices.size());
3236
return cpu_devices[device_idx];
3337
} else {
3438
auto& gpu_devices = device_register.gpu_devices;
39+
const int device_idx = collective::IsDistributed()
40+
? collective::GetRank() % gpu_devices.size()
41+
: device_spec.ordinal;
3542
CHECK_LT(device_idx, gpu_devices.size());
3643
return gpu_devices[device_idx];
3744
}
@@ -63,18 +70,25 @@ ::sycl::queue DeviceManager::GetQueue(const DeviceOrd& device_spec) const {
6370
std::lock_guard<std::mutex> guard(queue_registering_mutex);
6471
if (not_use_default_selector) {
6572
DeviceRegister& device_register = GetDevicesRegister();
66-
const int device_idx =
67-
collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal;
6873
if (device_spec.IsSyclDefault()) {
6974
auto& devices = device_register.devices;
75+
const int device_idx = collective::IsDistributed()
76+
? collective::GetRank() % devices.size()
77+
: device_spec.ordinal;
7078
CHECK_LT(device_idx, devices.size());
7179
queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]);
7280
} else if (device_spec.IsSyclCPU()) {
7381
auto& cpu_devices = device_register.cpu_devices;
82+
const int device_idx = collective::IsDistributed()
83+
? collective::GetRank() % cpu_devices.size()
84+
: device_spec.ordinal;
7485
CHECK_LT(device_idx, cpu_devices.size());
7586
queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]);
7687
} else if (device_spec.IsSyclGPU()) {
7788
auto& gpu_devices = device_register.gpu_devices;
89+
const int device_idx = collective::IsDistributed()
90+
? collective::GetRank() % gpu_devices.size()
91+
: device_spec.ordinal;
7892
CHECK_LT(device_idx, gpu_devices.size());
7993
queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]);
8094
}

0 commit comments

Comments
 (0)