Skip to content

Commit 04d1fe9

Browse files
committed
refactor(llms): remove unused attributes from LLM classes and streamline usage handling
1 parent 91eb184 commit 04d1fe9

14 files changed

+251
-227
lines changed

pkgs/swarmauri_standard/swarmauri_standard/llms/AI21StudioModel.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import asyncio
22
import json
3-
from typing import AsyncIterator, Iterator, List, Literal, Type
3+
from typing import AsyncIterator, Iterator, List, Type
44

55
import httpx
6-
from pydantic import PrivateAttr, SecretStr
6+
from pydantic import PrivateAttr
7+
from swarmauri_base.ComponentBase import ComponentBase
78
from swarmauri_base.llms.LLMBase import LLMBase
89
from swarmauri_base.messages.MessageBase import MessageBase
9-
from swarmauri_base.ComponentBase import ComponentBase
1010

1111
from swarmauri_standard.conversations.Conversation import Conversation
1212
from swarmauri_standard.messages.AgentMessage import AgentMessage, UsageData
@@ -32,16 +32,11 @@ class AI21StudioModel(LLMBase):
3232
Provider resources: https://docs.ai21.com/reference/jamba-15-api-ref
3333
"""
3434

35-
api_key: SecretStr
36-
allowed_models: List[str] = []
37-
name: str = ""
38-
type: Literal["AI21StudioModel"] = "AI21StudioModel"
3935
_client: httpx.Client = PrivateAttr(default=None)
4036
_async_client: httpx.AsyncClient = PrivateAttr(default=None)
4137
_BASE_URL: str = PrivateAttr(
4238
default="https://api.ai21.com/studio/v1/chat/completions"
4339
)
44-
timeout: float = 600.0
4540

4641
def __init__(self, **data) -> None:
4742
"""
@@ -147,9 +142,12 @@ def predict(
147142
message_content = response_data["choices"][0]["message"]["content"]
148143
usage_data = response_data.get("usage", {})
149144

150-
usage = self._prepare_usage_data(usage_data, prompt_timer.duration)
151-
conversation.add_message(AgentMessage(content=message_content, usage=usage))
152-
145+
# Prepare usage data if tracking is enabled
146+
if self.include_usage:
147+
usage = self._prepare_usage_data(usage_data, prompt_timer.duration)
148+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
149+
else:
150+
conversation.add_message(AgentMessage(content=message_content))
153151
return conversation
154152

155153
@retry_on_status_codes((429, 529), max_retries=1)
@@ -196,9 +194,12 @@ async def apredict(
196194
message_content = response_data["choices"][0]["message"]["content"]
197195
usage_data = response_data.get("usage", {})
198196

199-
usage = self._prepare_usage_data(usage_data, prompt_timer.duration)
200-
conversation.add_message(AgentMessage(content=message_content, usage=usage))
201-
197+
# Prepare usage data if tracking is enabled
198+
if self.include_usage:
199+
usage = self._prepare_usage_data(usage_data, prompt_timer.duration)
200+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
201+
else:
202+
conversation.add_message(AgentMessage(content=message_content))
202203
return conversation
203204

204205
@retry_on_status_codes((429, 529), max_retries=1)
@@ -267,6 +268,15 @@ def stream(
267268

268269
conversation.add_message(AgentMessage(content=message_content, usage=usage))
269270

271+
# Prepare usage data if tracking is enabled
272+
if self.include_usage:
273+
usage = self._prepare_usage_data(
274+
usage_data, prompt_timer.duration, completion_timer.duration
275+
)
276+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
277+
else:
278+
conversation.add_message(AgentMessage(content=message_content))
279+
270280
@retry_on_status_codes((429, 529), max_retries=1)
271281
async def astream(
272282
self,
@@ -327,11 +337,14 @@ async def astream(
327337
except json.JSONDecodeError:
328338
pass
329339

330-
usage = self._prepare_usage_data(
331-
usage_data, prompt_timer.duration, completion_timer.duration
332-
)
333-
334-
conversation.add_message(AgentMessage(content=message_content, usage=usage))
340+
# Prepare usage data if tracking is enabled
341+
if self.include_usage:
342+
usage = self._prepare_usage_data(
343+
usage_data, prompt_timer.duration, completion_timer.duration
344+
)
345+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
346+
else:
347+
conversation.add_message(AgentMessage(content=message_content))
335348

336349
def batch(
337350
self,

pkgs/swarmauri_standard/swarmauri_standard/llms/AnthropicModel.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import asyncio
22
import json
3-
from typing import AsyncIterator, Dict, Iterator, List, Literal, Type
3+
from typing import AsyncIterator, Dict, Iterator, List, Type
44

55
import httpx
6-
from pydantic import PrivateAttr, SecretStr
6+
from pydantic import PrivateAttr
7+
from swarmauri_base.ComponentBase import ComponentBase
78
from swarmauri_base.llms.LLMBase import LLMBase
89
from swarmauri_base.messages.MessageBase import MessageBase
9-
from swarmauri_base.ComponentBase import ComponentBase
1010

1111
from swarmauri_standard.conversations.Conversation import Conversation
1212
from swarmauri_standard.messages.AgentMessage import AgentMessage, UsageData
@@ -33,13 +33,6 @@ class AnthropicModel(LLMBase):
3333
_client: httpx.Client = PrivateAttr()
3434
_async_client: httpx.AsyncClient = PrivateAttr()
3535

36-
api_key: SecretStr
37-
allowed_models: List[str] = []
38-
name: str = ""
39-
type: Literal["AnthropicModel"] = "AnthropicModel"
40-
41-
timeout: float = 600.0
42-
4336
def __init__(self, **data):
4437
super().__init__(**data)
4538
headers = {
@@ -162,11 +155,15 @@ def predict(
162155
message_content = response_data["content"][0]["text"]
163156

164157
usage_data = response_data["usage"]
165-
usage = self._prepare_usage_data(
166-
usage_data, prompt_timer.duration, completion_timer.duration
167-
)
168158

169-
conversation.add_message(AgentMessage(content=message_content, usage=usage))
159+
if self.include_usage:
160+
usage = self._prepare_usage_data(
161+
usage_data, prompt_timer.duration, completion_timer.duration
162+
)
163+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
164+
else:
165+
conversation.add_message(AgentMessage(content=message_content))
166+
170167
return conversation
171168

172169
@retry_on_status_codes((429, 529), max_retries=1)
@@ -245,10 +242,13 @@ def stream(
245242
except (json.JSONDecodeError, KeyError):
246243
continue
247244

248-
usage = self._prepare_usage_data(
249-
usage_data, prompt_timer.duration, completion_timer.duration
250-
)
251-
conversation.add_message(AgentMessage(content=message_content, usage=usage))
245+
if self.include_usage:
246+
usage = self._prepare_usage_data(
247+
usage_data, prompt_timer.duration, completion_timer.duration
248+
)
249+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
250+
else:
251+
conversation.add_message(AgentMessage(content=message_content))
252252

253253
@retry_on_status_codes((429, 529), max_retries=1)
254254
async def apredict(
@@ -287,11 +287,15 @@ async def apredict(
287287
message_content = response_data["content"][0]["text"]
288288

289289
usage_data = response_data["usage"]
290-
usage = self._prepare_usage_data(
291-
usage_data, prompt_timer.duration, completion_timer.duration
292-
)
293290

294-
conversation.add_message(AgentMessage(content=message_content, usage=usage))
291+
if self.include_usage:
292+
usage = self._prepare_usage_data(
293+
usage_data, prompt_timer.duration, completion_timer.duration
294+
)
295+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
296+
else:
297+
conversation.add_message(AgentMessage(content=message_content))
298+
295299
return conversation
296300

297301
@retry_on_status_codes((429, 529), max_retries=1)
@@ -372,10 +376,13 @@ async def astream(
372376
except (json.JSONDecodeError, KeyError):
373377
continue
374378

375-
usage = self._prepare_usage_data(
376-
usage_data, prompt_timer.duration, completion_timer.duration
377-
)
378-
conversation.add_message(AgentMessage(content=message_content, usage=usage))
379+
if self.include_usage:
380+
usage = self._prepare_usage_data(
381+
usage_data, prompt_timer.duration, completion_timer.duration
382+
)
383+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
384+
else:
385+
conversation.add_message(AgentMessage(content=message_content))
379386

380387
def batch(
381388
self, conversations: List[Conversation], temperature=0.7, max_tokens=256

pkgs/swarmauri_standard/swarmauri_standard/llms/CohereModel.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import asyncio
22
import json
3-
from typing import AsyncIterator, Dict, Iterator, List, Literal
3+
from typing import AsyncIterator, Dict, Iterator, List
44

55
import httpx
6-
from pydantic import PrivateAttr, SecretStr
6+
from pydantic import PrivateAttr
7+
from swarmauri_base.ComponentBase import ComponentBase
78
from swarmauri_base.llms.LLMBase import LLMBase
89
from swarmauri_base.messages.MessageBase import MessageBase
9-
from swarmauri_base.ComponentBase import ComponentBase
1010

1111
from swarmauri_standard.messages.AgentMessage import AgentMessage, UsageData
1212
from swarmauri_standard.utils.duration_manager import DurationManager
@@ -32,13 +32,6 @@ class CohereModel(LLMBase):
3232
_BASE_URL: str = PrivateAttr("https://api.cohere.ai/v1")
3333
_client: httpx.Client = PrivateAttr()
3434

35-
api_key: SecretStr
36-
allowed_models: List[str] = []
37-
name: str = ""
38-
type: Literal["CohereModel"] = "CohereModel"
39-
40-
timeout: float = 600.0
41-
4235
def __init__(self, **data):
4336
"""
4437
Initialize the CohereModel with the provided configuration.
@@ -185,11 +178,14 @@ def predict(self, conversation, temperature=0.7, max_tokens=256):
185178

186179
usage_data = data.get("usage", {})
187180

188-
usage = self._prepare_usage_data(
189-
usage_data, prompt_timer.duration, completion_timer.duration
190-
)
181+
if self.include_usage:
182+
usage = self._prepare_usage_data(
183+
usage_data, prompt_timer.duration, completion_timer.duration
184+
)
185+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
186+
else:
187+
conversation.add_message(AgentMessage(content=message_content))
191188

192-
conversation.add_message(AgentMessage(content=message_content, usage=usage))
193189
return conversation
194190

195191
@retry_on_status_codes((429, 529), max_retries=1)
@@ -242,12 +238,15 @@ async def apredict(self, conversation, temperature=0.7, max_tokens=256):
242238

243239
usage_data = data.get("usage", {})
244240

241+
if self.include_usage:
245242
usage = self._prepare_usage_data(
246243
usage_data, prompt_timer.duration, completion_timer.duration
247244
)
248-
249245
conversation.add_message(AgentMessage(content=message_content, usage=usage))
250-
return conversation
246+
else:
247+
conversation.add_message(AgentMessage(content=message_content))
248+
249+
return conversation
251250

252251
@retry_on_status_codes((429, 529), max_retries=1)
253252
def stream(self, conversation, temperature=0.7, max_tokens=256) -> Iterator[str]:
@@ -309,12 +308,15 @@ def stream(self, conversation, temperature=0.7, max_tokens=256) -> Iterator[str]
309308
elif "usage" in chunk:
310309
usage_data = chunk["usage"]
311310

312-
full_content = "".join(collected_content)
313-
usage = self._prepare_usage_data(
314-
usage_data, prompt_timer.duration, completion_timer.duration
315-
)
311+
message_content = "".join(collected_content)
316312

317-
conversation.add_message(AgentMessage(content=full_content, usage=usage))
313+
if self.include_usage:
314+
usage = self._prepare_usage_data(
315+
usage_data, prompt_timer.duration, completion_timer.duration
316+
)
317+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
318+
else:
319+
conversation.add_message(AgentMessage(content=message_content))
318320

319321
@retry_on_status_codes((429, 529), max_retries=1)
320322
async def astream(
@@ -388,12 +390,14 @@ async def astream(
388390
except json.JSONDecodeError:
389391
continue
390392

391-
full_content = "".join(collected_content)
393+
message_content = "".join(collected_content)
394+
if self.include_usage:
392395
usage = self._prepare_usage_data(
393396
usage_data, prompt_timer.duration, completion_timer.duration
394397
)
395-
396-
conversation.add_message(AgentMessage(content=full_content, usage=usage))
398+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
399+
else:
400+
conversation.add_message(AgentMessage(content=message_content))
397401

398402
def batch(self, conversations: List, temperature=0.7, max_tokens=256) -> List:
399403
"""

pkgs/swarmauri_standard/swarmauri_standard/llms/DeepInfraModel.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,6 @@ class DeepInfraModel(LLMBase):
4040
_client: httpx.Client = PrivateAttr(default=None)
4141
_async_client: httpx.AsyncClient = PrivateAttr(default=None)
4242

43-
api_key: SecretStr
44-
allowed_models: List[str] = []
45-
46-
name: str = ""
47-
48-
type: Literal["DeepInfraModel"] = "DeepInfraModel"
49-
50-
timeout: float = 600.0
51-
5243
def __init__(self, **data):
5344
"""
5445
Initializes the DeepInfraModel instance with the provided API key

pkgs/swarmauri_standard/swarmauri_standard/llms/DeepSeekModel.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,6 @@ class DeepSeekModel(LLMBase):
3333

3434
_BASE_URL: str = PrivateAttr("https://api.deepseek.com/v1")
3535

36-
api_key: SecretStr
37-
allowed_models: List[str] = []
38-
name: str = ""
39-
40-
type: Literal["DeepSeekModel"] = "DeepSeekModel"
4136
_client: httpx.Client = PrivateAttr()
4237
_async_client: httpx.AsyncClient = PrivateAttr()
4338

0 commit comments

Comments
 (0)