Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 0 additions & 52 deletions tests/e2e/test_spyre_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,58 +109,6 @@ def test_api_cb_generates_correct_max_tokens(
assert response.usage.completion_tokens == max_tokens


@pytest.mark.cb
@pytest.mark.spyre
@pytest.mark.xfail # TODO: remove once a spyre-base image supports this
@pytest.mark.parametrize("model", get_spyre_model_list())
def test_continuous_batching_with_long_contexts(model, monkeypatch):
"""Tests that continuous batching generates the same outputs on the spyre
cards as it does on cpu, when the max context length is set to 4k.
This ensures that the compiler is generating the correct programs for long
context cases, but we test here with small prompts for speed.

Importantly, we're generating the cpu results to compare against using vllm
as well, instead of using transformers directly. This ensures that the model
code is all the same, and the only difference is the torch compilation
backend.
"""
max_model_len = 4096
max_num_seqs = 4
prompts = get_chicken_soup_prompts(4)

sampling_params = SamplingParams(max_tokens=20,
temperature=0,
ignore_eos=True,
logprobs=0)

vllm_cpu_results = generate_spyre_vllm_output(
model=model,
prompts=prompts,
max_model_len=max_model_len,
sampling_params=sampling_params,
tensor_parallel_size=1,
backend="eager",
max_num_seqs=max_num_seqs,
use_cb=True,
monkeypatch=monkeypatch)

vllm_spyre_results = generate_spyre_vllm_output(
model=model,
prompts=prompts,
max_model_len=max_model_len,
sampling_params=sampling_params,
tensor_parallel_size=1,
backend="sendnn",
max_num_seqs=max_num_seqs,
use_cb=True,
monkeypatch=monkeypatch)

for i in range(len(vllm_cpu_results)):
# As long as no sequences have top candidate tokens with very close
# logprobs, the generated text should be identical.
assert vllm_cpu_results[i]["text"] == vllm_spyre_results[i]["text"]


@pytest.mark.cb
@pytest.mark.parametrize("model", get_spyre_model_list())
@pytest.mark.parametrize(
Expand Down
Loading