Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/hf_cache.json

Large diffs are not rendered by default.

78 changes: 78 additions & 0 deletions tests/hf_result_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import json
import os
from pathlib import Path


class HFResultCache:
"""
A simple cache for storing and retrieving results from Hugging Face models.
The cache is stored in a JSON file named 'hf_cache.json' in the same
directory as this script.

This cache can be (re)populated by running all tests and committing the
changes to the .json file.
"""

def __init__(self):
"""
Initialize the HFResultCache. Load existing cached results from
'hf_cache.json'. If the file does not exist, an empty cache dictionary
is created.
"""
current_dir = Path(os.path.abspath(os.path.dirname(__file__)))
self.cached_results_file_path = current_dir / "hf_cache.json"

if not self.cached_results_file_path.exists():
self.cached_results = {}
# Start with empty file
with open(self.cached_results_file_path, 'w') as f:
json.dump(self.cached_results, f)
else:
with open(self.cached_results_file_path) as f:
self.cached_results = json.load(f)

self.dirty = False

def write_cache(self):
"""
Write the current cache to 'hf_cache.json' if it has been modified.
"""
if self.dirty:
with open(self.cached_results_file_path, 'w') as f:
json.dump(self.cached_results, f)
self.dirty = False

def get_cached_result(self, model: str, prompt: str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the type annotation for prompt be Union[str, list[int]]?

max_tokens: int) -> dict:
"""
Retrieve a cached result for the given model, prompt, and max_tokens.
Returns an empty dictionary if no cache entry is found.
"""
if isinstance(prompt, list):
prompt = self._token_ids_to_string(prompt)
max_tokens = str(max_tokens)

return self.cached_results.get(model, {}).get(prompt,
{}).get(max_tokens, {})

def add_to_cache(self, model: str, prompt: str, max_tokens: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about the prompt type annotation

result: dict):
"""
Add a new result to the cache for the given model, prompt, and
max_tokens. Marks the cache as 'dirty' to indicate that it needs to be
written to disk.
"""
if isinstance(prompt, list):
prompt = self._token_ids_to_string(prompt)
max_tokens = str(max_tokens)

self.cached_results.setdefault(model,
{}).setdefault(prompt, {}).setdefault(
max_tokens, result)
self.dirty = True

def _token_ids_to_string(self, token_ids: list[int]) -> str:
"""Use a string to represent a list of token ids, so that it can be
hashed and used as a json key."""

return "__tokens__" + "_".join(str(token_id) for token_id in token_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, no tokenizer required.

35 changes: 32 additions & 3 deletions tests/spyre_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pytest
import requests
import torch
from hf_result_cache import HFResultCache
from sentence_transformers import SentenceTransformer, util
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM, SamplingParams
Expand All @@ -28,6 +29,8 @@
ISCLOSE_REL_TOL_CPU = 0.35
ISCLOSE_REL_TOL_SPYRE = 0.35

HF_RESULT_CACHE = HFResultCache()


def force_engine_shutdown(llm: LLM):
"""
Expand Down Expand Up @@ -253,17 +256,32 @@ def generate_hf_output(
max_new_tokens: Union[int, list[int]],
ignore_eos: bool = False,
) -> list[dict[str, Any]]:
"""Loads and runs the model on cpu with transformers, caching the results.
Returns cached results if any are found to avoid overhead."""

if not isinstance(max_new_tokens, list):
max_new_tokens = [max_new_tokens] * len(prompts)

results = []
for prompt, max_tokens in zip(prompts, max_new_tokens):
results.append(
HF_RESULT_CACHE.get_cached_result(model, prompt, max_tokens))

if all(results):
# Everything hit cache
return results

hf_model = AutoModelForCausalLM.from_pretrained(model)
hf_tokenizer = AutoTokenizer.from_pretrained(model)
if ignore_eos:
hf_model.generation_config.eos_token_id = None

results = []
for prompt_index, prompt in enumerate(prompts):

if results[prompt_index]:
# Already have cached result
continue

hf_input_tokens = hf_tokenizer(prompt, return_tensors="pt").input_ids \
if isinstance(prompt[0], str) \
else torch.tensor([prompts[prompt_index]])
Expand Down Expand Up @@ -295,8 +313,14 @@ def generate_hf_output(
result['token_ids'] = tuple(result['token_ids'])
result['tokens'] = tuple(result['tokens'])
result['logprobs'] = tuple(result['logprobs'])
results.append(result)

# Save and cache new result
results[prompt_index] = result
HF_RESULT_CACHE.add_to_cache(model, prompt,
max_new_tokens[prompt_index], result)

# Write back to the cache
HF_RESULT_CACHE.write_cache()
return results


Expand Down Expand Up @@ -336,8 +360,13 @@ def compare_results(
print(f" vLLM: {repr(vllm_result['text']):s}{err_msg}")
print()

if isinstance(hf_result['token_ids'], list):
hf_result['token_ids'] = tuple(hf_result['token_ids'])

assert DISABLE_ASSERTS or backend == 'sendnn' or\
hf_result['token_ids'] == vllm_result['token_ids']
hf_result['token_ids'] == vllm_result['token_ids'], \
f"Token ids differ: {hf_result['token_ids']} != " \
f"{vllm_result['token_ids']}"

if len(hf_result['tokens']) > 0:
print(" token id. token logprob "
Expand Down
Loading