1- from collections .abc import Callable
1+ from collections .abc import Callable , Sequence
22from typing import cast , Generic , Optional , TypeVar
33
44import equinox as eqx
1919from .._misc import upcast_or_raise
2020from .._solution import is_okay , RESULTS
2121from .._term import AbstractTerm
22- from .base import AbstractStepSizeController
22+ from .base import AbstractAdaptiveStepSizeController
2323
2424
2525_ControllerState = TypeVar ("_ControllerState" )
@@ -118,7 +118,7 @@ def cond_down(_i):
118118
119119
120120class ClipStepSizeController (
121- AbstractStepSizeController [_ClipState [_ControllerState ], _Dt0 ]
121+ AbstractAdaptiveStepSizeController [_ClipState [_ControllerState ], _Dt0 ]
122122):
123123 """Wraps an existing step controller with three pieces of functionality:
124124
@@ -166,20 +166,32 @@ class ClipStepSizeController(
166166 ```
167167 """
168168
169- controller : AbstractStepSizeController [_ControllerState , _Dt0 ]
169+ controller : AbstractAdaptiveStepSizeController [_ControllerState , _Dt0 ]
170170 step_ts : Optional [Real [Array , " steps" ]]
171171 jump_ts : Optional [Real [Array , " jumps" ]]
172172 store_rejected_steps : Optional [int ] = eqx .field (static = True )
173173 callback_on_reject : Optional [Callable ] = eqx .field (static = True )
174174
175+ @property
176+ def atol (self ):
177+ return self .controller .atol
178+
179+ @property
180+ def rtol (self ):
181+ return self .controller .rtol
182+
183+ @property
184+ def norm (self ): # pyright: ignore[reportIncompatibleMethodOverride]
185+ return self .controller .norm
186+
175187 @eqxi .doc_remove_args ("_callback_on_reject" )
176188 def __init__ (
177189 self ,
178- controller ,
179- step_ts = None ,
180- jump_ts = None ,
181- store_rejected_steps = None ,
182- _callback_on_reject = None ,
190+ controller : AbstractAdaptiveStepSizeController [ _ControllerState , _Dt0 ] ,
191+ step_ts : None | Sequence [ RealScalarLike ] | Real [ Array , " steps" ] = None ,
192+ jump_ts : None | Sequence [ RealScalarLike ] | Real [ Array , " jumps" ] = None ,
193+ store_rejected_steps : Optional [ int ] = None ,
194+ _callback_on_reject : Optional [ Callable ] = None ,
183195 ):
184196 """**Arguments**:
185197
@@ -198,6 +210,11 @@ def __init__(
198210 that this is not the total number of rejected steps in a solve, but just the
199211 maximum number of *consecutive* rejected steps.)
200212 """
213+ if not isinstance (controller , AbstractAdaptiveStepSizeController ):
214+ raise ValueError (
215+ "Can only apply `ClipStepSizeController` to adaptive step size "
216+ f"controllers, but got { controller } ."
217+ )
201218 self .controller = controller
202219 self .step_ts = _none_or_sorted_array (step_ts )
203220 self .jump_ts = _none_or_sorted_array (jump_ts )
0 commit comments