@@ -924,7 +924,7 @@ def rk_stage(val):
924924 )
925925
926926 def eval_f_jac ():
927- return self .root_finder .init ( # pyright: ignore
927+ return self .root_finder .init (
928928 lambda y , a : (_implicit_relation_f (y , a ), None ),
929929 lax .stop_gradient (f_pred ),
930930 _filter_stop_gradient (f_implicit_args ),
@@ -935,7 +935,7 @@ def eval_f_jac():
935935 )
936936
937937 def eval_k_jac ():
938- return self .root_finder .init ( # pyright: ignore
938+ return self .root_finder .init (
939939 lambda y , a : (_implicit_relation_k (y , a ), None ),
940940 lax .stop_gradient (k_pred ),
941941 _filter_stop_gradient (k_implicit_args ),
@@ -980,12 +980,12 @@ def eval_k_jac():
980980 jac_f = eqxi .nondifferentiable (jac_f , name = "jac_f" )
981981 nonlinear_sol = optx .root_find (
982982 _implicit_relation_f ,
983- self .root_finder , # pyright: ignore
983+ self .root_finder ,
984984 f_pred ,
985985 f_implicit_args ,
986986 options = dict (init_state = jac_f ),
987987 throw = False ,
988- max_steps = self .root_find_max_steps , # pyright: ignore
988+ max_steps = self .root_find_max_steps ,
989989 )
990990 implicit_fi = nonlinear_sol .value
991991 implicit_ki = _unused
@@ -995,12 +995,12 @@ def eval_k_jac():
995995 jac_k = eqxi .nondifferentiable (jac_k , name = "jac_k" )
996996 nonlinear_sol = optx .root_find (
997997 _implicit_relation_k ,
998- self .root_finder , # pyright: ignore
998+ self .root_finder ,
999999 k_pred ,
10001000 k_implicit_args ,
10011001 options = dict (init_state = jac_k ),
10021002 throw = False ,
1003- max_steps = self .root_find_max_steps , # pyright: ignore
1003+ max_steps = self .root_find_max_steps ,
10041004 )
10051005 implicit_fi = _unused
10061006 implicit_ki = implicit_inc = nonlinear_sol .value
@@ -1093,7 +1093,7 @@ def buffers(val):
10931093 args ,
10941094 implicit_control ,
10951095 )
1096- jac_f = self .root_finder .init ( # pyright: ignore
1096+ jac_f = self .root_finder .init (
10971097 lambda y , a : (_implicit_relation_f (y , a ), None ),
10981098 jtu .tree_map (jnp .zeros_like , get_implicit (f0 )),
10991099 _filter_stop_gradient (f_implicit_args ),
@@ -1115,7 +1115,7 @@ def buffers(val):
11151115 implicit_control ,
11161116 )
11171117 jac_f = _unused
1118- jac_k = self .root_finder .init ( # pyright: ignore
1118+ jac_k = self .root_finder .init (
11191119 lambda y , a : (_implicit_relation_k (y , a ), None ),
11201120 jtu .tree_map (jnp .zeros_like , y0 ),
11211121 _filter_stop_gradient (k_implicit_args ),
0 commit comments