Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ jobs:
markers: "cpu and cb and not quantized"
flags: "--timeout=300 --durations=0 -s"
- name: "worker and utils"
markers: "not e2e"
markers: "not e2e and not quantized"
flags: "--timeout=300"
- name: "compatibility"
markers: "compat"
Expand All @@ -70,7 +70,7 @@ jobs:
markers: "cpu and other_e2e and not quantized"
flags: "--timeout=300"
- name: "precompilation"
markers: "precompilation"
markers: "precompilation and not quantized"
flags: "--timeout=300"
include:
- vllm_version:
Expand Down Expand Up @@ -167,7 +167,7 @@ jobs:

- name: "Restore HF models cache"
id: cache_restore
if: steps.changed-src-files.outputs.any_changed == 'true'
if: "!contains(env.model_key, 'micro')" # <------ TODO: revert this condition! --------
uses: actions/cache/restore@v4
with:
path: ${{ env.model_path }}
Expand All @@ -182,7 +182,10 @@ jobs:
# be removed by an admin or can be left to expire after 7 days.

download_tinygranite() {
python -c "from transformers import pipeline, AutoTokenizer; pipeline('text-generation', model='$1'); tokenizer=AutoTokenizer.from_pretrained('$1')"
python -c "from transformers import pipeline; pipeline('text-generation', model='$1', revision='2714578f54cfb744ece40df9326ee0b47e879e03');"
}
download_tinygranite_FP8() {
python -c "from transformers import pipeline; pipeline('text-generation', model='$1');"
}
download_roberta_large() {
python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('$1')"
Expand All @@ -203,9 +206,12 @@ jobs:
for model in "${models[@]}"; do
echo "Downloading $model ..."
case "$model" in
"ibm-ai-platform/micro-g3.3-8b-instruct-1b"*)
"ibm-ai-platform/micro-g3.3-8b-instruct-1b")
download_tinygranite "$model" &
;;
"ibm-ai-platform/micro-g3.3-8b-instruct-1b-FP8")
download_tinygranite_FP8 "$model" &
;;
"JackFram/llama-160m")
download_tinyllama "$model" &
;;
Expand Down
5 changes: 3 additions & 2 deletions tests/e2e/test_spyre_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def test_full_batch_scheduling(model: ModelInfo, backend: str, monkeypatch):
request_id=i,
num_tokens=max_batched_tokens,
sampling_params=vllm_sampling_params,
model=model.name,
model=model,
))
schedule = scheduler.schedule()

Expand All @@ -217,7 +217,8 @@ def test_max_model_len_override(model: ModelInfo, backend, warmup_shapes, cb,

patch_environment(**kwargs, backend=backend, monkeypatch=monkeypatch)
vllm_config = EngineArgs(
model=model.name, max_model_len=max_model_len).create_engine_config()
model=model.name, revision=model.revision,
max_model_len=max_model_len).create_engine_config()
model_config = vllm_config.model_config

if not cb:
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_spyre_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_long_context_batches(
)

for batch_size, token_len in batch_token_pairs:
prompt = create_seq_prompt(model.name, token_length=token_len)
prompt = create_seq_prompt(model, token_length=token_len)
prompts = [prompt] * batch_size

vllm_outputs = vllm_model.generate(prompts, sampling_params)
Expand Down
19 changes: 12 additions & 7 deletions tests/e2e/test_spyre_prompt_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import torch
import torch.nn.functional
from llm_cache_util import force_engine_shutdown
from spyre_util import get_chicken_soup_prompts, skip_unsupported_tp_size
from spyre_util import (ModelInfo, get_chicken_soup_prompts,
skip_unsupported_tp_size)
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM, RequestOutput, SamplingParams
from vllm.config import ModelConfig, VllmConfig
Expand All @@ -18,7 +19,7 @@

# Skip for now until prompt logprobs are fixed
@pytest.mark.skip
def test_prompt_logprobs(backend: str, model: str, tp_size: int,
def test_prompt_logprobs(backend: str, model: str | ModelInfo, tp_size: int,
monkeypatch: pytest.MonkeyPatch) -> None:
'''
This test checks the prompt_logprobs output from vllm against a reference
Expand All @@ -33,14 +34,16 @@ def test_prompt_logprobs(backend: str, model: str, tp_size: int,

monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend)
monkeypatch.setenv("VLLM_SPYRE_ENABLE_PROMPT_LOGPROBS", "1")
llm = LLM(model, tensor_parallel_size=tp_size, tokenizer=model)
tokenizer = AutoTokenizer.from_pretrained(model.name,
revision=model.revision)
llm = LLM(model, tensor_parallel_size=tp_size, tokenizer=tokenizer)

responses: list[RequestOutput] = llm.generate(
prompts,
sampling_params=SamplingParams(prompt_logprobs=num_prompt_logprobs))

expected_prompt_logprobs: dict[str, list] = _get_hf_prompt_logprobs(
model_name=model, prompts=prompts, n=num_prompt_logprobs)
model_info=model, prompts=prompts, n=num_prompt_logprobs)

for prompt, response in zip(prompts, responses):
actual_logprobs = response.prompt_logprobs
Expand Down Expand Up @@ -128,11 +131,13 @@ def _compare_prompt_logprobs(expected: list, actual: list,
rel_tol=relative_tolerance)


def _get_hf_prompt_logprobs(model_name, prompts, n) -> dict[str, list]:
def _get_hf_prompt_logprobs(model_info: ModelInfo, prompts,
n) -> dict[str, list]:
"""Get prompt logprobs from HF model directly, including top n candidates
for each token"""
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_info.name,
revision=model_info.revision)
model = AutoModelForCausalLM.from_pretrained(model_info.name)

prompt_logprobs = {}
for prompt in prompts:
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_spyre_static_batching_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_max_prompt_len_and_new_tokens(model: ModelInfo,

# Craft a request with a prompt that is slightly too long for the warmup
# shape
prompt = create_text_prompt(model.name,
prompt = create_text_prompt(model,
min_token_length=max_prompt_length,
max_token_length=max_prompt_length +
max_new_tokens - 1)
Expand Down
Loading