Skip to content

Commit c9a214d

Browse files
SRKs now support forward-mode autodiff.
1 parent 9eb1bff commit c9a214d

File tree

3 files changed

+48
-10
lines changed

3 files changed

+48
-10
lines changed

diffrax/_adjoint.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616

1717
from ._heuristics import is_sde, is_unsafe_sde
1818
from ._saveat import save_y, SaveAt, SubSaveAt
19-
from ._solver import AbstractItoSolver, AbstractRungeKutta, AbstractStratonovichSolver
19+
from ._solver import (
20+
AbstractItoSolver,
21+
AbstractRungeKutta,
22+
AbstractSRK,
23+
AbstractStratonovichSolver,
24+
)
2025
from ._term import AbstractTerm, AdjointTerm
2126

2227

@@ -272,7 +277,7 @@ def loop(
272277
if is_unsafe_sde(terms):
273278
raise ValueError(
274279
"`adjoint=RecursiveCheckpointAdjoint()` does not support "
275-
"`UnsafeBrownianPath`. Consider using `adjoint=DirectAdjoint()` "
280+
"`UnsafeBrownianPath`. Consider using `adjoint=ForwardMode()` "
276281
"instead."
277282
)
278283
if self.checkpoints is None and max_steps is None:
@@ -376,7 +381,10 @@ def loop(
376381
msg = None
377382
# Support forward-mode autodiff.
378383
# TODO: remove this hack once we can JVP through custom_vjps.
379-
if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None:
384+
if (
385+
isinstance(solver, (AbstractRungeKutta, AbstractSRK))
386+
and solver.scan_kind is None
387+
):
380388
solver = eqx.tree_at(
381389
lambda s: s.scan_kind, solver, "bounded", is_leaf=_is_none
382390
)
@@ -888,7 +896,10 @@ def loop(
888896
outer_while_loop = eqx.Partial(_outer_loop, kind="lax")
889897
# Support forward-mode autodiff.
890898
# TODO: remove this hack once we can JVP through custom_vjps.
891-
if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None:
899+
if (
900+
isinstance(solver, (AbstractRungeKutta, AbstractSRK))
901+
and solver.scan_kind is None
902+
):
892903
solver = eqx.tree_at(lambda s: s.scan_kind, solver, "lax", is_leaf=_is_none)
893904
final_state = self._loop(
894905
solver=solver,

diffrax/_solver/srk.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import abc
22
from dataclasses import dataclass
3-
from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar, Union
3+
from typing import Any, Generic, Literal, Optional, TYPE_CHECKING, TypeVar, Union
44
from typing_extensions import TypeAlias
55

66
import equinox as eqx
@@ -255,6 +255,8 @@ class AbstractSRK(AbstractSolver[_SolverState]):
255255
as well as $b^H$, $a^H$, $b^K$, and $a^K$ if needed.
256256
"""
257257

258+
scan_kind: Union[None, Literal["lax", "checkpointed"]] = None
259+
258260
interpolation_cls = LocalLinearInterpolation
259261
term_compatible_contr_kwargs = (dict(), dict(use_levy=True))
260262
tableau: AbstractClassVar[StochasticButcherTableau]
@@ -583,7 +585,7 @@ def compute_and_insert_kg_j(_w_kgs_in, _levylist_kgs_in):
583585
scan_inputs,
584586
len(b_sol),
585587
buffers=lambda x: x,
586-
kind="checkpointed",
588+
kind="checkpointed" if self.scan_kind is None else self.scan_kind,
587589
checkpoints="all",
588590
)
589591

test/test_adjoint.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,10 +366,7 @@ def run(model):
366366
run(mlp)
367367

368368

369-
@pytest.mark.parametrize(
370-
"diffusion_fn",
371-
["weak", "lineax"],
372-
)
369+
@pytest.mark.parametrize("diffusion_fn", ["weak", "lineax"])
373370
def test_sde_against(diffusion_fn, getkey):
374371
def f(t, y, args):
375372
del t
@@ -427,3 +424,31 @@ def test_implicit_runge_kutta_direct_adjoint():
427424
adjoint=diffrax.DirectAdjoint(),
428425
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
429426
)
427+
428+
429+
@pytest.mark.parametrize("solver", (diffrax.Tsit5(), diffrax.GeneralShARK()))
430+
def test_forward_mode_runge_kutta(solver, getkey):
431+
# Totally fine that we're using Tsit5 with an SDE, it should converge to the
432+
# Stratonovich solution.
433+
bm = diffrax.UnsafeBrownianPath((), getkey(), levy_area=diffrax.SpaceTimeLevyArea)
434+
drift = diffrax.ODETerm(lambda t, y, args: -y)
435+
diffusion = diffrax.ControlTerm(lambda t, y, args: 0.1 * y, bm)
436+
terms = diffrax.MultiTerm(drift, diffusion)
437+
438+
def run(y0):
439+
sol = diffrax.diffeqsolve(
440+
terms,
441+
solver,
442+
0,
443+
1,
444+
0.01,
445+
y0,
446+
adjoint=diffrax.ForwardMode(),
447+
)
448+
return sol.ys
449+
450+
@jax.jit
451+
def run_jvp(y0):
452+
return jax.jvp(run, (y0,), (jnp.ones_like(y0),))
453+
454+
run_jvp(jnp.array(1.0))

0 commit comments

Comments
 (0)