Skip to content

Commit 2842c59

Browse files
committed
bucket: add query len 1 to prefill bucket
Signed-off-by: Xinyu Chen <[email protected]>
1 parent 2e56818 commit 2842c59

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

vllm_gaudi/extension/bucketing/exponential.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def get_prompt_cfgs(self, max_num_prefill_seqs, block_size, max_num_batched_toke
3636
prompt_bs_limit = math.ceil(math.log2(max_num_prefill_seqs)) + 1
3737
prompt_bs_bucket_cfg = [1, 2, max_num_prefill_seqs, prompt_bs_limit]
3838
max_prompt_seq_limit = math.ceil(math.log2(max_num_batched_tokens))
39-
prompt_query_bucket_cfg = [block_size, block_size, max_num_batched_tokens, max_prompt_seq_limit]
39+
prompt_query_bucket_cfg = [1, block_size, max_num_batched_tokens, max_prompt_seq_limit]
4040
max_ctx = max(1, math.ceil((max_model_len - prompt_query_bucket_cfg[0]) // block_size))
4141
max_prompt_ctx_limit = 2 if max_ctx == 1 else math.ceil(math.log2(max_ctx)) + 1
4242
prompt_ctx_bucket_cfg = [0, 1, max_ctx, max_prompt_ctx_limit]
@@ -125,7 +125,8 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], long_context=Fals
125125

126126
bmin, bstep, bmax, num_buckets = config
127127
add_zero_bucket = bmin == 0
128-
if add_zero_bucket:
128+
add_one_bucket = bmin == 1
129+
if add_zero_bucket or add_one_bucket:
129130
bmin = bstep
130131
linear_buckets = set(np.arange(bmin, bmax + 1, step=bstep))
131132
assert num_buckets > 0, "num_buckets must be a positive integer"
@@ -176,4 +177,6 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], long_context=Fals
176177
buckets.add(bucket)
177178
if add_zero_bucket:
178179
buckets.add(0)
180+
if add_one_bucket:
181+
buckets.add(1)
179182
return list(sorted(buckets))

vllm_gaudi/extension/bucketing/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def get_prompt_cfgs(self, max_num_prefill_seqs, block_size, max_num_batched_toke
1818
prompt_bs_bucket_cfg = read_bucket_settings('prompt', 'bs', min=1, step=1, max=max_num_prefill_seqs)
1919
prompt_query_bucket_cfg = read_bucket_settings('prompt',
2020
'query',
21-
min=block_size,
21+
min=1,
2222
step=block_size,
2323
max=max_num_batched_tokens)
2424
max_ctx = math.ceil((max_model_len - prompt_query_bucket_cfg[0]) // block_size)

0 commit comments

Comments
 (0)