From 5466e8743f0646d9d8a12b090631a98453a19048 Mon Sep 17 00:00:00 2001 From: "suohe.lx" Date: Thu, 9 Oct 2025 16:17:53 +0800 Subject: [PATCH] feat: add fastsafetensors loader --- rtp_llm/BUILD | 1 + rtp_llm/distribute/gang_server.py | 12 +- rtp_llm/eplb/ep_balancer.py | 3 +- .../model_loader/dynamic_fp8_quant_weight.py | 6 +- rtp_llm/model_loader/ffn_weight.py | 25 +++- rtp_llm/model_loader/loader.py | 140 +++++++++++++++++- rtp_llm/model_loader/model_weight_info.py | 1 - .../per_block_fp8_quant_weight.py | 11 +- .../model_loader/static_fp8_quant_weight.py | 22 ++- rtp_llm/model_loader/tensor_source.py | 66 +++++++++ rtp_llm/model_loader/weight_module.py | 41 ++++- .../model_loader/weight_only_quant_weight.py | 11 +- .../distributed/process_group_state.py | 2 +- rtp_llm/utils/ckpt_file_info.py | 4 + rtp_llm/utils/database.py | 57 ++++++- 15 files changed, 362 insertions(+), 40 deletions(-) create mode 100644 rtp_llm/model_loader/tensor_source.py diff --git a/rtp_llm/BUILD b/rtp_llm/BUILD index 6da22a002..69a5858cd 100755 --- a/rtp_llm/BUILD +++ b/rtp_llm/BUILD @@ -97,6 +97,7 @@ requirement([ "portalocker", "concurrent_log_handler", "aiter", + "fastsafetensors", ] + tensorrt) filegroup( diff --git a/rtp_llm/distribute/gang_server.py b/rtp_llm/distribute/gang_server.py index 5fe65463f..06c42f63d 100644 --- a/rtp_llm/distribute/gang_server.py +++ b/rtp_llm/distribute/gang_server.py @@ -13,6 +13,7 @@ import requests import uvicorn from fastapi import FastAPI +import torch.distributed from rtp_llm.config.py_config_modules import PyEnvConfigs, StaticConfig from rtp_llm.config.uvicorn_config import UVICORN_LOGGING_CONFIG @@ -423,9 +424,16 @@ def start(self): master_url = ( f"tcp://{g_master_info.ip}:{self._gang_info.master.server_port - 1}" ) - logging.info(f"gang worker {g_parallel_info} memory_barrier {master_url}") + logging.info(f"gang worker {g_parallel_info} init_process_group {master_url}") init_process_timeout = self.py_env_configs.gang_config.dist_barrier_timeout - self.memory_barrier(master_url, timeout=init_process_timeout) + os.environ["TORCH_DIST_INIT_BARRIER"] = "1" + torch.distributed.init_process_group( + backend=torch.distributed.Backend.NCCL, + init_method=master_url, + rank=g_parallel_info.world_rank, + world_size=g_parallel_info.world_size, + timeout=timedelta(seconds=init_process_timeout), + ) logging.info(f"gang worker {g_parallel_info} start_health_check") self.start_health_check() diff --git a/rtp_llm/eplb/ep_balancer.py b/rtp_llm/eplb/ep_balancer.py index c96814220..cb3cf0726 100644 --- a/rtp_llm/eplb/ep_balancer.py +++ b/rtp_llm/eplb/ep_balancer.py @@ -20,6 +20,7 @@ ModelWeightInfo, ) from rtp_llm.utils.database import BaseDatabase +from rtp_llm.model_loader.tensor_source import DatabaseTensorSource from rtp_llm.utils.model_weight import W @@ -217,7 +218,7 @@ def load_moe_weight( f"[EPLB_py][RANK {self._load_config.ep_rank}] Load MOE weight layer {layer_id} for {choose_expert_id}" ) try: - res = moe_weight.load(self.database, layer_id, "cpu", self._load_config) + res = moe_weight.load(DatabaseTensorSource(self.database), layer_id, "cpu", self._load_config) except: logging.error( f"[EPLB_py][RANK {self._load_config.ep_rank}] Load MOE weight layer failed: 完整堆栈:\n{traceback.format_exc()}" diff --git a/rtp_llm/model_loader/dynamic_fp8_quant_weight.py b/rtp_llm/model_loader/dynamic_fp8_quant_weight.py index 0cd0b5515..13444f896 100644 --- a/rtp_llm/model_loader/dynamic_fp8_quant_weight.py +++ b/rtp_llm/model_loader/dynamic_fp8_quant_weight.py @@ -18,7 +18,7 @@ QuantWeight, WeightModule, ) -from rtp_llm.utils.database import BaseDatabase +from rtp_llm.model_loader.tensor_source import TensorSource from rtp_llm.utils.model_weight import W, WeightStyle if utils.is_cuda(): @@ -159,12 +159,12 @@ def __init__( def _load_raw_tensor( self, - database: BaseDatabase, + tensor_source: TensorSource, layer_id: Optional[int], device: str, load_config: LoadConfig, ): - kernel = self.kernel._load_raw_tensor(database, layer_id, device, load_config) + kernel = self.kernel._load_raw_tensor(tensor_source, layer_id, device, load_config) if self.kernel.name in [W.moe_w1, W.moe_w2]: # per expert quant moe w13 and w2 to fp8 kernel_tensor = kernel[self.kernel.name] diff --git a/rtp_llm/model_loader/ffn_weight.py b/rtp_llm/model_loader/ffn_weight.py index b49ea9c5e..04b1ba2cc 100644 --- a/rtp_llm/model_loader/ffn_weight.py +++ b/rtp_llm/model_loader/ffn_weight.py @@ -14,7 +14,7 @@ QuantWeight, WeightModule, ) -from rtp_llm.utils.database import BaseDatabase +from rtp_llm.model_loader.tensor_source import TensorSource from rtp_llm.utils.model_weight import CkptWeightInfo, W, identity from rtp_llm.utils.util import check_with_info @@ -291,13 +291,13 @@ def __init__( def _load_raw_tensor( self, - database: BaseDatabase, + tensor_source: TensorSource, layer_id: Optional[int], device: str, load_config: LoadConfig, ): if self.config.weight_stack: - return super()._load_raw_tensor(database, layer_id, device, load_config) + return super()._load_raw_tensor(tensor_source, layer_id, device, load_config) # weight should be expand by experts before_merge_tensors = [] @@ -318,7 +318,7 @@ def _load_raw_tensor( ckpt_weight.merge_fun( [ x.to(device) - for x in database.load_tensor(name, convert_type) + for x in tensor_source.load_tensor(name, convert_type) ] ) ) @@ -331,6 +331,23 @@ def _load_raw_tensor( after_merge_tensor = self.process_fun(before_merge_tensors).to(convert_type) logging.debug("load weight :%s, %s ", self.name, after_merge_tensor.shape) return {self.name: after_merge_tensor} + + def get_tensor_names( + self, layer_id: Optional[int], load_config: LoadConfig + ) -> set[str]: + if self.config.weight_stack: + return super().get_tensor_names(layer_id, load_config) + names = set[str]() + for ckpt_weight in self.weights: + selected_experts = load_config.get_selected_experts( + layer_id, self.config.expert_num + ) + for expert_id in selected_experts: + name = ckpt_weight.name.format( + i=str(layer_id), i_1=str(layer_id + 1), expert_id=str(expert_id) + ) + names.add(name) + return names class MoeWeight(CompositeWeight): diff --git a/rtp_llm/model_loader/loader.py b/rtp_llm/model_loader/loader.py index 1446e0c2b..cace8802a 100644 --- a/rtp_llm/model_loader/loader.py +++ b/rtp_llm/model_loader/loader.py @@ -1,8 +1,9 @@ import gc import logging import os +import time from collections import OrderedDict -from typing import Optional +from typing import Dict, Optional, NamedTuple, List, Tuple import safetensors import torch @@ -21,6 +22,7 @@ ModelWeights, ) from rtp_llm.model_loader.weight_module import CustomAtomicWeight, WeightModule +from rtp_llm.model_loader.tensor_source import TensorCollector, DatabaseTensorSource from rtp_llm.utils.database import BaseDatabase, CkptDatabase from rtp_llm.utils.fuser import fetch_remote_file_to_local from rtp_llm.utils.model_weight import W, WeightStyle @@ -29,6 +31,8 @@ class ModelLoader: + WeightInfo = NamedTuple("WeightInfo", [("weight", WeightModule), ("layer_id", Optional[int]), ("collector", TensorCollector)]) + def __init__( self, task_type: TaskType, @@ -66,7 +70,8 @@ def load_weights(self, device: str): if self._load_config.is_ft_style_weight: weights = self._load_from_ft_style(device) else: - weights = self._load_from_scratch(device) + weights = self._load_weight(device) + self.force_clean_cuda_memory() # load dynamic weight self._load_dynamic_weights(weights, device) @@ -203,6 +208,81 @@ def _load_from_ft_style(self, device: str): model_weights.global_weights = global_weights return model_weights + def _load_weight(self, device: str): + is_safetensor = self._load_config.database.is_safetensor + convert_device = self._choose_weight_convert_device(device) + if is_safetensor and convert_device != "cpu" and self._is_memory_enough_for_fastsafetensor(): + try: + return self._load_from_fastsafetensor(device) + except Exception as e: + logging.warning(f"Failed to load from fastsafetensors: {e}") + + logging.info( + f"database is safetensor: {is_safetensor}, device: {device}, choose devie: {convert_device}" + ) + return self._load_from_scratch(device) + + def _is_memory_enough_for_fastsafetensor(self): + model_size = self._weights_info.config.eval_model_size() + device_mem_info = self._load_config.exported_device.get_mem_info() + max_file_size = self._load_config.database.get_max_file_size() + if device_mem_info is None: + return False + else: + free_mem = device_mem_info.free / (1024.0**2) + model_mem = model_size / self._load_config.tp_size / (1024.0**2) + max_file_mem = max_file_size / (1024.0**2) + logging.debug(f"free mem: {free_mem}, model mem: {model_mem}, max file mem: {max_file_mem}") + return (free_mem - model_mem) > (3 * max_file_mem) + + def _load_from_fastsafetensor(self, device: str): + all_tensors = self._load_config.database.fastsafetensors_weights_iterator( + device, True + ) + logging.info(f"load weight by device: {device}") + model_weights = self._create_model_weights(device) + tensor_to_weight_map, weight_info_list = self._generate_weight_info() + direct_io = self._load_config.exported_device.support_dio_load + for key, loaded_tensor in all_tensors: + if key not in tensor_to_weight_map: + continue + weight_info = tensor_to_weight_map[key] + complete = weight_info.collector.store_tensor(key, loaded_tensor) + if complete: + start = time.time() + tensors = weight_info.weight.load( + tensor_source=weight_info.collector, + layer_id=weight_info.layer_id, + device=device, + load_config=self._load_config, + ) + for name, tensor in tensors.items(): + if weight_info.layer_id is not None: + model_weights.set_layer_weight(weight_info.layer_id, name, tensor) + else: + model_weights.set_global_weight(name, tensor) + logging.debug( + f"weight: {type(weight_info.weight).__name__} load cost {time.time() - start}" + ) + weight_info.collector.clear() + + for weight_info in weight_info_list: + weight_info.collector.clear() + if weight_info.collector.is_collection_complete(): + continue + tensors = weight_info.weight.load( + tensor_source=DatabaseTensorSource(self._load_config.database), + layer_id=weight_info.layer_id, + device=device, + load_config=self._load_config + ) + for name, tensor in tensors.items(): + if weight_info.layer_id is not None: + model_weights.set_layer_weight(weight_info.layer_id, name, tensor) + else: + model_weights.set_global_weight(name, tensor) + return model_weights + def prepare_weights(self, device: str): if self._load_config.vit_separation != 1 and not self._is_attn_model: for id in range(self._load_config.num_layers): @@ -214,17 +294,62 @@ def prepare_weights(self, device: str): if self._maybe_skip_weight(weight): continue weights = weight.load( - self._load_config.database, None, device, self._load_config + DatabaseTensorSource(self._load_config.database), None, device, self._load_config ) for name, tensor in weights.items(): yield (None, name, tensor) for weight in self._misc_weights_info: weights = weight.load( - self._load_config.database, None, device, self._load_config + DatabaseTensorSource(self._load_config.database), None, device, self._load_config ) for name, tensor in weights.items(): yield (None, name, tensor) + + def _generate_weight_info(self) -> Tuple[Dict[str, WeightInfo], List[WeightInfo]]: + # WeightInfo = namedtuple("WeightInfo", ["weight", "layer_id", "collector"]) + WeightInfo = ModelLoader.WeightInfo + tensor_to_weight_map: Dict[str, WeightInfo] = {} + weight_info_list: List[WeightInfo] = [] + if self._load_config.vit_separation != 1: + for layer_id in range(self._load_config.num_layers): + layer_weights = self._model_weights_info.layer_weights[layer_id] + if isinstance(layer_weights, WeightModule): + names = layer_weights.get_tensor_names(layer_id, self._load_config) + collector = TensorCollector(names, self._load_config.database) + weight_info = WeightInfo(weight=layer_weights, layer_id=layer_id, collector=collector) + tensor_to_weight_map.update( + {k: weight_info for k in names} + ) + weight_info_list.append(weight_info) + else: + for weight in layer_weights: + names = weight.get_tensor_names(layer_id, self._load_config) + collector = TensorCollector(names, self._load_config.database) + weight_info = WeightInfo(weight=weight, layer_id=layer_id, collector=collector) + tensor_to_weight_map.update( + {k: weight_info for k in names} + ) + weight_info_list.append(weight_info) + for weight in self._model_weights_info.weights: + if self._maybe_skip_weight(weight): + continue + names = weight.get_tensor_names(None, self._load_config) + collector = TensorCollector(names, self._load_config.database) + weight_info = WeightInfo(weight=weight, layer_id=None, collector=collector) + tensor_to_weight_map.update( + {k: weight_info for k in names} + ) + weight_info_list.append(weight_info) + for weight in self._misc_weights_info: + names = weight.get_tensor_names(None, self._load_config) + collector = TensorCollector(names, self._load_config.database) + weight_info = WeightInfo(weight=weight, layer_id=None, collector=collector) + tensor_to_weight_map.update( + {k: weight_info for k in names} + ) + weight_info_list.append(weight_info) + return tensor_to_weight_map, weight_info_list def _maybe_skip_weight(self, weight: WeightModule): if self._task_type == TaskType.LANGUAGE_MODEL: @@ -254,7 +379,7 @@ def _choose_weight_convert_device(self, current_device): else: free_mem = device_mem_info.free / (1024.0**2) model_mem = model_size / self._load_config.tp_size / (1024.0**2) - return current_device if free_mem * 0.8 > model_mem else "cpu" + return current_device if free_mem * 0.9 > model_mem else "cpu" def _load_from_scratch(self, device: str): weights = self._create_model_weights(device) @@ -270,6 +395,7 @@ def _load_from_scratch(self, device: str): weights.set_layer_weight(layer_id, name, tensor) else: weights.set_global_weight(name, tensor) + gc.collect() return weights def _load_layer_weights(self, layer_id: int, device: str): @@ -278,7 +404,7 @@ def _load_layer_weights(self, layer_id: int, device: str): weights = {} for weight in layer_weights: res = weight.load( - self._load_config.database, layer_id, device, self._load_config + DatabaseTensorSource(self._load_config.database), layer_id, device, self._load_config ) weights.update(res) return weights @@ -337,7 +463,7 @@ def _load_dynamic_weights(self, weight: ModelWeights, device: str): if dynamic_weights: for dynamic_weight in dynamic_weights: dynamic_w = dynamic_weight.load( - self._load_config.database, None, device, self._load_config + DatabaseTensorSource(self._load_config.database), None, device, self._load_config ) weight.set_global_weight( dynamic_weight.name, dynamic_w.get(dynamic_weight.name) diff --git a/rtp_llm/model_loader/model_weight_info.py b/rtp_llm/model_loader/model_weight_info.py index 6e59d2d04..73d170f21 100644 --- a/rtp_llm/model_loader/model_weight_info.py +++ b/rtp_llm/model_loader/model_weight_info.py @@ -601,7 +601,6 @@ def __init__(self, num_layers: int, device: str, dtype: torch.dtype): def set_layer_weight(self, layer_id: int, name: str, tensor: torch.Tensor): self.weights[layer_id][name] = tensor - gc.collect() def set_global_weight(self, name: str, tensor: torch.Tensor): self.global_weights[name] = tensor diff --git a/rtp_llm/model_loader/per_block_fp8_quant_weight.py b/rtp_llm/model_loader/per_block_fp8_quant_weight.py index a06aa12f4..4d9c6198e 100644 --- a/rtp_llm/model_loader/per_block_fp8_quant_weight.py +++ b/rtp_llm/model_loader/per_block_fp8_quant_weight.py @@ -14,7 +14,7 @@ QuantWeight, WeightModule, ) -from rtp_llm.utils.database import BaseDatabase +from rtp_llm.model_loader.tensor_source import TensorSource from rtp_llm.utils.model_weight import ( FP8_E4M3_MAX, CkptWeightInfo, @@ -745,12 +745,12 @@ def __init__( def _load_raw_tensor( self, - database: BaseDatabase, + tensor_source: TensorSource, layer_id: Optional[int], device: str, load_config: LoadConfig, ): - kernel = self.kernel._load_raw_tensor(database, layer_id, device, load_config) + kernel = self.kernel._load_raw_tensor(tensor_source, layer_id, device, load_config) res = {} scale = None @@ -774,3 +774,8 @@ def _load_raw_tensor( res.update({self.scale.name: scale.contiguous().to(device)}) return res + + def get_tensor_names( + self, layer_id: Optional[int], load_config: LoadConfig + ) -> set[str]: + return self.kernel.get_tensor_names(layer_id, load_config) diff --git a/rtp_llm/model_loader/static_fp8_quant_weight.py b/rtp_llm/model_loader/static_fp8_quant_weight.py index b21febdc6..fc040825c 100644 --- a/rtp_llm/model_loader/static_fp8_quant_weight.py +++ b/rtp_llm/model_loader/static_fp8_quant_weight.py @@ -19,7 +19,7 @@ QuantWeight, WeightModule, ) -from rtp_llm.utils.database import BaseDatabase +from rtp_llm.model_loader.tensor_source import TensorSource from rtp_llm.utils.model_weight import ( FP8_E4M3_MAX, CkptWeightInfo, @@ -681,12 +681,12 @@ def __init__( def _load_raw_tensor( self, - database: BaseDatabase, + tensor_source: TensorSource, layer_id: Optional[int], device: str, load_config: LoadConfig, ): - kernel = self.kernel._load_raw_tensor(database, layer_id, device, load_config) + kernel = self.kernel._load_raw_tensor(tensor_source, layer_id, device, load_config) res = {} quant_kernel, scale = quantize_weight_to_fp8(kernel.get(self.kernel.name)) quant_kernel = quant_kernel.T @@ -697,15 +697,27 @@ def _load_raw_tensor( if self.act_scale: act_scale = self.act_scale._load_raw_tensor( - database, layer_id, device, load_config + tensor_source, layer_id, device, load_config ) res.update(act_scale) if self.act_scale_inv: act_scale_inv = self.act_scale_inv._load_raw_tensor( - database, layer_id, device, load_config + tensor_source, layer_id, device, load_config ) res.update(act_scale_inv) return res + + def get_tensor_names( + self, layer_id: Optional[int], load_config: LoadConfig + ) -> set[str]: + names = self.kernel.get_tensor_names(layer_id, load_config) + if self.act_scale: + names = names.union(self.act_scale.get_tensor_names(layer_id, load_config)) + if self.act_scale_inv: + names = names.union( + self.act_scale_inv.get_tensor_names(layer_id, load_config) + ) + return names class Fp8PerTensorCompressedWeight(CompositeWeight, QuantWeight): diff --git a/rtp_llm/model_loader/tensor_source.py b/rtp_llm/model_loader/tensor_source.py new file mode 100644 index 000000000..e762904bd --- /dev/null +++ b/rtp_llm/model_loader/tensor_source.py @@ -0,0 +1,66 @@ +from typing import Any, Dict, Generator, List, Optional +from rtp_llm.utils.database import BaseDatabase + +import torch + +class TensorSource: + def load_tensor( + self, name: str, data_type: Optional[torch.dtype] = torch.float16 + ) -> List[torch.Tensor]: + raise NotImplementedError + + def get_database(self) -> BaseDatabase: + raise NotImplementedError + + +class DatabaseTensorSource(TensorSource): + _database: BaseDatabase + + def __init__(self, database: BaseDatabase): + self._database = database + + def load_tensor(self, name, data_type = torch.float16): + return self._database.load_tensor(name, data_type) + + def get_database(self) -> BaseDatabase: + return self._database + + +class TensorCollector(TensorSource): + _target_keys: List[str] + _tensors: Dict[str, torch.Tensor] + _completed_once: bool + _database: BaseDatabase + + def __init__(self, target_keys: List[str], database: BaseDatabase): + self._target_keys = target_keys + self._tensors = {} + self._completed_once = False + self._database = database + + def load_tensor(self, name, data_type = torch.float16): + tensors = [] + t = self._tensors.get(name) + if t is not None: + tensors.append(self._tensors[name].to(data_type)) + return tensors + + def store_tensor(self, name: str, tensor: torch.Tensor) -> bool: + if name not in self._target_keys: + raise ValueError(f"Tensor name '{name}' not in target list.") + self._tensors[name] = tensor + self._check_completion() + return self.is_collection_complete() + + def _check_completion(self): + if self._target_keys.issubset(self._tensors.keys()): + self._completed_once = True + + def clear(self): + self._tensors.clear() + + def is_collection_complete(self) -> bool: + return self._completed_once + + def get_database(self) -> BaseDatabase: + return self._database diff --git a/rtp_llm/model_loader/weight_module.py b/rtp_llm/model_loader/weight_module.py index 425ee86c0..0066de405 100644 --- a/rtp_llm/model_loader/weight_module.py +++ b/rtp_llm/model_loader/weight_module.py @@ -1,6 +1,7 @@ import functools import inspect import logging +import time import traceback import weakref from abc import ABC, abstractmethod @@ -11,6 +12,7 @@ from rtp_llm.config.quant_config import QuantizationConfig from rtp_llm.model_loader.load_config import LoadConfig from rtp_llm.utils.database import BaseDatabase +from rtp_llm.model_loader.tensor_source import TensorSource from rtp_llm.utils.model_weight import CkptWeightInfo, W, WeightStyle, identity, sp_id @@ -151,16 +153,16 @@ def support( @torch.inference_mode() def load( self, - database: BaseDatabase, + tensor_source: TensorSource, layer_id: Optional[int], device: str, load_config: LoadConfig, ): - raw_tensors = self._load_raw_tensor(database, layer_id, device, load_config) + raw_tensors = self._load_raw_tensor(tensor_source, layer_id, device, load_config) if load_config.merge_lora: merged_tensors = self._merge_lora( - raw_tensors, database, layer_id, load_config + raw_tensors, tensor_source.get_database(), layer_id, load_config ) else: merged_tensors = raw_tensors @@ -230,13 +232,19 @@ def __extract_tensor(tensors): @abstractmethod def _load_raw_tensor( self, - database: BaseDatabase, + tensor_source: TensorSource, layer_id: Optional[int], device: str, load_config: LoadConfig, ): pass + @abstractmethod + def get_tensor_names( + self, layer_id: Optional[int], load_config: LoadConfig + ) -> set[str]: + pass + @abstractmethod def _split(self, tensor: torch.Tensor, load_config: LoadConfig): pass @@ -308,7 +316,7 @@ def need_transpose(self) -> bool: def _load_raw_tensor( self, - database: BaseDatabase, + tensor_source: TensorSource, layer_id: Optional[int], device: str, load_config: LoadConfig, @@ -322,7 +330,7 @@ def _load_raw_tensor( try: before_merge_tensors.append( ckpt_weight.merge_fun( - [x.to(device) for x in database.load_tensor(name, convert_type)] + [x.to(device) for x in tensor_source.load_tensor(name, convert_type)] ) ) except Exception as e: @@ -617,6 +625,14 @@ def _postprocess( ) } + def get_tensor_names( + self, layer_id: Optional[int], load_config: LoadConfig + ) -> set[str]: + names = set[str]() + for ckpt_weight in self.weights: + names.add(ckpt_weight.tensor_name(layer_id)) + return names + def _get_split_func(self): return W.gpt_style_tp_strategy[self.name] @@ -737,7 +753,7 @@ def __repr__(self) -> str: def _load_raw_tensor( self, - database: BaseDatabase, + tensor_source: TensorSource, layer_id: Optional[int], device: str, load_config: LoadConfig, @@ -745,7 +761,7 @@ def _load_raw_tensor( raw_tensors = {} for name, sub_weight in self.sub_weights.items(): sub_tensors = sub_weight._load_raw_tensor( - database, layer_id, device, load_config + tensor_source, layer_id, device, load_config ) if isinstance(sub_weight, AtomicWeight) and isinstance(sub_tensors, dict): raw_tensors.update(sub_tensors) @@ -836,3 +852,12 @@ def _postprocess( else: processed_tensors.update({name: sub_tensors}) return processed_tensors + + def get_tensor_names( + self, layer_id: Optional[int], load_config: LoadConfig + ) -> set[str]: + names = set[str]() + for _, sub_weight in self.sub_weights.items(): + sub_names = sub_weight.get_tensor_names(layer_id, load_config) + names = names.union(sub_names) + return names diff --git a/rtp_llm/model_loader/weight_only_quant_weight.py b/rtp_llm/model_loader/weight_only_quant_weight.py index 96d29370d..c3c6adebb 100644 --- a/rtp_llm/model_loader/weight_only_quant_weight.py +++ b/rtp_llm/model_loader/weight_only_quant_weight.py @@ -14,7 +14,7 @@ QuantWeight, WeightModule, ) -from rtp_llm.utils.database import BaseDatabase +from rtp_llm.model_loader.tensor_source import TensorSource from rtp_llm.utils.model_weight import W @@ -78,12 +78,12 @@ def __init__( def _load_raw_tensor( self, - database: BaseDatabase, + tensor_source: TensorSource, layer_id: Optional[int], device: str, load_config: LoadConfig, ): - kernel = self.kernel._load_raw_tensor(database, layer_id, device, load_config) + kernel = self.kernel._load_raw_tensor(tensor_source, layer_id, device, load_config) return kernel def _split(self, tensor: torch.Tensor, load_config: LoadConfig): @@ -102,3 +102,8 @@ def _postprocess( else: weight, scale = load_config.exported_device.apply_int8(kernel, device) return {self.kernel.name: weight, self.scale.name: scale} + + def get_tensor_names( + self, layer_id: Optional[int], load_config: LoadConfig + ) -> set[str]: + return self.kernel.get_tensor_names(layer_id, load_config) diff --git a/rtp_llm/models_py/distributed/process_group_state.py b/rtp_llm/models_py/distributed/process_group_state.py index 19b990621..b77462b1b 100644 --- a/rtp_llm/models_py/distributed/process_group_state.py +++ b/rtp_llm/models_py/distributed/process_group_state.py @@ -152,7 +152,7 @@ def init_distributed_environment( device_id=torch.device(f"cuda:{local_rank}"), timeout=timeout, # pyright: ignore[reportArgumentType] ) - initialize_expert_parallel(params, backend) + initialize_expert_parallel(params, backend) _EP: Optional[ProcessGroupState] = None diff --git a/rtp_llm/utils/ckpt_file_info.py b/rtp_llm/utils/ckpt_file_info.py index e12e9386b..2adff6de0 100644 --- a/rtp_llm/utils/ckpt_file_info.py +++ b/rtp_llm/utils/ckpt_file_info.py @@ -55,6 +55,10 @@ def get_tensor_names(self) -> List[str]: def tensor_num(self) -> int: return len(self.metadata.keys()) + @property + def file_size(self) -> int: + return os.path.getsize(self.file_name) + def is_safetensor(self) -> bool: if self.ckpt_type == CkptType.safetensors: return True diff --git a/rtp_llm/utils/database.py b/rtp_llm/utils/database.py index d36a0c670..d72a5ded3 100644 --- a/rtp_llm/utils/database.py +++ b/rtp_llm/utils/database.py @@ -2,10 +2,12 @@ import logging import os import re +import time from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Generator, List, Optional import torch +from tqdm.auto import tqdm from rtp_llm.lora.lora_file import LoraCkpt from rtp_llm.utils.ckpt_file_info import CkptFileInfo, FinetuneType @@ -29,6 +31,13 @@ def get_tensor_order(self, name: str) -> List[int]: def get_tensor_type(self, name: str) -> torch.dtype: raise NotImplementedError + + def get_max_file_size(self) -> int: + raise NotImplementedError + + @property + def is_safetensor(self) -> bool: + return False @property def is_ft_style(self) -> bool: @@ -76,10 +85,17 @@ def __init__(self, path: Optional[str], ptuning_path: Optional[str] = None) -> N @property def is_ft_style(self) -> bool: return self._is_ft_style + + @property + def is_safetensor(self) -> bool: + return all(map(lambda file: file.is_safetensor(), self.pretrain_file_list)) @property def ft_weight_params(self) -> Optional[Dict[str, Any]]: return self._ft_weight_params + + def get_max_file_size(self) -> int: + return max([file.file_size for file in self.pretrain_file_list]) def load_hf_meta(self, path: str): # avoid consolidated.safetensors in Mistral-Nemo-Instruct-2407 @@ -171,12 +187,19 @@ def get_tensor_order(self, name: str) -> List[int]: def load_tensors_by_prefix( self, prefix_list: List[str], device: str, direct_io: bool ) -> dict[str, List[torch.Tensor]]: + try: + from fast_safetensors import LoadWithShm + loader = LoadWithShm(2 * 1024 * 1024 * 1024, device, direct_io) + load_tensors = lambda ckptfile: loader.load_safetensors_to_device(ckptfile.file_name) + except (ModuleNotFoundError, ImportError): + load_tensors = lambda ckptfile: ckptfile.load_tensors(device, direct_io) + res = {} for ckptfile in self.pretrain_file_list: if any( tensor.startswith(prefix_list) for tensor in ckptfile.get_tensor_names() ): - tensors = ckptfile.load_tensors(device, direct_io) + tensors = load_tensors(ckptfile) for k, v in tensors.items(): if not k.startswith(prefix_list): continue @@ -185,6 +208,36 @@ def load_tensors_by_prefix( else: res[k].append(v) return res + + def fastsafetensors_weights_iterator(self, device: str, use_tqdm_on_load: bool): + from fastsafetensors import ParallelLoader, SingleGroup + def iterator(device: str, use_tqdm_on_load: bool): + if torch.distributed.is_initialized(): + pg = torch.distributed.group.WORLD + else: + pg = SingleGroup() + + hf_weights_files = sorted( + [file.file_name for file in self.pretrain_file_list] + ) + if device == "cuda": + device = f"cuda:{pg.rank()}" + logging.debug(f"origin device is cuda, set to {device}") + # Create loader + iterator = ParallelLoader( + pg, + hf_weights_files=hf_weights_files, + use_tqdm_on_load=use_tqdm_on_load, + device=device, + bbuf_size_kb=1024 * 1024 * 2, + use_shm=True, + ) + try: + # Execute parallel iteration + yield from iterator.iterate_weights() + finally: + iterator.loader.close() + return iterator(device, use_tqdm_on_load) def get_lora_tensor_names(self, config_name: str) -> List[str]: return self.lora_ckpt.get_lora_tensor_names(config_name)