1
1
import asyncio
2
2
import json
3
- from typing import AsyncIterator , Iterator , List , Literal , Type
3
+ from typing import AsyncIterator , Iterator , List , Type
4
4
5
5
import httpx
6
- from pydantic import PrivateAttr , SecretStr
6
+ from pydantic import PrivateAttr
7
+ from swarmauri_base .ComponentBase import ComponentBase
7
8
from swarmauri_base .llms .LLMBase import LLMBase
8
9
from swarmauri_base .messages .MessageBase import MessageBase
9
- from swarmauri_base .ComponentBase import ComponentBase
10
10
11
11
from swarmauri_standard .conversations .Conversation import Conversation
12
12
from swarmauri_standard .messages .AgentMessage import AgentMessage , UsageData
@@ -32,6 +32,7 @@ class AI21StudioModel(LLMBase):
32
32
Provider resources: https://docs.ai21.com/reference/jamba-15-api-ref
33
33
"""
34
34
35
+
35
36
api_key : SecretStr
36
37
allowed_models : List [str ] = ["jamba-1.5-large" , "jamba-1.5-mini" ]
37
38
name : str = "jamba-1.5-large"
@@ -41,7 +42,6 @@ class AI21StudioModel(LLMBase):
41
42
_BASE_URL : str = PrivateAttr (
42
43
default = "https://api.ai21.com/studio/v1/chat/completions"
43
44
)
44
- timeout : float = 600.0
45
45
46
46
def __init__ (self , ** data ) -> None :
47
47
"""
@@ -145,9 +145,12 @@ def predict(
145
145
message_content = response_data ["choices" ][0 ]["message" ]["content" ]
146
146
usage_data = response_data .get ("usage" , {})
147
147
148
- usage = self ._prepare_usage_data (usage_data , prompt_timer .duration )
149
- conversation .add_message (AgentMessage (content = message_content , usage = usage ))
150
-
148
+ # Prepare usage data if tracking is enabled
149
+ if self .include_usage :
150
+ usage = self ._prepare_usage_data (usage_data , prompt_timer .duration )
151
+ conversation .add_message (AgentMessage (content = message_content , usage = usage ))
152
+ else :
153
+ conversation .add_message (AgentMessage (content = message_content ))
151
154
return conversation
152
155
153
156
@retry_on_status_codes ((429 , 529 ), max_retries = 1 )
@@ -194,9 +197,12 @@ async def apredict(
194
197
message_content = response_data ["choices" ][0 ]["message" ]["content" ]
195
198
usage_data = response_data .get ("usage" , {})
196
199
197
- usage = self ._prepare_usage_data (usage_data , prompt_timer .duration )
198
- conversation .add_message (AgentMessage (content = message_content , usage = usage ))
199
-
200
+ # Prepare usage data if tracking is enabled
201
+ if self .include_usage :
202
+ usage = self ._prepare_usage_data (usage_data , prompt_timer .duration )
203
+ conversation .add_message (AgentMessage (content = message_content , usage = usage ))
204
+ else :
205
+ conversation .add_message (AgentMessage (content = message_content ))
200
206
return conversation
201
207
202
208
@retry_on_status_codes ((429 , 529 ), max_retries = 1 )
@@ -265,6 +271,15 @@ def stream(
265
271
266
272
conversation .add_message (AgentMessage (content = message_content , usage = usage ))
267
273
274
+ # Prepare usage data if tracking is enabled
275
+ if self .include_usage :
276
+ usage = self ._prepare_usage_data (
277
+ usage_data , prompt_timer .duration , completion_timer .duration
278
+ )
279
+ conversation .add_message (AgentMessage (content = message_content , usage = usage ))
280
+ else :
281
+ conversation .add_message (AgentMessage (content = message_content ))
282
+
268
283
@retry_on_status_codes ((429 , 529 ), max_retries = 1 )
269
284
async def astream (
270
285
self ,
@@ -325,11 +340,14 @@ async def astream(
325
340
except json .JSONDecodeError :
326
341
pass
327
342
328
- usage = self ._prepare_usage_data (
329
- usage_data , prompt_timer .duration , completion_timer .duration
330
- )
331
-
332
- conversation .add_message (AgentMessage (content = message_content , usage = usage ))
343
+ # Prepare usage data if tracking is enabled
344
+ if self .include_usage :
345
+ usage = self ._prepare_usage_data (
346
+ usage_data , prompt_timer .duration , completion_timer .duration
347
+ )
348
+ conversation .add_message (AgentMessage (content = message_content , usage = usage ))
349
+ else :
350
+ conversation .add_message (AgentMessage (content = message_content ))
333
351
334
352
def batch (
335
353
self ,
0 commit comments