1- import copy
1+ import json
22import logging
33import os
4- from functools import reduce
54from pathlib import Path
6- from shutil import rmtree
75from typing import Dict , Iterator , Optional , OrderedDict , Tuple
8- import json
96
107import torch
118import torch .distributed as dist
129import torch .nn as nn
1310from torch .distributed import ProcessGroup
14- from torch .optim .lr_scheduler import _LRScheduler as LRScheduler
15- from torch .utils ._pytree import tree_map
11+ from torch .distributed .distributed_c10d import _get_default_group
1612
1713from colossalai .cluster import DistCoordinator
18- from colossalai .interface import ModelWrapper , OptimizerWrapper
19- from colossalai .tensor .padded_tensor import (
20- init_as_padded_tensor ,
21- is_padded_tensor ,
22- to_padded_tensor ,
23- to_unpadded_tensor ,
24- )
25- from colossalai .utils import get_current_device , get_non_persistent_buffers_set
26- from torch .distributed .distributed_c10d import _get_default_group
14+ from colossalai .interface import ModelWrapper
15+ from colossalai .utils import get_non_persistent_buffers_set
2716
2817from .general_checkpoint_io import GeneralCheckpointIO
2918from .index_file import CheckpointIndexFile
3019from .utils import (
3120 StateDictSharder ,
3221 async_save_state_dict_shards ,
3322 create_pinned_state_dict ,
34- gather_distributed_param ,
3523 get_model_base_filenames ,
36- get_optimizer_base_filenames ,
37- is_safetensors_available ,
38- load_shard_state_dict ,
3924 load_state_dict ,
40- load_state_dict_into_model ,
41- load_states_into_optimizer ,
42- save_config_file ,
43- save_param_groups ,
4425 save_state_dict ,
4526 save_state_dict_shards ,
46- search_padding_dim ,
4727 search_tp_partition_dim ,
48- sharded_optimizer_loading_epilogue ,
4928)
5029
5130try :
@@ -97,7 +76,6 @@ def __init__(
9776 self .model_metadata = None
9877 self .optimizer_metadata = None
9978 self .global_rank = dist .get_rank (_get_default_group ())
100-
10179
10280 @staticmethod
10381 def model_state_dict (model : nn .Module , prefix : str = "" , keep_vars : bool = False ):
@@ -106,13 +84,13 @@ def model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False
10684 for name , param in model .named_parameters ():
10785 if param is None :
10886 continue
109- destination [prefix + name ] = param
87+ destination [prefix + name ] = param
11088 # Save buffers.
11189 non_persist_buffers_set = get_non_persistent_buffers_set (model )
11290 for name , buf in model .named_buffers ():
11391 if buf is not None and name not in non_persist_buffers_set :
11492 buffer = buf if keep_vars else buf .detach ()
115- destination [prefix + name ] = buffer
93+ destination [prefix + name ] = buffer
11694
11795 # Save extra states.
11896 extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
@@ -123,22 +101,24 @@ def model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False
123101 extra_state = model .get_extra_state ()
124102 destination [extra_state_key ] = extra_state
125103 return destination
126-
104+
127105 @staticmethod
128- def load_state_dict (model : nn .Module , state_dict : Dict , prefix : str = "" , keep_vars : bool = False , strict : bool = False ):
106+ def load_state_dict (
107+ model : nn .Module , state_dict : Dict , prefix : str = "" , keep_vars : bool = False , strict : bool = False
108+ ):
129109 destination = dict ()
130110 # Save parameters.
131111 for name , param in model .named_parameters ():
132112 if param is None :
133113 continue
134114 with torch .no_grad ():
135- param .copy_ (state_dict [prefix + name ])
115+ param .copy_ (state_dict [prefix + name ])
136116 # Save buffers.
137117 non_persist_buffers_set = get_non_persistent_buffers_set (model )
138118 for name , buf in model .named_buffers ():
139119 if buf is not None and name not in non_persist_buffers_set :
140120 with torch .no_grad ():
141- buf .copy_ (state_dict [prefix + name ])
121+ buf .copy_ (state_dict [prefix + name ])
142122
143123 # Save extra states.
144124 extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
@@ -151,26 +131,33 @@ def load_state_dict(model: nn.Module, state_dict: Dict, prefix: str = "", keep_v
151131 extra_state .copy_ (state_dict [extra_state_key ])
152132 return destination
153133
154- def create_model_metadata (self , model : nn .Module , prefix : str = "" ,):
134+ def create_model_metadata (
135+ self ,
136+ model : nn .Module ,
137+ prefix : str = "" ,
138+ ):
155139 param_origin_shape = model .param_origin_shape
156140 model = model .unwrap ()
157141 self .model_metadata = {}
158142 for name , param in model .named_parameters ():
159143 if param is None :
160144 continue
161- self .model_metadata [prefix + name ] = {}
145+ self .model_metadata [prefix + name ] = {}
162146 original_shape = param_origin_shape [name ]
163- tp_partition_dim = search_tp_partition_dim (current_shape = param .shape , original_shape = original_shape , tp_size = self .tp_size )
164- self .model_metadata [prefix + name ]["offsets" ] = torch .zeros (len (original_shape ), dtype = torch .int )
165- self .model_metadata [prefix + name ]["lengths" ] = list (param .shape )
166- self .model_metadata [prefix + name ]["global_shape" ] = list (original_shape )
147+ tp_partition_dim = search_tp_partition_dim (
148+ current_shape = param .shape , original_shape = original_shape , tp_size = self .tp_size
149+ )
150+ self .model_metadata [prefix + name ]["offsets" ] = torch .zeros (len (original_shape ), dtype = torch .int )
151+ self .model_metadata [prefix + name ]["lengths" ] = list (param .shape )
152+ self .model_metadata [prefix + name ]["global_shape" ] = list (original_shape )
167153 if tp_partition_dim is not None :
168154 partition_size = param .shape [tp_partition_dim ]
169- self .model_metadata [prefix + name ]["offsets" ][tp_partition_dim ] = partition_size * self .tp_rank
155+ self .model_metadata [prefix + name ]["offsets" ][tp_partition_dim ] = partition_size * self .tp_rank
170156 if self .tp_rank == self .tp_size - 1 :
171- self .model_metadata [prefix + name ]["lengths" ][tp_partition_dim ] = original_shape [tp_partition_dim ] - (partition_size * (self .tp_size - 1 ))
157+ self .model_metadata [prefix + name ]["lengths" ][tp_partition_dim ] = original_shape [
158+ tp_partition_dim
159+ ] - (partition_size * (self .tp_size - 1 ))
172160
173-
174161 def save_metadata (self , metadata_file , checkpoint_file = None , total_size = None ):
175162 metadata_dicts = {
176163 "checkpoint_version" : "1.0" ,
@@ -188,7 +175,7 @@ def save_metadata(self, metadata_file, checkpoint_file=None, total_size=None):
188175 metadata_dicts ["metadata" ][name ]["rank" ] = self .global_rank
189176 with open (metadata_file , "w" ) as json_file :
190177 json .dump (metadata_dicts , json_file , indent = 4 )
191-
178+
192179 def save_unsharded_model (
193180 self , model : ModelWrapper , checkpoint : str , gather_dtensor : bool , use_safetensors : bool , use_async : bool = False
194181 ):
@@ -249,13 +236,13 @@ def load_metadata(self, checkpoint: str):
249236 try :
250237 with open (file_path , "r" ) as f :
251238 metadata_json = json .load (f )
252- for name , item in metadata_json [' metadata' ].items ():
239+ for name , item in metadata_json [" metadata" ].items ():
253240 if name not in metadata_dict :
254241 metadata_dict [name ] = {}
255- metadata_dict [name ]["global_shape" ] = item [' global_shape' ]
242+ metadata_dict [name ]["global_shape" ] = item [" global_shape" ]
256243 metadata_dict [name ]["shards" ] = {}
257244 else :
258- assert metadata_dict [name ]["global_shape" ] == item [' global_shape' ]
245+ assert metadata_dict [name ]["global_shape" ] == item [" global_shape" ]
259246 shard = {}
260247 shard [item ["rank" ]] = {}
261248 shard [item ["rank" ]]["file" ] = item ["file" ]
@@ -304,7 +291,7 @@ def find_covering_shards(self, shards, target_offsets, target_lengths):
304291
305292 assert total_lengths == global_shape
306293 return covering_shards
307-
294+
308295 def extract_weight_from_shard_partial (self , shard , target_offsets , target_lengths ):
309296 """
310297 Extract the target range of weights from shard data, supporting partial overlap.
@@ -314,14 +301,16 @@ def extract_weight_from_shard_partial(self, shard, target_offsets, target_length
314301 param target_lengths: A 1D array indicating the length of the target tensor in each dimension.
315302 return: The extracted sub-tensor of the target weights and its position within the target range.
316303 """
317- shard_offsets = shard [' offsets' ]
318- shard_lengths = shard [' lengths' ]
319- weight = shard [' weight' ]
304+ shard_offsets = shard [" offsets" ]
305+ shard_lengths = shard [" lengths" ]
306+ weight = shard [" weight" ]
320307
321308 slices = []
322309 target_slices = []
323310
324- for dim , (t_offset , t_length , s_offset , s_length ) in enumerate (zip (target_offsets , target_lengths , shard_offsets , shard_lengths )):
311+ for dim , (t_offset , t_length , s_offset , s_length ) in enumerate (
312+ zip (target_offsets , target_lengths , shard_offsets , shard_lengths )
313+ ):
325314 intersection_start = max (t_offset , s_offset )
326315 intersection_end = min (t_offset + t_length , s_offset + s_length )
327316
@@ -339,7 +328,6 @@ def extract_weight_from_shard_partial(self, shard, target_offsets, target_length
339328 target_weight = weight [tuple (slices )]
340329 return target_weight , target_slices
341330
342-
343331 def assemble_tensor_from_shards_partial (self , shards , target_offsets , target_lengths , dtype ):
344332 target_tensor = torch .zeros (target_lengths , dtype = dtype )
345333
@@ -351,15 +339,14 @@ def assemble_tensor_from_shards_partial(self, shards, target_offsets, target_len
351339
352340 return target_tensor
353341
354-
355- def load_unsharded_model (
342+ def load_unsharded_model (
356343 self ,
357344 model : ModelWrapper ,
358345 checkpoint : str ,
359346 strict : bool = False ,
360347 low_cpu_mem_mode : bool = True ,
361348 num_threads : int = 1 ,
362- ):
349+ ):
363350 """
364351 Load model from a single file with the given path of checkpoint.
365352
@@ -390,30 +377,34 @@ def load_unsharded_model(
390377 for key , item in self .model_metadata .items ():
391378 offsets = item ["offsets" ]
392379 lengths = item ["lengths" ]
393- assert item ["global_shape" ] == metadata_loaded [key ]["global_shape" ], f"{ item ['global_shape' ]} , { metadata_loaded [key ]['global_shape' ]} "
380+ assert (
381+ item ["global_shape" ] == metadata_loaded [key ]["global_shape" ]
382+ ), f"{ item ['global_shape' ]} , { metadata_loaded [key ]['global_shape' ]} "
394383 shards = metadata_loaded [key ]["shards" ]
395384 covering_shards = self .find_covering_shards (shards = shards , target_offsets = offsets , target_lengths = lengths )
396385 covered_shards [key ] = covering_shards
397386 # load_files.update({rank: shard['file'] for rank, shard in covering_shards.items()})
398387 for rank , shard in covering_shards .items ():
399388 if rank not in load_files :
400389 load_files [rank ] = set ()
401- load_files [rank ].add (shard [' file' ])
390+ load_files [rank ].add (shard [" file" ])
402391
403392 dtype = None
404393 for rank , files in load_files .items ():
405394 for file in files :
406395 file_path = os .path .join (checkpoint , file )
407396 state_dict_shard = load_state_dict (file_path )
408- for key , weight in state_dict_shard .items ():
397+ for key , weight in state_dict_shard .items ():
409398 if key not in covered_shards :
410399 continue
411400 if dtype == None :
412401 dtype = weight .dtype
413402 covered_shards [key ][rank ]["weight" ] = weight
414403 state_dict = {}
415404 for key , shards in covered_shards .items ():
416- state = self .assemble_tensor_from_shards_partial (shards , self .model_metadata [key ]["offsets" ], self .model_metadata [key ]["lengths" ], dtype = dtype )
405+ state = self .assemble_tensor_from_shards_partial (
406+ shards , self .model_metadata [key ]["offsets" ], self .model_metadata [key ]["lengths" ], dtype = dtype
407+ )
417408 state_dict [key ] = state
418409
419410 if not low_cpu_mem_mode :
@@ -424,7 +415,6 @@ def load_unsharded_model(
424415 # Update master params if mixed-precision training is enabled.
425416 model_before_wrapping .update_master_params ()
426417
427-
428418 @staticmethod
429419 def _model_sharder (
430420 model : nn .Module ,
@@ -571,7 +561,7 @@ def save_sharded_model(
571561 )
572562 for k , _ in self .model_metadata .items ():
573563 self .model_metadata [k ]["file" ] = index_file .get_checkpoint_file (k )
574-
564+
575565 self .save_metadata (metadata_file , total_size = total_size )
576566
577567 def load_sharded_model (
@@ -606,30 +596,34 @@ def load_sharded_model(
606596 for key , item in self .model_metadata .items ():
607597 offsets = item ["offsets" ]
608598 lengths = item ["lengths" ]
609- assert item ["global_shape" ] == metadata_loaded [key ]["global_shape" ], f"{ item ['global_shape' ]} , { metadata_loaded [key ]['global_shape' ]} "
599+ assert (
600+ item ["global_shape" ] == metadata_loaded [key ]["global_shape" ]
601+ ), f"{ item ['global_shape' ]} , { metadata_loaded [key ]['global_shape' ]} "
610602 shards = metadata_loaded [key ]["shards" ]
611603 covering_shards = self .find_covering_shards (shards = shards , target_offsets = offsets , target_lengths = lengths )
612604 covered_shards [key ] = covering_shards
613605 for rank , shard in covering_shards .items ():
614606 if rank not in load_files :
615607 load_files [rank ] = set ()
616- load_files [rank ].add (shard [' file' ])
617-
608+ load_files [rank ].add (shard [" file" ])
609+
618610 dtype = None
619611 for rank , files in load_files .items ():
620612 for file in files :
621613 file_path = os .path .join (checkpoint , file )
622614 state_dict_shard = load_state_dict (file_path )
623- for key , weight in state_dict_shard .items ():
615+ for key , weight in state_dict_shard .items ():
624616 if key not in covered_shards :
625617 continue
626618 if dtype == None :
627619 dtype = weight .dtype
628620 covered_shards [key ][rank ]["weight" ] = weight
629-
621+
630622 state_dict = {}
631623 for key , shards in covered_shards .items ():
632- state = self .assemble_tensor_from_shards_partial (shards , self .model_metadata [key ]["offsets" ], self .model_metadata [key ]["lengths" ], dtype = dtype )
624+ state = self .assemble_tensor_from_shards_partial (
625+ shards , self .model_metadata [key ]["offsets" ], self .model_metadata [key ]["lengths" ], dtype = dtype
626+ )
633627 state_dict [key ] = state
634628
635629 if not low_cpu_mem_mode :
@@ -638,4 +632,4 @@ def load_sharded_model(
638632 DistributedCheckpointIO .load_state_dict (model = model , state_dict = state_dict )
639633
640634 # Update master params if mixed-precision training is enabled.
641- model_before_wrapping .update_master_params ()
635+ model_before_wrapping .update_master_params ()
0 commit comments