Skip to content

Commit d89de6b

Browse files
committed
test: Fix models.list() recording/replay support
o Create AsyncIterableModelsWrapper that supports both usage patterns: * Direct async iteration: async for m in client.models.list() * Await then iterate: res = await client.models.list(); async for m in res o Preserve all existing recording/replay functionality for other endpoints Signed-off-by: Derek Higgins <[email protected]>
1 parent a7f9ce9 commit d89de6b

File tree

1 file changed

+36
-7
lines changed

1 file changed

+36
-7
lines changed

llama_stack/testing/inference_recorder.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -388,14 +388,43 @@ async def patched_embeddings_create(self, *args, **kwargs):
388388
_original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs
389389
)
390390

391+
# Special handling for models.list which needs to return something directly async-iterable
392+
# Direct iteration: async for m in client.models.list()
393+
# Await then iterate: res = await client.models.list(); async for m in res
391394
def patched_models_list(self, *args, **kwargs):
392-
async def _iter():
393-
for item in await _patched_inference_method(
394-
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
395-
):
396-
yield item
397-
398-
return _iter()
395+
class AsyncIterableModelsWrapper:
396+
def __init__(self, original_method, client_self, args, kwargs):
397+
self.original_method = original_method
398+
self.client_self = client_self
399+
self.args = args
400+
self.kwargs = kwargs
401+
self._result = None
402+
403+
def __aiter__(self):
404+
return self._async_iter()
405+
406+
async def _async_iter(self):
407+
# Get the result from the patched method
408+
result = await _patched_inference_method(
409+
self.original_method, self.client_self, "openai", "/v1/models", *self.args, **self.kwargs
410+
)
411+
412+
# result is either a async_generator (replay) or a dict (record)
413+
if isinstance(result, dict):
414+
for item in result["data"]:
415+
yield item
416+
else:
417+
for item in result:
418+
yield item
419+
420+
def __await__(self):
421+
# When awaited, return self (since we're already async-iterable)
422+
async def _return_self():
423+
return self
424+
425+
return _return_self().__await__()
426+
427+
return AsyncIterableModelsWrapper(_original_methods["models_list"], self, args, kwargs)
399428

400429
# Apply OpenAI patches
401430
AsyncChatCompletions.create = patched_chat_completions_create

0 commit comments

Comments
 (0)