Skip to content

Commit 62a596c

Browse files
Merge pull request #2122 from OceanParcels/vectorized-kernel
Implementing vectorized kernels. More testing and development needs to be done, but for now let's merge this in so we can continue with the v4-dev
2 parents add8158 + 5d2bf9a commit 62a596c

21 files changed

+692
-688
lines changed

parcels/_core/utils/time.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ def __init__(self, left: T, right: T) -> None:
4848
def __contains__(self, item: T) -> bool:
4949
return self.left <= item <= self.right
5050

51+
def is_all_time_in_interval(self, time):
52+
item = np.atleast_1d(time)
53+
return (self.left <= item).all() and (item <= self.right).all()
54+
5155
def __repr__(self) -> str:
5256
return f"TimeInterval(left={self.left!r}, right={self.right!r})"
5357

parcels/_index_search.py

Lines changed: 33 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -39,36 +39,14 @@ def _search_time_index(field: Field, time: datetime):
3939
if the sampled value is outside the time value range.
4040
"""
4141
if field.time_interval is None:
42-
return 0, 0
42+
return np.zeros(shape=time.shape, dtype=np.float32), np.zeros(shape=time.shape, dtype=np.int32)
4343

44-
if time not in field.time_interval:
44+
if not field.time_interval.is_all_time_in_interval(time):
4545
_raise_time_extrapolation_error(time, field=None)
4646

47-
time_index = field.data.time <= time
48-
49-
if time_index.all():
50-
# If given time > last known field time, use
51-
# the last field frame without interpolation
52-
ti = len(field.data.time) - 1
53-
54-
elif np.logical_not(time_index).all():
55-
# If given time < any time in the field, use
56-
# the first field frame without interpolation
57-
ti = 0
58-
else:
59-
ti = int(time_index.argmin() - 1) if time_index.any() else 0
60-
if len(field.data.time) == 1:
61-
tau = 0
62-
elif ti == len(field.data.time) - 1:
63-
tau = 1
64-
else:
65-
tau = (
66-
(time - field.data.time[ti]).dt.total_seconds().values
67-
/ (field.data.time[ti + 1] - field.data.time[ti]).dt.total_seconds().values
68-
if field.data.time[ti] != field.data.time[ti + 1]
69-
else 0
70-
)
71-
return tau, ti
47+
ti = np.searchsorted(field.data.time.data, time, side="right") - 1
48+
tau = (time - field.data.time.data[ti]) / (field.data.time.data[ti + 1] - field.data.time.data[ti])
49+
return np.atleast_1d(tau), np.atleast_1d(ti)
7250

7351

7452
def search_indices_vertical_z(depth, gridindexingtype: GridIndexingType, z: float):
@@ -273,13 +251,13 @@ def _search_indices_rectilinear(
273251

274252
def _search_indices_curvilinear_2d(
275253
grid: XGrid, y: float, x: float, yi_guess: int | None = None, xi_guess: int | None = None
276-
):
254+
): # TODO fix typing instructions to make clear that y, x etc need to be ndarrays
277255
yi, xi = yi_guess, xi_guess
278256
if yi is None or xi is None:
279257
faces = grid.get_spatial_hash().query(np.column_stack((y, x)))
280258
yi, xi = faces[0]
281259

282-
xsi = eta = -1.0
260+
xsi = eta = -1.0 * np.ones(len(x), dtype=float)
283261
invA = np.array(
284262
[
285263
[1, 0, 0, 0],
@@ -303,7 +281,7 @@ def _search_indices_curvilinear_2d(
303281
# if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]:
304282
# _raise_field_out_of_bound_error(z, y, x)
305283

306-
while xsi < -tol or xsi > 1 + tol or eta < -tol or eta > 1 + tol:
284+
while np.any(xsi < -tol) or np.any(xsi > 1 + tol) or np.any(eta < -tol) or np.any(eta > 1 + tol):
307285
px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]])
308286

309287
py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]])
@@ -313,40 +291,29 @@ def _search_indices_curvilinear_2d(
313291
aa = a[3] * b[2] - a[2] * b[3]
314292
bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3]
315293
cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1]
316-
if abs(aa) < 1e-12: # Rectilinear cell, or quasi
317-
eta = -cc / bb
318-
else:
319-
det2 = bb * bb - 4 * aa * cc
320-
if det2 > 0: # so, if det is nan we keep the xsi, eta from previous iter
321-
det = np.sqrt(det2)
322-
eta = (-bb + det) / (2 * aa)
323-
if abs(a[1] + a[3] * eta) < 1e-12: # this happens when recti cell rotated of 90deg
324-
xsi = ((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5
325-
else:
326-
xsi = (x - a[0] - a[2] * eta) / (a[1] + a[3] * eta)
327-
if xsi < 0 and eta < 0 and xi == 0 and yi == 0:
328-
_raise_field_out_of_bound_error(0, y, x)
329-
if xsi > 1 and eta > 1 and xi == grid.xdim - 1 and yi == grid.ydim - 1:
330-
_raise_field_out_of_bound_error(0, y, x)
331-
if xsi < -tol:
332-
xi -= 1
333-
elif xsi > 1 + tol:
334-
xi += 1
335-
if eta < -tol:
336-
yi -= 1
337-
elif eta > 1 + tol:
338-
yi += 1
294+
295+
det2 = bb * bb - 4 * aa * cc
296+
det = np.where(det2 > 0, np.sqrt(det2), eta)
297+
eta = np.where(abs(aa) < 1e-12, -cc / bb, np.where(det2 > 0, (-bb + det) / (2 * aa), eta))
298+
299+
xsi = np.where(
300+
abs(a[1] + a[3] * eta) < 1e-12,
301+
((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5,
302+
(x - a[0] - a[2] * eta) / (a[1] + a[3] * eta),
303+
)
304+
305+
xi = np.where(xsi < -tol, xi - 1, np.where(xsi > 1 + tol, xi + 1, xi))
306+
yi = np.where(eta < -tol, yi - 1, np.where(eta > 1 + tol, yi + 1, yi))
307+
339308
(yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid.mesh)
340309
it += 1
341310
if it > maxIterSearch:
342311
print(f"Correct cell not found after {maxIterSearch} iterations")
343312
_raise_field_out_of_bound_error(0, y, x)
344-
xsi = max(0.0, xsi)
345-
eta = max(0.0, eta)
346-
xsi = min(1.0, xsi)
347-
eta = min(1.0, eta)
313+
xsi = np.where(xsi < 0.0, 0.0, np.where(xsi > 1.0, 1.0, xsi))
314+
eta = np.where(eta < 0.0, 0.0, np.where(eta > 1.0, 1.0, eta))
348315

349-
if not ((0 <= xsi <= 1) and (0 <= eta <= 1)):
316+
if np.any((xsi < 0) | (xsi > 1) | (eta < 0) | (eta > 1)):
350317
_raise_field_sampling_error(y, x)
351318
return (yi, eta, xi, xsi)
352319

@@ -442,20 +409,12 @@ def _search_indices_curvilinear(field, time, z, y, x, ti, particle=None, search2
442409

443410

444411
def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, sphere_mesh: bool):
445-
if xi < 0:
446-
if sphere_mesh:
447-
xi = xdim - 2
448-
else:
449-
xi = 0
450-
if xi > xdim - 2:
451-
if sphere_mesh:
452-
xi = 0
453-
else:
454-
xi = xdim - 2
455-
if yi < 0:
456-
yi = 0
457-
if yi > ydim - 2:
458-
yi = ydim - 2
459-
if sphere_mesh:
460-
xi = xdim - xi
412+
xi = np.where(xi < 0, (xdim - 2) if sphere_mesh else 0, xi)
413+
xi = np.where(xi > xdim - 2, 0 if sphere_mesh else (xdim - 2), xi)
414+
415+
xi = np.where(yi > ydim - 2, xdim - xi if sphere_mesh else xi, xi)
416+
417+
yi = np.where(yi < 0, 0, yi)
418+
yi = np.where(yi > ydim - 2, ydim - 2, yi)
419+
461420
return yi, xi

parcels/application_kernels/advection.py

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ def AdvectionRK4(particle, fieldset, time): # pragma: no cover
2121
dt = particle.dt / np.timedelta64(1, "s") # TODO: improve API for converting dt to seconds
2222
(u1, v1) = fieldset.UV[particle]
2323
lon1, lat1 = (particle.lon + u1 * 0.5 * dt, particle.lat + v1 * 0.5 * dt)
24-
(u2, v2) = fieldset.UV[time + 0.5 * particle.dt, particle.depth, lat1, lon1, particle]
24+
(u2, v2) = fieldset.UV[particle.time + 0.5 * particle.dt, particle.depth, lat1, lon1, particle]
2525
lon2, lat2 = (particle.lon + u2 * 0.5 * dt, particle.lat + v2 * 0.5 * dt)
26-
(u3, v3) = fieldset.UV[time + 0.5 * particle.dt, particle.depth, lat2, lon2, particle]
26+
(u3, v3) = fieldset.UV[particle.time + 0.5 * particle.dt, particle.depth, lat2, lon2, particle]
2727
lon3, lat3 = (particle.lon + u3 * dt, particle.lat + v3 * dt)
28-
(u4, v4) = fieldset.UV[time + particle.dt, particle.depth, lat3, lon3, particle]
28+
(u4, v4) = fieldset.UV[particle.time + particle.dt, particle.depth, lat3, lon3, particle]
2929
particle.dlon += (u1 + 2 * u2 + 2 * u3 + u4) / 6.0 * dt
3030
particle.dlat += (v1 + 2 * v2 + 2 * v3 + v4) / 6.0 * dt
3131

@@ -37,15 +37,15 @@ def AdvectionRK4_3D(particle, fieldset, time): # pragma: no cover
3737
lon1 = particle.lon + u1 * 0.5 * dt
3838
lat1 = particle.lat + v1 * 0.5 * dt
3939
dep1 = particle.depth + w1 * 0.5 * dt
40-
(u2, v2, w2) = fieldset.UVW[time + 0.5 * particle.dt, dep1, lat1, lon1, particle]
40+
(u2, v2, w2) = fieldset.UVW[particle.time + 0.5 * particle.dt, dep1, lat1, lon1, particle]
4141
lon2 = particle.lon + u2 * 0.5 * dt
4242
lat2 = particle.lat + v2 * 0.5 * dt
4343
dep2 = particle.depth + w2 * 0.5 * dt
44-
(u3, v3, w3) = fieldset.UVW[time + 0.5 * particle.dt, dep2, lat2, lon2, particle]
44+
(u3, v3, w3) = fieldset.UVW[particle.time + 0.5 * particle.dt, dep2, lat2, lon2, particle]
4545
lon3 = particle.lon + u3 * dt
4646
lat3 = particle.lat + v3 * dt
4747
dep3 = particle.depth + w3 * dt
48-
(u4, v4, w4) = fieldset.UVW[time + particle.dt, dep3, lat3, lon3, particle]
48+
(u4, v4, w4) = fieldset.UVW[particle.time + particle.dt, dep3, lat3, lon3, particle]
4949
particle.dlon += (u1 + 2 * u2 + 2 * u3 + u4) / 6 * dt
5050
particle.dlat += (v1 + 2 * v2 + 2 * v3 + v4) / 6 * dt
5151
particle.ddepth += (w1 + 2 * w2 + 2 * w3 + w4) / 6 * dt
@@ -56,35 +56,35 @@ def AdvectionRK4_3D_CROCO(particle, fieldset, time): # pragma: no cover
5656
This kernel assumes the vertical velocity is the 'w' field from CROCO output and works on sigma-layers.
5757
"""
5858
dt = particle.dt / np.timedelta64(1, "s") # TODO: improve API for converting dt to seconds
59-
sig_dep = particle.depth / fieldset.H[time, 0, particle.lat, particle.lon]
59+
sig_dep = particle.depth / fieldset.H[particle.time, 0, particle.lat, particle.lon]
6060

61-
(u1, v1, w1) = fieldset.UVW[time, particle.depth, particle.lat, particle.lon, particle]
62-
w1 *= sig_dep / fieldset.H[time, 0, particle.lat, particle.lon]
61+
(u1, v1, w1) = fieldset.UVW[particle.time, particle.depth, particle.lat, particle.lon, particle]
62+
w1 *= sig_dep / fieldset.H[particle.time, 0, particle.lat, particle.lon]
6363
lon1 = particle.lon + u1 * 0.5 * dt
6464
lat1 = particle.lat + v1 * 0.5 * dt
6565
sig_dep1 = sig_dep + w1 * 0.5 * dt
66-
dep1 = sig_dep1 * fieldset.H[time, 0, lat1, lon1]
66+
dep1 = sig_dep1 * fieldset.H[particle.time, 0, lat1, lon1]
6767

68-
(u2, v2, w2) = fieldset.UVW[time + 0.5 * particle.dt, dep1, lat1, lon1, particle]
69-
w2 *= sig_dep1 / fieldset.H[time, 0, lat1, lon1]
68+
(u2, v2, w2) = fieldset.UVW[particle.time + 0.5 * particle.dt, dep1, lat1, lon1, particle]
69+
w2 *= sig_dep1 / fieldset.H[particle.time, 0, lat1, lon1]
7070
lon2 = particle.lon + u2 * 0.5 * dt
7171
lat2 = particle.lat + v2 * 0.5 * dt
7272
sig_dep2 = sig_dep + w2 * 0.5 * dt
73-
dep2 = sig_dep2 * fieldset.H[time, 0, lat2, lon2]
73+
dep2 = sig_dep2 * fieldset.H[particle.time, 0, lat2, lon2]
7474

75-
(u3, v3, w3) = fieldset.UVW[time + 0.5 * particle.dt, dep2, lat2, lon2, particle]
76-
w3 *= sig_dep2 / fieldset.H[time, 0, lat2, lon2]
75+
(u3, v3, w3) = fieldset.UVW[particle.time + 0.5 * particle.dt, dep2, lat2, lon2, particle]
76+
w3 *= sig_dep2 / fieldset.H[particle.time, 0, lat2, lon2]
7777
lon3 = particle.lon + u3 * dt
7878
lat3 = particle.lat + v3 * dt
7979
sig_dep3 = sig_dep + w3 * dt
80-
dep3 = sig_dep3 * fieldset.H[time, 0, lat3, lon3]
80+
dep3 = sig_dep3 * fieldset.H[particle.time, 0, lat3, lon3]
8181

82-
(u4, v4, w4) = fieldset.UVW[time + particle.dt, dep3, lat3, lon3, particle]
83-
w4 *= sig_dep3 / fieldset.H[time, 0, lat3, lon3]
82+
(u4, v4, w4) = fieldset.UVW[particle.time + particle.dt, dep3, lat3, lon3, particle]
83+
w4 *= sig_dep3 / fieldset.H[particle.time, 0, lat3, lon3]
8484
lon4 = particle.lon + u4 * dt
8585
lat4 = particle.lat + v4 * dt
8686
sig_dep4 = sig_dep + w4 * dt
87-
dep4 = sig_dep4 * fieldset.H[time, 0, lat4, lon4]
87+
dep4 = sig_dep4 * fieldset.H[particle.time, 0, lat4, lon4]
8888

8989
particle.dlon += (u1 + 2 * u2 + 2 * u3 + u4) / 6 * dt
9090
particle.dlat += (v1 + 2 * v2 + 2 * v3 + v4) / 6 * dt
@@ -115,14 +115,7 @@ def AdvectionRK45(particle, fieldset, time): # pragma: no cover
115115
Time-step dt is halved if error is larger than fieldset.RK45_tol,
116116
and doubled if error is smaller than 1/10th of tolerance.
117117
"""
118-
dt = particle.next_dt / np.timedelta64(1, "s") # TODO: improve API for converting dt to seconds
119-
if dt > fieldset.RK45_max_dt:
120-
dt = fieldset.RK45_max_dt
121-
particle.next_dt = fieldset.RK45_max_dt * np.timedelta64(1, "s")
122-
if dt < fieldset.RK45_min_dt:
123-
particle.next_dt = fieldset.RK45_min_dt * np.timedelta64(1, "s")
124-
return StatusCode.Repeat
125-
particle.dt = particle.next_dt
118+
dt = particle.dt / np.timedelta64(1, "s") # TODO: improve API for converting dt to seconds
126119

127120
c = [1.0 / 4.0, 3.0 / 8.0, 12.0 / 13.0, 1.0, 1.0 / 2.0]
128121
A = [
@@ -137,42 +130,58 @@ def AdvectionRK45(particle, fieldset, time): # pragma: no cover
137130

138131
(u1, v1) = fieldset.UV[particle]
139132
lon1, lat1 = (particle.lon + u1 * A[0][0] * dt, particle.lat + v1 * A[0][0] * dt)
140-
(u2, v2) = fieldset.UV[time + c[0] * particle.dt, particle.depth, lat1, lon1, particle]
133+
(u2, v2) = fieldset.UV[particle.time + c[0] * particle.dt, particle.depth, lat1, lon1, particle]
141134
lon2, lat2 = (
142135
particle.lon + (u1 * A[1][0] + u2 * A[1][1]) * dt,
143136
particle.lat + (v1 * A[1][0] + v2 * A[1][1]) * dt,
144137
)
145-
(u3, v3) = fieldset.UV[time + c[1] * particle.dt, particle.depth, lat2, lon2, particle]
138+
(u3, v3) = fieldset.UV[particle.time + c[1] * particle.dt, particle.depth, lat2, lon2, particle]
146139
lon3, lat3 = (
147140
particle.lon + (u1 * A[2][0] + u2 * A[2][1] + u3 * A[2][2]) * dt,
148141
particle.lat + (v1 * A[2][0] + v2 * A[2][1] + v3 * A[2][2]) * dt,
149142
)
150-
(u4, v4) = fieldset.UV[time + c[2] * particle.dt, particle.depth, lat3, lon3, particle]
143+
(u4, v4) = fieldset.UV[particle.time + c[2] * particle.dt, particle.depth, lat3, lon3, particle]
151144
lon4, lat4 = (
152145
particle.lon + (u1 * A[3][0] + u2 * A[3][1] + u3 * A[3][2] + u4 * A[3][3]) * dt,
153146
particle.lat + (v1 * A[3][0] + v2 * A[3][1] + v3 * A[3][2] + v4 * A[3][3]) * dt,
154147
)
155-
(u5, v5) = fieldset.UV[time + c[3] * particle.dt, particle.depth, lat4, lon4, particle]
148+
(u5, v5) = fieldset.UV[particle.time + c[3] * particle.dt, particle.depth, lat4, lon4, particle]
156149
lon5, lat5 = (
157150
particle.lon + (u1 * A[4][0] + u2 * A[4][1] + u3 * A[4][2] + u4 * A[4][3] + u5 * A[4][4]) * dt,
158151
particle.lat + (v1 * A[4][0] + v2 * A[4][1] + v3 * A[4][2] + v4 * A[4][3] + v5 * A[4][4]) * dt,
159152
)
160-
(u6, v6) = fieldset.UV[time + c[4] * particle.dt, particle.depth, lat5, lon5, particle]
153+
(u6, v6) = fieldset.UV[particle.time + c[4] * particle.dt, particle.depth, lat5, lon5, particle]
161154

162155
lon_4th = (u1 * b4[0] + u2 * b4[1] + u3 * b4[2] + u4 * b4[3] + u5 * b4[4]) * dt
163156
lat_4th = (v1 * b4[0] + v2 * b4[1] + v3 * b4[2] + v4 * b4[3] + v5 * b4[4]) * dt
164157
lon_5th = (u1 * b5[0] + u2 * b5[1] + u3 * b5[2] + u4 * b5[3] + u5 * b5[4] + u6 * b5[5]) * dt
165158
lat_5th = (v1 * b5[0] + v2 * b5[1] + v3 * b5[2] + v4 * b5[3] + v5 * b5[4] + v6 * b5[5]) * dt
166159

167-
kappa = math.sqrt(math.pow(lon_5th - lon_4th, 2) + math.pow(lat_5th - lat_4th, 2))
168-
if (kappa <= fieldset.RK45_tol) or (math.fabs(dt) < math.fabs(fieldset.RK45_min_dt)):
169-
particle.dlon += lon_4th
170-
particle.dlat += lat_4th
171-
if (kappa <= fieldset.RK45_tol / 10) and (math.fabs(dt * 2) <= math.fabs(fieldset.RK45_max_dt)):
172-
particle.next_dt *= 2
173-
else:
174-
particle.next_dt /= 2
175-
return StatusCode.Repeat
160+
kappa = np.sqrt(np.pow(lon_5th - lon_4th, 2) + np.pow(lat_5th - lat_4th, 2))
161+
162+
good_particles = (kappa <= fieldset.RK45_tol) | (np.fabs(dt) <= np.fabs(fieldset.RK45_min_dt))
163+
particle.dlon += np.where(good_particles, lon_5th, 0)
164+
particle.dlat += np.where(good_particles, lat_5th, 0)
165+
166+
increase_dt_particles = (
167+
good_particles & (kappa <= fieldset.RK45_tol / 10) & (np.fabs(dt * 2) <= np.fabs(fieldset.RK45_max_dt))
168+
)
169+
particle.dt = np.where(increase_dt_particles, particle.dt * 2, particle.dt)
170+
particle.dt = np.where(
171+
particle.dt > fieldset.RK45_max_dt * np.timedelta64(1, "s"),
172+
fieldset.RK45_max_dt * np.timedelta64(1, "s"),
173+
particle.dt,
174+
)
175+
particle.state = np.where(good_particles, StatusCode.Success, particle.state)
176+
177+
repeat_particles = np.invert(good_particles)
178+
particle.dt = np.where(repeat_particles, particle.dt / 2, particle.dt)
179+
particle.dt = np.where(
180+
particle.dt < fieldset.RK45_min_dt * np.timedelta64(1, "s"),
181+
fieldset.RK45_min_dt * np.timedelta64(1, "s"),
182+
particle.dt,
183+
)
184+
particle.state = np.where(repeat_particles, StatusCode.Repeat, particle.state)
176185

177186

178187
def AdvectionAnalytical(particle, fieldset, time): # pragma: no cover

0 commit comments

Comments
 (0)