Skip to content

Commit 51a7f03

Browse files
committed
Update Particle Container to Pure SoA
Transition particle containers to pure SoA layouts.
1 parent ce33709 commit 51a7f03

39 files changed

+693
-645
lines changed

cmake/dependencies/ABLASTR.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ set(ImpactX_openpmd_src ""
178178
set(ImpactX_ablastr_repo "https://github.com/ECP-WarpX/WarpX.git"
179179
CACHE STRING
180180
"Repository URI to pull and build ABLASTR from if(ImpactX_ablastr_internal)")
181-
set(ImpactX_ablastr_branch "24.01"
181+
set(ImpactX_ablastr_branch "94ae11900131846e9f8b3704194673f7f02d8959"
182182
CACHE STRING
183183
"Repository branch for ImpactX_ablastr_repo if(ImpactX_ablastr_internal)")
184184

examples/fodo/run_fodo_programmable.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,16 @@ def my_drift(pge, pti, refpart):
7878

7979
else:
8080
array = np.array
81-
# access AoS data such as positions and cpu/id
82-
aos = pti.aos()
83-
aos_arr = array(aos, copy=False)
8481

85-
# access SoA data such as momentum
82+
# access particle attributes
8683
soa = pti.soa()
8784
real_arrays = soa.GetRealData()
88-
px = array(real_arrays[0], copy=False)
89-
py = array(real_arrays[1], copy=False)
90-
pt = array(real_arrays[2], copy=False)
85+
x = array(real_arrays[0], copy=False)
86+
y = array(real_arrays[1], copy=False)
87+
t = array(real_arrays[2], copy=False)
88+
px = array(real_arrays[3], copy=False)
89+
py = array(real_arrays[4], copy=False)
90+
pt = array(real_arrays[5], copy=False)
9191

9292
# length of the current slice
9393
slice_ds = pge.ds / pge.nslice
@@ -97,9 +97,9 @@ def my_drift(pge, pti, refpart):
9797
betgam2 = pt_ref**2 - 1.0
9898

9999
# advance position and momentum (drift)
100-
aos_arr[:]["x"] += slice_ds * px[:]
101-
aos_arr[:]["y"] += slice_ds * py[:]
102-
aos_arr[:]["z"] += (slice_ds / betgam2) * pt[:]
100+
x[:] += slice_ds * px[:]
101+
y[:] += slice_ds * py[:]
102+
t[:] += (slice_ds / betgam2) * pt[:]
103103

104104

105105
def my_ref_drift(pge, refpart):

src/particles/CollectLost.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ namespace impactx
2727
using DstData = ImpactXParticleContainer::ParticleTileType::ParticleTileDataType;
2828

2929
AMREX_GPU_HOST_DEVICE
30-
void operator() (DstData const &dst, SrcData const &src, int src_ip, int dst_ip) const noexcept {
31-
dst.m_aos[dst_ip] = src.m_aos[src_ip];
32-
30+
void operator() (DstData const &dst, SrcData const &src, int src_ip, int dst_ip) const noexcept
31+
{
32+
dst.m_idcpu[dst_ip] = src.m_idcpu[src_ip];
3333
for (int j = 0; j < SrcData::NAR; ++j)
3434
dst.m_rdata[j][dst_ip] = src.m_rdata[j][src_ip];
3535
for (int j = 0; j < src.m_num_runtime_real; ++j)
@@ -141,8 +141,7 @@ namespace impactx
141141
// move down
142142
int const new_index = ip - n_removed;
143143

144-
ptile_src_data.m_aos[new_index] = ptile_src_data.m_aos[ip];
145-
144+
ptile_src_data.m_idcpu[new_index] = ptile_src_data.m_idcpu[ip];
146145
for (int j = 0; j < SrcData::NAR; ++j)
147146
ptile_src_data.m_rdata[j][new_index] = ptile_src_data.m_rdata[j][ip];
148147
for (int j = 0; j < ptile_src_data.m_num_runtime_real; ++j)

src/particles/ImpactXParticleContainer.H

Lines changed: 24 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <AMReX_MultiFab.H>
1919
#include <AMReX_ParIter.H>
2020
#include <AMReX_Particles.H>
21+
#include <AMReX_ParticleTile.H>
2122

2223
#include <AMReX_IntVect.H>
2324
#include <AMReX_Vector.H>
@@ -35,43 +36,16 @@ namespace impactx
3536
t ///< fixed t as the independent variable
3637
};
3738

38-
/** AMReX pre-defined Real attributes
39-
*
40-
* These are the AMReX pre-defined struct indexes for the Real attributes
41-
* stored in an AoS in ImpactXParticleContainer. We document this here,
42-
* because we change the meaning of these "positions" depending on the
43-
* coordinate system we are currently in.
44-
*/
45-
struct RealAoS
46-
{
47-
enum
48-
{
49-
x, ///< position in x [m] (at fixed s OR fixed t)
50-
y, ///< position in y [m] (at fixed s OR fixed t)
51-
t, ///< c * time-of-flight [m] (at fixed s)
52-
nattribs ///< the number of attributes above (always last)
53-
};
54-
55-
// at fixed t, the third component represents the position z
56-
enum {
57-
z = t ///< position in z [m] (at fixed t)
58-
};
59-
60-
//! named labels for fixed s
61-
static constexpr auto names_s = { "position_x", "position_y", "position_t" };
62-
//! named labels for fixed t
63-
static constexpr auto names_t = { "position_x", "position_y", "position_z" };
64-
static_assert(names_s.size() == nattribs);
65-
static_assert(names_t.size() == nattribs);
66-
};
67-
68-
/** This struct indexes the additional Real attributes
39+
/** This struct indexes the Real attributes
6940
* stored in an SoA in ImpactXParticleContainer
7041
*/
7142
struct RealSoA
7243
{
7344
enum
7445
{
46+
x, ///< position in x [m] (at fixed s or t)
47+
y, ///< position in y [m] (at fixed s or t)
48+
t, ///< time-of-flight ct [m] (at fixed s)
7549
px, ///< momentum in x, scaled by the magnitude of the reference momentum [unitless] (at fixed s or t)
7650
py, ///< momentum in y, scaled by the magnitude of the reference momentum [unitless] (at fixed s or t)
7751
pt, ///< energy deviation, scaled by speed of light * the magnitude of the reference momentum [unitless] (at fixed s)
@@ -80,27 +54,28 @@ namespace impactx
8054
nattribs ///< the number of attributes above (always last)
8155
};
8256

83-
// at fixed t, the third component represents the momentum in z
57+
// at fixed t, the third component represents the position z, the 6th component represents the momentum in z
8458
enum {
59+
z = t, ///< position in z [m] (at fixed t)
8560
pz = pt ///< momentum in z, scaled by the magnitude of the reference momentum [unitless] (at fixed t)
8661
};
8762

8863
//! named labels for fixed s
89-
static constexpr auto names_s = { "momentum_x", "momentum_y", "momentum_t", "qm", "weighting" };
64+
static constexpr auto names_s = { "position_x", "position_y", "position_t", "momentum_x", "momentum_y", "momentum_t", "qm", "weighting" };
9065
//! named labels for fixed t
91-
static constexpr auto names_t = { "momentum_x", "momentum_y", "momentum_z", "qm", "weighting" };
66+
static constexpr auto names_t = { "position_x", "position_y", "position_z", "momentum_x", "momentum_y", "momentum_z", "qm", "weighting" };
9267
static_assert(names_s.size() == nattribs);
9368
static_assert(names_t.size() == nattribs);
9469
};
9570

96-
/** This struct indexes the additional Integer attributes
71+
/** This struct indexes the Integer attributes
9772
* stored in an SoA in ImpactXParticleContainer
9873
*/
9974
struct IntSoA
10075
{
10176
enum
10277
{
103-
nattribs ///< the number of particles above (always last)
78+
nattribs ///< the number of attributes above (always last)
10479
};
10580
};
10681

@@ -109,46 +84,46 @@ namespace impactx
10984
* We subclass here to change the default threading strategy, which is
11085
* `static` in AMReX, to `dynamic` in ImpactX.
11186
*/
112-
class ParIter
113-
: public amrex::ParIter<0, 0, RealSoA::nattribs, IntSoA::nattribs>
87+
class ParIterSoA
88+
: public amrex::ParIterSoA<RealSoA::nattribs, IntSoA::nattribs>
11489
{
11590
public:
116-
using amrex::ParIter<0, 0, RealSoA::nattribs, IntSoA::nattribs>::ParIter;
91+
using amrex::ParIterSoA<RealSoA::nattribs, IntSoA::nattribs>::ParIterSoA;
11792

118-
ParIter (ContainerType& pc, int level);
93+
ParIterSoA (ContainerType& pc, int level);
11994

120-
ParIter (ContainerType& pc, int level, amrex::MFItInfo& info);
95+
ParIterSoA (ContainerType& pc, int level, amrex::MFItInfo& info);
12196
};
12297

12398
/** Const AMReX iterator for particle boxes - data is read only.
12499
*
125100
* We subclass here to change the default threading strategy, which is
126101
* `static` in AMReX, to `dynamic` in ImpactX.
127102
*/
128-
class ParConstIter
129-
: public amrex::ParConstIter<0, 0, RealSoA::nattribs, IntSoA::nattribs>
103+
class ParConstIterSoA
104+
: public amrex::ParConstIterSoA<RealSoA::nattribs, IntSoA::nattribs>
130105
{
131106
public:
132-
using amrex::ParConstIter<0, 0, RealSoA::nattribs, IntSoA::nattribs>::ParConstIter;
107+
using amrex::ParConstIterSoA<RealSoA::nattribs, IntSoA::nattribs>::ParConstIterSoA;
133108

134-
ParConstIter (ContainerType& pc, int level);
109+
ParConstIterSoA (ContainerType& pc, int level);
135110

136-
ParConstIter (ContainerType& pc, int level, amrex::MFItInfo& info);
111+
ParConstIterSoA (ContainerType& pc, int level, amrex::MFItInfo& info);
137112
};
138113

139114
/** Beam Particles in ImpactX
140115
*
141116
* This class stores particles, distributed over MPI ranks.
142117
*/
143118
class ImpactXParticleContainer
144-
: public amrex::ParticleContainer<0, 0, RealSoA::nattribs, IntSoA::nattribs>
119+
: public amrex::ParticleContainerPureSoA<RealSoA::nattribs, IntSoA::nattribs>
145120
{
146121
public:
147122
//! amrex iterator for particle boxes
148-
using iterator = impactx::ParIter;
123+
using iterator = impactx::ParIterSoA;
149124

150125
//! amrex constant iterator for particle boxes (read-only)
151-
using const_iterator = impactx::ParConstIter;
126+
using const_iterator = impactx::ParConstIterSoA;
152127

153128
//! Construct a new particle container
154129
ImpactXParticleContainer (initialization::AmrCoreData* amr_core);
@@ -276,10 +251,6 @@ namespace impactx
276251
DepositCharge (std::unordered_map<int, amrex::MultiFab> & rho,
277252
amrex::Vector<amrex::IntVect> const & ref_ratio);
278253

279-
/** Get the name of each Real AoS component */
280-
std::vector<std::string>
281-
RealAoS_names () const;
282-
283254
/** Get the name of each Real SoA component */
284255
std::vector<std::string>
285256
RealSoA_names () const;
@@ -311,10 +282,6 @@ namespace impactx
311282

312283
}; // ImpactXParticleContainer
313284

314-
/** Get the name of each Real AoS component */
315-
std::vector<std::string>
316-
get_RealAoS_names ();
317-
318285
/** Get the name of each Real SoA component
319286
*
320287
* @param num_real_comps number of compile-time + runtime arrays

src/particles/ImpactXParticleContainer.cpp

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include <AMReX_AmrParGDB.H>
2020
#include <AMReX_ParallelDescriptor.H>
2121
#include <AMReX_ParmParse.H>
22-
#include <AMReX_ParticleTile.H>
22+
#include <AMReX_Particle.H>
2323

2424
#include <algorithm>
2525
#include <stdexcept>
@@ -38,24 +38,24 @@ namespace
3838

3939
namespace impactx
4040
{
41-
ParIter::ParIter (ContainerType& pc, int level)
42-
: amrex::ParIter<0, 0, RealSoA::nattribs, IntSoA::nattribs>(pc, level,
41+
ParIterSoA::ParIterSoA (ContainerType& pc, int level)
42+
: amrex::ParIterSoA<RealSoA::nattribs, IntSoA::nattribs>(pc, level,
4343
amrex::MFItInfo().SetDynamic(do_omp_dynamic())) {}
4444

45-
ParIter::ParIter (ContainerType& pc, int level, amrex::MFItInfo& info)
46-
: amrex::ParIter<0, 0, RealSoA::nattribs, IntSoA::nattribs>(pc, level,
45+
ParIterSoA::ParIterSoA (ContainerType& pc, int level, amrex::MFItInfo& info)
46+
: amrex::ParIterSoA<RealSoA::nattribs, IntSoA::nattribs>(pc, level,
4747
info.SetDynamic(do_omp_dynamic())) {}
4848

49-
ParConstIter::ParConstIter (ContainerType& pc, int level)
50-
: amrex::ParConstIter<0, 0, RealSoA::nattribs, IntSoA::nattribs>(pc, level,
49+
ParConstIterSoA::ParConstIterSoA (ContainerType& pc, int level)
50+
: amrex::ParConstIterSoA<RealSoA::nattribs, IntSoA::nattribs>(pc, level,
5151
amrex::MFItInfo().SetDynamic(do_omp_dynamic())) {}
5252

53-
ParConstIter::ParConstIter (ContainerType& pc, int level, amrex::MFItInfo& info)
54-
: amrex::ParConstIter<0, 0, RealSoA::nattribs, IntSoA::nattribs>(pc, level,
53+
ParConstIterSoA::ParConstIterSoA (ContainerType& pc, int level, amrex::MFItInfo& info)
54+
: amrex::ParConstIterSoA<RealSoA::nattribs, IntSoA::nattribs>(pc, level,
5555
info.SetDynamic(do_omp_dynamic())) {}
5656

5757
ImpactXParticleContainer::ImpactXParticleContainer (initialization::AmrCoreData* amr_core)
58-
: amrex::ParticleContainer<0, 0, RealSoA::nattribs, IntSoA::nattribs>(amr_core->GetParGDB())
58+
: amrex::ParticleContainerPureSoA<RealSoA::nattribs, IntSoA::nattribs>(amr_core->GetParGDB())
5959
{
6060
SetParticleSize();
6161
}
@@ -157,14 +157,18 @@ namespace impactx
157157

158158
const int cpuid = amrex::ParallelDescriptor::MyProc();
159159

160-
auto * AMREX_RESTRICT pstructs = particle_tile.GetArrayOfStructs()().dataPtr();
161160
auto & soa = particle_tile.GetStructOfArrays().GetRealData();
161+
amrex::ParticleReal * const AMREX_RESTRICT x_arr = soa[RealSoA::x].dataPtr();
162+
amrex::ParticleReal * const AMREX_RESTRICT y_arr = soa[RealSoA::y].dataPtr();
163+
amrex::ParticleReal * const AMREX_RESTRICT t_arr = soa[RealSoA::t].dataPtr();
162164
amrex::ParticleReal * const AMREX_RESTRICT px_arr = soa[RealSoA::px].dataPtr();
163165
amrex::ParticleReal * const AMREX_RESTRICT py_arr = soa[RealSoA::py].dataPtr();
164166
amrex::ParticleReal * const AMREX_RESTRICT pt_arr = soa[RealSoA::pt].dataPtr();
165167
amrex::ParticleReal * const AMREX_RESTRICT qm_arr = soa[RealSoA::qm].dataPtr();
166168
amrex::ParticleReal * const AMREX_RESTRICT w_arr = soa[RealSoA::w ].dataPtr();
167169

170+
uint64_t * const AMREX_RESTRICT idcpu_arr = particle_tile.GetStructOfArrays().GetIdCPUData().dataPtr();
171+
168172
amrex::ParticleReal const * const AMREX_RESTRICT x_ptr = x.data();
169173
amrex::ParticleReal const * const AMREX_RESTRICT y_ptr = y.data();
170174
amrex::ParticleReal const * const AMREX_RESTRICT t_ptr = t.data();
@@ -175,12 +179,12 @@ namespace impactx
175179
amrex::ParallelFor(np,
176180
[=] AMREX_GPU_DEVICE (int i) noexcept
177181
{
178-
ParticleType& p = pstructs[old_np + i];
179-
p.id() = pid + i;
180-
p.cpu() = cpuid;
181-
p.pos(RealAoS::x) = x_ptr[i];
182-
p.pos(RealAoS::y) = y_ptr[i];
183-
p.pos(RealAoS::t) = t_ptr[i];
182+
amrex::ParticleIDWrapper{idcpu_arr[old_np+i]} = pid + i;
183+
amrex::ParticleCPUWrapper{idcpu_arr[old_np+i]} = cpuid;
184+
185+
x_arr[old_np+i] = x_ptr[i];
186+
y_arr[old_np+i] = y_ptr[i];
187+
t_arr[old_np+i] = t_ptr[i];
184188

185189
px_arr[old_np+i] = px_ptr[i];
186190
py_arr[old_np+i] = py_ptr[i];
@@ -240,12 +244,6 @@ namespace impactx
240244
>(*this);
241245
}
242246

243-
std::vector<std::string>
244-
ImpactXParticleContainer::RealAoS_names () const
245-
{
246-
return get_RealAoS_names();
247-
}
248-
249247
std::vector<std::string>
250248
ImpactXParticleContainer::RealSoA_names () const
251249
{
@@ -264,17 +262,6 @@ namespace impactx
264262
m_coordsystem = coord_system;
265263
}
266264

267-
std::vector<std::string>
268-
get_RealAoS_names ()
269-
{
270-
std::vector<std::string> real_aos_names(RealAoS::names_s.size());
271-
272-
// compile-time attributes
273-
std::copy(RealAoS::names_s.begin(), RealAoS::names_s.end(), real_aos_names.begin());
274-
275-
return real_aos_names;
276-
}
277-
278265
std::vector<std::string>
279266
get_RealSoA_names (int num_real_comps)
280267
{

0 commit comments

Comments
 (0)