Skip to content

Commit 6bb6ba7

Browse files
authored
Merge pull request #1266 from MichaelDecent/generic_llm
Added Generic LLM 🐝
2 parents 0e980df + c3d6a89 commit 6bb6ba7

17 files changed

+1139
-162
lines changed

pkgs/base/swarmauri_base/llms/LLMBase.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from abc import abstractmethod
2-
from typing import Optional, List, Literal
3-
from pydantic import ConfigDict, model_validator, Field
2+
from typing import Dict, List, Literal, Optional
43

4+
from pydantic import ConfigDict, Field, PrivateAttr, SecretStr, model_validator
55
from swarmauri_core.llms.IPredict import IPredict
6+
67
from swarmauri_base.ComponentBase import ComponentBase, ResourceTypes
78

89

@@ -13,6 +14,15 @@ class LLMBase(IPredict, ComponentBase):
1314
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
1415
type: Literal["LLMBase"] = "LLMBase"
1516

17+
api_key: Optional[SecretStr] = None
18+
name: str = ""
19+
timeout: float = 600.0
20+
include_usage: bool = True
21+
22+
# Base URL to be overridden by subclasses
23+
BASE_URL: Optional[str] = None
24+
_headers: Dict[str, str] = PrivateAttr(default=None)
25+
1626
@model_validator(mode="after")
1727
@classmethod
1828
def _validate_name_in_allowed_models(cls, values):

pkgs/swarmauri_standard/swarmauri_standard/llms/AI21StudioModel.py

+33-15
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import asyncio
22
import json
3-
from typing import AsyncIterator, Iterator, List, Literal, Type
3+
from typing import AsyncIterator, Iterator, List, Type
44

55
import httpx
6-
from pydantic import PrivateAttr, SecretStr
6+
from pydantic import PrivateAttr
7+
from swarmauri_base.ComponentBase import ComponentBase
78
from swarmauri_base.llms.LLMBase import LLMBase
89
from swarmauri_base.messages.MessageBase import MessageBase
9-
from swarmauri_base.ComponentBase import ComponentBase
1010

1111
from swarmauri_standard.conversations.Conversation import Conversation
1212
from swarmauri_standard.messages.AgentMessage import AgentMessage, UsageData
@@ -32,6 +32,7 @@ class AI21StudioModel(LLMBase):
3232
Provider resources: https://docs.ai21.com/reference/jamba-15-api-ref
3333
"""
3434

35+
3536
api_key: SecretStr
3637
allowed_models: List[str] = ["jamba-1.5-large", "jamba-1.5-mini"]
3738
name: str = "jamba-1.5-large"
@@ -41,7 +42,6 @@ class AI21StudioModel(LLMBase):
4142
_BASE_URL: str = PrivateAttr(
4243
default="https://api.ai21.com/studio/v1/chat/completions"
4344
)
44-
timeout: float = 600.0
4545

4646
def __init__(self, **data) -> None:
4747
"""
@@ -145,9 +145,12 @@ def predict(
145145
message_content = response_data["choices"][0]["message"]["content"]
146146
usage_data = response_data.get("usage", {})
147147

148-
usage = self._prepare_usage_data(usage_data, prompt_timer.duration)
149-
conversation.add_message(AgentMessage(content=message_content, usage=usage))
150-
148+
# Prepare usage data if tracking is enabled
149+
if self.include_usage:
150+
usage = self._prepare_usage_data(usage_data, prompt_timer.duration)
151+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
152+
else:
153+
conversation.add_message(AgentMessage(content=message_content))
151154
return conversation
152155

153156
@retry_on_status_codes((429, 529), max_retries=1)
@@ -194,9 +197,12 @@ async def apredict(
194197
message_content = response_data["choices"][0]["message"]["content"]
195198
usage_data = response_data.get("usage", {})
196199

197-
usage = self._prepare_usage_data(usage_data, prompt_timer.duration)
198-
conversation.add_message(AgentMessage(content=message_content, usage=usage))
199-
200+
# Prepare usage data if tracking is enabled
201+
if self.include_usage:
202+
usage = self._prepare_usage_data(usage_data, prompt_timer.duration)
203+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
204+
else:
205+
conversation.add_message(AgentMessage(content=message_content))
200206
return conversation
201207

202208
@retry_on_status_codes((429, 529), max_retries=1)
@@ -265,6 +271,15 @@ def stream(
265271

266272
conversation.add_message(AgentMessage(content=message_content, usage=usage))
267273

274+
# Prepare usage data if tracking is enabled
275+
if self.include_usage:
276+
usage = self._prepare_usage_data(
277+
usage_data, prompt_timer.duration, completion_timer.duration
278+
)
279+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
280+
else:
281+
conversation.add_message(AgentMessage(content=message_content))
282+
268283
@retry_on_status_codes((429, 529), max_retries=1)
269284
async def astream(
270285
self,
@@ -325,11 +340,14 @@ async def astream(
325340
except json.JSONDecodeError:
326341
pass
327342

328-
usage = self._prepare_usage_data(
329-
usage_data, prompt_timer.duration, completion_timer.duration
330-
)
331-
332-
conversation.add_message(AgentMessage(content=message_content, usage=usage))
343+
# Prepare usage data if tracking is enabled
344+
if self.include_usage:
345+
usage = self._prepare_usage_data(
346+
usage_data, prompt_timer.duration, completion_timer.duration
347+
)
348+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
349+
else:
350+
conversation.add_message(AgentMessage(content=message_content))
333351

334352
def batch(
335353
self,

pkgs/swarmauri_standard/swarmauri_standard/llms/AnthropicModel.py

+31-18
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import json
3-
from typing import AsyncIterator, Dict, Iterator, List, Literal, Type
3+
from typing import AsyncIterator, Dict, Iterator, List, Type
44

55
import httpx
66
from pydantic import PrivateAttr, SecretStr
@@ -32,7 +32,6 @@ class AnthropicModel(LLMBase):
3232
_BASE_URL: str = PrivateAttr("https://api.anthropic.com/v1")
3333
_client: httpx.Client = PrivateAttr()
3434
_async_client: httpx.AsyncClient = PrivateAttr()
35-
3635
api_key: SecretStr
3736
allowed_models: List[str] = [
3837
"claude-3-7-sonnet-latest",
@@ -171,11 +170,15 @@ def predict(
171170
message_content = response_data["content"][0]["text"]
172171

173172
usage_data = response_data["usage"]
174-
usage = self._prepare_usage_data(
175-
usage_data, prompt_timer.duration, completion_timer.duration
176-
)
177173

178-
conversation.add_message(AgentMessage(content=message_content, usage=usage))
174+
if self.include_usage:
175+
usage = self._prepare_usage_data(
176+
usage_data, prompt_timer.duration, completion_timer.duration
177+
)
178+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
179+
else:
180+
conversation.add_message(AgentMessage(content=message_content))
181+
179182
return conversation
180183

181184
@retry_on_status_codes((429, 529), max_retries=1)
@@ -254,10 +257,13 @@ def stream(
254257
except (json.JSONDecodeError, KeyError):
255258
continue
256259

257-
usage = self._prepare_usage_data(
258-
usage_data, prompt_timer.duration, completion_timer.duration
259-
)
260-
conversation.add_message(AgentMessage(content=message_content, usage=usage))
260+
if self.include_usage:
261+
usage = self._prepare_usage_data(
262+
usage_data, prompt_timer.duration, completion_timer.duration
263+
)
264+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
265+
else:
266+
conversation.add_message(AgentMessage(content=message_content))
261267

262268
@retry_on_status_codes((429, 529), max_retries=1)
263269
async def apredict(
@@ -296,11 +302,15 @@ async def apredict(
296302
message_content = response_data["content"][0]["text"]
297303

298304
usage_data = response_data["usage"]
299-
usage = self._prepare_usage_data(
300-
usage_data, prompt_timer.duration, completion_timer.duration
301-
)
302305

303-
conversation.add_message(AgentMessage(content=message_content, usage=usage))
306+
if self.include_usage:
307+
usage = self._prepare_usage_data(
308+
usage_data, prompt_timer.duration, completion_timer.duration
309+
)
310+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
311+
else:
312+
conversation.add_message(AgentMessage(content=message_content))
313+
304314
return conversation
305315

306316
@retry_on_status_codes((429, 529), max_retries=1)
@@ -381,10 +391,13 @@ async def astream(
381391
except (json.JSONDecodeError, KeyError):
382392
continue
383393

384-
usage = self._prepare_usage_data(
385-
usage_data, prompt_timer.duration, completion_timer.duration
386-
)
387-
conversation.add_message(AgentMessage(content=message_content, usage=usage))
394+
if self.include_usage:
395+
usage = self._prepare_usage_data(
396+
usage_data, prompt_timer.duration, completion_timer.duration
397+
)
398+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
399+
else:
400+
conversation.add_message(AgentMessage(content=message_content))
388401

389402
def batch(
390403
self, conversations: List[Conversation], temperature=0.7, max_tokens=256

pkgs/swarmauri_standard/swarmauri_standard/llms/CohereModel.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import json
3-
from typing import AsyncIterator, Dict, Iterator, List, Literal
3+
from typing import AsyncIterator, Dict, Iterator, List
44

55
import httpx
66
from pydantic import PrivateAttr, SecretStr
@@ -47,7 +47,6 @@ class CohereModel(LLMBase):
4747
type: Literal["CohereModel"] = "CohereModel"
4848

4949
timeout: float = 600.0
50-
5150
def __init__(self, **data):
5251
"""
5352
Initialize the CohereModel with the provided configuration.
@@ -192,11 +191,14 @@ def predict(self, conversation, temperature=0.7, max_tokens=256):
192191

193192
usage_data = data.get("usage", {})
194193

195-
usage = self._prepare_usage_data(
196-
usage_data, prompt_timer.duration, completion_timer.duration
197-
)
194+
if self.include_usage:
195+
usage = self._prepare_usage_data(
196+
usage_data, prompt_timer.duration, completion_timer.duration
197+
)
198+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
199+
else:
200+
conversation.add_message(AgentMessage(content=message_content))
198201

199-
conversation.add_message(AgentMessage(content=message_content, usage=usage))
200202
return conversation
201203

202204
@retry_on_status_codes((429, 529), max_retries=1)
@@ -249,12 +251,15 @@ async def apredict(self, conversation, temperature=0.7, max_tokens=256):
249251

250252
usage_data = data.get("usage", {})
251253

254+
if self.include_usage:
252255
usage = self._prepare_usage_data(
253256
usage_data, prompt_timer.duration, completion_timer.duration
254257
)
255-
256258
conversation.add_message(AgentMessage(content=message_content, usage=usage))
257-
return conversation
259+
else:
260+
conversation.add_message(AgentMessage(content=message_content))
261+
262+
return conversation
258263

259264
@retry_on_status_codes((429, 529), max_retries=1)
260265
def stream(self, conversation, temperature=0.7, max_tokens=256) -> Iterator[str]:
@@ -316,12 +321,15 @@ def stream(self, conversation, temperature=0.7, max_tokens=256) -> Iterator[str]
316321
elif "usage" in chunk:
317322
usage_data = chunk["usage"]
318323

319-
full_content = "".join(collected_content)
320-
usage = self._prepare_usage_data(
321-
usage_data, prompt_timer.duration, completion_timer.duration
322-
)
324+
message_content = "".join(collected_content)
323325

324-
conversation.add_message(AgentMessage(content=full_content, usage=usage))
326+
if self.include_usage:
327+
usage = self._prepare_usage_data(
328+
usage_data, prompt_timer.duration, completion_timer.duration
329+
)
330+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
331+
else:
332+
conversation.add_message(AgentMessage(content=message_content))
325333

326334
@retry_on_status_codes((429, 529), max_retries=1)
327335
async def astream(
@@ -395,12 +403,14 @@ async def astream(
395403
except json.JSONDecodeError:
396404
continue
397405

398-
full_content = "".join(collected_content)
406+
message_content = "".join(collected_content)
407+
if self.include_usage:
399408
usage = self._prepare_usage_data(
400409
usage_data, prompt_timer.duration, completion_timer.duration
401410
)
402-
403-
conversation.add_message(AgentMessage(content=full_content, usage=usage))
411+
conversation.add_message(AgentMessage(content=message_content, usage=usage))
412+
else:
413+
conversation.add_message(AgentMessage(content=message_content))
404414

405415
def batch(self, conversations: List, temperature=0.7, max_tokens=256) -> List:
406416
"""

pkgs/swarmauri_standard/swarmauri_standard/llms/DeepInfraModel.py

-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ class DeepInfraModel(LLMBase):
9595
type: Literal["DeepInfraModel"] = "DeepInfraModel"
9696

9797
timeout: float = 600.0
98-
9998
def __init__(self, **data):
10099
"""
101100
Initializes the DeepInfraModel instance with the provided API key

pkgs/swarmauri_standard/swarmauri_standard/llms/DeepSeekModel.py

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class DeepSeekModel(LLMBase):
3333

3434
_BASE_URL: str = PrivateAttr("https://api.deepseek.com/v1")
3535

36+
3637
api_key: SecretStr
3738
allowed_models: List[str] = ["deepseek-chat", "deepseek-reasoner"]
3839
name: str = "deepseek-chat"

0 commit comments

Comments
 (0)