1+ """Baseten model provider.
2+
3+ - Docs: https://docs.baseten.co/
4+ """
5+
6+ import logging
7+ from typing import Any , Generator , Iterable , Optional , Protocol , Type , TypedDict , TypeVar , Union , cast
8+
9+ import openai
10+ from openai .types .chat .parsed_chat_completion import ParsedChatCompletion
11+ from pydantic import BaseModel
12+ from typing_extensions import Unpack , override
13+
14+ from ..types .content import Messages
15+ from ..types .models import OpenAIModel
16+
17+ logger = logging .getLogger (__name__ )
18+
19+ T = TypeVar ("T" , bound = BaseModel )
20+
21+
22+ class Client (Protocol ):
23+ """Protocol defining the OpenAI-compatible interface for the underlying provider client."""
24+
25+ @property
26+ # pragma: no cover
27+ def chat (self ) -> Any :
28+ """Chat completions interface."""
29+ ...
30+
31+
32+ class BasetenModel (OpenAIModel ):
33+ """Baseten model provider implementation."""
34+
35+ client : Client
36+
37+ class BasetenConfig (TypedDict , total = False ):
38+ """Configuration options for Baseten models.
39+
40+ Attributes:
41+ model_id: Model ID for the Baseten model.
42+ For Model APIs, use model slugs like "deepseek-ai/DeepSeek-R1-0528" or "meta-llama/Llama-4-Maverick-17B-128E-Instruct".
43+ For dedicated deployments, use the deployment ID.
44+ base_url: Base URL for the Baseten API.
45+ For Model APIs: https://inference.baseten.co/v1
46+ For dedicated deployments: https://model-xxxxxxx.api.baseten.co/environments/production/sync/v1
47+ params: Model parameters (e.g., max_tokens).
48+ For a complete list of supported parameters, see
49+ https://platform.openai.com/docs/api-reference/chat/create.
50+ """
51+
52+ model_id : str
53+ base_url : Optional [str ]
54+ params : Optional [dict [str , Any ]]
55+
56+ def __init__ (self , client_args : Optional [dict [str , Any ]] = None , ** model_config : Unpack [BasetenConfig ]) -> None :
57+ """Initialize provider instance.
58+
59+ Args:
60+ client_args: Arguments for the Baseten client.
61+ For a complete list of supported arguments, see https://pypi.org/project/openai/.
62+ **model_config: Configuration options for the Baseten model.
63+ """
64+ self .config = dict (model_config )
65+
66+ logger .debug ("config=<%s> | initializing" , self .config )
67+
68+ client_args = client_args or {}
69+
70+ # Set default base URL for Model APIs if not provided
71+ if "base_url" not in client_args and "base_url" not in self .config :
72+ client_args ["base_url" ] = "https://inference.baseten.co/v1"
73+ elif "base_url" in self .config :
74+ client_args ["base_url" ] = self .config ["base_url" ]
75+
76+ self .client = openai .OpenAI (** client_args )
77+
78+ @override
79+ def update_config (self , ** model_config : Unpack [BasetenConfig ]) -> None : # type: ignore[override]
80+ """Update the Baseten model configuration with the provided arguments.
81+
82+ Args:
83+ **model_config: Configuration overrides.
84+ """
85+ self .config .update (model_config )
86+
87+ @override
88+ def get_config (self ) -> BasetenConfig :
89+ """Get the Baseten model configuration.
90+
91+ Returns:
92+ The Baseten model configuration.
93+ """
94+ return cast (BasetenModel .BasetenConfig , self .config )
95+
96+ @override
97+ def stream (self , request : dict [str , Any ]) -> Iterable [dict [str , Any ]]:
98+ """Send the request to the Baseten model and get the streaming response.
99+
100+ Args:
101+ request: The formatted request to send to the Baseten model.
102+
103+ Returns:
104+ An iterable of response events from the Baseten model.
105+ """
106+ response = self .client .chat .completions .create (** request )
107+
108+ yield {"chunk_type" : "message_start" }
109+ yield {"chunk_type" : "content_start" , "data_type" : "text" }
110+
111+ tool_calls : dict [int , list [Any ]] = {}
112+
113+ for event in response :
114+ # Defensive: skip events with empty or missing choices
115+ if not getattr (event , "choices" , None ):
116+ continue
117+ choice = event .choices [0 ]
118+
119+ if choice .delta .content :
120+ yield {"chunk_type" : "content_delta" , "data_type" : "text" , "data" : choice .delta .content }
121+
122+ if hasattr (choice .delta , "reasoning_content" ) and choice .delta .reasoning_content :
123+ yield {
124+ "chunk_type" : "content_delta" ,
125+ "data_type" : "reasoning_content" ,
126+ "data" : choice .delta .reasoning_content ,
127+ }
128+
129+ for tool_call in choice .delta .tool_calls or []:
130+ tool_calls .setdefault (tool_call .index , []).append (tool_call )
131+
132+ if choice .finish_reason :
133+ break
134+
135+ yield {"chunk_type" : "content_stop" , "data_type" : "text" }
136+
137+ for tool_deltas in tool_calls .values ():
138+ yield {"chunk_type" : "content_start" , "data_type" : "tool" , "data" : tool_deltas [0 ]}
139+
140+ for tool_delta in tool_deltas :
141+ yield {"chunk_type" : "content_delta" , "data_type" : "tool" , "data" : tool_delta }
142+
143+ yield {"chunk_type" : "content_stop" , "data_type" : "tool" }
144+
145+ yield {"chunk_type" : "message_stop" , "data" : choice .finish_reason }
146+
147+ # Skip remaining events as we don't have use for anything except the final usage payload
148+ for event in response :
149+ _ = event
150+
151+ yield {"chunk_type" : "metadata" , "data" : event .usage }
152+
153+ @override
154+ def structured_output (
155+ self , output_model : Type [T ], prompt : Messages
156+ ) -> Generator [dict [str , Union [T , Any ]], None , None ]:
157+ """Get structured output from the model.
158+
159+ Args:
160+ output_model: The output model to use for the agent.
161+ prompt: The prompt messages to use for the agent.
162+
163+ Yields:
164+ Model events with the last being the structured output.
165+ """
166+ response : ParsedChatCompletion = self .client .beta .chat .completions .parse ( # type: ignore
167+ model = self .get_config ()["model_id" ],
168+ messages = super ().format_request (prompt )["messages" ],
169+ response_format = output_model ,
170+ )
171+
172+ parsed : T | None = None
173+ # Find the first choice with tool_calls
174+ if len (response .choices ) > 1 :
175+ raise ValueError ("Multiple choices found in the Baseten response." )
176+
177+ for choice in response .choices :
178+ if isinstance (choice .message .parsed , output_model ):
179+ parsed = choice .message .parsed
180+ break
181+
182+ if parsed :
183+ yield {"output" : parsed }
184+ else :
185+ raise ValueError ("No valid tool use or tool use input was found in the Baseten response." )
0 commit comments