Skip to content

Commit

Permalink
Fix the AllReduce hang issue in torch plugin (#26)
Browse files Browse the repository at this point in the history
* Fix the AllReduce hang issue in torch plugin

* Add FLAGCX_GLOO_SOCKET_IFNAME to specify tcp attr iface for gloo adaptor
  • Loading branch information
MC952-arch authored Feb 5, 2025
1 parent d733f33 commit e13b888
Show file tree
Hide file tree
Showing 6 changed files with 336 additions and 78 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,4 @@ $(OBJDIR)/%.o: %.cc
-include $(LIBOBJ:.o=.d)

clean:
@rm -rf $(LIBDIR)/$(TARGET) $(OBJDIR)
@rm -rf $(LIBDIR)/$(TARGET) $(OBJDIR)
15 changes: 11 additions & 4 deletions flagcx/adaptor/gloo_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,18 @@ flagcxResult_t glooAdaptorCommInitRank(flagcxInnerComm_t *comm, int nranks, flag
// std::cout << "Caught an exception during the creation of ibverbs transport device: " << e.what() << ". Try tcp transport device alternatively." << std::endl;
// // Alternatively, try tcp
try {
char line[1024];
FLAGCXCHECK(getHostName(line, 1024, '.'));
std::string hostname(line);
::gloo::transport::tcp::attr attr;
attr.hostname = hostname;
char line[1024];
const char* glooIface = flagcxGetEnv("FLAGCX_GLOO_SOCKET_IFNAME");
if(glooIface == NULL) {
FLAGCXCHECK(getHostName(line, 1024, '.'));
std::string hostname(line);
attr.hostname = hostname;
} else {
strcpy(line, glooIface);
std::string iface(line);
attr.iface = iface;
}
dev = ::gloo::transport::tcp::CreateDevice(attr);
} catch (const std::exception& e) {
std::cout << "Caught an exception during the creation of tcp transport device: " << e.what() << ". Fail to create gloo transport device." << std::endl;
Expand Down
182 changes: 179 additions & 3 deletions flagcx/flagcx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -581,27 +581,72 @@ flagcxResult_t flagcxAllReduce(const void *sendbuff, void *recvbuff, size_t coun
}
if (has_host_comm())
{
uint64_t timers[TIMERS_COLL_COUNT] = {0};
timers[TIMER_COLL_TOTAL] = clockNano();
void *buff_in;
void *buff_out;
size_t size = count * getFlagcxDataTypeSize(datatype);

// step 1: malloc host buffer
timers[TIMER_COLL_ALLOC] = clockNano();
deviceAdaptor->deviceMalloc(&buff_in, size, flagcxMemHost);
deviceAdaptor->deviceMalloc(&buff_out, size, flagcxMemHost);
timers[TIMER_COLL_ALLOC] = clockNano() - timers[TIMER_COLL_ALLOC];

// step 2: memcpy d2h
timers[TIMER_COLL_MEM_D2H] = clockNano();
deviceAdaptor->deviceMemcpy(buff_in, const_cast<void *>(sendbuff), size, flagcxMemcpyDeviceToHost, NULL, NULL);
timers[TIMER_COLL_MEM_D2H] = clockNano() - timers[TIMER_COLL_MEM_D2H];

// step 3: allreduce
timers[TIMER_COLL_COMM] = clockNano();
cclAdaptors[flagcxCCLAdaptorHost]->allReduce(buff_in, buff_out, count, datatype, op, comm->host_comm, NULL);
timers[TIMER_COLL_COMM] = clockNano() - timers[TIMER_COLL_COMM];

// step 4: memcpy h2d
timers[TIMER_COLL_MEM_H2D] = clockNano();
deviceAdaptor->deviceMemcpy(recvbuff, buff_out, size, flagcxMemcpyHostToDevice, NULL, NULL);
timers[TIMER_COLL_MEM_H2D] = clockNano() - timers[TIMER_COLL_MEM_H2D];

// step 5: free host buffer
timers[TIMER_COLL_FREE] = clockNano();
deviceAdaptor->deviceFree(buff_in, flagcxMemHost);
deviceAdaptor->deviceFree(buff_out, flagcxMemHost);
timers[TIMER_COLL_FREE] = clockNano() - timers[TIMER_COLL_FREE];

timers[TIMER_COLL_TOTAL] = clockNano() - timers[TIMER_COLL_TOTAL];
INFO(FLAGCX_COLL,
"Flagcx timings - %s: rank %d nranks %d total %.2fms (memory alloc %.2fms, memory free %.2fms, memory d2h %.2fms, memory h2d %.2fms, comm %.2fms)",
"GlooAllReduce",
comm->rank,
comm->nranks,
timers[TIMER_COLL_TOTAL] / 1e6,
timers[TIMER_COLL_ALLOC] / 1e6,
timers[TIMER_COLL_FREE] / 1e6,
timers[TIMER_COLL_MEM_D2H] / 1e6,
timers[TIMER_COLL_MEM_H2D] / 1e6,
timers[TIMER_COLL_COMM] / 1e6
);
}
else
{
// op validation
if (op != flagcxSum && op != flagcxMax && op != flagcxMin)
{
WARN("Unsupported reduction operation %d", op);
return flagcxInvalidArgument;
}

// intra-cluster reduce
FLAGCXCHECK(cclAdaptors[flagcxCCLAdaptorDevice]->reduce(sendbuff, recvbuff, count, datatype, op, comm->homo_inter_rank, comm->homo_comm, stream));

// inter-cluster sendrecv
deviceAdaptor->streamSynchronize(stream);
if (comm->homo_inter_rank != comm->homo_rank)
{
deviceAdaptor->deviceMemset(recvbuff, 0, count * getFlagcxDataTypeSize(datatype), flagcxMemDevice, stream);
if (op == flagcxSum) {
deviceAdaptor->deviceMemset(recvbuff, 0, count * getFlagcxDataTypeSize(datatype), flagcxMemDevice, stream);
}
}
int cid = 0;
flagcxGroupStart();
Expand Down Expand Up @@ -703,17 +748,53 @@ flagcxResult_t flagcxAllGather(const void *sendbuff, void *recvbuff, size_t send
}
if (has_host_comm())
{
uint64_t timers[TIMERS_COLL_COUNT] = {0};
timers[TIMER_COLL_TOTAL] = clockNano();
void *buff_in;
void *buff_out;
size_t size = sendcount * getFlagcxDataTypeSize(datatype);
size_t totalSize = comm->nranks * size;

// step 1: malloc host buffer
timers[TIMER_COLL_ALLOC] = clockNano();
deviceAdaptor->deviceMalloc(&buff_in, size, flagcxMemHost);
deviceAdaptor->deviceMalloc(&buff_out, totalSize, flagcxMemHost);
timers[TIMER_COLL_ALLOC] = clockNano() - timers[TIMER_COLL_ALLOC];

// step 2: memcpy d2h
timers[TIMER_COLL_MEM_D2H] = clockNano();
deviceAdaptor->deviceMemcpy(buff_in, const_cast<void *>(sendbuff), size, flagcxMemcpyDeviceToHost, NULL, NULL);
timers[TIMER_COLL_MEM_D2H] = clockNano() - timers[TIMER_COLL_MEM_D2H];

// step 3: allgather
timers[TIMER_COLL_COMM] = clockNano();
cclAdaptors[flagcxCCLAdaptorHost]->allGather(buff_in, buff_out, sendcount, datatype, comm->host_comm, NULL);
timers[TIMER_COLL_COMM] = clockNano() - timers[TIMER_COLL_COMM];

// step 4: memcpy h2d
timers[TIMER_COLL_MEM_H2D] = clockNano();
deviceAdaptor->deviceMemcpy(recvbuff, buff_out, totalSize, flagcxMemcpyHostToDevice, NULL, NULL);
timers[TIMER_COLL_MEM_H2D] = clockNano() - timers[TIMER_COLL_MEM_H2D];

// step 5: free host buffer
timers[TIMER_COLL_FREE] = clockNano();
deviceAdaptor->deviceFree(buff_in, flagcxMemHost);
deviceAdaptor->deviceFree(buff_out, flagcxMemHost);
timers[TIMER_COLL_FREE] = clockNano() - timers[TIMER_COLL_FREE];

timers[TIMER_COLL_TOTAL] = clockNano() - timers[TIMER_COLL_TOTAL];
INFO(FLAGCX_COLL,
"Flagcx timings - %s: rank %d nranks %d total %.2fms (memory alloc %.2fms, memory free %.2fms, memory d2h %.2fms, memory h2d %.2fms, comm %.2fms)",
"GlooAllGather",
comm->rank,
comm->nranks,
timers[TIMER_COLL_TOTAL] / 1e6,
timers[TIMER_COLL_ALLOC] / 1e6,
timers[TIMER_COLL_FREE] / 1e6,
timers[TIMER_COLL_MEM_D2H] / 1e6,
timers[TIMER_COLL_MEM_H2D] / 1e6,
timers[TIMER_COLL_COMM] / 1e6
);
}
else
{
Expand Down Expand Up @@ -808,16 +889,52 @@ flagcxResult_t flagcxAlltoAll(const void *sendbuff, void *recvbuff, size_t count
}
if (has_host_comm())
{
uint64_t timers[TIMERS_COLL_COUNT] = {0};
timers[TIMER_COLL_TOTAL] = clockNano();
void *buff_in;
void *buff_out;
size_t size = comm->nranks * count * getFlagcxDataTypeSize(datatype);

// step 1: malloc host buffer
timers[TIMER_COLL_ALLOC] = clockNano();
deviceAdaptor->deviceMalloc(&buff_in, size, flagcxMemHost);
deviceAdaptor->deviceMalloc(&buff_out, size, flagcxMemHost);
timers[TIMER_COLL_ALLOC] = clockNano() - timers[TIMER_COLL_ALLOC];

// step 2: memcpy d2h
timers[TIMER_COLL_MEM_D2H] = clockNano();
deviceAdaptor->deviceMemcpy(buff_in, const_cast<void *>(sendbuff), size, flagcxMemcpyDeviceToHost, NULL, NULL);
cclAdaptors[flagcxCCLAdaptorHost]->allGather(buff_in, buff_out, count, datatype, comm->host_comm, NULL);
timers[TIMER_COLL_MEM_D2H] = clockNano() - timers[TIMER_COLL_MEM_D2H];

// step 3: alltoall
timers[TIMER_COLL_COMM] = clockNano();
cclAdaptors[flagcxCCLAdaptorHost]->alltoAll(buff_in, buff_out, count, datatype, comm->host_comm, NULL);
timers[TIMER_COLL_COMM] = clockNano() - timers[TIMER_COLL_COMM];

// step 4: memcpy h2d
timers[TIMER_COLL_MEM_H2D] = clockNano();
deviceAdaptor->deviceMemcpy(recvbuff, buff_out, size, flagcxMemcpyHostToDevice, NULL, NULL);
timers[TIMER_COLL_MEM_H2D] = clockNano() - timers[TIMER_COLL_MEM_H2D];

// step 5: free host buffer
timers[TIMER_COLL_FREE] = clockNano();
deviceAdaptor->deviceFree(buff_in, flagcxMemHost);
deviceAdaptor->deviceFree(buff_out, flagcxMemHost);
timers[TIMER_COLL_FREE] = clockNano() - timers[TIMER_COLL_FREE];

timers[TIMER_COLL_TOTAL] = clockNano() - timers[TIMER_COLL_TOTAL];
INFO(FLAGCX_COLL,
"Flagcx timings - %s: rank %d nranks %d total %.2fms (memory alloc %.2fms, memory free %.2fms, memory d2h %.2fms, memory h2d %.2fms, comm %.2fms)",
"GlooAlltoAll",
comm->rank,
comm->nranks,
timers[TIMER_COLL_TOTAL] / 1e6,
timers[TIMER_COLL_ALLOC] / 1e6,
timers[TIMER_COLL_FREE] / 1e6,
timers[TIMER_COLL_MEM_D2H] / 1e6,
timers[TIMER_COLL_MEM_H2D] / 1e6,
timers[TIMER_COLL_COMM] / 1e6
);
}
else
{
Expand Down Expand Up @@ -861,13 +978,41 @@ flagcxResult_t flagcxSend(const void *sendbuff, size_t count, flagcxDataType_t d
{
if (has_host_comm())
{
uint64_t timers[TIMERS_COLL_COUNT] = {0};
timers[TIMER_COLL_TOTAL] = clockNano();
void *buff_in;
size_t size = count * getFlagcxDataTypeSize(datatype);

// step 1: malloc host buffer
timers[TIMER_COLL_ALLOC] = clockNano();
deviceAdaptor->deviceMalloc(&buff_in, size, flagcxMemHost);
timers[TIMER_COLL_ALLOC] = clockNano() - timers[TIMER_COLL_ALLOC];

// step 2: memcpy d2h
timers[TIMER_COLL_MEM_D2H] = clockNano();
deviceAdaptor->deviceMemcpy(buff_in, const_cast<void *>(sendbuff), size, flagcxMemcpyDeviceToHost, NULL, NULL);
timers[TIMER_COLL_MEM_D2H] = clockNano() - timers[TIMER_COLL_MEM_D2H];

// step 3: send
timers[TIMER_COLL_COMM] = clockNano();
cclAdaptors[flagcxCCLAdaptorHost]->send(buff_in, count, datatype, peer, comm->host_comm, NULL);
// buff_in will be freed in gloo adaptor send function
timers[TIMER_COLL_COMM] = clockNano() - timers[TIMER_COLL_COMM];

// buff_in will be freed in gloo adaptor send function?
// TODO: check if buff_in should be freed here
// deviceAdaptor->deviceFree(buff_in, flagcxMemHost);

timers[TIMER_COLL_TOTAL] = clockNano() - timers[TIMER_COLL_TOTAL];
INFO(FLAGCX_COLL,
"Flagcx timings - %s: rank %d nranks %d total %.2fms (memory alloc %.2fms, memory d2h %.2fms, comm %.2fms)",
"GlooSend",
comm->rank,
comm->nranks,
timers[TIMER_COLL_TOTAL] / 1e6,
timers[TIMER_COLL_ALLOC] / 1e6,
timers[TIMER_COLL_MEM_D2H] / 1e6,
timers[TIMER_COLL_COMM] / 1e6
);
}
else
{
Expand All @@ -888,12 +1033,43 @@ flagcxResult_t flagcxRecv(void *recvbuff, size_t count, flagcxDataType_t datatyp
{
if (has_host_comm())
{
uint64_t timers[TIMERS_COLL_COUNT] = {0};
timers[TIMER_COLL_TOTAL] = clockNano();
void *buff_out;
size_t size = count * getFlagcxDataTypeSize(datatype);

// step 1: malloc host buffer
timers[TIMER_COLL_ALLOC] = clockNano();
deviceAdaptor->deviceMalloc(&buff_out, size, flagcxMemHost);
timers[TIMER_COLL_ALLOC] = clockNano() - timers[TIMER_COLL_ALLOC];

// step 2: recv
timers[TIMER_COLL_COMM] = clockNano();
cclAdaptors[flagcxCCLAdaptorHost]->recv(buff_out, count, datatype, peer, comm->host_comm, NULL);
timers[TIMER_COLL_COMM] = clockNano() - timers[TIMER_COLL_COMM];

// step 3: memcpy h2d
timers[TIMER_COLL_MEM_H2D] = clockNano();
deviceAdaptor->deviceMemcpy(recvbuff, buff_out, size, flagcxMemcpyHostToDevice, NULL, NULL);
timers[TIMER_COLL_MEM_H2D] = clockNano() - timers[TIMER_COLL_MEM_H2D];

// step 4: free host buffer
timers[TIMER_COLL_FREE] = clockNano();
deviceAdaptor->deviceFree(buff_out, flagcxMemHost);
timers[TIMER_COLL_FREE] = clockNano() - timers[TIMER_COLL_FREE];

timers[TIMER_COLL_TOTAL] = clockNano() - timers[TIMER_COLL_TOTAL];
INFO(FLAGCX_COLL,
"Flagcx timings - %s: rank %d nranks %d total %.2fms (memory alloc %.2fms, memory free %.2fms, memory h2d %.2fms, comm %.2fms)",
"GlooRecv",
comm->rank,
comm->nranks,
timers[TIMER_COLL_TOTAL] / 1e6,
timers[TIMER_COLL_ALLOC] / 1e6,
timers[TIMER_COLL_FREE] / 1e6,
timers[TIMER_COLL_MEM_H2D] / 1e6,
timers[TIMER_COLL_COMM] / 1e6
);
}
else
{
Expand Down
3 changes: 2 additions & 1 deletion flagcx/service/debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ void flagcxSetThreadName(pthread_t thread, const char *fmt, ...);
#define TIMER_COLL_MEM_D2H 4
#define TIMER_COLL_MEM_H2D 5
#define TIMER_COLL_ALLOC 6
#define TIMERS_COLL_COUNT 7
#define TIMER_COLL_FREE 7
#define TIMERS_COLL_COUNT 8

#endif
9 changes: 4 additions & 5 deletions plugin/torch/include/backend_flagcx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ namespace c10d
public:
WorkFlagcx(
OpType opType,
c10::intrusive_ptr<c10::ivalue::Future> future, // future of the output
flagcxStream_t stream = nullptr,
flagcxDeviceHandle_t handler = nullptr,
c10::intrusive_ptr<c10::ivalue::Future> future = nullptr, // future of the output
int device_id = 0,
bool coalesced = false)
: Work(
-1, // rank, only used by recvAnySource, irrelevant in this implementation
opType),
future_(std::move(future)), stream_(stream), handler_(handler), device_id_(device_id), coalesced_(coalesced), isBarrierOp_(false)
stream_(stream), handler_(handler), future_(std::move(future)), device_id_(device_id), coalesced_(coalesced), isBarrierOp_(false)
{
#ifdef USE_NVIDIA_ADAPTOR
event_ = std::make_unique<CUDAEventFlagcx>();
Expand All @@ -40,18 +40,17 @@ namespace c10d
#elif USE_CAMBRICON_ADAPTOR
event_ = std::make_unique<MLUEventFlagcx>();
#endif
event_->record(stream_, device_id_);
printf("WorkFlagcx created with device_id = %d, coalesced = %d\n", device_id_, coalesced_);
// event_->record(stream_, device_id_);
}
bool isCompleted() override;
bool isSuccess() const override;
bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override;
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;

private:
c10::intrusive_ptr<c10::ivalue::Future> future_;
flagcxStream_t stream_;
flagcxDeviceHandle_t handler_;
c10::intrusive_ptr<c10::ivalue::Future> future_;
int device_id_;
bool coalesced_; // for group semantics, unused for now
bool isBarrierOp_;
Expand Down
Loading

0 comments on commit e13b888

Please sign in to comment.