1
1
import torch
2
-
3
2
import triton
4
3
import triton .language as tl
5
4
import numpy as np
5
+ from lightllm .common .req_manager import ReqSamplingParamsManager
6
6
7
7
8
8
@triton .jit
9
9
def _fwd_kernel_apply_penalty (
10
10
Logits ,
11
- presence_penalty ,
12
- freqency_penalty ,
13
- repetition_penalty ,
11
+ stride_logit_b ,
12
+ stride_logit_s ,
13
+ b_req_idx ,
14
+ req_to_presence_penalty ,
15
+ req_to_frequency_penalty ,
16
+ req_to_repetition_penalty ,
17
+ req_to_exponential_decay_length_penalty ,
18
+ b_length_penalty_param ,
14
19
p_token_ids ,
15
20
p_token_counts ,
16
21
p_cumsum_seq_len ,
17
- exponential_decay_length_penalties ,
18
- length_penalty_idx ,
19
22
eos_ids ,
20
- mask_eos_reqs ,
21
- stride_logit_b ,
22
- stride_logit_s ,
23
+ b_mask_eos_reqs ,
24
+ vocab_size ,
23
25
BLOCK_P : tl .constexpr ,
24
26
EOS_ID_NUM : tl .constexpr ,
25
27
):
26
28
cur_batch = tl .program_id (0 )
27
- cur_freqency = tl .load (freqency_penalty + cur_batch )
28
- cur_presence = tl .load (presence_penalty + cur_batch )
29
- cur_repetition = tl .load (repetition_penalty + cur_batch )
29
+ cur_req_idx = tl .load (b_req_idx + cur_batch )
30
+ cur_freqency = tl .load (req_to_frequency_penalty + cur_req_idx )
31
+ cur_presence = tl .load (req_to_presence_penalty + cur_req_idx )
32
+ cur_repetition = tl .load (req_to_repetition_penalty + cur_req_idx )
30
33
31
34
cur_batch_start_index = tl .load (p_cumsum_seq_len + cur_batch )
32
35
cur_batch_end_index = tl .load (p_cumsum_seq_len + cur_batch + 1 )
33
36
for block_start_index in range (cur_batch_start_index , cur_batch_end_index , BLOCK_P ):
34
37
cur_batch_id_offset = block_start_index + tl .arange (0 , BLOCK_P )
35
- batch_ids = tl .load (p_token_ids + cur_batch_id_offset , mask = cur_batch_id_offset < cur_batch_end_index , other = 0 )
36
- batch_ids_count = tl .load (
38
+ token_ids = tl .load (p_token_ids + cur_batch_id_offset , mask = cur_batch_id_offset < cur_batch_end_index , other = 0 )
39
+ token_ids_count = tl .load (
37
40
p_token_counts + cur_batch_id_offset , mask = cur_batch_id_offset < cur_batch_end_index , other = 0
38
41
)
39
42
40
43
row_start_ptr = Logits + cur_batch * stride_logit_b
41
- cur_offset = row_start_ptr + batch_ids
42
- cur_logits = tl .load (cur_offset , mask = cur_batch_id_offset < cur_batch_end_index , other = 0.0 )
44
+ cur_offset = row_start_ptr + token_ids
45
+ cur_logits = tl .load (
46
+ cur_offset , mask = (cur_batch_id_offset < cur_batch_end_index ) & (token_ids < vocab_size ), other = 0.0
47
+ )
43
48
rep_logits = tl .where (cur_logits > 0 , cur_logits / cur_repetition , cur_logits * cur_repetition )
44
- freq_logits = rep_logits - batch_ids_count * cur_freqency
49
+ freq_logits = rep_logits - token_ids_count * cur_freqency
45
50
pre_logits = freq_logits - cur_presence
46
- output_ptr = Logits + cur_batch * stride_logit_b + batch_ids
47
- tl .store (output_ptr , pre_logits , mask = cur_batch_id_offset < cur_batch_end_index )
51
+ output_ptr = Logits + cur_batch * stride_logit_b + token_ids
52
+ tl .store (output_ptr , pre_logits , mask = ( cur_batch_id_offset < cur_batch_end_index ) & ( token_ids < vocab_size ) )
48
53
49
- mask_eos = tl .load (mask_eos_reqs + cur_batch )
50
- exponential_decay_length_penalty = tl .load (exponential_decay_length_penalties + cur_batch )
51
- length_penalty = tl .load (length_penalty_idx + cur_batch )
54
+ mask_eos = tl .load (b_mask_eos_reqs + cur_batch )
55
+ exponential_decay_length_penalty = tl .load (req_to_exponential_decay_length_penalty + cur_req_idx )
56
+ length_penalty = tl .load (b_length_penalty_param + cur_batch )
52
57
penalty_scale = tl .exp2 (tl .log2 (exponential_decay_length_penalty ) * length_penalty ) - 1
53
58
54
59
for eos_index in range (EOS_ID_NUM ):
@@ -63,35 +68,35 @@ def _fwd_kernel_apply_penalty(
63
68
64
69
@torch .no_grad ()
65
70
def apply_penalty (
66
- Logits ,
67
- presence_penalty ,
68
- freqency_penalty ,
69
- repetition_penalty ,
70
- p_token_ids ,
71
- p_token_counts ,
72
- p_cumsum_seq_len ,
73
- exponential_decay_length_penalties ,
74
- length_penalty_idx ,
75
- eos_ids ,
76
- mask_eos_reqs ,
71
+ Logits : torch .Tensor ,
72
+ b_req_idx : torch .Tensor ,
73
+ b_length_penalty_param : torch .Tensor ,
74
+ b_mask_eos_reqs : torch .Tensor ,
75
+ p_token_ids : torch .Tensor ,
76
+ p_token_counts : torch .Tensor ,
77
+ p_cumsum_seq_len : torch .Tensor ,
78
+ eos_ids : torch .Tensor ,
79
+ sampling_params_manager : ReqSamplingParamsManager ,
77
80
):
78
81
assert Logits .is_contiguous ()
79
82
BLOCK_P = 1024
80
83
num_warps = 8
81
84
_fwd_kernel_apply_penalty [(Logits .shape [0 ],)](
82
- Logits ,
83
- presence_penalty ,
84
- freqency_penalty ,
85
- repetition_penalty ,
86
- p_token_ids ,
87
- p_token_counts ,
88
- p_cumsum_seq_len ,
89
- exponential_decay_length_penalties ,
90
- length_penalty_idx ,
91
- eos_ids ,
92
- mask_eos_reqs ,
93
- Logits .stride (0 ),
94
- Logits .stride (1 ),
85
+ Logits = Logits ,
86
+ stride_logit_b = Logits .stride (0 ),
87
+ stride_logit_s = Logits .stride (1 ),
88
+ b_req_idx = b_req_idx ,
89
+ req_to_presence_penalty = sampling_params_manager .req_to_presence_penalty ,
90
+ req_to_frequency_penalty = sampling_params_manager .req_to_frequency_penalty ,
91
+ req_to_repetition_penalty = sampling_params_manager .req_to_repetition_penalty ,
92
+ req_to_exponential_decay_length_penalty = sampling_params_manager .req_to_exponential_decay_length_penalty ,
93
+ b_length_penalty_param = b_length_penalty_param ,
94
+ p_token_ids = p_token_ids ,
95
+ p_token_counts = p_token_counts ,
96
+ p_cumsum_seq_len = p_cumsum_seq_len ,
97
+ eos_ids = eos_ids ,
98
+ b_mask_eos_reqs = b_mask_eos_reqs ,
99
+ vocab_size = sampling_params_manager .vocab_size ,
95
100
num_warps = num_warps ,
96
101
BLOCK_P = BLOCK_P ,
97
102
EOS_ID_NUM = eos_ids .shape [0 ],
0 commit comments