Skip to content

Commit 85d37c0

Browse files
ycui1984facebook-github-bot
authored andcommitted
Leverage fuse kernel in inference workload (#4157)
Summary: X-link: facebookresearch/FBGEMM#1237 Differential Revision: D75011610
1 parent 52f07e7 commit 85d37c0

File tree

1 file changed

+27
-0
lines changed
  • fbgemm_gpu/experimental/gen_ai/src/comm

1 file changed

+27
-0
lines changed

fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,32 @@ void nccl_allreduce(
224224
default:
225225
TORCH_CHECK(false, "unsupported type: ", src.scalar_type());
226226
}
227+
#if defined(USE_ROCM)
228+
if (bias) {
229+
C10D_NCCL_CHECK(
230+
ncclAllReduceWithBias(
231+
src.data_ptr(),
232+
dst.data_ptr(),
233+
src.numel(),
234+
type,
235+
ncclSum,
236+
*get_nccl_comm(comm_idx),
237+
at::cuda::getCurrentCUDAStream(),
238+
(*bias).data_ptr()),
239+
"ncclAllReduceWithBias");
240+
} else {
241+
C10D_NCCL_CHECK(
242+
ncclAllReduce(
243+
src.data_ptr(),
244+
dst.data_ptr(),
245+
src.numel(),
246+
type,
247+
ncclSum,
248+
*get_nccl_comm(comm_idx),
249+
at::cuda::getCurrentCUDAStream()),
250+
"ncclAllReduce");
251+
}
252+
#else
227253
C10D_NCCL_CHECK(
228254
ncclAllReduce(
229255
src.data_ptr(),
@@ -237,6 +263,7 @@ void nccl_allreduce(
237263
if (bias) {
238264
dst.add_(*bias);
239265
}
266+
#endif
240267
}
241268

242269
} // namespace

0 commit comments

Comments
 (0)