From b1043293679a693ad61284b78b6ecc380457152f Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Thu, 17 Jul 2025 13:49:05 +0200 Subject: [PATCH 1/2] relax constraint of min num output tokens from 3 to 2 Signed-off-by: Yannick Schnider --- vllm_spyre/v1/worker/spyre_worker.py | 4 ++-- vllm_spyre/worker/spyre_worker.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 1af005935..949d4e56e 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -79,9 +79,9 @@ def compile_or_warm_up_model(self) -> None: if self.model_config.task != "embed": # TODO: remove if spyre supports # lower number of output tokens - assert num_decode_tokens >= 3, ( + assert num_decode_tokens >= 2, ( "VLLM_SPYRE_WARMUP_NEW_TOKENS must be " - "at least 3 (spyre requirement).") + "at least 2 (spyre requirement).") # warmup individual combination logger.info( "[WARMUP] (%d/%d) for prompt length %d, decoding %d tokens " diff --git a/vllm_spyre/worker/spyre_worker.py b/vllm_spyre/worker/spyre_worker.py index 60b4845ab..7f9f7e83b 100644 --- a/vllm_spyre/worker/spyre_worker.py +++ b/vllm_spyre/worker/spyre_worker.py @@ -208,9 +208,9 @@ def load_model(self): if self.model_config.task != "embed": # TODO: remove if spyre supports # lower number of output tokens - assert num_decode_tokens >= 3, ( + assert num_decode_tokens >= 2, ( "VLLM_SPYRE_WARMUP_NEW_TOKENS must be " - "at least 3 (spyre requirement).") + "at least 2 (spyre requirement).") # warmup individual combination print(f"[SpyreWorker] Warmup {i+1}/" f"{len(wup_new_tokens)} " From 243d5713585e36c3fe10d7eab0cde782a87b9467 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Thu, 17 Jul 2025 13:49:53 +0200 Subject: [PATCH 2/2] refactor test script with for loop Signed-off-by: Yannick Schnider --- tests/e2e/test_spyre_basic.py | 32 +++++++------------------------- 1 file changed, 7 insertions(+), 25 deletions(-) diff --git a/tests/e2e/test_spyre_basic.py b/tests/e2e/test_spyre_basic.py index 043105fea..7f1297ac8 100644 --- a/tests/e2e/test_spyre_basic.py +++ b/tests/e2e/test_spyre_basic.py @@ -160,32 +160,14 @@ def test_batch_handling(model: str, backend: str, cb: int, prompts = get_chicken_soup_prompts(4) - sampling_params1 = SamplingParams(max_tokens=5, - min_tokens=5, - temperature=0, - ignore_eos=True, - logprobs=0) - sampling_params2 = SamplingParams(max_tokens=20, - min_tokens=20, - temperature=0, - ignore_eos=True, - logprobs=0) - sampling_params3 = SamplingParams(max_tokens=10, - min_tokens=10, - temperature=0, - ignore_eos=True, - logprobs=0) - sampling_params4 = SamplingParams(max_tokens=5, - min_tokens=5, - temperature=0, - ignore_eos=True, - logprobs=0) + max_new_tokens = [5, 20, 10, 5] vllm_sampling_params = [ - sampling_params1, - sampling_params2, - sampling_params3, - sampling_params4, + SamplingParams(max_tokens=max_new_tokens[i], + min_tokens=max_new_tokens[i], + temperature=0, + ignore_eos=True, + logprobs=0) for i in range(len(max_new_tokens)) ] kwargs = { @@ -207,7 +189,7 @@ def test_batch_handling(model: str, backend: str, cb: int, **kwargs) hf_results = generate_hf_output(model=model, prompts=prompts, - max_new_tokens=[5, 20, 10, 5]) + max_new_tokens=max_new_tokens) compare_results( model=model,