Skip to content

Commit 7a7de4c

Browse files
Danidapenak223kim
andauthored
fix: handle empty responses (#66)
* fix: handle empty repsonses * fix: handle empty repsonses * fix: handle empty repsonses * fix: handle empty repsonses * add test * fix output * fix output * fix output * async will work * async will work * fix-some-typing * fix-more-typig * fix-more-typig * fix-more-typig * fix-more-typig * fix-req * remove-some-code * add-tests * nit: return type --------- Co-authored-by: Kaeun Kim <[email protected]>
1 parent 1854f73 commit 7a7de4c

File tree

4 files changed

+379
-42
lines changed

4 files changed

+379
-42
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
lightning_sdk >= 2025.09.16
2+
nest-asyncio

src/litai/llm.py

Lines changed: 199 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,25 @@
1313
# limitations under the License.
1414
"""LLM client class."""
1515

16+
import asyncio
1617
import datetime
18+
import itertools
1719
import json
1820
import logging
1921
import os
2022
import threading
2123
import warnings
22-
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
24+
from asyncio import Task
25+
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Sequence, Union
2326

27+
import nest_asyncio
2428
import requests
2529
from lightning_sdk.lightning_cloud.openapi import V1ConversationResponseChunk
2630
from lightning_sdk.llm import LLM as SDKLLM
2731

2832
from litai.tools import LitTool
2933
from litai.utils.supported_public_models import ModelLiteral
30-
from litai.utils.utils import handle_model_error
34+
from litai.utils.utils import handle_empty_response, handle_model_error
3135

3236
if TYPE_CHECKING:
3337
from langchain_core.tools import StructuredTool
@@ -206,7 +210,7 @@ def _format_tool_response(
206210
return LLM.call_tool(result, lit_tools) or ""
207211
return json.dumps(result)
208212

209-
def _model_call(
213+
def _model_call( # noqa: D417
210214
self,
211215
model: SDKLLM,
212216
prompt: str,
@@ -258,6 +262,147 @@ def context_length(self, model: Optional[str] = None) -> int:
258262
return self._llm.get_context_length(self._model)
259263
return self._llm.get_context_length(model)
260264

265+
async def _peek_and_rebuild_async(
266+
self,
267+
agen: AsyncIterator[str],
268+
) -> Optional[AsyncIterator[str]]:
269+
"""Peek into an async iterator to check for non-empty content and rebuild it if necessary."""
270+
peeked_items: List[str] = []
271+
has_content_found = False
272+
273+
async for item in agen:
274+
peeked_items.append(item)
275+
if item != "":
276+
has_content_found = True
277+
break
278+
279+
if has_content_found:
280+
281+
async def rebuilt() -> AsyncIterator[str]:
282+
for peeked_item in peeked_items:
283+
yield peeked_item
284+
285+
async for remaining_item in agen:
286+
yield remaining_item
287+
288+
return rebuilt()
289+
290+
return None
291+
292+
async def async_chat(
293+
self,
294+
models_to_try: List[SDKLLM],
295+
prompt: str,
296+
system_prompt: Optional[str],
297+
max_tokens: Optional[int],
298+
images: Optional[Union[List[str], str]],
299+
conversation: Optional[str],
300+
metadata: Optional[Dict[str, str]],
301+
stream: bool,
302+
full_response: Optional[bool] = None,
303+
model: Optional[SDKLLM] = None,
304+
tools: Optional[Sequence[Union[str, Dict[str, Any]]]] = None,
305+
lit_tools: Optional[List[LitTool]] = None,
306+
auto_call_tools: bool = False,
307+
reasoning_effort: Optional[str] = None,
308+
**kwargs: Any,
309+
) -> Union[str, AsyncIterator[str], None]:
310+
"""Sends a message to the LLM asynchronously with full retry/fallback logic."""
311+
for sdk_model in models_to_try:
312+
for attempt in range(self.max_retries):
313+
try:
314+
response = await self._model_call( # type: ignore[misc]
315+
model=sdk_model,
316+
prompt=prompt,
317+
system_prompt=system_prompt,
318+
max_completion_tokens=max_tokens,
319+
images=images,
320+
conversation=conversation,
321+
metadata=metadata,
322+
stream=stream,
323+
tools=tools,
324+
lit_tools=lit_tools,
325+
full_response=full_response,
326+
auto_call_tools=auto_call_tools,
327+
reasoning_effort=reasoning_effort,
328+
**kwargs,
329+
)
330+
331+
if not stream and response:
332+
return response
333+
if stream and response:
334+
non_empty_stream = await self._peek_and_rebuild_async(response)
335+
if non_empty_stream:
336+
return non_empty_stream
337+
handle_empty_response(sdk_model, attempt, self.max_retries)
338+
if sdk_model == model:
339+
print(f"💥 Failed to override with model '{model}'")
340+
except Exception as e:
341+
handle_model_error(e, sdk_model, attempt, self.max_retries, self._verbose)
342+
raise RuntimeError(f"💥 [LLM call failed after {self.max_retries} attempts]")
343+
344+
def sync_chat(
345+
self,
346+
models_to_try: List[SDKLLM],
347+
prompt: str,
348+
system_prompt: Optional[str],
349+
max_tokens: Optional[int],
350+
images: Optional[Union[List[str], str]],
351+
conversation: Optional[str],
352+
metadata: Optional[Dict[str, str]],
353+
stream: bool,
354+
model: Optional[SDKLLM] = None,
355+
full_response: Optional[bool] = None,
356+
tools: Optional[Sequence[Union[str, Dict[str, Any]]]] = None,
357+
lit_tools: Optional[List[LitTool]] = None,
358+
auto_call_tools: bool = False,
359+
reasoning_effort: Optional[str] = None,
360+
**kwargs: Any,
361+
) -> Union[str, Iterator[str], None]:
362+
"""Sends a message to the LLM synchronously with full retry/fallback logic."""
363+
for sdk_model in models_to_try:
364+
for attempt in range(self.max_retries):
365+
try:
366+
response = self._model_call(
367+
model=sdk_model,
368+
prompt=prompt,
369+
system_prompt=system_prompt,
370+
max_completion_tokens=max_tokens,
371+
images=images,
372+
conversation=conversation,
373+
metadata=metadata,
374+
stream=stream,
375+
tools=tools,
376+
lit_tools=lit_tools,
377+
full_response=full_response,
378+
auto_call_tools=auto_call_tools,
379+
reasoning_effort=reasoning_effort,
380+
**kwargs,
381+
)
382+
383+
if not stream and response:
384+
return response
385+
if stream:
386+
try:
387+
peek_iter, return_iter = itertools.tee(response)
388+
has_content = False
389+
for chunk in peek_iter:
390+
if chunk != "":
391+
has_content = True
392+
break
393+
if has_content:
394+
return return_iter
395+
except StopIteration:
396+
pass
397+
handle_empty_response(sdk_model, attempt, self.max_retries)
398+
399+
except Exception as e:
400+
if sdk_model == model:
401+
print(f"💥 Failed to override with model '{model}'")
402+
handle_model_error(e, sdk_model, attempt, self.max_retries, self._verbose)
403+
404+
raise RuntimeError(f"💥 [LLM call failed after {self.max_retries} attempts]")
405+
261406
def chat( # noqa: D417
262407
self,
263408
prompt: str,
@@ -272,7 +417,7 @@ def chat( # noqa: D417
272417
auto_call_tools: bool = False,
273418
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
274419
**kwargs: Any,
275-
) -> str:
420+
) -> Union[str, Task[Union[str, AsyncIterator[str], None]], Iterator[str], None]:
276421
"""Sends a message to the LLM and retrieves a response.
277422
278423
Args:
@@ -303,57 +448,61 @@ def chat( # noqa: D417
303448
self._wait_for_model()
304449
lit_tools = LitTool.convert_tools(tools)
305450
processed_tools = [tool.as_tool() for tool in lit_tools] if lit_tools else None
451+
452+
models_to_try = []
453+
sdk_model = None
306454
if model:
307-
try:
308-
model_key = f"{model}::{self._teamspace}::{self._enable_async}"
309-
if model_key not in self._sdkllm_cache:
310-
self._sdkllm_cache[model_key] = SDKLLM(
311-
name=model, teamspace=self._teamspace, enable_async=self._enable_async
312-
)
313-
sdk_model = self._sdkllm_cache[model_key]
314-
return self._model_call(
455+
model_key = f"{model}::{self._teamspace}::{self._enable_async}"
456+
if model_key not in self._sdkllm_cache:
457+
self._sdkllm_cache[model_key] = SDKLLM(
458+
name=model, teamspace=self._teamspace, enable_async=self._enable_async
459+
)
460+
sdk_model = self._sdkllm_cache[model_key]
461+
models_to_try.append(sdk_model)
462+
models_to_try.extend(self.models)
463+
464+
if self._enable_async:
465+
nest_asyncio.apply()
466+
nest_asyncio.apply()
467+
468+
loop = asyncio.get_event_loop()
469+
return loop.create_task(
470+
self.async_chat(
471+
models_to_try=models_to_try,
315472
model=sdk_model,
316473
prompt=prompt,
317474
system_prompt=system_prompt,
318-
max_completion_tokens=max_tokens,
475+
max_tokens=max_tokens,
319476
images=images,
320477
conversation=conversation,
321478
metadata=metadata,
322479
stream=stream,
480+
full_response=self._full_response,
323481
tools=processed_tools,
324482
lit_tools=lit_tools,
325483
auto_call_tools=auto_call_tools,
326484
reasoning_effort=reasoning_effort,
327485
**kwargs,
328486
)
329-
except Exception as e:
330-
print(f"💥 Failed to override with model '{model}'")
331-
handle_model_error(e, sdk_model, 0, self.max_retries, self._verbose)
487+
)
332488

333-
# Retry with fallback models
334-
for model in self.models:
335-
for attempt in range(self.max_retries):
336-
try:
337-
return self._model_call(
338-
model=model,
339-
prompt=prompt,
340-
system_prompt=system_prompt,
341-
max_completion_tokens=max_tokens,
342-
images=images,
343-
conversation=conversation,
344-
metadata=metadata,
345-
stream=stream,
346-
tools=processed_tools,
347-
lit_tools=lit_tools,
348-
auto_call_tools=auto_call_tools,
349-
reasoning_effort=reasoning_effort,
350-
**kwargs,
351-
)
352-
353-
except Exception as e:
354-
handle_model_error(e, model, attempt, self.max_retries, self._verbose)
355-
356-
raise RuntimeError(f"💥 [LLM call failed after {self.max_retries} attempts]")
489+
return self.sync_chat(
490+
models_to_try=models_to_try,
491+
model=sdk_model,
492+
prompt=prompt,
493+
system_prompt=system_prompt,
494+
max_tokens=max_tokens,
495+
images=images,
496+
conversation=conversation,
497+
metadata=metadata,
498+
stream=stream,
499+
full_response=self._full_response,
500+
tools=processed_tools,
501+
lit_tools=lit_tools,
502+
auto_call_tools=auto_call_tools,
503+
reasoning_effort=reasoning_effort,
504+
**kwargs,
505+
)
357506

358507
@staticmethod
359508
def call_tool(
@@ -491,7 +640,11 @@ def if_(self, input: str, question: str) -> bool:
491640
Answer with only 'yes' or 'no'.
492641
"""
493642

494-
response = self.chat(prompt).strip().lower()
643+
response = self.chat(prompt)
644+
if isinstance(response, str):
645+
response = response.strip().lower()
646+
else:
647+
return False
495648
return "yes" in response
496649

497650
def classify(self, input: str, choices: List[str]) -> str:
@@ -517,7 +670,11 @@ def classify(self, input: str, choices: List[str]) -> str:
517670
Answer with only one of the choices.
518671
""".strip()
519672

520-
response = self.chat(prompt).strip().lower()
673+
response = self.chat(prompt)
674+
if isinstance(response, str):
675+
response = response.strip().lower()
676+
else:
677+
return normalized_choices[0]
521678

522679
if response in normalized_choices:
523680
return response

src/litai/utils/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,13 @@ def handle_model_error(e: Exception, model: SDKLLM, attempt: int, max_retries: i
182182
print("-" * 50)
183183
print(f"❌ All {max_retries} attempts failed for model {model.name}")
184184
print("-" * 50)
185+
186+
187+
def handle_empty_response(model: SDKLLM, attempt: int, max_retries: int) -> None:
188+
"""Handles empty responses from model calls."""
189+
if attempt < max_retries - 1:
190+
print(f"🔁 Received empty response. Attempt {attempt + 1}/{max_retries} failed. Retrying...")
191+
else:
192+
print("-" * 50)
193+
print(f"❌ All {max_retries} attempts received empty responses for model {model.name}.")
194+
print("-" * 50)

0 commit comments

Comments
 (0)