1313# limitations under the License.
1414"""LLM client class."""
1515
16+ import asyncio
1617import datetime
18+ import itertools
1719import json
1820import logging
1921import os
2022import threading
2123import warnings
22- from typing import TYPE_CHECKING , Any , Dict , List , Literal , Optional , Sequence , Union
24+ from asyncio import Task
25+ from typing import TYPE_CHECKING , Any , AsyncIterator , Dict , Iterator , List , Literal , Optional , Sequence , Union
2326
27+ import nest_asyncio
2428import requests
2529from lightning_sdk .lightning_cloud .openapi import V1ConversationResponseChunk
2630from lightning_sdk .llm import LLM as SDKLLM
2731
2832from litai .tools import LitTool
2933from litai .utils .supported_public_models import ModelLiteral
30- from litai .utils .utils import handle_model_error
34+ from litai .utils .utils import handle_empty_response , handle_model_error
3135
3236if TYPE_CHECKING :
3337 from langchain_core .tools import StructuredTool
@@ -206,7 +210,7 @@ def _format_tool_response(
206210 return LLM .call_tool (result , lit_tools ) or ""
207211 return json .dumps (result )
208212
209- def _model_call (
213+ def _model_call ( # noqa: D417
210214 self ,
211215 model : SDKLLM ,
212216 prompt : str ,
@@ -258,6 +262,147 @@ def context_length(self, model: Optional[str] = None) -> int:
258262 return self ._llm .get_context_length (self ._model )
259263 return self ._llm .get_context_length (model )
260264
265+ async def _peek_and_rebuild_async (
266+ self ,
267+ agen : AsyncIterator [str ],
268+ ) -> Optional [AsyncIterator [str ]]:
269+ """Peek into an async iterator to check for non-empty content and rebuild it if necessary."""
270+ peeked_items : List [str ] = []
271+ has_content_found = False
272+
273+ async for item in agen :
274+ peeked_items .append (item )
275+ if item != "" :
276+ has_content_found = True
277+ break
278+
279+ if has_content_found :
280+
281+ async def rebuilt () -> AsyncIterator [str ]:
282+ for peeked_item in peeked_items :
283+ yield peeked_item
284+
285+ async for remaining_item in agen :
286+ yield remaining_item
287+
288+ return rebuilt ()
289+
290+ return None
291+
292+ async def async_chat (
293+ self ,
294+ models_to_try : List [SDKLLM ],
295+ prompt : str ,
296+ system_prompt : Optional [str ],
297+ max_tokens : Optional [int ],
298+ images : Optional [Union [List [str ], str ]],
299+ conversation : Optional [str ],
300+ metadata : Optional [Dict [str , str ]],
301+ stream : bool ,
302+ full_response : Optional [bool ] = None ,
303+ model : Optional [SDKLLM ] = None ,
304+ tools : Optional [Sequence [Union [str , Dict [str , Any ]]]] = None ,
305+ lit_tools : Optional [List [LitTool ]] = None ,
306+ auto_call_tools : bool = False ,
307+ reasoning_effort : Optional [str ] = None ,
308+ ** kwargs : Any ,
309+ ) -> Union [str , AsyncIterator [str ], None ]:
310+ """Sends a message to the LLM asynchronously with full retry/fallback logic."""
311+ for sdk_model in models_to_try :
312+ for attempt in range (self .max_retries ):
313+ try :
314+ response = await self ._model_call ( # type: ignore[misc]
315+ model = sdk_model ,
316+ prompt = prompt ,
317+ system_prompt = system_prompt ,
318+ max_completion_tokens = max_tokens ,
319+ images = images ,
320+ conversation = conversation ,
321+ metadata = metadata ,
322+ stream = stream ,
323+ tools = tools ,
324+ lit_tools = lit_tools ,
325+ full_response = full_response ,
326+ auto_call_tools = auto_call_tools ,
327+ reasoning_effort = reasoning_effort ,
328+ ** kwargs ,
329+ )
330+
331+ if not stream and response :
332+ return response
333+ if stream and response :
334+ non_empty_stream = await self ._peek_and_rebuild_async (response )
335+ if non_empty_stream :
336+ return non_empty_stream
337+ handle_empty_response (sdk_model , attempt , self .max_retries )
338+ if sdk_model == model :
339+ print (f"💥 Failed to override with model '{ model } '" )
340+ except Exception as e :
341+ handle_model_error (e , sdk_model , attempt , self .max_retries , self ._verbose )
342+ raise RuntimeError (f"💥 [LLM call failed after { self .max_retries } attempts]" )
343+
344+ def sync_chat (
345+ self ,
346+ models_to_try : List [SDKLLM ],
347+ prompt : str ,
348+ system_prompt : Optional [str ],
349+ max_tokens : Optional [int ],
350+ images : Optional [Union [List [str ], str ]],
351+ conversation : Optional [str ],
352+ metadata : Optional [Dict [str , str ]],
353+ stream : bool ,
354+ model : Optional [SDKLLM ] = None ,
355+ full_response : Optional [bool ] = None ,
356+ tools : Optional [Sequence [Union [str , Dict [str , Any ]]]] = None ,
357+ lit_tools : Optional [List [LitTool ]] = None ,
358+ auto_call_tools : bool = False ,
359+ reasoning_effort : Optional [str ] = None ,
360+ ** kwargs : Any ,
361+ ) -> Union [str , Iterator [str ], None ]:
362+ """Sends a message to the LLM synchronously with full retry/fallback logic."""
363+ for sdk_model in models_to_try :
364+ for attempt in range (self .max_retries ):
365+ try :
366+ response = self ._model_call (
367+ model = sdk_model ,
368+ prompt = prompt ,
369+ system_prompt = system_prompt ,
370+ max_completion_tokens = max_tokens ,
371+ images = images ,
372+ conversation = conversation ,
373+ metadata = metadata ,
374+ stream = stream ,
375+ tools = tools ,
376+ lit_tools = lit_tools ,
377+ full_response = full_response ,
378+ auto_call_tools = auto_call_tools ,
379+ reasoning_effort = reasoning_effort ,
380+ ** kwargs ,
381+ )
382+
383+ if not stream and response :
384+ return response
385+ if stream :
386+ try :
387+ peek_iter , return_iter = itertools .tee (response )
388+ has_content = False
389+ for chunk in peek_iter :
390+ if chunk != "" :
391+ has_content = True
392+ break
393+ if has_content :
394+ return return_iter
395+ except StopIteration :
396+ pass
397+ handle_empty_response (sdk_model , attempt , self .max_retries )
398+
399+ except Exception as e :
400+ if sdk_model == model :
401+ print (f"💥 Failed to override with model '{ model } '" )
402+ handle_model_error (e , sdk_model , attempt , self .max_retries , self ._verbose )
403+
404+ raise RuntimeError (f"💥 [LLM call failed after { self .max_retries } attempts]" )
405+
261406 def chat ( # noqa: D417
262407 self ,
263408 prompt : str ,
@@ -272,7 +417,7 @@ def chat( # noqa: D417
272417 auto_call_tools : bool = False ,
273418 reasoning_effort : Optional [Literal ["none" , "low" , "medium" , "high" ]] = None ,
274419 ** kwargs : Any ,
275- ) -> str :
420+ ) -> Union [ str , Task [ Union [ str , AsyncIterator [ str ], None ]], Iterator [ str ], None ] :
276421 """Sends a message to the LLM and retrieves a response.
277422
278423 Args:
@@ -303,57 +448,61 @@ def chat( # noqa: D417
303448 self ._wait_for_model ()
304449 lit_tools = LitTool .convert_tools (tools )
305450 processed_tools = [tool .as_tool () for tool in lit_tools ] if lit_tools else None
451+
452+ models_to_try = []
453+ sdk_model = None
306454 if model :
307- try :
308- model_key = f"{ model } ::{ self ._teamspace } ::{ self ._enable_async } "
309- if model_key not in self ._sdkllm_cache :
310- self ._sdkllm_cache [model_key ] = SDKLLM (
311- name = model , teamspace = self ._teamspace , enable_async = self ._enable_async
312- )
313- sdk_model = self ._sdkllm_cache [model_key ]
314- return self ._model_call (
455+ model_key = f"{ model } ::{ self ._teamspace } ::{ self ._enable_async } "
456+ if model_key not in self ._sdkllm_cache :
457+ self ._sdkllm_cache [model_key ] = SDKLLM (
458+ name = model , teamspace = self ._teamspace , enable_async = self ._enable_async
459+ )
460+ sdk_model = self ._sdkllm_cache [model_key ]
461+ models_to_try .append (sdk_model )
462+ models_to_try .extend (self .models )
463+
464+ if self ._enable_async :
465+ nest_asyncio .apply ()
466+ nest_asyncio .apply ()
467+
468+ loop = asyncio .get_event_loop ()
469+ return loop .create_task (
470+ self .async_chat (
471+ models_to_try = models_to_try ,
315472 model = sdk_model ,
316473 prompt = prompt ,
317474 system_prompt = system_prompt ,
318- max_completion_tokens = max_tokens ,
475+ max_tokens = max_tokens ,
319476 images = images ,
320477 conversation = conversation ,
321478 metadata = metadata ,
322479 stream = stream ,
480+ full_response = self ._full_response ,
323481 tools = processed_tools ,
324482 lit_tools = lit_tools ,
325483 auto_call_tools = auto_call_tools ,
326484 reasoning_effort = reasoning_effort ,
327485 ** kwargs ,
328486 )
329- except Exception as e :
330- print (f"💥 Failed to override with model '{ model } '" )
331- handle_model_error (e , sdk_model , 0 , self .max_retries , self ._verbose )
487+ )
332488
333- # Retry with fallback models
334- for model in self .models :
335- for attempt in range (self .max_retries ):
336- try :
337- return self ._model_call (
338- model = model ,
339- prompt = prompt ,
340- system_prompt = system_prompt ,
341- max_completion_tokens = max_tokens ,
342- images = images ,
343- conversation = conversation ,
344- metadata = metadata ,
345- stream = stream ,
346- tools = processed_tools ,
347- lit_tools = lit_tools ,
348- auto_call_tools = auto_call_tools ,
349- reasoning_effort = reasoning_effort ,
350- ** kwargs ,
351- )
352-
353- except Exception as e :
354- handle_model_error (e , model , attempt , self .max_retries , self ._verbose )
355-
356- raise RuntimeError (f"💥 [LLM call failed after { self .max_retries } attempts]" )
489+ return self .sync_chat (
490+ models_to_try = models_to_try ,
491+ model = sdk_model ,
492+ prompt = prompt ,
493+ system_prompt = system_prompt ,
494+ max_tokens = max_tokens ,
495+ images = images ,
496+ conversation = conversation ,
497+ metadata = metadata ,
498+ stream = stream ,
499+ full_response = self ._full_response ,
500+ tools = processed_tools ,
501+ lit_tools = lit_tools ,
502+ auto_call_tools = auto_call_tools ,
503+ reasoning_effort = reasoning_effort ,
504+ ** kwargs ,
505+ )
357506
358507 @staticmethod
359508 def call_tool (
@@ -491,7 +640,11 @@ def if_(self, input: str, question: str) -> bool:
491640 Answer with only 'yes' or 'no'.
492641 """
493642
494- response = self .chat (prompt ).strip ().lower ()
643+ response = self .chat (prompt )
644+ if isinstance (response , str ):
645+ response = response .strip ().lower ()
646+ else :
647+ return False
495648 return "yes" in response
496649
497650 def classify (self , input : str , choices : List [str ]) -> str :
@@ -517,7 +670,11 @@ def classify(self, input: str, choices: List[str]) -> str:
517670 Answer with only one of the choices.
518671 """ .strip ()
519672
520- response = self .chat (prompt ).strip ().lower ()
673+ response = self .chat (prompt )
674+ if isinstance (response , str ):
675+ response = response .strip ().lower ()
676+ else :
677+ return normalized_choices [0 ]
521678
522679 if response in normalized_choices :
523680 return response
0 commit comments