Skip to content

Commit 9743eb8

Browse files
Randlpatrick-kidger
authored andcommitted
Enable implicit solvers for complex inputs (#411)
* Enable implicit solvers for complex inputs * change version * make pyright happy
1 parent 068b4b9 commit 9743eb8

File tree

5 files changed

+6
-19
lines changed

5 files changed

+6
-19
lines changed

diffrax/_integrate.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -709,11 +709,6 @@ def diffeqsolve(
709709
eqx.is_array(xi) and jnp.iscomplexobj(xi)
710710
for xi in jtu.tree_leaves((terms, y0, args))
711711
):
712-
if isinstance(solver, AbstractImplicitSolver):
713-
raise ValueError(
714-
"Implicit solvers in conjunction with complex dtypes is currently not "
715-
"supported."
716-
)
717712
warnings.warn(
718713
"Complex dtype support is work in progress, please read "
719714
"https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully.",

diffrax/_progress_meter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def _step(_progress, _idx):
294294
# Return the idx to thread the callbacks in the correct order.
295295
return _idx
296296

297-
return jax.pure_callback(_step, idx, progress, idx, vectorized=True) # pyright: ignore
297+
return jax.pure_callback(_step, idx, progress, idx, vectorized=True)
298298

299299
def close(self, close_bar: Callable[[Any], None], idx: IntScalarLike):
300300
def _close(_idx):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ classifiers = [
2323
"Topic :: Scientific/Engineering :: Mathematics",
2424
]
2525
urls = {repository = "https://github.com/patrick-kidger/diffrax" }
26-
dependencies = ["jax>=0.4.23", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.2", "lineax>=0.0.5", "optimistix>=0.0.6"]
26+
dependencies = ["jax>=0.4.23", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.2", "lineax>=0.0.5", "optimistix>=0.0.7"]
2727

2828
[build-system]
2929
requires = ["hatchling"]

test/test_integrate.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,10 @@ def test_basic(solver, t_dtype, y_dtype, treedef, stepsize_controller, getkey):
7777
return
7878

7979
if jnp.iscomplexobj(y_dtype) and treedef != jtu.tree_structure(None):
80-
if isinstance(solver, diffrax.AbstractImplicitSolver):
81-
return
82-
else:
83-
complex_warn = pytest.warns(match="Complex dtype")
80+
complex_warn = pytest.warns(match="Complex dtype")
8481

85-
def f(t, y, args):
86-
return jtu.tree_map(lambda yi: -1j * yi, y)
82+
def f(t, y, args):
83+
return jtu.tree_map(lambda yi: -1j * yi, y)
8784
else:
8885
complex_warn = contextlib.nullcontext()
8986

@@ -152,8 +149,6 @@ def test_ode_order(solver, dtype):
152149

153150
A = jr.normal(akey, (10, 10), dtype=dtype) * 0.5
154151

155-
if jnp.iscomplexobj(A) and isinstance(solver, diffrax.AbstractImplicitSolver):
156-
return
157152
if (
158153
solver.term_structure
159154
== diffrax.MultiTerm[tuple[diffrax.AbstractTerm, diffrax.AbstractTerm]]

test/test_interpolation.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@ def test_derivative(dtype, getkey):
6060
solver = implicit_tol(solver)
6161
y0 = jr.normal(getkey(), (3,), dtype=dtype)
6262

63-
if jnp.iscomplexobj(y0) and isinstance(solver, diffrax.AbstractImplicitSolver):
64-
continue
6563
solution = diffrax.diffeqsolve(
6664
diffrax.ODETerm(lambda t, y, p: -y),
6765
solver,
@@ -77,8 +75,7 @@ def test_derivative(dtype, getkey):
7775
for solver in all_split_solvers:
7876
solver = implicit_tol(solver)
7977
y0 = jr.normal(getkey(), (3,), dtype=dtype)
80-
if jnp.iscomplexobj(y0) and isinstance(solver, diffrax.AbstractImplicitSolver):
81-
continue
78+
8279
solution = diffrax.diffeqsolve(
8380
diffrax.MultiTerm(
8481
diffrax.ODETerm(lambda t, y, p: -0.7 * y),

0 commit comments

Comments
 (0)