Skip to content

Commit 9876a9e

Browse files
authored
Update Particle Container to Pure SoA (#348)
Transition particle containers to pure SoA layouts.
1 parent 7259a22 commit 9876a9e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1015
-795
lines changed

cmake/dependencies/ABLASTR.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ set(ImpactX_openpmd_src ""
178178
set(ImpactX_ablastr_repo "https://github.com/ECP-WarpX/WarpX.git"
179179
CACHE STRING
180180
"Repository URI to pull and build ABLASTR from if(ImpactX_ablastr_internal)")
181-
set(ImpactX_ablastr_branch "24.02"
181+
set(ImpactX_ablastr_branch "11aabdca56335c5ae1cbb2257b8abd6c8f04a67c"
182182
CACHE STRING
183183
"Repository branch for ImpactX_ablastr_repo if(ImpactX_ablastr_internal)")
184184

cmake/dependencies/pyAMReX.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ option(ImpactX_pyamrex_internal "Download & build pyAMReX" ON)
7979
set(ImpactX_pyamrex_repo "https://github.com/AMReX-Codes/pyamrex.git"
8080
CACHE STRING
8181
"Repository URI to pull and build pyamrex from if(ImpactX_pyamrex_internal)")
82-
set(ImpactX_pyamrex_branch "24.02"
82+
set(ImpactX_pyamrex_branch "5aa700de18a61f933cb435adbe2299d74d794d6b"
8383
CACHE STRING
8484
"Repository branch for ImpactX_pyamrex_repo if(ImpactX_pyamrex_internal)")
8585

examples/epac2004_benchmarks/input_fodo_rf_SC.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,4 @@ geometry.prob_relative = 4.0
125125
###############################################################################
126126
# Diagnostics
127127
###############################################################################
128-
diag.slice_step_diagnostics = true
128+
diag.slice_step_diagnostics = false

examples/fodo/run_fodo_programmable.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,16 @@ def my_drift(pge, pti, refpart):
7777

7878
else:
7979
array = np.array
80-
# access AoS data such as positions and cpu/id
81-
aos = pti.aos()
82-
aos_arr = array(aos, copy=False)
8380

84-
# access SoA data such as momentum
81+
# access particle attributes
8582
soa = pti.soa()
86-
real_arrays = soa.GetRealData()
87-
px = array(real_arrays[0], copy=False)
88-
py = array(real_arrays[1], copy=False)
89-
pt = array(real_arrays[2], copy=False)
83+
real_arrays = soa.get_real_data()
84+
x = array(real_arrays[0], copy=False)
85+
y = array(real_arrays[1], copy=False)
86+
t = array(real_arrays[2], copy=False)
87+
px = array(real_arrays[3], copy=False)
88+
py = array(real_arrays[4], copy=False)
89+
pt = array(real_arrays[5], copy=False)
9090

9191
# length of the current slice
9292
slice_ds = pge.ds / pge.nslice
@@ -96,9 +96,9 @@ def my_drift(pge, pti, refpart):
9696
betgam2 = pt_ref**2 - 1.0
9797

9898
# advance position and momentum (drift)
99-
aos_arr[:]["x"] += slice_ds * px[:]
100-
aos_arr[:]["y"] += slice_ds * py[:]
101-
aos_arr[:]["z"] += (slice_ds / betgam2) * pt[:]
99+
x[:] += slice_ds * px[:]
100+
y[:] += slice_ds * py[:]
101+
t[:] += (slice_ds / betgam2) * pt[:]
102102

103103

104104
def my_ref_drift(pge, refpart):

examples/pytorch_surrogate_model/run_ml_surrogate.py

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@
1111
from urllib import request
1212

1313
import numpy as np
14+
15+
try:
16+
import cupy as cp
17+
18+
cupy_available = True
19+
except ImportError:
20+
cupy_available = False
21+
1422
from surrogate_model_definitions import surrogate_model
1523

1624
try:
@@ -20,14 +28,34 @@
2028
sys.exit(0)
2129

2230
from impactx import (
31+
Config,
32+
CoordSystem,
2333
ImpactX,
2434
ImpactXParIter,
25-
TransformationDirection,
2635
coordinate_transformation,
2736
distribution,
2837
elements,
2938
)
3039

40+
# CPU/GPU logic
41+
if Config.have_gpu:
42+
if cupy_available:
43+
array = cp.array
44+
stack = cp.stack
45+
device = torch.device("cuda")
46+
else:
47+
print("Warning: GPU found but cupy not available! Try managed...")
48+
array = np.array
49+
stack = np.stack
50+
device = torch.device("cpu")
51+
if Config.gpu_backend == "SYCL":
52+
print("Warning: SYCL GPU backend not yet implemented for Python")
53+
54+
else:
55+
array = np.array
56+
stack = np.stack
57+
device = torch.device("cpu")
58+
3159

3260
def download_and_unzip(url, data_dir):
3361
request.urlretrieve(url, data_dir)
@@ -50,6 +78,7 @@ def download_and_unzip(url, data_dir):
5078
surrogate_model(
5179
dataset_dir + f"dataset_beam_stage_{i}.pt",
5280
model_dir + f"beam_stage_{i}_model.pt",
81+
device=device,
5382
)
5483
for i in range(N_stage)
5584
]
@@ -78,47 +107,62 @@ def __init__(self, stage_i, surrogate_model, surrogate_length, stage_start):
78107
self.ds = surrogate_length
79108

80109
def surrogate_push(self, pc, step):
81-
array = np.array
82-
83110
ref_part = pc.ref_particle()
84111
ref_z_i = ref_part.z
85112
ref_z_i_LPA = ref_z_i - self.stage_start
86113
ref_z_f = ref_z_i + self.surrogate_length
87114

88115
ref_part_tensor = torch.tensor(
89-
[ref_part.x, ref_part.y, ref_z_i_LPA, ref_part.px, ref_part.py, ref_part.pz]
116+
[
117+
ref_part.x,
118+
ref_part.y,
119+
ref_z_i_LPA,
120+
ref_part.px,
121+
ref_part.py,
122+
ref_part.pz,
123+
],
124+
dtype=torch.float64,
125+
device=device,
90126
)
91-
ref_beta_gamma = np.sqrt(torch.sum(ref_part_tensor[3:] ** 2))
127+
ref_beta_gamma = torch.sqrt(torch.sum(ref_part_tensor[3:] ** 2))
92128

93129
with torch.no_grad():
94-
ref_part_model_final = self.surrogate_model(ref_part_tensor.float())
130+
ref_part_model_final = self.surrogate_model(ref_part_tensor)
95131
ref_uz_f = ref_part_model_final[5]
96132
ref_beta_gamma_final = (
97133
ref_uz_f # NOT np.sqrt(torch.sum(ref_part_model_final[3:]**2))
98134
)
99-
ref_part_final = torch.tensor([0, 0, ref_z_f, 0, 0, ref_uz_f])
135+
ref_part_final = torch.tensor(
136+
[0, 0, ref_z_f, 0, 0, ref_uz_f], dtype=torch.float64, device=device
137+
)
100138

101139
# transform
102-
coordinate_transformation(pc, TransformationDirection.to_fixed_t)
140+
coordinate_transformation(pc, direction=CoordSystem.t)
103141

104142
for lvl in range(pc.finest_level + 1):
105143
for pti in ImpactXParIter(pc, level=lvl):
106-
aos = pti.aos()
107-
aos_arr = array(aos, copy=False)
108-
109144
soa = pti.soa()
110-
real_arrays = soa.GetRealData()
111-
px = array(real_arrays[0], copy=False)
112-
py = array(real_arrays[1], copy=False)
113-
pt = array(real_arrays[2], copy=False)
114-
data_arr = (
115-
torch.tensor(
116-
np.vstack(
117-
[aos_arr["x"], aos_arr["y"], aos_arr["z"], real_arrays[:3]]
118-
)
119-
)
120-
.float()
121-
.T
145+
real_arrays = soa.get_real_data()
146+
x = array(real_arrays[0], copy=False)
147+
y = array(real_arrays[1], copy=False)
148+
t = array(real_arrays[2], copy=False)
149+
px = array(real_arrays[3], copy=False)
150+
py = array(real_arrays[4], copy=False)
151+
pt = array(real_arrays[5], copy=False)
152+
data_arr = torch.tensor(
153+
stack(
154+
[
155+
x,
156+
y,
157+
t,
158+
px,
159+
py,
160+
py,
161+
],
162+
axis=1,
163+
),
164+
dtype=torch.float64,
165+
device=device,
122166
)
123167

124168
data_arr[:, 0] += ref_part.x
@@ -135,7 +179,7 @@ def surrogate_push(self, pc, step):
135179
# # assume for now it is
136180

137181
with torch.no_grad():
138-
data_arr_post_model = self.surrogate_model(data_arr.float())
182+
data_arr_post_model = self.surrogate_model(data_arr)
139183

140184
# need to add stage start to z
141185
data_arr_post_model[:, 2] += self.stage_start
@@ -146,9 +190,9 @@ def surrogate_push(self, pc, step):
146190
data_arr_post_model[:, 3 + ii] -= ref_part_final[3 + ii]
147191
data_arr_post_model[:, 3 + ii] /= ref_beta_gamma_final
148192

149-
aos_arr["x"] = data_arr_post_model[:, 0]
150-
aos_arr["y"] = data_arr_post_model[:, 1]
151-
aos_arr["z"] = data_arr_post_model[:, 2]
193+
x[:] = data_arr_post_model[:, 0]
194+
y[:] = data_arr_post_model[:, 1]
195+
t[:] = data_arr_post_model[:, 2]
152196
px[:] = data_arr_post_model[:, 3]
153197
py[:] = data_arr_post_model[:, 4]
154198
pt[:] = data_arr_post_model[:, 5]
@@ -160,7 +204,7 @@ def surrogate_push(self, pc, step):
160204
ref_part.x = ref_part_final[0]
161205
ref_part.y = ref_part_final[1]
162206
ref_part.z = ref_part_final[2]
163-
ref_gamma = np.sqrt(1 + ref_beta_gamma_final**2)
207+
ref_gamma = torch.sqrt(1 + ref_beta_gamma_final**2)
164208
ref_part.px = ref_part_final[3]
165209
ref_part.py = ref_part_final[4]
166210
ref_part.pz = ref_part_final[5]
@@ -173,7 +217,7 @@ def surrogate_push(self, pc, step):
173217
# ref_part.s += pge1.ds
174218
# ref_part.t += pge1.ds / ref_beta
175219

176-
coordinate_transformation(pc, TransformationDirection.to_fixed_s)
220+
coordinate_transformation(pc, direction=CoordSystem.s)
177221
## Done!
178222

179223

examples/pytorch_surrogate_model/surrogate_model_definitions.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,10 @@ def __init__(self, n_in, n_out, n_hidden_nodes, n_hidden_layers, act):
9090
class surrogate_model:
9191
""" """
9292

93-
def __init__(self, dataset_file, model_file):
93+
def __init__(self, dataset_file, model_file, device):
9494
self.dataset = torch.load(dataset_file)
95-
model_dict = torch.load(model_file, map_location=torch.device("cpu"))
95+
self.device = device
96+
model_dict = torch.load(model_file)
9697
n_in = model_dict["model_state_dict"]["stack.0.weight"].shape[1]
9798
final_layer_key = list(model_dict["model_state_dict"].keys())[-1]
9899
n_out = model_dict["model_state_dict"][final_layer_key].shape[0]
@@ -112,13 +113,20 @@ def __init__(self, dataset_file, model_file):
112113
self.neural_network.load_state_dict(model_dict["model_state_dict"])
113114
self.neural_network.eval()
114115

115-
def __call__(self, data_arr):
116-
data_arr -= self.dataset["source_means"]
117-
data_arr /= self.dataset["source_stds"]
118-
data_arr = data_arr.float()
116+
def __call__(self, data_arr, device=None):
117+
data_arr -= torch.tensor(
118+
self.dataset["source_means"], dtype=torch.float64, device=device
119+
)
120+
data_arr /= torch.tensor(
121+
self.dataset["source_stds"], dtype=torch.float64, device=device
122+
)
119123
with torch.no_grad():
120-
data_arr_post_model = self.neural_network(data_arr)
124+
data_arr_post_model = self.neural_network(data_arr.float()).double()
121125

122-
data_arr_post_model *= self.dataset["target_stds"]
123-
data_arr_post_model += self.dataset["target_means"]
126+
data_arr_post_model *= torch.tensor(
127+
self.dataset["target_stds"], dtype=torch.float64, device=device
128+
)
129+
data_arr_post_model += torch.tensor(
130+
self.dataset["target_means"], dtype=torch.float64, device=device
131+
)
124132
return data_arr_post_model

src/particles/CollectLost.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <AMReX_GpuLaunch.H>
1313
#include <AMReX_GpuQualifiers.H>
1414
#include <AMReX_Math.H>
15+
#include <AMReX_Particle.H>
1516
#include <AMReX_ParticleTransformation.H>
1617
#include <AMReX_RandomEngine.H>
1718

@@ -27,9 +28,9 @@ namespace impactx
2728
using DstData = ImpactXParticleContainer::ParticleTileType::ParticleTileDataType;
2829

2930
AMREX_GPU_HOST_DEVICE
30-
void operator() (DstData const &dst, SrcData const &src, int src_ip, int dst_ip) const noexcept {
31-
dst.m_aos[dst_ip] = src.m_aos[src_ip];
32-
31+
void operator() (DstData const &dst, SrcData const &src, int src_ip, int dst_ip) const noexcept
32+
{
33+
dst.m_idcpu[dst_ip] = src.m_idcpu[src_ip];
3334
for (int j = 0; j < SrcData::NAR; ++j)
3435
dst.m_rdata[j][dst_ip] = src.m_rdata[j][src_ip];
3536
for (int j = 0; j < src.m_num_runtime_real; ++j)
@@ -42,7 +43,7 @@ namespace impactx
4243
// dst.m_runtime_idata[j][dst_ip] = src.m_runtime_idata[j][src_ip];
4344

4445
// flip id to positive in destination
45-
dst.id(dst_ip) = amrex::Math::abs(dst.id(dst_ip));
46+
amrex::ParticleIDWrapper{dst.m_idcpu[dst_ip]}.make_valid();
4647

4748
// remember the current s of the ref particle when lost
4849
dst.m_runtime_rdata[s_index][dst_ip] = s_lost;
@@ -85,7 +86,7 @@ namespace impactx
8586
auto const predicate = [] AMREX_GPU_HOST_DEVICE (const SrcData& src, int ip)
8687
/* NVCC 11.3.109 chokes in C++17 on this: noexcept */
8788
{
88-
return src.id(ip) < 0;
89+
return !amrex::ConstParticleIDWrapper{src.m_idcpu[ip]}.is_valid();
8990
};
9091

9192
auto& ptile_dest = dest.DefineAndReturnParticleTile(
@@ -130,9 +131,11 @@ namespace impactx
130131
{
131132
int n_removed = 0;
132133
auto ptile_src_data = ptile_source.getParticleTileData();
134+
auto const ptile_soa = ptile_source.GetStructOfArrays();
135+
auto const ptile_idcpu = ptile_soa.GetIdCPUData().dataPtr();
133136
for (int ip = 0; ip < np; ++ip)
134137
{
135-
if (ptile_source.id(ip) < 0)
138+
if (!amrex::ConstParticleIDWrapper{ptile_idcpu[ip]}.is_valid())
136139
n_removed++;
137140
else
138141
{
@@ -141,8 +144,7 @@ namespace impactx
141144
// move down
142145
int const new_index = ip - n_removed;
143146

144-
ptile_src_data.m_aos[new_index] = ptile_src_data.m_aos[ip];
145-
147+
ptile_src_data.m_idcpu[new_index] = ptile_src_data.m_idcpu[ip];
146148
for (int j = 0; j < SrcData::NAR; ++j)
147149
ptile_src_data.m_rdata[j][new_index] = ptile_src_data.m_rdata[j][ip];
148150
for (int j = 0; j < ptile_src_data.m_num_runtime_real; ++j)

0 commit comments

Comments
 (0)