Skip to content

Commit 84f0418

Browse files
committed
fix: allowed_models config did not filter models
1 parent d45137a commit 84f0418

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

src/llama_stack/providers/utils/inference/model_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121

2222
class RemoteInferenceProviderConfig(BaseModel):
23-
allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default
23+
allowed_models: list[str] | None = Field(
2424
default=None,
2525
description="List of models that should be registered with the model registry. If None, all models are allowed.",
2626
)

src/llama_stack/providers/utils/inference/openai_mixin.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
8383
# This is set in list_models() and used in check_model_availability()
8484
_model_cache: dict[str, Model] = {}
8585

86-
# List of allowed models for this provider, if empty all models allowed
87-
allowed_models: list[str] = []
88-
8986
# Optional field name in provider data to look for API key, which takes precedence
9087
provider_data_api_key_field: str | None = None
9188

@@ -441,7 +438,7 @@ async def list_models(self) -> list[Model] | None:
441438
for provider_model_id in provider_models_ids:
442439
if not isinstance(provider_model_id, str):
443440
raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
444-
if self.allowed_models and provider_model_id not in self.allowed_models:
441+
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
445442
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
446443
continue
447444
model = self.construct_model_from_identifier(provider_model_id)

tests/unit/providers/utils/inference/test_openai_mixin.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,8 @@ class TestOpenAIMixinAllowedModels:
455455
"""Test cases for allowed_models filtering functionality"""
456456

457457
async def test_list_models_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
458-
"""Test that list_models filters models based on allowed_models set"""
459-
mixin.allowed_models = {"some-mock-model-id", "another-mock-model-id"}
458+
"""Test that list_models filters models based on allowed_models"""
459+
mixin.config.allowed_models = ["some-mock-model-id", "another-mock-model-id"]
460460

461461
with mock_client_context(mixin, mock_client_with_models):
462462
result = await mixin.list_models()
@@ -470,8 +470,18 @@ async def test_list_models_with_allowed_models_filter(self, mixin, mock_client_w
470470
assert "final-mock-model-id" not in model_ids
471471

472472
async def test_list_models_with_empty_allowed_models(self, mixin, mock_client_with_models, mock_client_context):
473-
"""Test that empty allowed_models set allows all models"""
474-
assert len(mixin.allowed_models) == 0
473+
"""Test that empty allowed_models allows no models"""
474+
mixin.config.allowed_models = []
475+
476+
with mock_client_context(mixin, mock_client_with_models):
477+
result = await mixin.list_models()
478+
479+
assert result is not None
480+
assert len(result) == 0 # No models should be included
481+
482+
async def test_list_models_with_omitted_allowed_models(self, mixin, mock_client_with_models, mock_client_context):
483+
"""Test that omitted allowed_models allows all models"""
484+
assert mixin.config.allowed_models is None
475485

476486
with mock_client_context(mixin, mock_client_with_models):
477487
result = await mixin.list_models()
@@ -488,7 +498,7 @@ async def test_check_model_availability_with_allowed_models(
488498
self, mixin, mock_client_with_models, mock_client_context
489499
):
490500
"""Test that check_model_availability respects allowed_models"""
491-
mixin.allowed_models = {"final-mock-model-id"}
501+
mixin.config.allowed_models = ["final-mock-model-id"]
492502

493503
with mock_client_context(mixin, mock_client_with_models):
494504
assert await mixin.check_model_availability("final-mock-model-id")
@@ -536,7 +546,7 @@ async def test_register_model_not_available(self, mixin, mock_client_with_models
536546

537547
async def test_register_model_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
538548
"""Test model registration with allowed_models filtering"""
539-
mixin.allowed_models = {"some-mock-model-id"}
549+
mixin.config.allowed_models = ["some-mock-model-id"]
540550

541551
# Test with allowed model
542552
allowed_model = Model(
@@ -690,7 +700,7 @@ async def test_respects_allowed_models(self, config):
690700
mixin = CustomListProviderModelIdsImplementation(
691701
config=config, custom_model_ids=["model-1", "model-2", "model-3"]
692702
)
693-
mixin.allowed_models = ["model-1"]
703+
mixin.config.allowed_models = ["model-1"]
694704

695705
result = await mixin.list_models()
696706

0 commit comments

Comments
 (0)