Skip to content

Commit 79c9e46

Browse files
committed
feat(openai): add configurable base_url support with OPENAI_BASE_URL env var
- Add base_url field to OpenAIConfig with default "https://api.openai.com/v1" - Update sample_run_config to support OPENAI_BASE_URL environment variable - Modify get_base_url() to return configured base_url instead of hardcoded value - Add comprehensive test suite covering: - Default base URL behavior - Custom base URL from config - Environment variable override - Config precedence over environment variables - Client initialization with configured URL - Model availability checks using configured URL This enables users to configure custom OpenAI-compatible API endpoints via environment variables or configuration files.
1 parent 09abdb0 commit 79c9e46

File tree

7 files changed

+143
-3
lines changed

7 files changed

+143
-3
lines changed

docs/source/providers/inference/remote_openai.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
99
| Field | Type | Required | Default | Description |
1010
|-------|------|----------|---------|-------------|
1111
| `api_key` | `str \| None` | No | | API key for OpenAI models |
12+
| `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
1213

1314
## Sample Configuration
1415

1516
```yaml
1617
api_key: ${env.OPENAI_API_KEY:=}
18+
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
1719

1820
```
1921

llama_stack/providers/remote/inference/openai/config.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,19 @@ class OpenAIConfig(BaseModel):
2424
default=None,
2525
description="API key for OpenAI models",
2626
)
27+
base_url: str = Field(
28+
default="https://api.openai.com/v1",
29+
description="Base URL for OpenAI API",
30+
)
2731

2832
@classmethod
29-
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY:=}", **kwargs) -> dict[str, Any]:
33+
def sample_run_config(
34+
cls,
35+
api_key: str = "${env.OPENAI_API_KEY:=}",
36+
base_url: str = "${env.OPENAI_BASE_URL:=https://api.openai.com/v1}",
37+
**kwargs,
38+
) -> dict[str, Any]:
3039
return {
3140
"api_key": api_key,
41+
"base_url": base_url,
3242
}

llama_stack/providers/remote/inference/openai/openai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def get_base_url(self) -> str:
6565
"""
6666
Get the OpenAI API base URL.
6767
68-
Returns the standard OpenAI API base URL for direct OpenAI API calls.
68+
Returns the OpenAI API base URL from the configuration.
6969
"""
70-
return "https://api.openai.com/v1"
70+
return self.config.base_url
7171

7272
async def initialize(self) -> None:
7373
await super().initialize()

llama_stack/templates/ci-tests/run.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ providers:
5656
provider_type: remote::openai
5757
config:
5858
api_key: ${env.OPENAI_API_KEY:=}
59+
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
5960
- provider_id: anthropic
6061
provider_type: remote::anthropic
6162
config:

llama_stack/templates/open-benchmark/run.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ providers:
1616
provider_type: remote::openai
1717
config:
1818
api_key: ${env.OPENAI_API_KEY:=}
19+
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
1920
- provider_id: anthropic
2021
provider_type: remote::anthropic
2122
config:

llama_stack/templates/starter/run.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ providers:
5656
provider_type: remote::openai
5757
config:
5858
api_key: ${env.OPENAI_API_KEY:=}
59+
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
5960
- provider_id: anthropic
6061
provider_type: remote::anthropic
6162
config:
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
import os
8+
from unittest.mock import AsyncMock, MagicMock, patch
9+
10+
from llama_stack.distribution.stack import replace_env_vars
11+
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
12+
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
13+
14+
15+
class TestOpenAIBaseURLConfig:
16+
"""Test that OPENAI_BASE_URL environment variable properly configures the OpenAI adapter."""
17+
18+
def test_default_base_url_without_env_var(self):
19+
"""Test that the adapter uses the default OpenAI base URL when no environment variable is set."""
20+
config = OpenAIConfig(api_key="test-key")
21+
adapter = OpenAIInferenceAdapter(config)
22+
23+
assert adapter.get_base_url() == "https://api.openai.com/v1"
24+
25+
def test_custom_base_url_from_config(self):
26+
"""Test that the adapter uses a custom base URL when provided in config."""
27+
custom_url = "https://custom.openai.com/v1"
28+
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
29+
adapter = OpenAIInferenceAdapter(config)
30+
31+
assert adapter.get_base_url() == custom_url
32+
33+
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://env.openai.com/v1"})
34+
def test_base_url_from_environment_variable(self):
35+
"""Test that the adapter uses base URL from OPENAI_BASE_URL environment variable."""
36+
# Use sample_run_config which has proper environment variable syntax
37+
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
38+
processed_config = replace_env_vars(config_data)
39+
config = OpenAIConfig.model_validate(processed_config)
40+
adapter = OpenAIInferenceAdapter(config)
41+
42+
assert adapter.get_base_url() == "https://env.openai.com/v1"
43+
44+
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://env.openai.com/v1"})
45+
def test_config_overrides_environment_variable(self):
46+
"""Test that explicit config value overrides environment variable."""
47+
custom_url = "https://config.openai.com/v1"
48+
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
49+
adapter = OpenAIInferenceAdapter(config)
50+
51+
# Config should take precedence over environment variable
52+
assert adapter.get_base_url() == custom_url
53+
54+
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
55+
def test_client_uses_configured_base_url(self, mock_openai_class):
56+
"""Test that the OpenAI client is initialized with the configured base URL."""
57+
custom_url = "https://test.openai.com/v1"
58+
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
59+
adapter = OpenAIInferenceAdapter(config)
60+
61+
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
62+
adapter.get_api_key = MagicMock(return_value="test-key")
63+
64+
# Access the client property to trigger AsyncOpenAI initialization
65+
_ = adapter.client
66+
67+
# Verify AsyncOpenAI was called with the correct base_url
68+
mock_openai_class.assert_called_once_with(
69+
api_key="test-key",
70+
base_url=custom_url,
71+
)
72+
73+
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
74+
async def test_check_model_availability_uses_configured_url(self, mock_openai_class):
75+
"""Test that check_model_availability uses the configured base URL."""
76+
custom_url = "https://test.openai.com/v1"
77+
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
78+
adapter = OpenAIInferenceAdapter(config)
79+
80+
# Mock the get_api_key method
81+
adapter.get_api_key = MagicMock(return_value="test-key")
82+
83+
# Mock the AsyncOpenAI client and its models.retrieve method
84+
mock_client = MagicMock()
85+
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
86+
mock_openai_class.return_value = mock_client
87+
88+
# Call check_model_availability and verify it returns True
89+
assert await adapter.check_model_availability("gpt-4")
90+
91+
# Verify the client was created with the custom URL
92+
mock_openai_class.assert_called_with(
93+
api_key="test-key",
94+
base_url=custom_url,
95+
)
96+
97+
# Verify the method was called and returned True
98+
mock_client.models.retrieve.assert_called_once_with("gpt-4")
99+
100+
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://proxy.openai.com/v1"})
101+
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
102+
async def test_environment_variable_affects_model_availability_check(self, mock_openai_class):
103+
"""Test that setting OPENAI_BASE_URL environment variable affects where model availability is checked."""
104+
# Use sample_run_config which has proper environment variable syntax
105+
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
106+
processed_config = replace_env_vars(config_data)
107+
config = OpenAIConfig.model_validate(processed_config)
108+
adapter = OpenAIInferenceAdapter(config)
109+
110+
# Mock the get_api_key method
111+
adapter.get_api_key = MagicMock(return_value="test-key")
112+
113+
# Mock the AsyncOpenAI client
114+
mock_client = MagicMock()
115+
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
116+
mock_openai_class.return_value = mock_client
117+
118+
# Call check_model_availability and verify it returns True
119+
assert await adapter.check_model_availability("gpt-4")
120+
121+
# Verify the client was created with the environment variable URL
122+
mock_openai_class.assert_called_with(
123+
api_key="test-key",
124+
base_url="https://proxy.openai.com/v1",
125+
)

0 commit comments

Comments
 (0)