11import  gc 
22import  logging 
33import  os 
4+ import  time 
45from  collections  import  OrderedDict 
5- from  typing  import  Optional 
6+ from  typing  import  Dict ,  Optional 
67
78import  safetensors 
89import  torch 
@@ -66,7 +67,8 @@ def load_weights(self, device: str):
6667        if  self ._load_config .is_ft_style_weight :
6768            weights  =  self ._load_from_ft_style (device )
6869        else :
69-             weights  =  self ._load_from_scratch (device )
70+             weights  =  self ._load_weight (device )
71+             self .force_clean_cuda_memory ()
7072
7173        # load dynamic weight 
7274        self ._load_dynamic_weights (weights , device )
@@ -203,6 +205,88 @@ def _load_from_ft_style(self, device: str):
203205        model_weights .global_weights  =  global_weights 
204206        return  model_weights 
205207
208+     def  _load_weight (self , device : str ):
209+         is_safetensor  =  self ._load_config .database .is_safetensor 
210+         convert_device  =  self ._choose_weight_convert_device (device )
211+         if  is_safetensor  and  convert_device  !=  "cpu"  and  self ._is_memory_enough_for_fastsafetensor ():
212+             return  self ._load_from_fastsafetensor (device )
213+         
214+         logging .info (
215+             f"database is safetensor: { is_safetensor }  , device: { device }  , choose devie: { convert_device }  " 
216+         )
217+         return  self ._load_from_scratch (device )
218+     
219+     def  _is_memory_enough_for_fastsafetensor (self ):
220+         model_size  =  self ._weights_info .config .eval_model_size ()
221+         device_mem_info  =  self ._load_config .exported_device .get_mem_info ()
222+         max_file_size  =  self ._load_config .database .get_max_file_size ()
223+         if  device_mem_info  is  None :
224+             return  False 
225+         else :
226+             free_mem  =  device_mem_info .free  /  (1024.0 ** 2 )
227+         model_mem  =  model_size  /  self ._load_config .tp_size  /  (1024.0 ** 2 )
228+         max_file_mem  =  max_file_size  /  (1024.0 ** 2 )
229+         logging .debug (f"free mem: { free_mem }  , model mem: { model_mem }  , max file mem: { max_file_mem }  " )
230+         return  (free_mem  -  model_mem ) >  (3  *  max_file_mem )
231+     
232+     def  _load_from_fastsafetensor (self , device : str ):
233+         try :
234+             all_tensors  =  self ._load_config .database .fastsafetensors_weights_iterator (
235+                 device , True 
236+             )
237+         except  (ModuleNotFoundError , ImportError ) as  e :
238+             logging .warning (f"Failed to import fastsafetensors: { e }  " )
239+             return  self ._load_from_scratch (device )
240+ 
241+         logging .info (f"load weight by device: { device }  " )
242+         model_weights  =  self ._create_model_weights (device )
243+         tensor_to_weight_map  =  self ._get_tensor_to_weight_map ()
244+         direct_io  =  self ._load_config .exported_device .support_dio_load 
245+         for  key , loaded_tensor  in  all_tensors :
246+             if  key  not  in   tensor_to_weight_map :
247+                 continue 
248+             layer_id , weight  =  tensor_to_weight_map [key ]
249+             start  =  time .time ()
250+             res , complete  =  weight .add_tensor (
251+                 key , self ._load_config .database , layer_id , loaded_tensor , device , self ._load_config 
252+             )
253+             logging .debug (
254+                 f"weight: { type (weight ).__name__ }   add tensor, complete: { complete }  , cost { time .time () -  start }  " 
255+             )
256+             if  complete  and  res  is  not   None :
257+                 for  name , tensor  in  res .items ():
258+                     if  layer_id  is  not   None  and  self ._load_config .vit_separation  !=  1 :
259+                         model_weights .set_layer_weight (layer_id , name , tensor )
260+                     else :
261+                         model_weights .set_global_weight (name , tensor )
262+         for  layer_id , name , tensor  in  self ._load_uncomplete_weight_modules (device ):
263+             if  layer_id  is  not   None  and  self ._load_config .vit_separation  !=  1 :
264+                 model_weights .set_layer_weight (layer_id , name , tensor )
265+             else :
266+                 model_weights .set_global_weight (name , tensor )
267+         return  model_weights 
268+     
269+     def  _load_uncomplete_weight_modules (self , device : str ):
270+         if  self ._load_config .vit_separation  !=  1  and  not  self ._is_attn_model :
271+             for  layer_id  in  range (self ._load_config .num_layers ):
272+                 layer_weights  =  self ._model_weights_info .layer_weights [layer_id ]
273+                 for  weight  in  layer_weights :
274+                     if  weight .loaded :
275+                         continue 
276+                     results  =  weight .load (
277+                         self ._load_config .database , layer_id , device , self ._load_config 
278+                     )
279+                     for  name , tensor  in  results .items ():
280+                         yield  (layer_id , name , tensor )
281+         for  weight  in  self ._model_weights_info .weights :
282+             if  self ._maybe_skip_weight (weight ) or  weight .loaded :
283+                 continue 
284+             results  =  weight .load (
285+                 self ._load_config .database , None , device , self ._load_config 
286+             )
287+             for  name , tensor  in  results .items ():
288+                 yield  (None , name , tensor )
289+ 
206290    def  prepare_weights (self , device : str ):
207291        if  self ._load_config .vit_separation  !=  1  and  not  self ._is_attn_model :
208292            for  id  in  range (self ._load_config .num_layers ):
@@ -225,6 +309,34 @@ def prepare_weights(self, device: str):
225309            )
226310            for  name , tensor  in  weights .items ():
227311                yield  (None , name , tensor )
312+     
313+     def  _get_tensor_to_weight_map (
314+         self ,
315+     ) ->  Dict [str , tuple [Optional [int ], WeightModule ]]:
316+         tensor_to_weight_map : Dict [str , tuple [Optional [int ], WeightModule ]] =  {}
317+         if  self ._load_config .vit_separation  !=  1  and  not  self ._is_attn_model :
318+             for  layer_id  in  range (self ._load_config .num_layers ):
319+                 layer_weights  =  self ._model_weights_info .layer_weights [layer_id ]
320+                 if  isinstance (layer_weights , WeightModule ):
321+                     names  =  layer_weights .get_tensor_names (layer_id , self ._load_config )
322+                     tensor_to_weight_map .update (
323+                         {k : (layer_id , layer_weights ) for  k  in  names }
324+                     )
325+                 else :
326+                     for  weight  in  layer_weights :
327+                         names  =  weight .get_tensor_names (layer_id , self ._load_config )
328+                         tensor_to_weight_map .update (
329+                             {k : (layer_id , weight ) for  k  in  names }
330+                         )
331+         for  weight  in  self ._model_weights_info .weights :
332+             if  self ._maybe_skip_weight (weight ):
333+                 continue 
334+             names  =  weight .get_tensor_names (None , self ._load_config )
335+             tensor_to_weight_map .update ({k : (None , weight ) for  k  in  names })
336+         for  weight  in  self ._misc_weights_info :
337+             names  =  weight .get_tensor_names (None , self ._load_config )
338+             tensor_to_weight_map .update ({k : (None , weight ) for  k  in  names })
339+         return  tensor_to_weight_map 
228340
229341    def  _maybe_skip_weight (self , weight : WeightModule ):
230342        if  self ._task_type  ==  TaskType .LANGUAGE_MODEL :
@@ -270,6 +382,7 @@ def _load_from_scratch(self, device: str):
270382                weights .set_layer_weight (layer_id , name , tensor )
271383            else :
272384                weights .set_global_weight (name , tensor )
385+             gc .collect ()
273386        return  weights 
274387
275388    def  _load_layer_weights (self , layer_id : int , device : str ):
0 commit comments