File tree 1 file changed +27
-0
lines changed
fbgemm_gpu/experimental/gen_ai/src/comm
1 file changed +27
-0
lines changed Original file line number Diff line number Diff line change @@ -224,6 +224,32 @@ void nccl_allreduce(
224
224
default :
225
225
TORCH_CHECK (false , " unsupported type: " , src.scalar_type ());
226
226
}
227
+ #if defined(USE_ROCM)
228
+ if (bias) {
229
+ C10D_NCCL_CHECK (
230
+ ncclAllReduceWithBias (
231
+ src.data_ptr (),
232
+ dst.data_ptr (),
233
+ src.numel (),
234
+ type,
235
+ ncclSum,
236
+ *get_nccl_comm (comm_idx),
237
+ at::cuda::getCurrentCUDAStream (),
238
+ (*bias).data_ptr ()),
239
+ " ncclAllReduceWithBias" );
240
+ } else {
241
+ C10D_NCCL_CHECK (
242
+ ncclAllReduce (
243
+ src.data_ptr (),
244
+ dst.data_ptr (),
245
+ src.numel (),
246
+ type,
247
+ ncclSum,
248
+ *get_nccl_comm (comm_idx),
249
+ at::cuda::getCurrentCUDAStream ()),
250
+ " ncclAllReduce" );
251
+ }
252
+ #else
227
253
C10D_NCCL_CHECK (
228
254
ncclAllReduce (
229
255
src.data_ptr (),
@@ -237,6 +263,7 @@ void nccl_allreduce(
237
263
if (bias) {
238
264
dst.add_ (*bias);
239
265
}
266
+ #endif
240
267
}
241
268
242
269
} // namespace
You can’t perform that action at this time.
0 commit comments