11import gc
22import logging
33import os
4+ import time
45from collections import OrderedDict
5- from typing import Optional
6+ from typing import Dict , Optional , NamedTuple , List , Tuple
67
78import safetensors
89import torch
2122 ModelWeights ,
2223)
2324from rtp_llm .model_loader .weight_module import CustomAtomicWeight , WeightModule
25+ from rtp_llm .model_loader .tensor_source import TensorCollector , DatabaseTensorSource
2426from rtp_llm .utils .database import BaseDatabase , CkptDatabase
2527from rtp_llm .utils .fuser import fetch_remote_file_to_local
2628from rtp_llm .utils .model_weight import W , WeightStyle
2931
3032
3133class ModelLoader :
34+ WeightInfo = NamedTuple ("WeightInfo" , [("weight" , WeightModule ), ("layer_id" , Optional [int ]), ("collector" , TensorCollector )])
35+
3236 def __init__ (
3337 self ,
3438 task_type : TaskType ,
@@ -66,7 +70,8 @@ def load_weights(self, device: str):
6670 if self ._load_config .is_ft_style_weight :
6771 weights = self ._load_from_ft_style (device )
6872 else :
69- weights = self ._load_from_scratch (device )
73+ weights = self ._load_weight (device )
74+ self .force_clean_cuda_memory ()
7075
7176 # load dynamic weight
7277 self ._load_dynamic_weights (weights , device )
@@ -203,6 +208,81 @@ def _load_from_ft_style(self, device: str):
203208 model_weights .global_weights = global_weights
204209 return model_weights
205210
211+ def _load_weight (self , device : str ):
212+ is_safetensor = self ._load_config .database .is_safetensor
213+ convert_device = self ._choose_weight_convert_device (device )
214+ if is_safetensor and convert_device != "cpu" and self ._is_memory_enough_for_fastsafetensor ():
215+ try :
216+ return self ._load_from_fastsafetensor (device )
217+ except Exception as e :
218+ logging .warning (f"Failed to load from fastsafetensors: { e } " )
219+
220+ logging .info (
221+ f"database is safetensor: { is_safetensor } , device: { device } , choose devie: { convert_device } "
222+ )
223+ return self ._load_from_scratch (device )
224+
225+ def _is_memory_enough_for_fastsafetensor (self ):
226+ model_size = self ._weights_info .config .eval_model_size ()
227+ device_mem_info = self ._load_config .exported_device .get_mem_info ()
228+ max_file_size = self ._load_config .database .get_max_file_size ()
229+ if device_mem_info is None :
230+ return False
231+ else :
232+ free_mem = device_mem_info .free / (1024.0 ** 2 )
233+ model_mem = model_size / self ._load_config .tp_size / (1024.0 ** 2 )
234+ max_file_mem = max_file_size / (1024.0 ** 2 )
235+ logging .debug (f"free mem: { free_mem } , model mem: { model_mem } , max file mem: { max_file_mem } " )
236+ return (free_mem - model_mem ) > (3 * max_file_mem )
237+
238+ def _load_from_fastsafetensor (self , device : str ):
239+ all_tensors = self ._load_config .database .fastsafetensors_weights_iterator (
240+ device , True
241+ )
242+ logging .info (f"load weight by device: { device } " )
243+ model_weights = self ._create_model_weights (device )
244+ tensor_to_weight_map , weight_info_list = self ._generate_weight_info ()
245+ direct_io = self ._load_config .exported_device .support_dio_load
246+ for key , loaded_tensor in all_tensors :
247+ if key not in tensor_to_weight_map :
248+ continue
249+ weight_info = tensor_to_weight_map [key ]
250+ complete = weight_info .collector .store_tensor (key , loaded_tensor )
251+ if complete :
252+ start = time .time ()
253+ tensors = weight_info .weight .load (
254+ tensor_source = weight_info .collector ,
255+ layer_id = weight_info .layer_id ,
256+ device = device ,
257+ load_config = self ._load_config ,
258+ )
259+ for name , tensor in tensors .items ():
260+ if weight_info .layer_id is not None :
261+ model_weights .set_layer_weight (weight_info .layer_id , name , tensor )
262+ else :
263+ model_weights .set_global_weight (name , tensor )
264+ logging .debug (
265+ f"weight: { type (weight_info .weight ).__name__ } load cost { time .time () - start } "
266+ )
267+ weight_info .collector .clear ()
268+
269+ for weight_info in weight_info_list :
270+ weight_info .collector .clear ()
271+ if weight_info .collector .is_collection_complete ():
272+ continue
273+ tensors = weight_info .weight .load (
274+ tensor_source = DatabaseTensorSource (self ._load_config .database ),
275+ layer_id = weight_info .layer_id ,
276+ device = device ,
277+ load_config = self ._load_config
278+ )
279+ for name , tensor in tensors .items ():
280+ if weight_info .layer_id is not None :
281+ model_weights .set_layer_weight (weight_info .layer_id , name , tensor )
282+ else :
283+ model_weights .set_global_weight (name , tensor )
284+ return model_weights
285+
206286 def prepare_weights (self , device : str ):
207287 if self ._load_config .vit_separation != 1 and not self ._is_attn_model :
208288 for id in range (self ._load_config .num_layers ):
@@ -214,17 +294,62 @@ def prepare_weights(self, device: str):
214294 if self ._maybe_skip_weight (weight ):
215295 continue
216296 weights = weight .load (
217- self ._load_config .database , None , device , self ._load_config
297+ DatabaseTensorSource ( self ._load_config .database ) , None , device , self ._load_config
218298 )
219299 for name , tensor in weights .items ():
220300 yield (None , name , tensor )
221301
222302 for weight in self ._misc_weights_info :
223303 weights = weight .load (
224- self ._load_config .database , None , device , self ._load_config
304+ DatabaseTensorSource ( self ._load_config .database ) , None , device , self ._load_config
225305 )
226306 for name , tensor in weights .items ():
227307 yield (None , name , tensor )
308+
309+ def _generate_weight_info (self ) -> Tuple [Dict [str , WeightInfo ], List [WeightInfo ]]:
310+ # WeightInfo = namedtuple("WeightInfo", ["weight", "layer_id", "collector"])
311+ WeightInfo = ModelLoader .WeightInfo
312+ tensor_to_weight_map : Dict [str , WeightInfo ] = {}
313+ weight_info_list : List [WeightInfo ] = []
314+ if self ._load_config .vit_separation != 1 :
315+ for layer_id in range (self ._load_config .num_layers ):
316+ layer_weights = self ._model_weights_info .layer_weights [layer_id ]
317+ if isinstance (layer_weights , WeightModule ):
318+ names = layer_weights .get_tensor_names (layer_id , self ._load_config )
319+ collector = TensorCollector (names , self ._load_config .database )
320+ weight_info = WeightInfo (weight = layer_weights , layer_id = layer_id , collector = collector )
321+ tensor_to_weight_map .update (
322+ {k : weight_info for k in names }
323+ )
324+ weight_info_list .append (weight_info )
325+ else :
326+ for weight in layer_weights :
327+ names = weight .get_tensor_names (layer_id , self ._load_config )
328+ collector = TensorCollector (names , self ._load_config .database )
329+ weight_info = WeightInfo (weight = weight , layer_id = layer_id , collector = collector )
330+ tensor_to_weight_map .update (
331+ {k : weight_info for k in names }
332+ )
333+ weight_info_list .append (weight_info )
334+ for weight in self ._model_weights_info .weights :
335+ if self ._maybe_skip_weight (weight ):
336+ continue
337+ names = weight .get_tensor_names (None , self ._load_config )
338+ collector = TensorCollector (names , self ._load_config .database )
339+ weight_info = WeightInfo (weight = weight , layer_id = None , collector = collector )
340+ tensor_to_weight_map .update (
341+ {k : weight_info for k in names }
342+ )
343+ weight_info_list .append (weight_info )
344+ for weight in self ._misc_weights_info :
345+ names = weight .get_tensor_names (None , self ._load_config )
346+ collector = TensorCollector (names , self ._load_config .database )
347+ weight_info = WeightInfo (weight = weight , layer_id = None , collector = collector )
348+ tensor_to_weight_map .update (
349+ {k : weight_info for k in names }
350+ )
351+ weight_info_list .append (weight_info )
352+ return tensor_to_weight_map , weight_info_list
228353
229354 def _maybe_skip_weight (self , weight : WeightModule ):
230355 if self ._task_type == TaskType .LANGUAGE_MODEL :
@@ -254,7 +379,7 @@ def _choose_weight_convert_device(self, current_device):
254379 else :
255380 free_mem = device_mem_info .free / (1024.0 ** 2 )
256381 model_mem = model_size / self ._load_config .tp_size / (1024.0 ** 2 )
257- return current_device if free_mem * 0.8 > model_mem else "cpu"
382+ return current_device if free_mem * 0.9 > model_mem else "cpu"
258383
259384 def _load_from_scratch (self , device : str ):
260385 weights = self ._create_model_weights (device )
@@ -270,6 +395,7 @@ def _load_from_scratch(self, device: str):
270395 weights .set_layer_weight (layer_id , name , tensor )
271396 else :
272397 weights .set_global_weight (name , tensor )
398+ gc .collect ()
273399 return weights
274400
275401 def _load_layer_weights (self , layer_id : int , device : str ):
@@ -278,7 +404,7 @@ def _load_layer_weights(self, layer_id: int, device: str):
278404 weights = {}
279405 for weight in layer_weights :
280406 res = weight .load (
281- self ._load_config .database , layer_id , device , self ._load_config
407+ DatabaseTensorSource ( self ._load_config .database ) , layer_id , device , self ._load_config
282408 )
283409 weights .update (res )
284410 return weights
@@ -337,7 +463,7 @@ def _load_dynamic_weights(self, weight: ModelWeights, device: str):
337463 if dynamic_weights :
338464 for dynamic_weight in dynamic_weights :
339465 dynamic_w = dynamic_weight .load (
340- self ._load_config .database , None , device , self ._load_config
466+ DatabaseTensorSource ( self ._load_config .database ) , None , device , self ._load_config
341467 )
342468 weight .set_global_weight (
343469 dynamic_weight .name , dynamic_w .get (dynamic_weight .name )
0 commit comments