Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rtp_llm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ requirement([
"portalocker",
"concurrent_log_handler",
"aiter",
"fastsafetensors",
] + tensorrt)

filegroup(
Expand Down
12 changes: 10 additions & 2 deletions rtp_llm/distribute/gang_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import requests
import uvicorn
from fastapi import FastAPI
import torch.distributed

from rtp_llm.config.py_config_modules import PyEnvConfigs, StaticConfig
from rtp_llm.config.uvicorn_config import UVICORN_LOGGING_CONFIG
Expand Down Expand Up @@ -423,9 +424,16 @@ def start(self):
master_url = (
f"tcp://{g_master_info.ip}:{self._gang_info.master.server_port - 1}"
)
logging.info(f"gang worker {g_parallel_info} memory_barrier {master_url}")
logging.info(f"gang worker {g_parallel_info} init_process_group {master_url}")
init_process_timeout = self.py_env_configs.gang_config.dist_barrier_timeout
self.memory_barrier(master_url, timeout=init_process_timeout)
os.environ["TORCH_DIST_INIT_BARRIER"] = "1"
torch.distributed.init_process_group(
backend=torch.distributed.Backend.NCCL,
init_method=master_url,
rank=g_parallel_info.world_rank,
world_size=g_parallel_info.world_size,
timeout=timedelta(seconds=init_process_timeout),
)

logging.info(f"gang worker {g_parallel_info} start_health_check")
self.start_health_check()
Expand Down
3 changes: 2 additions & 1 deletion rtp_llm/eplb/ep_balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ModelWeightInfo,
)
from rtp_llm.utils.database import BaseDatabase
from rtp_llm.model_loader.tensor_source import DatabaseTensorSource
from rtp_llm.utils.model_weight import W


Expand Down Expand Up @@ -217,7 +218,7 @@ def load_moe_weight(
f"[EPLB_py][RANK {self._load_config.ep_rank}] Load MOE weight layer {layer_id} for {choose_expert_id}"
)
try:
res = moe_weight.load(self.database, layer_id, "cpu", self._load_config)
res = moe_weight.load(DatabaseTensorSource(self.database), layer_id, "cpu", self._load_config)
except:
logging.error(
f"[EPLB_py][RANK {self._load_config.ep_rank}] Load MOE weight layer failed: 完整堆栈:\n{traceback.format_exc()}"
Expand Down
6 changes: 3 additions & 3 deletions rtp_llm/model_loader/dynamic_fp8_quant_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
QuantWeight,
WeightModule,
)
from rtp_llm.utils.database import BaseDatabase
from rtp_llm.model_loader.tensor_source import TensorSource
from rtp_llm.utils.model_weight import W, WeightStyle

if utils.is_cuda():
Expand Down Expand Up @@ -159,12 +159,12 @@ def __init__(

def _load_raw_tensor(
self,
database: BaseDatabase,
tensor_source: TensorSource,
layer_id: Optional[int],
device: str,
load_config: LoadConfig,
):
kernel = self.kernel._load_raw_tensor(database, layer_id, device, load_config)
kernel = self.kernel._load_raw_tensor(tensor_source, layer_id, device, load_config)
if self.kernel.name in [W.moe_w1, W.moe_w2]:
# per expert quant moe w13 and w2 to fp8
kernel_tensor = kernel[self.kernel.name]
Expand Down
25 changes: 21 additions & 4 deletions rtp_llm/model_loader/ffn_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
QuantWeight,
WeightModule,
)
from rtp_llm.utils.database import BaseDatabase
from rtp_llm.model_loader.tensor_source import TensorSource
from rtp_llm.utils.model_weight import CkptWeightInfo, W, identity
from rtp_llm.utils.util import check_with_info

Expand Down Expand Up @@ -291,13 +291,13 @@ def __init__(

def _load_raw_tensor(
self,
database: BaseDatabase,
tensor_source: TensorSource,
layer_id: Optional[int],
device: str,
load_config: LoadConfig,
):
if self.config.weight_stack:
return super()._load_raw_tensor(database, layer_id, device, load_config)
return super()._load_raw_tensor(tensor_source, layer_id, device, load_config)

# weight should be expand by experts
before_merge_tensors = []
Expand All @@ -318,7 +318,7 @@ def _load_raw_tensor(
ckpt_weight.merge_fun(
[
x.to(device)
for x in database.load_tensor(name, convert_type)
for x in tensor_source.load_tensor(name, convert_type)
]
)
)
Expand All @@ -331,6 +331,23 @@ def _load_raw_tensor(
after_merge_tensor = self.process_fun(before_merge_tensors).to(convert_type)
logging.debug("load weight :%s, %s ", self.name, after_merge_tensor.shape)
return {self.name: after_merge_tensor}

def get_tensor_names(
self, layer_id: Optional[int], load_config: LoadConfig
) -> set[str]:
if self.config.weight_stack:
return super().get_tensor_names(layer_id, load_config)
names = set[str]()
for ckpt_weight in self.weights:
selected_experts = load_config.get_selected_experts(
layer_id, self.config.expert_num
)
for expert_id in selected_experts:
name = ckpt_weight.name.format(
i=str(layer_id), i_1=str(layer_id + 1), expert_id=str(expert_id)
)
names.add(name)
return names


class MoeWeight(CompositeWeight):
Expand Down
140 changes: 133 additions & 7 deletions rtp_llm/model_loader/loader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import gc
import logging
import os
import time
from collections import OrderedDict
from typing import Optional
from typing import Dict, Optional, NamedTuple, List, Tuple

import safetensors
import torch
Expand All @@ -21,6 +22,7 @@
ModelWeights,
)
from rtp_llm.model_loader.weight_module import CustomAtomicWeight, WeightModule
from rtp_llm.model_loader.tensor_source import TensorCollector, DatabaseTensorSource
from rtp_llm.utils.database import BaseDatabase, CkptDatabase
from rtp_llm.utils.fuser import fetch_remote_file_to_local
from rtp_llm.utils.model_weight import W, WeightStyle
Expand All @@ -29,6 +31,8 @@


class ModelLoader:
WeightInfo = NamedTuple("WeightInfo", [("weight", WeightModule), ("layer_id", Optional[int]), ("collector", TensorCollector)])

def __init__(
self,
task_type: TaskType,
Expand Down Expand Up @@ -66,7 +70,8 @@ def load_weights(self, device: str):
if self._load_config.is_ft_style_weight:
weights = self._load_from_ft_style(device)
else:
weights = self._load_from_scratch(device)
weights = self._load_weight(device)
self.force_clean_cuda_memory()

# load dynamic weight
self._load_dynamic_weights(weights, device)
Expand Down Expand Up @@ -203,6 +208,81 @@ def _load_from_ft_style(self, device: str):
model_weights.global_weights = global_weights
return model_weights

def _load_weight(self, device: str):
is_safetensor = self._load_config.database.is_safetensor
convert_device = self._choose_weight_convert_device(device)
if is_safetensor and convert_device != "cpu" and self._is_memory_enough_for_fastsafetensor():
try:
return self._load_from_fastsafetensor(device)
except Exception as e:
logging.warning(f"Failed to load from fastsafetensors: {e}")

logging.info(
f"database is safetensor: {is_safetensor}, device: {device}, choose devie: {convert_device}"
)
return self._load_from_scratch(device)

def _is_memory_enough_for_fastsafetensor(self):
model_size = self._weights_info.config.eval_model_size()
device_mem_info = self._load_config.exported_device.get_mem_info()
max_file_size = self._load_config.database.get_max_file_size()
if device_mem_info is None:
return False
else:
free_mem = device_mem_info.free / (1024.0**2)
model_mem = model_size / self._load_config.tp_size / (1024.0**2)
max_file_mem = max_file_size / (1024.0**2)
logging.debug(f"free mem: {free_mem}, model mem: {model_mem}, max file mem: {max_file_mem}")
return (free_mem - model_mem) > (3 * max_file_mem)

def _load_from_fastsafetensor(self, device: str):
all_tensors = self._load_config.database.fastsafetensors_weights_iterator(
device, True
)
logging.info(f"load weight by device: {device}")
model_weights = self._create_model_weights(device)
tensor_to_weight_map, weight_info_list = self._generate_weight_info()
direct_io = self._load_config.exported_device.support_dio_load
for key, loaded_tensor in all_tensors:
if key not in tensor_to_weight_map:
continue
weight_info = tensor_to_weight_map[key]
complete = weight_info.collector.store_tensor(key, loaded_tensor)
if complete:
start = time.time()
tensors = weight_info.weight.load(
tensor_source=weight_info.collector,
layer_id=weight_info.layer_id,
device=device,
load_config=self._load_config,
)
for name, tensor in tensors.items():
if weight_info.layer_id is not None:
model_weights.set_layer_weight(weight_info.layer_id, name, tensor)
else:
model_weights.set_global_weight(name, tensor)
logging.debug(
f"weight: {type(weight_info.weight).__name__} load cost {time.time() - start}"
)
weight_info.collector.clear()

for weight_info in weight_info_list:
weight_info.collector.clear()
if weight_info.collector.is_collection_complete():
continue
tensors = weight_info.weight.load(
tensor_source=DatabaseTensorSource(self._load_config.database),
layer_id=weight_info.layer_id,
device=device,
load_config=self._load_config
)
for name, tensor in tensors.items():
if weight_info.layer_id is not None:
model_weights.set_layer_weight(weight_info.layer_id, name, tensor)
else:
model_weights.set_global_weight(name, tensor)
return model_weights

def prepare_weights(self, device: str):
if self._load_config.vit_separation != 1 and not self._is_attn_model:
for id in range(self._load_config.num_layers):
Expand All @@ -214,17 +294,62 @@ def prepare_weights(self, device: str):
if self._maybe_skip_weight(weight):
continue
weights = weight.load(
self._load_config.database, None, device, self._load_config
DatabaseTensorSource(self._load_config.database), None, device, self._load_config
)
for name, tensor in weights.items():
yield (None, name, tensor)

for weight in self._misc_weights_info:
weights = weight.load(
self._load_config.database, None, device, self._load_config
DatabaseTensorSource(self._load_config.database), None, device, self._load_config
)
for name, tensor in weights.items():
yield (None, name, tensor)

def _generate_weight_info(self) -> Tuple[Dict[str, WeightInfo], List[WeightInfo]]:
# WeightInfo = namedtuple("WeightInfo", ["weight", "layer_id", "collector"])
WeightInfo = ModelLoader.WeightInfo
tensor_to_weight_map: Dict[str, WeightInfo] = {}
weight_info_list: List[WeightInfo] = []
if self._load_config.vit_separation != 1:
for layer_id in range(self._load_config.num_layers):
layer_weights = self._model_weights_info.layer_weights[layer_id]
if isinstance(layer_weights, WeightModule):
names = layer_weights.get_tensor_names(layer_id, self._load_config)
collector = TensorCollector(names, self._load_config.database)
weight_info = WeightInfo(weight=layer_weights, layer_id=layer_id, collector=collector)
tensor_to_weight_map.update(
{k: weight_info for k in names}
)
weight_info_list.append(weight_info)
else:
for weight in layer_weights:
names = weight.get_tensor_names(layer_id, self._load_config)
collector = TensorCollector(names, self._load_config.database)
weight_info = WeightInfo(weight=weight, layer_id=layer_id, collector=collector)
tensor_to_weight_map.update(
{k: weight_info for k in names}
)
weight_info_list.append(weight_info)
for weight in self._model_weights_info.weights:
if self._maybe_skip_weight(weight):
continue
names = weight.get_tensor_names(None, self._load_config)
collector = TensorCollector(names, self._load_config.database)
weight_info = WeightInfo(weight=weight, layer_id=None, collector=collector)
tensor_to_weight_map.update(
{k: weight_info for k in names}
)
weight_info_list.append(weight_info)
for weight in self._misc_weights_info:
names = weight.get_tensor_names(None, self._load_config)
collector = TensorCollector(names, self._load_config.database)
weight_info = WeightInfo(weight=weight, layer_id=None, collector=collector)
tensor_to_weight_map.update(
{k: weight_info for k in names}
)
weight_info_list.append(weight_info)
return tensor_to_weight_map, weight_info_list

def _maybe_skip_weight(self, weight: WeightModule):
if self._task_type == TaskType.LANGUAGE_MODEL:
Expand Down Expand Up @@ -254,7 +379,7 @@ def _choose_weight_convert_device(self, current_device):
else:
free_mem = device_mem_info.free / (1024.0**2)
model_mem = model_size / self._load_config.tp_size / (1024.0**2)
return current_device if free_mem * 0.8 > model_mem else "cpu"
return current_device if free_mem * 0.9 > model_mem else "cpu"

def _load_from_scratch(self, device: str):
weights = self._create_model_weights(device)
Expand All @@ -270,6 +395,7 @@ def _load_from_scratch(self, device: str):
weights.set_layer_weight(layer_id, name, tensor)
else:
weights.set_global_weight(name, tensor)
gc.collect()
return weights

def _load_layer_weights(self, layer_id: int, device: str):
Expand All @@ -278,7 +404,7 @@ def _load_layer_weights(self, layer_id: int, device: str):
weights = {}
for weight in layer_weights:
res = weight.load(
self._load_config.database, layer_id, device, self._load_config
DatabaseTensorSource(self._load_config.database), layer_id, device, self._load_config
)
weights.update(res)
return weights
Expand Down Expand Up @@ -337,7 +463,7 @@ def _load_dynamic_weights(self, weight: ModelWeights, device: str):
if dynamic_weights:
for dynamic_weight in dynamic_weights:
dynamic_w = dynamic_weight.load(
self._load_config.database, None, device, self._load_config
DatabaseTensorSource(self._load_config.database), None, device, self._load_config
)
weight.set_global_weight(
dynamic_weight.name, dynamic_w.get(dynamic_weight.name)
Expand Down
1 change: 0 additions & 1 deletion rtp_llm/model_loader/model_weight_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,6 @@ def __init__(self, num_layers: int, device: str, dtype: torch.dtype):

def set_layer_weight(self, layer_id: int, name: str, tensor: torch.Tensor):
self.weights[layer_id][name] = tensor
gc.collect()

def set_global_weight(self, name: str, tensor: torch.Tensor):
self.global_weights[name] = tensor
Expand Down
11 changes: 8 additions & 3 deletions rtp_llm/model_loader/per_block_fp8_quant_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
QuantWeight,
WeightModule,
)
from rtp_llm.utils.database import BaseDatabase
from rtp_llm.model_loader.tensor_source import TensorSource
from rtp_llm.utils.model_weight import (
FP8_E4M3_MAX,
CkptWeightInfo,
Expand Down Expand Up @@ -745,12 +745,12 @@ def __init__(

def _load_raw_tensor(
self,
database: BaseDatabase,
tensor_source: TensorSource,
layer_id: Optional[int],
device: str,
load_config: LoadConfig,
):
kernel = self.kernel._load_raw_tensor(database, layer_id, device, load_config)
kernel = self.kernel._load_raw_tensor(tensor_source, layer_id, device, load_config)

res = {}
scale = None
Expand All @@ -774,3 +774,8 @@ def _load_raw_tensor(
res.update({self.scale.name: scale.contiguous().to(device)})

return res

def get_tensor_names(
self, layer_id: Optional[int], load_config: LoadConfig
) -> set[str]:
return self.kernel.get_tensor_names(layer_id, load_config)
Loading