Skip to content

Commit 3df311c

Browse files
Removed old pyright-ignores.
1 parent f02d8eb commit 3df311c

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

diffrax/_solver/runge_kutta.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,7 @@ def rk_stage(val):
924924
)
925925

926926
def eval_f_jac():
927-
return self.root_finder.init( # pyright: ignore
927+
return self.root_finder.init(
928928
lambda y, a: (_implicit_relation_f(y, a), None),
929929
lax.stop_gradient(f_pred),
930930
_filter_stop_gradient(f_implicit_args),
@@ -935,7 +935,7 @@ def eval_f_jac():
935935
)
936936

937937
def eval_k_jac():
938-
return self.root_finder.init( # pyright: ignore
938+
return self.root_finder.init(
939939
lambda y, a: (_implicit_relation_k(y, a), None),
940940
lax.stop_gradient(k_pred),
941941
_filter_stop_gradient(k_implicit_args),
@@ -980,12 +980,12 @@ def eval_k_jac():
980980
jac_f = eqxi.nondifferentiable(jac_f, name="jac_f")
981981
nonlinear_sol = optx.root_find(
982982
_implicit_relation_f,
983-
self.root_finder, # pyright: ignore
983+
self.root_finder,
984984
f_pred,
985985
f_implicit_args,
986986
options=dict(init_state=jac_f),
987987
throw=False,
988-
max_steps=self.root_find_max_steps, # pyright: ignore
988+
max_steps=self.root_find_max_steps,
989989
)
990990
implicit_fi = nonlinear_sol.value
991991
implicit_ki = _unused
@@ -995,12 +995,12 @@ def eval_k_jac():
995995
jac_k = eqxi.nondifferentiable(jac_k, name="jac_k")
996996
nonlinear_sol = optx.root_find(
997997
_implicit_relation_k,
998-
self.root_finder, # pyright: ignore
998+
self.root_finder,
999999
k_pred,
10001000
k_implicit_args,
10011001
options=dict(init_state=jac_k),
10021002
throw=False,
1003-
max_steps=self.root_find_max_steps, # pyright: ignore
1003+
max_steps=self.root_find_max_steps,
10041004
)
10051005
implicit_fi = _unused
10061006
implicit_ki = implicit_inc = nonlinear_sol.value
@@ -1093,7 +1093,7 @@ def buffers(val):
10931093
args,
10941094
implicit_control,
10951095
)
1096-
jac_f = self.root_finder.init( # pyright: ignore
1096+
jac_f = self.root_finder.init(
10971097
lambda y, a: (_implicit_relation_f(y, a), None),
10981098
jtu.tree_map(jnp.zeros_like, get_implicit(f0)),
10991099
_filter_stop_gradient(f_implicit_args),
@@ -1115,7 +1115,7 @@ def buffers(val):
11151115
implicit_control,
11161116
)
11171117
jac_f = _unused
1118-
jac_k = self.root_finder.init( # pyright: ignore
1118+
jac_k = self.root_finder.init(
11191119
lambda y, a: (_implicit_relation_k(y, a), None),
11201120
jtu.tree_map(jnp.zeros_like, y0),
11211121
_filter_stop_gradient(k_implicit_args),

diffrax/_step_size_controller/pid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def intermediate(carry):
8989
# PIDController(... step_ts=s, jump_ts=j) this should return a
9090
# ClipStepSizeController(PIDController(...), s, j).
9191
class _MetaPID(type(eqx.Module)):
92-
def __call__(cls, *args, **kwargs):
92+
def __call__(cls, *args, **kwargs): # pyright: ignore[reportSelfClsParameterName]
9393
step_ts = kwargs.pop("step_ts", None)
9494
jump_ts = kwargs.pop("jump_ts", None)
9595
if step_ts is not None or jump_ts is not None:

0 commit comments

Comments
 (0)