Skip to content

Commit 3a9b4d8

Browse files
committed
Communicate FieldPerp consistently with other Fields
- Store a `variant` in `FieldGroup` that can handle `FieldPerp` - Removes `Mesh::communicate(FieldPerp&)` overload
1 parent 82fe1de commit 3a9b4d8

File tree

7 files changed

+198
-162
lines changed

7 files changed

+198
-162
lines changed

include/bout/fieldgroup.hxx

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
#ifndef BOUT_FIELDGROUP_H
22
#define BOUT_FIELDGROUP_H
33

4-
#include "bout/field_data.hxx"
5-
#include <bout/field3d.hxx>
6-
4+
#include <bout/sys/variant.hxx>
5+
#include <bout/traits.hxx>
76
#include <bout/vector2d.hxx>
87
#include <bout/vector3d.hxx>
98

109
#include <vector>
1110

12-
#include <algorithm>
11+
class Field2D;
12+
class Field3D;
13+
class FieldPerp;
14+
class FieldData;
1315

1416
/// Group together fields for easier communication
1517
///
@@ -19,14 +21,18 @@
1921
/// components (x,y,z) as Field2D or Field3D objects.
2022
class FieldGroup {
2123
public:
24+
using Item = bout::utils::variant<Field3D*, Field2D*, FieldPerp*>;
25+
2226
FieldGroup() = default;
2327
FieldGroup(const FieldGroup& other) = default;
2428
FieldGroup(FieldGroup&& other) = default;
2529
FieldGroup& operator=(const FieldGroup& other) = default;
2630
FieldGroup& operator=(FieldGroup&& other) = default;
31+
~FieldGroup() = default;
2732

2833
/// Constructor with a single FieldData \p f
29-
FieldGroup(FieldData& f) { fvec.push_back(&f); }
34+
FieldGroup(Field2D& f) { fvec.push_back(&f); }
35+
FieldGroup(FieldPerp& f) { fvec.push_back(&f); }
3036

3137
/// Constructor with a single Field3D \p f
3238
FieldGroup(Field3D& f) {
@@ -83,7 +89,8 @@ public:
8389
/// A pointer to this field will be stored internally,
8490
/// so the lifetime of this variable should be longer
8591
/// than the lifetime of this group.
86-
void add(FieldData& f) { fvec.push_back(&f); }
92+
void add(Field2D& f) { fvec.push_back(&f); }
93+
void add(FieldPerp& f) { fvec.push_back(&f); }
8794

8895
// Add a 3D field \p f, which goes into both vectors.
8996
//
@@ -121,18 +128,8 @@ public:
121128
}
122129

123130
/// Add multiple fields to this group
124-
///
125-
/// This is a variadic template which allows Field3D objects to be
126-
/// treated as a special case. An arbitrary number of fields can be
127-
/// added.
128-
template <typename... Ts>
129-
void add(FieldData& t, Ts&... ts) {
130-
add(t); // Add the first using functions above
131-
add(ts...); // Add the rest
132-
}
133-
134-
template <typename... Ts>
135-
void add(Field3D& t, Ts&... ts) {
131+
template <typename T, typename... Ts, typename = bout::utils::EnableIfField<T>>
132+
void add(T& t, Ts&... ts) {
136133
add(t); // Add the first using functions above
137134
add(ts...); // Add the rest
138135
}
@@ -165,16 +162,14 @@ public:
165162
}
166163

167164
/// Iteration over all fields
168-
using iterator = std::vector<FieldData*>::iterator;
169-
iterator begin() { return fvec.begin(); }
170-
iterator end() { return fvec.end(); }
165+
auto begin() { return fvec.begin(); }
166+
auto end() { return fvec.end(); }
171167

172168
/// Const iteration over all fields
173-
using const_iterator = std::vector<FieldData*>::const_iterator;
174-
const_iterator begin() const { return fvec.begin(); }
175-
const_iterator end() const { return fvec.end(); }
169+
auto begin() const { return fvec.cbegin(); }
170+
auto end() const { return fvec.cend(); }
176171

177-
const std::vector<FieldData*>& get() const { return fvec; }
172+
const std::vector<Item>& get() const { return fvec; }
178173

179174
/// Iteration over 3D fields
180175
const std::vector<Field3D*>& field3d() const { return f3vec; }
@@ -183,8 +178,8 @@ public:
183178
void makeUnique();
184179

185180
private:
186-
std::vector<FieldData*> fvec; // Vector of fields
187-
std::vector<Field3D*> f3vec; // Vector of 3D fields
181+
std::vector<Item> fvec; // Vector of fields
182+
std::vector<Field3D*> f3vec; // Vector of 3D fields
188183
};
189184

190185
/// Combine two FieldGroups

include/bout/mesh.hxx

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
*
44
* Interface for mesh classes. Contains standard variables and useful
55
* routines.
6-
*
6+
*
77
* Changelog
88
* =========
99
*
1010
* 2014-12 Ben Dudson <bd512@york.ac.uk>
1111
* * Removing coordinate system into separate
1212
* Coordinates class
1313
* * Adding index derivative functions from derivs.cxx
14-
*
14+
*
1515
* 2010-06 Ben Dudson, Sean Farley
1616
* * Initial version, adapted from GridData class
1717
* * Incorporates code from topology.cpp and Communicator
@@ -20,7 +20,7 @@
2020
* Copyright 2010-2025 BOUT++ contributors
2121
*
2222
* Contact: Ben Dudson, dudson2@llnl.gov
23-
*
23+
*
2424
* This file is part of BOUT++.
2525
*
2626
* BOUT++ is free software: you can redistribute it and/or modify
@@ -301,11 +301,6 @@ public:
301301
/// @param g The group of fields to communicate. Guard cells will be modified
302302
void communicateYZ(FieldGroup& g);
303303

304-
/*!
305-
* Communicate an X-Z field
306-
*/
307-
virtual void communicate(FieldPerp& f);
308-
309304
/*!
310305
* Send a list of FieldData objects
311306
* Packs arguments into a FieldGroup and passes
@@ -815,7 +810,7 @@ protected:
815810
const std::vector<int> readInts(const std::string& name, int n);
816811

817812
/// Calculates the size of a message for a given x and y range
818-
int msg_len(const std::vector<FieldData*>& var_list, int xge, int xlt, int yge,
813+
int msg_len(const std::vector<FieldGroup::Item>& var_list, int xge, int xlt, int yge,
819814
int ylt);
820815

821816
/// Initialise derivatives

src/mesh/impls/bout/boutmesh.cxx

Lines changed: 114 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
#include <bout/sys/gettext.hxx>
5050
#include <bout/sys/range.hxx>
5151
#include <bout/sys/timer.hxx>
52+
#include <bout/sys/variant.hxx>
5253
#include <bout/utils.hxx>
5354

5455
#include <fmt/format.h>
@@ -1400,6 +1401,16 @@ comm_handle BoutMesh::sendY(FieldGroup& g, comm_handle handle) {
14001401
return static_cast<void*>(ch);
14011402
}
14021403

1404+
namespace {
1405+
// FieldGroup stores a vector of variants now, rather than a pointer to the
1406+
// FieldData base class, so we need this visitor as a shim
1407+
struct DoneCommsVisitor {
1408+
void operator()(Field3D* var) const { var->doneComms(); }
1409+
void operator()(Field2D* var) const { var->doneComms(); }
1410+
void operator()(FieldPerp* var) const { var->doneComms(); }
1411+
};
1412+
} // namespace
1413+
14031414
int BoutMesh::wait(comm_handle handle) {
14041415

14051416
if (handle == nullptr) {
@@ -1545,7 +1556,7 @@ int BoutMesh::wait(comm_handle handle) {
15451556
#if CHECK > 0
15461557
// Keeping track of whether communications have been done
15471558
for (const auto& var : ch->var_list) {
1548-
var->doneComms();
1559+
bout::utils::visit(DoneCommsVisitor{}, var);
15491560
}
15501561
#endif
15511562

@@ -2218,9 +2229,9 @@ void BoutMesh::topology() {
22182229
}
22192230

22202231
for (int i = 0; i < limiter_count; ++i) {
2221-
int const yind = limiter_yinds[i];
2222-
int const xstart = limiter_xstarts[i];
2223-
int const xend = limiter_xends[i];
2232+
const int yind = limiter_yinds[i];
2233+
const int xstart = limiter_xstarts[i];
2234+
const int xend = limiter_xends[i];
22242235
output_info.write("Adding a limiter between y={} and {}. X indices {} to {}\n",
22252236
yind, yind + 1, xstart, xend);
22262237
add_target(yind, xstart, xend);
@@ -2397,72 +2408,121 @@ void BoutMesh::overlapHandleMemory(BoutMesh* yup, BoutMesh* ydown, BoutMesh* xin
23972408
* Communication utilities
23982409
****************************************************************/
23992410

2400-
int BoutMesh::pack_data(const std::vector<FieldData*>& var_list, int xge, int xlt,
2401-
int yge, int ylt, BoutReal* buffer) {
2402-
2403-
int len = 0;
2411+
namespace {
2412+
// Visitor for packing data from a `FieldGroup::Item` into an existing buffer
2413+
struct PackDataVisitor {
2414+
int xge;
2415+
int xlt;
2416+
int yge;
2417+
int ylt;
2418+
int zge;
2419+
int zlt;
2420+
BoutReal* buffer;
2421+
int len;
24042422

2405-
/// Loop over variables
2406-
for (const auto& var : var_list) {
2407-
if (var->is3D()) {
2408-
// 3D variable
2409-
auto* var3d_ref_ptr = dynamic_cast<Field3D*>(var);
2410-
ASSERT0(var3d_ref_ptr != nullptr);
2411-
auto& var3d_ref = *var3d_ref_ptr;
2412-
ASSERT2(var3d_ref.isAllocated());
2413-
for (int jx = xge; jx != xlt; jx++) {
2414-
for (int jy = yge; jy < ylt; jy++) {
2415-
for (int jz = 0; jz < LocalNz; jz++, len++) {
2416-
buffer[len] = var3d_ref(jx, jy, jz);
2417-
}
2423+
int operator()(const Field3D* var) {
2424+
const auto& var3d_ref = *var;
2425+
ASSERT2(var3d_ref.isAllocated());
2426+
for (int jx = xge; jx < xlt; jx++) {
2427+
for (int jy = yge; jy < ylt; jy++) {
2428+
for (int jz = zge; jz < zlt; jz++, len++) {
2429+
buffer[len] = var3d_ref(jx, jy, jz);
24182430
}
24192431
}
2420-
} else {
2421-
// 2D variable
2422-
auto* var2d_ref_ptr = dynamic_cast<Field2D*>(var);
2423-
ASSERT0(var2d_ref_ptr != nullptr);
2424-
auto& var2d_ref = *var2d_ref_ptr;
2425-
ASSERT2(var2d_ref.isAllocated());
2426-
for (int jx = xge; jx != xlt; jx++) {
2427-
for (int jy = yge; jy < ylt; jy++, len++) {
2428-
buffer[len] = var2d_ref(jx, jy);
2432+
}
2433+
return len;
2434+
}
2435+
2436+
int operator()(const Field2D* var) {
2437+
const auto& var2d_ref = *var;
2438+
ASSERT2(var2d_ref.isAllocated());
2439+
for (int jx = xge; jx < xlt; jx++) {
2440+
for (int jy = yge; jy < ylt; jy++, len++) {
2441+
buffer[len] = var2d_ref(jx, jy);
2442+
}
2443+
}
2444+
return len;
2445+
}
2446+
2447+
int operator()(const FieldPerp* var) {
2448+
const auto& varperp_ref = *var;
2449+
ASSERT2(varperp_ref.isAllocated());
2450+
for (int jx = xge; jx < xlt; jx++) {
2451+
for (int jz = zge; jz < zlt; jz++, len++) {
2452+
buffer[len] = varperp_ref(jx, jz);
2453+
}
2454+
}
2455+
return len;
2456+
}
2457+
};
2458+
2459+
// Visitor for unpacking a buffer into a `FieldGroup::Item`
2460+
struct UnpackDataVisitor {
2461+
int xge;
2462+
int xlt;
2463+
int yge;
2464+
int ylt;
2465+
int zge;
2466+
int zlt;
2467+
BoutReal* buffer;
2468+
int len;
2469+
2470+
void operator()(Field3D* var) {
2471+
auto& var3d_ref = *var;
2472+
ASSERT2(var3d_ref.isAllocated());
2473+
for (int jx = xge; jx < xlt; jx++) {
2474+
for (int jy = yge; jy < ylt; jy++) {
2475+
for (int jz = zge; jz < zlt; jz++, len++) {
2476+
var3d_ref(jx, jy, jz) = buffer[len];
24292477
}
24302478
}
24312479
}
24322480
}
24332481

2434-
return (len);
2482+
void operator()(Field2D* var) {
2483+
auto& var2d_ref = *var;
2484+
ASSERT2(var2d_ref.isAllocated());
2485+
for (int jx = xge; jx < xlt; jx++) {
2486+
for (int jy = yge; jy < ylt; jy++, len++) {
2487+
var2d_ref(jx, jy) = buffer[len];
2488+
}
2489+
}
2490+
}
2491+
2492+
void operator()(FieldPerp* var) {
2493+
auto& varperp_ref = *var;
2494+
ASSERT2(varperp_ref.isAllocated());
2495+
for (int jx = xge; jx < xlt; jx++) {
2496+
for (int jz = zge; jz < zlt; jz++, len++) {
2497+
varperp_ref(jx, jz) = buffer[len];
2498+
}
2499+
}
2500+
}
2501+
};
2502+
} // namespace
2503+
2504+
int BoutMesh::pack_data(const std::vector<FieldGroup::Item>& var_list, int xge, int xlt,
2505+
int yge, int ylt, BoutReal* buffer) {
2506+
2507+
auto visitor = PackDataVisitor{xge, xlt, yge, ylt, 0, LocalNz, buffer, 0};
2508+
2509+
for (const auto& var : var_list) {
2510+
bout::utils::visit(visitor, var);
2511+
}
2512+
2513+
return visitor.len;
24352514
}
24362515

2437-
int BoutMesh::unpack_data(const std::vector<FieldData*>& var_list, int xge, int xlt,
2516+
int BoutMesh::unpack_data(const std::vector<FieldGroup::Item>& var_list, int xge, int xlt,
24382517
int yge, int ylt, BoutReal* buffer) {
24392518

2440-
int len = 0;
2519+
auto visitor = UnpackDataVisitor{xge, xlt, yge, ylt, 0, LocalNz, buffer, 0};
24412520

2442-
/// Loop over variables
24432521
for (const auto& var : var_list) {
2444-
if (var->is3D()) {
2445-
// 3D variable
2446-
auto& var3d_ref = *dynamic_cast<Field3D*>(var);
2447-
for (int jx = xge; jx != xlt; jx++) {
2448-
for (int jy = yge; jy < ylt; jy++) {
2449-
for (int jz = 0; jz < LocalNz; jz++, len++) {
2450-
var3d_ref(jx, jy, jz) = buffer[len];
2451-
}
2452-
}
2453-
}
2454-
} else {
2455-
// 2D variable
2456-
auto& var2d_ref = *dynamic_cast<Field2D*>(var);
2457-
for (int jx = xge; jx != xlt; jx++) {
2458-
for (int jy = yge; jy < ylt; jy++, len++) {
2459-
var2d_ref(jx, jy) = buffer[len];
2460-
}
2461-
}
2462-
}
2522+
bout::utils::visit(visitor, var);
24632523
}
24642524

2465-
return (len);
2525+
return visitor.len;
24662526
}
24672527

24682528
/****************************************************************

src/mesh/impls/bout/boutmesh.hxx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -476,12 +476,12 @@ private:
476476
void post_receiveY(CommHandle& ch);
477477

478478
/// Take data from objects and put into a buffer
479-
int pack_data(const std::vector<FieldData*>& var_list, int xge, int xlt, int yge,
479+
int pack_data(const std::vector<FieldGroup::Item>& var_list, int xge, int xlt, int yge,
480480
int ylt, BoutReal* buffer);
481481
/// Copy data from a buffer back into the fields
482482

483-
int unpack_data(const std::vector<FieldData*>& var_list, int xge, int xlt, int yge,
484-
int ylt, BoutReal* buffer);
483+
int unpack_data(const std::vector<FieldGroup::Item>& var_list, int xge, int xlt,
484+
int yge, int ylt, BoutReal* buffer);
485485
};
486486

487487
namespace {

0 commit comments

Comments
 (0)