@@ -240,14 +240,23 @@ def _save(
240240 args : PyTree ,
241241 fn : Callable ,
242242 save_state : SaveState ,
243+ repeat : int ,
243244) -> SaveState :
244245 ts = save_state .ts
245246 ys = save_state .ys
246247 save_index = save_state .save_index
247248
248- ts = ts .at [save_index ].set (t )
249- ys = jtu .tree_map (lambda ys_ , y_ : ys_ .at [save_index ].set (y_ ), ys , fn (t , y , args ))
250- save_index = save_index + 1
249+ ts = lax .dynamic_update_slice_in_dim (
250+ ts , jnp .broadcast_to (t , (repeat ,)), save_index , axis = 0
251+ )
252+ ys = jtu .tree_map (
253+ lambda ys_ , y_ : lax .dynamic_update_slice_in_dim (
254+ ys_ , jnp .broadcast_to (y_ , (repeat , * y_ .shape )), save_index , axis = 0
255+ ),
256+ ys ,
257+ fn (t , y , args ),
258+ )
259+ save_index = save_index + repeat
251260
252261 return eqx .tree_at (
253262 lambda s : [s .ts , s .ys , s .save_index ], save_state , [ts , ys , save_index ]
@@ -306,7 +315,9 @@ def loop(
306315
307316 def save_t0 (subsaveat : SubSaveAt , save_state : SaveState ) -> SaveState :
308317 if subsaveat .t0 :
309- save_state = _save (t0 , init_state .y , args , subsaveat .fn , save_state )
318+ save_state = _save (
319+ t0 , init_state .y , args , subsaveat .fn , save_state , repeat = 1
320+ )
310321 return save_state
311322
312323 save_state = jtu .tree_map (
@@ -638,6 +649,7 @@ def body_fun(state):
638649 final_state = outer_while_loop (
639650 cond_fun , body_fun , init_state , max_steps = max_steps , buffers = _outer_buffers
640651 )
652+ save_state = final_state .save_state
641653 result = final_state .result
642654
643655 if event is None or event .root_finder is None :
@@ -765,66 +777,16 @@ def unsave(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
765777 )
766778
767779 save_state = jtu .tree_map (
768- unsave , saveat .subs , final_state .save_state , is_leaf = _is_subsaveat
769- )
770-
771- final_state = eqx .tree_at (
772- lambda s : s .save_state ,
773- final_state ,
774- save_state ,
775- is_leaf = _is_none ,
780+ unsave , saveat .subs , save_state , is_leaf = _is_subsaveat
776781 )
777782
778- def _save_t1 (subsaveat , save_state ):
779- if event is None or event .root_finder is None :
780- if subsaveat .t1 and not subsaveat .steps :
781- # If subsaveat.steps then the final value is already saved.
782- save_state = _save (tfinal , yfinal , args , subsaveat .fn , save_state )
783- else :
784- if subsaveat .t1 or subsaveat .steps :
785- # In this branch we need to replace the last value with tfinal
786- # and yfinal returned by the root finder also if subsaveat.steps
787- # because we deleted the last value after the event time above.
788- save_state = _save (tfinal , yfinal , args , subsaveat .fn , save_state )
789- return save_state
790-
791783 def _save_if_t0_equals_t1 (subsaveat : SubSaveAt , save_state : SaveState ) -> SaveState :
792784 if subsaveat .ts is not None :
793- out_size = 1 if subsaveat .t0 else 0
794- out_size += 1 if subsaveat .t1 and not subsaveat .steps else 0
795- out_size += len (subsaveat .ts )
796-
797- def _make_ys (out , old_outs ):
798- outs = jnp .stack ([out ] * out_size )
799- if subsaveat .steps :
800- outs = jnp .concatenate (
801- [
802- outs ,
803- jnp .full (
804- (max_steps ,) + out .shape , jnp .inf , dtype = out .dtype
805- ),
806- ]
807- )
808- assert outs .shape == old_outs .shape
809- return outs
810-
811- ts = jnp .full (out_size , t0 )
812- if subsaveat .steps :
813- ts = jnp .concatenate ((ts , jnp .full (max_steps , jnp .inf , dtype = ts .dtype )))
814- assert ts .shape == save_state .ts .shape
815- ys = jtu .tree_map (_make_ys , subsaveat .fn (t0 , yfinal , args ), save_state .ys )
816- save_state = SaveState (
817- saveat_ts_index = out_size ,
818- ts = ts ,
819- ys = ys ,
820- save_index = out_size ,
785+ save_state = _save (
786+ t0 , yfinal , args , subsaveat .fn , save_state , repeat = len (subsaveat .ts )
821787 )
822788 return save_state
823789
824- save_state = jtu .tree_map (
825- _save_t1 , saveat .subs , final_state .save_state , is_leaf = _is_subsaveat
826- )
827-
828790 # if t0 == t1 then we don't enter the integration loop. In this case we have to
829791 # manually update the saved ts and ys if we want to save at "intermediate"
830792 # times specified by saveat.subs.ts
@@ -842,10 +804,28 @@ def _make_ys(out, old_outs):
842804 save_state ,
843805 )
844806
807+ def _save_t1 (subsaveat , save_state ):
808+ if event is None or event .root_finder is None :
809+ if subsaveat .t1 and not subsaveat .steps :
810+ # If subsaveat.steps then the final value is already saved.
811+ save_state = _save (
812+ tfinal , yfinal , args , subsaveat .fn , save_state , repeat = 1
813+ )
814+ else :
815+ if subsaveat .t1 or subsaveat .steps :
816+ # In this branch we need to replace the last value with tfinal
817+ # and yfinal returned by the root finder also if subsaveat.steps
818+ # because we deleted the last value after the event time above.
819+ save_state = _save (
820+ tfinal , yfinal , args , subsaveat .fn , save_state , repeat = 1
821+ )
822+ return save_state
823+
824+ save_state = jtu .tree_map (_save_t1 , saveat .subs , save_state , is_leaf = _is_subsaveat )
825+
845826 final_state = eqx .tree_at (
846827 lambda s : s .save_state , final_state , save_state , is_leaf = _is_none
847828 )
848-
849829 final_state = _handle_static (final_state )
850830 result = RESULTS .where (cond_fun (final_state ), RESULTS .max_steps_reached , result )
851831 aux_stats = dict () # TODO: put something in here?
@@ -1287,9 +1267,7 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
12871267 "`max_steps=None` is incompatible with `saveat.dense=True`"
12881268 )
12891269 dense_ts = jnp .full (max_steps + 1 , jnp .inf , dtype = time_dtype )
1290- _make_full = lambda x : jnp .full (
1291- (max_steps ,) + jnp .shape (x ), jnp .inf , dtype = x .dtype
1292- )
1270+ _make_full = lambda x : jnp .full ((max_steps ,) + x .shape , jnp .inf , dtype = x .dtype )
12931271 dense_infos = jtu .tree_map (_make_full , dense_info_struct ) # pyright: ignore[reportPossiblyUnboundVariable]
12941272 dense_save_index = 0
12951273 else :
0 commit comments