From fac7c24da90b56e90d4274f5ad56cef6b52868fa Mon Sep 17 00:00:00 2001 From: phact Date: Tue, 7 May 2024 11:46:32 -0400 Subject: [PATCH] back to standard error handling for litellm errors --- impl/routes/stateless.py | 23 ++++++++++-- impl/services/inference_utils.py | 64 ++++++++++++++++---------------- 2 files changed, 51 insertions(+), 36 deletions(-) diff --git a/impl/routes/stateless.py b/impl/routes/stateless.py index 86b8ee3..7650582 100644 --- a/impl/routes/stateless.py +++ b/impl/routes/stateless.py @@ -1,4 +1,5 @@ """The stateless endpoints that do not depend on information from DB""" +import logging import time import uuid from typing import Any, Dict @@ -6,7 +7,8 @@ import json -from fastapi import APIRouter, Depends, Request +from fastapi import APIRouter, Depends, Request, HTTPException +from litellm import APIError from starlette.responses import StreamingResponse, JSONResponse from openapi_server.models.chat_completion_stream_response_delta import ChatCompletionStreamResponseDelta @@ -28,8 +30,12 @@ router = APIRouter() +logger = logging.getLogger(__name__) + + async def _completion_from_request( chat_request: CreateChatCompletionRequest, + using_openai: bool, **litellm_kwargs: Any, ) -> CreateChatCompletionResponse | StreamingResponse: # NOTE: litellm_kwargs should contain auth @@ -56,6 +62,11 @@ async def _completion_from_request( messages.append(message_dict) + tools = [] + if chat_request.tools is not None: + for tool in chat_request.tools: + tools.append(tool.to_dict()) + functions = [] if chat_request.functions is not None: for function in chat_request.functions: @@ -98,6 +109,10 @@ async def _completion_from_request( if functions: kwargs["functions"] = functions + if tools: + kwargs["tools"] = tools + + # workaround for https://github.com/BerriAI/litellm/pull/3439 #if "function" not in kwargs and "tools" in kwargs: # kwargs["functions"] = kwargs["tools"] @@ -107,10 +122,9 @@ async def _completion_from_request( kwargs["logit_bias"] = chat_request.logit_bias if chat_request.user is not None: - kwargs["user"] = chat_request.user + kwargs["user"] = chat_request.usefunctionsr response = await get_async_chat_completion_response(**kwargs) - # TODO - throw error if response fails choices = [] if chat_request.stream is not None and chat_request.stream: @@ -221,8 +235,9 @@ async def create_moderation( async def create_chat_completion( create_chat_completion_request: CreateChatCompletionRequest, litellm_kwargs: Dict[str, Any] = Depends(get_litellm_kwargs), + using_openai: bool = Depends(check_if_using_openai), ) -> Any: - return await _completion_from_request(create_chat_completion_request, **litellm_kwargs) + return await _completion_from_request(create_chat_completion_request, using_openai, **litellm_kwargs) @router.post( diff --git a/impl/services/inference_utils.py b/impl/services/inference_utils.py index df581cf..2f53007 100644 --- a/impl/services/inference_utils.py +++ b/impl/services/inference_utils.py @@ -88,38 +88,38 @@ async def get_async_chat_completion_response( if model is None and deployment_id is None: raise ValueError("Must provide either a model or a deployment id") - #try: - if model is None: - model = deployment_id - - type_hints = get_type_hints(acompletion) - - for key, value in litellm_kwargs.items(): - if value is not None and key in type_hints and isinstance(value, str): - type_hint = type_hints[key] - # handle optional - if hasattr(type_hint, "__origin__") and type_hint.__origin__ == Union: - litellm_kwargs[key] = type_hint.__args__[0](value) - else: - litellm_kwargs[key] = type_hints[key](value) - - litellm.set_verbose=True - completion = await acompletion( - model=model, - messages=messages, - deployment_id=deployment_id, - **litellm_kwargs - ) - return completion - #except Exception as e: - # if "LLM Provider NOT provided" in e.args[0]: - # logger.error(f"Error: error {model} is not currently supported") - # raise ValueError(f"Model {model} is not currently supported") - # logger.error(f"Error: {e}") - # raise ValueError(f"Error: {e}") - #except asyncio.CancelledError: - # logger.error("litellm call cancelled") - # raise RuntimeError("litellm call cancelled") + try: + if model is None: + model = deployment_id + + type_hints = get_type_hints(acompletion) + + for key, value in litellm_kwargs.items(): + if value is not None and key in type_hints and isinstance(value, str): + type_hint = type_hints[key] + # handle optional + if hasattr(type_hint, "__origin__") and type_hint.__origin__ == Union: + litellm_kwargs[key] = type_hint.__args__[0](value) + else: + litellm_kwargs[key] = type_hints[key](value) + + #litellm.set_verbose = True + completion = await acompletion( + model=model, + messages=messages, + deployment_id=deployment_id, + **litellm_kwargs + ) + return completion + except Exception as e: + if "LLM Provider NOT provided" in e.args[0]: + logger.error(f"Error: error {model} is not currently supported") + raise ValueError(f"Model {model} is not currently supported") + logger.error(f"Error: {e}") + raise ValueError(f"Error: {e}") + except asyncio.CancelledError: + logger.error("litellm call cancelled") + raise RuntimeError("litellm call cancelled") async def get_chat_completion(