Skip to content

Commit 2d0293d

Browse files
fix crashes with the usage of top_k (#543)
# Description We found a couple of ways to crash the server with edge cases in top_k: - setting `top_k > vocab_size` in the request - mixing greedy requests and sampling requests with `top_k > 0` in the same batch See #542 for details on the crashes. The "fix" in this PR is to just copy the logic from the vLLM's GPU `InputBatch` for setting the value of top_k: clamping the value to `vocab_size` and setting the default top_k to `vocab_size` instead of 0 in a mixed batch. REF: https://github.com/vllm-project/vllm/blob/fc168c33f35e0610d41206e864b6bf90fe613f19/vllm/v1/worker/gpu_input_batch.py#L353-L357 ## Related Issues FIX #542 --------- Signed-off-by: Travis Johnson <[email protected]>
1 parent 7ed0611 commit 2d0293d

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

tests/v1/worker/test_spyre_input_batch.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _construct_expected_sampling_metadata(
6363
presence_penalties = [0.0 for _ in range(num_reqs)]
6464
frequency_penalties = [0.0 for _ in range(num_reqs)]
6565
repetition_penalties = [1.0 for _ in range(num_reqs)]
66-
top_k = [0 for _ in range(num_reqs)]
66+
top_k = [VOCAB_SIZE for _ in range(num_reqs)]
6767
top_p = [0.0 for _ in range(num_reqs)]
6868
temperature = [0.0 for _ in range(num_reqs)]
6969
allowed_token_ids_mask = torch.zeros(num_reqs,
@@ -85,7 +85,8 @@ def _construct_expected_sampling_metadata(
8585
req.sampling_params.frequency_penalty)
8686
repetition_penalties[index_in_input_batch] = (
8787
req.sampling_params.repetition_penalty)
88-
top_k[index_in_input_batch] = req.sampling_params.top_k
88+
if req.sampling_params.top_k > 0:
89+
top_k[index_in_input_batch] = req.sampling_params.top_k
8990
top_p[index_in_input_batch] = req.sampling_params.top_p
9091
temperature[index_in_input_batch] = req.sampling_params.temperature
9192
if req.sampling_params.allowed_token_ids:
@@ -102,7 +103,7 @@ def _construct_expected_sampling_metadata(
102103
all_random=True,
103104
top_p=None if all(x == 1.0 for x in top_p) else torch.tensor(
104105
top_p, dtype=torch.float, device=device),
105-
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
106+
top_k=None if all(x == VOCAB_SIZE for x in top_k) else torch.tensor(
106107
top_k, dtype=torch.int, device=device),
107108
generators={},
108109
max_num_logprobs=0,
@@ -133,7 +134,7 @@ def _construct_expected_sampling_metadata(
133134

134135
def _create_sampling_params():
135136
return SamplingParams(
136-
top_k=np.random.randint(1, 10),
137+
top_k=np.random.randint(0, 10),
137138
top_p=np.random.uniform(0.0, 1.0),
138139
presence_penalty=np.random.uniform(-2.0, 2.0),
139140
repetition_penalty=np.random.uniform(0.0, 2.0),
@@ -267,3 +268,30 @@ def test_sampling_metadata_in_input_batch(batch_size: int):
267268
reqs, req_ids_retained, input_batch, device=torch.device(device))
268269

269270
compare_results(sampling_metadata, expected_sampling_metadata)
271+
272+
273+
@pytest.mark.cpu
274+
@pytest.mark.worker
275+
def test_sampling_metadata_topk_edges():
276+
device = torch.device('cpu')
277+
input_batch: SamplingInputBatch = SamplingInputBatch(
278+
max_num_reqs=2,
279+
max_model_len=1024,
280+
device=device,
281+
pin_memory=is_pin_memory_available(),
282+
vocab_size=VOCAB_SIZE,
283+
)
284+
285+
# top_k should be clamped to VOCAB_SIZE
286+
req = _construct_cached_request_state(0)
287+
req.sampling_params = SamplingParams(temperature=1.0, top_k=VOCAB_SIZE + 1)
288+
input_batch.add_request(req, 0)
289+
290+
# in a batch with both greedy and sampling, default top_k should be
291+
# VOCAB_SIZE
292+
req = _construct_cached_request_state(1)
293+
req.sampling_params = SamplingParams(temperature=0)
294+
input_batch.add_request(req, 1)
295+
296+
assert input_batch.top_k[0] == VOCAB_SIZE
297+
assert input_batch.top_k[1] == VOCAB_SIZE

vllm_spyre/v1/worker/spyre_input_batch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,9 +425,12 @@ def add_request(
425425
self.top_p_cpu[req_index] = sampling_params.top_p
426426
if sampling_params.top_p < 1:
427427
self.top_p_reqs.add(req_id)
428-
self.top_k_cpu[req_index] = sampling_params.top_k
429-
if sampling_params.top_k > 0:
428+
top_k = sampling_params.top_k
429+
if 0 < top_k < self.vocab_size:
430430
self.top_k_reqs.add(req_id)
431+
else:
432+
top_k = self.vocab_size
433+
self.top_k_cpu[req_index] = top_k
431434
self.frequency_penalties_cpu[
432435
req_index] = sampling_params.frequency_penalty
433436
if sampling_params.frequency_penalty != 0.0:

0 commit comments

Comments
 (0)