Skip to content

Commit fe812bd

Browse files
committed
Update Particle Container to Pure SoA
Transition particle containers to pure SoA layouts.
1 parent ce33709 commit fe812bd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+760
-693
lines changed

cmake/dependencies/ABLASTR.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ set(ImpactX_openpmd_src ""
178178
set(ImpactX_ablastr_repo "https://github.com/ECP-WarpX/WarpX.git"
179179
CACHE STRING
180180
"Repository URI to pull and build ABLASTR from if(ImpactX_ablastr_internal)")
181-
set(ImpactX_ablastr_branch "24.01"
181+
set(ImpactX_ablastr_branch "94ae11900131846e9f8b3704194673f7f02d8959"
182182
CACHE STRING
183183
"Repository branch for ImpactX_ablastr_repo if(ImpactX_ablastr_internal)")
184184

examples/fodo/run_fodo_programmable.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,16 @@ def my_drift(pge, pti, refpart):
7878

7979
else:
8080
array = np.array
81-
# access AoS data such as positions and cpu/id
82-
aos = pti.aos()
83-
aos_arr = array(aos, copy=False)
8481

85-
# access SoA data such as momentum
82+
# access particle attributes
8683
soa = pti.soa()
8784
real_arrays = soa.GetRealData()
88-
px = array(real_arrays[0], copy=False)
89-
py = array(real_arrays[1], copy=False)
90-
pt = array(real_arrays[2], copy=False)
85+
x = array(real_arrays[0], copy=False)
86+
y = array(real_arrays[1], copy=False)
87+
t = array(real_arrays[2], copy=False)
88+
px = array(real_arrays[3], copy=False)
89+
py = array(real_arrays[4], copy=False)
90+
pt = array(real_arrays[5], copy=False)
9191

9292
# length of the current slice
9393
slice_ds = pge.ds / pge.nslice
@@ -97,9 +97,9 @@ def my_drift(pge, pti, refpart):
9797
betgam2 = pt_ref**2 - 1.0
9898

9999
# advance position and momentum (drift)
100-
aos_arr[:]["x"] += slice_ds * px[:]
101-
aos_arr[:]["y"] += slice_ds * py[:]
102-
aos_arr[:]["z"] += (slice_ds / betgam2) * pt[:]
100+
x[:] += slice_ds * px[:]
101+
y[:] += slice_ds * py[:]
102+
t[:] += (slice_ds / betgam2) * pt[:]
103103

104104

105105
def my_ref_drift(pge, refpart):

examples/pytorch_surrogate_model/run_ml_surrogate.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@
1111
from urllib import request
1212

1313
import numpy as np
14+
15+
try:
16+
import cupy as cp
17+
18+
cupy_available = True
19+
except ImportError:
20+
cupy_available = False
21+
1422
from surrogate_model_definitions import surrogate_model
1523

1624
try:
@@ -20,10 +28,11 @@
2028
sys.exit(0)
2129

2230
from impactx import (
31+
Config,
32+
CoordSystem,
2333
ImpactX,
2434
ImpactXParIter,
2535
RefPart,
26-
TransformationDirection,
2736
coordinate_transformation,
2837
distribution,
2938
elements,
@@ -79,7 +88,18 @@ def __init__(self, stage_i, surrogate_model, surrogate_length, stage_start):
7988
self.ds = surrogate_length
8089

8190
def surrogate_push(self, pc, step):
82-
array = np.array
91+
# CPU/GPU logic
92+
if Config.have_gpu:
93+
if cupy_available:
94+
array = cp.array
95+
else:
96+
print("Warning: GPU found but cupy not available! Try managed...")
97+
array = np.array
98+
if Config.gpu_backend == "SYCL":
99+
print("Warning: SYCL GPU backend not yet implemented for Python")
100+
101+
else:
102+
array = np.array
83103

84104
ref_part = pc.ref_particle()
85105
ref_z_i = ref_part.z
@@ -100,26 +120,28 @@ def surrogate_push(self, pc, step):
100120
ref_part_final = torch.tensor([0, 0, ref_z_f, 0, 0, ref_uz_f])
101121

102122
# transform
103-
coordinate_transformation(pc, TransformationDirection.to_fixed_t)
123+
coordinate_transformation(pc, direction=CoordSystem.t)
104124

105125
for lvl in range(pc.finest_level + 1):
106126
for pti in ImpactXParIter(pc, level=lvl):
107-
aos = pti.aos()
108-
aos_arr = array(aos, copy=False)
109-
110127
soa = pti.soa()
111128
real_arrays = soa.GetRealData()
112-
px = array(real_arrays[0], copy=False)
113-
py = array(real_arrays[1], copy=False)
114-
pt = array(real_arrays[2], copy=False)
115-
data_arr = (
116-
torch.tensor(
117-
np.vstack(
118-
[aos_arr["x"], aos_arr["y"], aos_arr["z"], real_arrays[:3]]
119-
)
120-
)
121-
.float()
122-
.T
129+
x = array(real_arrays[0], copy=False)
130+
y = array(real_arrays[1], copy=False)
131+
t = array(real_arrays[2], copy=False)
132+
px = array(real_arrays[3], copy=False)
133+
py = array(real_arrays[4], copy=False)
134+
pt = array(real_arrays[5], copy=False)
135+
data_arr = torch.stack(
136+
(
137+
x,
138+
y,
139+
t,
140+
px,
141+
py,
142+
py,
143+
),
144+
dim=0,
123145
)
124146

125147
data_arr[:, 0] += ref_part.x
@@ -147,9 +169,9 @@ def surrogate_push(self, pc, step):
147169
data_arr_post_model[:, 3 + ii] -= ref_part_final[3 + ii]
148170
data_arr_post_model[:, 3 + ii] /= ref_beta_gamma_final
149171

150-
aos_arr["x"] = data_arr_post_model[:, 0]
151-
aos_arr["y"] = data_arr_post_model[:, 1]
152-
aos_arr["z"] = data_arr_post_model[:, 2]
172+
x[:] = data_arr_post_model[:, 0]
173+
y[:] = data_arr_post_model[:, 1]
174+
t[:] = data_arr_post_model[:, 2]
153175
px[:] = data_arr_post_model[:, 3]
154176
py[:] = data_arr_post_model[:, 4]
155177
pt[:] = data_arr_post_model[:, 5]
@@ -174,7 +196,7 @@ def surrogate_push(self, pc, step):
174196
# ref_part.s += pge1.ds
175197
# ref_part.t += pge1.ds / ref_beta
176198

177-
coordinate_transformation(pc, TransformationDirection.to_fixed_s)
199+
coordinate_transformation(pc, direction=CoordSystem.s)
178200
## Done!
179201

180202

src/particles/CollectLost.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ namespace impactx
2727
using DstData = ImpactXParticleContainer::ParticleTileType::ParticleTileDataType;
2828

2929
AMREX_GPU_HOST_DEVICE
30-
void operator() (DstData const &dst, SrcData const &src, int src_ip, int dst_ip) const noexcept {
31-
dst.m_aos[dst_ip] = src.m_aos[src_ip];
32-
30+
void operator() (DstData const &dst, SrcData const &src, int src_ip, int dst_ip) const noexcept
31+
{
32+
dst.m_idcpu[dst_ip] = src.m_idcpu[src_ip];
3333
for (int j = 0; j < SrcData::NAR; ++j)
3434
dst.m_rdata[j][dst_ip] = src.m_rdata[j][src_ip];
3535
for (int j = 0; j < src.m_num_runtime_real; ++j)
@@ -141,8 +141,7 @@ namespace impactx
141141
// move down
142142
int const new_index = ip - n_removed;
143143

144-
ptile_src_data.m_aos[new_index] = ptile_src_data.m_aos[ip];
145-
144+
ptile_src_data.m_idcpu[new_index] = ptile_src_data.m_idcpu[ip];
146145
for (int j = 0; j < SrcData::NAR; ++j)
147146
ptile_src_data.m_rdata[j][new_index] = ptile_src_data.m_rdata[j][ip];
148147
for (int j = 0; j < ptile_src_data.m_num_runtime_real; ++j)

src/particles/ImpactXParticleContainer.H

Lines changed: 24 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <AMReX_MultiFab.H>
1919
#include <AMReX_ParIter.H>
2020
#include <AMReX_Particles.H>
21+
#include <AMReX_ParticleTile.H>
2122

2223
#include <AMReX_IntVect.H>
2324
#include <AMReX_Vector.H>
@@ -35,43 +36,16 @@ namespace impactx
3536
t ///< fixed t as the independent variable
3637
};
3738

38-
/** AMReX pre-defined Real attributes
39-
*
40-
* These are the AMReX pre-defined struct indexes for the Real attributes
41-
* stored in an AoS in ImpactXParticleContainer. We document this here,
42-
* because we change the meaning of these "positions" depending on the
43-
* coordinate system we are currently in.
44-
*/
45-
struct RealAoS
46-
{
47-
enum
48-
{
49-
x, ///< position in x [m] (at fixed s OR fixed t)
50-
y, ///< position in y [m] (at fixed s OR fixed t)
51-
t, ///< c * time-of-flight [m] (at fixed s)
52-
nattribs ///< the number of attributes above (always last)
53-
};
54-
55-
// at fixed t, the third component represents the position z
56-
enum {
57-
z = t ///< position in z [m] (at fixed t)
58-
};
59-
60-
//! named labels for fixed s
61-
static constexpr auto names_s = { "position_x", "position_y", "position_t" };
62-
//! named labels for fixed t
63-
static constexpr auto names_t = { "position_x", "position_y", "position_z" };
64-
static_assert(names_s.size() == nattribs);
65-
static_assert(names_t.size() == nattribs);
66-
};
67-
68-
/** This struct indexes the additional Real attributes
39+
/** This struct indexes the Real attributes
6940
* stored in an SoA in ImpactXParticleContainer
7041
*/
7142
struct RealSoA
7243
{
7344
enum
7445
{
46+
x, ///< position in x [m] (at fixed s or t)
47+
y, ///< position in y [m] (at fixed s or t)
48+
t, ///< time-of-flight ct [m] (at fixed s)
7549
px, ///< momentum in x, scaled by the magnitude of the reference momentum [unitless] (at fixed s or t)
7650
py, ///< momentum in y, scaled by the magnitude of the reference momentum [unitless] (at fixed s or t)
7751
pt, ///< energy deviation, scaled by speed of light * the magnitude of the reference momentum [unitless] (at fixed s)
@@ -80,27 +54,28 @@ namespace impactx
8054
nattribs ///< the number of attributes above (always last)
8155
};
8256

83-
// at fixed t, the third component represents the momentum in z
57+
// at fixed t, the third component represents the position z, the 6th component represents the momentum in z
8458
enum {
59+
z = t, ///< position in z [m] (at fixed t)
8560
pz = pt ///< momentum in z, scaled by the magnitude of the reference momentum [unitless] (at fixed t)
8661
};
8762

8863
//! named labels for fixed s
89-
static constexpr auto names_s = { "momentum_x", "momentum_y", "momentum_t", "qm", "weighting" };
64+
static constexpr auto names_s = { "position_x", "position_y", "position_t", "momentum_x", "momentum_y", "momentum_t", "qm", "weighting" };
9065
//! named labels for fixed t
91-
static constexpr auto names_t = { "momentum_x", "momentum_y", "momentum_z", "qm", "weighting" };
66+
static constexpr auto names_t = { "position_x", "position_y", "position_z", "momentum_x", "momentum_y", "momentum_z", "qm", "weighting" };
9267
static_assert(names_s.size() == nattribs);
9368
static_assert(names_t.size() == nattribs);
9469
};
9570

96-
/** This struct indexes the additional Integer attributes
71+
/** This struct indexes the Integer attributes
9772
* stored in an SoA in ImpactXParticleContainer
9873
*/
9974
struct IntSoA
10075
{
10176
enum
10277
{
103-
nattribs ///< the number of particles above (always last)
78+
nattribs ///< the number of attributes above (always last)
10479
};
10580
};
10681

@@ -109,46 +84,46 @@ namespace impactx
10984
* We subclass here to change the default threading strategy, which is
11085
* `static` in AMReX, to `dynamic` in ImpactX.
11186
*/
112-
class ParIter
113-
: public amrex::ParIter<0, 0, RealSoA::nattribs, IntSoA::nattribs>
87+
class ParIterSoA
88+
: public amrex::ParIterSoA<RealSoA::nattribs, IntSoA::nattribs>
11489
{
11590
public:
116-
using amrex::ParIter<0, 0, RealSoA::nattribs, IntSoA::nattribs>::ParIter;
91+
using amrex::ParIterSoA<RealSoA::nattribs, IntSoA::nattribs>::ParIterSoA;
11792

118-
ParIter (ContainerType& pc, int level);
93+
ParIterSoA (ContainerType& pc, int level);
11994

120-
ParIter (ContainerType& pc, int level, amrex::MFItInfo& info);
95+
ParIterSoA (ContainerType& pc, int level, amrex::MFItInfo& info);
12196
};
12297

12398
/** Const AMReX iterator for particle boxes - data is read only.
12499
*
125100
* We subclass here to change the default threading strategy, which is
126101
* `static` in AMReX, to `dynamic` in ImpactX.
127102
*/
128-
class ParConstIter
129-
: public amrex::ParConstIter<0, 0, RealSoA::nattribs, IntSoA::nattribs>
103+
class ParConstIterSoA
104+
: public amrex::ParConstIterSoA<RealSoA::nattribs, IntSoA::nattribs>
130105
{
131106
public:
132-
using amrex::ParConstIter<0, 0, RealSoA::nattribs, IntSoA::nattribs>::ParConstIter;
107+
using amrex::ParConstIterSoA<RealSoA::nattribs, IntSoA::nattribs>::ParConstIterSoA;
133108

134-
ParConstIter (ContainerType& pc, int level);
109+
ParConstIterSoA (ContainerType& pc, int level);
135110

136-
ParConstIter (ContainerType& pc, int level, amrex::MFItInfo& info);
111+
ParConstIterSoA (ContainerType& pc, int level, amrex::MFItInfo& info);
137112
};
138113

139114
/** Beam Particles in ImpactX
140115
*
141116
* This class stores particles, distributed over MPI ranks.
142117
*/
143118
class ImpactXParticleContainer
144-
: public amrex::ParticleContainer<0, 0, RealSoA::nattribs, IntSoA::nattribs>
119+
: public amrex::ParticleContainerPureSoA<RealSoA::nattribs, IntSoA::nattribs>
145120
{
146121
public:
147122
//! amrex iterator for particle boxes
148-
using iterator = impactx::ParIter;
123+
using iterator = impactx::ParIterSoA;
149124

150125
//! amrex constant iterator for particle boxes (read-only)
151-
using const_iterator = impactx::ParConstIter;
126+
using const_iterator = impactx::ParConstIterSoA;
152127

153128
//! Construct a new particle container
154129
ImpactXParticleContainer (initialization::AmrCoreData* amr_core);
@@ -276,10 +251,6 @@ namespace impactx
276251
DepositCharge (std::unordered_map<int, amrex::MultiFab> & rho,
277252
amrex::Vector<amrex::IntVect> const & ref_ratio);
278253

279-
/** Get the name of each Real AoS component */
280-
std::vector<std::string>
281-
RealAoS_names () const;
282-
283254
/** Get the name of each Real SoA component */
284255
std::vector<std::string>
285256
RealSoA_names () const;
@@ -311,10 +282,6 @@ namespace impactx
311282

312283
}; // ImpactXParticleContainer
313284

314-
/** Get the name of each Real AoS component */
315-
std::vector<std::string>
316-
get_RealAoS_names ();
317-
318285
/** Get the name of each Real SoA component
319286
*
320287
* @param num_real_comps number of compile-time + runtime arrays

0 commit comments

Comments
 (0)