11import gc
22import logging
33import os
4+ import time
45from collections import OrderedDict
5- from typing import Optional
6+ from typing import Dict , Optional
67
78import safetensors
89import torch
@@ -66,7 +67,8 @@ def load_weights(self, device: str):
6667 if self ._load_config .is_ft_style_weight :
6768 weights = self ._load_from_ft_style (device )
6869 else :
69- weights = self ._load_from_scratch (device )
70+ weights = self ._load_weight (device )
71+ self .force_clean_cuda_memory ()
7072
7173 # load dynamic weight
7274 self ._load_dynamic_weights (weights , device )
@@ -203,6 +205,88 @@ def _load_from_ft_style(self, device: str):
203205 model_weights .global_weights = global_weights
204206 return model_weights
205207
208+ def _load_weight (self , device : str ):
209+ is_safetensor = self ._load_config .database .is_safetensor
210+ convert_device = self ._choose_weight_convert_device (device )
211+ if is_safetensor and convert_device != "cpu" and self ._is_memory_enough_for_fastsafetensor ():
212+ return self ._load_from_fastsafetensor (device )
213+
214+ logging .info (
215+ f"database is safetensor: { is_safetensor } , device: { device } , choose devie: { convert_device } "
216+ )
217+ return self ._load_from_scratch (device )
218+
219+ def _is_memory_enough_for_fastsafetensor (self ):
220+ model_size = self ._weights_info .config .eval_model_size ()
221+ device_mem_info = self ._load_config .exported_device .get_mem_info ()
222+ max_file_size = self ._load_config .database .get_max_file_size ()
223+ if device_mem_info is None :
224+ return False
225+ else :
226+ free_mem = device_mem_info .free / (1024.0 ** 2 )
227+ model_mem = model_size / self ._load_config .tp_size / (1024.0 ** 2 )
228+ max_file_mem = max_file_size / (1024.0 ** 2 )
229+ logging .debug (f"free mem: { free_mem } , model mem: { model_mem } , max file mem: { max_file_mem } " )
230+ return (free_mem - model_mem ) > (3 * max_file_mem )
231+
232+ def _load_from_fastsafetensor (self , device : str ):
233+ try :
234+ all_tensors = self ._load_config .database .fastsafetensors_weights_iterator (
235+ device , True
236+ )
237+ except (ModuleNotFoundError , ImportError ) as e :
238+ logging .warning (f"Failed to import fastsafetensors: { e } " )
239+ return self ._load_from_scratch (device )
240+
241+ logging .info (f"load weight by device: { device } " )
242+ model_weights = self ._create_model_weights (device )
243+ tensor_to_weight_map = self ._get_tensor_to_weight_map ()
244+ direct_io = self ._load_config .exported_device .support_dio_load
245+ for key , loaded_tensor in all_tensors :
246+ if key not in tensor_to_weight_map :
247+ continue
248+ layer_id , weight = tensor_to_weight_map [key ]
249+ start = time .time ()
250+ res , complete = weight .add_tensor (
251+ key , layer_id , loaded_tensor , device , self ._load_config
252+ )
253+ logging .debug (
254+ f"weight: { type (weight ).__name__ } add tensor, complete: { complete } , cost { time .time () - start } "
255+ )
256+ if complete and res is not None :
257+ for name , tensor in res .items ():
258+ if layer_id is not None and self ._load_config .vit_separation != 1 :
259+ model_weights .set_layer_weight (layer_id , name , tensor )
260+ else :
261+ model_weights .set_global_weight (name , tensor )
262+ for layer_id , name , tensor in self ._load_uncomplete_weight_modules (device ):
263+ if layer_id is not None and self ._load_config .vit_separation != 1 :
264+ model_weights .set_layer_weight (layer_id , name , tensor )
265+ else :
266+ model_weights .set_global_weight (name , tensor )
267+ return model_weights
268+
269+ def _load_uncomplete_weight_modules (self , device : str ):
270+ if self ._load_config .vit_separation != 1 and not self ._is_attn_model :
271+ for layer_id in range (self ._load_config .num_layers ):
272+ layer_weights = self ._model_weights_info .layer_weights [layer_id ]
273+ for weight in layer_weights :
274+ if weight .loaded :
275+ continue
276+ results = weight .load (
277+ self ._load_config .database , layer_id , device , self ._load_config
278+ )
279+ for name , tensor in results .items ():
280+ yield (layer_id , name , tensor )
281+ for weight in self ._model_weights_info .weights :
282+ if self ._maybe_skip_weight (weight ) or weight .loaded :
283+ continue
284+ results = weight .load (
285+ self ._load_config .database , None , device , self ._load_config
286+ )
287+ for name , tensor in results .items ():
288+ yield (None , name , tensor )
289+
206290 def prepare_weights (self , device : str ):
207291 if self ._load_config .vit_separation != 1 and not self ._is_attn_model :
208292 for id in range (self ._load_config .num_layers ):
@@ -225,6 +309,34 @@ def prepare_weights(self, device: str):
225309 )
226310 for name , tensor in weights .items ():
227311 yield (None , name , tensor )
312+
313+ def _get_tensor_to_weight_map (
314+ self ,
315+ ) -> Dict [str , tuple [Optional [int ], WeightModule ]]:
316+ tensor_to_weight_map : Dict [str , tuple [Optional [int ], WeightModule ]] = {}
317+ if self ._load_config .vit_separation != 1 and not self ._is_attn_model :
318+ for layer_id in range (self ._load_config .num_layers ):
319+ layer_weights = self ._model_weights_info .layer_weights [layer_id ]
320+ if isinstance (layer_weights , WeightModule ):
321+ names = layer_weights .get_tensor_names (layer_id , self ._load_config )
322+ tensor_to_weight_map .update (
323+ {k : (layer_id , layer_weights ) for k in names }
324+ )
325+ else :
326+ for weight in layer_weights :
327+ names = weight .get_tensor_names (layer_id , self ._load_config )
328+ tensor_to_weight_map .update (
329+ {k : (layer_id , weight ) for k in names }
330+ )
331+ for weight in self ._model_weights_info .weights :
332+ if self ._maybe_skip_weight (weight ):
333+ continue
334+ names = weight .get_tensor_names (None , self ._load_config )
335+ tensor_to_weight_map .update ({k : (None , weight ) for k in names })
336+ for weight in self ._misc_weights_info :
337+ names = weight .get_tensor_names (None , self ._load_config )
338+ tensor_to_weight_map .update ({k : (None , weight ) for k in names })
339+ return tensor_to_weight_map
228340
229341 def _maybe_skip_weight (self , weight : WeightModule ):
230342 if self ._task_type == TaskType .LANGUAGE_MODEL :
@@ -270,6 +382,7 @@ def _load_from_scratch(self, device: str):
270382 weights .set_layer_weight (layer_id , name , tensor )
271383 else :
272384 weights .set_global_weight (name , tensor )
385+ gc .collect ()
273386 return weights
274387
275388 def _load_layer_weights (self , layer_id : int , device : str ):
0 commit comments