Skip to content

Commit 3003c64

Browse files
committed
test: add models.list() recording/replay support
o Add patching for OpenAI AsyncModels.list method to inference recorder o Create AsyncIterableModelsWrapper that supports both usage patterns: * Direct async iteration: async for m in client.models.list() * Await then iterate: res = await client.models.list(); async for m in res o Update streaming detection to handle AsyncPage objects from models.list o Preserve all existing recording/replay functionality for other endpoints Signed-off-by: Derek Higgins <[email protected]>
1 parent 657a1fa commit 3003c64

File tree

1 file changed

+40
-2
lines changed

1 file changed

+40
-2
lines changed

llama_stack/testing/inference_recorder.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pathlib import Path
1717
from typing import Any, Literal, cast
1818

19+
from openai.pagination import AsyncPage
1920
from openai.types.chat import ChatCompletion, ChatCompletionChunk
2021

2122
from llama_stack.log import get_logger
@@ -255,7 +256,8 @@ async def replay_stream():
255256
}
256257

257258
# Determine if this is a streaming request based on request parameters
258-
is_streaming = body.get("stream", False)
259+
# or if the response is an AsyncPage (like models.list returns)
260+
is_streaming = body.get("stream", False) or isinstance(response, AsyncPage)
259261

260262
if is_streaming:
261263
# For streaming responses, we need to collect all chunks immediately before yielding
@@ -291,9 +293,11 @@ def patch_inference_clients():
291293
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
292294
from openai.resources.completions import AsyncCompletions
293295
from openai.resources.embeddings import AsyncEmbeddings
296+
from openai.resources.models import AsyncModels
294297

295298
# Store original methods for both OpenAI and Ollama clients
296299
_original_methods = {
300+
"models_list": AsyncModels.list,
297301
"chat_completions_create": AsyncChatCompletions.create,
298302
"completions_create": AsyncCompletions.create,
299303
"embeddings_create": AsyncEmbeddings.create,
@@ -305,7 +309,38 @@ def patch_inference_clients():
305309
"ollama_list": OllamaAsyncClient.list,
306310
}
307311

308-
# Create patched methods for OpenAI client
312+
# Special handling for models.list which needs to return something directly async-iterable
313+
# Direct iteration: async for m in client.models.list()
314+
# Await then iterate: res = await client.models.list(); async for m in res
315+
def patched_models_list(self, *args, **kwargs):
316+
class AsyncIterableModelsWrapper:
317+
def __init__(self, original_method, client_self, args, kwargs):
318+
self.original_method = original_method
319+
self.client_self = client_self
320+
self.args = args
321+
self.kwargs = kwargs
322+
self._result = None
323+
324+
def __aiter__(self):
325+
return self._async_iter()
326+
327+
async def _async_iter(self):
328+
# Get the result from the patched method
329+
result = await _patched_inference_method(
330+
self.original_method, self.client_self, "openai", "/v1/models", *self.args, **self.kwargs
331+
)
332+
async for item in result:
333+
yield item
334+
335+
def __await__(self):
336+
# When awaited, return self (since we're already async-iterable)
337+
async def _return_self():
338+
return self
339+
340+
return _return_self().__await__()
341+
342+
return AsyncIterableModelsWrapper(_original_methods["models_list"], self, args, kwargs)
343+
309344
async def patched_chat_completions_create(self, *args, **kwargs):
310345
return await _patched_inference_method(
311346
_original_methods["chat_completions_create"], self, "openai", "/v1/chat/completions", *args, **kwargs
@@ -322,6 +357,7 @@ async def patched_embeddings_create(self, *args, **kwargs):
322357
)
323358

324359
# Apply OpenAI patches
360+
AsyncModels.list = patched_models_list
325361
AsyncChatCompletions.create = patched_chat_completions_create
326362
AsyncCompletions.create = patched_completions_create
327363
AsyncEmbeddings.create = patched_embeddings_create
@@ -378,8 +414,10 @@ def unpatch_inference_clients():
378414
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
379415
from openai.resources.completions import AsyncCompletions
380416
from openai.resources.embeddings import AsyncEmbeddings
417+
from openai.resources.models import AsyncModels
381418

382419
# Restore OpenAI client methods
420+
AsyncModels.list = _original_methods["models_list"]
383421
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
384422
AsyncCompletions.create = _original_methods["completions_create"]
385423
AsyncEmbeddings.create = _original_methods["embeddings_create"]

0 commit comments

Comments
 (0)