Skip to content

Commit

Permalink
Remove usages of orbax_utils.save_args_from_target, as this function …
Browse files Browse the repository at this point in the history
…does nothing (it used to control a checkpointing behavior that has since been optimized away).

PiperOrigin-RevId: 715575757
  • Loading branch information
cpgaffney1 authored and Flax Authors committed Jan 15, 2025
1 parent 1961c12 commit d9b4539
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions flax/training/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,9 +690,8 @@ def save_checkpoint(
' https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#if-you-don-t-save-pytrees'
)

save_args = orbax_utils.save_args_from_target(target)
orbax_checkpointer.save(
ckpt_path, target, save_args=save_args, force=overwrite
ckpt_path, target, force=overwrite
)
# Do a process check here in case people call this for multihost.
if process_index() == 0:
Expand Down Expand Up @@ -843,9 +842,8 @@ def save_checkpoint_multiprocess(
_remove_invalid_ckpts(
ckpt_path, base_path, keep, overwrite, keep_every_n_steps, True
)
save_args = orbax_utils.save_args_from_target(target)
orbax_checkpointer.save(
ckpt_path, target, save_args=save_args, force=overwrite
ckpt_path, target, force=overwrite
)
end_time = time.time()
monitoring.record_event_duration_secs(
Expand Down

0 comments on commit d9b4539

Please sign in to comment.