Skip to content

Commit f69ae45

Browse files
committed
query_available_models() -> list[str] -> check_model_availability(model) -> bool
1 parent c2ab898 commit f69ae45

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

llama_stack/providers/remote/inference/nvidia/nvidia.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,14 @@ def __init__(self, config: NVIDIAConfig) -> None:
8989

9090
self._config = config
9191

92-
async def query_available_models(self) -> list[str]:
93-
"""Query available models from the NVIDIA API."""
94-
return [model.id async for model in self._get_client().models.list()]
92+
async def check_model_availability(self, model: str) -> bool:
93+
"""Check if a specific model is available from the NVIDIA API."""
94+
try:
95+
await self._get_client().models.retrieve(model)
96+
return True
97+
except Exception:
98+
# If we can't retrieve the model, it's not available
99+
return False
95100

96101
@lru_cache # noqa: B019
97102
def _get_client(self, provider_model_id: str | None = None) -> AsyncOpenAI:

tests/unit/providers/nvidia/test_supervised_fine_tuning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,8 @@ def test_inference_register_model(self):
344344
)
345345

346346
# simulate a NIM where default/job-1234 is an available model
347-
with patch.object(self.inference_adapter, "query_available_models", new_callable=AsyncMock) as mock_query:
348-
mock_query.return_value = [model_id]
347+
with patch.object(self.inference_adapter, "check_model_availability", new_callable=AsyncMock) as mock_check:
348+
mock_check.return_value = True
349349
result = self.run_async(self.inference_adapter.register_model(model))
350350
assert result == model
351351
assert len(self.inference_adapter.alias_to_provider_id_map) > 1

0 commit comments

Comments
 (0)