Skip to content

Commit 6711fd4

Browse files
committed
update SambaNovaInferenceAdapter to use _get_params from LiteLLMOpenAIMixin by adding extra params to the mixin
1 parent 037d28f commit 6711fd4

File tree

2 files changed

+12
-58
lines changed

2 files changed

+12
-58
lines changed

llama_stack/providers/remote/inference/sambanova/sambanova.py

Lines changed: 2 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,9 @@
66

77
import requests
88

9-
from llama_stack.apis.inference import (
10-
ChatCompletionRequest,
11-
JsonSchemaResponseFormat,
12-
ToolChoice,
13-
)
149
from llama_stack.apis.models import Model
1510
from llama_stack.log import get_logger
1611
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
17-
from llama_stack.providers.utils.inference.openai_compat import (
18-
convert_message_to_openai_dict_new,
19-
convert_tooldef_to_openai_tool,
20-
get_sampling_options,
21-
)
2212

2313
from .config import SambaNovaImplConfig
2414
from .models import MODEL_ENTRIES
@@ -39,54 +29,10 @@ def __init__(self, config: SambaNovaImplConfig):
3929
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
4030
provider_data_api_key_field="sambanova_api_key",
4131
openai_compat_api_base=self.config.url,
32+
download_images=True, # SambaNova requires base64 image encoding
33+
json_schema_strict=False, # SambaNova doesn't support strict=True yet
4234
)
4335

44-
async def _get_params(self, request: ChatCompletionRequest) -> dict:
45-
input_dict = {}
46-
47-
input_dict["messages"] = [
48-
await convert_message_to_openai_dict_new(m, download_images=True) for m in request.messages
49-
]
50-
if fmt := request.response_format:
51-
if not isinstance(fmt, JsonSchemaResponseFormat):
52-
raise ValueError(
53-
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
54-
)
55-
56-
fmt = fmt.json_schema
57-
name = fmt["title"]
58-
del fmt["title"]
59-
fmt["additionalProperties"] = False
60-
61-
# Apply additionalProperties: False recursively to all objects
62-
fmt = self._add_additional_properties_recursive(fmt)
63-
64-
input_dict["response_format"] = {
65-
"type": "json_schema",
66-
"json_schema": {
67-
"name": name,
68-
"schema": fmt,
69-
"strict": False,
70-
},
71-
}
72-
if request.tools:
73-
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
74-
if request.tool_config.tool_choice:
75-
input_dict["tool_choice"] = (
76-
request.tool_config.tool_choice.value
77-
if isinstance(request.tool_config.tool_choice, ToolChoice)
78-
else request.tool_config.tool_choice
79-
)
80-
81-
return {
82-
"model": request.model,
83-
"api_key": self.get_api_key(),
84-
"api_base": self.api_base,
85-
**input_dict,
86-
"stream": request.stream,
87-
**get_sampling_options(request.sampling_params),
88-
}
89-
9036
async def register_model(self, model: Model) -> Model:
9137
model_id = self.get_provider_model_id(model.provider_resource_id)
9238

llama_stack/providers/utils/inference/litellm_openai_mixin.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def __init__(
7272
api_key_from_config: str | None,
7373
provider_data_api_key_field: str,
7474
openai_compat_api_base: str | None = None,
75+
download_images: bool = False,
76+
json_schema_strict: bool = True,
7577
):
7678
"""
7779
Initialize the LiteLLMOpenAIMixin.
@@ -81,13 +83,17 @@ def __init__(
8183
:param provider_data_api_key_field: The field in the provider data that contains the API key.
8284
:param litellm_provider_name: The name of the provider, used for model lookups.
8385
:param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility.
86+
:param download_images: Whether to download images and convert to base64 for message conversion.
87+
:param json_schema_strict: Whether to use strict mode for JSON schema validation.
8488
"""
8589
ModelRegistryHelper.__init__(self, model_entries)
8690

8791
self.litellm_provider_name = litellm_provider_name
8892
self.api_key_from_config = api_key_from_config
8993
self.provider_data_api_key_field = provider_data_api_key_field
9094
self.api_base = openai_compat_api_base
95+
self.download_images = download_images
96+
self.json_schema_strict = json_schema_strict
9197

9298
if openai_compat_api_base:
9399
self.is_openai_compat = True
@@ -206,7 +212,9 @@ def _add_additional_properties_recursive(self, schema):
206212
async def _get_params(self, request: ChatCompletionRequest) -> dict:
207213
input_dict = {}
208214

209-
input_dict["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages]
215+
input_dict["messages"] = [
216+
await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages
217+
]
210218
if fmt := request.response_format:
211219
if not isinstance(fmt, JsonSchemaResponseFormat):
212220
raise ValueError(
@@ -226,7 +234,7 @@ async def _get_params(self, request: ChatCompletionRequest) -> dict:
226234
"json_schema": {
227235
"name": name,
228236
"schema": fmt,
229-
"strict": True,
237+
"strict": self.json_schema_strict,
230238
},
231239
}
232240
if request.tools:

0 commit comments

Comments
 (0)