Skip to content

Commit 167b3e9

Browse files
Fixed incompatibility between step_ts/jump_ts and implicit solvers.
1 parent 034ead8 commit 167b3e9

File tree

2 files changed

+50
-9
lines changed

2 files changed

+50
-9
lines changed

diffrax/_step_size_controller/clip.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Callable
1+
from collections.abc import Callable, Sequence
22
from typing import cast, Generic, Optional, TypeVar
33

44
import equinox as eqx
@@ -19,7 +19,7 @@
1919
from .._misc import upcast_or_raise
2020
from .._solution import is_okay, RESULTS
2121
from .._term import AbstractTerm
22-
from .base import AbstractStepSizeController
22+
from .base import AbstractAdaptiveStepSizeController
2323

2424

2525
_ControllerState = TypeVar("_ControllerState")
@@ -118,7 +118,7 @@ def cond_down(_i):
118118

119119

120120
class ClipStepSizeController(
121-
AbstractStepSizeController[_ClipState[_ControllerState], _Dt0]
121+
AbstractAdaptiveStepSizeController[_ClipState[_ControllerState], _Dt0]
122122
):
123123
"""Wraps an existing step controller with three pieces of functionality:
124124
@@ -166,20 +166,32 @@ class ClipStepSizeController(
166166
```
167167
"""
168168

169-
controller: AbstractStepSizeController[_ControllerState, _Dt0]
169+
controller: AbstractAdaptiveStepSizeController[_ControllerState, _Dt0]
170170
step_ts: Optional[Real[Array, " steps"]]
171171
jump_ts: Optional[Real[Array, " jumps"]]
172172
store_rejected_steps: Optional[int] = eqx.field(static=True)
173173
callback_on_reject: Optional[Callable] = eqx.field(static=True)
174174

175+
@property
176+
def atol(self):
177+
return self.controller.atol
178+
179+
@property
180+
def rtol(self):
181+
return self.controller.rtol
182+
183+
@property
184+
def norm(self): # pyright: ignore[reportIncompatibleMethodOverride]
185+
return self.controller.norm
186+
175187
@eqxi.doc_remove_args("_callback_on_reject")
176188
def __init__(
177189
self,
178-
controller,
179-
step_ts=None,
180-
jump_ts=None,
181-
store_rejected_steps=None,
182-
_callback_on_reject=None,
190+
controller: AbstractAdaptiveStepSizeController[_ControllerState, _Dt0],
191+
step_ts: None | Sequence[RealScalarLike] | Real[Array, " steps"] = None,
192+
jump_ts: None | Sequence[RealScalarLike] | Real[Array, " jumps"] = None,
193+
store_rejected_steps: Optional[int] = None,
194+
_callback_on_reject: Optional[Callable] = None,
183195
):
184196
"""**Arguments**:
185197
@@ -198,6 +210,11 @@ def __init__(
198210
that this is not the total number of rejected steps in a solve, but just the
199211
maximum number of *consecutive* rejected steps.)
200212
"""
213+
if not isinstance(controller, AbstractAdaptiveStepSizeController):
214+
raise ValueError(
215+
"Can only apply `ClipStepSizeController` to adaptive step size "
216+
f"controllers, but got {controller}."
217+
)
201218
self.controller = controller
202219
self.step_ts = _none_or_sorted_array(step_ts)
203220
self.jump_ts = _none_or_sorted_array(jump_ts)

test/test_adaptive_stepsize_controller.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,27 @@ def test_find_idx_with_hint():
312312
assert idx == 3 # not 2; we want the first value *strictly* greater.
313313
idx = _find_idx_with_hint(1.9, ts, hint)
314314
assert idx == 2
315+
316+
317+
# https://github.com/patrick-kidger/diffrax/issues/607
318+
@pytest.mark.parametrize("new", (False, True))
319+
def test_implicit_solver_with_clip_controller(new: bool):
320+
term = diffrax.ODETerm(lambda t, y, args: -y)
321+
solver = diffrax.Kvaerno3()
322+
if new:
323+
ssc = diffrax.PIDController(rtol=1e-3, atol=1e-3)
324+
ssc = diffrax.ClipStepSizeController(ssc, jump_ts=[0.5])
325+
else:
326+
ssc = diffrax.PIDController(jump_ts=[0.5], rtol=1e-3, atol=1e-3) # pyright: ignore[reportCallIssue]
327+
diffrax.diffeqsolve(
328+
term,
329+
solver,
330+
t0=0,
331+
t1=1,
332+
dt0=0.01,
333+
args=None,
334+
y0=1.0,
335+
stepsize_controller=ssc,
336+
max_steps=16384,
337+
saveat=diffrax.SaveAt(t1=True),
338+
)

0 commit comments

Comments
 (0)