From 786a7786547443a8408b5a41dce7245d4c059b5c Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 8 Jul 2025 11:13:19 -0700 Subject: [PATCH 1/7] =?UTF-8?q?=F0=9F=8E=A8=20add=20cb=20parameterization?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- tests/e2e/test_spyre_max_new_tokens.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/e2e/test_spyre_max_new_tokens.py b/tests/e2e/test_spyre_max_new_tokens.py index e8d21586d..a847a1370 100644 --- a/tests/e2e/test_spyre_max_new_tokens.py +++ b/tests/e2e/test_spyre_max_new_tokens.py @@ -10,6 +10,8 @@ from vllm import SamplingParams +@pytest.mark.parametrize("cb", + [pytest.param(1, marks=pytest.mark.cb, id="cb"), 0]) @pytest.mark.parametrize("model", get_spyre_model_list()) @pytest.mark.parametrize("stop_last", [True, False]) @pytest.mark.parametrize( @@ -20,6 +22,7 @@ def test_output( stop_last: bool, warmup_shape: tuple[int, int, int], backend: str, + cb: int, monkeypatch: pytest.MonkeyPatch, ) -> None: ''' @@ -66,16 +69,21 @@ def test_output( ] + vllm_sampling_params hf_max_new_tokens = [max_new_tokens_early_stop] + hf_max_new_tokens + kwargs = ( + {"max_num_seqs": 2, "use_cb": True} + if cb == 1 + else {"warmup_shapes": (warmup_shape,)} + ) vllm_results = generate_spyre_vllm_output( model=model, prompts=prompts, - warmup_shapes=[warmup_shape], max_model_len=2048, block_size=2048, sampling_params=vllm_sampling_params, tensor_parallel_size=1, backend=backend, - monkeypatch=monkeypatch) + monkeypatch=monkeypatch, + **kwargs) hf_results = generate_hf_output(model=model, prompts=prompts, From e85cf2ce7b641fdc298a7accde688bb4fb1e898d Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 8 Jul 2025 11:23:33 -0700 Subject: [PATCH 2/7] =?UTF-8?q?=F0=9F=8E=A8=20set=20max=5Fmodel=5Flen=20to?= =?UTF-8?q?=20128=20for=20cb?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- tests/e2e/test_spyre_max_new_tokens.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/e2e/test_spyre_max_new_tokens.py b/tests/e2e/test_spyre_max_new_tokens.py index a847a1370..ffdc49a16 100644 --- a/tests/e2e/test_spyre_max_new_tokens.py +++ b/tests/e2e/test_spyre_max_new_tokens.py @@ -70,14 +70,13 @@ def test_output( hf_max_new_tokens = [max_new_tokens_early_stop] + hf_max_new_tokens kwargs = ( - {"max_num_seqs": 2, "use_cb": True} + {"max_num_seqs": 2, "use_cb": True, "max_model_len": 128} if cb == 1 - else {"warmup_shapes": (warmup_shape,)} + else {"warmup_shapes": (warmup_shape,), "max_model_len": 2048} ) vllm_results = generate_spyre_vllm_output( model=model, prompts=prompts, - max_model_len=2048, block_size=2048, sampling_params=vllm_sampling_params, tensor_parallel_size=1, From 0b6581e9cf288eb7920db60d988e35853b53b99e Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 8 Jul 2025 11:26:18 -0700 Subject: [PATCH 3/7] =?UTF-8?q?=F0=9F=8E=A8=20fmt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- tests/e2e/test_spyre_max_new_tokens.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/e2e/test_spyre_max_new_tokens.py b/tests/e2e/test_spyre_max_new_tokens.py index ffdc49a16..cc3b9ef27 100644 --- a/tests/e2e/test_spyre_max_new_tokens.py +++ b/tests/e2e/test_spyre_max_new_tokens.py @@ -69,12 +69,15 @@ def test_output( ] + vllm_sampling_params hf_max_new_tokens = [max_new_tokens_early_stop] + hf_max_new_tokens - kwargs = ( - {"max_num_seqs": 2, "use_cb": True, "max_model_len": 128} - if cb == 1 - else {"warmup_shapes": (warmup_shape,), "max_model_len": 2048} - ) - vllm_results = generate_spyre_vllm_output( + kwargs = ({ + "max_num_seqs": 2, + "use_cb": True, + "max_model_len": 128 + } if cb == 1 else { + "warmup_shapes": (warmup_shape, ), + "max_model_len": 2048 + }) + vllm_results = generate_spyre_vllm_output( model=model, prompts=prompts, block_size=2048, From ad9002cf432fdd9d0bfec7bbf5d5c7be1131cc20 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 8 Jul 2025 11:28:44 -0700 Subject: [PATCH 4/7] =?UTF-8?q?=F0=9F=8E=A8=20fmt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- tests/e2e/test_spyre_max_new_tokens.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/e2e/test_spyre_max_new_tokens.py b/tests/e2e/test_spyre_max_new_tokens.py index cc3b9ef27..0123d00f9 100644 --- a/tests/e2e/test_spyre_max_new_tokens.py +++ b/tests/e2e/test_spyre_max_new_tokens.py @@ -78,14 +78,14 @@ def test_output( "max_model_len": 2048 }) vllm_results = generate_spyre_vllm_output( - model=model, - prompts=prompts, - block_size=2048, - sampling_params=vllm_sampling_params, - tensor_parallel_size=1, - backend=backend, - monkeypatch=monkeypatch, - **kwargs) + model=model, + prompts=prompts, + block_size=2048, + sampling_params=vllm_sampling_params, + tensor_parallel_size=1, + backend=backend, + monkeypatch=monkeypatch, + **kwargs) hf_results = generate_hf_output(model=model, prompts=prompts, From bc39adb5dd34a78e94660fe88ae7f9c019b763db Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 8 Jul 2025 11:34:56 -0700 Subject: [PATCH 5/7] =?UTF-8?q?=F0=9F=8E=A8=20set=20to=20256=20for=20consi?= =?UTF-8?q?stency?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- tests/e2e/test_spyre_max_new_tokens.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e2e/test_spyre_max_new_tokens.py b/tests/e2e/test_spyre_max_new_tokens.py index 0123d00f9..1306aa21f 100644 --- a/tests/e2e/test_spyre_max_new_tokens.py +++ b/tests/e2e/test_spyre_max_new_tokens.py @@ -72,7 +72,7 @@ def test_output( kwargs = ({ "max_num_seqs": 2, "use_cb": True, - "max_model_len": 128 + "max_model_len": 256 } if cb == 1 else { "warmup_shapes": (warmup_shape, ), "max_model_len": 2048 From a501cf6d0830b19a330e29e01b37ad5dea7b28b2 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 8 Jul 2025 11:48:38 -0700 Subject: [PATCH 6/7] =?UTF-8?q?=F0=9F=8E=A8=20fix=20indentation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- tests/e2e/test_spyre_max_new_tokens.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/e2e/test_spyre_max_new_tokens.py b/tests/e2e/test_spyre_max_new_tokens.py index 1306aa21f..c68e8d3c0 100644 --- a/tests/e2e/test_spyre_max_new_tokens.py +++ b/tests/e2e/test_spyre_max_new_tokens.py @@ -77,15 +77,16 @@ def test_output( "warmup_shapes": (warmup_shape, ), "max_model_len": 2048 }) - vllm_results = generate_spyre_vllm_output( - model=model, - prompts=prompts, - block_size=2048, - sampling_params=vllm_sampling_params, - tensor_parallel_size=1, - backend=backend, - monkeypatch=monkeypatch, - **kwargs) + + vllm_results = generate_spyre_vllm_output( + model=model, + prompts=prompts, + block_size=2048, + sampling_params=vllm_sampling_params, + tensor_parallel_size=1, + backend=backend, + monkeypatch=monkeypatch, + **kwargs) hf_results = generate_hf_output(model=model, prompts=prompts, From 000ed74eab0388582da0d06286bf77c785cb52a0 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 8 Jul 2025 16:35:39 -0700 Subject: [PATCH 7/7] =?UTF-8?q?=F0=9F=8E=A8=20fix=20indentation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- tests/e2e/test_spyre_max_new_tokens.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/e2e/test_spyre_max_new_tokens.py b/tests/e2e/test_spyre_max_new_tokens.py index c68e8d3c0..01ddb13f9 100644 --- a/tests/e2e/test_spyre_max_new_tokens.py +++ b/tests/e2e/test_spyre_max_new_tokens.py @@ -69,14 +69,14 @@ def test_output( ] + vllm_sampling_params hf_max_new_tokens = [max_new_tokens_early_stop] + hf_max_new_tokens - kwargs = ({ - "max_num_seqs": 2, - "use_cb": True, - "max_model_len": 256 - } if cb == 1 else { - "warmup_shapes": (warmup_shape, ), - "max_model_len": 2048 - }) + kwargs = ({ + "max_num_seqs": 2, + "use_cb": True, + "max_model_len": 256 + } if cb == 1 else { + "warmup_shapes": (warmup_shape, ), + "max_model_len": 2048 + }) vllm_results = generate_spyre_vllm_output( model=model,