Skip to content

Commit 4a308b8

Browse files
Improves error messages for mismatched terms.
Serendipitously, this can also use the new `wadler_lindig` library for pretty-printing complicated types.
1 parent c9a214d commit 4a308b8

File tree

5 files changed

+32
-4
lines changed

5 files changed

+32
-4
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ repos:
1111
rev: v1.1.350
1212
hooks:
1313
- id: pyright
14-
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typeguard==2.13.3, typing_extensions]
14+
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typeguard==2.13.3, typing_extensions, wadler_lindig]

diffrax/_custom_types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,14 @@ class SpaceTimeTimeLevyArea(AbstractSpaceTimeTimeLevyArea):
112112
K: BM
113113

114114

115+
AbstractBrownianIncrement.__module__ = "diffrax"
116+
AbstractSpaceTimeLevyArea.__module__ = "diffrax"
117+
AbstractSpaceTimeTimeLevyArea.__module__ = "diffrax"
118+
BrownianIncrement.__module__ = "diffrax"
119+
SpaceTimeLevyArea.__module__ = "diffrax"
120+
SpaceTimeTimeLevyArea.__module__ = "diffrax"
121+
122+
115123
def levy_tree_transpose(
116124
tree_shape, tree: PyTree[AbstractBrownianIncrement]
117125
) -> AbstractBrownianIncrement:

diffrax/_integrate.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import lineax.internal as lxi
2323
import numpy as np
2424
import optimistix as optx
25+
import wadler_lindig as wl
2526
from jaxtyping import Array, ArrayLike, Float, Inexact, PyTree, Real
2627

2728
from ._adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint
@@ -184,7 +185,11 @@ def _check(term_cls, term, term_contr_kwargs, yi):
184185
better_isinstance, control_type, control_type_expected
185186
)
186187
if not control_type_compatible:
187-
raise ValueError(f"Control term {term} is incompatible.")
188+
raise ValueError(
189+
"Control term is incompatible: the returned control (e.g. "
190+
f"Brownian motion for an SDE) was {control_type}, but this "
191+
f"solver expected {control_type_expected}."
192+
)
188193
else:
189194
assert False, "Malformed term structure"
190195
# If we've got to this point then the term is compatible
@@ -194,7 +199,13 @@ def _check(term_cls, term, term_contr_kwargs, yi):
194199
jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
195200
except Exception as e:
196201
# ValueError may also arise from mismatched tree structures
197-
raise ValueError("Terms are not compatible with solver!") from e
202+
pretty_term = wl.pformat(terms)
203+
pretty_expected = wl.pformat(term_structure)
204+
raise ValueError(
205+
f"Terms are not compatible with solver! Got:\n{pretty_term}\nbut expected:"
206+
f"\n{pretty_expected}\nNote that terms are checked recursively: if you "
207+
"scroll up you may find a root-cause error that is more specific."
208+
) from e
198209

199210

200211
def _is_subsaveat(x: Any) -> bool:

diffrax/_term.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,3 +981,12 @@ def prod(
981981
self, vf: UnderdampedLangevinTuple, control: RealScalarLike
982982
) -> UnderdampedLangevinTuple:
983983
return jtu.tree_map(lambda _vf: control * _vf, vf)
984+
985+
986+
AbstractTerm.__module__ = "diffrax"
987+
ODETerm.__module__ = "diffrax"
988+
ControlTerm.__module__ = "diffrax"
989+
WeaklyDiagonalControlTerm.__module__ = "diffrax"
990+
MultiTerm.__module__ = "diffrax"
991+
UnderdampedLangevinDriftTerm.__module__ = "diffrax"
992+
UnderdampedLangevinDiffusionTerm.__module__ = "diffrax"

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ classifiers = [
2323
"Topic :: Scientific/Engineering :: Mathematics",
2424
]
2525
urls = {repository = "https://github.com/patrick-kidger/diffrax" }
26-
dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.10", "lineax>=0.0.5", "optimistix>=0.0.10"]
26+
dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.10", "lineax>=0.0.5", "optimistix>=0.0.10", "wadler_lindig>=0.1.1"]
2727

2828
[build-system]
2929
requires = ["hatchling"]

0 commit comments

Comments
 (0)