diff --git a/flax/training/checkpoints.py b/flax/training/checkpoints.py index 85af61e3..d49f69ed 100644 --- a/flax/training/checkpoints.py +++ b/flax/training/checkpoints.py @@ -19,6 +19,7 @@ checkpoint files. """ +from concurrent.futures import thread import os import re @@ -110,7 +111,11 @@ def save_checkpoint(ckpt_dir, return ckpt_path -def restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_'): +def restore_checkpoint(ckpt_dir, + target, + step=None, + prefix='checkpoint_', + parallel=True): """Restore last/best checkpoint from checkpoints in path. Sorts the checkpoint files naturally, returning the highest-valued @@ -125,6 +130,7 @@ def restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_'): the deserialized state-dict is returned as-is. step: int: step number to load or None to load latest. prefix: str: name prefix of checkpoint files. + parallel: bool: whether to load seekable checkpoints in parallel, for speed. Returns: Restored `target` updated from checkpoint file, or if no step specified and @@ -145,7 +151,33 @@ def restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_'): logging.info('Restoring checkpoint from %s', ckpt_path) with gfile.GFile(ckpt_path, 'rb') as fp: + if parallel and fp.seekable(): + buf_size = 128 << 20 # 128M buffer. + num_bufs = fp.size() / buf_size + logging.debug('num_bufs: %d', num_bufs) + checkpoint_contents = bytearray(fp.size()) + + def read_chunk(i): + # NOTE: We have to re-open the file to read each chunk, otherwise the + # parallelism has no effect. But we could reuse the file pointers + # within each thread. + with gfile.GFile(ckpt_path, 'rb') as f: + f.seek(i * buf_size) + buf = f.read(buf_size) + if buf: + checkpoint_contents[i * buf_size:i * buf_size + len(buf)] = buf + return len(buf) / buf_size + + pool_size = 32 + pool = thread.ThreadPoolExecutor(pool_size) + results = pool.map(read_chunk, range(int(num_bufs) + 1)) + results = list(results) + pool.shutdown(wait=False) + logging.debug('results: %s', results) + else: + checkpoint_contents = fp.read() + if target is None: - return serialization.msgpack_restore(fp.read()) + return serialization.msgpack_restore(checkpoint_contents) else: - return serialization.from_bytes(target, fp.read()) + return serialization.from_bytes(target, checkpoint_contents)