Skip to content

Commit 3cf4c29

Browse files
authored
Make VS relative thread-safe (#5212)
1 parent d7811f0 commit 3cf4c29

File tree

3 files changed

+50
-41
lines changed

3 files changed

+50
-41
lines changed

src/core/cell_system/CellStructure.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,10 @@ class CellStructure : public System::Leaf<CellStructure> {
364364
* @brief Run a kernel on all local particles.
365365
* The kernel is assumed to be thread-safe.
366366
*/
367-
void for_each_local_particle(ParticleUnaryOp &&f) const {
367+
void for_each_local_particle(ParticleUnaryOp &&f,
368+
bool parallel = true) const {
368369
#ifdef ESPRESSO_SHARED_MEMORY_PARALLELISM
369-
if (use_parallel_for_each_local_particle()) {
370+
if (parallel and use_parallel_for_each_local_particle()) {
370371
parallel_for_each_particle_impl(decomposition().local_cells(), f);
371372
return;
372373
}

src/core/virtual_sites/relative.cpp

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -168,23 +168,26 @@ void vs_relative_back_transfer_forces_and_torques(
168168
init_forces_ghosts(cell_structure);
169169

170170
// Iterate over all the particles in the local cells
171-
cell_structure.for_each_local_particle([&](Particle &p) {
172-
if (!is_vs(p))
173-
return;
174-
175-
auto *p_ref_ptr = get_reference_particle(cell_structure, p);
176-
assert(p_ref_ptr != nullptr);
177-
178-
auto &p_ref = *p_ref_ptr;
179-
if (is_vs_relative_trans(p)) {
180-
p_ref.force() += p.force();
181-
p_ref.torque() += vector_product(connection_vector(p_ref, p), p.force());
182-
}
183-
184-
if (is_vs_rot(p)) {
185-
p_ref.torque() += p.torque();
186-
}
187-
});
171+
cell_structure.for_each_local_particle(
172+
[&](Particle &p) {
173+
if (!is_vs(p))
174+
return;
175+
176+
auto *p_ref_ptr = get_reference_particle(cell_structure, p);
177+
assert(p_ref_ptr != nullptr);
178+
179+
auto &p_ref = *p_ref_ptr;
180+
if (is_vs_relative_trans(p)) {
181+
p_ref.force() += p.force();
182+
p_ref.torque() +=
183+
vector_product(connection_vector(p_ref, p), p.force());
184+
}
185+
186+
if (is_vs_rot(p)) {
187+
p_ref.torque() += p.torque();
188+
}
189+
},
190+
/* parallel */ false);
188191
}
189192

190193
// Rigid body contribution to scalar pressure and pressure tensor

testsuite/python/virtual_sites_relative.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -202,14 +202,19 @@ def test_pos_vel_forces(self):
202202
system.min_global_cut = 0.23
203203
self.assertEqual(system.min_global_cut, 0.23)
204204

205-
# Place central particle + 3 vs
205+
# Place central particle + N virtual sites
206206
p1 = system.part.add(rotation=3 * [True], pos=(0.5, 0.5, 0.5), id=1,
207207
quat=(1, 0, 0, 0), omega_lab=(1, 2, 3))
208-
pos2 = (0.5, 0.4, 0.5)
209-
pos3 = (0.3, 0.5, 0.4)
210-
pos4 = (0.5, 0.5, 0.5)
211-
for pos in (pos2, pos3, pos4):
212-
p = system.part.add(rotation=3 * [True], pos=pos)
208+
209+
# Number of virtual sites to create
210+
N = 100
211+
# Generate N random positions within 1.2 of central particle in each coordinate
212+
np.random.seed(42)
213+
vs_positions = p1.pos + np.random.uniform(-0.15, 0.15, (N, 3))
214+
215+
# Create virtual sites at random positions
216+
sites = system.part.add(rotation=[3 * [True]] * N, pos=vs_positions)
217+
for p in sites:
213218
p.vs_auto_relate_to(p1)
214219
# Was the particle made virtual
215220
self.assertTrue(p.is_virtual())
@@ -226,39 +231,39 @@ def test_pos_vel_forces(self):
226231
p1.v = (0.45, 0.14, 0.447)
227232
p1.omega_lab = (0.45, 0.14, 0.447)
228233
system.integrator.run(0, recalc_forces=True)
229-
for p in system.part:
230-
if p.id != p1.id:
231-
self.verify_vs(p)
234+
for p in sites:
235+
self.verify_vs(p)
232236

233237
# Check if still true, when non-virtual particle has rotated and a
234238
# linear motion
235239
p1.omega_lab = [-5., 3., 8.4]
236240
system.integrator.run(10)
237-
for p in system.part:
238-
if p.id != p1.id:
239-
self.verify_vs(p)
241+
for p in sites:
242+
self.verify_vs(p)
240243

241244
if espressomd.has_features("EXTERNAL_FORCES"):
242245
# Test transfer of forces accumulating on virtual sites
243246
# to central particle
244-
f2 = np.array((3, 4, 5))
245-
f3 = np.array((-4, 5, 6))
246-
# Add forces to vs
247-
p2, p3 = system.part.by_ids([2, 3])
248-
p2.ext_force = f2
249-
p3.ext_force = f3
247+
# Generate random forces for all N virtual sites
248+
sites.ext_force = np.random.uniform(-5, 5, (N, 3))
249+
250250
system.integrator.run(0)
251251
# get force/torques on non-vs
252252
f = p1.f
253253
t = p1.torque_lab
254254

255-
# Expected force = sum of the forces on the vs
256-
self.assertAlmostEqual(np.linalg.norm(f - f2 - f3), 0., delta=1E-6)
255+
# Expected force = sum of all forces on the vs
256+
f_exp = np.sum(sites.ext_force, axis=0)
257+
print()
258+
print(f"{f=} {p1.f=}, {f_exp=}")
259+
self.assertAlmostEqual(np.linalg.norm(f - f_exp), 0., delta=1E-6)
257260

258261
# Expected torque
259262
# Radial components of forces on a rigid body add to the torque
260-
t_exp = np.cross(system.distance_vec(p1, p2), f2)
261-
t_exp += np.cross(system.distance_vec(p1, p3), f3)
263+
t_exp = np.zeros(3)
264+
for vs_p in sites:
265+
t_exp += np.cross(system.distance_vec(p1,
266+
vs_p), vs_p.ext_force)
262267
# Check
263268
self.assertAlmostEqual(np.linalg.norm(t_exp - t), 0., delta=1E-6)
264269

0 commit comments

Comments
 (0)