1
+ #pragma once
2
+
3
+ #include " recorders.h"
4
+ #include < mutex>
5
+ #include < thread>
6
+ #include < vector>
7
+
8
+ namespace celerity ::detail {
9
+ // in c++23 replace this with mdspan
10
+ template <typename T>
11
+ struct mpi_multidim_send_wrapper {
12
+ public:
13
+ const T& operator [](std::pair<int , int > ij) const {
14
+ assert (ij.first * m_width + ij.second < m_data.size ());
15
+ return m_data[ij.first * m_width + ij.second ];
16
+ }
17
+
18
+ T* data () { return m_data.data (); }
19
+
20
+ mpi_multidim_send_wrapper (size_t width, size_t height) : m_data(width * height), m_width(width){};
21
+
22
+ private:
23
+ std::vector<T> m_data;
24
+ const size_t m_width;
25
+ };
26
+
27
+ // Probably replace this in c++20 with span
28
+ template <typename T>
29
+ struct window {
30
+ public:
31
+ window (const std::vector<T>& value) : m_value(value) {}
32
+
33
+ const T& operator [](size_t i) const {
34
+ assert (i >= 0 && i < m_width);
35
+ return m_value[m_offset + i];
36
+ }
37
+
38
+ size_t size () {
39
+ m_width = m_value.size () - m_offset;
40
+ return m_width;
41
+ }
42
+
43
+ void slide (size_t i) {
44
+ assert (i == 0 || (i >= 0 && i <= m_width));
45
+ m_offset += i;
46
+ m_width -= i;
47
+ }
48
+
49
+ private:
50
+ const std::vector<T>& m_value;
51
+ size_t m_offset = 0 ;
52
+ size_t m_width = 0 ;
53
+ };
54
+
55
+ using task_hash = size_t ;
56
+ using task_hash_data = mpi_multidim_send_wrapper<task_hash>;
57
+ using divergence_map = std::unordered_map<task_hash, std::vector<node_id>>;
58
+
59
+ class abstract_block_chain {
60
+ friend struct abstract_block_chain_testspy ;
61
+
62
+ public:
63
+ virtual void start () { m_is_running = true ; };
64
+ virtual void stop () { m_is_running = false ; };
65
+
66
+ abstract_block_chain (const abstract_block_chain&) = delete ;
67
+ abstract_block_chain& operator =(const abstract_block_chain&) = delete ;
68
+ abstract_block_chain& operator =(abstract_block_chain&&) = delete ;
69
+
70
+ abstract_block_chain (abstract_block_chain&&) = default ;
71
+
72
+ abstract_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_recorder, MPI_Comm comm)
73
+ : m_local_nid(local_nid), m_num_nodes(num_nodes), m_sizes(num_nodes), m_task_recorder_window(task_recorder), m_comm(comm) {}
74
+
75
+ virtual ~abstract_block_chain () = default ;
76
+
77
+ protected:
78
+ virtual void run () = 0;
79
+
80
+ virtual void divergence_out (const divergence_map& check_map, const int task_num) = 0;
81
+
82
+ void add_new_hashes ();
83
+ void clear (const int min_progress);
84
+ virtual void allgather_sizes ();
85
+ virtual void allgather_hashes (const int max_size, task_hash_data& data);
86
+ std::pair<int , int > collect_sizes ();
87
+ task_hash_data collect_hashes (const int max_size);
88
+ divergence_map create_check_map (const task_hash_data& task_graphs, const int task_num) const ;
89
+
90
+ void check_for_deadlock () const ;
91
+
92
+ static void print_node_divergences (const divergence_map& check_map, const int task_num);
93
+
94
+ static void print_task_record (const divergence_map& check_map, const task_record& task, const task_hash hash);
95
+
96
+ virtual void dedub_print_task_record (const divergence_map& check_map, const int task_num) const ;
97
+
98
+ bool check_for_divergence ();
99
+
100
+ protected:
101
+ node_id m_local_nid;
102
+ size_t m_num_nodes;
103
+
104
+ std::vector<task_hash> m_hashes;
105
+ std::vector<int32_t > m_sizes;
106
+
107
+ bool m_is_running = true ;
108
+
109
+ window<task_record> m_task_recorder_window;
110
+
111
+ std::chrono::time_point<std::chrono::steady_clock> m_last_cleared = std::chrono::steady_clock::now();
112
+
113
+ MPI_Comm m_comm;
114
+ };
115
+
116
+ class single_node_test_divergence_block_chain : public abstract_block_chain {
117
+ public:
118
+ single_node_test_divergence_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_recorder, MPI_Comm comm,
119
+ const std::vector<std::reference_wrapper<const std::vector<task_record>>>& other_task_records)
120
+ : abstract_block_chain(num_nodes, local_nid, task_recorder, comm), m_other_hashes(other_task_records.size()) {
121
+ for (auto & tsk_rcd : other_task_records) {
122
+ m_other_task_records.push_back (window<task_record>(tsk_rcd));
123
+ }
124
+ }
125
+
126
+ private:
127
+ void run () override {}
128
+
129
+ void divergence_out (const divergence_map& check_map, const int task_num) override ;
130
+ void allgather_sizes () override ;
131
+ void allgather_hashes (const int max_size, task_hash_data& data) override ;
132
+
133
+ void dedub_print_task_record (const divergence_map& check_map, const int task_num) const override ;
134
+
135
+ std::vector<std::vector<task_hash>> m_other_hashes;
136
+ std::vector<window<task_record>> m_other_task_records;
137
+
138
+ int m_injected_delete_size = 0 ;
139
+ };
140
+
141
+ class distributed_test_divergence_block_chain : public abstract_block_chain {
142
+ public:
143
+ distributed_test_divergence_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_record, MPI_Comm comm)
144
+ : abstract_block_chain(num_nodes, local_nid, task_record, comm) {}
145
+
146
+ private:
147
+ void run () override {}
148
+
149
+ void divergence_out (const divergence_map& check_map, const int task_num) override ;
150
+ };
151
+
152
+ class divergence_block_chain : public abstract_block_chain {
153
+ public:
154
+ void start () override ;
155
+ void stop () override ;
156
+
157
+ divergence_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_record, MPI_Comm comm)
158
+ : abstract_block_chain(num_nodes, local_nid, task_record, comm) {
159
+ start ();
160
+ }
161
+
162
+ divergence_block_chain (const divergence_block_chain&) = delete ;
163
+ divergence_block_chain& operator =(const divergence_block_chain&) = delete ;
164
+ divergence_block_chain& operator =(divergence_block_chain&&) = delete ;
165
+
166
+ divergence_block_chain (divergence_block_chain&&) = default ;
167
+
168
+ ~divergence_block_chain () override { stop (); }
169
+
170
+ private:
171
+ void run () override ;
172
+
173
+ void divergence_out (const divergence_map& check_map, const int task_num) override ;
174
+
175
+ private:
176
+ std::thread m_thread;
177
+ };
178
+ } // namespace celerity::detail
0 commit comments