diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py index f18309217c3..ac839c21f18 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -26,22 +26,22 @@ def _batched_p2p_ops( ops = [] if tensor_send_prev is not None: send_prev_op = torch.distributed.P2POp( - torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, + torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group ) ops.append(send_prev_op) if tensor_recv_prev is not None: recv_prev_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, + torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group ) ops.append(recv_prev_op) if tensor_send_next is not None: send_next_op = torch.distributed.P2POp( - torch.distributed.isend, tensor_send_next, next_pipeline_rank, + torch.distributed.isend, tensor_send_next, next_pipeline_rank, group ) ops.append(send_next_op) if tensor_recv_next is not None: recv_next_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, + torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group ) ops.append(recv_next_op) if len(ops) > 0: