Skip to content

Commit

Permalink
Support checking provider-specific /models endpoints for available …
Browse files Browse the repository at this point in the history
…models based on key (#7538)

* test(test_utils.py): initial test for valid models

Addresses #7525

* fix: test

* feat(fireworks_ai/transformation.py): support retrieving valid models from fireworks ai endpoint

* refactor(fireworks_ai/): support checking model info on `/v1/models` route

* docs(set_keys.md): update docs to clarify check llm provider api usage

* fix(watsonx/common_utils.py): support 'WATSONX_ZENAPIKEY' for iam auth

* fix(watsonx): read in watsonx token from env var

* fix: fix linting errors

* fix(utils.py): fix provider config check

* style: cleanup unused imports
  • Loading branch information
krrishdholakia authored Jan 4, 2025
1 parent cac06a3 commit f770dd0
Show file tree
Hide file tree
Showing 12 changed files with 352 additions and 44 deletions.
1 change: 1 addition & 0 deletions docs/my-website/docs/providers/watsonx.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ os.environ["WATSONX_TOKEN"] = "" # IAM auth token
# optional - can also be passed as params to completion() or embedding()
os.environ["WATSONX_PROJECT_ID"] = "" # Project ID of your WatsonX instance
os.environ["WATSONX_DEPLOYMENT_SPACE_ID"] = "" # ID of your deployment space to use deployed models
os.environ["WATSONX_ZENAPIKEY"] = "" # Zen API key (use for long-term api token)
```

See [here](https://cloud.ibm.com/apidocs/watsonx-ai#api-authentication) for more information on how to get an access token to authenticate to watsonx.ai.
Expand Down
16 changes: 16 additions & 0 deletions docs/my-website/docs/set_keys.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,22 @@ assert(valid_models == expected_models)
os.environ = old_environ
```

### `get_valid_models(check_provider_endpoint: True)`

This helper will check the provider's endpoint for valid models.

Currently implemented for:
- OpenAI (if OPENAI_API_KEY is set)
- Fireworks AI (if FIREWORKS_AI_API_KEY is set)
- LiteLLM Proxy (if LITELLM_PROXY_API_KEY is set)

```python
from litellm import get_valid_models

valid_models = get_valid_models(check_provider_endpoint=True)
print(valid_models)
```

### `validate_environment(model: str)`

This helper tells you if you have all the required environment variables for a model, and if not - what's missing.
Expand Down
1 change: 1 addition & 0 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,7 @@ def add_known_models():
from .llms.azure.chat.gpt_transformation import AzureOpenAIConfig
from .llms.azure.completion.transformation import AzureOpenAITextConfig
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
from .llms.litellm_proxy.chat.transformation import LiteLLMProxyChatConfig
from .llms.vllm.completion.transformation import VLLMConfig
from .llms.deepseek.chat.transformation import DeepSeekChatConfig
from .llms.lm_studio.chat.transformation import LMStudioChatConfig
Expand Down
3 changes: 1 addition & 2 deletions litellm/litellm_core_utils/get_llm_provider_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,11 +488,10 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
elif custom_llm_provider == "fireworks_ai":
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1
(
model,
api_base,
dynamic_api_key,
) = litellm.FireworksAIConfig()._get_openai_compatible_provider_info(
model=model, api_base=api_base, api_key=api_key
api_base=api_base, api_key=api_key
)
elif custom_llm_provider == "azure_ai":
(
Expand Down
11 changes: 10 additions & 1 deletion litellm/llms/base_llm/base_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from abc import ABC, abstractmethod
from typing import List, Optional

from litellm.types.utils import ModelInfoBase


class BaseLLMModelInfo(ABC):
@abstractmethod
def get_model_info(self, model: str) -> ModelInfoBase:
def get_model_info(
self,
model: str,
existing_model_info: Optional[ModelInfoBase] = None,
) -> Optional[ModelInfoBase]:
pass

@abstractmethod
def get_models(self) -> List[str]:
pass
37 changes: 32 additions & 5 deletions litellm/llms/fireworks_ai/chat/transformation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from typing import List, Literal, Optional, Tuple, Union, cast

import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues, ChatCompletionImageObject
from litellm.types.utils import ModelInfoBase, ProviderSpecificModelInfo

from ...openai.chat.gpt_transformation import OpenAIGPTConfig


class FireworksAIConfig(BaseLLMModelInfo, OpenAIGPTConfig):
class FireworksAIConfig(OpenAIGPTConfig):
"""
Reference: https://docs.fireworks.ai/api-reference/post-chatcompletions
Expand Down Expand Up @@ -209,8 +208,8 @@ def transform_request(
)

def _get_openai_compatible_provider_info(
self, model: str, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[str, Optional[str], Optional[str]]:
self, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
api_base = (
api_base
or get_secret_str("FIREWORKS_API_BASE")
Expand All @@ -222,4 +221,32 @@ def _get_openai_compatible_provider_info(
or get_secret_str("FIREWORKSAI_API_KEY")
or get_secret_str("FIREWORKS_AI_TOKEN")
)
return model, api_base, dynamic_api_key
return api_base, dynamic_api_key

def get_models(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
api_base, api_key = self._get_openai_compatible_provider_info(
api_base=api_base, api_key=api_key
)
if api_base is None or api_key is None:
raise ValueError(
"FIREWORKS_API_BASE or FIREWORKS_API_KEY is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint."
)

account_id = get_secret_str("FIREWORKS_ACCOUNT_ID")
if account_id is None:
raise ValueError(
"FIREWORKS_ACCOUNT_ID is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint."
)

response = litellm.module_level_client.get(
url=f"{api_base}/v1/accounts/{account_id}/models",
headers={"Authorization": f"Bearer {api_key}"},
)

if response.status_code != 200:
raise ValueError(
f"Failed to fetch models from Fireworks AI. Status code: {response.status_code}, Response: {response.json()}"
)

models = response.json()["models"]
return ["fireworks_ai/" + model["name"] for model in models]
29 changes: 29 additions & 0 deletions litellm/llms/litellm_proxy/chat/transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
Translate from OpenAI's `/v1/chat/completions` to VLLM's `/v1/chat/completions`
"""

from typing import List, Optional, Tuple

from litellm.secret_managers.main import get_secret_str

from ...openai.chat.gpt_transformation import OpenAIGPTConfig


class LiteLLMProxyChatConfig(OpenAIGPTConfig):
def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
api_base = api_base or get_secret_str("LITELLM_PROXY_API_BASE") # type: ignore
dynamic_api_key = api_key or get_secret_str("LITELLM_PROXY_API_KEY")
return api_base, dynamic_api_key

def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:
api_base, api_key = self._get_openai_compatible_provider_info(api_base, api_key)
if api_base is None:
raise ValueError(
"api_base not set for LiteLLM Proxy route. Set in env via `LITELLM_PROXY_API_BASE`"
)
models = super().get_models(api_key=api_key, api_base=api_base)
return [f"litellm_proxy/{model}" for model in models]
46 changes: 44 additions & 2 deletions litellm/llms/openai/chat/gpt_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import httpx

import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
from litellm.types.utils import ModelInfoBase, ModelResponse

from ..common_utils import OpenAIError

Expand All @@ -21,7 +23,7 @@
LiteLLMLoggingObj = Any


class OpenAIGPTConfig(BaseConfig):
class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
Expand Down Expand Up @@ -229,3 +231,43 @@ def validate_environment(
api_base: Optional[str] = None,
) -> dict:
raise NotImplementedError

def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:
"""
Calls OpenAI's `/v1/models` endpoint and returns the list of models.
"""

if api_base is None:
api_base = "https://api.openai.com"
if api_key is None:
api_key = get_secret_str("OPENAI_API_KEY")

response = litellm.module_level_client.get(
url=f"{api_base}/v1/models",
headers={"Authorization": f"Bearer {api_key}"},
)

if response.status_code != 200:
raise Exception(f"Failed to get models: {response.text}")

models = response.json()["data"]
return [model["id"] for model in models]

def get_model_info(
self, model: str, existing_model_info: Optional[ModelInfoBase] = None
) -> ModelInfoBase:

if existing_model_info is not None:
return existing_model_info
return ModelInfoBase(
key=model,
litellm_provider="openai",
mode="chat",
input_cost_per_token=0.0,
output_cost_per_token=0.0,
max_tokens=None,
max_input_tokens=None,
max_output_tokens=None,
)
8 changes: 7 additions & 1 deletion litellm/llms/watsonx/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,12 @@ def validate_environment(

if "Authorization" in headers:
return {**default_headers, **headers}
token = cast(Optional[str], optional_params.get("token"))
token = cast(
Optional[str],
optional_params.get("token")
or get_secret_str("WATSONX_ZENAPIKEY")
or get_secret_str("WATSONX_TOKEN"),
)
if token:
headers["Authorization"] = f"Bearer {token}"
else:
Expand Down Expand Up @@ -245,6 +250,7 @@ def get_watsonx_credentials(
)

token: Optional[str] = None

if wx_credentials is not None:
api_base = wx_credentials.get("url", api_base)
api_key = wx_credentials.get(
Expand Down
50 changes: 36 additions & 14 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4223,6 +4223,7 @@ def _get_model_info_helper( # noqa: PLR0915

_model_info: Optional[Dict[str, Any]] = None
key: Optional[str] = None
provider_config: Optional[BaseLLMModelInfo] = None
if combined_model_name in litellm.model_cost:
key = combined_model_name
_model_info = _get_model_info_from_model_cost(key=key)
Expand Down Expand Up @@ -4261,16 +4262,20 @@ def _get_model_info_helper( # noqa: PLR0915
model_info=_model_info, custom_llm_provider=custom_llm_provider
):
_model_info = None
if _model_info is None and ProviderConfigManager.get_provider_model_info(
model=model, provider=LlmProviders(custom_llm_provider)
):

if custom_llm_provider:
provider_config = ProviderConfigManager.get_provider_model_info(
model=model, provider=LlmProviders(custom_llm_provider)
)
if provider_config is not None:
_model_info = cast(
dict, provider_config.get_model_info(model=model)
)

if _model_info is None and provider_config is not None:
_model_info = cast(
Optional[Dict],
provider_config.get_model_info(
model=model, existing_model_info=_model_info
),
)
if key is None:
key = "provider_specific_model_info"
if _model_info is None or key is None:
raise ValueError(
Expand Down Expand Up @@ -5706,12 +5711,12 @@ def trim_messages(
return messages


def get_valid_models() -> List[str]:
def get_valid_models(check_provider_endpoint: bool = False) -> List[str]:
"""
Returns a list of valid LLMs based on the set environment variables
Args:
None
check_provider_endpoint: If True, will check the provider's endpoint for valid models.
Returns:
A list of valid LLMs
Expand All @@ -5725,22 +5730,36 @@ def get_valid_models() -> List[str]:

for provider in litellm.provider_list:
# edge case litellm has together_ai as a provider, it should be togetherai
provider = provider.replace("_", "")
env_provider_1 = provider.replace("_", "")
env_provider_2 = provider

# litellm standardizes expected provider keys to
# PROVIDER_API_KEY. Example: OPENAI_API_KEY, COHERE_API_KEY
expected_provider_key = f"{provider.upper()}_API_KEY"
if expected_provider_key in environ_keys:
expected_provider_key_1 = f"{env_provider_1.upper()}_API_KEY"
expected_provider_key_2 = f"{env_provider_2.upper()}_API_KEY"
if (
expected_provider_key_1 in environ_keys
or expected_provider_key_2 in environ_keys
):
# key is set
valid_providers.append(provider)

for provider in valid_providers:
provider_config = ProviderConfigManager.get_provider_model_info(
model=None,
provider=LlmProviders(provider),
)

if provider == "azure":
valid_models.append("Azure-LLM")
elif provider_config is not None and check_provider_endpoint:
valid_models.extend(provider_config.get_models())
else:
models_for_provider = litellm.models_by_provider.get(provider, [])
valid_models.extend(models_for_provider)
return valid_models
except Exception:
except Exception as e:
verbose_logger.debug(f"Error getting valid models: {e}")
return [] # NON-Blocking


Expand Down Expand Up @@ -6291,11 +6310,14 @@ def get_provider_text_completion_config(

@staticmethod
def get_provider_model_info(
model: str,
model: Optional[str],
provider: LlmProviders,
) -> Optional[BaseLLMModelInfo]:
if LlmProviders.FIREWORKS_AI == provider:
return litellm.FireworksAIConfig()
elif LlmProviders.LITELLM_PROXY == provider:
return litellm.LiteLLMProxyChatConfig()

return None


Expand Down
Loading

0 comments on commit f770dd0

Please sign in to comment.