Skip to content

Commit 6a4a100

Browse files
committed
Merge remote-tracking branch 'origin/main' into add-lightllm-kernel
2 parents 597dc1a + 4b170a2 commit 6a4a100

File tree

20 files changed

+713
-167
lines changed

20 files changed

+713
-167
lines changed

.github/workflows/docker-publish.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ jobs:
3030
id-token: write
3131

3232
steps:
33+
- name: Free Disk Space (Ubuntu)
34+
uses: jlumbroso/free-disk-space@main
35+
with:
36+
tool-cache: true
37+
android: true
38+
dotnet: true
39+
haskell: true
40+
large-packages: true
41+
swap-storage: false
42+
docker-images: false
43+
3344
- name: Checkout repository
3445
uses: actions/checkout@v3
3546

lightllm/common/basemodel/triton_kernel/apply_penalty.py

Lines changed: 51 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,59 @@
11
import torch
2-
32
import triton
43
import triton.language as tl
54
import numpy as np
5+
from lightllm.common.req_manager import ReqSamplingParamsManager
66

77

88
@triton.jit
99
def _fwd_kernel_apply_penalty(
1010
Logits,
11-
presence_penalty,
12-
freqency_penalty,
13-
repetition_penalty,
11+
stride_logit_b,
12+
stride_logit_s,
13+
b_req_idx,
14+
req_to_presence_penalty,
15+
req_to_frequency_penalty,
16+
req_to_repetition_penalty,
17+
req_to_exponential_decay_length_penalty,
18+
b_length_penalty_param,
1419
p_token_ids,
1520
p_token_counts,
1621
p_cumsum_seq_len,
17-
exponential_decay_length_penalties,
18-
length_penalty_idx,
1922
eos_ids,
20-
mask_eos_reqs,
21-
stride_logit_b,
22-
stride_logit_s,
23+
b_mask_eos_reqs,
24+
vocab_size,
2325
BLOCK_P: tl.constexpr,
2426
EOS_ID_NUM: tl.constexpr,
2527
):
2628
cur_batch = tl.program_id(0)
27-
cur_freqency = tl.load(freqency_penalty + cur_batch)
28-
cur_presence = tl.load(presence_penalty + cur_batch)
29-
cur_repetition = tl.load(repetition_penalty + cur_batch)
29+
cur_req_idx = tl.load(b_req_idx + cur_batch)
30+
cur_freqency = tl.load(req_to_frequency_penalty + cur_req_idx)
31+
cur_presence = tl.load(req_to_presence_penalty + cur_req_idx)
32+
cur_repetition = tl.load(req_to_repetition_penalty + cur_req_idx)
3033

3134
cur_batch_start_index = tl.load(p_cumsum_seq_len + cur_batch)
3235
cur_batch_end_index = tl.load(p_cumsum_seq_len + cur_batch + 1)
3336
for block_start_index in range(cur_batch_start_index, cur_batch_end_index, BLOCK_P):
3437
cur_batch_id_offset = block_start_index + tl.arange(0, BLOCK_P)
35-
batch_ids = tl.load(p_token_ids + cur_batch_id_offset, mask=cur_batch_id_offset < cur_batch_end_index, other=0)
36-
batch_ids_count = tl.load(
38+
token_ids = tl.load(p_token_ids + cur_batch_id_offset, mask=cur_batch_id_offset < cur_batch_end_index, other=0)
39+
token_ids_count = tl.load(
3740
p_token_counts + cur_batch_id_offset, mask=cur_batch_id_offset < cur_batch_end_index, other=0
3841
)
3942

4043
row_start_ptr = Logits + cur_batch * stride_logit_b
41-
cur_offset = row_start_ptr + batch_ids
42-
cur_logits = tl.load(cur_offset, mask=cur_batch_id_offset < cur_batch_end_index, other=0.0)
44+
cur_offset = row_start_ptr + token_ids
45+
cur_logits = tl.load(
46+
cur_offset, mask=(cur_batch_id_offset < cur_batch_end_index) & (token_ids < vocab_size), other=0.0
47+
)
4348
rep_logits = tl.where(cur_logits > 0, cur_logits / cur_repetition, cur_logits * cur_repetition)
44-
freq_logits = rep_logits - batch_ids_count * cur_freqency
49+
freq_logits = rep_logits - token_ids_count * cur_freqency
4550
pre_logits = freq_logits - cur_presence
46-
output_ptr = Logits + cur_batch * stride_logit_b + batch_ids
47-
tl.store(output_ptr, pre_logits, mask=cur_batch_id_offset < cur_batch_end_index)
51+
output_ptr = Logits + cur_batch * stride_logit_b + token_ids
52+
tl.store(output_ptr, pre_logits, mask=(cur_batch_id_offset < cur_batch_end_index) & (token_ids < vocab_size))
4853

49-
mask_eos = tl.load(mask_eos_reqs + cur_batch)
50-
exponential_decay_length_penalty = tl.load(exponential_decay_length_penalties + cur_batch)
51-
length_penalty = tl.load(length_penalty_idx + cur_batch)
54+
mask_eos = tl.load(b_mask_eos_reqs + cur_batch)
55+
exponential_decay_length_penalty = tl.load(req_to_exponential_decay_length_penalty + cur_req_idx)
56+
length_penalty = tl.load(b_length_penalty_param + cur_batch)
5257
penalty_scale = tl.exp2(tl.log2(exponential_decay_length_penalty) * length_penalty) - 1
5358

5459
for eos_index in range(EOS_ID_NUM):
@@ -63,35 +68,35 @@ def _fwd_kernel_apply_penalty(
6368

6469
@torch.no_grad()
6570
def apply_penalty(
66-
Logits,
67-
presence_penalty,
68-
freqency_penalty,
69-
repetition_penalty,
70-
p_token_ids,
71-
p_token_counts,
72-
p_cumsum_seq_len,
73-
exponential_decay_length_penalties,
74-
length_penalty_idx,
75-
eos_ids,
76-
mask_eos_reqs,
71+
Logits: torch.Tensor,
72+
b_req_idx: torch.Tensor,
73+
b_length_penalty_param: torch.Tensor,
74+
b_mask_eos_reqs: torch.Tensor,
75+
p_token_ids: torch.Tensor,
76+
p_token_counts: torch.Tensor,
77+
p_cumsum_seq_len: torch.Tensor,
78+
eos_ids: torch.Tensor,
79+
sampling_params_manager: ReqSamplingParamsManager,
7780
):
7881
assert Logits.is_contiguous()
7982
BLOCK_P = 1024
8083
num_warps = 8
8184
_fwd_kernel_apply_penalty[(Logits.shape[0],)](
82-
Logits,
83-
presence_penalty,
84-
freqency_penalty,
85-
repetition_penalty,
86-
p_token_ids,
87-
p_token_counts,
88-
p_cumsum_seq_len,
89-
exponential_decay_length_penalties,
90-
length_penalty_idx,
91-
eos_ids,
92-
mask_eos_reqs,
93-
Logits.stride(0),
94-
Logits.stride(1),
85+
Logits=Logits,
86+
stride_logit_b=Logits.stride(0),
87+
stride_logit_s=Logits.stride(1),
88+
b_req_idx=b_req_idx,
89+
req_to_presence_penalty=sampling_params_manager.req_to_presence_penalty,
90+
req_to_frequency_penalty=sampling_params_manager.req_to_frequency_penalty,
91+
req_to_repetition_penalty=sampling_params_manager.req_to_repetition_penalty,
92+
req_to_exponential_decay_length_penalty=sampling_params_manager.req_to_exponential_decay_length_penalty,
93+
b_length_penalty_param=b_length_penalty_param,
94+
p_token_ids=p_token_ids,
95+
p_token_counts=p_token_counts,
96+
p_cumsum_seq_len=p_cumsum_seq_len,
97+
eos_ids=eos_ids,
98+
b_mask_eos_reqs=b_mask_eos_reqs,
99+
vocab_size=sampling_params_manager.vocab_size,
95100
num_warps=num_warps,
96101
BLOCK_P=BLOCK_P,
97102
EOS_ID_NUM=eos_ids.shape[0],
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
import torch.nn.functional as F
5+
import numpy as np
6+
from lightllm.common.req_manager import ReqSamplingParamsManager
7+
8+
9+
@triton.jit
10+
def _fwd_kernel_apply_penalty_cache(
11+
Logits,
12+
stride_logit_b,
13+
stride_logit_s,
14+
b_req_idx,
15+
req_to_presence_penalty,
16+
req_to_frequency_penalty,
17+
req_to_repetition_penalty,
18+
req_to_out_token_id_counter,
19+
stride_counter_r,
20+
stride_counter_s,
21+
vocab_size,
22+
BLOCK_P: tl.constexpr,
23+
):
24+
cur_batch = tl.program_id(0)
25+
cur_req_idx = tl.load(b_req_idx + cur_batch)
26+
block_idx = tl.program_id(1)
27+
cur_freqency = tl.load(req_to_frequency_penalty + cur_req_idx)
28+
cur_presence = tl.load(req_to_presence_penalty + cur_req_idx)
29+
cur_repetition = tl.load(req_to_repetition_penalty + cur_req_idx)
30+
31+
token_ids = BLOCK_P * block_idx + tl.arange(0, BLOCK_P)
32+
mask = token_ids < vocab_size
33+
token_ids_count = tl.load(
34+
req_to_out_token_id_counter + cur_req_idx * stride_counter_r + token_ids,
35+
mask=mask,
36+
other=0,
37+
)
38+
row_start_ptr = Logits + cur_batch * stride_logit_b
39+
cur_offset = row_start_ptr + token_ids
40+
origin_logits = tl.load(cur_offset, mask=mask, other=0.0)
41+
p_logits = tl.where(origin_logits > 0, origin_logits / cur_repetition, origin_logits * cur_repetition)
42+
p_logits = tl.where(token_ids_count > 0, p_logits, origin_logits)
43+
p_logits = p_logits - token_ids_count * cur_freqency
44+
p_logits = p_logits - tl.where(token_ids_count > 0, cur_presence, 0.0)
45+
output_ptr = Logits + cur_batch * stride_logit_b + token_ids
46+
tl.store(output_ptr, p_logits, mask=mask)
47+
return
48+
49+
50+
@triton.jit
51+
def _eos_penalty(
52+
Logits,
53+
stride_logit_b,
54+
stride_logit_s,
55+
b_req_idx,
56+
req_to_exponential_decay_length_penalty,
57+
b_length_penalty_param,
58+
eos_ids,
59+
b_mask_eos_reqs,
60+
batch_size,
61+
BLOCK: tl.constexpr,
62+
EOS_ID_NUM: tl.constexpr,
63+
):
64+
block_index = tl.program_id(0)
65+
offs = block_index * BLOCK + tl.arange(0, BLOCK)
66+
mask = offs < batch_size
67+
req_idxes = tl.load(b_req_idx + offs, mask=mask, other=0)
68+
exponential_decay_length_penalty = tl.load(
69+
req_to_exponential_decay_length_penalty + req_idxes, mask=mask, other=1.0
70+
)
71+
length_penalty = tl.load(b_length_penalty_param + offs, mask=mask, other=0)
72+
penalty_scale = tl.exp2(tl.log2(exponential_decay_length_penalty) * length_penalty) - 1
73+
mask_eos = tl.load(b_mask_eos_reqs + offs, mask=mask, other=True)
74+
for eos_index in range(EOS_ID_NUM):
75+
eos_id = tl.load(eos_ids + eos_index)
76+
cur_eos_logit_ptr = Logits + offs * stride_logit_b + eos_id
77+
cur_eos_logit = tl.load(cur_eos_logit_ptr, mask=mask, other=0.0)
78+
cur_eos_logit = cur_eos_logit + tl.abs(cur_eos_logit) * penalty_scale
79+
cur_eos_logit = tl.where(mask_eos, -10000000.0, cur_eos_logit)
80+
tl.store(cur_eos_logit_ptr, cur_eos_logit, mask=mask)
81+
return
82+
83+
84+
@torch.no_grad()
85+
def apply_penalty_gpu_cache(
86+
Logits: torch.Tensor,
87+
b_req_idx: torch.Tensor,
88+
b_length_penalty_param: torch.Tensor,
89+
b_mask_eos_reqs: torch.Tensor,
90+
eos_ids: torch.Tensor,
91+
sampling_params_manager: ReqSamplingParamsManager,
92+
):
93+
assert Logits.is_contiguous()
94+
BLOCK_P = 2048
95+
num_warps = 8
96+
vocab_size = sampling_params_manager.vocab_size
97+
req_to_out_token_id_counter = sampling_params_manager.req_to_out_token_id_counter
98+
_fwd_kernel_apply_penalty_cache[(Logits.shape[0], triton.cdiv(vocab_size, BLOCK_P))](
99+
Logits=Logits,
100+
stride_logit_b=Logits.stride(0),
101+
stride_logit_s=Logits.stride(1),
102+
b_req_idx=b_req_idx,
103+
req_to_presence_penalty=sampling_params_manager.req_to_presence_penalty,
104+
req_to_frequency_penalty=sampling_params_manager.req_to_frequency_penalty,
105+
req_to_repetition_penalty=sampling_params_manager.req_to_repetition_penalty,
106+
req_to_out_token_id_counter=req_to_out_token_id_counter,
107+
stride_counter_r=req_to_out_token_id_counter.stride(0),
108+
stride_counter_s=req_to_out_token_id_counter.stride(1),
109+
vocab_size=vocab_size,
110+
BLOCK_P=BLOCK_P,
111+
num_warps=num_warps,
112+
)
113+
114+
BLOCK = 128
115+
grid = (triton.cdiv(Logits.shape[0], BLOCK),)
116+
_eos_penalty[grid](
117+
Logits=Logits,
118+
stride_logit_b=Logits.stride(0),
119+
stride_logit_s=Logits.stride(1),
120+
b_req_idx=b_req_idx,
121+
req_to_exponential_decay_length_penalty=sampling_params_manager.req_to_exponential_decay_length_penalty,
122+
b_length_penalty_param=b_length_penalty_param,
123+
eos_ids=eos_ids,
124+
b_mask_eos_reqs=b_mask_eos_reqs,
125+
batch_size=Logits.shape[0],
126+
BLOCK=BLOCK,
127+
EOS_ID_NUM=eos_ids.shape[0],
128+
num_warps=1,
129+
)
130+
return

0 commit comments

Comments
 (0)