@@ -794,9 +794,10 @@ def __init__(self, blocks=()):
794
794
# A dictionary mapping the block variables in the checkpointable state
795
795
# to their checkpoint values.
796
796
self ._checkpoint = {}
797
- # A flag to indicate whether the adjoint dependencies have been cleaned
798
- # from the outputs not marked in the path.
799
- self ._adj_deps_cleaned = False
797
+ # Flag indicating whether the adjoint dependencies have been revised
798
+ # by removing outputs not marked in the path and adding checkpointable
799
+ # states that are marked in the path.
800
+ self ._revised_adj_deps = False
800
801
801
802
def copy (self , blocks = None ):
802
803
out = TimeStep (blocks or self )
@@ -826,8 +827,15 @@ def checkpoint(self, checkpointable_state, adj_dependencies, global_deps):
826
827
self ._checkpoint [var ] = var .saved_output ._ad_create_checkpoint ()
827
828
828
829
if adj_dependencies :
829
- for var in self .adjoint_dependencies :
830
- self ._checkpoint [var ] = var .saved_output ._ad_create_checkpoint ()
830
+ if self ._revised_adj_deps :
831
+ for var in self .adjoint_dependencies :
832
+ self ._checkpoint [var ] = var .saved_output ._ad_create_checkpoint ()
833
+ else :
834
+ # The adjoint dependencies have not been revised yet. At this stage,
835
+ # the block nodes are not marked in the path because the control variable(s)
836
+ # are not yet determined.
837
+ for var in self .adjoint_dependencies .union (self .checkpointable_state ):
838
+ self ._checkpoint [var ] = var .saved_output ._ad_create_checkpoint ()
831
839
832
840
def restore_from_checkpoint (self , from_storage ):
833
841
"""Restore the block var checkpoints from the timestep checkpoint."""
0 commit comments