Skip to content

Commit

Permalink
save_checkpoint, load_checkpoint and aggregate_checkpoints (microsoft…
Browse files Browse the repository at this point in the history
…#6136)

* save_checkpoint and load_checkpoint implementations

* checkpoint aggregation logic

* unit tests for save_checkpoint, load_checkpoint and aggregate_checkpoints
  • Loading branch information
baijumeswani authored Dec 18, 2020
1 parent c339bb2 commit adc2071
Show file tree
Hide file tree
Showing 6 changed files with 717 additions and 56 deletions.
18 changes: 18 additions & 0 deletions orttraining/orttraining/python/training/_checkpoint_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import h5py
from collections.abc import Mapping
import pickle

def _dfs_save(group, save_obj):
"""Recursively go over each level in the save_obj dictionary and save values to a hdf5 group"""
Expand Down Expand Up @@ -79,3 +80,20 @@ def load(path, key=None):
_dfs_load(f, load_obj)

return load_obj

def to_serialized_hex(user_dict):
"""Serialize the user_dict and convert the serialized bytes to a hex string and return"""

return pickle.dumps(user_dict).hex()

def from_serialized_hex(serialized_hex):
"""Convert serialized_hex to bytes and deserialize it and return"""

# serialized_hex can be either a regular string or a byte string.
# if it is a byte string, convert to regular string using decode()
# if it is a regular string, do nothing to it
try:
serialized_hex = serialized_hex.decode()
except AttributeError:
pass
return pickle.loads(bytes.fromhex(serialized_hex))
45 changes: 44 additions & 1 deletion orttraining/orttraining/python/training/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,4 +202,47 @@ def state_dict_trainer_options_key():
def state_dict_full_precision_key():
"""Returns the full precision key name in the state dictionary"""

return 'fp32'
return 'full_precision'

def state_dict_original_dimension_key():
"""Returns the original dimension key name in the state dictionary"""

return 'original_dim'

def state_dict_sharded_optimizer_keys():
"""Returns the optimizer key names that can be sharded in the state dictionary"""

return {
'Moment_1',
'Moment_2'
}

def state_dict_user_dict_key():
"""Returns the user dict key name in the state dictionary"""

return 'user_dict'

def state_dict_trainer_options_mixed_precision_key():
"""Returns the trainer options mixed precision key name in the state dictionary"""

return 'mixed_precision'

def state_dict_trainer_options_zero_stage_key():
"""Returns the trainer options zero_stage key name in the state dictionary"""

return 'zero_stage'

def state_dict_trainer_options_world_rank_key():
"""Returns the trainer options world_rank key name in the state dictionary"""

return 'world_rank'

def state_dict_trainer_options_world_size_key():
"""Returns the trainer options world_size key name in the state dictionary"""

return 'world_size'

def state_dict_trainer_options_optimizer_name_key():
"""Returns the trainer options optimizer_name key name in the state dictionary"""

return 'optimizer_name'
230 changes: 229 additions & 1 deletion orttraining/orttraining/python/training/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import torch
import warnings
from . import _checkpoint_storage, _utils


################################################################################
Expand Down Expand Up @@ -108,6 +109,233 @@ def experimental_load_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix=
else:
return _load_single_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, is_partitioned, strict)

def _order_paths(paths):
"""Reorders the given paths in ascending order of rank and return the ordered list"""

trainer_options_path_tuples = []
world_rank = _utils.state_dict_trainer_options_world_rank_key()

for path in paths:
trainer_options_path_tuples.append((_checkpoint_storage.load(path,
key=_utils.state_dict_trainer_options_key()), path))

ordered_paths = [path for _, path in sorted(trainer_options_path_tuples,
key=lambda trainer_options_path_pair: trainer_options_path_pair[0][world_rank])]

return ordered_paths

def _add_or_update_sharded_key_for_zero(state_key, state_value, state_sub_dict,
model_state_key, original_dim, sharded_states_original_dims):
"""Add or update the record for the sharded state_key in the state_sub_dict"""

# record the original dimension for this state
sharded_states_original_dims[model_state_key] = original_dim

if state_key in state_sub_dict:
# state_dict already contains a record for this state
# since this state is sharded, concatenate the state value to
# the record in the state_dict
state_sub_dict[state_key] = \
np.concatenate((state_sub_dict[state_key], state_value))
else:
# create a new entry for this state in the state_dict
state_sub_dict[state_key] = state_value

def _add_or_validate_unsharded_key_for_zero(state_key, state_value, state_sub_dict, mismatch_error_string):
"""Add or validate the record for the unsharded state_key in the state_sub_dict"""

if state_key in state_sub_dict:
# state_dict already contains a record for this unsharded state.
# assert that all values are the same for this previously loaded state
assert (state_sub_dict[state_key] == state_value).all(), mismatch_error_string
else:
# create a new entry for this state in the state_sub_dict
state_sub_dict[state_key] = state_value

def _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict):
"""Aggregates all model states from the rank_state_dict into state_dict"""

model = _utils.state_dict_model_key()
full_precision = _utils.state_dict_full_precision_key()
partition_info = _utils.state_dict_partition_info_key()
original_dim = _utils.state_dict_original_dimension_key()

# if there are no model states in the rank_state_dict, no model aggregation is needed
if model not in rank_state_dict:
return

if model not in state_dict:
state_dict[model] = {}

if full_precision not in state_dict[model]:
state_dict[model][full_precision] = {}

# iterate over all model state keys
for model_state_key, model_state_value in rank_state_dict[model][full_precision].items():
if model_state_key in rank_state_dict[partition_info]:
# this model state is sharded since a record exists in the partition_info subdict
_add_or_update_sharded_key_for_zero(model_state_key, model_state_value,
state_dict[model][full_precision], model_state_key,
rank_state_dict[partition_info][model_state_key][original_dim], sharded_states_original_dims)
else:
# this model state is not sharded since a record for it does not exist in the partition_info subdict
_add_or_validate_unsharded_key_for_zero(model_state_key, model_state_value,
state_dict[model][full_precision], "Value mismatch for model state {}".format(model_state_key))

def _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict):
"""Aggregates all optimizer states from the rank_state_dict into state_dict"""

optimizer = _utils.state_dict_optimizer_key()
partition_info = _utils.state_dict_partition_info_key()
original_dim = _utils.state_dict_original_dimension_key()
sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys()

# if there are no optimizer states in the rank_state_dict, no optimizer aggregation is needed
if optimizer not in rank_state_dict:
return

if optimizer not in state_dict:
state_dict[optimizer] = {}

# iterate over all optimizer state keys
for model_state_key, optimizer_dict in rank_state_dict[optimizer].items():
for optimizer_key, optimizer_value in optimizer_dict.items():
if model_state_key not in state_dict[optimizer]:
state_dict[optimizer][model_state_key] = {}

if optimizer_key in sharded_optimizer_keys and model_state_key in rank_state_dict[partition_info]:
# this optimizer state is sharded since a record exists in the partition_info subdict
_add_or_update_sharded_key_for_zero(optimizer_key, optimizer_value,
state_dict[optimizer][model_state_key], model_state_key,
rank_state_dict[partition_info][model_state_key][original_dim], sharded_states_original_dims)
else:
# this optimizer state is not sharded since a record for it does not exist in the partition_info subdict
# or this optimizer key is not one of the sharded optimizer keys
_add_or_validate_unsharded_key_for_zero(optimizer_key, optimizer_value,
state_dict[optimizer][model_state_key],
"Value mismatch for model state {} and optimizer state {}".format(model_state_key, optimizer_key))

def _reshape_states(sharded_states_original_dims, state_dict):
"""Reshape model and optimizer states in the state_dict according to dimensions in sharded_states_original_dims"""

model = _utils.state_dict_model_key()
full_precision = _utils.state_dict_full_precision_key()
optimizer = _utils.state_dict_optimizer_key()
sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys()

for sharded_state_key, original_dim in sharded_states_original_dims.items():
# reshape model states to original_dim
if model in state_dict:
state_dict[model][full_precision][sharded_state_key] = \
state_dict[model][full_precision][sharded_state_key].reshape(original_dim)

# reshape optimizer states to original_dim
if optimizer in state_dict:
for optimizer_key, optimizer_value in state_dict[optimizer][sharded_state_key].items():
if optimizer_key in sharded_optimizer_keys:
state_dict[optimizer][sharded_state_key][optimizer_key] = optimizer_value.reshape(original_dim)

def _aggregate_trainer_options(rank_state_dict, state_dict):
"""Extracts trainer options from rank_state_dict and loads them accordingly on state_dict"""

state_dict[_utils.state_dict_trainer_options_key()] = {}

mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key()
zero_stage = _utils.state_dict_trainer_options_zero_stage_key()
world_rank = _utils.state_dict_trainer_options_world_rank_key()
world_size = _utils.state_dict_trainer_options_world_size_key()
optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key()

state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] = \
rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision]
state_dict[_utils.state_dict_trainer_options_key()][zero_stage] = 0
state_dict[_utils.state_dict_trainer_options_key()][world_rank] = 0
state_dict[_utils.state_dict_trainer_options_key()][world_size] = 1
state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] = \
rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name]

def aggregate_checkpoints(paths, pytorch_format=True):
"""Aggregate checkpoint files and return a single state dictionary
Aggregates checkpoint files specified by paths and laods the checkpoint file one at a time merging
them into a single state dictionary.
The checkpoint files represented by paths must be saved through ORTTrainer.save_checkpoint() function.
The schema of the state_dict returned will be in the same as the one returned by ORTTrainer.state_dict()
Args:
paths: list of more than one file represented as strings where the checkpoint is saved
pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict
Returns:
state_dict that can be loaded into an ORTTrainer or into a PyTorch model
"""

# order the paths in ascending order of ranks
ordered_paths = _order_paths(paths)

state_dict = {}
sharded_states_original_dims = {}
world_rank = _utils.state_dict_trainer_options_world_rank_key()
mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key()
zero_stage = _utils.state_dict_trainer_options_zero_stage_key()
world_size = _utils.state_dict_trainer_options_world_size_key()
optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key()

loaded_mixed_precision = None
loaded_world_size = None
loaded_zero_stage = None
loaded_optimizer_name = None

for rank, path in enumerate(ordered_paths):
rank_state_dict = _checkpoint_storage.load(path)

assert _utils.state_dict_partition_info_key() in rank_state_dict, "Missing information: partition_info"
assert _utils.state_dict_trainer_options_key() in rank_state_dict, "Missing information: trainer_options"
assert rank == rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank], \
"Unexpected rank in file at path {}. Expected {}, got {}".\
format(path, rank, rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank])
if loaded_mixed_precision is None:
loaded_mixed_precision = rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision]
else:
assert loaded_mixed_precision == rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision], \
"Mixed precision state mismatch among checkpoint files. File: {}".format(path)
if loaded_world_size is None:
loaded_world_size = rank_state_dict[_utils.state_dict_trainer_options_key()][world_size]
else:
assert loaded_world_size == rank_state_dict[_utils.state_dict_trainer_options_key()][world_size], \
"World size state mismatch among checkpoint files. File: {}".format(path)
if loaded_zero_stage is None:
loaded_zero_stage = rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage]
else:
assert loaded_zero_stage == rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage], \
"Zero stage mismatch among checkpoint files. File: {}".format(path)
if loaded_optimizer_name is None:
loaded_optimizer_name = rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name]
else:
assert loaded_optimizer_name == rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name], \
"Optimizer name mismatch among checkpoint files. File: {}".format(path)

# aggregate all model states
_aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict)

if not pytorch_format:
# aggregate all optimizer states if pytorch_format is False
_aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict)

# entry for trainer_options in the state_dict to perform other sanity checks
if _utils.state_dict_trainer_options_key() not in state_dict:
_aggregate_trainer_options(rank_state_dict, state_dict)

# entry for user_dict in the state_dict if not already present
if _utils.state_dict_user_dict_key() not in state_dict and \
_utils.state_dict_user_dict_key() in rank_state_dict:
state_dict[_utils.state_dict_user_dict_key()] = rank_state_dict[_utils.state_dict_user_dict_key()]

# reshape all the sharded tensors based on the original dimensions stored in sharded_states_original_dims
_reshape_states(sharded_states_original_dims, state_dict)

# return a flat structure for PyTorch model in case pytorch_format is True
# else return the hierarchical structure for ORTTrainer
return state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()] if pytorch_format else state_dict

################################################################################
# Helper functions
Expand Down Expand Up @@ -201,7 +429,7 @@ def _split_name(self, name):
name_split = name.split('_view_')
view_num = None
if(len(name_split) > 1):
view_num = int(name_split[1])
view_num = int(name_split[1])
optimizer_key = ''
mp_suffix = ''
if name_split[0].startswith('Moment_1'):
Expand Down
Loading

0 comments on commit adc2071

Please sign in to comment.