Skip to content

Commit a2b0618

Browse files
Merge pull request #2096 from OceanParcels/kernelloop-speedups
Kernelloop speedups
2 parents 488e3fb + de6c374 commit a2b0618

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

parcels/kernel.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def remove_deleted(self, pset):
7575
# TODO v4: need to implement ParticleFile writing of deleted particles
7676
# if len(indices) > 0 and self.fieldset.particlefile is not None:
7777
# self.fieldset.particlefile.write(pset, None, indices=indices)
78-
pset.remove_indices(indices)
78+
if len(indices) > 0:
79+
pset.remove_indices(indices)
7980

8081

8182
class Kernel(BaseKernel):
@@ -378,25 +379,24 @@ def evaluate_particle(self, p, endtime):
378379
dt :
379380
computational integration timestep
380381
"""
382+
sign_dt = 1 if p.dt >= 0 else -1
381383
while p.state in [StatusCode.Evaluate, StatusCode.Repeat]:
382-
pre_dt = p.dt
383-
384-
sign_dt = np.sign(p.dt).astype(int)
385-
if sign_dt * (endtime - p.time_nextloop) <= np.timedelta64(0, "ns"):
384+
if sign_dt * (endtime - p.time_nextloop) <= 0:
386385
return p
387386

387+
pre_dt = p.dt
388388
# TODO implement below later again
389389
# try: # Use next_dt from AdvectionRK45 if it is set
390390
# if abs(endtime - p.time_nextloop) < abs(p.next_dt) - 1e-6:
391391
# p.next_dt = abs(endtime - p.time_nextloop) * sign_dt
392392
# except AttributeError:
393-
if abs(endtime - p.time_nextloop) <= abs(p.dt):
394-
p.dt = abs(endtime - p.time_nextloop) * sign_dt
393+
if sign_dt * (endtime - p.time_nextloop) <= p.dt:
394+
p.dt = sign_dt * (endtime - p.time_nextloop)
395395
res = self._pyfunc(p, self._fieldset, p.time_nextloop)
396396

397397
if res is None:
398398
if p.state == StatusCode.Success:
399-
if sign_dt * (p.time - endtime) > np.timedelta64(0, "ns"):
399+
if sign_dt * (p.time - endtime) > 0:
400400
p.state = StatusCode.Evaluate
401401
else:
402402
p.state = res

parcels/particleset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -797,9 +797,9 @@ def execute(
797797
time = start_time
798798
while sign_dt * (time - end_time) < 0:
799799
if sign_dt > 0:
800-
next_time = min(time + dt, end_time)
800+
next_time = end_time # TODO update to min(next_output, end_time) when ParticleFile works
801801
else:
802-
next_time = max(time + dt, end_time)
802+
next_time = end_time # TODO update to max(next_output, end_time) when ParticleFile works
803803
res = self._kernel.execute(self, endtime=next_time, dt=dt)
804804
if res == StatusCode.StopAllExecution:
805805
return StatusCode.StopAllExecution
@@ -813,7 +813,7 @@ def execute(
813813
next_output += outputdt
814814

815815
if verbose_progress:
816-
pbar.update(dt / np.timedelta64(1, "s"))
816+
pbar.update((next_time - time) / np.timedelta64(1, "s"))
817817

818818
time = next_time
819819

tests/v4/test_particleset_execute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def PythonFail(particle, fieldset, time): # pragma: no cover
113113
pset.execute(PythonFail, runtime=np.timedelta64(20, "s"), dt=np.timedelta64(2, "s"))
114114
assert len(pset) == npart
115115
assert pset.time[0] == fieldset.time_interval.left + np.timedelta64(10, "s")
116-
assert all([time == fieldset.time_interval.left + np.timedelta64(8, "s") for time in pset.time[1:]])
116+
assert all([time == fieldset.time_interval.left + np.timedelta64(0, "s") for time in pset.time[1:]])
117117

118118

119119
@pytest.mark.parametrize("verbose_progress", [True, False])

0 commit comments

Comments
 (0)