diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index fe163b6f672d..bcb958a6d2e5 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -107,6 +107,7 @@ def __init__(self, save_interval: int, max_to_keep: Optional[int] = 0, max_pending_async: Optional[int] = 1, + num_of_threads: Optional[int] = 1, process_group: dist.ProcessGroup = None, chkpt_on_preemption: bool = True): """ @@ -127,6 +128,8 @@ def __init__(self, slow down the active checkpoint. Default: 1, which only allows a single async checkpoint to be pending at a time. + number_of_threads: Number of concurrent threads for writing checkpoint to + file system. process_group: The process group to use when coordinating the checkpoint. Default: None, in which case a subgroup of the default process group will be created. @@ -142,6 +145,7 @@ def __init__(self, self.base_path = os.path.join(path, '') # Ensure the base path ends in '/' self.save_interval = save_interval self.max_to_keep = max_to_keep + self.num_of_threads = num_of_threads self.chkpt_on_preemption = chkpt_on_preemption # Create a new group if none is provided @@ -226,6 +230,7 @@ def _save(self, step, state_dict): state_dict=state_dict, storage_writer=FsspecWriter( path, + thread_count=self.num_of_threads, per_thread_copy_ahead=0, ), planner=xc.SPMDSavePlanner(),