Skip to content

Dump all_reduce duration to detect slow node. #939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions dipu/tests/python/unittests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import torch
import torch_dipu
from torch_dipu import diputype
from torch_dipu.testing._internal.common_utils import TestCase, run_tests
from torch_dipu.testing._internal.common_utils import (
TestCase,
run_tests,
onlyOn,
)


class TestGenerator(TestCase):
Expand All @@ -20,13 +24,13 @@ def test_python_api(self):
torch.cuda.manual_seed(i)

state = torch.cuda.get_rng_state(0)
new_state = torch.ones_like(state)
new_state = torch.ones_like(state) * 4
torch.cuda.set_rng_state(new_state, 0)
current_state = torch.cuda.get_rng_state(0)
self.assertTrue(
torch.allclose(
current_state,
torch.tensor(1, device=current_state.device, dtype=current_state.dtype),
torch.tensor(4, device=current_state.device, dtype=current_state.dtype),
)
)

Expand Down Expand Up @@ -194,6 +198,23 @@ def test_default_generators(self):
torch.cuda.default_generators[0].manual_seed(1)
self.assertEqual(torch.cuda.default_generators[0].initial_seed(), 1)

@onlyOn("CUDA")
def test_cuda_generator(self):
state = torch.cuda.get_rng_state(0)
state[-16] = 4
state[-15:-8] = 0
state[-8:] = 0
torch.cuda.set_rng_state(state)
self.assertEqual(torch.cuda.initial_seed(), 4)

# invalid offset, offset must be a multiple of 4
state[-8:] = 1
try:
torch.cuda.set_rng_state(state)
self.assertTrue(False, "should not go here")
except Exception as ex:
self.assertIn("offset must be a multiple of 4", ex.args[0])


if __name__ == "__main__":
run_tests()
2 changes: 1 addition & 1 deletion dipu/third_party/DIOPI
Submodule DIOPI updated 36 files
+1 −1 .github/workflows/main.yml
+3 −3 diopi_test/diopi_stub/csrc/litert.cpp
+94 −10 diopi_test/python/configs/diopi_configs.py
+33 −0 diopi_test/python/conformance/customized_test.py
+69 −0 diopi_test/python/conformance/diopi_functions.py
+6 −4 diopi_test/python/conformance/gen_input.py
+19 −5 diopi_test/python/conformance/gen_output.py
+1 −0 diopi_test/python/conformance/global_op_list.py
+17 −2 impl/ascend/aclnn/adaptor.hpp
+116 −0 impl/ascend/ascend_tensor.cpp
+5 −0 impl/ascend/ascend_tensor.hpp
+4 −0 impl/ascend/common/acloprunner.hpp
+43 −0 impl/ascend/common/utils.cpp
+9 −0 impl/ascend/convert_config.yaml
+7 −15 impl/ascend/device_configs.py
+335 −0 impl/ascend/functions/index.cpp
+34 −0 impl/ascend/functions/syn_batch_norm.cpp
+107 −0 impl/ascend/functions_ext/token_attention_inference.cpp
+103 −0 impl/ascend/functions_ext/token_softmax_reducev_inference.cpp
+4 −0 impl/ascend_npu/CMakeLists.txt
+7 −4 impl/ascend_npu/ascend_config.yaml
+5 −0 impl/camb/device_configs.py
+78 −0 impl/cuda/device_configs.py
+1 −0 impl/cuda/error.cpp
+49 −1 impl/cuda/functions.cu
+13 −17 impl/cuda/test/CMakeLists.txt
+12 −0 impl/cuda/test/conform_test.cpp
+4 −1 impl/droplet/CMakeLists.txt
+1 −1 impl/droplet/test/CMakeLists.txt
+49 −0 impl/torch/functions/functions.cpp
+62 −12 impl/torch/functions/functions_ext.cpp
+14 −1 impl/torch/functions/functions_ext/flash-attention/CMakeLists.txt
+43 −38 impl/torch/functions/functions_ext/flash-attention/include/flash_attn/flash_api.h
+4 −3 impl/torch/functions/functions_sparse.cpp
+4 −3 proto/include/diopi/diopirt.h
+42 −0 proto/include/diopi/functions.h
5 changes: 4 additions & 1 deletion dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2023, DeepLink.

Check notice on line 1 in dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp

View workflow job for this annotation

GitHub Actions / clang-format

Run clang-format on dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp

File dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp does not conform to Custom style guidelines. (lines 208)
#include <chrono>
#include <cstddef>
#include <cstdint>
Expand Down Expand Up @@ -205,7 +205,8 @@
py::arg("enable_timing") = false, py::arg("blocking") = false,
py::arg("interprocess") = false)
.def("record", py::overload_cast<>(&DIPUEvent::record), "record event")
.def("record", py::overload_cast<const DIPUStream&>(&DIPUEvent::record),
.def("record", py::overload_cast<const DIPUStream&, bool>(&DIPUEvent::record),
py::arg("stream"), py::arg("use_pool") = true,
"record event on stream")
.def("elapsed_time", &dipu::DIPUEvent::elapsed_time)
.def("synchronize",
Expand Down Expand Up @@ -249,6 +250,8 @@
return kBackendDefaultTimeout;
});

m.def("dump_info", dumpInfo);

// py::object mdist = py::module::import("torch.distributed");
// py::object register_backend =
// mdist.attr("Backend").attr("register_backend"); The first parameter is the
Expand Down
21 changes: 17 additions & 4 deletions dipu/torch_dipu/csrc_dipu/runtime/core/DIPUEvent.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2023, DeepLink.

Check notice on line 1 in dipu/torch_dipu/csrc_dipu/runtime/core/DIPUEvent.h

View workflow job for this annotation

GitHub Actions / clang-format

Run clang-format on dipu/torch_dipu/csrc_dipu/runtime/core/DIPUEvent.h

File dipu/torch_dipu/csrc_dipu/runtime/core/DIPUEvent.h does not conform to Custom style guidelines. (lines 131, 141)
#pragma once

#include <cstdint>
Expand All @@ -20,6 +20,7 @@
deviceEvent_t event_{nullptr};
c10::DeviceIndex device_index_{-1};
c10::StreamId last_recorded_stream_id_{-1};
bool use_pool_{true};

public:
DIPUEvent(const DIPUEvent&) = delete;
Expand All @@ -29,7 +30,8 @@
constexpr DIPUEvent(DIPUEvent&& other) noexcept
: event_(other.event_),
device_index_(other.device_index_),
last_recorded_stream_id_(other.last_recorded_stream_id_) {
last_recorded_stream_id_(other.last_recorded_stream_id_),
use_pool_(other.use_pool_) {
other.unsafe_reset();
}

Expand All @@ -39,6 +41,7 @@
event_ = other.event_;
device_index_ = other.device_index_;
last_recorded_stream_id_ = other.last_recorded_stream_id_;
use_pool_ = other.use_pool_;
other.unsafe_reset();
}
return *this;
Expand Down Expand Up @@ -76,8 +79,9 @@

void record() { record(getCurrentDIPUStream()); }

void record(const DIPUStream& stream) {
void record(const DIPUStream& stream, bool use_pool = true) {
if (!initialized()) {
use_pool_ = use_pool;
create_event(stream.device_index());
}

Expand Down Expand Up @@ -124,14 +128,23 @@
void create_event(c10::DeviceIndex device_index) {
device_index_ = device_index;
DIPUGuard guard(device_index_);
devproxy::createEvent(&event_);
if(use_pool_) {
devproxy::createEvent(&event_);
} else {
devapis::createEvent(&event_);
}
}

void release_event() {
if (initialized()) {
DIPUGuard guard(device_index_);
devproxy::destroyEvent(event_);
if(use_pool_) {
devproxy::destroyEvent(event_);
} else {
devapis::destroyEvent(event_);
}
event_ = nullptr;
use_pool_ = true;
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ DIPUGeneratorImpl::DIPUGeneratorImpl(at::DeviceIndex device_index)
*/
void DIPUGeneratorImpl::set_current_seed(uint64_t seed) {
seed_ = seed;
offset_ = 0;
set_offset(0);
state_need_reset_ = true;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) 2023, DeepLink.
#include "ProcessGroupDICL.h"

#include <fstream>
#include <mutex>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -120,6 +122,70 @@ void checkGatherScatterRootRank(

} // anonymous namespace

// start WorkStore

class WorkStore {
struct WorkInfo {
DIPUEvent startEvent_;
DIPUEvent endEvent_;
int rank_;
int comm_size_;
};

public:
void setUid(const std::vector<uint8_t>& uidVec) { uniqueidVec_ = uidVec; }

size_t recordStart(const DIPUStream& stream, int rank, int comm_size) {
std::lock_guard<std::mutex> lock(mtx_);
info_vec_.push_back(WorkInfo());
size_t index = info_vec_.size() - 1;
info_vec_[index].startEvent_.record(stream, false);
info_vec_[index].rank_ = rank;
info_vec_[index].comm_size_ = comm_size;

return index;
}

void recordEnd(const DIPUStream& stream, size_t index) {
std::lock_guard<std::mutex> lock(mtx_);
info_vec_[index].endEvent_.record(stream, false);
}

void dump(std::string& path) {
for (auto& wi : info_vec_) {
wi.endEvent_.synchronize();
float duration = wi.startEvent_.elapsed_time(wi.endEvent_);
std::ostringstream oss;
oss << "PG uniqueId = ";
for (int i = 0; i < 32; ++i) {
oss << static_cast<int>(uniqueidVec_[i]);
}
oss << ", comm_size = " << wi.comm_size_ << ", duration = " << duration
<< std::endl;
std::string filePath = path + "/rank_" + std::to_string(wi.rank_);
std::ofstream outFile(filePath, std::ios::app);
outFile << oss.str();
}

info_vec_.clear();
}

private:
std::vector<WorkInfo> info_vec_;
std::mutex mtx_;
std::vector<uint8_t> uniqueidVec_;
};

// end WorkStore

std::vector<std::shared_ptr<WorkStore>> global_stores;

void dumpInfo(std::string& path) {
for (auto p : global_stores) {
p->dump(path);
}
}

// start WorkDICL

// currently DICL do not support error check
Expand Down Expand Up @@ -196,7 +262,10 @@ ProcessGroupDICL::WorkDICL::getFuture() {

ProcessGroupDICL::ProcessGroupDICL(const c10::intrusive_ptr<Store>& store,
int rank, int size)
: c10d::Backend(rank, size), store_(store) {
: c10d::Backend(rank, size),
store_(store),
pWstore_(std::make_shared<WorkStore>()) {
global_stores.push_back(pWstore_);
char* blockingWait = getenv(DICL_BLOCKING_WAIT);
try {
if (blockingWait != nullptr) {
Expand Down Expand Up @@ -238,6 +307,7 @@ void ProcessGroupDICL::broadcastUniqueID(commUniqueId* uniqueId,
auto vec = std::vector<uint8_t>(reinterpret_cast<uint8_t*>(uniqueId),
reinterpret_cast<uint8_t*>(uniqueId) +
devapis::DICL_UNIQUE_ID_BYTES_SIZE);
pWstore_->setUid(vec);
store_->set(storeKey, vec);
} else {
auto vec = store_->get(storeKey);
Expand All @@ -246,6 +316,7 @@ void ProcessGroupDICL::broadcastUniqueID(commUniqueId* uniqueId,
"Unexpected DICL unique ID length received "
"from the store");
}
pWstore_->setUid(vec);
std::memcpy(uniqueId, vec.data(), vec.size());
}
}
Expand Down Expand Up @@ -442,6 +513,13 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::doComm(
auto work = c10::make_intrusive<ProcessGroupDICL::WorkDICL>(
diclComms, blockingWait_, opTimeout_);

size_t eventIndex;
if (opType == OpType::ALLREDUCE) {
eventIndex =
pWstore_->recordStart(diclComms[0]->diclStream_, this->rank_,
inputs[0].element_size() * inputs[0].numel());
}

OptionalDIPUGuard dipuGuard;
pre(diclComms);

Expand All @@ -466,6 +544,11 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::doComm(
}

post(diclComms);

if (opType == OpType::ALLREDUCE) {
pWstore_->recordEnd(diclComms[0]->diclStream_, eventIndex);
}

work->record();

work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#pragma once

#include <chrono>
#include <queue>
#include <string_view>
#include <unordered_map>
#include <vector>
Expand Down Expand Up @@ -40,6 +41,10 @@ using c10d::Work;
constexpr const char* DICL_BLOCKING_WAIT = "DICL_BLOCKING_WAIT";
constexpr int64_t diclSyncBusyWaitMillis = 30;

void dumpInfo(std::string& path);

class WorkStore;

/**
* ProcessGroupDICL implements DICLbindings for c10d.
*
Expand Down Expand Up @@ -310,6 +315,8 @@ class DIPU_API ProcessGroupDICL : public Backend {
bool blockingWait_ = false;

std::chrono::milliseconds opTimeout_ = kBackendDefaultTimeout;

std::shared_ptr<WorkStore> pWstore_;
};

namespace dicl_hook {
Expand Down
10 changes: 8 additions & 2 deletions dipu/torch_dipu/csrc_dipu/vendor/cuda/CudaGeneratorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ class CUDAGeneratorImpl : public dipu::DIPUGeneratorImpl {
#else
auto new_rng_state = state.data_dtype_initialized<uint8_t>();
#endif
memcpy(&input_seed, new_rng_state, seed_size);
memcpy(&input_seed, new_rng_state + states_size, seed_size);
this->set_current_seed(input_seed);
int64_t philox_offset = 0;
if (!no_philox_seed) {
memcpy(&philox_offset, new_rng_state + seed_size, offset_size);
memcpy(&philox_offset, new_rng_state + states_size + seed_size,
offset_size);
}
this->set_offset(static_cast<uint64_t>(philox_offset));

Expand Down Expand Up @@ -71,6 +72,11 @@ class CUDAGeneratorImpl : public dipu::DIPUGeneratorImpl {
state_need_reset_ = false;
}
}

void set_offset(uint64_t offset) override {
TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4");
DIPUGeneratorImpl::set_offset(offset);
}
};

// NOLINTNEXTLINE(readability-const-return-type)
Expand Down
7 changes: 7 additions & 0 deletions dipu/torch_dipu/dipu/distributed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import timedelta
import os

import torch
from torch import distributed as dist
Expand Down Expand Up @@ -113,6 +114,10 @@ def _wrap_new_group(
return _raw_new_group(ranks, timeout, backend, pg_options)


def _wrap_dump_info(path):
_C.dump_info(path)


def apply_dist_patch():
dist.get_backend = _wrap_get_backend
dist.init_process_group = _wrap_init_process_groups
Expand All @@ -123,3 +128,5 @@ def apply_dist_patch():

if dipu.get_dipu_torch_version() == dipu.torch_ver_200:
dist.new_group = _wrap_new_group

dist.dump_info = _wrap_dump_info
Loading