We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 958d483 commit c9a6ba3Copy full SHA for c9a6ba3
tests/e2e/test_spyre_basic.py
@@ -29,11 +29,14 @@
29
ids=lambda val: f"TP({val})",
30
)
31
@pytest.mark.parametrize("backend", get_spyre_backend_list())
32
+@pytest.mark.parametrize("max_num_seqs", [4],
33
+ ids=lambda val: f"max_num_seqs({val})")
34
def test_output(
35
model: str,
36
tp_size: int,
37
backend: str,
38
cb: int,
39
+ max_num_seqs: int,
40
monkeypatch: pytest.MonkeyPatch,
41
) -> None:
42
'''
@@ -55,10 +58,10 @@ def test_output(
55
58
warmup_shape = (64, 20, 4)
56
59
57
60
kwargs = ({
- "max_num_seqs": 2,
61
+ "max_num_seqs": max_num_seqs,
62
"use_cb": True,
63
"max_model_len": 256,
- "block_size": 256
64
+ "block_size": 256,
65
} if cb == 1 else {
66
"warmup_shapes": (warmup_shape, ),
67
"max_model_len": 2048,
0 commit comments