Skip to content

Commit 976fa6e

Browse files
committed
Add communicator for XZ planes
Required for properly communicating `FieldPerp` with multiple Z processors when using PETSc
1 parent 8397daa commit 976fa6e

File tree

10 files changed

+60
-18
lines changed

10 files changed

+60
-18
lines changed

include/bout/globalindexer.hxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public:
8686

8787
int localSize = size();
8888
MPI_Comm comm =
89-
std::is_same_v<T, FieldPerp> ? fieldmesh->getXcomm() : BoutComm::get();
89+
std::is_same_v<T, FieldPerp> ? fieldmesh->getXZcomm() : BoutComm::get();
9090
fieldmesh->getMpi().MPI_Scan(&localSize, &globalEnd, 1, MPI_INT, MPI_SUM, comm);
9191
globalEnd--;
9292
int counter = globalStart = globalEnd - size() + 1;

include/bout/hypre_interface.hxx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ public:
159159
explicit HypreVector(IndexerPtr<T> indConverter) : indexConverter(indConverter) {
160160
Mesh& mesh = *indConverter->getMesh();
161161
const MPI_Comm comm =
162-
std::is_same_v<T, FieldPerp> ? mesh.getXcomm() : BoutComm::get();
162+
std::is_same_v<T, FieldPerp> ? mesh.getXZcomm() : BoutComm::get();
163163

164164
HYPRE_BigInt jlower = indConverter->getGlobalStart();
165165
HYPRE_BigInt jupper = jlower + indConverter->size() - 1; // inclusive end
@@ -380,7 +380,7 @@ public:
380380
: hypre_matrix(new HYPRE_IJMatrix, MatrixDeleter{}), index_converter(indConverter) {
381381
Mesh* mesh = indConverter->getMesh();
382382
const MPI_Comm comm =
383-
std::is_same_v<T, FieldPerp> ? mesh->getXcomm() : BoutComm::get();
383+
std::is_same_v<T, FieldPerp> ? mesh->getXZcomm() : BoutComm::get();
384384
parallel_transform = &mesh->getCoordinates()->getParallelTransform();
385385

386386
ilower = indConverter->getGlobalStart();
@@ -812,7 +812,7 @@ public:
812812
"values are: gmres, bicgstab, pcg")
813813
.withDefault(HYPRE_SOLVER_TYPE::bicgstab);
814814

815-
comm = std::is_same_v<T, FieldPerp> ? mesh.getXcomm() : BoutComm::get();
815+
comm = std::is_same_v<T, FieldPerp> ? mesh.getXZcomm() : BoutComm::get();
816816

817817
auto print_level =
818818
options["hypre_print_level"]

include/bout/mesh.hxx

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
*
44
* Interface for mesh classes. Contains standard variables and useful
55
* routines.
6-
*
6+
*
77
* Changelog
88
* =========
99
*
1010
* 2014-12 Ben Dudson <bd512@york.ac.uk>
1111
* * Removing coordinate system into separate
1212
* Coordinates class
1313
* * Adding index derivative functions from derivs.cxx
14-
*
14+
*
1515
* 2010-06 Ben Dudson, Sean Farley
1616
* * Initial version, adapted from GridData class
1717
* * Incorporates code from topology.cpp and Communicator
@@ -20,7 +20,7 @@
2020
* Copyright 2010-2025 BOUT++ contributors
2121
*
2222
* Contact: Ben Dudson, dudson2@llnl.gov
23-
*
23+
*
2424
* This file is part of BOUT++.
2525
*
2626
* BOUT++ is free software: you can redistribute it and/or modify
@@ -58,8 +58,6 @@ class Mesh;
5858
#include "bout/sys/range.hxx" // RangeIterator
5959
#include "bout/unused.hxx"
6060

61-
#include "mpi.h"
62-
6361
#include <map>
6462
#include <memory>
6563
#include <optional>
@@ -405,6 +403,7 @@ public:
405403
} ///< Return communicator containing all processors in X
406404
virtual MPI_Comm getXcomm(int jy) const = 0; ///< Return X communicator
407405
virtual MPI_Comm getYcomm(int jx) const = 0; ///< Return Y communicator
406+
virtual MPI_Comm getXZcomm() const = 0; ///< Communicator in X-Z
408407

409408
/// Return pointer to the mesh's MPI Wrapper object
410409
MpiWrapper& getMpi() { return *mpi; }

include/bout/petsc_interface.hxx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ inline MPI_Comm getComm([[maybe_unused]] const T& field) {
7979

8080
template <>
8181
inline MPI_Comm getComm([[maybe_unused]] const FieldPerp& field) {
82-
return field.getMesh()->getXcomm();
82+
return field.getMesh()->getXZcomm();
8383
}
8484

8585
template <class T>
@@ -293,7 +293,7 @@ public:
293293
PetscMatrix(IndexerPtr<T> indConverter, bool preallocate = true)
294294
: matrix(new Mat()), indexConverter(indConverter),
295295
pt(&indConverter->getMesh()->getCoordinates()->getParallelTransform()) {
296-
MPI_Comm comm = std::is_same_v<T, FieldPerp> ? indConverter->getMesh()->getXcomm()
296+
MPI_Comm comm = std::is_same_v<T, FieldPerp> ? indConverter->getMesh()->getXZcomm()
297297
: BoutComm::get();
298298
const int size = indexConverter->size();
299299

src/invert/laplace/impls/petsc/petsc_laplace.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ LaplacePetsc::LaplacePetsc(Options* opt, const CELL_LOC loc, Mesh* mesh_in,
140140
[[maybe_unused]] Solver* solver)
141141
: Laplacian(opt, loc, mesh_in), A(0.0, mesh_in), C1(1.0, mesh_in), C2(1.0, mesh_in),
142142
D(1.0, mesh_in), Ex(0.0, mesh_in), Ez(0.0, mesh_in), issetD(false), issetC(false),
143-
issetE(false), comm(localmesh->getXcomm()),
143+
issetE(false), comm(localmesh->getXZcomm()),
144144
opts(opt == nullptr ? &(Options::root()["laplace"]) : opt),
145145
// WARNING: only a few of these options actually make sense: see the
146146
// PETSc documentation to work out which they are (possibly

src/invert/laplacexz/impls/petsc/laplacexz-petsc.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ LaplaceXZpetsc::LaplaceXZpetsc(Mesh* m, Options* opt, const CELL_LOC loc)
155155
.withDefault("petsc");
156156

157157
// Get MPI communicator
158-
MPI_Comm comm = localmesh->getXcomm();
158+
MPI_Comm comm = localmesh->getXZcomm();
159159

160160
// Local size
161161
int localN = (localmesh->xend - localmesh->xstart + 1) * (localmesh->LocalNz);

src/mesh/impls/bout/boutmesh.cxx

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include <bout/utils.hxx>
5353

5454
#include <fmt/format.h>
55+
#include <fmt/ranges.h>
5556

5657
#include <algorithm>
5758
#include <cmath>
@@ -106,6 +107,9 @@ BoutMesh::~BoutMesh() {
106107
if (comm_outer != MPI_COMM_NULL) {
107108
MPI_Comm_free(&comm_outer);
108109
}
110+
if (comm_xz != MPI_COMM_NULL) {
111+
MPI_Comm_free(&comm_xz);
112+
}
109113
}
110114

111115
BoutMesh::YDecompositionIndices
@@ -665,10 +669,43 @@ int BoutMesh::load() {
665669
return 0;
666670
}
667671

672+
namespace {
673+
auto make_XZ_communicator(const BoutMesh& mesh, MPI_Group group_world) -> MPI_Comm {
674+
std::vector<int> ranks;
675+
676+
const int yp = mesh.getYProcIndex();
677+
678+
// All processors with the same Y index
679+
for (int xp = 0; xp < mesh.getNXPE(); ++xp) {
680+
for (int zp = 0; zp < mesh.getNZPE(); ++zp) {
681+
ranks.push_back(mesh.getProcIndex(xp, yp, zp));
682+
}
683+
}
684+
MPI_Group group{};
685+
if (MPI_Group_incl(group_world, static_cast<int>(ranks.size()), ranks.data(), &group)
686+
!= MPI_SUCCESS) {
687+
throw BoutException("Could not create X-Z communication group for ranks {}",
688+
fmt::join(ranks, ", "));
689+
}
690+
691+
MPI_Comm comm_xz{};
692+
if (MPI_Comm_create(BoutComm::get(), group, &comm_xz) != MPI_SUCCESS) {
693+
throw BoutException("Could not create X-Z communicator for yp={} (xind={}, yind={}, "
694+
"zind={}) ranks={}",
695+
yp, mesh.getXProcIndex(), mesh.getYProcIndex(),
696+
mesh.getZProcIndex(), fmt::join(ranks, ", "));
697+
}
698+
699+
return comm_xz;
700+
}
701+
} // namespace
702+
668703
void BoutMesh::createCommunicators() {
669704
MPI_Group group_world{};
670705
MPI_Comm_group(BoutComm::get(), &group_world); // Get the entire group
671706

707+
comm_xz = make_XZ_communicator(*this, group_world);
708+
672709
//////////////////////////////////////////////////////
673710
/// Communicator in X
674711

@@ -1038,7 +1075,9 @@ void BoutMesh::createXBoundaries() {
10381075
}
10391076
}
10401077

1041-
int BoutMesh::getProcIndex(int X, int Y, int Z) const { return Y * NXPE + X; }
1078+
int BoutMesh::getProcIndex(int X, int Y, int Z) const {
1079+
return (((Z * NYPE) + Y) * NXPE) + X;
1080+
}
10421081

10431082
void BoutMesh::createYBoundaries() {
10441083
if (MYG <= 0) {
@@ -2218,9 +2257,9 @@ void BoutMesh::topology() {
22182257
}
22192258

22202259
for (int i = 0; i < limiter_count; ++i) {
2221-
int const yind = limiter_yinds[i];
2222-
int const xstart = limiter_xstarts[i];
2223-
int const xend = limiter_xends[i];
2260+
const int yind = limiter_yinds[i];
2261+
const int xstart = limiter_xstarts[i];
2262+
const int xend = limiter_xends[i];
22242263
output_info.write("Adding a limiter between y={} and {}. X indices {} to {}\n",
22252264
yind, yind + 1, xstart, xend);
22262265
add_target(yind, xstart, xend);

src/mesh/impls/bout/boutmesh.hxx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ public:
107107
MPI_Comm getXcomm(int UNUSED(jy)) const override { return comm_x; }
108108
/// Return communicator containing all processors in Y
109109
MPI_Comm getYcomm(int xpos) const override;
110+
MPI_Comm getXZcomm() const override { return comm_xz; }
110111

111112
/// Is local X index \p jx periodic in Y?
112113
///
@@ -455,6 +456,8 @@ private:
455456

456457
/// Communicator containing all processors in X
457458
MPI_Comm comm_x{MPI_COMM_NULL};
459+
/// Communicator for all processors in an XZ plane
460+
MPI_Comm comm_xz{MPI_COMM_NULL};
458461

459462
//////////////////////////////////////////////////
460463
// Surface communications

src/mesh/mesh.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ int Mesh::globalStartIndex2D() {
516516
int Mesh::globalStartIndexPerp() {
517517
int localSize = localSizePerp();
518518
int cumulativeSize = 0;
519-
mpi->MPI_Scan(&localSize, &cumulativeSize, 1, MPI_INT, MPI_SUM, getXcomm());
519+
mpi->MPI_Scan(&localSize, &cumulativeSize, 1, MPI_INT, MPI_SUM, getXZcomm());
520520
return cumulativeSize - localSize;
521521
}
522522

tests/unit/fake_mesh.hxx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ public:
139139
}
140140
MPI_Comm getXcomm(int UNUSED(jy)) const override { return BoutComm::get(); }
141141
MPI_Comm getYcomm(int UNUSED(jx)) const override { return BoutComm::get(); }
142+
MPI_Comm getXZcomm() const override { return BoutComm::get(); }
142143

143144
// Periodic Y
144145
int ix_separatrix{1000000}; // separatrix index

0 commit comments

Comments
 (0)