From adc2071043dffe7fc0739b292afd4502483a51b5 Mon Sep 17 00:00:00 2001 From: baijumeswani Date: Thu, 17 Dec 2020 21:01:36 -0800 Subject: [PATCH] save_checkpoint, load_checkpoint and aggregate_checkpoints (#6136) * save_checkpoint and load_checkpoint implementations * checkpoint aggregation logic * unit tests for save_checkpoint, load_checkpoint and aggregate_checkpoints --- .../python/training/_checkpoint_storage.py | 18 + .../orttraining/python/training/_utils.py | 45 ++- .../orttraining/python/training/checkpoint.py | 230 +++++++++++- .../orttraining/python/training/orttrainer.py | 138 +++++++- .../orttraining_test_checkpoint_storage.py | 9 +- ...ng_test_orttrainer_checkpoint_functions.py | 333 ++++++++++++++++-- 6 files changed, 717 insertions(+), 56 deletions(-) diff --git a/orttraining/orttraining/python/training/_checkpoint_storage.py b/orttraining/orttraining/python/training/_checkpoint_storage.py index 2016fb0038cc9..064338b2ad815 100644 --- a/orttraining/orttraining/python/training/_checkpoint_storage.py +++ b/orttraining/orttraining/python/training/_checkpoint_storage.py @@ -4,6 +4,7 @@ import h5py from collections.abc import Mapping +import pickle def _dfs_save(group, save_obj): """Recursively go over each level in the save_obj dictionary and save values to a hdf5 group""" @@ -79,3 +80,20 @@ def load(path, key=None): _dfs_load(f, load_obj) return load_obj + +def to_serialized_hex(user_dict): + """Serialize the user_dict and convert the serialized bytes to a hex string and return""" + + return pickle.dumps(user_dict).hex() + +def from_serialized_hex(serialized_hex): + """Convert serialized_hex to bytes and deserialize it and return""" + + # serialized_hex can be either a regular string or a byte string. + # if it is a byte string, convert to regular string using decode() + # if it is a regular string, do nothing to it + try: + serialized_hex = serialized_hex.decode() + except AttributeError: + pass + return pickle.loads(bytes.fromhex(serialized_hex)) diff --git a/orttraining/orttraining/python/training/_utils.py b/orttraining/orttraining/python/training/_utils.py index 4444e8327c4c5..7d110f08f737a 100644 --- a/orttraining/orttraining/python/training/_utils.py +++ b/orttraining/orttraining/python/training/_utils.py @@ -202,4 +202,47 @@ def state_dict_trainer_options_key(): def state_dict_full_precision_key(): """Returns the full precision key name in the state dictionary""" - return 'fp32' + return 'full_precision' + +def state_dict_original_dimension_key(): + """Returns the original dimension key name in the state dictionary""" + + return 'original_dim' + +def state_dict_sharded_optimizer_keys(): + """Returns the optimizer key names that can be sharded in the state dictionary""" + + return { + 'Moment_1', + 'Moment_2' + } + +def state_dict_user_dict_key(): + """Returns the user dict key name in the state dictionary""" + + return 'user_dict' + +def state_dict_trainer_options_mixed_precision_key(): + """Returns the trainer options mixed precision key name in the state dictionary""" + + return 'mixed_precision' + +def state_dict_trainer_options_zero_stage_key(): + """Returns the trainer options zero_stage key name in the state dictionary""" + + return 'zero_stage' + +def state_dict_trainer_options_world_rank_key(): + """Returns the trainer options world_rank key name in the state dictionary""" + + return 'world_rank' + +def state_dict_trainer_options_world_size_key(): + """Returns the trainer options world_size key name in the state dictionary""" + + return 'world_size' + +def state_dict_trainer_options_optimizer_name_key(): + """Returns the trainer options optimizer_name key name in the state dictionary""" + + return 'optimizer_name' diff --git a/orttraining/orttraining/python/training/checkpoint.py b/orttraining/orttraining/python/training/checkpoint.py index a1b3abc156040..8384367415452 100644 --- a/orttraining/orttraining/python/training/checkpoint.py +++ b/orttraining/orttraining/python/training/checkpoint.py @@ -3,6 +3,7 @@ import os import torch import warnings +from . import _checkpoint_storage, _utils ################################################################################ @@ -108,6 +109,233 @@ def experimental_load_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix= else: return _load_single_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, is_partitioned, strict) +def _order_paths(paths): + """Reorders the given paths in ascending order of rank and return the ordered list""" + + trainer_options_path_tuples = [] + world_rank = _utils.state_dict_trainer_options_world_rank_key() + + for path in paths: + trainer_options_path_tuples.append((_checkpoint_storage.load(path, + key=_utils.state_dict_trainer_options_key()), path)) + + ordered_paths = [path for _, path in sorted(trainer_options_path_tuples, + key=lambda trainer_options_path_pair: trainer_options_path_pair[0][world_rank])] + + return ordered_paths + +def _add_or_update_sharded_key_for_zero(state_key, state_value, state_sub_dict, + model_state_key, original_dim, sharded_states_original_dims): + """Add or update the record for the sharded state_key in the state_sub_dict""" + + # record the original dimension for this state + sharded_states_original_dims[model_state_key] = original_dim + + if state_key in state_sub_dict: + # state_dict already contains a record for this state + # since this state is sharded, concatenate the state value to + # the record in the state_dict + state_sub_dict[state_key] = \ + np.concatenate((state_sub_dict[state_key], state_value)) + else: + # create a new entry for this state in the state_dict + state_sub_dict[state_key] = state_value + +def _add_or_validate_unsharded_key_for_zero(state_key, state_value, state_sub_dict, mismatch_error_string): + """Add or validate the record for the unsharded state_key in the state_sub_dict""" + + if state_key in state_sub_dict: + # state_dict already contains a record for this unsharded state. + # assert that all values are the same for this previously loaded state + assert (state_sub_dict[state_key] == state_value).all(), mismatch_error_string + else: + # create a new entry for this state in the state_sub_dict + state_sub_dict[state_key] = state_value + +def _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict): + """Aggregates all model states from the rank_state_dict into state_dict""" + + model = _utils.state_dict_model_key() + full_precision = _utils.state_dict_full_precision_key() + partition_info = _utils.state_dict_partition_info_key() + original_dim = _utils.state_dict_original_dimension_key() + + # if there are no model states in the rank_state_dict, no model aggregation is needed + if model not in rank_state_dict: + return + + if model not in state_dict: + state_dict[model] = {} + + if full_precision not in state_dict[model]: + state_dict[model][full_precision] = {} + + # iterate over all model state keys + for model_state_key, model_state_value in rank_state_dict[model][full_precision].items(): + if model_state_key in rank_state_dict[partition_info]: + # this model state is sharded since a record exists in the partition_info subdict + _add_or_update_sharded_key_for_zero(model_state_key, model_state_value, + state_dict[model][full_precision], model_state_key, + rank_state_dict[partition_info][model_state_key][original_dim], sharded_states_original_dims) + else: + # this model state is not sharded since a record for it does not exist in the partition_info subdict + _add_or_validate_unsharded_key_for_zero(model_state_key, model_state_value, + state_dict[model][full_precision], "Value mismatch for model state {}".format(model_state_key)) + +def _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict): + """Aggregates all optimizer states from the rank_state_dict into state_dict""" + + optimizer = _utils.state_dict_optimizer_key() + partition_info = _utils.state_dict_partition_info_key() + original_dim = _utils.state_dict_original_dimension_key() + sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys() + + # if there are no optimizer states in the rank_state_dict, no optimizer aggregation is needed + if optimizer not in rank_state_dict: + return + + if optimizer not in state_dict: + state_dict[optimizer] = {} + + # iterate over all optimizer state keys + for model_state_key, optimizer_dict in rank_state_dict[optimizer].items(): + for optimizer_key, optimizer_value in optimizer_dict.items(): + if model_state_key not in state_dict[optimizer]: + state_dict[optimizer][model_state_key] = {} + + if optimizer_key in sharded_optimizer_keys and model_state_key in rank_state_dict[partition_info]: + # this optimizer state is sharded since a record exists in the partition_info subdict + _add_or_update_sharded_key_for_zero(optimizer_key, optimizer_value, + state_dict[optimizer][model_state_key], model_state_key, + rank_state_dict[partition_info][model_state_key][original_dim], sharded_states_original_dims) + else: + # this optimizer state is not sharded since a record for it does not exist in the partition_info subdict + # or this optimizer key is not one of the sharded optimizer keys + _add_or_validate_unsharded_key_for_zero(optimizer_key, optimizer_value, + state_dict[optimizer][model_state_key], + "Value mismatch for model state {} and optimizer state {}".format(model_state_key, optimizer_key)) + +def _reshape_states(sharded_states_original_dims, state_dict): + """Reshape model and optimizer states in the state_dict according to dimensions in sharded_states_original_dims""" + + model = _utils.state_dict_model_key() + full_precision = _utils.state_dict_full_precision_key() + optimizer = _utils.state_dict_optimizer_key() + sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys() + + for sharded_state_key, original_dim in sharded_states_original_dims.items(): + # reshape model states to original_dim + if model in state_dict: + state_dict[model][full_precision][sharded_state_key] = \ + state_dict[model][full_precision][sharded_state_key].reshape(original_dim) + + # reshape optimizer states to original_dim + if optimizer in state_dict: + for optimizer_key, optimizer_value in state_dict[optimizer][sharded_state_key].items(): + if optimizer_key in sharded_optimizer_keys: + state_dict[optimizer][sharded_state_key][optimizer_key] = optimizer_value.reshape(original_dim) + +def _aggregate_trainer_options(rank_state_dict, state_dict): + """Extracts trainer options from rank_state_dict and loads them accordingly on state_dict""" + + state_dict[_utils.state_dict_trainer_options_key()] = {} + + mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key() + zero_stage = _utils.state_dict_trainer_options_zero_stage_key() + world_rank = _utils.state_dict_trainer_options_world_rank_key() + world_size = _utils.state_dict_trainer_options_world_size_key() + optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key() + + state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] = \ + rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] + state_dict[_utils.state_dict_trainer_options_key()][zero_stage] = 0 + state_dict[_utils.state_dict_trainer_options_key()][world_rank] = 0 + state_dict[_utils.state_dict_trainer_options_key()][world_size] = 1 + state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] = \ + rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] + +def aggregate_checkpoints(paths, pytorch_format=True): + """Aggregate checkpoint files and return a single state dictionary + + Aggregates checkpoint files specified by paths and laods the checkpoint file one at a time merging + them into a single state dictionary. + The checkpoint files represented by paths must be saved through ORTTrainer.save_checkpoint() function. + The schema of the state_dict returned will be in the same as the one returned by ORTTrainer.state_dict() + + Args: + paths: list of more than one file represented as strings where the checkpoint is saved + pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict + Returns: + state_dict that can be loaded into an ORTTrainer or into a PyTorch model + """ + + # order the paths in ascending order of ranks + ordered_paths = _order_paths(paths) + + state_dict = {} + sharded_states_original_dims = {} + world_rank = _utils.state_dict_trainer_options_world_rank_key() + mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key() + zero_stage = _utils.state_dict_trainer_options_zero_stage_key() + world_size = _utils.state_dict_trainer_options_world_size_key() + optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key() + + loaded_mixed_precision = None + loaded_world_size = None + loaded_zero_stage = None + loaded_optimizer_name = None + + for rank, path in enumerate(ordered_paths): + rank_state_dict = _checkpoint_storage.load(path) + + assert _utils.state_dict_partition_info_key() in rank_state_dict, "Missing information: partition_info" + assert _utils.state_dict_trainer_options_key() in rank_state_dict, "Missing information: trainer_options" + assert rank == rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank], \ + "Unexpected rank in file at path {}. Expected {}, got {}".\ + format(path, rank, rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank]) + if loaded_mixed_precision is None: + loaded_mixed_precision = rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] + else: + assert loaded_mixed_precision == rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision], \ + "Mixed precision state mismatch among checkpoint files. File: {}".format(path) + if loaded_world_size is None: + loaded_world_size = rank_state_dict[_utils.state_dict_trainer_options_key()][world_size] + else: + assert loaded_world_size == rank_state_dict[_utils.state_dict_trainer_options_key()][world_size], \ + "World size state mismatch among checkpoint files. File: {}".format(path) + if loaded_zero_stage is None: + loaded_zero_stage = rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage] + else: + assert loaded_zero_stage == rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage], \ + "Zero stage mismatch among checkpoint files. File: {}".format(path) + if loaded_optimizer_name is None: + loaded_optimizer_name = rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] + else: + assert loaded_optimizer_name == rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name], \ + "Optimizer name mismatch among checkpoint files. File: {}".format(path) + + # aggregate all model states + _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict) + + if not pytorch_format: + # aggregate all optimizer states if pytorch_format is False + _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict) + + # entry for trainer_options in the state_dict to perform other sanity checks + if _utils.state_dict_trainer_options_key() not in state_dict: + _aggregate_trainer_options(rank_state_dict, state_dict) + + # entry for user_dict in the state_dict if not already present + if _utils.state_dict_user_dict_key() not in state_dict and \ + _utils.state_dict_user_dict_key() in rank_state_dict: + state_dict[_utils.state_dict_user_dict_key()] = rank_state_dict[_utils.state_dict_user_dict_key()] + + # reshape all the sharded tensors based on the original dimensions stored in sharded_states_original_dims + _reshape_states(sharded_states_original_dims, state_dict) + + # return a flat structure for PyTorch model in case pytorch_format is True + # else return the hierarchical structure for ORTTrainer + return state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()] if pytorch_format else state_dict ################################################################################ # Helper functions @@ -201,7 +429,7 @@ def _split_name(self, name): name_split = name.split('_view_') view_num = None if(len(name_split) > 1): - view_num = int(name_split[1]) + view_num = int(name_split[1]) optimizer_key = '' mp_suffix = '' if name_split[0].startswith('Moment_1'): diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index 37c8b4ed51df1..8a1877c9fe84d 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -9,7 +9,7 @@ import numpy as np import onnxruntime as ort -from . import _utils, amp, checkpoint, optim, postprocess, ORTTrainerOptions +from . import _utils, amp, checkpoint, optim, postprocess, ORTTrainerOptions, _checkpoint_storage from .model_desc_validation import _ORTTrainerModelDesc from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference @@ -568,7 +568,7 @@ def forward(self, *inputs): return onnx_model - def _create_ort_training_session(self, state_dict = {}): + def _create_ort_training_session(self, optimizer_state_dict={}): # Validating frozen_weights names unused_frozen_weights = [n for n in self.options.utils.frozen_weights\ if n not in [i.name for i in self._onnx_model.graph.initializer]] @@ -639,8 +639,8 @@ def _create_ort_training_session(self, state_dict = {}): ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map if bool(self._optim_state_dict): ort_parameters.set_optimizer_initial_state(self._optim_state_dict) - if bool(state_dict) and bool(state_dict[_utils.state_dict_optimizer_key()]): - ort_parameters.set_optimizer_initial_state(state_dict[_utils.state_dict_optimizer_key()]) + if bool(optimizer_state_dict): + ort_parameters.set_optimizer_initial_state(optimizer_state_dict) ort_parameters.attn_dropout_recompute = self.options.graph_transformer.attn_dropout_recompute ort_parameters.gelu_recompute = self.options.graph_transformer.gelu_recompute @@ -688,13 +688,13 @@ def _init_onnx_model(self, inputs): if self.options._internal_use.extra_postprocess: self._onnx_model = self.options._internal_use.extra_postprocess(self._onnx_model) - state_dict = {} + optimizer_state_dict = {} if self._load_state_dict: - state_dict = self._load_state_dict() + optimizer_state_dict = self._load_state_dict() - self._init_session(state_dict) + self._init_session(optimizer_state_dict) - def _init_session(self, state_dict = {}): + def _init_session(self, optimizer_state_dict={}): if self._onnx_model is None: return @@ -703,7 +703,7 @@ def _init_session(self, state_dict = {}): # Create training session used by train_step # pass all optimizer states to the backend - self._create_ort_training_session(state_dict) + self._create_ort_training_session(optimizer_state_dict) # Update model description to update dtype when mixed precision is enabled # C++ backend modifies model's output dtype from float32 to float16 for mixed precision @@ -886,13 +886,19 @@ def _extract_model_states(self, state_dict, pytorch_format): def _extract_trainer_options(self, state_dict): """Extract relevant trainer configuration and load it into the state_dict""" + mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key() + zero_stage = _utils.state_dict_trainer_options_zero_stage_key() + world_rank = _utils.state_dict_trainer_options_world_rank_key() + world_size = _utils.state_dict_trainer_options_world_size_key() + optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key() + state_dict[_utils.state_dict_trainer_options_key()] = {} - state_dict[_utils.state_dict_trainer_options_key()]['mixed_precision'] = self.options.mixed_precision.enabled - state_dict[_utils.state_dict_trainer_options_key()]['zero_stage'] = \ - self.options.distributed.deepspeed_zero_optimization.stage or 0 - state_dict[_utils.state_dict_trainer_options_key()]['world_rank'] = self.options.distributed.world_rank or 0 - state_dict[_utils.state_dict_trainer_options_key()]['world_size'] = self.options.distributed.world_size or 1 - state_dict[_utils.state_dict_trainer_options_key()]['optimizer_name'] = self.optim_config.name + state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] = self.options.mixed_precision.enabled + state_dict[_utils.state_dict_trainer_options_key()][zero_stage] = \ + self.options.distributed.deepspeed_zero_optimization.stage + state_dict[_utils.state_dict_trainer_options_key()][world_rank] = self.options.distributed.world_rank + state_dict[_utils.state_dict_trainer_options_key()][world_size] = self.options.distributed.world_size + state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] = self.optim_config.name def state_dict(self, pytorch_format=False): """Returns a dictionary with model, and optionally, optimizer states @@ -911,7 +917,7 @@ def state_dict(self, pytorch_format=False): type: dict, schema: { - "fp32": + "full_precision": { type: dict, schema: @@ -1082,9 +1088,30 @@ def _load_model_states(self, state_dict, strict): def _load_optimizer_states(self, current_state_dict, state_dict): """Load the optimizer states onto the training session state dictionary""" + def _check_optimizer_mismatch(state_dict): + """Assert that the loaded optimizer has the same config as the current training session config""" + + # the state_dict optimizer_name can be a byte string (if coming from checkpoint file) + # or can be a regular string (coming from user) + optimizer_name = \ + state_dict[_utils.state_dict_trainer_options_key()][_utils.state_dict_trainer_options_optimizer_name_key()] + + # optimizer_name can be either a regular string or a byte string. + # if it is a byte string, convert to regular string using decode() + # if it is a regular string, do nothing to it + try: + optimizer_name = optimizer_name.decode() + except AttributeError: + pass + assert self.optim_config.name == optimizer_name, \ + "Optimizer mismatch: expected {}, got {}".format(self.optim_config.name, optimizer_name) + if _utils.state_dict_optimizer_key() not in state_dict: return + # check optimizer config names are the same for current session and the sessino being loaded + _check_optimizer_mismatch(state_dict) + # create an entry for the optimizer in the training session state dictionary if _utils.state_dict_optimizer_key() not in current_state_dict: current_state_dict[_utils.state_dict_optimizer_key()] = {} @@ -1179,7 +1206,8 @@ def _check_key_mismatch(current_state_dict, state_dict): # dictionary self._load_optimizer_states(current_state_dict, state_dict) - return current_state_dict + return current_state_dict[_utils.state_dict_optimizer_key()] if \ + _utils.state_dict_optimizer_key() in current_state_dict else {} def load_state_dict(self, state_dict, strict=True): """Loads state_dict containing model/optimizer states into ORTTrainer @@ -1203,8 +1231,80 @@ def load_state_dict(self, state_dict, strict=True): return # load states onto the frontend onnx graph - state_dict = self._load_state_dict_impl(state_dict, strict=strict) + optimizer_state_dict = self._load_state_dict_impl(state_dict, strict=strict) # create a new training session after loading initializer states onto the onnx graph # pass the populated states to the training session to populate the backend graph - self._init_session(state_dict) + self._init_session(optimizer_state_dict) + + def save_checkpoint(self, path, user_dict={}, include_optimizer_states=True): + """Persists ORTTrainer state dictionary on disk along with user_dict. + + Saves the state_dict along with the user_dict to a file specified by path. + + Args: + path: string representation to a file path or a python file-like object. + if file already exists at path, an exception is raised. + user_dict: custom data to be saved along with the state_dict. This data will be returned + to the user when load_checkpoint is called. + include_optimizer_states: boolean flag indicating whether or not to persist the optimizer states. + on load_checkpoint, only model states will be loaded if include_optimizer_states==True + """ + + # extract state_dict to be saved in the checkpoint + state_dict = self.state_dict() + + # if user_dict is provided, serialize to bytes and convert to hex string. + # this helps in loading the types as they are given by the user since hdf5 + # converts to numpy types otherwise + if bool(user_dict): + state_dict[_utils.state_dict_user_dict_key()] = _checkpoint_storage.to_serialized_hex(user_dict) + + # if include_optimizer_states is False, only save the model states in the checkpoint file + if not include_optimizer_states: + if _utils.state_dict_optimizer_key() in state_dict: + del state_dict[_utils.state_dict_optimizer_key()] + + _checkpoint_storage.save(state_dict, path) + + def _aggregation_required(self, loaded_trainer_options): + """Checks if aggregation is required for the loading the state_dict into the ORTTrainer""" + + # To load states in the backend, aggregation is required for every ZeRO checkpoint + return loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0 + + def load_checkpoint(self, *paths, strict=True): + """Loads the saved checkpoint state dictionary into the ORTTrainer + + Reads the saved checkpoint files specified by paths from disk and loads the state dictionary + onto the ORTTrainer. + Aggregates the checkpoint files if aggregation is required. + + Args: + paths: one or more files represented as strings where the checkpoint is saved + strict: boolean flag to strictly enforce that the saved checkpoint state_dict + keys match the keys from ORTTrainer.state_dict + Returns: + dictionary that the user had saved when calling save_checkpoint + """ + state_dict = {} + + # check if aggregation is required + loaded_trainer_options = _checkpoint_storage.load(paths[0], key=_utils.state_dict_trainer_options_key()) + if self._aggregation_required(loaded_trainer_options): + # if aggregation is required, aggregation logic must be run on the saved checkpoints + state_dict = checkpoint.aggregate_checkpoints(paths, pytorch_format=False) + else: + # if aggregation is not required, there must only be a single file that needs to be loaded + assert len(paths) == 1, "Expected number of files to load: 1, got {}".format(len(paths)) + state_dict = _checkpoint_storage.load(paths[0]) + + # extract user dict from the saved checkpoint + user_dict = {} + if _utils.state_dict_user_dict_key() in state_dict: + user_dict = _checkpoint_storage.from_serialized_hex(state_dict[_utils.state_dict_user_dict_key()]) + del state_dict[_utils.state_dict_user_dict_key()] + + self.load_state_dict(state_dict, strict=strict) + + return user_dict diff --git a/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py b/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py index 6ca7b131b2616..6beceab19b580 100644 --- a/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py +++ b/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py @@ -8,7 +8,6 @@ import os import shutil import pickle -import binascii from onnxruntime.training import _checkpoint_storage @@ -215,7 +214,7 @@ def test_checkpoint_storage_for_custom_user_dict_succeeds(checkpoint_storage_tes 'custom_class': custom_class } - pickled_bytes = binascii.b2a_hex(pickle.dumps(user_dict)) + pickled_bytes = pickle.dumps(user_dict).hex() to_save = { 'a': torch.tensor(np.array([1.0, 2.0]), dtype=torch.float32), 'user_dict': pickled_bytes @@ -224,7 +223,11 @@ def test_checkpoint_storage_for_custom_user_dict_succeeds(checkpoint_storage_tes loaded_dict = _checkpoint_storage.load(pytest.checkpoint_path) assert (loaded_dict['a'] == to_save['a'].numpy()).all() - loaded_obj = pickle.loads(binascii.a2b_hex(loaded_dict['user_dict'])) + try: + loaded_dict['user_dict'] = loaded_dict['user_dict'].decode() + except AttributeError: + pass + loaded_obj = pickle.loads(bytes.fromhex(loaded_dict['user_dict'])) assert torch.all(loaded_obj['tensor1'].eq(user_dict['tensor1'])) assert loaded_obj['custom_class'] == custom_class diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py index f3359f6f73e42..cade9e09c339d 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py @@ -1,7 +1,7 @@ import pytest from unittest.mock import patch, Mock from orttraining_test_orttrainer_frontend import _load_pytorch_transformer_model -from onnxruntime.training import amp, checkpoint, optim, orttrainer +from onnxruntime.training import amp, checkpoint, optim, orttrainer, _checkpoint_storage import numpy as np import onnx import torch @@ -59,7 +59,7 @@ def _get_load_state_dict_strict_error_arguments(): training_session_state_dict = { 'model': { - 'fp32': { + 'full_precision': { 'a': np.arange(5), 'b': np.arange(7) } @@ -77,20 +77,20 @@ def _get_load_state_dict_strict_error_arguments(): # input state dictionaries precision_key_missing = {'model': {}, 'optimizer': {}} - precision_key_unexpected = {'model': {'fp32': {}, 'fp16': {}}, 'optimizer': {}} - model_state_key_missing = {'model': {'fp32': {}}, 'optimizer': {}} - model_state_key_unexpected = {'model': {'fp32': {'a': 2, 'b': 3, 'c': 4}}, 'optimizer': {}} - optimizer_model_state_key_missing = {'model': {'fp32': {'a': 2, 'b': 3}}, 'optimizer': {}} - optimizer_model_state_key_unexpected = {'model': {'fp32': {'a': 2, 'b': 3}}, 'optimizer': \ + precision_key_unexpected = {'model': {'full_precision': {}, 'mixed_precision': {}}, 'optimizer': {}} + model_state_key_missing = {'model': {'full_precision': {}}, 'optimizer': {}} + model_state_key_unexpected = {'model': {'full_precision': {'a': 2, 'b': 3, 'c': 4}}, 'optimizer': {}} + optimizer_model_state_key_missing = {'model': {'full_precision': {'a': 2, 'b': 3}}, 'optimizer': {}} + optimizer_model_state_key_unexpected = {'model': {'full_precision': {'a': 2, 'b': 3}}, 'optimizer': \ {'a': {}, 'shared_optimizer_state': {}, 'b': {}}} - optimizer_state_key_missing = {'model': {'fp32': {'a': 2, 'b': 3}}, 'optimizer': \ + optimizer_state_key_missing = {'model': {'full_precision': {'a': 2, 'b': 3}}, 'optimizer': \ {'a': {}, 'shared_optimizer_state': {'step': np.arange(5)}}} - optimizer_state_key_unexpected = {'model': {'fp32': {'a': 2, 'b': 3}}, 'optimizer': \ + optimizer_state_key_unexpected = {'model': {'full_precision': {'a': 2, 'b': 3}}, 'optimizer': \ {'a': {'Moment_1': np.arange(5), 'Moment_2': np.arange(7)}, 'shared_optimizer_state': {'step': np.arange(5), 'another_step': np.arange(1)}}} input_arguments = [ - (training_session_state_dict, precision_key_missing, ['fp32']), - (training_session_state_dict, precision_key_unexpected, ['fp16']), + (training_session_state_dict, precision_key_missing, ['full_precision']), + (training_session_state_dict, precision_key_unexpected, ['mixed_precision']), (training_session_state_dict, model_state_key_missing, ['a', 'b']), (training_session_state_dict, model_state_key_unexpected, ['c']), (training_session_state_dict, optimizer_model_state_key_missing, ['a', 'shared_optimizer_state']), @@ -126,7 +126,7 @@ def test_training_session_provides_empty_model_states(onnx_model_mock): def test_training_session_provides_model_states(onnx_model_mock): trainer = _create_trainer() model_states = { - 'fp32': { + 'full_precision': { 'a': np.arange(5), 'b': np.arange(7) } @@ -136,14 +136,14 @@ def test_training_session_provides_model_states(onnx_model_mock): trainer._onnx_model = onnx_model_mock() state_dict = trainer.state_dict() - assert (state_dict['model']['fp32']['a'] == np.arange(5)).all() - assert (state_dict['model']['fp32']['b'] == np.arange(7)).all() + assert (state_dict['model']['full_precision']['a'] == np.arange(5)).all() + assert (state_dict['model']['full_precision']['b'] == np.arange(7)).all() @patch('onnx.ModelProto') def test_training_session_provides_model_states_pytorch_format(onnx_model_mock): trainer = _create_trainer() model_states = { - 'fp32': { + 'full_precision': { 'a': np.arange(5), 'b': np.arange(7) } @@ -160,7 +160,7 @@ def test_training_session_provides_model_states_pytorch_format(onnx_model_mock): def test_onnx_graph_provides_frozen_model_states(onnx_model_mock): trainer = _create_trainer() model_states = { - 'fp32': { + 'full_precision': { 'a': np.arange(5), 'b': np.arange(7) } @@ -176,11 +176,11 @@ def test_onnx_graph_provides_frozen_model_states(onnx_model_mock): ] state_dict = trainer.state_dict() - assert (state_dict['model']['fp32']['a'] == np.arange(5)).all() - assert (state_dict['model']['fp32']['b'] == np.arange(7)).all() - assert (state_dict['model']['fp32']['a_frozen_weight'] == np.array([1, 2, 3], dtype=np.float32)).all() - assert 'a_non_fronzen_weight' not in state_dict['model']['fp32'] - assert (state_dict['model']['fp32']['a_float16_weight'] == np.array([7, 8, 9], dtype=np.float32)).all() + assert (state_dict['model']['full_precision']['a'] == np.arange(5)).all() + assert (state_dict['model']['full_precision']['b'] == np.arange(7)).all() + assert (state_dict['model']['full_precision']['a_frozen_weight'] == np.array([1, 2, 3], dtype=np.float32)).all() + assert 'a_non_fronzen_weight' not in state_dict['model']['full_precision'] + assert (state_dict['model']['full_precision']['a_float16_weight'] == np.array([7, 8, 9], dtype=np.float32)).all() @patch('onnx.ModelProto') def test_training_session_provides_empty_optimizer_states(onnx_model_mock): @@ -217,7 +217,7 @@ def test_training_session_provides_optimizer_states(onnx_model_mock): def test_training_session_provides_optimizer_states_pytorch_format(onnx_model_mock): trainer = _create_trainer() model_states = { - 'fp32': { + 'full_precision': { 'a': np.arange(5), 'b': np.arange(7) } @@ -267,7 +267,7 @@ def test_training_session_provides_partition_info_map(onnx_model_mock): def test_training_session_provides_all_states(onnx_model_mock): trainer = _create_trainer(zero_enabled=True) model_states = { - 'fp32': { + 'full_precision': { 'a': np.arange(5), 'b': np.arange(7) } @@ -291,8 +291,8 @@ def test_training_session_provides_all_states(onnx_model_mock): trainer._onnx_model = onnx_model_mock() state_dict = trainer.state_dict() - assert (state_dict['model']['fp32']['a'] == np.arange(5)).all() - assert (state_dict['model']['fp32']['b'] == np.arange(7)).all() + assert (state_dict['model']['full_precision']['a'] == np.arange(5)).all() + assert (state_dict['model']['full_precision']['b'] == np.arange(7)).all() assert (state_dict['optimizer']['model_weight']['Moment_1'] == np.arange(5)).all() assert (state_dict['optimizer']['model_weight']['Moment_2'] == np.arange(7)).all() assert (state_dict['optimizer']['shared_optimizer_state']['step'] == np.arange(1)).all() @@ -302,7 +302,7 @@ def test_load_state_dict_holds_when_training_session_not_initialized(): trainer = _create_trainer() state_dict = { 'model': { - 'fp32': { + 'full_precision': { 'a': np.arange(5), 'b': np.arange(7) } @@ -321,7 +321,27 @@ def test_load_state_dict_holds_when_training_session_not_initialized(): state_dict = trainer.load_state_dict(state_dict) assert trainer._load_state_dict -@pytest.mark.parametrize("state_dict, input_state_dict, error_key", [({'optimizer':{}}, {'optimizer':{}}, 'model'), ({'model':{}}, {'model':{}}, 'optimizer')]) +@pytest.mark.parametrize("state_dict, input_state_dict, error_key", [ + ({ + 'optimizer':{}, + }, + { + 'optimizer':{}, + 'trainer_options': { + 'optimizer_name': 'LambOptimizer' + } + }, + 'model'), + ({ + 'model':{} + }, + { + 'model':{}, + 'trainer_options': { + 'optimizer_name': 'LambOptimizer' + } + }, + 'optimizer')]) def test_load_state_dict_warns_when_model_optimizer_key_missing(state_dict, input_state_dict, error_key): trainer = _create_trainer() trainer._training_session = _training_session_mock({}, {}, {}) @@ -351,7 +371,7 @@ def test_load_state_dict_loads_the_states_and_inits_training_session(onnx_model_ trainer = _create_trainer() training_session_state_dict = { 'model': { - 'fp32': { + 'full_precision': { 'a': np.arange(5), 'b': np.arange(7) } @@ -369,7 +389,7 @@ def test_load_state_dict_loads_the_states_and_inits_training_session(onnx_model_ input_state_dict = { 'model': { - 'fp32': { + 'full_precision': { 'a': np.array([1, 2]), 'b': np.array([3, 4]) } @@ -382,6 +402,9 @@ def test_load_state_dict_loads_the_states_and_inits_training_session(onnx_model_ 'shared_optimizer_state': { 'step': np.array([9]) } + }, + 'trainer_options': { + 'optimizer_name': 'LambOptimizer' } } trainer._training_session = _training_session_mock({}, {}, {}) @@ -404,6 +427,252 @@ def test_load_state_dict_loads_the_states_and_inits_training_session(onnx_model_ assert 'b' in loaded_initializers[0] assert (loaded_initializers[0]['b'] == np.array([3, 4])).all() - assert (state_dict_to_load[0]['optimizer']['a']['Moment_1'] == np.array([5, 6])).all() - assert (state_dict_to_load[0]['optimizer']['a']['Moment_2'] == np.array([7, 8])).all() - assert (state_dict_to_load[0]['optimizer']['shared_optimizer_state']['step'] == np.array([9])).all() + assert (state_dict_to_load[0]['a']['Moment_1'] == np.array([5, 6])).all() + assert (state_dict_to_load[0]['a']['Moment_2'] == np.array([7, 8])).all() + assert (state_dict_to_load[0]['shared_optimizer_state']['step'] == np.array([9])).all() + +@patch('onnxruntime.training._checkpoint_storage.save') +def test_save_checkpoint_calls_checkpoint_storage_save(save_mock): + trainer = _create_trainer() + state_dict = { + 'model': {}, + 'optimizer': {} + } + trainer.state_dict = Mock(return_value=state_dict) + + trainer.save_checkpoint('abc') + + save_args, _ = save_mock.call_args + assert 'model' in save_args[0] + assert not bool(save_args[0]['model']) + assert 'optimizer' in save_args[0] + assert not bool(save_args[0]['optimizer']) + assert save_args[1] == 'abc' + +@patch('onnxruntime.training._checkpoint_storage.save') +def test_save_checkpoint_exclude_optimizer_states(save_mock): + trainer = _create_trainer() + state_dict = { + 'model': {}, + 'optimizer': {} + } + trainer.state_dict = Mock(return_value=state_dict) + + trainer.save_checkpoint('abc', include_optimizer_states=False) + + save_args, _ = save_mock.call_args + assert 'model' in save_args[0] + assert not bool(save_args[0]['model']) + assert 'optimizer' not in save_args[0] + assert save_args[1] == 'abc' + +@patch('onnxruntime.training._checkpoint_storage.save') +def test_save_checkpoint_user_dict(save_mock): + trainer = _create_trainer() + state_dict = { + 'model': {}, + 'optimizer': {} + } + trainer.state_dict = Mock(return_value=state_dict) + + trainer.save_checkpoint('abc', user_dict={'abc': np.arange(4)}) + + save_args, _ = save_mock.call_args + assert 'user_dict' in save_args[0] + assert save_args[0]['user_dict'] == _checkpoint_storage.to_serialized_hex({'abc': np.arange(4)}) + +@patch('onnxruntime.training._checkpoint_storage.load') +@patch('onnxruntime.training.checkpoint.aggregate_checkpoints') +def test_load_checkpoint(aggregate_checkpoints_mock, load_mock): + trainer = _create_trainer() + trainer_options = { + 'mixed_precision': np.bool_(False), + 'world_rank': np.int64(0), + 'world_size': np.int64(1), + 'zero_stage': np.int64(0) + } + state_dict = { + 'model': {}, + 'optimizer': {}, + 'trainer_options': { + 'mixed_precision': np.bool_(False), + 'world_rank': np.int64(0), + 'world_size': np.int64(1), + 'zero_stage': np.int64(0) + } + } + trainer.load_state_dict = Mock() + + load_mock.side_effect = [trainer_options, state_dict] + trainer.load_checkpoint('abc') + + args_list = load_mock.call_args_list + load_args, load_kwargs = args_list[0] + assert load_args[0] == 'abc' + assert load_kwargs['key'] == 'trainer_options' + load_args, load_kwargs = args_list[1] + assert load_args[0] == 'abc' + assert 'key' not in load_kwargs + assert not aggregate_checkpoints_mock.called + +@patch('onnxruntime.training._checkpoint_storage.load') +@patch('onnxruntime.training.checkpoint.aggregate_checkpoints') +@pytest.mark.parametrize("trainer_options", [ + { + 'mixed_precision': np.bool_(False), + 'world_rank': np.int64(0), + 'world_size': np.int64(4), + 'zero_stage': np.int64(1) + }, + { + 'mixed_precision': np.bool_(True), + 'world_rank': np.int64(0), + 'world_size': np.int64(1), + 'zero_stage': np.int64(1) + }, + { + 'mixed_precision': np.bool_(True), + 'world_rank': np.int64(0), + 'world_size': np.int64(1), + 'zero_stage': np.int64(1) + } +]) +def test_load_checkpoint_aggregation_required_zero_enabled(aggregate_checkpoints_mock, load_mock, trainer_options): + trainer = _create_trainer() + trainer.load_state_dict = Mock() + + load_mock.side_effect = [trainer_options] + trainer.load_checkpoint('abc') + + args_list = load_mock.call_args_list + load_args, load_kwargs = args_list[0] + assert load_args[0] == 'abc' + assert load_kwargs['key'] == 'trainer_options' + assert aggregate_checkpoints_mock.called + call_args, _ = aggregate_checkpoints_mock.call_args + assert call_args[0] == tuple(['abc']) + +@patch('onnxruntime.training._checkpoint_storage.load') +@patch('onnxruntime.training.checkpoint.aggregate_checkpoints') +def test_load_checkpoint_user_dict(aggregate_checkpoints_mock, load_mock): + trainer = _create_trainer() + trainer_options = { + 'mixed_precision': np.bool_(False), + 'world_rank': np.int64(0), + 'world_size': np.int64(1), + 'zero_stage': np.int64(0) + } + state_dict = { + 'model': {}, + 'optimizer': {}, + 'trainer_options': { + 'mixed_precision': np.bool_(False), + 'world_rank': np.int64(0), + 'world_size': np.int64(1), + 'zero_stage': np.int64(0) + }, + 'user_dict': _checkpoint_storage.to_serialized_hex({'array': torch.tensor(np.arange(5))}) + } + trainer.load_state_dict = Mock() + + load_mock.side_effect = [trainer_options, state_dict] + user_dict = trainer.load_checkpoint('abc') + + assert torch.all(torch.eq(user_dict['array'], torch.tensor(np.arange(5)))) + +@patch('onnxruntime.training._checkpoint_storage.load') +def test_checkpoint_aggregation(load_mock): + trainer_options1 = { + 'mixed_precision': np.bool_(False), + 'world_rank': np.int64(0), + 'world_size': np.int64(2), + 'zero_stage': np.int64(1), + 'optimizer_name': b'Adam' + } + trainer_options2 = { + 'mixed_precision': np.bool_(False), + 'world_rank': np.int64(1), + 'world_size': np.int64(2), + 'zero_stage': np.int64(1), + 'optimizer_name': b'Adam' + } + + state_dict1 = { + 'model': { + 'full_precision': { + 'sharded': np.array([1, 2, 3]), + 'non_sharded': np.array([11, 22, 33]) + } + }, + 'optimizer': { + 'sharded': { + 'Moment_1': np.array([9, 8, 7]), + 'Moment_2': np.array([99, 88, 77]), + 'Step': np.array([5]) + }, + 'non_sharded': { + 'Moment_1': np.array([666, 555, 444]), + 'Moment_2': np.array([6666, 5555, 4444]), + 'Step': np.array([55]) + } + }, + 'trainer_options': { + 'mixed_precision': np.bool_(False), + 'world_rank': np.int64(0), + 'world_size': np.int64(1), + 'zero_stage': np.int64(0), + 'optimizer_name': b'Adam' + }, + 'partition_info': { + 'sharded': {'original_dim': np.array([2, 3])} + } + } + + state_dict2 = { + 'model': { + 'full_precision': { + 'sharded': np.array([4, 5, 6]), + 'non_sharded': np.array([11, 22, 33]) + } + }, + 'optimizer': { + 'sharded': { + 'Moment_1': np.array([6, 5, 4]), + 'Moment_2': np.array([66, 55, 44]), + 'Step': np.array([5]) + }, + 'non_sharded': { + 'Moment_1': np.array([666, 555, 444]), + 'Moment_2': np.array([6666, 5555, 4444]), + 'Step': np.array([55]) + } + }, + 'trainer_options': { + 'mixed_precision': np.bool_(False), + 'world_rank': np.int64(1), + 'world_size': np.int64(1), + 'zero_stage': np.int64(0), + 'optimizer_name': b'Adam' + }, + 'partition_info': { + 'sharded': {'original_dim': np.array([2, 3])} + } + } + + load_mock.side_effect = [trainer_options1, trainer_options2, state_dict1, state_dict2] + state_dict = checkpoint.aggregate_checkpoints(['abc', 'def'], pytorch_format=False) + + assert (state_dict['model']['full_precision']['sharded'] == np.array([[1, 2, 3], [4, 5, 6]])).all() + assert (state_dict['model']['full_precision']['non_sharded'] == np.array([11, 22, 33])).all() + assert (state_dict['optimizer']['sharded']['Moment_1'] == np.array([[9, 8, 7], [6, 5, 4]])).all() + assert (state_dict['optimizer']['sharded']['Moment_2'] == np.array([[99, 88, 77], [66, 55, 44]])).all() + assert (state_dict['optimizer']['sharded']['Step'] == np.array([5])).all() + assert (state_dict['optimizer']['non_sharded']['Moment_1'] == np.array([666, 555, 444])).all() + assert (state_dict['optimizer']['non_sharded']['Moment_2'] == np.array([6666, 5555, 4444])).all() + assert (state_dict['optimizer']['non_sharded']['Step'] == np.array([55])).all() + + assert state_dict['trainer_options']['mixed_precision'] == False + assert state_dict['trainer_options']['world_rank'] == 0 + assert state_dict['trainer_options']['world_size'] == 1 + assert state_dict['trainer_options']['zero_stage'] == 0 + assert state_dict['trainer_options']['optimizer_name'] == b'Adam'