Skip to content

Commit

Permalink
Explicitly create GDA directories when saving checkpoints.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 471385155
  • Loading branch information
IvyZX authored and Flax Authors committed Aug 31, 2022
1 parent 6781c59 commit 0451a55
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions flax/training/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from flax import traverse_util
from jax import process_index
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.multihost_utils import sync_global_devices
from tensorflow.io import gfile # pytype: disable=import-error

_IMPORT_GDAM_SUCCESSFUL = False
Expand Down Expand Up @@ -113,6 +114,13 @@ def on_commit_callback(tmp_path, final_path):
logging.info('Finished saving checkpoint to `%s`.', final_path)


def _make_gda_dirs(gda_targets: List[Tuple[GlobalDeviceArray, str]],
tmp_path: str):
_, gda_subpaths = zip(*gda_targets)
for subpath in gda_subpaths:
gfile.makedirs(os.path.join(tmp_path, subpath))


def _save_gdas(gda_manager,
gda_targets: List[Tuple[GlobalDeviceArray, str]],
tmp_path: str, final_path: str):
Expand All @@ -124,9 +132,8 @@ def _save_gdas(gda_manager,
# TODO: figure out a way to unit-test the behavior.
if tmp_path.startswith('gs://'):
tmp_path = final_path
ts_specs = [
get_tensorstore_spec(os.path.join(tmp_path, x)) for x in gda_subpaths
]
gda_paths = [os.path.join(tmp_path, x) for x in gda_subpaths]
ts_specs = [get_tensorstore_spec(x) for x in gda_paths]
gda_manager.serialize(
list(gda_list),
ts_specs,
Expand Down Expand Up @@ -380,6 +387,11 @@ def save_task():
if not gda_manager:
raise errors.GDACheckpointingRequiredError(ckpt_path, step)
gda_tmp_path, gda_final_path = ckpt_tmp_path + '_gda', ckpt_path + '_gda'
# Creating the directory containing GDAs explicitly. This should happen only
# on process 0 and before any worker starts to write data.
if process_index() == 0:
_make_gda_dirs(gda_targets, gda_tmp_path)
sync_global_devices('sync_after_create_dir')
_save_gdas(gda_manager, gda_targets, gda_tmp_path, gda_final_path)

return ckpt_path
Expand Down

0 comments on commit 0451a55

Please sign in to comment.