Skip to content

Commit 77b439c

Browse files
committed
feat: add fastsafetensors loader
1 parent af77b7c commit 77b439c

15 files changed

+362
-40
lines changed

rtp_llm/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ requirement([
9797
"portalocker",
9898
"concurrent_log_handler",
9999
"aiter",
100+
"fastsafetensors",
100101
] + tensorrt)
101102

102103
filegroup(

rtp_llm/distribute/gang_server.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import requests
1414
import uvicorn
1515
from fastapi import FastAPI
16+
import torch.distributed
1617

1718
from rtp_llm.config.py_config_modules import PyEnvConfigs, StaticConfig
1819
from rtp_llm.config.uvicorn_config import UVICORN_LOGGING_CONFIG
@@ -423,9 +424,16 @@ def start(self):
423424
master_url = (
424425
f"tcp://{g_master_info.ip}:{self._gang_info.master.server_port - 1}"
425426
)
426-
logging.info(f"gang worker {g_parallel_info} memory_barrier {master_url}")
427+
logging.info(f"gang worker {g_parallel_info} init_process_group {master_url}")
427428
init_process_timeout = self.py_env_configs.gang_config.dist_barrier_timeout
428-
self.memory_barrier(master_url, timeout=init_process_timeout)
429+
os.environ["TORCH_DIST_INIT_BARRIER"] = "1"
430+
torch.distributed.init_process_group(
431+
backend=torch.distributed.Backend.NCCL,
432+
init_method=master_url,
433+
rank=g_parallel_info.world_rank,
434+
world_size=g_parallel_info.world_size,
435+
timeout=timedelta(seconds=init_process_timeout),
436+
)
429437

430438
logging.info(f"gang worker {g_parallel_info} start_health_check")
431439
self.start_health_check()

rtp_llm/eplb/ep_balancer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ModelWeightInfo,
2121
)
2222
from rtp_llm.utils.database import BaseDatabase
23+
from rtp_llm.model_loader.tensor_source import DatabaseTensorSource
2324
from rtp_llm.utils.model_weight import W
2425

2526

@@ -217,7 +218,7 @@ def load_moe_weight(
217218
f"[EPLB_py][RANK {self._load_config.ep_rank}] Load MOE weight layer {layer_id} for {choose_expert_id}"
218219
)
219220
try:
220-
res = moe_weight.load(self.database, layer_id, "cpu", self._load_config)
221+
res = moe_weight.load(DatabaseTensorSource(self.database), layer_id, "cpu", self._load_config)
221222
except:
222223
logging.error(
223224
f"[EPLB_py][RANK {self._load_config.ep_rank}] Load MOE weight layer failed: 完整堆栈:\n{traceback.format_exc()}"

rtp_llm/model_loader/dynamic_fp8_quant_weight.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
QuantWeight,
1919
WeightModule,
2020
)
21-
from rtp_llm.utils.database import BaseDatabase
21+
from rtp_llm.model_loader.tensor_source import TensorSource
2222
from rtp_llm.utils.model_weight import W, WeightStyle
2323

2424
if utils.is_cuda():
@@ -159,12 +159,12 @@ def __init__(
159159

160160
def _load_raw_tensor(
161161
self,
162-
database: BaseDatabase,
162+
tensor_source: TensorSource,
163163
layer_id: Optional[int],
164164
device: str,
165165
load_config: LoadConfig,
166166
):
167-
kernel = self.kernel._load_raw_tensor(database, layer_id, device, load_config)
167+
kernel = self.kernel._load_raw_tensor(tensor_source, layer_id, device, load_config)
168168
if self.kernel.name in [W.moe_w1, W.moe_w2]:
169169
# per expert quant moe w13 and w2 to fp8
170170
kernel_tensor = kernel[self.kernel.name]

rtp_llm/model_loader/ffn_weight.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
QuantWeight,
1515
WeightModule,
1616
)
17-
from rtp_llm.utils.database import BaseDatabase
17+
from rtp_llm.model_loader.tensor_source import TensorSource
1818
from rtp_llm.utils.model_weight import CkptWeightInfo, W, identity
1919
from rtp_llm.utils.util import check_with_info
2020

@@ -291,13 +291,13 @@ def __init__(
291291

292292
def _load_raw_tensor(
293293
self,
294-
database: BaseDatabase,
294+
tensor_source: TensorSource,
295295
layer_id: Optional[int],
296296
device: str,
297297
load_config: LoadConfig,
298298
):
299299
if self.config.weight_stack:
300-
return super()._load_raw_tensor(database, layer_id, device, load_config)
300+
return super()._load_raw_tensor(tensor_source, layer_id, device, load_config)
301301

302302
# weight should be expand by experts
303303
before_merge_tensors = []
@@ -318,7 +318,7 @@ def _load_raw_tensor(
318318
ckpt_weight.merge_fun(
319319
[
320320
x.to(device)
321-
for x in database.load_tensor(name, convert_type)
321+
for x in tensor_source.load_tensor(name, convert_type)
322322
]
323323
)
324324
)
@@ -331,6 +331,23 @@ def _load_raw_tensor(
331331
after_merge_tensor = self.process_fun(before_merge_tensors).to(convert_type)
332332
logging.debug("load weight :%s, %s ", self.name, after_merge_tensor.shape)
333333
return {self.name: after_merge_tensor}
334+
335+
def get_tensor_names(
336+
self, layer_id: Optional[int], load_config: LoadConfig
337+
) -> set[str]:
338+
if self.config.weight_stack:
339+
return super().get_tensor_names(layer_id, load_config)
340+
names = set[str]()
341+
for ckpt_weight in self.weights:
342+
selected_experts = load_config.get_selected_experts(
343+
layer_id, self.config.expert_num
344+
)
345+
for expert_id in selected_experts:
346+
name = ckpt_weight.name.format(
347+
i=str(layer_id), i_1=str(layer_id + 1), expert_id=str(expert_id)
348+
)
349+
names.add(name)
350+
return names
334351

335352

336353
class MoeWeight(CompositeWeight):

rtp_llm/model_loader/loader.py

Lines changed: 133 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import gc
22
import logging
33
import os
4+
import time
45
from collections import OrderedDict
5-
from typing import Optional
6+
from typing import Dict, Optional, NamedTuple, List, Tuple
67

78
import safetensors
89
import torch
@@ -21,6 +22,7 @@
2122
ModelWeights,
2223
)
2324
from rtp_llm.model_loader.weight_module import CustomAtomicWeight, WeightModule
25+
from rtp_llm.model_loader.tensor_source import TensorCollector, DatabaseTensorSource
2426
from rtp_llm.utils.database import BaseDatabase, CkptDatabase
2527
from rtp_llm.utils.fuser import fetch_remote_file_to_local
2628
from rtp_llm.utils.model_weight import W, WeightStyle
@@ -29,6 +31,8 @@
2931

3032

3133
class ModelLoader:
34+
WeightInfo = NamedTuple("WeightInfo", [("weight", WeightModule), ("layer_id", Optional[int]), ("collector", TensorCollector)])
35+
3236
def __init__(
3337
self,
3438
task_type: TaskType,
@@ -66,7 +70,8 @@ def load_weights(self, device: str):
6670
if self._load_config.is_ft_style_weight:
6771
weights = self._load_from_ft_style(device)
6872
else:
69-
weights = self._load_from_scratch(device)
73+
weights = self._load_weight(device)
74+
self.force_clean_cuda_memory()
7075

7176
# load dynamic weight
7277
self._load_dynamic_weights(weights, device)
@@ -203,6 +208,81 @@ def _load_from_ft_style(self, device: str):
203208
model_weights.global_weights = global_weights
204209
return model_weights
205210

211+
def _load_weight(self, device: str):
212+
is_safetensor = self._load_config.database.is_safetensor
213+
convert_device = self._choose_weight_convert_device(device)
214+
if is_safetensor and convert_device != "cpu" and self._is_memory_enough_for_fastsafetensor():
215+
try:
216+
return self._load_from_fastsafetensor(device)
217+
except Exception as e:
218+
logging.warning(f"Failed to load from fastsafetensors: {e}")
219+
220+
logging.info(
221+
f"database is safetensor: {is_safetensor}, device: {device}, choose devie: {convert_device}"
222+
)
223+
return self._load_from_scratch(device)
224+
225+
def _is_memory_enough_for_fastsafetensor(self):
226+
model_size = self._weights_info.config.eval_model_size()
227+
device_mem_info = self._load_config.exported_device.get_mem_info()
228+
max_file_size = self._load_config.database.get_max_file_size()
229+
if device_mem_info is None:
230+
return False
231+
else:
232+
free_mem = device_mem_info.free / (1024.0**2)
233+
model_mem = model_size / self._load_config.tp_size / (1024.0**2)
234+
max_file_mem = max_file_size / (1024.0**2)
235+
logging.debug(f"free mem: {free_mem}, model mem: {model_mem}, max file mem: {max_file_mem}")
236+
return (free_mem - model_mem) > (3 * max_file_mem)
237+
238+
def _load_from_fastsafetensor(self, device: str):
239+
all_tensors = self._load_config.database.fastsafetensors_weights_iterator(
240+
device, True
241+
)
242+
logging.info(f"load weight by device: {device}")
243+
model_weights = self._create_model_weights(device)
244+
tensor_to_weight_map, weight_info_list = self._generate_weight_info()
245+
direct_io = self._load_config.exported_device.support_dio_load
246+
for key, loaded_tensor in all_tensors:
247+
if key not in tensor_to_weight_map:
248+
continue
249+
weight_info = tensor_to_weight_map[key]
250+
complete = weight_info.collector.store_tensor(key, loaded_tensor)
251+
if complete:
252+
start = time.time()
253+
tensors = weight_info.weight.load(
254+
tensor_source=weight_info.collector,
255+
layer_id=weight_info.layer_id,
256+
device=device,
257+
load_config=self._load_config,
258+
)
259+
for name, tensor in tensors.items():
260+
if weight_info.layer_id is not None:
261+
model_weights.set_layer_weight(weight_info.layer_id, name, tensor)
262+
else:
263+
model_weights.set_global_weight(name, tensor)
264+
logging.debug(
265+
f"weight: {type(weight_info.weight).__name__} load cost {time.time() - start}"
266+
)
267+
weight_info.collector.clear()
268+
269+
for weight_info in weight_info_list:
270+
weight_info.collector.clear()
271+
if weight_info.collector.is_collection_complete():
272+
continue
273+
tensors = weight_info.weight.load(
274+
tensor_source=DatabaseTensorSource(self._load_config.database),
275+
layer_id=weight_info.layer_id,
276+
device=device,
277+
load_config=self._load_config
278+
)
279+
for name, tensor in tensors.items():
280+
if weight_info.layer_id is not None:
281+
model_weights.set_layer_weight(weight_info.layer_id, name, tensor)
282+
else:
283+
model_weights.set_global_weight(name, tensor)
284+
return model_weights
285+
206286
def prepare_weights(self, device: str):
207287
if self._load_config.vit_separation != 1 and not self._is_attn_model:
208288
for id in range(self._load_config.num_layers):
@@ -214,17 +294,62 @@ def prepare_weights(self, device: str):
214294
if self._maybe_skip_weight(weight):
215295
continue
216296
weights = weight.load(
217-
self._load_config.database, None, device, self._load_config
297+
DatabaseTensorSource(self._load_config.database), None, device, self._load_config
218298
)
219299
for name, tensor in weights.items():
220300
yield (None, name, tensor)
221301

222302
for weight in self._misc_weights_info:
223303
weights = weight.load(
224-
self._load_config.database, None, device, self._load_config
304+
DatabaseTensorSource(self._load_config.database), None, device, self._load_config
225305
)
226306
for name, tensor in weights.items():
227307
yield (None, name, tensor)
308+
309+
def _generate_weight_info(self) -> Tuple[Dict[str, WeightInfo], List[WeightInfo]]:
310+
# WeightInfo = namedtuple("WeightInfo", ["weight", "layer_id", "collector"])
311+
WeightInfo = ModelLoader.WeightInfo
312+
tensor_to_weight_map: Dict[str, WeightInfo] = {}
313+
weight_info_list: List[WeightInfo] = []
314+
if self._load_config.vit_separation != 1:
315+
for layer_id in range(self._load_config.num_layers):
316+
layer_weights = self._model_weights_info.layer_weights[layer_id]
317+
if isinstance(layer_weights, WeightModule):
318+
names = layer_weights.get_tensor_names(layer_id, self._load_config)
319+
collector = TensorCollector(names, self._load_config.database)
320+
weight_info = WeightInfo(weight=layer_weights, layer_id=layer_id, collector=collector)
321+
tensor_to_weight_map.update(
322+
{k: weight_info for k in names}
323+
)
324+
weight_info_list.append(weight_info)
325+
else:
326+
for weight in layer_weights:
327+
names = weight.get_tensor_names(layer_id, self._load_config)
328+
collector = TensorCollector(names, self._load_config.database)
329+
weight_info = WeightInfo(weight=weight, layer_id=layer_id, collector=collector)
330+
tensor_to_weight_map.update(
331+
{k: weight_info for k in names}
332+
)
333+
weight_info_list.append(weight_info)
334+
for weight in self._model_weights_info.weights:
335+
if self._maybe_skip_weight(weight):
336+
continue
337+
names = weight.get_tensor_names(None, self._load_config)
338+
collector = TensorCollector(names, self._load_config.database)
339+
weight_info = WeightInfo(weight=weight, layer_id=None, collector=collector)
340+
tensor_to_weight_map.update(
341+
{k: weight_info for k in names}
342+
)
343+
weight_info_list.append(weight_info)
344+
for weight in self._misc_weights_info:
345+
names = weight.get_tensor_names(None, self._load_config)
346+
collector = TensorCollector(names, self._load_config.database)
347+
weight_info = WeightInfo(weight=weight, layer_id=None, collector=collector)
348+
tensor_to_weight_map.update(
349+
{k: weight_info for k in names}
350+
)
351+
weight_info_list.append(weight_info)
352+
return tensor_to_weight_map, weight_info_list
228353

229354
def _maybe_skip_weight(self, weight: WeightModule):
230355
if self._task_type == TaskType.LANGUAGE_MODEL:
@@ -254,7 +379,7 @@ def _choose_weight_convert_device(self, current_device):
254379
else:
255380
free_mem = device_mem_info.free / (1024.0**2)
256381
model_mem = model_size / self._load_config.tp_size / (1024.0**2)
257-
return current_device if free_mem * 0.8 > model_mem else "cpu"
382+
return current_device if free_mem * 0.9 > model_mem else "cpu"
258383

259384
def _load_from_scratch(self, device: str):
260385
weights = self._create_model_weights(device)
@@ -270,6 +395,7 @@ def _load_from_scratch(self, device: str):
270395
weights.set_layer_weight(layer_id, name, tensor)
271396
else:
272397
weights.set_global_weight(name, tensor)
398+
gc.collect()
273399
return weights
274400

275401
def _load_layer_weights(self, layer_id: int, device: str):
@@ -278,7 +404,7 @@ def _load_layer_weights(self, layer_id: int, device: str):
278404
weights = {}
279405
for weight in layer_weights:
280406
res = weight.load(
281-
self._load_config.database, layer_id, device, self._load_config
407+
DatabaseTensorSource(self._load_config.database), layer_id, device, self._load_config
282408
)
283409
weights.update(res)
284410
return weights
@@ -337,7 +463,7 @@ def _load_dynamic_weights(self, weight: ModelWeights, device: str):
337463
if dynamic_weights:
338464
for dynamic_weight in dynamic_weights:
339465
dynamic_w = dynamic_weight.load(
340-
self._load_config.database, None, device, self._load_config
466+
DatabaseTensorSource(self._load_config.database), None, device, self._load_config
341467
)
342468
weight.set_global_weight(
343469
dynamic_weight.name, dynamic_w.get(dynamic_weight.name)

rtp_llm/model_loader/model_weight_info.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,6 @@ def __init__(self, num_layers: int, device: str, dtype: torch.dtype):
601601

602602
def set_layer_weight(self, layer_id: int, name: str, tensor: torch.Tensor):
603603
self.weights[layer_id][name] = tensor
604-
gc.collect()
605604

606605
def set_global_weight(self, name: str, tensor: torch.Tensor):
607606
self.global_weights[name] = tensor

rtp_llm/model_loader/per_block_fp8_quant_weight.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
QuantWeight,
1515
WeightModule,
1616
)
17-
from rtp_llm.utils.database import BaseDatabase
17+
from rtp_llm.model_loader.tensor_source import TensorSource
1818
from rtp_llm.utils.model_weight import (
1919
FP8_E4M3_MAX,
2020
CkptWeightInfo,
@@ -745,12 +745,12 @@ def __init__(
745745

746746
def _load_raw_tensor(
747747
self,
748-
database: BaseDatabase,
748+
tensor_source: TensorSource,
749749
layer_id: Optional[int],
750750
device: str,
751751
load_config: LoadConfig,
752752
):
753-
kernel = self.kernel._load_raw_tensor(database, layer_id, device, load_config)
753+
kernel = self.kernel._load_raw_tensor(tensor_source, layer_id, device, load_config)
754754

755755
res = {}
756756
scale = None
@@ -774,3 +774,8 @@ def _load_raw_tensor(
774774
res.update({self.scale.name: scale.contiguous().to(device)})
775775

776776
return res
777+
778+
def get_tensor_names(
779+
self, layer_id: Optional[int], load_config: LoadConfig
780+
) -> set[str]:
781+
return self.kernel.get_tensor_names(layer_id, load_config)

0 commit comments

Comments
 (0)