Skip to content

Commit 78fac1d

Browse files
Resolves speed issue of #606.
1 parent 2fafbc7 commit 78fac1d

File tree

1 file changed

+39
-61
lines changed

1 file changed

+39
-61
lines changed

diffrax/_integrate.py

Lines changed: 39 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -240,14 +240,23 @@ def _save(
240240
args: PyTree,
241241
fn: Callable,
242242
save_state: SaveState,
243+
repeat: int,
243244
) -> SaveState:
244245
ts = save_state.ts
245246
ys = save_state.ys
246247
save_index = save_state.save_index
247248

248-
ts = ts.at[save_index].set(t)
249-
ys = jtu.tree_map(lambda ys_, y_: ys_.at[save_index].set(y_), ys, fn(t, y, args))
250-
save_index = save_index + 1
249+
ts = lax.dynamic_update_slice_in_dim(
250+
ts, jnp.broadcast_to(t, (repeat,)), save_index, axis=0
251+
)
252+
ys = jtu.tree_map(
253+
lambda ys_, y_: lax.dynamic_update_slice_in_dim(
254+
ys_, jnp.broadcast_to(y_, (repeat, *y_.shape)), save_index, axis=0
255+
),
256+
ys,
257+
fn(t, y, args),
258+
)
259+
save_index = save_index + repeat
251260

252261
return eqx.tree_at(
253262
lambda s: [s.ts, s.ys, s.save_index], save_state, [ts, ys, save_index]
@@ -306,7 +315,9 @@ def loop(
306315

307316
def save_t0(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
308317
if subsaveat.t0:
309-
save_state = _save(t0, init_state.y, args, subsaveat.fn, save_state)
318+
save_state = _save(
319+
t0, init_state.y, args, subsaveat.fn, save_state, repeat=1
320+
)
310321
return save_state
311322

312323
save_state = jtu.tree_map(
@@ -638,6 +649,7 @@ def body_fun(state):
638649
final_state = outer_while_loop(
639650
cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers
640651
)
652+
save_state = final_state.save_state
641653
result = final_state.result
642654

643655
if event is None or event.root_finder is None:
@@ -765,66 +777,16 @@ def unsave(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
765777
)
766778

767779
save_state = jtu.tree_map(
768-
unsave, saveat.subs, final_state.save_state, is_leaf=_is_subsaveat
769-
)
770-
771-
final_state = eqx.tree_at(
772-
lambda s: s.save_state,
773-
final_state,
774-
save_state,
775-
is_leaf=_is_none,
780+
unsave, saveat.subs, save_state, is_leaf=_is_subsaveat
776781
)
777782

778-
def _save_t1(subsaveat, save_state):
779-
if event is None or event.root_finder is None:
780-
if subsaveat.t1 and not subsaveat.steps:
781-
# If subsaveat.steps then the final value is already saved.
782-
save_state = _save(tfinal, yfinal, args, subsaveat.fn, save_state)
783-
else:
784-
if subsaveat.t1 or subsaveat.steps:
785-
# In this branch we need to replace the last value with tfinal
786-
# and yfinal returned by the root finder also if subsaveat.steps
787-
# because we deleted the last value after the event time above.
788-
save_state = _save(tfinal, yfinal, args, subsaveat.fn, save_state)
789-
return save_state
790-
791783
def _save_if_t0_equals_t1(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
792784
if subsaveat.ts is not None:
793-
out_size = 1 if subsaveat.t0 else 0
794-
out_size += 1 if subsaveat.t1 and not subsaveat.steps else 0
795-
out_size += len(subsaveat.ts)
796-
797-
def _make_ys(out, old_outs):
798-
outs = jnp.stack([out] * out_size)
799-
if subsaveat.steps:
800-
outs = jnp.concatenate(
801-
[
802-
outs,
803-
jnp.full(
804-
(max_steps,) + out.shape, jnp.inf, dtype=out.dtype
805-
),
806-
]
807-
)
808-
assert outs.shape == old_outs.shape
809-
return outs
810-
811-
ts = jnp.full(out_size, t0)
812-
if subsaveat.steps:
813-
ts = jnp.concatenate((ts, jnp.full(max_steps, jnp.inf, dtype=ts.dtype)))
814-
assert ts.shape == save_state.ts.shape
815-
ys = jtu.tree_map(_make_ys, subsaveat.fn(t0, yfinal, args), save_state.ys)
816-
save_state = SaveState(
817-
saveat_ts_index=out_size,
818-
ts=ts,
819-
ys=ys,
820-
save_index=out_size,
785+
save_state = _save(
786+
t0, yfinal, args, subsaveat.fn, save_state, repeat=len(subsaveat.ts)
821787
)
822788
return save_state
823789

824-
save_state = jtu.tree_map(
825-
_save_t1, saveat.subs, final_state.save_state, is_leaf=_is_subsaveat
826-
)
827-
828790
# if t0 == t1 then we don't enter the integration loop. In this case we have to
829791
# manually update the saved ts and ys if we want to save at "intermediate"
830792
# times specified by saveat.subs.ts
@@ -842,10 +804,28 @@ def _make_ys(out, old_outs):
842804
save_state,
843805
)
844806

807+
def _save_t1(subsaveat, save_state):
808+
if event is None or event.root_finder is None:
809+
if subsaveat.t1 and not subsaveat.steps:
810+
# If subsaveat.steps then the final value is already saved.
811+
save_state = _save(
812+
tfinal, yfinal, args, subsaveat.fn, save_state, repeat=1
813+
)
814+
else:
815+
if subsaveat.t1 or subsaveat.steps:
816+
# In this branch we need to replace the last value with tfinal
817+
# and yfinal returned by the root finder also if subsaveat.steps
818+
# because we deleted the last value after the event time above.
819+
save_state = _save(
820+
tfinal, yfinal, args, subsaveat.fn, save_state, repeat=1
821+
)
822+
return save_state
823+
824+
save_state = jtu.tree_map(_save_t1, saveat.subs, save_state, is_leaf=_is_subsaveat)
825+
845826
final_state = eqx.tree_at(
846827
lambda s: s.save_state, final_state, save_state, is_leaf=_is_none
847828
)
848-
849829
final_state = _handle_static(final_state)
850830
result = RESULTS.where(cond_fun(final_state), RESULTS.max_steps_reached, result)
851831
aux_stats = dict() # TODO: put something in here?
@@ -1287,9 +1267,7 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
12871267
"`max_steps=None` is incompatible with `saveat.dense=True`"
12881268
)
12891269
dense_ts = jnp.full(max_steps + 1, jnp.inf, dtype=time_dtype)
1290-
_make_full = lambda x: jnp.full(
1291-
(max_steps,) + jnp.shape(x), jnp.inf, dtype=x.dtype
1292-
)
1270+
_make_full = lambda x: jnp.full((max_steps,) + x.shape, jnp.inf, dtype=x.dtype)
12931271
dense_infos = jtu.tree_map(_make_full, dense_info_struct) # pyright: ignore[reportPossiblyUnboundVariable]
12941272
dense_save_index = 0
12951273
else:

0 commit comments

Comments
 (0)