Skip to content

Commit 9fece56

Browse files
committed
Fix issues
1 parent f61213c commit 9fece56

File tree

4 files changed

+66
-67
lines changed

4 files changed

+66
-67
lines changed

crab/agents/backend_models/claude_model.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,32 @@ class ClaudeModel(BackendModel):
3131
def __init__(
3232
self,
3333
model: str,
34-
parameters: dict[str, Any] = dict(),
34+
parameters: dict[str, Any] | None = None,
3535
history_messages_len: int = 0,
3636
tool_call_required: bool = False,
3737
) -> None:
3838
if anthropic_model_enable is False:
3939
raise ImportError("Please install anthropic to use ClaudeModel")
40-
super().__init__(
41-
model,
42-
parameters,
43-
history_messages_len,
44-
)
40+
self.model = model
41+
self.parameters = parameters if parameters is not None else {}
42+
self.history_messages_len = history_messages_len
43+
44+
assert self.history_messages_len >= 0
45+
4546
self.client = anthropic.Anthropic()
46-
self.tool_call_required = tool_call_required
47+
self.tool_call_required: bool = tool_call_required
48+
self.system_message: str = "You are a helpful assistant."
49+
self.action_space: list[Action] | None = None
50+
self.action_schema: list[dict] | None = None
51+
self.token_usage: int = 0
52+
self.chat_history: list[list[dict]] = []
4753

4854
def reset(self, system_message: str, action_space: list[Action] | None) -> None:
4955
self.system_message = system_message
5056
self.action_space = action_space
5157
self.action_schema = _convert_action_to_schema(self.action_space)
5258
self.token_usage = 0
53-
self.chat_history: list[list[dict]] = []
59+
self.chat_history = []
5460

5561
def chat(self, message: list[Message] | Message) -> BackendOutput:
5662
if isinstance(message, tuple):
@@ -92,8 +98,8 @@ def construct_new_message(self, message: list[Message]) -> dict[str, Any]:
9298
def fetch_from_memory(self) -> list[dict]:
9399
request: list[dict] = []
94100
if self.history_messages_len > 0:
95-
fetch_hisotry_len = min(self.history_messages_len, len(self.chat_history))
96-
for history_message in self.chat_history[-fetch_hisotry_len:]:
101+
fetch_history_len = min(self.history_messages_len, len(self.chat_history))
102+
for history_message in self.chat_history[-fetch_history_len:]:
97103
request = request + history_message
98104
return request
99105

crab/agents/backend_models/gemini_model.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,27 +40,32 @@ class GeminiModel(BackendModel):
4040
def __init__(
4141
self,
4242
model: str,
43-
parameters: dict[str, Any] = dict(),
43+
parameters: dict[str, Any] | None = None,
4444
history_messages_len: int = 0,
4545
tool_call_required: bool = False,
4646
) -> None:
4747
if gemini_model_enable is False:
4848
raise ImportError("Please install google.generativeai to use GeminiModel")
49-
super().__init__(
50-
model,
51-
parameters,
52-
history_messages_len,
53-
)
49+
50+
self.model = model
51+
self.parameters = parameters if parameters is not None else {}
52+
self.history_messages_len = history_messages_len
53+
assert self.history_messages_len >= 0
5454
genai.configure(api_key=os.environ["GEMINI_API_KEY"])
5555
self.client = genai
5656
self.tool_call_required = tool_call_required
57+
self.system_message: str = "You are a helpful assistant."
58+
self.action_space: list[Action] | None = None
59+
self.action_schema: list[Tool] | None = None
60+
self.token_usage: int = 0
61+
self.chat_history: list[list[dict]] = []
5762

5863
def reset(self, system_message: str, action_space: list[Action] | None) -> None:
5964
self.system_message = system_message
6065
self.action_space = action_space
6166
self.action_schema = _convert_action_to_schema(self.action_space)
6267
self.token_usage = 0
63-
self.chat_history: list[list[dict]] = []
68+
self.chat_history = []
6469

6570
def chat(self, message: list[Message] | Message) -> BackendOutput:
6671
if isinstance(message, tuple):
@@ -105,8 +110,8 @@ def generate_backend_output(self, response_message: Content) -> BackendOutput:
105110
def fetch_from_memory(self) -> list[dict]:
106111
request: list[dict] = []
107112
if self.history_messages_len > 0:
108-
fetch_hisotry_len = min(self.history_messages_len, len(self.chat_history))
109-
for history_message in self.chat_history[-fetch_hisotry_len:]:
113+
fetch_history_len = min(self.history_messages_len, len(self.chat_history))
114+
for history_message in self.chat_history[-fetch_history_len:]:
110115
request = request + history_message
111116
return request
112117

@@ -161,7 +166,7 @@ def _convert_action_to_schema(action_space: list[Action] | None) -> list[Tool] |
161166
actions = [
162167
Tool(
163168
function_declarations=[
164-
_action_to_funcdec(action) for action in action_space
169+
_action_to_func_dec(action) for action in action_space
165170
]
166171
)
167172
]
@@ -179,7 +184,7 @@ def _clear_schema(schema_dict: dict) -> None:
179184
_clear_schema(schema_dict["items"])
180185

181186

182-
def _action_to_funcdec(action: Action) -> FunctionDeclaration:
187+
def _action_to_func_dec(action: Action) -> FunctionDeclaration:
183188
"Converts crab Action to google FunctionDeclaration"
184189
p_schema = action.parameters.model_json_schema()
185190
if "$defs" in p_schema:

crab/agents/backend_models/openai_model.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -29,27 +29,26 @@ class OpenAIModel(BackendModel):
2929
def __init__(
3030
self,
3131
model: str,
32-
parameters: dict[str, Any] = dict(),
32+
parameters: dict[str, Any] | None = None,
3333
history_messages_len: int = 0,
3434
tool_call_required: bool = False,
35+
base_url: str | None = None,
3536
) -> None:
3637
if not openai_model_enable:
3738
raise ImportError("Please install openai to use OpenAIModel")
38-
super().__init__(
39-
model,
40-
parameters,
41-
history_messages_len,
42-
)
43-
self.client = openai.OpenAI()
44-
self.tool_call_required = tool_call_required
45-
self.system_message = "You are a helpful assistant."
46-
self.openai_system_message = {
47-
"role": "system",
48-
"content": self.system_message,
49-
}
50-
self.action_space = None
51-
self.action_schema = None
52-
self.token_usage = 0
39+
40+
self.model = model
41+
self.parameters = parameters if parameters is not None else {}
42+
self.history_messages_len = history_messages_len
43+
44+
assert self.history_messages_len >= 0
45+
46+
self.client = openai.OpenAI(base_url=base_url)
47+
self.tool_call_required: bool = tool_call_required
48+
self.system_message: str = "You are a helpful assistant."
49+
self.action_space: list[Action] | None = None
50+
self.action_schema: list[dict] | None = None
51+
self.token_usage: int = 0
5352
self.chat_history: list[list[ChatCompletionMessage | dict]] = []
5453

5554
def reset(self, system_message: str, action_space: list[Action] | None) -> None:
@@ -59,9 +58,9 @@ def reset(self, system_message: str, action_space: list[Action] | None) -> None:
5958
"content": system_message,
6059
}
6160
self.action_space = action_space
62-
self.action_schema = self._convert_action_to_schema(self.action_space)
61+
self.action_schema = _convert_action_to_schema(self.action_space)
6362
self.token_usage = 0
64-
self.chat_history: list[list[ChatCompletionMessage | dict]] = []
63+
self.chat_history = []
6564

6665
def chat(self, message: list[Message] | Message) -> BackendOutput:
6766
if isinstance(message, tuple):
@@ -93,10 +92,12 @@ def record_message(
9392
}
9493
) # extend conversation with function response
9594

96-
def call_api(self, request_messages: list) -> ChatCompletionMessage:
95+
def call_api(
96+
self, request_messages: list[ChatCompletionMessage | dict]
97+
) -> ChatCompletionMessage:
9798
if self.action_schema is not None:
9899
response = self.client.chat.completions.create(
99-
messages=request_messages,
100+
messages=request_messages, # type: ignore
100101
model=self.model,
101102
tools=self.action_schema,
102103
tool_choice="required" if self.tool_call_required else "auto",
@@ -115,8 +116,8 @@ def call_api(self, request_messages: list) -> ChatCompletionMessage:
115116
def fetch_from_memory(self) -> list[ChatCompletionMessage | dict]:
116117
request: list[ChatCompletionMessage | dict] = [self.openai_system_message]
117118
if self.history_messages_len > 0:
118-
fetch_hisotry_len = min(self.history_messages_len, len(self.chat_history))
119-
for history_message in self.chat_history[-fetch_hisotry_len:]:
119+
fetch_history_len = min(self.history_messages_len, len(self.chat_history))
120+
for history_message in self.chat_history[-fetch_history_len:]:
120121
request = request + history_message
121122
return request
122123

@@ -161,12 +162,14 @@ def generate_backend_output(
161162
action_list=action_list,
162163
)
163164

164-
@staticmethod
165-
def _convert_action_to_schema(action_space):
166-
if action_space is None:
167-
return None
168-
actions = []
169-
for action in action_space:
170-
new_action = action.to_openai_json_schema()
171-
actions.append({"type": "function", "function": new_action})
172-
return actions
165+
166+
def _convert_action_to_schema(
167+
action_space: list[Action] | None,
168+
) -> list[dict] | None:
169+
if action_space is None:
170+
return None
171+
actions = []
172+
for action in action_space:
173+
new_action = action.to_openai_json_schema()
174+
actions.append({"type": "function", "function": new_action})
175+
return actions

crab/core/backend_model.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,11 @@
1212
# limitations under the License.
1313
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
1414
from abc import ABC, abstractmethod
15-
from typing import Any
1615

1716
from .models import Action, BackendOutput, MessageType
1817

1918

2019
class BackendModel(ABC):
21-
def __init__(
22-
self,
23-
model: str,
24-
parameters: dict[str, Any] = dict(),
25-
history_messages_len: int = 0,
26-
) -> None:
27-
self.model = model
28-
self.parameters = parameters
29-
self.history_messages_len = history_messages_len
30-
31-
assert self.history_messages_len >= 0
32-
33-
self.reset("You are a helpful assistant.", None)
34-
3520
@abstractmethod
3621
def chat(self, contents: list[tuple[str, MessageType]]) -> BackendOutput: ...
3722

0 commit comments

Comments
 (0)