diff --git a/pyproject.toml b/pyproject.toml index 9b26f7ae89..3efc08d6a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,8 @@ dev = [ ] # These are the dependencies required for running unit tests. unit = [ + "anthropic", + "databricks-sdk", "sqlite-vec", "ollama", "aiosqlite", diff --git a/src/llama_stack/providers/registry/inference.py b/src/llama_stack/providers/registry/inference.py index 35afb296d0..00967a8ec2 100644 --- a/src/llama_stack/providers/registry/inference.py +++ b/src/llama_stack/providers/registry/inference.py @@ -61,6 +61,7 @@ def available_providers() -> list[ProviderSpec]: pip_packages=[], module="llama_stack.providers.remote.inference.cerebras", config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator", description="Cerebras inference provider for running models on Cerebras Cloud platform.", ), RemoteProviderSpec( @@ -149,6 +150,7 @@ def available_providers() -> list[ProviderSpec]: pip_packages=["databricks-sdk"], module="llama_stack.providers.remote.inference.databricks", config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator", description="Databricks inference provider for running models on Databricks' unified analytics platform.", ), RemoteProviderSpec( @@ -158,6 +160,7 @@ def available_providers() -> list[ProviderSpec]: pip_packages=[], module="llama_stack.providers.remote.inference.nvidia", config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", + provider_data_validator="llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator", description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.", ), RemoteProviderSpec( @@ -167,6 +170,7 @@ def available_providers() -> list[ProviderSpec]: pip_packages=[], module="llama_stack.providers.remote.inference.runpod", config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator", description="RunPod inference provider for running models on RunPod's cloud GPU platform.", ), RemoteProviderSpec( diff --git a/src/llama_stack/providers/remote/inference/cerebras/cerebras.py b/src/llama_stack/providers/remote/inference/cerebras/cerebras.py index daf67616b2..d5def9da1d 100644 --- a/src/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/src/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -18,6 +18,8 @@ class CerebrasInferenceAdapter(OpenAIMixin): config: CerebrasImplConfig + provider_data_api_key_field: str = "cerebras_api_key" + def get_base_url(self) -> str: return urljoin(self.config.base_url, "v1") diff --git a/src/llama_stack/providers/remote/inference/cerebras/config.py b/src/llama_stack/providers/remote/inference/cerebras/config.py index dc9a0f5fca..9ba7737245 100644 --- a/src/llama_stack/providers/remote/inference/cerebras/config.py +++ b/src/llama_stack/providers/remote/inference/cerebras/config.py @@ -7,7 +7,7 @@ import os from typing import Any -from pydantic import Field +from pydantic import BaseModel, Field from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -15,6 +15,13 @@ DEFAULT_BASE_URL = "https://api.cerebras.ai" +class CerebrasProviderDataValidator(BaseModel): + cerebras_api_key: str | None = Field( + default=None, + description="API key for Cerebras models", + ) + + @json_schema_type class CerebrasImplConfig(RemoteInferenceProviderConfig): base_url: str = Field( diff --git a/src/llama_stack/providers/remote/inference/databricks/config.py b/src/llama_stack/providers/remote/inference/databricks/config.py index 49d19cd35f..84357f764a 100644 --- a/src/llama_stack/providers/remote/inference/databricks/config.py +++ b/src/llama_stack/providers/remote/inference/databricks/config.py @@ -6,12 +6,19 @@ from typing import Any -from pydantic import Field, SecretStr +from pydantic import BaseModel, Field, SecretStr from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type +class DatabricksProviderDataValidator(BaseModel): + databricks_api_token: str | None = Field( + default=None, + description="API token for Databricks models", + ) + + @json_schema_type class DatabricksImplConfig(RemoteInferenceProviderConfig): url: str | None = Field( diff --git a/src/llama_stack/providers/remote/inference/databricks/databricks.py b/src/llama_stack/providers/remote/inference/databricks/databricks.py index 44996507f1..6b5783ec1e 100644 --- a/src/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/src/llama_stack/providers/remote/inference/databricks/databricks.py @@ -20,6 +20,8 @@ class DatabricksInferenceAdapter(OpenAIMixin): config: DatabricksImplConfig + provider_data_api_key_field: str = "databricks_api_token" + # source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models embedding_model_metadata: dict[str, dict[str, int]] = { "databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192}, diff --git a/src/llama_stack/providers/remote/inference/nvidia/config.py b/src/llama_stack/providers/remote/inference/nvidia/config.py index 2171877a51..3545d2b114 100644 --- a/src/llama_stack/providers/remote/inference/nvidia/config.py +++ b/src/llama_stack/providers/remote/inference/nvidia/config.py @@ -7,12 +7,19 @@ import os from typing import Any -from pydantic import Field +from pydantic import BaseModel, Field from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type +class NVIDIAProviderDataValidator(BaseModel): + nvidia_api_key: str | None = Field( + default=None, + description="API key for NVIDIA NIM models", + ) + + @json_schema_type class NVIDIAConfig(RemoteInferenceProviderConfig): """ diff --git a/src/llama_stack/providers/remote/inference/nvidia/nvidia.py b/src/llama_stack/providers/remote/inference/nvidia/nvidia.py index 5aba6bddcd..ea11b49cd0 100644 --- a/src/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/src/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -17,6 +17,8 @@ class NVIDIAInferenceAdapter(OpenAIMixin): config: NVIDIAConfig + provider_data_api_key_field: str = "nvidia_api_key" + """ NVIDIA Inference Adapter for Llama Stack. """ diff --git a/src/llama_stack/providers/remote/inference/runpod/config.py b/src/llama_stack/providers/remote/inference/runpod/config.py index 3d16d20fdb..a2a1add97d 100644 --- a/src/llama_stack/providers/remote/inference/runpod/config.py +++ b/src/llama_stack/providers/remote/inference/runpod/config.py @@ -6,12 +6,19 @@ from typing import Any -from pydantic import Field, SecretStr +from pydantic import BaseModel, Field, SecretStr from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type +class RunpodProviderDataValidator(BaseModel): + runpod_api_token: str | None = Field( + default=None, + description="API token for RunPod models", + ) + + @json_schema_type class RunpodImplConfig(RemoteInferenceProviderConfig): url: str | None = Field( diff --git a/src/llama_stack/providers/remote/inference/runpod/runpod.py b/src/llama_stack/providers/remote/inference/runpod/runpod.py index db60644caa..a76e941cbb 100644 --- a/src/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/src/llama_stack/providers/remote/inference/runpod/runpod.py @@ -24,6 +24,7 @@ class RunpodInferenceAdapter(OpenAIMixin): """ config: RunpodImplConfig + provider_data_api_key_field: str = "runpod_api_token" def get_base_url(self) -> str: """Get base URL for OpenAI client.""" diff --git a/tests/unit/providers/inference/test_inference_client_caching.py b/tests/unit/providers/inference/test_inference_client_caching.py index 55a6793c2b..aa3a2c77a5 100644 --- a/tests/unit/providers/inference/test_inference_client_caching.py +++ b/tests/unit/providers/inference/test_inference_client_caching.py @@ -10,47 +10,124 @@ import pytest from llama_stack.core.request_headers import request_provider_data_context +from llama_stack.providers.remote.inference.anthropic.anthropic import AnthropicInferenceAdapter +from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig +from llama_stack.providers.remote.inference.cerebras.cerebras import CerebrasInferenceAdapter +from llama_stack.providers.remote.inference.cerebras.config import CerebrasImplConfig +from llama_stack.providers.remote.inference.databricks.config import DatabricksImplConfig +from llama_stack.providers.remote.inference.databricks.databricks import DatabricksInferenceAdapter +from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig +from llama_stack.providers.remote.inference.fireworks.fireworks import FireworksInferenceAdapter +from llama_stack.providers.remote.inference.gemini.config import GeminiConfig +from llama_stack.providers.remote.inference.gemini.gemini import GeminiInferenceAdapter from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter +from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig +from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter +from llama_stack.providers.remote.inference.runpod.config import RunpodImplConfig +from llama_stack.providers.remote.inference.runpod.runpod import RunpodInferenceAdapter +from llama_stack.providers.remote.inference.sambanova.config import SambaNovaImplConfig +from llama_stack.providers.remote.inference.sambanova.sambanova import SambaNovaInferenceAdapter from llama_stack.providers.remote.inference.together.config import TogetherImplConfig from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter +from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig +from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter @pytest.mark.parametrize( - "config_cls,adapter_cls,provider_data_validator", + "config_cls,adapter_cls,provider_data_validator,config_params", [ ( GroqConfig, GroqInferenceAdapter, "llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator", + {}, ), ( OpenAIConfig, OpenAIInferenceAdapter, "llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator", + {}, ), ( TogetherImplConfig, TogetherInferenceAdapter, "llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", + {}, ), ( LlamaCompatConfig, LlamaCompatInferenceAdapter, "llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator", + {}, + ), + ( + CerebrasImplConfig, + CerebrasInferenceAdapter, + "llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator", + {}, + ), + ( + DatabricksImplConfig, + DatabricksInferenceAdapter, + "llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator", + {}, + ), + ( + NVIDIAConfig, + NVIDIAInferenceAdapter, + "llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator", + {}, + ), + ( + RunpodImplConfig, + RunpodInferenceAdapter, + "llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator", + {}, + ), + ( + FireworksImplConfig, + FireworksInferenceAdapter, + "llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator", + {}, + ), + ( + AnthropicConfig, + AnthropicInferenceAdapter, + "llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator", + {}, + ), + ( + GeminiConfig, + GeminiInferenceAdapter, + "llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator", + {}, + ), + ( + SambaNovaImplConfig, + SambaNovaInferenceAdapter, + "llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator", + {}, + ), + ( + VLLMInferenceAdapterConfig, + VLLMInferenceAdapter, + "llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator", + { + "url": "http://fake", + }, ), ], ) -def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str): +def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str, config_params: dict): """Ensure the OpenAI provider does not cache api keys across client requests""" - - inference_adapter = adapter_cls(config=config_cls()) + inference_adapter = adapter_cls(config=config_cls(**config_params)) inference_adapter.__provider_spec__ = MagicMock() inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator diff --git a/uv.lock b/uv.lock index aad77f6a14..9340132438 100644 --- a/uv.lock +++ b/uv.lock @@ -129,6 +129,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] +[[package]] +name = "anthropic" +version = "0.69.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "docstring-parser" }, + { name = "httpx" }, + { name = "jiter" }, + { name = "pydantic" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c8/9d/9ad1778b95f15c5b04e7d328c1b5f558f1e893857b7c33cd288c19c0057a/anthropic-0.69.0.tar.gz", hash = "sha256:c604d287f4d73640f40bd2c0f3265a2eb6ce034217ead0608f6b07a8bc5ae5f2", size = 480622, upload-time = "2025-09-29T16:53:45.282Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/38/75129688de5637eb5b383e5f2b1570a5cc3aecafa4de422da8eea4b90a6c/anthropic-0.69.0-py3-none-any.whl", hash = "sha256:1f73193040f33f11e27c2cd6ec25f24fe7c3f193dc1c5cde6b7a08b18a16bcc5", size = 337265, upload-time = "2025-09-29T16:53:43.686Z" }, +] + [[package]] name = "anyio" version = "4.9.0" @@ -758,6 +777,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/79/b3/28ac139109d9005ad3f6b6f8976ffede6706a6478e21c889ce36c840918e/cryptography-45.0.5-cp37-abi3-win_amd64.whl", hash = "sha256:90cb0a7bb35959f37e23303b7eed0a32280510030daba3f7fdfbb65defde6a97", size = 3390016, upload-time = "2025-07-02T13:05:50.811Z" }, ] +[[package]] +name = "databricks-sdk" +version = "0.67.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b3/5b/df3e5424d833e4f3f9b42c409ef8b513e468c9cdf06c2a9935c6cbc4d128/databricks_sdk-0.67.0.tar.gz", hash = "sha256:f923227babcaad428b0c2eede2755ebe9deb996e2c8654f179eb37f486b37a36", size = 761000, upload-time = "2025-09-25T13:32:10.858Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/ca/2aff3817041483fb8e4f75a74a36ff4ca3a826e276becd1179a591b6348f/databricks_sdk-0.67.0-py3-none-any.whl", hash = "sha256:ef49e49db45ed12c015a32a6f9d4ba395850f25bb3dcffdcaf31a5167fe03ee2", size = 718422, upload-time = "2025-09-25T13:32:09.011Z" }, +] + [[package]] name = "datasets" version = "4.0.0" @@ -856,6 +888,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, ] +[[package]] +name = "docstring-parser" +version = "0.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/9d/c3b43da9515bd270df0f80548d9944e389870713cc1fe2b8fb35fe2bcefd/docstring_parser-0.17.0.tar.gz", hash = "sha256:583de4a309722b3315439bb31d64ba3eebada841f2e2cee23b99df001434c912", size = 27442, upload-time = "2025-07-21T07:35:01.868Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl", hash = "sha256:cf2569abd23dce8099b300f9b4fa8191e9582dda731fd533daf54c4551658708", size = 36896, upload-time = "2025-07-21T07:35:00.684Z" }, +] + [[package]] name = "docutils" version = "0.21.2" @@ -1863,9 +1904,11 @@ test = [ unit = [ { name = "aiohttp" }, { name = "aiosqlite" }, + { name = "anthropic" }, { name = "blobfile" }, { name = "chardet" }, { name = "coverage" }, + { name = "databricks-sdk" }, { name = "faiss-cpu" }, { name = "litellm" }, { name = "mcp" }, @@ -1978,9 +2021,11 @@ test = [ unit = [ { name = "aiohttp" }, { name = "aiosqlite" }, + { name = "anthropic" }, { name = "blobfile" }, { name = "chardet" }, { name = "coverage" }, + { name = "databricks-sdk" }, { name = "faiss-cpu" }, { name = "litellm" }, { name = "mcp" },