Skip to content

Commit 80393ea

Browse files
committed
test: Fix models.list() recording/replay support
o Create AsyncIterableModelsWrapper that supports both usage patterns: * Direct async iteration: async for m in client.models.list() * Await then iterate: res = await client.models.list(); async for m in res o Preserve all existing recording/replay functionality for other endpoints Signed-off-by: Derek Higgins <[email protected]>
1 parent 65d45c7 commit 80393ea

File tree

1 file changed

+36
-7
lines changed

1 file changed

+36
-7
lines changed

llama_stack/testing/inference_recorder.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -383,14 +383,43 @@ async def patched_embeddings_create(self, *args, **kwargs):
383383
_original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs
384384
)
385385

386+
# Special handling for models.list which needs to return something directly async-iterable
387+
# Direct iteration: async for m in client.models.list()
388+
# Await then iterate: res = await client.models.list(); async for m in res
386389
def patched_models_list(self, *args, **kwargs):
387-
async def _iter():
388-
for item in await _patched_inference_method(
389-
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
390-
):
391-
yield item
392-
393-
return _iter()
390+
class AsyncIterableModelsWrapper:
391+
def __init__(self, original_method, client_self, args, kwargs):
392+
self.original_method = original_method
393+
self.client_self = client_self
394+
self.args = args
395+
self.kwargs = kwargs
396+
self._result = None
397+
398+
def __aiter__(self):
399+
return self._async_iter()
400+
401+
async def _async_iter(self):
402+
# Get the result from the patched method
403+
result = await _patched_inference_method(
404+
self.original_method, self.client_self, "openai", "/v1/models", *self.args, **self.kwargs
405+
)
406+
407+
# result is either a async_generator (replay) or a dict (record)
408+
if isinstance(result, dict):
409+
for item in result["data"]:
410+
yield item
411+
else:
412+
for item in result:
413+
yield item
414+
415+
def __await__(self):
416+
# When awaited, return self (since we're already async-iterable)
417+
async def _return_self():
418+
return self
419+
420+
return _return_self().__await__()
421+
422+
return AsyncIterableModelsWrapper(_original_methods["models_list"], self, args, kwargs)
394423

395424
# Apply OpenAI patches
396425
AsyncChatCompletions.create = patched_chat_completions_create

0 commit comments

Comments
 (0)