Skip to content

Commit 9d9d1dd

Browse files
authored
Revise the adjoint dependencies (#201)
* wip * Revise the adjoint dependences * falke8
1 parent 8eaa3ed commit 9d9d1dd

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

pyadjoint/checkpointing.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -406,11 +406,15 @@ def _(self, cp_action, progress_bar, markings, functional=None, **kwargs):
406406
current_step = self.tape.timesteps[step]
407407
for block in reversed(current_step):
408408
block.evaluate_adj(markings=markings)
409-
if not current_step._adj_deps_cleaned:
409+
if not current_step._revised_adj_deps:
410+
# Update the adjoint dependency set.
411+
for deps in block.get_dependencies():
412+
if deps.marked_in_path:
413+
current_step.adjoint_dependencies.add(deps)
410414
for out in block._outputs:
411415
if not out.marked_in_path:
412416
current_step.adjoint_dependencies.discard(out)
413-
current_step._adj_deps_cleaned = True
417+
current_step._revised_adj_deps = True
414418
# Output variables are used for the last time when running
415419
# backwards.
416420
to_keep = current_step.checkpointable_state

pyadjoint/tape.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -794,9 +794,10 @@ def __init__(self, blocks=()):
794794
# A dictionary mapping the block variables in the checkpointable state
795795
# to their checkpoint values.
796796
self._checkpoint = {}
797-
# A flag to indicate whether the adjoint dependencies have been cleaned
798-
# from the outputs not marked in the path.
799-
self._adj_deps_cleaned = False
797+
# Flag indicating whether the adjoint dependencies have been revised
798+
# by removing outputs not marked in the path and adding checkpointable
799+
# states that are marked in the path.
800+
self._revised_adj_deps = False
800801

801802
def copy(self, blocks=None):
802803
out = TimeStep(blocks or self)
@@ -826,8 +827,15 @@ def checkpoint(self, checkpointable_state, adj_dependencies, global_deps):
826827
self._checkpoint[var] = var.saved_output._ad_create_checkpoint()
827828

828829
if adj_dependencies:
829-
for var in self.adjoint_dependencies:
830-
self._checkpoint[var] = var.saved_output._ad_create_checkpoint()
830+
if self._revised_adj_deps:
831+
for var in self.adjoint_dependencies:
832+
self._checkpoint[var] = var.saved_output._ad_create_checkpoint()
833+
else:
834+
# The adjoint dependencies have not been revised yet. At this stage,
835+
# the block nodes are not marked in the path because the control variable(s)
836+
# are not yet determined.
837+
for var in self.adjoint_dependencies.union(self.checkpointable_state):
838+
self._checkpoint[var] = var.saved_output._ad_create_checkpoint()
831839

832840
def restore_from_checkpoint(self, from_storage):
833841
"""Restore the block var checkpoints from the timestep checkpoint."""

0 commit comments

Comments
 (0)