Skip to content

Commit c9a6ba3

Browse files
⏪ parameterize max_num_seqs
Signed-off-by: Prashant Gupta <[email protected]>
1 parent 958d483 commit c9a6ba3

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

tests/e2e/test_spyre_basic.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,14 @@
2929
ids=lambda val: f"TP({val})",
3030
)
3131
@pytest.mark.parametrize("backend", get_spyre_backend_list())
32+
@pytest.mark.parametrize("max_num_seqs", [4],
33+
ids=lambda val: f"max_num_seqs({val})")
3234
def test_output(
3335
model: str,
3436
tp_size: int,
3537
backend: str,
3638
cb: int,
39+
max_num_seqs: int,
3740
monkeypatch: pytest.MonkeyPatch,
3841
) -> None:
3942
'''
@@ -55,10 +58,10 @@ def test_output(
5558
warmup_shape = (64, 20, 4)
5659

5760
kwargs = ({
58-
"max_num_seqs": 2,
61+
"max_num_seqs": max_num_seqs,
5962
"use_cb": True,
6063
"max_model_len": 256,
61-
"block_size": 256
64+
"block_size": 256,
6265
} if cb == 1 else {
6366
"warmup_shapes": (warmup_shape, ),
6467
"max_model_len": 2048,

0 commit comments

Comments
 (0)