Skip to content

Don't use saved_output to create checkpoints for timesteps #206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

angus-g
Copy link
Contributor

@angus-g angus-g commented Apr 9, 2025

We have a few blocks, T (the BC function) <- AssignBlock (for T += 1), <- DirichletBCBlock <- AssignBlock <- DirichletBCBlock. At the end of timestep 0 (or the beginning of timestep 1), we have the first 3 of these blocks on tape. During the add_dependency phase of the solve, the DirichletBCBlock was also checkpointed, so its associated block variable has a _checkpoint member which is just the T function (https://github.com/firedrakeproject/firedrake/blob/a9f2c621633f0e56c69c9ab42c6f2343051f4839/firedrake/adjoint_utils/dirichletbc.py#L36-L43)

Now, when timestep 1 is due to begin, the process_taping method checkpoints all the blocks:

self.tape.timesteps[timestep - 1].checkpoint(
_store_checkpointable_state, _store_adj_dependencies, self._global_deps)

This process goes through the saved_output property:

pyadjoint/pyadjoint/tape.py

Lines 839 to 840 in 0b13480

for var in self.adjoint_dependencies.union(self.checkpointable_state):
self._checkpoint[var] = var.saved_output._ad_create_checkpoint()

Let's consider when var here is the block variable associated with the DirichletBC. The saved_output property will use the existing checkpoint:

@property
def saved_output(self):
if self.checkpoint is not None:
return self.output._ad_restore_at_checkpoint(self.checkpoint)
else:
return self.output

This has the effect of calling DirichletBCMixin._ad_restore_at_checkpoint on a checkpointed function: https://github.com/firedrakeproject/firedrake/blob/a9f2c621633f0e56c69c9ab42c6f2343051f4839/firedrake/adjoint_utils/dirichletbc.py#L45-L48

Because of the use of set_value here, the DirichletBC is then modified (and refers to the previous value of T). It seems like the weak checkpointing in DirichletBC is probably correct, which points to the accessing of saved_output during TimeStep.checkpoint being incorrect. I'm really unclear about what's checkpointed and when and when to use .output vs. .saved_output, etc. I don't know if there are further implications for this change.

This can cause the restoration of the saved_output's checkpoint, which
can lead to stale data being stored.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant