1
1
import os
2
2
import torch
3
3
from lightllm .utils .log_utils import init_logger
4
- from lightllm .distributed .parallel_state import graph_capture
5
- from contextlib import nullcontext
4
+ from lightllm .distributed import lightllm_capture_graph
6
5
7
6
logger = init_logger (__name__ )
8
7
@@ -32,8 +31,9 @@ def capture_decode(self, decode_func, input_ids, infer_state):
32
31
torch .cuda .synchronize ()
33
32
decode_func (input_ids , infer_state )
34
33
torch .cuda .synchronize ()
35
- with torch .cuda .graph (graph_obj , pool = self .mempool , stream = self .stream ):
36
- predict_logics = decode_func (input_ids , infer_state )
34
+ with lightllm_capture_graph ():
35
+ with torch .cuda .graph (graph_obj , pool = self .mempool ):
36
+ predict_logics = decode_func (input_ids , infer_state )
37
37
self .graph [batch_size ] = (graph_obj , input_ids , infer_state , predict_logics )
38
38
graph_obj .replay ()
39
39
return predict_logics
@@ -49,65 +49,61 @@ def replay(self, input_ids, infer_state):
49
49
@torch .no_grad ()
50
50
def warmup (self , model ):
51
51
logger .info ("Begin capture cudagraph, use the --disable_cudagraph to disable it." )
52
- LIGHTLLM_PYNCCL_ENABLE = os .getenv ("LIGHTLLM_PYNCCL_ENABLE" , "False" ).upper () in ["ON" , "TRUE" , "1" ]
53
- graph_capture_context_manager = graph_capture () if LIGHTLLM_PYNCCL_ENABLE else nullcontext ()
54
- with graph_capture_context_manager as graph_capture_context :
55
- self .stream = graph_capture_context .stream if graph_capture_context is not None else None
56
- for batch_size in range (self .max_batch_size , 0 , - 1 ):
57
- # dummy prefill
58
- prefill_input_len = 1
59
- dummy_input_ids = torch .ones ((batch_size ,), dtype = torch .int32 , device = "cuda" )
60
- b_req_idx = model .req_manager .alloc (batch_size ).int ()
61
- mem_indexes = model .mem_manager .alloc (len (dummy_input_ids ))
62
- b_seq_len = torch .ones (batch_size , dtype = torch .int32 , device = "cuda" )
63
- b_ready_cache_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
64
- b_start_loc = torch .arange (0 , batch_size , dtype = torch .int32 , device = "cuda" )
65
- total_token_num = prefill_input_len * batch_size
66
- logics = model .forward (
67
- batch_size ,
68
- total_token_num ,
69
- prefill_input_len ,
70
- dummy_input_ids ,
71
- mem_indexes ,
72
- b_req_idx ,
73
- b_start_loc ,
74
- b_seq_len ,
75
- b_ready_cache_len = b_ready_cache_len ,
76
- is_prefill = True ,
77
- multimodal_params = [],
78
- )
79
- mem_indexes = None
80
- prob_out = torch .softmax (logics , dim = - 1 )
81
- logics = None
82
- predict_ids = torch .argmax (prob_out , dim = 1 , keepdim = True )
83
- prob_out = None
84
- predict_ids = predict_ids .detach ().cpu ().numpy ()
85
- torch .cuda .empty_cache ()
52
+ for batch_size in range (self .max_batch_size , 0 , - 1 ):
53
+ # dummy prefill
54
+ prefill_input_len = 1
55
+ dummy_input_ids = torch .ones ((batch_size ,), dtype = torch .int32 , device = "cuda" )
56
+ b_req_idx = model .req_manager .alloc (batch_size ).int ()
57
+ mem_indexes = model .mem_manager .alloc (len (dummy_input_ids ))
58
+ b_seq_len = torch .ones (batch_size , dtype = torch .int32 , device = "cuda" )
59
+ b_ready_cache_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
60
+ b_start_loc = torch .arange (0 , batch_size , dtype = torch .int32 , device = "cuda" )
61
+ total_token_num = prefill_input_len * batch_size
62
+ logics = model .forward (
63
+ batch_size ,
64
+ total_token_num ,
65
+ prefill_input_len ,
66
+ dummy_input_ids ,
67
+ mem_indexes ,
68
+ b_req_idx ,
69
+ b_start_loc ,
70
+ b_seq_len ,
71
+ b_ready_cache_len = b_ready_cache_len ,
72
+ is_prefill = True ,
73
+ multimodal_params = [],
74
+ )
75
+ mem_indexes = None
76
+ prob_out = torch .softmax (logics , dim = - 1 )
77
+ logics = None
78
+ predict_ids = torch .argmax (prob_out , dim = 1 , keepdim = True )
79
+ prob_out = None
80
+ predict_ids = predict_ids .detach ().cpu ().numpy ()
81
+ torch .cuda .empty_cache ()
86
82
87
- # dummy decoding, capture the cudagraph
88
- b_start_loc = b_start_loc + torch .arange (0 , batch_size , dtype = torch .int32 , device = "cuda" )
89
- total_token_num += batch_size
90
- b_seq_len += 1
91
- mem_indexes = model .mem_manager .alloc (len (predict_ids ))
92
- logics = model .forward (
93
- batch_size ,
94
- total_token_num ,
95
- prefill_input_len + 1 ,
96
- torch .from_numpy (predict_ids ).cuda ().reshape (- 1 ),
97
- mem_indexes ,
98
- b_req_idx ,
99
- b_start_loc ,
100
- b_seq_len ,
101
- is_prefill = False ,
102
- )
103
- mem_indexes = None
104
- model .mem_manager .free_all ()
105
- model .req_manager .free_all ()
106
- # release local tensors
107
- for var_name , var_value in list (locals ().items ()):
108
- if isinstance (var_value , torch .Tensor ):
109
- del locals ()[var_name ]
110
- torch .cuda .empty_cache ()
83
+ # dummy decoding, capture the cudagraph
84
+ b_start_loc = b_start_loc + torch .arange (0 , batch_size , dtype = torch .int32 , device = "cuda" )
85
+ total_token_num += batch_size
86
+ b_seq_len += 1
87
+ mem_indexes = model .mem_manager .alloc (len (predict_ids ))
88
+ logics = model .forward (
89
+ batch_size ,
90
+ total_token_num ,
91
+ prefill_input_len + 1 ,
92
+ torch .from_numpy (predict_ids ).cuda ().reshape (- 1 ),
93
+ mem_indexes ,
94
+ b_req_idx ,
95
+ b_start_loc ,
96
+ b_seq_len ,
97
+ is_prefill = False ,
98
+ )
99
+ mem_indexes = None
100
+ model .mem_manager .free_all ()
101
+ model .req_manager .free_all ()
102
+ # release local tensors
103
+ for var_name , var_value in list (locals ().items ()):
104
+ if isinstance (var_value , torch .Tensor ):
105
+ del locals ()[var_name ]
106
+ torch .cuda .empty_cache ()
111
107
logger .info (
112
108
f"Capture cudagraph success, batch_size <={ self .max_batch_size } "
113
109
f"and max_len_in_batch <= { self .graph_max_len_in_batch } will infer with cudagraph."
0 commit comments