Skip to content

Commit 5dfb0e0

Browse files
Add baseten integration
1 parent 6d46291 commit 5dfb0e0

File tree

3 files changed

+492
-0
lines changed

3 files changed

+492
-0
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ packages = ["src/strands"]
5050
anthropic = [
5151
"anthropic>=0.21.0,<1.0.0",
5252
]
53+
baseten = [
54+
"openai>=1.68.0,<2.0.0",
55+
]
5356
dev = [
5457
"commitizen>=4.4.0,<5.0.0",
5558
"hatch>=1.0.0,<2.0.0",

src/strands/models/baseten.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
"""Baseten model provider.
2+
3+
- Docs: https://docs.baseten.co/
4+
"""
5+
6+
import logging
7+
from typing import Any, Generator, Iterable, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast
8+
9+
import openai
10+
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
11+
from pydantic import BaseModel
12+
from typing_extensions import Unpack, override
13+
14+
from ..types.content import Messages
15+
from ..types.models import OpenAIModel
16+
17+
logger = logging.getLogger(__name__)
18+
19+
T = TypeVar("T", bound=BaseModel)
20+
21+
22+
class Client(Protocol):
23+
"""Protocol defining the OpenAI-compatible interface for the underlying provider client."""
24+
25+
@property
26+
# pragma: no cover
27+
def chat(self) -> Any:
28+
"""Chat completions interface."""
29+
...
30+
31+
32+
class BasetenModel(OpenAIModel):
33+
"""Baseten model provider implementation."""
34+
35+
client: Client
36+
37+
class BasetenConfig(TypedDict, total=False):
38+
"""Configuration options for Baseten models.
39+
40+
Attributes:
41+
model_id: Model ID for the Baseten model.
42+
For Model APIs, use model slugs like "deepseek-ai/DeepSeek-R1-0528" or "meta-llama/Llama-4-Maverick-17B-128E-Instruct".
43+
For dedicated deployments, use the deployment ID.
44+
base_url: Base URL for the Baseten API.
45+
For Model APIs: https://inference.baseten.co/v1
46+
For dedicated deployments: https://model-xxxxxxx.api.baseten.co/environments/production/sync/v1
47+
params: Model parameters (e.g., max_tokens).
48+
For a complete list of supported parameters, see
49+
https://platform.openai.com/docs/api-reference/chat/create.
50+
"""
51+
52+
model_id: str
53+
base_url: Optional[str]
54+
params: Optional[dict[str, Any]]
55+
56+
def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[BasetenConfig]) -> None:
57+
"""Initialize provider instance.
58+
59+
Args:
60+
client_args: Arguments for the Baseten client.
61+
For a complete list of supported arguments, see https://pypi.org/project/openai/.
62+
**model_config: Configuration options for the Baseten model.
63+
"""
64+
self.config = dict(model_config)
65+
66+
logger.debug("config=<%s> | initializing", self.config)
67+
68+
client_args = client_args or {}
69+
70+
# Set default base URL for Model APIs if not provided
71+
if "base_url" not in client_args and "base_url" not in self.config:
72+
client_args["base_url"] = "https://inference.baseten.co/v1"
73+
elif "base_url" in self.config:
74+
client_args["base_url"] = self.config["base_url"]
75+
76+
self.client = openai.OpenAI(**client_args)
77+
78+
@override
79+
def update_config(self, **model_config: Unpack[BasetenConfig]) -> None: # type: ignore[override]
80+
"""Update the Baseten model configuration with the provided arguments.
81+
82+
Args:
83+
**model_config: Configuration overrides.
84+
"""
85+
self.config.update(model_config)
86+
87+
@override
88+
def get_config(self) -> BasetenConfig:
89+
"""Get the Baseten model configuration.
90+
91+
Returns:
92+
The Baseten model configuration.
93+
"""
94+
return cast(BasetenModel.BasetenConfig, self.config)
95+
96+
@override
97+
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
98+
"""Send the request to the Baseten model and get the streaming response.
99+
100+
Args:
101+
request: The formatted request to send to the Baseten model.
102+
103+
Returns:
104+
An iterable of response events from the Baseten model.
105+
"""
106+
response = self.client.chat.completions.create(**request)
107+
108+
yield {"chunk_type": "message_start"}
109+
yield {"chunk_type": "content_start", "data_type": "text"}
110+
111+
tool_calls: dict[int, list[Any]] = {}
112+
113+
for event in response:
114+
# Defensive: skip events with empty or missing choices
115+
if not getattr(event, "choices", None):
116+
continue
117+
choice = event.choices[0]
118+
119+
if choice.delta.content:
120+
yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
121+
122+
if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
123+
yield {
124+
"chunk_type": "content_delta",
125+
"data_type": "reasoning_content",
126+
"data": choice.delta.reasoning_content,
127+
}
128+
129+
for tool_call in choice.delta.tool_calls or []:
130+
tool_calls.setdefault(tool_call.index, []).append(tool_call)
131+
132+
if choice.finish_reason:
133+
break
134+
135+
yield {"chunk_type": "content_stop", "data_type": "text"}
136+
137+
for tool_deltas in tool_calls.values():
138+
yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
139+
140+
for tool_delta in tool_deltas:
141+
yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}
142+
143+
yield {"chunk_type": "content_stop", "data_type": "tool"}
144+
145+
yield {"chunk_type": "message_stop", "data": choice.finish_reason}
146+
147+
# Skip remaining events as we don't have use for anything except the final usage payload
148+
for event in response:
149+
_ = event
150+
151+
yield {"chunk_type": "metadata", "data": event.usage}
152+
153+
@override
154+
def structured_output(
155+
self, output_model: Type[T], prompt: Messages
156+
) -> Generator[dict[str, Union[T, Any]], None, None]:
157+
"""Get structured output from the model.
158+
159+
Args:
160+
output_model: The output model to use for the agent.
161+
prompt: The prompt messages to use for the agent.
162+
163+
Yields:
164+
Model events with the last being the structured output.
165+
"""
166+
response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore
167+
model=self.get_config()["model_id"],
168+
messages=super().format_request(prompt)["messages"],
169+
response_format=output_model,
170+
)
171+
172+
parsed: T | None = None
173+
# Find the first choice with tool_calls
174+
if len(response.choices) > 1:
175+
raise ValueError("Multiple choices found in the Baseten response.")
176+
177+
for choice in response.choices:
178+
if isinstance(choice.message.parsed, output_model):
179+
parsed = choice.message.parsed
180+
break
181+
182+
if parsed:
183+
yield {"output": parsed}
184+
else:
185+
raise ValueError("No valid tool use or tool use input was found in the Baseten response.")

0 commit comments

Comments
 (0)