Skip to content

WIP - do not merge - Vllm v1 hidden states #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
722b739
checkpointing before implementing rest of hidden states
kyle-pena-kuzco Jun 3, 2025
9b257ae
checkpoint
kyle-pena-kuzco Jun 4, 2025
37e424f
core engine hidden states implementation possibly complete
kyle-pena-kuzco Jun 4, 2025
dd65e97
another checkpoint - partial API integration
kyle-pena-kuzco Jun 4, 2025
5c2e114
checkpointing on hidden states extraction
kyle-pena-kuzco Jun 4, 2025
0b138e3
implemented true test of hidden states core engine functionality
kyle-pena-kuzco Jun 5, 2025
dd34eff
implemented basic API support. stremaing to follow.
kyle-pena-kuzco Jun 5, 2025
c5d164f
cleaned up several unneeded files. fixed some other bugs.
kyle-pena-kuzco Jun 6, 2025
afcae9f
removal of more unneeded stuff
kyle-pena-kuzco Jun 6, 2025
ca4a83a
removed more stuff
kyle-pena-kuzco Jun 6, 2025
b55a6ed
continuing cleanup and centralization of tests
kyle-pena-kuzco Jun 6, 2025
8b513e1
more cleanup, expanded test coverage
kyle-pena-kuzco Jun 9, 2025
0c09f1e
more cleanup of unneeded test files
kyle-pena-kuzco Jun 9, 2025
c43b5eb
fixed chat completion streaming test (although the hidden states dict…
kyle-pena-kuzco Jun 9, 2025
8daca13
fixed streaming api tests
kyle-pena-kuzco Jun 9, 2025
e9f7c65
changed property name to be more understandable
kyle-pena-kuzco Jun 9, 2025
fdcc2ff
some progress on hdiden states being keyed by req id instead of token…
kyle-pena-kuzco Jun 9, 2025
88d1f44
fixed streaming for chat completions and completions
kyle-pena-kuzco Jun 9, 2025
a67c221
removed print statements
kyle-pena-kuzco Jun 9, 2025
8a16a81
implemented server level flag for enabling/disabling return hidden st…
kyle-pena-kuzco Jun 9, 2025
461261f
pushed the engine flag validation further down the stack so that raw …
kyle-pena-kuzco Jun 9, 2025
4c58a97
more changes to test coverage
kyle-pena-kuzco Jun 10, 2025
fe78b9e
fixes for the gathering of hidden states and their consumption in the…
kyle-pena-kuzco Jun 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/v1/hidden_states/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: Apache-2.0
308 changes: 308 additions & 0 deletions tests/v1/hidden_states/test_hidden_states_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
#!/usr/bin/env python3
"""
Integration test for vLLM Hidden States API

This test spins up a vLLM server with V1 engine and tests the hidden states functionality
using the same patterns as other vLLM integration tests.
"""

import pytest
import requests
from typing import Dict, Any

from tests.utils import RemoteOpenAIServer
from vllm.platforms import current_platform

if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
allow_module_level=True)

# Test model - use a small model for faster testing
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"


@pytest.fixture(scope="module")
def default_server_args():
"""Default server arguments for hidden states testing."""
return [
# Use half precision for speed and memory savings
"--max-model-len", "2048",
"--max-num-seqs", "128",
"--enforce-eager",
"--enable-return-hidden-states"
]


@pytest.fixture(scope="module")
def server(default_server_args):
"""Start vLLM server with V1 engine for hidden states testing."""
env_dict = {"VLLM_USE_V1": "1"} # Ensure V1 engine is enabled
with RemoteOpenAIServer(MODEL_NAME, default_server_args, env_dict=env_dict) as remote_server:
yield remote_server


class TestHiddenStatesAPI:
"""Test suite for hidden states API functionality."""

def test_chat_completion_without_hidden_states(self, server):
"""Test chat completion without hidden states (baseline functionality)."""
client = server.get_client()

response = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": "Hello! How are you today?"}],
max_tokens=10,
temperature=0.7
)

# Validate standard response structure
assert response.id
assert response.object == "chat.completion"
assert response.model == MODEL_NAME
assert len(response.choices) > 0

choice = response.choices[0]
assert choice.message
assert choice.message.role == "assistant"
assert choice.message.content

# Convert to dict to check for hidden_states field
choice_dict = choice.model_dump()

# With exclude_if_none, hidden_states should not be present when None
# But if it is present, it should be None (backward compatibility)
if "hidden_states" in choice_dict:
assert choice_dict["hidden_states"] is None
print(" NOTE: hidden_states field present but None (expected with current implementation)")
else:
print(" hidden_states field properly excluded")

def test_chat_completion_with_hidden_states(self, server):
"""Test chat completion with hidden states extraction."""

# Make raw HTTP request to test our custom parameters
url = server.url_for("v1", "chat", "completions")
headers = {"Content-Type": "application/json"}
payload = {
"model": MODEL_NAME,
"messages": [{"role": "user", "content": "What is the capital of France?"}],
"max_tokens": 10,
"temperature": 0.7,
"return_hidden_states": True,
"hidden_states_token_positions": [-1] # Last token
}

response = requests.post(url, json=payload, headers=headers)
assert response.status_code == 200

data = response.json()

# Validate response structure
assert "choices" in data
assert len(data["choices"]) > 0
choice = data["choices"][0]
assert "message" in choice

# Check if hidden states are present
# NOTE: This test may initially fail until the full hidden states pipeline is working
# For now, we'll check that the API accepts the parameters without error
print(f" Response received: {choice.get('message', {}).get('content', '')[:50]}...")

if "hidden_states" in choice:
if choice["hidden_states"] is not None:
assert isinstance(choice["hidden_states"], list)
assert len(choice["hidden_states"]) > 0
print(f" Hidden states extracted: {len(choice['hidden_states'])} dimensions")
else:
print(" Hidden states requested but None returned (pipeline may not be fully connected)")
else:
print(" Hidden states field not present (may indicate exclude_if_none is working)")

def test_completion_without_hidden_states(self, server):
"""Test completion without hidden states (baseline functionality)."""
client = server.get_client()

response = client.completions.create(
model=MODEL_NAME,
prompt="The capital of France is",
max_tokens=5,
temperature=0.7
)

# Validate standard response structure
assert response.id
assert response.object == "text_completion"
assert response.model == MODEL_NAME
assert len(response.choices) > 0

choice = response.choices[0]
assert choice.text

# Convert to dict to check for hidden_states field
choice_dict = choice.model_dump()

# With exclude_if_none, hidden_states should not be present when None
assert "hidden_states" not in choice_dict, "hidden_states field should not be present when None"

def test_completion_with_hidden_states(self, server):
"""Test completion with hidden states extraction."""

# Make raw HTTP request to test our custom parameters
url = server.url_for("v1", "completions")
headers = {"Content-Type": "application/json"}
payload = {
"model": MODEL_NAME,
"prompt": "The capital of France is",
"max_tokens": 5,
"temperature": 0.7,
"return_hidden_states": True,
"hidden_states_token_positions": [-1] # Last token
}

response = requests.post(url, json=payload, headers=headers)
assert response.status_code == 200

data = response.json()

# Validate response structure
assert "choices" in data
assert len(data["choices"]) > 0
choice = data["choices"][0]
assert "text" in choice

print(f" Response received: {choice.get('text', '')[:50]}...")

assert "hidden_states" in choice, "hidden_states field should be present"
assert choice["hidden_states"] is not None, "hidden_states should not be None"

def test_invalid_hidden_states_parameters(self, server):
"""Test API validation for invalid hidden states parameters."""

url = server.url_for("v1", "chat", "completions")
headers = {"Content-Type": "application/json"}

# Test invalid return_hidden_states type
payload = {
"model": MODEL_NAME,
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 5,
"return_hidden_states": "true" # Should be boolean
}

response = requests.post(url, json=payload, headers=headers)
# This should either work (if server converts string to bool) or return 422
if response.status_code == 422:
print(" Invalid parameter type correctly rejected")
else:
print(" Server accepted string 'true' for boolean field")

def test_chat_completion_with_hidden_states_streaming(self, server):
import requests
import json

url = server.url_for("v1/chat/completions")
payload = {
"model": MODEL_NAME,
"messages": [{"role": "user", "content": "Hello, can you help?"}],
"return_hidden_states": True,
"stream": True
}
response = requests.post(url, json=payload, stream=True)
response.raise_for_status()

full_content = ""
hidden_states_found = False

for line in response.iter_lines():
if line:
line_text = line.decode('utf-8')
print(line_text)
if line_text.startswith('data: '):
data_text = line_text[6:]
if data_text.strip() == '[DONE]':
break
try:
chunk = json.loads(data_text)
choice = chunk.get('choices', [{}])[0]
delta = choice.get('delta', {})
if 'hidden_states' in delta:
hidden_states_found = True
except json.JSONDecodeError:
continue

assert hidden_states_found, "Chat completion streaming should include hidden states."


def test_completion_with_hidden_states_streaming(self, server):
import requests
import json

url = server.url_for("v1/completions")
payload = {
"model": MODEL_NAME,
"prompt": "What is the answer?",
"return_hidden_states": True,
"stream": True
}
response = requests.post(url, json=payload, stream=True)
response.raise_for_status()

full_content = ""
hidden_states_found = False

for line in response.iter_lines():
if line:
line_text = line.decode('utf-8')
if line_text.startswith('data: '):
data_text = line_text[6:]
if data_text.strip() == '[DONE]':
break
try:
chunk = json.loads(data_text)
choice = chunk.get('choices', [{}])[0]
if 'hidden_states' in choice:
hidden_states_found = True
except json.JSONDecodeError:
continue

assert hidden_states_found, "Completion streaming should include hidden states."


def test_chat_completion_parallel_sampling(self, server):
"""Test retrieving hidden states via parallel sampling."""
print("Testing parallel sampling hidden states extraction...")

client = server.get_client()

response = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": "Hello! How are you today?"}],
temperature=0.7,
n = 2,
extra_body = {"return_hidden_states": True},
)

for choice in response.choices:
assert choice.hidden_states is not None

def test_completion_parallel_sampling(self, server):
"""Test retrieving hidden states via parallel sampling."""
print("Testing parallel sampling hidden states extraction...")

client = server.get_client()

response = client.completions.create(
model=MODEL_NAME,
prompt="Hello! How are you today?",
temperature=0.7,
n = 2,
extra_body = {"return_hidden_states": True},
)

for choice in response.choices:
choice = choice.model_dump()
assert "hidden_states" in choice and choice["hidden_states"] is not None

if __name__ == "__main__":
# Allow running this test directly
pytest.main([__file__, "-v", "-s"])
Loading