Skip to content

Commit 0c67427

Browse files
Merge branch 'main' into main
2 parents 228e3e7 + bd0712e commit 0c67427

File tree

18 files changed

+998
-2285
lines changed

18 files changed

+998
-2285
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,15 @@ def __init__(self, kvargs):
7171
self._verify_must()
7272
self._verify_params()
7373
self._init_quant()
74-
self._init_weights()
75-
self._init_mem_manager()
74+
75+
# 更连续的显存分配可以有更好的性能
76+
if self.max_total_token_num is None:
77+
self._init_weights()
78+
self._init_mem_manager()
79+
else:
80+
self._init_mem_manager()
81+
self._init_weights()
82+
7683
self._init_kv_move_buffer()
7784
self._check_mem_size()
7885
self._init_req_manager()

lightllm/common/basemodel/cuda_graph.py

Lines changed: 58 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import os
22
import torch
33
from lightllm.utils.log_utils import init_logger
4-
from lightllm.distributed.parallel_state import graph_capture
5-
from contextlib import nullcontext
4+
from lightllm.distributed import lightllm_capture_graph
65

76
logger = init_logger(__name__)
87

@@ -32,8 +31,9 @@ def capture_decode(self, decode_func, input_ids, infer_state):
3231
torch.cuda.synchronize()
3332
decode_func(input_ids, infer_state)
3433
torch.cuda.synchronize()
35-
with torch.cuda.graph(graph_obj, pool=self.mempool, stream=self.stream):
36-
predict_logics = decode_func(input_ids, infer_state)
34+
with lightllm_capture_graph():
35+
with torch.cuda.graph(graph_obj, pool=self.mempool):
36+
predict_logics = decode_func(input_ids, infer_state)
3737
self.graph[batch_size] = (graph_obj, input_ids, infer_state, predict_logics)
3838
graph_obj.replay()
3939
return predict_logics
@@ -49,65 +49,61 @@ def replay(self, input_ids, infer_state):
4949
@torch.no_grad()
5050
def warmup(self, model):
5151
logger.info("Begin capture cudagraph, use the --disable_cudagraph to disable it.")
52-
LIGHTLLM_PYNCCL_ENABLE = os.getenv("LIGHTLLM_PYNCCL_ENABLE", "False").upper() in ["ON", "TRUE", "1"]
53-
graph_capture_context_manager = graph_capture() if LIGHTLLM_PYNCCL_ENABLE else nullcontext()
54-
with graph_capture_context_manager as graph_capture_context:
55-
self.stream = graph_capture_context.stream if graph_capture_context is not None else None
56-
for batch_size in range(self.max_batch_size, 0, -1):
57-
# dummy prefill
58-
prefill_input_len = 1
59-
dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
60-
b_req_idx = model.req_manager.alloc(batch_size).int()
61-
mem_indexes = model.mem_manager.alloc(len(dummy_input_ids))
62-
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda")
63-
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
64-
b_start_loc = torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
65-
total_token_num = prefill_input_len * batch_size
66-
logics = model.forward(
67-
batch_size,
68-
total_token_num,
69-
prefill_input_len,
70-
dummy_input_ids,
71-
mem_indexes,
72-
b_req_idx,
73-
b_start_loc,
74-
b_seq_len,
75-
b_ready_cache_len=b_ready_cache_len,
76-
is_prefill=True,
77-
multimodal_params=[],
78-
)
79-
mem_indexes = None
80-
prob_out = torch.softmax(logics, dim=-1)
81-
logics = None
82-
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
83-
prob_out = None
84-
predict_ids = predict_ids.detach().cpu().numpy()
85-
torch.cuda.empty_cache()
52+
for batch_size in range(self.max_batch_size, 0, -1):
53+
# dummy prefill
54+
prefill_input_len = 1
55+
dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
56+
b_req_idx = model.req_manager.alloc(batch_size).int()
57+
mem_indexes = model.mem_manager.alloc(len(dummy_input_ids))
58+
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda")
59+
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
60+
b_start_loc = torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
61+
total_token_num = prefill_input_len * batch_size
62+
logics = model.forward(
63+
batch_size,
64+
total_token_num,
65+
prefill_input_len,
66+
dummy_input_ids,
67+
mem_indexes,
68+
b_req_idx,
69+
b_start_loc,
70+
b_seq_len,
71+
b_ready_cache_len=b_ready_cache_len,
72+
is_prefill=True,
73+
multimodal_params=[],
74+
)
75+
mem_indexes = None
76+
prob_out = torch.softmax(logics, dim=-1)
77+
logics = None
78+
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
79+
prob_out = None
80+
predict_ids = predict_ids.detach().cpu().numpy()
81+
torch.cuda.empty_cache()
8682

87-
# dummy decoding, capture the cudagraph
88-
b_start_loc = b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
89-
total_token_num += batch_size
90-
b_seq_len += 1
91-
mem_indexes = model.mem_manager.alloc(len(predict_ids))
92-
logics = model.forward(
93-
batch_size,
94-
total_token_num,
95-
prefill_input_len + 1,
96-
torch.from_numpy(predict_ids).cuda().reshape(-1),
97-
mem_indexes,
98-
b_req_idx,
99-
b_start_loc,
100-
b_seq_len,
101-
is_prefill=False,
102-
)
103-
mem_indexes = None
104-
model.mem_manager.free_all()
105-
model.req_manager.free_all()
106-
# release local tensors
107-
for var_name, var_value in list(locals().items()):
108-
if isinstance(var_value, torch.Tensor):
109-
del locals()[var_name]
110-
torch.cuda.empty_cache()
83+
# dummy decoding, capture the cudagraph
84+
b_start_loc = b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
85+
total_token_num += batch_size
86+
b_seq_len += 1
87+
mem_indexes = model.mem_manager.alloc(len(predict_ids))
88+
logics = model.forward(
89+
batch_size,
90+
total_token_num,
91+
prefill_input_len + 1,
92+
torch.from_numpy(predict_ids).cuda().reshape(-1),
93+
mem_indexes,
94+
b_req_idx,
95+
b_start_loc,
96+
b_seq_len,
97+
is_prefill=False,
98+
)
99+
mem_indexes = None
100+
model.mem_manager.free_all()
101+
model.req_manager.free_all()
102+
# release local tensors
103+
for var_name, var_value in list(locals().items()):
104+
if isinstance(var_value, torch.Tensor):
105+
del locals()[var_name]
106+
torch.cuda.empty_cache()
111107
logger.info(
112108
f"Capture cudagraph success, batch_size <={self.max_batch_size} "
113109
f"and max_len_in_batch <= {self.graph_max_len_in_batch} will infer with cudagraph."

lightllm/common/basemodel/layer_weights/meta_weights/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .base_weight import BaseWeight
22
from .mm_weight import (
3+
MMWeightTpl,
34
MMWeight,
45
MultiMMWeight,
56
ROWMMWeight,

lightllm/common/basemodel/layer_weights/transformer_layer_weight.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# from lightllm.common.layers.mm import MM
44
from .base_layer_weight import BaseLayerWeight
5-
from .meta_weights import BaseWeight, MultiMMWeight, MMWeight, FusedMoeWeight
5+
from .meta_weights import BaseWeight, MultiMMWeight, MMWeightTpl, FusedMoeWeight
66
from lightllm.utils.log_utils import init_logger
77

88
logger = init_logger(__name__)
@@ -51,7 +51,7 @@ def set_quantization(self):
5151
mix_quant_list = self.quant_cfg.get_mixed_list(self.layer_num_)
5252
for attr_name in dir(self):
5353
attr = getattr(self, attr_name)
54-
if isinstance(attr, MMWeight) or isinstance(attr, FusedMoeWeight):
54+
if isinstance(attr, MMWeightTpl) or isinstance(attr, FusedMoeWeight):
5555
if attr_name in mix_quant_list:
5656
attr.set_quant_method(self.quant_cfg.get_quant_method(self.layer_num_, attr_name))
5757
attr_quant_type = self.quant_cfg.get_quant_type(self.layer_num_, attr_name)

lightllm/distributed/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
from .communication_op import *
2-
from .parallel_state import *

lightllm/distributed/communication_op.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,42 +19,59 @@
1919

2020
from typing import Any, Dict, Optional, Union
2121

22+
import os
2223
import torch
23-
import torch.distributed
24-
25-
from .parallel_state import get_tp_group, original_all_reduce
24+
import torch.distributed as dist
2625
from torch.distributed import ReduceOp
2726
from lightllm.utils.log_utils import init_logger
2827

28+
original_all_reduce = torch.distributed.all_reduce
29+
from contextlib import nullcontext, contextmanager
30+
31+
try:
32+
HAS_VLLM = True
33+
from .custom_all_reduce import CustomAllreduce
34+
except:
35+
HAS_VLLM = False
2936

37+
vllm_reduce = None
3038
logger = init_logger(__name__)
31-
# if op != ReduceOp.SUM or group != None or async_op != False:
32-
# logger.warning("This function op, group, async_op will only run with default values")
3339

3440

35-
def all_reduce(input_, op=ReduceOp.SUM, group=None, async_op=False):
36-
if op != ReduceOp.SUM or group is not None or async_op:
37-
original_all_reduce(input_, op, group, async_op)
41+
@contextmanager
42+
def lightllm_capture_graph():
43+
if vllm_reduce is not None:
44+
with vllm_reduce.capture():
45+
yield
3846
else:
39-
input_.data = tensor_model_parallel_all_reduce(input_)
40-
41-
42-
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
43-
"""All-reduce the input tensor across model parallel group."""
44-
return get_tp_group().all_reduce(input_)
47+
yield
48+
pass
4549

4650

47-
def tensor_model_parallel_all_gather(input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
48-
"""All-gather the input tensor across model parallel group."""
49-
return get_tp_group().all_gather(input_, dim)
50-
51+
def _all_reduce(input_, op=ReduceOp.SUM, group=None, async_op=False):
52+
if op != ReduceOp.SUM or group is not None or async_op or vllm_reduce is None:
53+
original_all_reduce(input_, op, group, async_op)
54+
else:
55+
if vllm_reduce is not None:
56+
can_use = vllm_reduce.should_custom_ar(input_)
57+
if can_use:
58+
input_.data = vllm_reduce.custom_all_reduce(input_)
59+
return
60+
original_all_reduce(input_, op, group, async_op)
5161

52-
def tensor_model_parallel_gather(input_: torch.Tensor, dst: int = 0, dim: int = -1) -> Optional[torch.Tensor]:
53-
"""Gather the input tensor across model parallel group."""
54-
return get_tp_group().gather(input_, dst, dim)
5562

63+
def set_custom_reduce():
64+
global vllm_reduce
5665

57-
def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0):
58-
if not torch.distributed.is_initialized():
59-
return tensor_dict
60-
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
66+
ENABLE_VLLM_REDUCE = os.getenv("ENABLE_VLLM_REDUCE", "False").upper() in [
67+
"ON",
68+
"TRUE",
69+
"1",
70+
]
71+
if ENABLE_VLLM_REDUCE and HAS_VLLM:
72+
world_size = dist.get_world_size()
73+
ranks = list(range(world_size))
74+
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
75+
vllm_reduce = CustomAllreduce(cpu_group, torch.cuda.current_device())
76+
logger.info("Enable VLLM ALLReduce.")
77+
dist.all_reduce = _all_reduce

0 commit comments

Comments
 (0)