@@ -20,18 +20,25 @@ ::sycl::device DeviceManager::GetDevice(const DeviceOrd& device_spec) const {
20
20
(collective::IsDistributed ());
21
21
if (not_use_default_selector) {
22
22
DeviceRegister& device_register = GetDevicesRegister ();
23
- const int device_idx =
24
- collective::IsDistributed () ? collective::GetRank () : device_spec.ordinal ;
25
23
if (device_spec.IsSyclDefault ()) {
26
24
auto & devices = device_register.devices ;
25
+ const int device_idx = collective::IsDistributed ()
26
+ ? collective::GetRank () % devices.size ()
27
+ : device_spec.ordinal ;
27
28
CHECK_LT (device_idx, devices.size ());
28
29
return devices[device_idx];
29
30
} else if (device_spec.IsSyclCPU ()) {
30
31
auto & cpu_devices = device_register.cpu_devices ;
32
+ const int device_idx = collective::IsDistributed ()
33
+ ? collective::GetRank () % cpu_devices.size ()
34
+ : device_spec.ordinal ;
31
35
CHECK_LT (device_idx, cpu_devices.size ());
32
36
return cpu_devices[device_idx];
33
37
} else {
34
38
auto & gpu_devices = device_register.gpu_devices ;
39
+ const int device_idx = collective::IsDistributed ()
40
+ ? collective::GetRank () % gpu_devices.size ()
41
+ : device_spec.ordinal ;
35
42
CHECK_LT (device_idx, gpu_devices.size ());
36
43
return gpu_devices[device_idx];
37
44
}
@@ -63,18 +70,25 @@ ::sycl::queue DeviceManager::GetQueue(const DeviceOrd& device_spec) const {
63
70
std::lock_guard<std::mutex> guard (queue_registering_mutex);
64
71
if (not_use_default_selector) {
65
72
DeviceRegister& device_register = GetDevicesRegister ();
66
- const int device_idx =
67
- collective::IsDistributed () ? collective::GetRank () : device_spec.ordinal ;
68
73
if (device_spec.IsSyclDefault ()) {
69
74
auto & devices = device_register.devices ;
75
+ const int device_idx = collective::IsDistributed ()
76
+ ? collective::GetRank () % devices.size ()
77
+ : device_spec.ordinal ;
70
78
CHECK_LT (device_idx, devices.size ());
71
79
queue_register[device_spec.Name ()] = ::sycl::queue (devices[device_idx]);
72
80
} else if (device_spec.IsSyclCPU ()) {
73
81
auto & cpu_devices = device_register.cpu_devices ;
82
+ const int device_idx = collective::IsDistributed ()
83
+ ? collective::GetRank () % cpu_devices.size ()
84
+ : device_spec.ordinal ;
74
85
CHECK_LT (device_idx, cpu_devices.size ());
75
86
queue_register[device_spec.Name ()] = ::sycl::queue (cpu_devices[device_idx]);
76
87
} else if (device_spec.IsSyclGPU ()) {
77
88
auto & gpu_devices = device_register.gpu_devices ;
89
+ const int device_idx = collective::IsDistributed ()
90
+ ? collective::GetRank () % gpu_devices.size ()
91
+ : device_spec.ordinal ;
78
92
CHECK_LT (device_idx, gpu_devices.size ());
79
93
queue_register[device_spec.Name ()] = ::sycl::queue (gpu_devices[device_idx]);
80
94
}
0 commit comments