Skip to content

Commit 9ee85bc

Browse files
committed
add-tests
1 parent f8bb56d commit 9ee85bc

File tree

2 files changed

+122
-3
lines changed

2 files changed

+122
-3
lines changed

src/litai/llm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,10 @@ async def async_chat(
330330

331331
if not stream and response:
332332
return response
333-
non_empty_stream = await self._peek_and_rebuild_async(response)
334-
if non_empty_stream:
335-
return non_empty_stream
333+
if stream and response:
334+
non_empty_stream = await self._peek_and_rebuild_async(response)
335+
if non_empty_stream:
336+
return non_empty_stream
336337
handle_empty_response(sdk_model, attempt, self.max_retries)
337338
if sdk_model == model:
338339
print(f"💥 Failed to override with model '{model}'")

tests/test_llm.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,124 @@ def mock_llm_constructor(name, teamspace="default-teamspace", **kwargs):
298298
)
299299

300300

301+
def test_empty_response_retries_sync_stream(monkeypatch):
302+
"""Test that retries work correctly for sync streaming when empty responses are returned."""
303+
from litai.llm import LLM as LLMCLIENT
304+
305+
LLMCLIENT._sdkllm_cache.clear()
306+
307+
class MockSyncIterator:
308+
def __init__(self, items):
309+
self.items = items
310+
self.index = 0
311+
312+
def __iter__(self):
313+
return self
314+
315+
def __next__(self):
316+
if self.index < len(self.items):
317+
item = self.items[self.index]
318+
self.index += 1
319+
return item
320+
raise StopIteration
321+
322+
mock_responses = [
323+
MockSyncIterator([]),
324+
MockSyncIterator([]),
325+
MockSyncIterator(["hello", " world"]),
326+
]
327+
328+
mock_main_model = MagicMock()
329+
330+
def mock_llm_constructor(name, teamspace="default-teamspace", **kwargs):
331+
if name == "main-model":
332+
mock_main_model.chat.side_effect = mock_responses
333+
mock_main_model.name = "main-model"
334+
return mock_main_model
335+
raise ValueError(f"Unknown model: {name}")
336+
337+
monkeypatch.setattr("litai.llm.SDKLLM", mock_llm_constructor)
338+
339+
llm = LLM(
340+
model="main-model",
341+
)
342+
343+
response = llm.chat("test prompt", stream=True)
344+
345+
assert mock_main_model.chat.call_count == 3
346+
347+
result = ""
348+
for chunk in response:
349+
result += chunk
350+
assert result == "hello world"
351+
352+
353+
@pytest.mark.asyncio
354+
async def test_empty_response_retries_async(monkeypatch):
355+
"""Test that retries work correctly for async and non streaming when empty responses are returned."""
356+
from litai.llm import LLM as LLMCLIENT
357+
358+
LLMCLIENT._sdkllm_cache.clear()
359+
mock_sdkllm = MagicMock()
360+
mock_sdkllm.name = "mock-model"
361+
362+
mock_sdkllm.chat = AsyncMock(side_effect=["", "", "Main response"])
363+
364+
monkeypatch.setattr("litai.llm.SDKLLM", lambda *args, **kwargs: mock_sdkllm)
365+
366+
llm = LLM(
367+
model="main-model",
368+
enable_async=True,
369+
)
370+
response = await llm.chat(prompt="Hello", stream=False)
371+
372+
assert response == "Main response"
373+
assert mock_sdkllm.chat.call_count == 3
374+
375+
376+
@pytest.mark.asyncio
377+
async def test_empty_response_retries_async_stream(monkeypatch):
378+
"""Test that retries work correctly for async streaming when empty responses are returned."""
379+
from litai.llm import LLM as LLMCLIENT
380+
381+
LLMCLIENT._sdkllm_cache.clear()
382+
mock_sdkllm = MagicMock()
383+
mock_sdkllm.name = "mock-model"
384+
385+
class MockAsyncIterator:
386+
def __init__(self, items):
387+
self.items = items
388+
self.index = 0
389+
390+
def __aiter__(self):
391+
return self
392+
393+
async def __anext__(self):
394+
if self.index < len(self.items):
395+
item = self.items[self.index]
396+
self.index += 1
397+
return item
398+
raise StopAsyncIteration
399+
400+
mock_sdkllm.chat = AsyncMock(
401+
side_effect=[MockAsyncIterator([]), MockAsyncIterator([]), MockAsyncIterator(["Main", " response"])]
402+
)
403+
404+
monkeypatch.setattr("litai.llm.SDKLLM", lambda *args, **kwargs: mock_sdkllm)
405+
406+
llm = LLM(
407+
model="main-model",
408+
enable_async=True,
409+
)
410+
411+
response = await llm.chat(prompt="Hello", stream=True)
412+
result = ""
413+
async for chunk in response:
414+
result += chunk
415+
assert result == "Main response"
416+
assert mock_sdkllm.chat.call_count == 3
417+
418+
301419
@pytest.mark.asyncio
302420
async def test_llm_async_chat(monkeypatch):
303421
"""Test async requests."""

0 commit comments

Comments
 (0)