Skip to content

Commit a40fde2

Browse files
committed
cleanup
1 parent 7584161 commit a40fde2

File tree

6 files changed

+13
-10
lines changed

6 files changed

+13
-10
lines changed

diffrax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
AbstractLocalInterpolation as AbstractLocalInterpolation,
4646
FourthOrderPolynomialInterpolation as FourthOrderPolynomialInterpolation,
4747
LocalLinearInterpolation as LocalLinearInterpolation,
48-
RodasInterpolation as RodasInterpolation, # noqa: E501
48+
RodasInterpolation as RodasInterpolation,
4949
ThirdOrderHermitePolynomialInterpolation as ThirdOrderHermitePolynomialInterpolation, # noqa: E501
5050
)
5151
from ._misc import adjoint_rms_seminorm as adjoint_rms_seminorm

diffrax/_local_interpolation.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,17 @@ def evaluate(
199199
return self.evaluate(t1) - self.evaluate(t0)
200200

201201
t = linear_rescale(self.t0, t0, self.t1)
202-
weighted_increment = jax.vmap(
203-
lambda coeff, stage_k: (t * jnp.polyval(jnp.flip(coeff), t)) * stage_k
204-
)(self.stage_poly_coeffs, self.k)
202+
203+
def eval_increment():
204+
with jax.numpy_dtype_promotion("standard"):
205+
weighted_increment = jax.vmap(
206+
lambda coeff, stage_k: (t * jnp.polyval(jnp.flip(coeff), t))
207+
* stage_k
208+
)(self.stage_poly_coeffs, self.k)
209+
return jnp.sum(weighted_increment, axis=0).astype(self.k.dtype)
205210

206211
y0, unravel = fu.ravel_pytree(self.y0)
207-
y1 = y0 + jnp.sum(weighted_increment, axis=0)
212+
y1 = y0 + eval_increment()
208213
return unravel(y1)
209214

210215
@classmethod

diffrax/_solver/rodas5p.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from collections.abc import Callable
22
from typing import ClassVar
3-
from xmlrpc.client import Boolean
43

54
import numpy as np
65

@@ -229,7 +228,7 @@ class Rodas5p(AbstractRosenbrock):
229228
_Rodas5pInterpolation.from_k
230229
)
231230

232-
rodas: ClassVar[Boolean] = True
231+
rodas: ClassVar[bool] = True
233232

234233
def order(self, terms):
235234
del terms

diffrax/_solver/rosenbrock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class AbstractRosenbrock(AbstractAdaptiveSolver):
101101

102102
rodas: ClassVar[bool] = False
103103

104-
linear_solver: lx.AbstractLinearSolver = lx.LU()
104+
linear_solver: lx.AbstractLinearSolver = lx.AutoLinearSolver(well_posed=True)
105105

106106
def init(self, terms, t0, t1, y0, args) -> _SolverState:
107107
del t0, t1

test/test_integrate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def test_ode_order(solver, dtype):
152152
A = jr.normal(akey, (10, 10), dtype=dtype) * 0.5
153153

154154
if isinstance(solver, AbstractRosenbrock) and dtype == jnp.complex128:
155-
## complex support is not added to rosenbrock.
155+
# complex support is not added to rosenbrock.
156156
return
157157

158158
if (

test/test_solver.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,6 @@ def rober(t, y, args):
465465
[6.1723488239606716e-01, 6.1535912746388841e-06, 3.8275896401264059e-01],
466466
]
467467
)
468-
print(sol.ys)
469468
assert jnp.allclose(sol.ys, true_ys, rtol=1e-3, atol=1e-8) # pyright: ignore
470469

471470

0 commit comments

Comments
 (0)