Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions parcels/application_kernels/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

def AdvectionRK4(particle, fieldset, time): # pragma: no cover
"""Advection of particles using fourth-order Runge-Kutta integration."""
dt = particle.dt / np.timedelta64(1, "s") # noqa TODO improve API for converting dt to seconds
dt = particle.dt
(u1, v1) = fieldset.UV[particle]
lon1, lat1 = (particle.lon + u1 * 0.5 * dt, particle.lat + v1 * 0.5 * dt)
(u2, v2) = fieldset.UV[time + 0.5 * particle.dt, particle.depth, lat1, lon1, particle]
Expand All @@ -30,7 +30,7 @@ def AdvectionRK4(particle, fieldset, time): # pragma: no cover

def AdvectionRK4_3D(particle, fieldset, time): # pragma: no cover
"""Advection of particles using fourth-order Runge-Kutta integration including vertical velocity."""
dt = particle.dt / np.timedelta64(1, "s") # noqa TODO improve API for converting dt to seconds
dt = particle.dt
(u1, v1, w1) = fieldset.UVW[particle]
lon1 = particle.lon + u1 * 0.5 * dt
lat1 = particle.lat + v1 * 0.5 * dt
Expand All @@ -53,7 +53,7 @@ def AdvectionRK4_3D_CROCO(particle, fieldset, time): # pragma: no cover
"""Advection of particles using fourth-order Runge-Kutta integration including vertical velocity.
This kernel assumes the vertical velocity is the 'w' field from CROCO output and works on sigma-layers.
"""
dt = particle.dt / np.timedelta64(1, "s") # noqa TODO improve API for converting dt to seconds
dt = particle.dt
sig_dep = particle.depth / fieldset.H[time, 0, particle.lat, particle.lon]

(u1, v1, w1) = fieldset.UVW[time, particle.depth, particle.lat, particle.lon, particle]
Expand Down Expand Up @@ -97,7 +97,7 @@ def AdvectionRK4_3D_CROCO(particle, fieldset, time): # pragma: no cover

def AdvectionEE(particle, fieldset, time): # pragma: no cover
"""Advection of particles using Explicit Euler (aka Euler Forward) integration."""
dt = particle.dt / np.timedelta64(1, "s") # noqa TODO improve API for converting dt to seconds
dt = particle.dt
(u1, v1) = fieldset.UV[particle]
particle_dlon += u1 * dt # noqa
particle_dlat += v1 * dt # noqa
Expand All @@ -113,7 +113,7 @@ def AdvectionRK45(particle, fieldset, time): # pragma: no cover
Time-step dt is halved if error is larger than fieldset.RK45_tol,
and doubled if error is smaller than 1/10th of tolerance.
"""
dt = min(particle.next_dt, fieldset.RK45_max_dt) / np.timedelta64(1, "s") # noqa TODO improve API for converting dt to seconds
dt = min(particle.next_dt, fieldset.RK45_max_dt)
c = [1.0 / 4.0, 3.0 / 8.0, 12.0 / 13.0, 1.0, 1.0 / 2.0]
A = [
[1.0 / 4.0, 0.0, 0.0, 0.0, 0.0],
Expand Down Expand Up @@ -178,7 +178,7 @@ def AdvectionAnalytical(particle, fieldset, time): # pragma: no cover

tol = 1e-10
I_s = 10 # number of intermediate time steps
dt = particle.dt / np.timedelta64(1, "s") # TODO improve API for converting dt to seconds
dt = particle.dt
direction = 1.0 if dt > 0 else -1.0
withW = True if "W" in [f.name for f in fieldset.fields.values()] else False
withTime = True if len(fieldset.U.grid.time) > 1 else False
Expand Down
2 changes: 1 addition & 1 deletion parcels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def execute(self, pset, endtime, dt):
"""Execute this Kernel over a ParticleSet for several timesteps."""
pset._data["state"][:] = StatusCode.Evaluate

if abs(dt) < np.timedelta64(1000, "ns"): # TODO still needed?
if abs(dt) < 1e-6:
warnings.warn(
"'dt' is too small, causing numerical accuracy limit problems. Please chose a higher 'dt' and rather scale the 'time' axis of the field accordingly. (related issue #762)",
RuntimeWarning,
Expand Down
61 changes: 41 additions & 20 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tqdm import tqdm

from parcels._core.utils.time import TimeInterval
from parcels._core.utils.time import is_compatible as time_is_compatible
from parcels._reprs import particleset_repr
from parcels.application_kernels.advection import AdvectionRK4
from parcels.basegrid import GridType
Expand Down Expand Up @@ -108,12 +109,7 @@ def __init__(
depth = convert_to_flat_array(depth)
assert lon.size == lat.size and lon.size == depth.size, "lon, lat, depth don't all have the same lenghts"

if time is None or len(time) == 0:
time = np.datetime64("NaT", "ns") # do not set a time yet (because sign_dt not known)
elif type(time[0]) in [np.datetime64, np.timedelta64]:
pass # already in the right format
else:
raise TypeError("particle time must be a datetime, timedelta, or date object")
time = _get_release_times_from_interval_start(time, self.fieldset.time_interval)
time = np.repeat(time, lon.size) if time.size == 1 else time

assert lon.size == time.size, "time and positions (lon, lat, depth) do not have the same lengths."
Expand All @@ -140,7 +136,7 @@ def __init__(
"lat": lat.astype(lonlatdepth_dtype),
"depth": depth.astype(lonlatdepth_dtype),
"time": time,
"dt": np.timedelta64(1, "ns") * np.ones(len(trajectory_ids)),
"dt": np.ones(len(trajectory_ids), dtype=np.float64),
# "ei": (["trajectory", "ngrid"], np.zeros((len(trajectory_ids), len(fieldset.gridset)), dtype=np.int32)),
"state": np.zeros((len(trajectory_ids)), dtype=np.int32),
"lon_nextloop": lon.astype(lonlatdepth_dtype),
Expand Down Expand Up @@ -733,37 +729,42 @@ def execute(
if output_file:
output_file.metadata["parcels_kernels"] = self._kernel.name

if (dt is not None) and (not isinstance(dt, np.timedelta64)):
raise TypeError("dt must be a np.timedelta64 object")
if dt is None or np.isnat(dt):
if dt is None:
dt = np.timedelta64(1, "s")
if not isinstance(dt, np.timedelta64):
raise TypeError("dt must be a np.timedelta64 object")

dt /= np.timedelta64(1, "s")

self._data["dt"][:] = dt
sign_dt = np.sign(dt).astype(int)
if sign_dt not in [-1, 1]:
raise ValueError("dt must be a positive or negative np.timedelta64 object")

if self.fieldset.time_interval is None:
start_time = np.timedelta64(0, "s") # For the execution loop, we need a start time as a timedelta object
start_time = 0 # For the execution loop, we need a start time as a timedelta object
if runtime is None:
raise TypeError("The runtime must be provided when the time_interval is not defined for a fieldset.")

else:
if isinstance(runtime, np.timedelta64):
end_time = runtime
end_time = runtime / np.timedelta64(1, "s")
else:
raise TypeError("The runtime must be a np.timedelta64 object")

else:
if not np.isnat(self._data["time_nextloop"]).any():
if not np.isnan(self._data["time_nextloop"]).any():
if sign_dt > 0:
start_time = self._data["time_nextloop"].min()
else:
start_time = self._data["time_nextloop"].max()
else:
if sign_dt > 0:
start_time = self.fieldset.time_interval.left
start_time = 0.0
else:
start_time = self.fieldset.time_interval.right
start_time = (
self.fieldset.time_interval.right - self.fieldset.time_interval.left
) / np.timedelta64(1, "s")

if runtime is None:
if endtime is None:
Expand All @@ -784,11 +785,16 @@ def execute(
end_time = max(endtime, self.fieldset.time_interval.left)
else:
raise TypeError("The endtime must be of the same type as the fieldset.time_interval start time.")
end_time = (end_time - self.fieldset.time_interval.left) / np.timedelta64(1, "s")
else:
if isinstance(runtime, np.timedelta64):
runtime = runtime / np.timedelta64(1, "s")
elif isinstance(runtime, (int, float)):
raise TypeError("The runtime must be a np.timedelta64 object")
end_time = start_time + runtime * sign_dt

# Set the time of the particles if it hadn't been set on initialisation
if np.isnat(self._data["time"]).any():
if np.isnan(self._data["time"]).any():
self._data["time"][:] = start_time
self._data["time_nextloop"][:] = start_time

Expand All @@ -799,7 +805,7 @@ def execute(
logger.info(f"Output files are stored in {output_file.fname}.")

if verbose_progress:
pbar = tqdm(total=(end_time - start_time) / np.timedelta64(1, "s"), file=sys.stdout)
pbar = tqdm(total=(end_time - start_time), file=sys.stdout)

next_output = outputdt if output_file else None

Expand All @@ -822,14 +828,29 @@ def execute(
next_output += outputdt

if verbose_progress:
pbar.update((next_time - time) / np.timedelta64(1, "s"))
pbar.update(next_time - time)

time = next_time

if verbose_progress:
pbar.close()


def _get_release_times_from_interval_start(time: np.ndarray | None, time_interval: TimeInterval | None) -> np.ndarray:
if time is None or len(time) == 0:
return np.array([np.nan]) # do not set a time yet (because sign_dt not known)

if time_interval is None:
return (time - np.min(time)) / np.timedelta64(1, "s")

if time_is_compatible(time[0], time_interval.left):
raise TypeError(
f"Release times must be of a compatible type with the fieldset's time interval. Got {time[0]=!r}, but time interval is {time_interval=!r}."
)

return (time - time_interval.left) / np.timedelta64(1, "s")


def _warn_outputdt_release_desync(outputdt: float, starttime: float, release_times: Iterable[float]):
"""Gives the user a warning if the release time isn't a multiple of outputdt."""
if any((np.isfinite(t) and (t - starttime) % outputdt != 0) for t in release_times):
Expand All @@ -844,13 +865,13 @@ def _warn_outputdt_release_desync(outputdt: float, starttime: float, release_tim

def _warn_particle_times_outside_fieldset_time_bounds(release_times: np.ndarray, time: TimeInterval):
if np.any(release_times):
if np.any(release_times < time.left):
if np.any(release_times < 0):
warnings.warn(
"Some particles are set to be released outside the FieldSet's executable time domain.",
ParticleSetWarning,
stacklevel=2,
)
if np.any(release_times > time.right):
if np.any(release_times > (time.right - time.left) / np.timedelta64(1, "s")):
warnings.warn(
"Some particles are set to be released after the fieldset's last time and the fields are not constant in time.",
ParticleSetWarning,
Expand Down
20 changes: 18 additions & 2 deletions tests/v4/test_particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_pset_custominit_on_pclass(fieldset, pset_override):
@pytest.mark.parametrize(
"time, expectation",
[
(np.timedelta64(0, "s"), does_not_raise()),
(np.timedelta64(0, "s"), pytest.raises(TypeError)),
(np.datetime64("2000-01-02T00:00:00"), does_not_raise()),
(0.0, pytest.raises(TypeError)),
(timedelta(seconds=0), pytest.raises(TypeError)),
Expand All @@ -127,6 +127,22 @@ def test_pset_create_outside_time(fieldset):
ParticleSet(fieldset, pclass=Particle, lon=[0] * len(time), lat=[0] * len(time), time=time)


def test_pset_invalid_release_times(fieldset):
# define release times to be floats, and make sure that execution errors out informatively
...


def test_pset_incompatible_release_times(fieldset):
# define release times to be incompatible with the fieldset time interval, and make sure that execution errors out informatively
...


def test_get_release_times_from_interval_start(): ... # test inputs and returns


def test_get_release_times_from_interval_start_time_interval_none(): ...


@pytest.mark.parametrize(
"dt, expectation",
[
Expand All @@ -148,7 +164,7 @@ def test_pset_starttime_not_multiple_dt(fieldset):
pset = ParticleSet(fieldset, lon=[0] * len(times), lat=[0] * len(times), pclass=Particle, time=datetimes)

def Addlon(particle, fieldset, time): # pragma: no cover
particle_dlon += particle.dt / np.timedelta64(1, "s") # noqa
particle_dlon += particle.dt # noqa

pset.execute(Addlon, dt=np.timedelta64(2, "s"), runtime=np.timedelta64(8, "s"), verbose_progress=False)
assert np.allclose([p.lon_nextloop for p in pset], [8 - t for t in times])
Expand Down
30 changes: 15 additions & 15 deletions tests/v4/test_particleset_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def test_pset_stop_simulation(fieldset):
pset = ParticleSet(fieldset, lon=0, lat=0, pclass=Particle)

def Delete(particle, fieldset, time): # pragma: no cover
if time >= fieldset.time_interval.left + np.timedelta64(4, "s"):
if time >= 4:
return StatusCode.StopExecution

pset.execute(Delete, dt=np.timedelta64(1, "s"), runtime=np.timedelta64(21, "s"))
assert pset[0].time == fieldset.time_interval.left + np.timedelta64(4, "s")
assert pset[0].time == 4


@pytest.mark.parametrize("with_delete", [True, False])
Expand All @@ -74,45 +74,45 @@ def AddLat(particle, fieldset, time): # pragma: no cover

@pytest.mark.parametrize(
"starttime, endtime, dt",
[(0, 10, 1), (0, 10, 3), (2, 16, 3), (20, 10, -1), (20, 0, -2), (5, 15, None)],
[(0, 10, 1), (0, 10, 3), (2, 16, 3), (20, 10, -1), (20, 0, -2), (5, 15, 1)],
)
def test_execution_endtime(fieldset, starttime, endtime, dt):
starttime = fieldset.time_interval.left + np.timedelta64(starttime, "s")
endtime = fieldset.time_interval.left + np.timedelta64(endtime, "s")
endtime_date = fieldset.time_interval.left + np.timedelta64(endtime, "s")
dt = np.timedelta64(dt, "s")
pset = ParticleSet(fieldset, time=starttime, lon=0, lat=0)
pset.execute(DoNothing, endtime=endtime, dt=dt)
assert abs(pset.time_nextloop - endtime) < np.timedelta64(1, "ms")
pset.execute(DoNothing, endtime=endtime_date, dt=dt)
assert abs(pset.time_nextloop - endtime) < 1e-3


@pytest.mark.parametrize(
"starttime, runtime, dt",
[(0, 10, 1), (0, 10, 3), (2, 16, 3), (20, 10, -1), (20, 0, -2), (5, 15, None)],
[(0, 10, 1), (0, 10, 3), (2, 16, 3), (20, 10, -1), (20, 0, -2), (5, 15, 1)],
)
def test_execution_runtime(fieldset, starttime, runtime, dt):
starttime = fieldset.time_interval.left + np.timedelta64(starttime, "s")
runtime = np.timedelta64(runtime, "s")
starttime_date = fieldset.time_interval.left + np.timedelta64(starttime, "s")
runtime_date = np.timedelta64(runtime, "s")
sign_dt = 1 if dt is None else np.sign(dt)
dt = np.timedelta64(dt, "s")
pset = ParticleSet(fieldset, time=starttime, lon=0, lat=0)
pset.execute(DoNothing, runtime=runtime, dt=dt)
assert abs(pset.time_nextloop - starttime - runtime * sign_dt) < np.timedelta64(1, "ms")
pset = ParticleSet(fieldset, time=starttime_date, lon=0, lat=0)
pset.execute(DoNothing, runtime=runtime_date, dt=dt)
assert abs(pset.time_nextloop - starttime - runtime * sign_dt) < 1e-3


def test_execution_fail_python_exception(fieldset, npart=10):
pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.linspace(1, 0, npart))

def PythonFail(particle, fieldset, time): # pragma: no cover
if particle.time >= fieldset.time_interval.left + np.timedelta64(10, "s"):
if particle.time >= 10:
raise RuntimeError("Enough is enough!")
else:
pass

with pytest.raises(RuntimeError):
pset.execute(PythonFail, runtime=np.timedelta64(20, "s"), dt=np.timedelta64(2, "s"))
assert len(pset) == npart
assert pset.time[0] == fieldset.time_interval.left + np.timedelta64(10, "s")
assert all([time == fieldset.time_interval.left + np.timedelta64(0, "s") for time in pset.time[1:]])
assert pset.time[0] == 10
assert all([time == 0 for time in pset.time[1:]])


@pytest.mark.parametrize("verbose_progress", [True, False])
Expand Down