Skip to content

Commit af92517

Browse files
committed
feat: add fastsafetensors loader
1 parent 7875bd6 commit af92517

File tree

12 files changed

+498
-8
lines changed

12 files changed

+498
-8
lines changed

rtp_llm/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ requirement([
9797
"portalocker",
9898
"concurrent_log_handler",
9999
"aiter",
100+
"fastsafetensors",
100101
] + tensorrt)
101102

102103
filegroup(

rtp_llm/distribute/gang_server.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import requests
1414
import uvicorn
1515
from fastapi import FastAPI
16+
import torch.distributed
1617

1718
from rtp_llm.config.py_config_modules import PyEnvConfigs, StaticConfig
1819
from rtp_llm.config.uvicorn_config import UVICORN_LOGGING_CONFIG
@@ -423,9 +424,16 @@ def start(self):
423424
master_url = (
424425
f"tcp://{g_master_info.ip}:{self._gang_info.master.server_port - 1}"
425426
)
426-
logging.info(f"gang worker {g_parallel_info} memory_barrier {master_url}")
427+
logging.info(f"gang worker {g_parallel_info} init_process_group {master_url}")
427428
init_process_timeout = self.py_env_configs.gang_config.dist_barrier_timeout
428-
self.memory_barrier(master_url, timeout=init_process_timeout)
429+
os.environ["TORCH_DIST_INIT_BARRIER"] = "1"
430+
torch.distributed.init_process_group(
431+
backend=torch.distributed.Backend.NCCL,
432+
init_method=master_url,
433+
rank=g_parallel_info.world_rank,
434+
world_size=g_parallel_info.world_size,
435+
timeout=timedelta(seconds=init_process_timeout),
436+
)
429437

430438
logging.info(f"gang worker {g_parallel_info} start_health_check")
431439
self.start_health_check()

rtp_llm/model_loader/ffn_weight.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,58 @@ def _load_raw_tensor(
331331
after_merge_tensor = self.process_fun(before_merge_tensors).to(convert_type)
332332
logging.debug("load weight :%s, %s ", self.name, after_merge_tensor.shape)
333333
return {self.name: after_merge_tensor}
334+
335+
def get_tensor_names(
336+
self, layer_id: Optional[int], load_config: LoadConfig
337+
) -> set[str]:
338+
if self.config.weight_stack:
339+
return super().get_tensor_names(layer_id, load_config)
340+
names = set[str]()
341+
for ckpt_weight in self.weights:
342+
selected_experts = load_config.get_selected_experts(
343+
layer_id, self.config.expert_num
344+
)
345+
for expert_id in selected_experts:
346+
name = ckpt_weight.name.format(
347+
i=str(layer_id), i_1=str(layer_id + 1), expert_id=str(expert_id)
348+
)
349+
names.add(name)
350+
return names
351+
352+
def _process_raw_tensors(
353+
self,
354+
raw_tensors: Dict[str, torch.Tensor],
355+
layer_id: Optional[int],
356+
device: str,
357+
load_config: LoadConfig,
358+
):
359+
if self.config.weight_stack:
360+
return super()._process_raw_tensors(
361+
raw_tensors, layer_id, device, load_config
362+
)
363+
before_merge_tensors: List[torch.Tensor] = []
364+
convert_type = (
365+
self.data_type if self.data_type is not None else load_config.compute_dtype
366+
)
367+
for ckpt_weight in self.weights:
368+
selected_experts = load_config.get_selected_experts(
369+
layer_id, self.config.expert_num
370+
)
371+
for expert_id in selected_experts:
372+
name = ckpt_weight.name.format(
373+
i=str(layer_id), i_1=str(layer_id + 1), expert_id=str(expert_id)
374+
)
375+
try:
376+
before_merge_tensors.append(
377+
ckpt_weight.merge_fun([raw_tensors[name]])
378+
)
379+
except Exception as e:
380+
logging.error(
381+
f"加载 {self.name}: {name} 失败,完整堆栈:\n{traceback.format_exc()}"
382+
)
383+
raise e
384+
after_merge_tensor = self.process_fun(before_merge_tensors).to(convert_type)
385+
return {self.name: after_merge_tensor}
334386

335387

336388
class MoeWeight(CompositeWeight):

rtp_llm/model_loader/loader.py

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import gc
22
import logging
33
import os
4+
import time
45
from collections import OrderedDict
5-
from typing import Optional
6+
from typing import Dict, Optional
67

78
import safetensors
89
import torch
@@ -66,7 +67,8 @@ def load_weights(self, device: str):
6667
if self._load_config.is_ft_style_weight:
6768
weights = self._load_from_ft_style(device)
6869
else:
69-
weights = self._load_from_scratch(device)
70+
weights = self._load_weight(device)
71+
self.force_clean_cuda_memory()
7072

7173
# load dynamic weight
7274
self._load_dynamic_weights(weights, device)
@@ -203,6 +205,88 @@ def _load_from_ft_style(self, device: str):
203205
model_weights.global_weights = global_weights
204206
return model_weights
205207

208+
def _load_weight(self, device: str):
209+
is_safetensor = self._load_config.database.is_safetensor
210+
convert_device = self._choose_weight_convert_device(device)
211+
if is_safetensor and convert_device != "cpu" and self._is_memory_enough_for_fastsafetensor():
212+
return self._load_from_fastsafetensor(device)
213+
214+
logging.info(
215+
f"database is safetensor: {is_safetensor}, device: {device}, choose devie: {convert_device}"
216+
)
217+
return self._load_from_scratch(device)
218+
219+
def _is_memory_enough_for_fastsafetensor(self):
220+
model_size = self._weights_info.config.eval_model_size()
221+
device_mem_info = self._load_config.exported_device.get_mem_info()
222+
max_file_size = self._load_config.database.get_max_file_size()
223+
if device_mem_info is None:
224+
return False
225+
else:
226+
free_mem = device_mem_info.free / (1024.0**2)
227+
model_mem = model_size / self._load_config.tp_size / (1024.0**2)
228+
max_file_mem = max_file_size / (1024.0**2)
229+
logging.debug(f"free mem: {free_mem}, model mem: {model_mem}, max file mem: {max_file_mem}")
230+
return (free_mem - model_mem) > (3 * max_file_mem)
231+
232+
def _load_from_fastsafetensor(self, device: str):
233+
try:
234+
all_tensors = self._load_config.database.fastsafetensors_weights_iterator(
235+
device, True
236+
)
237+
except (ModuleNotFoundError, ImportError) as e:
238+
logging.warning(f"Failed to import fastsafetensors: {e}")
239+
return self._load_from_scratch(device)
240+
241+
logging.info(f"load weight by device: {device}")
242+
model_weights = self._create_model_weights(device)
243+
tensor_to_weight_map = self._get_tensor_to_weight_map()
244+
direct_io = self._load_config.exported_device.support_dio_load
245+
for key, loaded_tensor in all_tensors:
246+
if key not in tensor_to_weight_map:
247+
continue
248+
layer_id, weight = tensor_to_weight_map[key]
249+
start = time.time()
250+
res, complete = weight.add_tensor(
251+
key, self._load_config.database, layer_id, loaded_tensor, device, self._load_config
252+
)
253+
logging.debug(
254+
f"weight: {type(weight).__name__} add tensor, complete: {complete}, cost {time.time() - start}"
255+
)
256+
if complete and res is not None:
257+
for name, tensor in res.items():
258+
if layer_id is not None and self._load_config.vit_separation != 1:
259+
model_weights.set_layer_weight(layer_id, name, tensor)
260+
else:
261+
model_weights.set_global_weight(name, tensor)
262+
for layer_id, name, tensor in self._load_uncomplete_weight_modules(device):
263+
if layer_id is not None and self._load_config.vit_separation != 1:
264+
model_weights.set_layer_weight(layer_id, name, tensor)
265+
else:
266+
model_weights.set_global_weight(name, tensor)
267+
return model_weights
268+
269+
def _load_uncomplete_weight_modules(self, device: str):
270+
if self._load_config.vit_separation != 1 and not self._is_attn_model:
271+
for layer_id in range(self._load_config.num_layers):
272+
layer_weights = self._model_weights_info.layer_weights[layer_id]
273+
for weight in layer_weights:
274+
if weight.loaded:
275+
continue
276+
results = weight.load(
277+
self._load_config.database, layer_id, device, self._load_config
278+
)
279+
for name, tensor in results.items():
280+
yield (layer_id, name, tensor)
281+
for weight in self._model_weights_info.weights:
282+
if self._maybe_skip_weight(weight) or weight.loaded:
283+
continue
284+
results = weight.load(
285+
self._load_config.database, None, device, self._load_config
286+
)
287+
for name, tensor in results.items():
288+
yield (None, name, tensor)
289+
206290
def prepare_weights(self, device: str):
207291
if self._load_config.vit_separation != 1 and not self._is_attn_model:
208292
for id in range(self._load_config.num_layers):
@@ -225,6 +309,34 @@ def prepare_weights(self, device: str):
225309
)
226310
for name, tensor in weights.items():
227311
yield (None, name, tensor)
312+
313+
def _get_tensor_to_weight_map(
314+
self,
315+
) -> Dict[str, tuple[Optional[int], WeightModule]]:
316+
tensor_to_weight_map: Dict[str, tuple[Optional[int], WeightModule]] = {}
317+
if self._load_config.vit_separation != 1 and not self._is_attn_model:
318+
for layer_id in range(self._load_config.num_layers):
319+
layer_weights = self._model_weights_info.layer_weights[layer_id]
320+
if isinstance(layer_weights, WeightModule):
321+
names = layer_weights.get_tensor_names(layer_id, self._load_config)
322+
tensor_to_weight_map.update(
323+
{k: (layer_id, layer_weights) for k in names}
324+
)
325+
else:
326+
for weight in layer_weights:
327+
names = weight.get_tensor_names(layer_id, self._load_config)
328+
tensor_to_weight_map.update(
329+
{k: (layer_id, weight) for k in names}
330+
)
331+
for weight in self._model_weights_info.weights:
332+
if self._maybe_skip_weight(weight):
333+
continue
334+
names = weight.get_tensor_names(None, self._load_config)
335+
tensor_to_weight_map.update({k: (None, weight) for k in names})
336+
for weight in self._misc_weights_info:
337+
names = weight.get_tensor_names(None, self._load_config)
338+
tensor_to_weight_map.update({k: (None, weight) for k in names})
339+
return tensor_to_weight_map
228340

229341
def _maybe_skip_weight(self, weight: WeightModule):
230342
if self._task_type == TaskType.LANGUAGE_MODEL:
@@ -270,6 +382,7 @@ def _load_from_scratch(self, device: str):
270382
weights.set_layer_weight(layer_id, name, tensor)
271383
else:
272384
weights.set_global_weight(name, tensor)
385+
gc.collect()
273386
return weights
274387

275388
def _load_layer_weights(self, layer_id: int, device: str):

rtp_llm/model_loader/model_weight_info.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,6 @@ def __init__(self, num_layers: int, device: str, dtype: torch.dtype):
601601

602602
def set_layer_weight(self, layer_id: int, name: str, tensor: torch.Tensor):
603603
self.weights[layer_id][name] = tensor
604-
gc.collect()
605604

606605
def set_global_weight(self, name: str, tensor: torch.Tensor):
607606
self.global_weights[name] = tensor

rtp_llm/model_loader/per_block_fp8_quant_weight.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,3 +774,42 @@ def _load_raw_tensor(
774774
res.update({self.scale.name: scale.contiguous().to(device)})
775775

776776
return res
777+
778+
def get_tensor_names(
779+
self, layer_id: Optional[int], load_config: LoadConfig
780+
) -> set[str]:
781+
return self.kernel.get_tensor_names(layer_id, load_config)
782+
783+
def _process_raw_tensors(
784+
self,
785+
raw_tensors: Dict[str, torch.Tensor],
786+
layer_id: Optional[int],
787+
device: str,
788+
load_config: LoadConfig,
789+
):
790+
kernel_raw_tensors: Dict[str, torch.Tensor] = {
791+
key: raw_tensors[key]
792+
for key in self.kernel.get_tensor_names(layer_id, load_config)
793+
}
794+
kernel = self.kernel._process_raw_tensors(
795+
kernel_raw_tensors, layer_id, device, load_config
796+
)
797+
res = {}
798+
scale = None
799+
if self.scale:
800+
quant_kernel, scale = per_block_cast_to_fp8(
801+
kernel.get(self.kernel.name), self.group_size
802+
)
803+
if quant_kernel.dim() == 2:
804+
scale = scale.reshape([scale.shape[0], -1])
805+
else:
806+
quant_kernel = cast_to_fp8(kernel.get(self.kernel.name))
807+
if self.kernel.name == W.moe_w1 or self.kernel.name == W.moe_w2:
808+
pass
809+
elif quant_kernel.dim() == 2:
810+
quant_kernel = quant_kernel.T
811+
res = {self.kernel.name: quant_kernel.contiguous().to(device)}
812+
if self.scale:
813+
scale = scale.T if scale.dim() == 2 else scale
814+
res.update({self.scale.name: scale.contiguous().to(device)})
815+
return res

rtp_llm/model_loader/static_fp8_quant_weight.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,59 @@ def _load_raw_tensor(
706706
)
707707
res.update(act_scale_inv)
708708
return res
709+
710+
def get_tensor_names(
711+
self, layer_id: Optional[int], load_config: LoadConfig
712+
) -> set[str]:
713+
names = self.kernel.get_tensor_names(layer_id, load_config)
714+
if self.act_scale:
715+
names = names.union(self.act_scale.get_tensor_names(layer_id, load_config))
716+
if self.act_scale_inv:
717+
names = names.union(
718+
self.act_scale_inv.get_tensor_names(layer_id, load_config)
719+
)
720+
return names
721+
722+
def _process_raw_tensors(
723+
self,
724+
raw_tensors: Dict[str, torch.Tensor],
725+
layer_id: Optional[int],
726+
device: str,
727+
load_config: LoadConfig,
728+
):
729+
kernel_raw_tensors: Dict[str, torch.Tensor] = {
730+
key: raw_tensors[key]
731+
for key in self.kernel.get_tensor_names(layer_id, load_config)
732+
}
733+
kernel = self.kernel._process_raw_tensors(
734+
kernel_raw_tensors, layer_id, device, load_config
735+
)
736+
res = {}
737+
quant_kernel, scale = quantize_weight_to_fp8(kernel.get(self.kernel.name))
738+
quant_kernel = quant_kernel.T
739+
res = {
740+
self.kernel.name: quant_kernel.contiguous().to(device),
741+
self.scale.name: scale.contiguous().to(device),
742+
}
743+
if self.act_scale:
744+
act_scale_raw_tensors: Dict[str, torch.Tensor] = {
745+
key: raw_tensors[key]
746+
for key in self.act_scale.get_tensor_names(layer_id, load_config)
747+
}
748+
act_scale = self.act_scale._process_raw_tensors(
749+
act_scale_raw_tensors, layer_id, device, load_config
750+
)
751+
res.update(act_scale)
752+
if self.act_scale_inv:
753+
act_scale_inv_raw_tensors: Dict[str, torch.Tensor] = {
754+
key: raw_tensors[key]
755+
for key in self.act_scale_inv.get_tensor_names(layer_id, load_config)
756+
}
757+
act_scale_inv = self.act_scale_inv._process_raw_tensors(
758+
act_scale_inv_raw_tensors, layer_id, device, load_config
759+
)
760+
res.update(act_scale_inv)
761+
return res
709762

710763

711764
class Fp8PerTensorCompressedWeight(CompositeWeight, QuantWeight):

0 commit comments

Comments
 (0)