Skip to content

Commit 86e0b54

Browse files
Adding both a datstruct and a dict for the particledata
Following @VeckoTheGecko's suggestion at Parcels-code/parcels-benchmarks#1 (comment)
1 parent f681162 commit 86e0b54

File tree

2 files changed

+46
-18
lines changed

2 files changed

+46
-18
lines changed

parcels/particleset.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -135,35 +135,49 @@ def __init__(
135135
lon.size == kwargs[kwvar].size
136136
), f"{kwvar} and positions (lon, lat, depth) don't have the same lengths."
137137

138-
self._data = {
139-
"lon": lon.astype(lonlatdepth_dtype),
140-
"lat": lat.astype(lonlatdepth_dtype),
141-
"depth": depth.astype(lonlatdepth_dtype),
142-
"time": time,
143-
"dt": np.timedelta64(1, "ns") * np.ones(len(trajectory_ids)),
144-
# "ei": (["trajectory", "ngrid"], np.zeros((len(trajectory_ids), len(fieldset.gridset)), dtype=np.int32)),
145-
"state": np.zeros((len(trajectory_ids)), dtype=np.int32),
146-
"lon_nextloop": lon.astype(lonlatdepth_dtype),
147-
"lat_nextloop": lat.astype(lonlatdepth_dtype),
148-
"depth_nextloop": depth.astype(lonlatdepth_dtype),
149-
"time_nextloop": time,
150-
"trajectory": trajectory_ids,
151-
}
152-
self._ptype = pclass.getPType()
138+
self._ds = xr.Dataset(
139+
{
140+
"lon": (["trajectory"], lon.astype(lonlatdepth_dtype)),
141+
"lat": (["trajectory"], lat.astype(lonlatdepth_dtype)),
142+
"depth": (["trajectory"], depth.astype(lonlatdepth_dtype)),
143+
"time": (["trajectory"], time),
144+
"dt": (["trajectory"], np.timedelta64(1, "ns") * np.ones(len(trajectory_ids))),
145+
"ei": (["trajectory", "ngrid"], np.zeros((len(trajectory_ids), len(fieldset.gridset)), dtype=np.int32)),
146+
"state": (["trajectory"], np.zeros((len(trajectory_ids)), dtype=np.int32)),
147+
"lon_nextloop": (["trajectory"], lon.astype(lonlatdepth_dtype)),
148+
"lat_nextloop": (["trajectory"], lat.astype(lonlatdepth_dtype)),
149+
"depth_nextloop": (["trajectory"], depth.astype(lonlatdepth_dtype)),
150+
"time_nextloop": (["trajectory"], time),
151+
},
152+
coords={
153+
"trajectory": ("trajectory", trajectory_ids),
154+
},
155+
attrs={
156+
"ngrid": len(fieldset.gridset),
157+
"ptype": pclass.getPType(),
158+
},
159+
)
153160
# add extra fields from the custom Particle class
154161
for v in pclass.__dict__.values():
155162
if isinstance(v, Variable):
156163
if isinstance(v.initial, attrgetter):
157-
initial = v.initial(self)
164+
initial = v.initial(self).values
158165
else:
159166
initial = v.initial * np.ones(len(trajectory_ids), dtype=v.dtype)
160-
self._data[v.name] = initial
167+
self._ds[v.name] = (["trajectory"], initial)
161168

162169
# update initial values provided on ParticleSet creation
163170
for kwvar, kwval in kwargs.items():
164171
if not hasattr(pclass, kwvar):
165172
raise RuntimeError(f"Particle class does not have Variable {kwvar}")
166-
self._data[kwvar][:] = kwval
173+
self._ds[kwvar][:] = kwval
174+
175+
# also keep a struct of numpy arrays for faster access (see parcels-benchmarks/pull/1)
176+
self._data = {}
177+
for v in self._ds.keys():
178+
self._data[v] = self._ds[v].data
179+
self._data["trajectory"] = self._ds["trajectory"].data
180+
self._ptype = self._ds.attrs["ptype"]
167181

168182
self._kernel = None
169183

tests/v4/test_particleset_execute.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,20 @@ def PythonFail(particle, fieldset, time): # pragma: no cover
116116
assert all([time == fieldset.time_interval.left + np.timedelta64(0, "s") for time in pset.time[1:]])
117117

118118

119+
def test_pset_update_particle(fieldset, npart=10):
120+
lon_start = np.linspace(0, 1, npart)
121+
lat_start = np.linspace(1, 0, npart)
122+
pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.linspace(1, 0, npart))
123+
124+
def UpdateParticle(particle, fieldset, time): # pragma: no cover
125+
particle.lon += 0.1
126+
particle.lat -= 0.1
127+
128+
pset.execute(pset.Kernel(UpdateParticle), runtime=np.timedelta64(10, "s"), dt=np.timedelta64(1, "s"))
129+
assert np.allclose(pset.lon, lon_start + 1, atol=1e-5)
130+
assert np.allclose(pset.lat, lat_start - 1, atol=1e-5)
131+
132+
119133
@pytest.mark.parametrize("verbose_progress", [True, False])
120134
def test_uxstommelgyre_pset_execute(verbose_progress):
121135
ds = datasets_unstructured["stommel_gyre_delaunay"]

0 commit comments

Comments
 (0)