Skip to content

Commit

Permalink
Reads checkpoints in parallel for files that support seeking.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 346214034
  • Loading branch information
Daniel Andor authored and Flax Authors committed Dec 8, 2020
1 parent 9a0b06d commit afe0df1
Showing 1 changed file with 35 additions and 3 deletions.
38 changes: 35 additions & 3 deletions flax/training/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
checkpoint files.
"""

from concurrent.futures import thread
import os
import re

Expand Down Expand Up @@ -110,7 +111,11 @@ def save_checkpoint(ckpt_dir,
return ckpt_path


def restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_'):
def restore_checkpoint(ckpt_dir,
target,
step=None,
prefix='checkpoint_',
parallel=True):
"""Restore last/best checkpoint from checkpoints in path.
Sorts the checkpoint files naturally, returning the highest-valued
Expand All @@ -125,6 +130,7 @@ def restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_'):
the deserialized state-dict is returned as-is.
step: int: step number to load or None to load latest.
prefix: str: name prefix of checkpoint files.
parallel: bool: whether to load seekable checkpoints in parallel, for speed.
Returns:
Restored `target` updated from checkpoint file, or if no step specified and
Expand All @@ -145,7 +151,33 @@ def restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_'):

logging.info('Restoring checkpoint from %s', ckpt_path)
with gfile.GFile(ckpt_path, 'rb') as fp:
if parallel and fp.seekable():
buf_size = 128 << 20 # 128M buffer.
num_bufs = fp.size() / buf_size
logging.debug('num_bufs: %d', num_bufs)
checkpoint_contents = bytearray(fp.size())

def read_chunk(i):
# NOTE: We have to re-open the file to read each chunk, otherwise the
# parallelism has no effect. But we could reuse the file pointers
# within each thread.
with gfile.GFile(ckpt_path, 'rb') as f:
f.seek(i * buf_size)
buf = f.read(buf_size)
if buf:
checkpoint_contents[i * buf_size:i * buf_size + len(buf)] = buf
return len(buf) / buf_size

pool_size = 32
pool = thread.ThreadPoolExecutor(pool_size)
results = pool.map(read_chunk, range(int(num_bufs) + 1))
results = list(results)
pool.shutdown(wait=False)
logging.debug('results: %s', results)
else:
checkpoint_contents = fp.read()

if target is None:
return serialization.msgpack_restore(fp.read())
return serialization.msgpack_restore(checkpoint_contents)
else:
return serialization.from_bytes(target, fp.read())
return serialization.from_bytes(target, checkpoint_contents)

0 comments on commit afe0df1

Please sign in to comment.