@@ -36,7 +36,7 @@ def get_prompt_cfgs(self, max_num_prefill_seqs, block_size, max_num_batched_toke
3636 prompt_bs_limit = math .ceil (math .log2 (max_num_prefill_seqs )) + 1
3737 prompt_bs_bucket_cfg = [1 , 2 , max_num_prefill_seqs , prompt_bs_limit ]
3838 max_prompt_seq_limit = math .ceil (math .log2 (max_num_batched_tokens ))
39- prompt_query_bucket_cfg = [block_size , block_size , max_num_batched_tokens , max_prompt_seq_limit ]
39+ prompt_query_bucket_cfg = [1 , block_size , max_num_batched_tokens , max_prompt_seq_limit ]
4040 max_ctx = max (1 , math .ceil ((max_model_len - prompt_query_bucket_cfg [0 ]) // block_size ))
4141 max_prompt_ctx_limit = 2 if max_ctx == 1 else math .ceil (math .log2 (max_ctx )) + 1
4242 prompt_ctx_bucket_cfg = [0 , 1 , max_ctx , max_prompt_ctx_limit ]
@@ -125,7 +125,8 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], long_context=Fals
125125
126126 bmin , bstep , bmax , num_buckets = config
127127 add_zero_bucket = bmin == 0
128- if add_zero_bucket :
128+ add_one_bucket = bmin == 1
129+ if add_zero_bucket or add_one_bucket :
129130 bmin = bstep
130131 linear_buckets = set (np .arange (bmin , bmax + 1 , step = bstep ))
131132 assert num_buckets > 0 , "num_buckets must be a positive integer"
@@ -176,4 +177,6 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], long_context=Fals
176177 buckets .add (bucket )
177178 if add_zero_bucket :
178179 buckets .add (0 )
180+ if add_one_bucket :
181+ buckets .add (1 )
179182 return list (sorted (buckets ))
0 commit comments