diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e3ae2f8f..2cf8d1310 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ Versioning](http://semver.org/spec/v2.0.0.html). - Introduce new experimental `for_each_item` utility to iterate over a celerity range (#199) - Add new environment variables `CELERITY_HORIZON_STEP` and `CELERITY_HORIZON_MAX_PARALLELISM` to control Horizon generation (#199) - Add new `experimental::constrain_split` API to limit how a kernel can be split (#?) +- Add automatic detection of diverging execution in debug mode (#217) - `distr_queue::fence` and `buffer_snapshot` are now stable, subsuming the `experimental::` APIs of the same name (#225) - Celerity now warns at runtime when a task declares reads from uninitialized buffers or writes with overlapping ranges between nodes (#224) - Introduce new `experimental::hint` API for providing the runtime with additional information on how to execute a task (#227) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6de2e3875..d84968ea7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,11 +23,16 @@ endif() option(CELERITY_ACCESS_PATTERN_DIAGNOSTICS "Diagnose uninitialized reads and overlapping writes" ${DEFAULT_ENABLE_DEBUG_CHECKS}) option(CELERITY_ACCESSOR_BOUNDARY_CHECK "Enable accessor boundary check" ${DEFAULT_ENABLE_DEBUG_CHECKS}) +option(CELERITY_DIVERGENCE_CHECK "Enable divergence check" ${DEFAULT_ENABLE_DEBUG_CHECKS}) if(CELERITY_ACCESSOR_BOUNDARY_CHECK AND NOT (CMAKE_BUILD_TYPE STREQUAL "Debug")) message(STATUS "Accessor boundary check enabled - this will impact kernel performance") endif() +if(CELERITY_DIVERGENCE_CHECK AND NOT (CMAKE_BUILD_TYPE STREQUAL "Debug")) + message(STATUS "Divergence checker enabled - this will impact the overall performance") +endif() + set(CELERITY_CMAKE_DIR "${PROJECT_SOURCE_DIR}/cmake") set(CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH}" "${CELERITY_CMAKE_DIR}") find_package(MPI 2.0 REQUIRED) @@ -186,6 +191,7 @@ set(SOURCES src/command_graph.cc src/config.cc src/device_queue.cc + src/divergence_checker.cc src/executor.cc src/distributed_graph_generator.cc src/graph_serializer.cc @@ -288,6 +294,7 @@ target_compile_definitions(celerity_runtime PUBLIC CELERITY_FEATURE_UNNAMED_KERNELS=$ CELERITY_DETAIL_HAS_NAMED_THREADS=$ CELERITY_ACCESSOR_BOUNDARY_CHECK=$ + CELERITY_DIVERGENCE_CHECK=$ CELERITY_ACCESS_PATTERN_DIAGNOSTICS=$ ) diff --git a/docs/pitfalls.md b/docs/pitfalls.md index 94da5df46..8fed84f8e 100644 --- a/docs/pitfalls.md +++ b/docs/pitfalls.md @@ -132,3 +132,7 @@ if(rand() > 1337) { celerity::buffer my_buffer(...); } ``` + +> Diverging Host-Execution can be detected at runtime by enabling the +> `CELERITY_DIVERGENCE_CHECK` CMake option at the cost of some runtime +> overhead (enabled by default in debug builds). diff --git a/include/communicator.h b/include/communicator.h new file mode 100644 index 000000000..68ffed950 --- /dev/null +++ b/include/communicator.h @@ -0,0 +1,47 @@ +#pragma once + +#include "types.h" + +namespace celerity::detail { + +/* + * @brief Defines an interface for a communicator that can be used to communicate between nodes. + * + * This interface is used to abstract away the communication between nodes. This allows us to use different communication backends during testing and + * runtime. For example, we can use MPI for the runtime and a custom implementation for testing. + */ +class communicator { + public: + communicator() = default; + communicator(const communicator&) = delete; + communicator(communicator&&) noexcept = default; + + communicator& operator=(const communicator&) = delete; + communicator& operator=(communicator&&) noexcept = default; + + virtual ~communicator() = default; + + template + void allgather_inplace(S* sendrecvbuf, const int sendrecvcount) { + allgather_inplace_impl(reinterpret_cast(sendrecvbuf), sendrecvcount * sizeof(S)); + } + + template + void allgather(const S* sendbuf, const int sendcount, R* recvbuf, const int recvcount) { + allgather_impl(reinterpret_cast(sendbuf), sendcount * sizeof(S), reinterpret_cast(recvbuf), recvcount * sizeof(R)); + } + + void barrier() { barrier_impl(); } + + size_t get_num_nodes() { return num_nodes_impl(); } + + node_id get_local_nid() { return local_nid_impl(); } + + protected: + virtual void allgather_inplace_impl(std::byte* sendrecvbuf, const int sendrecvcount) = 0; + virtual void allgather_impl(const std::byte* sendbuf, const int sendcount, std::byte* recvbuf, const int recvcount) = 0; + virtual void barrier_impl() = 0; + virtual size_t num_nodes_impl() = 0; + virtual node_id local_nid_impl() = 0; +}; +} // namespace celerity::detail \ No newline at end of file diff --git a/include/divergence_checker.h b/include/divergence_checker.h new file mode 100644 index 000000000..8d17f2966 --- /dev/null +++ b/include/divergence_checker.h @@ -0,0 +1,137 @@ +#pragma once + +#include +#include +#include + +#include "communicator.h" +#include "recorders.h" + +namespace celerity::detail::divergence_checker_detail { +using task_hash = size_t; +using divergence_map = std::unordered_map>; + +/** + * @brief Stores the hashes of tasks for each node. + * + * The data is stored densely so it can easily be exchanged through MPI collective operations. + */ +struct per_node_task_hashes { + public: + per_node_task_hashes(const size_t max_hash_count, const size_t num_nodes) : m_data(max_hash_count * num_nodes), m_max_hash_count(max_hash_count){}; + const task_hash& operator()(const node_id nid, const size_t i) const { return m_data.at(nid * m_max_hash_count + i); } + task_hash* data() { return m_data.data(); } + + private: + std::vector m_data; + size_t m_max_hash_count; +}; + +/** + * @brief This class checks for divergences of tasks between nodes. + * + * It is responsible for collecting the task hashes from all nodes and checking for differences -> divergence. + * When a divergence is found, the task record for the diverging task is printed and the program is terminated. + * Additionally it will also print a warning when a deadlock is suspected. + */ + +class divergence_block_chain { + friend struct divergence_block_chain_testspy; + + public: + divergence_block_chain(task_recorder& task_recorder, std::unique_ptr comm) + : m_local_nid(comm->get_local_nid()), m_num_nodes(comm->get_num_nodes()), m_per_node_hash_counts(comm->get_num_nodes()), + m_communicator(std::move(comm)) { + task_recorder.add_callback([this](const task_record& task) { add_new_task(task); }); + } + + divergence_block_chain(const divergence_block_chain&) = delete; + divergence_block_chain(divergence_block_chain&&) = delete; + + ~divergence_block_chain() = default; + + divergence_block_chain& operator=(const divergence_block_chain&) = delete; + divergence_block_chain& operator=(divergence_block_chain&&) = delete; + + bool check_for_divergence(); + + private: + node_id m_local_nid; + size_t m_num_nodes; + + std::vector m_local_hashes; + std::vector m_task_records; + size_t m_tasks_checked = 0; + size_t m_hashes_added = 0; + task_hash m_last_hash = 0; + + std::vector m_per_node_hash_counts; + std::mutex m_task_records_mutex; + + std::chrono::time_point m_last_cleared = std::chrono::steady_clock::now(); + std::chrono::seconds m_time_of_last_warning = std::chrono::seconds(0); + + std::unique_ptr m_communicator; + + void reprot_divergence(const divergence_map& check_map, const int task_num); + + void add_new_hashes(); + void clear(const int min_progress); + std::pair collect_hash_counts(); + per_node_task_hashes collect_hashes(const int min_hash_count) const; + divergence_map create_divergence_map(const per_node_task_hashes& task_hashes, const int task_num) const; + + void check_for_deadlock(); + + static void log_node_divergences(const divergence_map& check_map, const int task_id); + static void log_task_record(const divergence_map& check_map, const task_record& task, const task_hash hash); + void log_task_record_once(const divergence_map& check_map, const int task_num); + + void add_new_task(const task_record& task); + task_record thread_save_get_task_record(const size_t task_num); +}; +}; // namespace celerity::detail::divergence_checker_detail + +namespace celerity::detail { +class divergence_checker { + friend struct runtime_testspy; + + public: + divergence_checker(task_recorder& task_recorder, std::unique_ptr comm, bool test_mode = false) + : m_block_chain(task_recorder, std::move(comm)) { + if(!test_mode) { start(); } + } + + divergence_checker(const divergence_checker&) = delete; + divergence_checker(const divergence_checker&&) = delete; + + divergence_checker& operator=(const divergence_checker&) = delete; + divergence_checker& operator=(divergence_checker&&) = delete; + + ~divergence_checker() { stop(); } + + private: + std::thread m_thread; + bool m_is_running = false; + divergence_checker_detail::divergence_block_chain m_block_chain; + + void start() { + m_thread = std::thread(&divergence_checker::run, this); + m_is_running = true; + } + + void stop() { + m_is_running = false; + if(m_thread.joinable()) { m_thread.join(); } + } + + void run() { + bool is_finished = false; + while(!is_finished || m_is_running) { + is_finished = m_block_chain.check_for_divergence(); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + } +}; +}; // namespace celerity::detail diff --git a/include/grid.h b/include/grid.h index df3a3a8f8..6563fa6f3 100644 --- a/include/grid.h +++ b/include/grid.h @@ -8,6 +8,7 @@ #include #include "ranges.h" +#include "utils.h" #include "workaround.h" namespace celerity::detail { @@ -257,6 +258,27 @@ class region { } // namespace celerity::detail +template +struct std::hash> { + std::size_t operator()(const celerity::detail::box r) { + std::size_t seed = 0; + celerity::detail::utils::hash_combine(seed, std::hash>{}(r.get_min()), std::hash>{}(r.get_max())); + return seed; + }; +}; + +template +struct std::hash> { + std::size_t operator()(const celerity::detail::region r) { + std::size_t seed = 0; + for(auto& box : r.get_boxes()) { + celerity::detail::utils::hash_combine(seed, std::hash>{}(box)); + } + return seed; + }; +}; + + namespace celerity::detail::grid_detail { // forward-declaration for tests (explicitly instantiated) diff --git a/include/mpi_communicator.h b/include/mpi_communicator.h new file mode 100644 index 000000000..912e73fa3 --- /dev/null +++ b/include/mpi_communicator.h @@ -0,0 +1,39 @@ +#pragma once + +#include + +#include + +#include "communicator.h" + +namespace celerity::detail { +class mpi_communicator : public communicator { + public: + mpi_communicator(MPI_Comm comm) : m_comm(comm) {} + + private: + MPI_Comm m_comm; + + void allgather_inplace_impl(std::byte* sendrecvbuf, const int sendrecvcount) override { + MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, sendrecvbuf, sendrecvcount, MPI_BYTE, m_comm); + }; + + void allgather_impl(const std::byte* sendbuf, const int sendcount, std::byte* recvbuf, const int recvcount) override { + MPI_Allgather(sendbuf, sendcount, MPI_BYTE, recvbuf, recvcount, MPI_BYTE, m_comm); + }; + + void barrier_impl() override { MPI_Barrier(m_comm); } + + size_t num_nodes_impl() override { + int size = -1; + MPI_Comm_size(m_comm, &size); + return static_cast(size); + } + + node_id local_nid_impl() override { + int rank = -1; + MPI_Comm_rank(m_comm, &rank); + return static_cast(rank); + } +}; +} // namespace celerity::detail diff --git a/include/ranges.h b/include/ranges.h index 110676933..6f24fb3fb 100644 --- a/include/ranges.h +++ b/include/ranges.h @@ -1,6 +1,7 @@ #pragma once #include "sycl_wrappers.h" +#include "utils.h" #include "workaround.h" namespace celerity { @@ -229,6 +230,17 @@ struct ones_t { }; // namespace celerity::detail +template +struct std::hash> { + std::size_t operator()(const celerity::detail::coordinate& r) const noexcept { + std::size_t seed = 0; + for(int i = 0; i < Dims; ++i) { + celerity::detail::utils::hash_combine(seed, std::hash{}(r[i])); + } + return seed; + }; +}; + namespace celerity { template @@ -401,6 +413,17 @@ nd_range(range<3> global_range, range<3> local_range)->nd_range<3>; } // namespace celerity + +template +struct std::hash> { + std::size_t operator()(const celerity::range& r) const noexcept { return std::hash, Dims>>{}(r); }; +}; + +template +struct std::hash> { + std::size_t operator()(const celerity::id& r) const noexcept { return std::hash, Dims>>{}(r); }; +}; + namespace celerity { namespace detail { diff --git a/include/recorders.h b/include/recorders.h index caf45b8c6..8eb457987 100644 --- a/include/recorders.h +++ b/include/recorders.h @@ -53,17 +53,23 @@ struct task_record { class task_recorder { public: - using task_record = std::vector; + using task_records = std::vector; + using task_callback = std::function; task_recorder(const buffer_manager* buff_mngr = nullptr) : m_buff_mngr(buff_mngr) {} void record_task(const task& tsk); - const task_record& get_tasks() const { return m_recorded_tasks; } + void add_callback(task_callback callback); + + const task_records& get_tasks() const { return m_recorded_tasks; } private: - task_record m_recorded_tasks; + task_records m_recorded_tasks; + std::vector m_callbacks{}; const buffer_manager* m_buff_mngr; + + void invoke_callbacks(const task_record& tsk) const; }; // Command recording @@ -99,18 +105,100 @@ struct command_record { class command_recorder { public: - using command_record = std::vector; + using command_records = std::vector; command_recorder(const task_manager* task_mngr, const buffer_manager* buff_mngr = nullptr) : m_task_mngr(task_mngr), m_buff_mngr(buff_mngr) {} void record_command(const abstract_command& com); - const command_record& get_commands() const { return m_recorded_commands; } + const command_records& get_commands() const { return m_recorded_commands; } private: - command_record m_recorded_commands; + command_records m_recorded_commands; const task_manager* m_task_mngr; const buffer_manager* m_buff_mngr; }; } // namespace celerity::detail + +template <> +struct std::hash { + std::size_t operator()(const celerity::detail::reduction_record& r) const noexcept { + std::size_t seed = 0; + celerity::detail::utils::hash_combine(seed, std::hash{}(r.rid), std::hash{}(r.bid), + std::hash{}(r.buffer_name), std::hash{}(r.init_from_buffer)); + return seed; + }; +}; + +template <> +struct std::hash { + std::size_t operator()(const celerity::detail::access_record& r) { + std::size_t seed = 0; + celerity::detail::utils::hash_combine(seed, std::hash{}(r.bid), std::hash{}(r.buffer_name), + std::hash{}(r.mode), std::hash>{}(r.req)); + return seed; + }; +}; + +template +struct std::hash> { + std::size_t operator()(const celerity::detail::dependency_record& r) const noexcept { + std::size_t seed = 0; + celerity::detail::utils::hash_combine(seed, std::hash{}(r.node), std::hash{}(r.kind), + std::hash{}(r.origin)); + return seed; + }; +}; + +template <> +struct std::hash { + std::size_t operator()(const celerity::detail::side_effect_map& m) const noexcept { + std::size_t seed = 0; + for(auto& [hoid, order] : m) { + celerity::detail::utils::hash_combine( + seed, std::hash{}(hoid), std::hash{}(order)); + } + return seed; + }; +}; + +template <> +struct std::hash { + std::size_t operator()(const celerity::detail::task_record& t) const noexcept { + std::size_t seed = 0; + celerity::detail::utils::hash_combine(seed, std::hash{}(t.tid), std::hash{}(t.debug_name), + std::hash{}(t.cgid), std::hash{}(t.type), + std::hash{}(t.geometry), celerity::detail::utils::vector_hash{}(t.reductions), + celerity::detail::utils::vector_hash{}(t.accesses), std::hash{}(t.side_effect_map), + celerity::detail::utils::vector_hash{}(t.dependencies)); + + return seed; + }; +}; + +template <> +struct fmt::formatter : fmt::formatter { + static format_context::iterator format(const celerity::detail::dependency_kind& dk, format_context& ctx) { + auto out = ctx.out(); + switch(dk) { + case celerity::detail::dependency_kind::anti_dep: out = std::copy_n("anti-dep", 8, out); break; + case celerity::detail::dependency_kind::true_dep: out = std::copy_n("true-dep", 8, out); break; + } + return out; + } +}; + +template <> +struct fmt::formatter : fmt::formatter { + static format_context::iterator format(const celerity::detail::dependency_origin& dk, format_context& ctx) { + auto out = ctx.out(); + switch(dk) { + case celerity::detail::dependency_origin::dataflow: out = std::copy_n("dataflow", 8, out); break; + case celerity::detail::dependency_origin::collective_group_serialization: out = std::copy_n("collective-group-serialization", 31, out); break; + case celerity::detail::dependency_origin::execution_front: out = std::copy_n("execution-front", 15, out); break; + case celerity::detail::dependency_origin::last_epoch: out = std::copy_n("last-epoch", 10, out); break; + } + return out; + } +}; diff --git a/include/runtime.h b/include/runtime.h index fb2672619..481315578 100644 --- a/include/runtime.h +++ b/include/runtime.h @@ -7,6 +7,7 @@ #include "command.h" #include "config.h" #include "device_queue.h" +#include "divergence_checker.h" #include "frame.h" #include "host_queue.h" #include "recorders.h" @@ -101,6 +102,8 @@ namespace detail { size_t m_num_nodes; node_id m_local_nid; + std::unique_ptr m_divergence_check; + // These management classes are only constructed on the master node. std::unique_ptr m_cdag; std::unique_ptr m_schdlr; diff --git a/include/task.h b/include/task.h index d41f617dc..3716b6bfd 100644 --- a/include/task.h +++ b/include/task.h @@ -14,6 +14,7 @@ #include "lifetime_extending_state.h" #include "range_mapper.h" #include "types.h" +#include "utils.h" namespace celerity { @@ -291,3 +292,31 @@ namespace detail { } // namespace detail } // namespace celerity + +template <> +struct std::hash { + std::size_t operator()(const celerity::detail::task_geometry& g) const noexcept { + std::size_t seed = 0; + celerity::detail::utils::hash_combine(seed, std::hash{}(g.dimensions), std::hash>{}(g.global_size), + std::hash>{}(g.global_offset), std::hash>{}(g.granularity)); + return seed; + }; +}; + +template <> +struct fmt::formatter : fmt::formatter { + static format_context::iterator format(const celerity::detail::task_type& tt, format_context& ctx) { + auto out = ctx.out(); + switch(tt) { + case celerity::detail::task_type::epoch: out = std::copy_n("epoch", 5, out); break; + case celerity::detail::task_type::host_compute: out = std::copy_n("host-compute", 12, out); break; + case celerity::detail::task_type::device_compute: out = std::copy_n("device-compute", 14, out); break; + case celerity::detail::task_type::collective: out = std::copy_n("collective", 10, out); break; + case celerity::detail::task_type::master_node: out = std::copy_n("master-node", 11, out); break; + case celerity::detail::task_type::horizon: out = std::copy_n("horizon", 7, out); break; + case celerity::detail::task_type::fence: out = std::copy_n("fence", 5, out); break; + default: out = std::copy_n("unknown", 7, out); break; + } + return out; + } +}; diff --git a/include/utils.h b/include/utils.h index b296c5ffb..aefac6931 100644 --- a/include/utils.h +++ b/include/utils.h @@ -47,15 +47,30 @@ decltype(auto) match(Variant&& v, Arms&&... arms) { return std::visit(overload{std::forward(arms)...}, std::forward(v)); } -// Implementation from Boost.ContainerHash, licensed under the Boost Software License, Version 1.0. -inline void hash_combine(std::size_t& seed, std::size_t value) { seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2); } +// A parameter pack extension to the implementation from Boost.ContainerHash, licensed under the Boost Software License, Version 1.0. +template +inline void hash_combine(std::size_t& seed, const T& v, const Rest&... rest) { + seed ^= std::hash{}(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + (hash_combine(seed, rest), ...); +} struct pair_hash { template std::size_t operator()(const std::pair& p) const { std::size_t seed = 0; - hash_combine(seed, std::hash{}(p.first)); - hash_combine(seed, std::hash{}(p.second)); + hash_combine(seed, p.first, p.second); + return seed; + } +}; + + +struct vector_hash { + template + std::size_t operator()(const std::vector& v) const { + std::size_t seed = 0; + for(auto& e : v) { + hash_combine(seed, e); + } return seed; } }; diff --git a/src/config.cc b/src/config.cc index 16192909c..98458970f 100644 --- a/src/config.cc +++ b/src/config.cc @@ -201,7 +201,12 @@ namespace detail { const auto has_dry_run_nodes = parsed_and_validated_envs.get(env_dry_run_nodes); if(has_dry_run_nodes) { m_dry_run_nodes = *has_dry_run_nodes; } +#if CELERITY_DIVERGENCE_CHECK + // divergence checker needs recording + m_recording = true; +#else m_recording = parsed_and_validated_envs.get_or(env_recording, false); +#endif m_horizon_step = parsed_and_validated_envs.get(env_horizon_step); m_horizon_max_parallelism = parsed_and_validated_envs.get(env_horizon_max_para); diff --git a/src/divergence_checker.cc b/src/divergence_checker.cc new file mode 100644 index 000000000..f575d4442 --- /dev/null +++ b/src/divergence_checker.cc @@ -0,0 +1,176 @@ +#include "divergence_checker.h" + +namespace celerity::detail::divergence_checker_detail { +bool divergence_block_chain::check_for_divergence() { + add_new_hashes(); + + const auto [min_hash_count, max_hash_count] = collect_hash_counts(); + + if(min_hash_count == 0) { + if(max_hash_count != 0 && m_local_nid == 0) { + check_for_deadlock(); + } else if(max_hash_count == 0) { + return true; + } + return false; + } + + const per_node_task_hashes task_hashes = collect_hashes(min_hash_count); + + for(int j = 0; j < min_hash_count; ++j) { + const divergence_map check_map = create_divergence_map(task_hashes, j); + + // If there is more than one hash for this task, we have a divergence! + if(check_map.size() > 1) { reprot_divergence(check_map, j); } + } + + clear(min_hash_count); + + return false; +} + +void divergence_block_chain::reprot_divergence(const divergence_map& check_map, const int task_num) { + if(m_local_nid == 0) { log_node_divergences(check_map, task_num + static_cast(m_tasks_checked) + 1); } + + // sleep for local_nid * 100 ms such that we have a no lock synchronized output + std::this_thread::sleep_for(std::chrono::milliseconds(m_local_nid * 100)); + + log_task_record_once(check_map, task_num); + + m_communicator->barrier(); + + throw std::runtime_error("Divergence in task graph detected"); +} + +void divergence_block_chain::add_new_hashes() { + std::lock_guard lock(m_task_records_mutex); + for(size_t i = m_hashes_added; i < m_task_records.size(); ++i) { + std::size_t seed = m_local_hashes.empty() ? m_last_hash : m_local_hashes.back(); + celerity::detail::utils::hash_combine(seed, std::hash{}(m_task_records[i])); + m_local_hashes.push_back(seed); + } + m_last_hash = m_local_hashes.empty() ? m_last_hash : m_local_hashes.back(); + m_hashes_added = m_task_records.size(); +} + +void divergence_block_chain::clear(const int min_progress) { + m_local_hashes.erase(m_local_hashes.begin(), m_local_hashes.begin() + min_progress); + m_tasks_checked += min_progress; + + m_last_cleared = std::chrono::steady_clock::now(); +} + +std::pair divergence_block_chain::collect_hash_counts() { + m_per_node_hash_counts[m_local_nid] = static_cast(m_local_hashes.size()); + + m_communicator->allgather_inplace(m_per_node_hash_counts.data(), 1); + + const auto [min, max] = std::minmax_element(m_per_node_hash_counts.cbegin(), m_per_node_hash_counts.cend()); + + return {*min, *max}; +} + +per_node_task_hashes divergence_block_chain::collect_hashes(const int min_hash_count) const { + per_node_task_hashes data(min_hash_count, m_num_nodes); + + m_communicator->allgather(m_local_hashes.data(), min_hash_count, data.data(), min_hash_count); + + return data; +} + + +divergence_map divergence_block_chain::create_divergence_map(const per_node_task_hashes& task_hashes, const int task_num) const { + divergence_map check_map; + for(node_id nid = 0; nid < m_num_nodes; ++nid) { + check_map[task_hashes(nid, task_num)].push_back(nid); + } + return check_map; +} + +void divergence_block_chain::check_for_deadlock() { + auto diff = std::chrono::duration_cast(std::chrono::steady_clock::now() - m_last_cleared); + + if(diff >= std::chrono::seconds(10) && diff - m_time_of_last_warning >= std::chrono::seconds(5)) { + std::string warning = fmt::format("After {} seconds of waiting, node(s)", diff.count()); + + std::vector stuck_nodes; + for(node_id nid = 0; nid < m_num_nodes; ++nid) { + if(m_per_node_hash_counts[nid] == 0) stuck_nodes.push_back(nid); + } + warning += fmt::format(" {} ", fmt::join(stuck_nodes, ",")); + warning += "did not move to the next task. The runtime might be stuck."; + + CELERITY_WARN("{}", warning); + m_time_of_last_warning = diff; + } +} + +void divergence_block_chain::log_node_divergences(const divergence_map& check_map, const int task_id) { + // TODO: Can we print task debug label here? + std::string error = fmt::format( + "Detected divergence in execution between worker nodes. This is a bug in your program! Task {} has different hashes on these nodes:\n\n", task_id); + for(auto& [hash, nodes] : check_map) { + error += fmt::format("{:#x} on {} {}\n", hash, nodes.size() > 1 ? "nodes" : "node", fmt::join(nodes, ",")); + } + CELERITY_CRITICAL("{}", error); +} + +void divergence_block_chain::log_task_record(const divergence_map& check_map, const task_record& task, const task_hash hash) { + std::string task_record_output = fmt::format("Task record for hash {:#x}:\n\n", hash); + task_record_output += fmt::format("id: {}, debug_name: {}, type: {}, cgid: {}\n", task.tid, task.debug_name, task.type, task.cgid); + const auto& geometry = task.geometry; + task_record_output += fmt::format("geometry:\n"); + task_record_output += fmt::format("\t dimensions: {}, global_size: {}, global_offset: {}, granularity: {}\n", geometry.dimensions, geometry.global_size, + geometry.global_offset, geometry.granularity); + + if(!task.reductions.empty()) { + task_record_output += fmt::format("reductions: \n"); + for(const auto& red : task.reductions) { + task_record_output += fmt::format( + "\t id: {}, bid: {}, buffer_name: {}, init_from_buffer: {}\n", red.rid, red.bid, red.buffer_name, red.init_from_buffer ? "true" : "false"); + } + } + + if(!task.accesses.empty()) { + task_record_output += fmt::format("accesses: \n"); + for(const auto& acc : task.accesses) { + task_record_output += fmt::format("\t bid: {}, buffer_name: {}, mode: {}, req: {}\n", acc.bid, acc.buffer_name, acc.mode, acc.req); + } + } + + if(!task.side_effect_map.empty()) { + task_record_output += fmt::format("side_effect_map: \n"); + for(const auto& [hoid, order] : task.side_effect_map) { + task_record_output += fmt::format("\t hoid: {}, order: {}\n", hoid, order); + } + } + + if(!task.dependencies.empty()) { + task_record_output += fmt::format("dependencies: \n"); + for(const auto& dep : task.dependencies) { + task_record_output += fmt::format("\t node: {}, kind: {}, origin: {}\n", dep.node, dep.kind, dep.origin); + } + } + CELERITY_ERROR("{}", task_record_output); +} + +void divergence_block_chain::log_task_record_once(const divergence_map& check_map, const int task_num) { + for(auto& [hash, nodes] : check_map) { + if(nodes[0] == m_local_nid) { + const auto task = thread_save_get_task_record(task_num + m_tasks_checked); + log_task_record(check_map, task, hash); + } + } +} + +void divergence_block_chain::add_new_task(const task_record& task) { // + std::lock_guard lock(m_task_records_mutex); + // make copy of task record so that we can access it later + m_task_records.emplace_back(task); +} + +task_record divergence_block_chain::thread_save_get_task_record(const size_t task_num) { + std::lock_guard lock(m_task_records_mutex); + return m_task_records[task_num]; +} +} // namespace celerity::detail::divergence_checker_detail diff --git a/src/recorders.cc b/src/recorders.cc index 187201cb0..15fc2d361 100644 --- a/src/recorders.cc +++ b/src/recorders.cc @@ -48,8 +48,19 @@ task_record::task_record(const task& from, const buffer_manager* buff_mngr) reductions(build_reduction_list(from, buff_mngr)), accesses(build_access_list(from, buff_mngr)), side_effect_map(from.get_side_effect_map()), dependencies(build_task_dependency_list(from)) {} -void task_recorder::record_task(const task& tsk) { // +void task_recorder::record_task(const task& tsk) { m_recorded_tasks.emplace_back(tsk, m_buff_mngr); + invoke_callbacks(m_recorded_tasks.back()); +} + +void task_recorder::add_callback(task_callback callback) { // + m_callbacks.push_back(std::move(callback)); +} + +void task_recorder::invoke_callbacks(const task_record& tsk) const { + for(const auto& cb : m_callbacks) { + cb(tsk); + } } // Commands diff --git a/src/runtime.cc b/src/runtime.cc index 8eb65aba0..b956c5c96 100644 --- a/src/runtime.cc +++ b/src/runtime.cc @@ -26,6 +26,7 @@ #include "executor.h" #include "host_object.h" #include "log.h" +#include "mpi_communicator.h" #include "mpi_support.h" #include "named_threads.h" #include "print_graph.h" @@ -176,6 +177,19 @@ namespace detail { m_schdlr = std::make_unique(is_dry_run(), std::move(dggen), *m_exec); m_task_mngr->register_task_callback([this](const task* tsk) { m_schdlr->notify_task_created(tsk); }); + +#if CELERITY_DIVERGENCE_CHECK + MPI_Comm comm = nullptr; + MPI_Comm_dup(MPI_COMM_WORLD, &comm); + m_divergence_check = std::make_unique(*m_task_recorder, std::make_unique(comm), m_test_mode); +#endif + // if (m_cfg->is_recording()) { + // MPI_Comm comm = nullptr; + // MPI_Comm_dup(MPI_COMM_WORLD, &comm); + // m_divergence_check = + // std::make_unique(*m_task_recorder, std::make_unique(comm), m_test_mode); + // } + CELERITY_INFO("Celerity runtime version {} running on {}. PID = {}, build type = {}, {}", get_version_string(), get_sycl_version(), get_pid(), get_build_type(), get_mimalloc_string()); m_d_queue->init(*m_cfg, user_device_or_selector); @@ -224,20 +238,34 @@ namespace detail { m_d_queue->wait(); m_h_queue->wait(); - if(spdlog::should_log(log_level::trace) && m_cfg->is_recording()) { - if(m_local_nid == 0) { // It's the same across all nodes - assert(m_task_recorder.get() != nullptr); - const auto graph_str = detail::print_task_graph(*m_task_recorder); - CELERITY_TRACE("Task graph:\n\n{}\n", graph_str); - } - // must be called on all nodes - auto cmd_graph = gather_command_graph(); - if(m_local_nid == 0) { - std::this_thread::sleep_for(std::chrono::milliseconds(500)); // Avoid racing on stdout with other nodes (funneled through mpirun) - CELERITY_TRACE("Command graph:\n\n{}\n", cmd_graph); + if(m_cfg->is_recording()) { + if(spdlog::should_log(log_level::trace)) { + if(m_local_nid == 0) { // It's the same across all nodes + assert(m_task_recorder.get() != nullptr); + const auto graph_str = detail::print_task_graph(*m_task_recorder); + CELERITY_TRACE("Task graph:\n\n{}\n", graph_str); + } + // must be called on all nodes + auto cmd_graph = gather_command_graph(); + if(m_local_nid == 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(500)); // Avoid racing on stdout with other nodes (funneled through mpirun) + CELERITY_TRACE("Command graph:\n\n{}\n", cmd_graph); + } } + + // // Sychronize all nodes before reseting shuch that we don't get into a deadlock + // // With this barrier we can be shure that every node is fully finished and all mpi operations are done. (Sending ...) + // MPI_Barrier(MPI_COMM_WORLD); + // m_divergence_check.reset(); } +#if CELERITY_DIVERGENCE_CHECK + // Sychronize all nodes before reseting shuch that we don't get into a deadlock + // With this barrier we can be shure that every node is fully finished and all mpi operations are done. (Sending ...) + MPI_Barrier(MPI_COMM_WORLD); + m_divergence_check.reset(); +#endif + // Shutting down the task_manager will cause all buffers captured inside command group functions to unregister. // Since we check whether the runtime is still active upon unregistering, we have to set this to false first. m_is_active = false; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 35dce7596..6e1bba87a 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -48,6 +48,7 @@ set(TEST_TARGETS test_utils_tests utils_tests device_selection_tests + divergence_checker_tests ) add_library(test_main test_main.cc grid_test_utils.cc) diff --git a/test/divergence_checker_test_utils.h b/test/divergence_checker_test_utils.h new file mode 100644 index 000000000..12aa2f504 --- /dev/null +++ b/test/divergence_checker_test_utils.h @@ -0,0 +1,137 @@ +#pragma once + +#include "divergence_checker.h" + +using namespace celerity; +using namespace celerity::detail; +using namespace celerity::detail::divergence_checker_detail; + +struct divergence_checker_detail::divergence_block_chain_testspy { + static per_node_task_hashes pre_check(divergence_block_chain& div_test, const int max_size) { + div_test.add_new_hashes(); + div_test.collect_hash_counts(); + return div_test.collect_hashes(max_size); + } + + static void post_check(divergence_block_chain& div_test, const int min_size) { div_test.clear(min_size); } + + static void call_check_for_divergence_with_pre_post(std::vector>& div_test) { + std::vector sizes; + std::transform(div_test.begin(), div_test.end(), std::back_inserter(sizes), [](auto& div) { return div->m_task_records.size(); }); + auto [min, max] = std::minmax_element(sizes.begin(), sizes.end()); + + std::vector extended_lifetime_hashes; + for(size_t i = 1; i < div_test.size(); i++) { + extended_lifetime_hashes.push_back(divergence_block_chain_testspy::pre_check(*div_test[i], static_cast(*max))); + } + + call_check_for_divergence(*div_test[0]); + + for(size_t i = 1; i < div_test.size(); i++) { + divergence_block_chain_testspy::post_check(*div_test[i], static_cast(*min)); + } + } + + static bool call_check_for_divergence(divergence_block_chain& div_test) { return div_test.check_for_divergence(); } + + static void set_last_cleared(divergence_block_chain& div_test, std::chrono::time_point time) { div_test.m_last_cleared = time; } +}; + +namespace celerity::test_utils { +// Note: this is only a simulator for this specific case. In the general case, we should use something more sophisticated for tracking the allgather +// communication. +class divergence_test_communicator_provider { + public: + divergence_test_communicator_provider(size_t num_nodes) : m_num_nodes(num_nodes), m_inplace_data(num_nodes), m_gather_data(num_nodes) {} + + std::unique_ptr create(node_id local_nid) { + return std::make_unique(local_nid, m_num_nodes, m_inplace_data, m_gather_data); + } + + private: + struct inplace_data { + std::byte* data; + int count; + }; + + struct gather_data { + const std::byte* sendbuf; + int sendcount; + std::byte* recvbuf; + int recvcount; + }; + + template + struct tracker { + tracker(size_t num_nodes) : m_was_called(num_nodes), m_data(num_nodes) {} + + void operator()(T data, const node_id nid) { + m_was_called[nid] = true; + m_data[nid] = data; + } + + bool all() const { + return std::all_of(m_was_called.cbegin(), m_was_called.cend(), [](bool b) { return b; }); + } + + void reset() { std::fill(m_was_called.begin(), m_was_called.end(), false); } + + std::vector m_was_called; + std::vector m_data; + }; + + class divergence_test_communicator : public communicator { + public: + divergence_test_communicator(node_id local_nid, size_t num_nodes, tracker& inplace_data, tracker& gather_data) + : m_local_nid(local_nid), m_num_nodes(num_nodes), m_inplace_data(inplace_data), m_gather_data(gather_data) {} + + private: + node_id local_nid_impl() override { return m_local_nid; } + size_t num_nodes_impl() override { return m_num_nodes; } + + void allgather_inplace_impl(std::byte* data, const int count) override { + m_inplace_data({data, count}, m_local_nid); + if(m_inplace_data.all()) { + for(size_t i = 0; i < m_num_nodes; ++i) { + for(size_t j = 0; j < m_num_nodes; ++j) { + for(int k = 0; k < count; ++k) { + if(j != i) { m_inplace_data.m_data[i].data[j * count + k] = m_inplace_data.m_data[j].data[j * count + k]; } + } + } + } + + m_inplace_data.reset(); + } + } + + void allgather_impl(const std::byte* sendbuf, const int sendcount, std::byte* recvbuf, const int recvcount) override { + m_gather_data({sendbuf, sendcount, recvbuf, recvcount}, m_local_nid); + if(m_gather_data.all()) { + for(size_t i = 0; i < m_num_nodes; ++i) { + for(size_t j = 0; j < m_num_nodes; ++j) { + for(int k = 0; k < m_gather_data.m_data[i].sendcount; ++k) { + m_gather_data.m_data[i].recvbuf[j * (m_gather_data.m_data[i].sendcount) + k] = m_gather_data.m_data[j].sendbuf[k]; + } + } + } + + m_gather_data.reset(); + } + } + + void barrier_impl() override {} + + node_id m_local_nid; + size_t m_num_nodes; + + tracker& m_inplace_data; + tracker& m_gather_data; + }; + + size_t m_num_nodes; + + tracker m_inplace_data{m_num_nodes}; + tracker m_gather_data{m_num_nodes}; +}; + +} // namespace celerity::test_utils diff --git a/test/divergence_checker_tests.cc b/test/divergence_checker_tests.cc new file mode 100644 index 000000000..abb7ac941 --- /dev/null +++ b/test/divergence_checker_tests.cc @@ -0,0 +1,171 @@ +#include +#include +#include + +#include + +#include "divergence_checker_test_utils.h" +#include "log_test_utils.h" +#include "test_utils.h" + +using namespace celerity; +using namespace celerity::detail; +using celerity::access_mode; +using celerity::access::fixed; + +TEST_CASE("test diverged task execution on device tasks", "[divergence]") { + test_utils::task_test_context tt = test_utils::task_test_context{}; + test_utils::task_test_context tt_two = test_utils::task_test_context{}; + + test_utils::divergence_test_communicator_provider provider{2}; + std::vector> div_tests; + div_tests.emplace_back(std::make_unique(tt.trec, provider.create(0))); + div_tests.emplace_back(std::make_unique(tt_two.trec, provider.create(1))); + + auto buf = tt.mbf.create_buffer(range<1>(128)); + auto buf_two = tt_two.mbf.create_buffer(range<1>(128)); + + test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf.get_access(cgh, fixed<1>{{0, 64}}); }); + test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf.get_access(cgh, fixed<1>{{0, 128}}); }); + test_utils::add_compute_task(tt_two.tm, [&](handler& cgh) { buf_two.get_access(cgh, fixed<1>{{64, 128}}); }); + + test_utils::log_capture log_capture; + + CHECK_THROWS(divergence_block_chain_testspy::call_check_for_divergence_with_pre_post(div_tests)); + + CHECK_THAT(log_capture.get_log(), Catch::Matchers::ContainsSubstring("Detected divergence in execution between worker nodes")); +} + +TEST_CASE("test divergence free task execution on device", "[divergence]") { + auto tt = test_utils::task_test_context{}; + auto tt_two = test_utils::task_test_context{}; + + test_utils::divergence_test_communicator_provider provider{2}; + std::vector> div_tests; + div_tests.emplace_back(std::make_unique(tt.trec, provider.create(0))); + div_tests.emplace_back(std::make_unique(tt_two.trec, provider.create(1))); + + auto buf = tt.mbf.create_buffer(range<1>(128)); + auto buf_two = tt_two.mbf.create_buffer(range<1>(128)); + + test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf.get_access(cgh, fixed<1>{{0, 64}}); }); + test_utils::add_compute_task(tt_two.tm, [&](handler& cgh) { buf_two.get_access(cgh, fixed<1>{{0, 64}}); }); + + test_utils::log_capture log_capture; + + divergence_block_chain_testspy::call_check_for_divergence_with_pre_post(div_tests); + + CHECK_THAT(log_capture.get_log(), !Catch::Matchers::ContainsSubstring("Detected divergence in execution between worker nodes")); +} + +TEST_CASE("test diverged task execution on host task", "[divergence]") { + auto tt = test_utils::task_test_context{}; + auto tt_two = test_utils::task_test_context{}; + + test_utils::divergence_test_communicator_provider provider{2}; + std::vector> div_tests; + div_tests.emplace_back(std::make_unique(tt.trec, provider.create(0))); + div_tests.emplace_back(std::make_unique(tt_two.trec, provider.create(1))); + + auto buf = tt.mbf.create_buffer(range<1>(128)); + auto buf_two = tt_two.mbf.create_buffer(range<1>(128)); + + test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, fixed<1>({0, 128})); }); + test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, fixed<1>({64, 128})); }); + test_utils::add_host_task(tt_two.tm, on_master_node, [&](handler& cgh) { buf_two.get_access(cgh, fixed<1>({64, 128})); }); + + test_utils::log_capture log_capture; + + CHECK_THROWS(divergence_block_chain_testspy::call_check_for_divergence_with_pre_post(div_tests)); + + CHECK_THAT(log_capture.get_log(), Catch::Matchers::ContainsSubstring("Detected divergence in execution between worker nodes")); +} + +TEST_CASE("test divergence free task execution on host task", "[divergence]") { + auto tt = test_utils::task_test_context{}; + auto tt_two = test_utils::task_test_context{}; + + test_utils::divergence_test_communicator_provider provider{2}; + std::vector> div_tests; + div_tests.emplace_back(std::make_unique(tt.trec, provider.create(0))); + div_tests.emplace_back(std::make_unique(tt_two.trec, provider.create(1))); + + auto buf = tt.mbf.create_buffer(range<1>(128)); + auto buf_two = tt_two.mbf.create_buffer(range<1>(128)); + + test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, fixed<1>({0, 128})); }); + test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, fixed<1>({64, 128})); }); + + test_utils::add_host_task(tt_two.tm, on_master_node, [&](handler& cgh) { buf_two.get_access(cgh, fixed<1>({0, 128})); }); + test_utils::add_host_task(tt_two.tm, on_master_node, [&](handler& cgh) { buf_two.get_access(cgh, fixed<1>({64, 128})); }); + + test_utils::log_capture log_capture; + + divergence_block_chain_testspy::call_check_for_divergence_with_pre_post(div_tests); + + CHECK_THAT(log_capture.get_log(), !Catch::Matchers::ContainsSubstring("Detected divergence in execution between worker nodes")); +} + +TEST_CASE("test deadlock warning for tasks that are stale longer than 10 seconds", "[divergence]") { + auto tt = test_utils::task_test_context{}; + auto tt_two = test_utils::task_test_context{}; + + test_utils::divergence_test_communicator_provider provider{2}; + std::vector> div_tests; + div_tests.emplace_back(std::make_unique(tt.trec, provider.create(0))); + div_tests.emplace_back(std::make_unique(tt_two.trec, provider.create(1))); + + auto buf = tt.mbf.create_buffer(range<1>(128)); + + test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, fixed<1>({0, 128})); }); + + test_utils::log_capture log_capture; + + // call two times because first time the start task has to be cleared + divergence_block_chain_testspy::call_check_for_divergence_with_pre_post(div_tests); + divergence_block_chain_testspy::set_last_cleared(*div_tests[0], (std::chrono::steady_clock::now() - std::chrono::seconds(10))); + divergence_block_chain_testspy::call_check_for_divergence_with_pre_post(div_tests); + + CHECK_THAT(log_capture.get_log(), + Catch::Matchers::ContainsSubstring("After 10 seconds of waiting, node(s) 1 did not move to the next task. The runtime might be stuck.")); +} + +TEST_CASE("test correct output of 3 different divergent tasks", "[divergence]") { + auto tt = test_utils::task_test_context{}; + auto tt_two = test_utils::task_test_context{}; + auto tt_three = test_utils::task_test_context{}; + + test_utils::divergence_test_communicator_provider provider{3}; + std::vector> div_tests; + div_tests.emplace_back(std::make_unique(tt.trec, provider.create(0))); + div_tests.emplace_back(std::make_unique(tt_two.trec, provider.create(1))); + div_tests.emplace_back(std::make_unique(tt_three.trec, provider.create(2))); + + auto buf = tt.mbf.create_buffer(range<1>(128)); + auto buf_two = tt_two.mbf.create_buffer(range<1>(128)); + auto buf_three = tt_three.mbf.create_buffer(range<1>(128)); + + test_utils::add_compute_task(tt.tm, [&](handler& cgh) { + celerity::debug::set_task_name(cgh, "task_a"); + buf.get_access(cgh, fixed<1>{{0, 64}}); + }); + + test_utils::add_compute_task(tt_two.tm, [&](handler& cgh) { + celerity::debug::set_task_name(cgh, "task_a"); + buf_two.get_access(cgh, fixed<1>{{64, 128}}); + }); + + test_utils::add_compute_task(tt_three.tm, [&](handler& cgh) { + celerity::debug::set_task_name(cgh, "task_a"); + buf_three.get_access(cgh, fixed<1>{{0, 128}}); + }); + + test_utils::log_capture log_capture; + + CHECK_THROWS(divergence_block_chain_testspy::call_check_for_divergence_with_pre_post(div_tests)); + + CHECK_THAT(log_capture.get_log(), Catch::Matchers::ContainsSubstring("Task 1 has different hashes on these nodes:")); + CHECK_THAT(log_capture.get_log(), Catch::Matchers::ContainsSubstring("on node 2")); + CHECK_THAT(log_capture.get_log(), Catch::Matchers::ContainsSubstring("on node 1")); + CHECK_THAT(log_capture.get_log(), Catch::Matchers::ContainsSubstring("on node 0")); +} diff --git a/test/system/distr_tests.cc b/test/system/distr_tests.cc index bed2ba300..f03f187f0 100644 --- a/test/system/distr_tests.cc +++ b/test/system/distr_tests.cc @@ -11,6 +11,7 @@ #include +#include "../divergence_checker_test_utils.h" #include "../log_test_utils.h" namespace celerity { @@ -476,5 +477,89 @@ namespace detail { #endif } + TEST_CASE_METHOD(test_utils::runtime_fixture, "Check divergence of different nodes", "[divergence]") { +#if !CELERITY_DIVERGENCE_CHECK + SKIP("Distributed divergence boundary check only enabled when CELERITY_DIVERGENCE_CHECK=ON"); +#endif + + env::scoped_test_environment tenv(recording_enabled_env_setting); + + runtime::init(nullptr, nullptr); + + test_utils::log_capture log_capture; + + size_t n = 0; + size_t rank = 0; + + { + distr_queue queue; + + n = runtime::get_instance().get_num_nodes(); + REQUIRE(n > 1); + + auto& div_check = runtime_testspy::get_divergence_block_chain(runtime::get_instance()); + + const auto range = celerity::range<1>(10000); + celerity::buffer buff(range); + + celerity::debug::set_buffer_name(buff, "mat_a"); + + rank = celerity::detail::runtime::get_instance().get_local_nid(); + + divergence_block_chain_testspy::call_check_for_divergence(div_check); + + // here we need a divergence which doesn't result in a deadlock, because else we would run into ether a failed test or a incompletable test... + if(rank % 2 == 0) { + queue.submit([&](celerity::handler& cgh) { + celerity::accessor dw{buff, cgh, celerity::access::one_to_one{}, celerity::write_only, celerity::no_init}; + const auto range = buff.get_range(); + cgh.parallel_for(range, [=](celerity::item<1> item) { + if(item[0] % 2 == 0) { dw[item] = 2.5; } + }); + }); + } + + divergence_block_chain_testspy::set_last_cleared(div_check, std::chrono::steady_clock::now() - std::chrono::seconds(10)); + divergence_block_chain_testspy::call_check_for_divergence(div_check); + + if(rank % 2 == 1) { + queue.submit([&](celerity::handler& cgh) { + celerity::accessor dw{buff, cgh, celerity::access::one_to_one{}, celerity::write_only, celerity::no_init}; + const auto range = buff.get_range(); + cgh.parallel_for(range, [=](celerity::item<1> item) { + if(item[0] % 2 == 0) { dw[item] = 0.5; } + }); + }); + } + + queue.submit([&](celerity::handler& cgh) { + celerity::accessor acc{buff, cgh, celerity::access::all{}, celerity::read_only_host_task}; + const auto range = buff.get_range(); + cgh.host_task(celerity::on_master_node, [=] { + for(size_t i = 0; i < range.get(0); ++i) { + if(acc[i] == 3) { break; } + } + }); + }); + + CHECK_THROWS(divergence_block_chain_testspy::call_check_for_divergence(div_check)); + } + + // create the check text + std::string check_text = fmt::format("After 10 seconds of waiting, node(s)"); + std::vector stuck_nodes; + for(node_id nid = 0; nid < n; ++nid) { + // every second node in this test is stuck + if(nid % 2 == 1) { stuck_nodes.push_back(nid); } + } + check_text += fmt::format(" {} ", fmt::join(stuck_nodes, ",")); + check_text += "did not move to the next task. The runtime might be stuck."; + + if(rank == 0) { + const auto log = log_capture.get_log(); + CHECK_THAT(log, Catch::Matchers::ContainsSubstring(check_text)); + CHECK_THAT(log, Catch::Matchers::ContainsSubstring("Task record for ")); + } + } } // namespace detail } // namespace celerity diff --git a/test/test_utils.h b/test/test_utils.h index f966abf02..b8f0601aa 100644 --- a/test/test_utils.h +++ b/test/test_utils.h @@ -21,6 +21,7 @@ #include "command_graph.h" #include "device_queue.h" #include "distributed_graph_generator.h" +#include "divergence_checker.h" #include "graph_serializer.h" #include "print_graph.h" #include "range_mapper.h" @@ -60,6 +61,7 @@ namespace detail { static command_graph& get_cdag(runtime& rt) { return *rt.m_cdag; } static std::string print_task_graph(runtime& rt) { return detail::print_task_graph(*rt.m_task_recorder); } static std::string print_command_graph(const node_id local_nid, runtime& rt) { return detail::print_command_graph(local_nid, *rt.m_command_recorder); } + static divergence_checker_detail::divergence_block_chain& get_divergence_block_chain(runtime& rt) { return rt.m_divergence_check->m_block_chain; } }; struct task_ring_buffer_testspy {