Skip to content

Commit aff284e

Browse files
committed
Fix all agent tests and add create_backend_model function
1 parent 138c1d8 commit aff284e

File tree

15 files changed

+218
-126
lines changed

15 files changed

+218
-126
lines changed

crab/agents/backend_models/__init__.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,47 @@
1212
# limitations under the License.
1313
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
1414
# ruff: noqa: F401
15+
from typing import Any, Literal
16+
17+
from pydantic import BaseModel
18+
19+
from crab.core.backend_model import BackendModel
20+
1521
from .camel_model import CamelModel
1622
from .claude_model import ClaudeModel
1723
from .gemini_model import GeminiModel
1824
from .openai_model import OpenAIModel
25+
26+
27+
class BackendModelConfig(BaseModel):
28+
model_class: Literal["openai", "claude", "gemini", "camel"]
29+
model_name: str
30+
history_messages_len: int = 0
31+
parameters: dict[str, Any] = {}
32+
tool_call_required: bool = False
33+
34+
35+
def create_backend_model(model_config: BackendModelConfig) -> BackendModel:
36+
match model_config.model_class:
37+
case "claude":
38+
return ClaudeModel(
39+
model=model_config.model_name,
40+
parameters=model_config.parameters,
41+
history_messages_len=model_config.history_messages_len,
42+
)
43+
case "gemini":
44+
return GeminiModel(
45+
model=model_config.model_name,
46+
parameters=model_config.parameters,
47+
history_messages_len=model_config.history_messages_len,
48+
)
49+
case "openai":
50+
return OpenAIModel(
51+
model=model_config.model_name,
52+
parameters=model_config.parameters,
53+
history_messages_len=model_config.history_messages_len,
54+
)
55+
case "camel":
56+
raise NotImplementedError("Cannot support camel model currently.")
57+
case _:
58+
raise ValueError(f"Unsupported model name: {model_config.model_name}")

crab/agents/backend_models/claude_model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
model: str,
3333
parameters: dict[str, Any] = dict(),
3434
history_messages_len: int = 0,
35+
tool_call_required: bool = False,
3536
) -> None:
3637
if anthropic_model_enable is False:
3738
raise ImportError("Please install anthropic to use ClaudeModel")
@@ -41,6 +42,7 @@ def __init__(
4142
history_messages_len,
4243
)
4344
self.client = anthropic.Anthropic()
45+
self.tool_call_required = tool_call_required
4446

4547
def reset(self, system_message: str, action_space: list[Action] | None) -> None:
4648
self.system_message = system_message
@@ -93,6 +95,7 @@ def record_message(self, new_message: dict, response_message: dict) -> None:
9395
"content": "success",
9496
}
9597
for call in tool_calls
98+
if call is ToolUseBlock
9699
],
97100
}
98101
)
@@ -101,12 +104,14 @@ def call_api(self, request_messages: list):
101104
while True:
102105
try:
103106
if self.action_schema is not None:
104-
response = self.client.beta.tools.messages.create(
107+
response = self.client.messages.create(
105108
system=self.system_message, # <-- system prompt
106109
messages=request_messages, # type: ignore
107110
model=self.model,
108111
tools=self.action_schema,
109-
tool_choice={"type": "any"},
112+
tool_choice={
113+
"type": "any" if self.tool_call_required else "auto"
114+
},
110115
**self.parameters,
111116
)
112117
else:

crab/agents/backend_models/gemini_model.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
model: str,
3636
parameters: dict[str, Any] = dict(),
3737
history_messages_len: int = 0,
38+
tool_call_required: bool = False,
3839
) -> None:
3940
if gemini_model_enable is False:
4041
raise ImportError("Please install google.generativeai to use GeminiModel")
@@ -45,6 +46,7 @@ def __init__(
4546
)
4647
genai.configure(api_key=os.environ["GEMINI_API_KEY"])
4748
self.client = genai
49+
self.tool_call_required = tool_call_required
4850

4951
def reset(self, system_message: str, action_space: list[Action] | None) -> None:
5052
self.system_message = system_message
@@ -98,7 +100,11 @@ def call_api(self, request_messages: list):
98100
try:
99101
if self.action_schema is not None:
100102
tool_config = content_types.to_tool_config(
101-
{"function_calling_config": {"mode": "ANY"}}
103+
{
104+
"function_calling_config": {
105+
"mode": "ANY" if self.tool_call_required else "AUTO"
106+
}
107+
}
102108
)
103109
response = self.client.GenerativeModel(
104110
self.model, system_instruction=self.system_message
@@ -141,9 +147,7 @@ def _convert_action_to_schema(cls, action_space):
141147
return None
142148
actions = []
143149
for action in action_space:
144-
actions.append(
145-
Tool(function_declarations=[cls._action_to_funcdec_policy(action)])
146-
)
150+
actions.append(Tool(function_declarations=[cls._action_to_funcdec(action)]))
147151
return actions
148152

149153
@staticmethod
@@ -171,14 +175,14 @@ def _clear_schema(cls, schema_dict: dict):
171175
cls._clear_schema(schema_dict["items"])
172176

173177
@classmethod
174-
def _action_to_funcdec(cls, action: Action, env: str):
178+
def _action_to_funcdec(cls, action: Action) -> FunctionDeclaration:
175179
"Converts crab Action to google FunctionDeclaration"
176180
p_schema = action.parameters.model_json_schema()
177181
if "$defs" in p_schema:
178182
p_schema = json_expand_refs(p_schema)
179183
cls._clear_schema(p_schema)
180184
return FunctionDeclaration(
181-
name=action.name + "__in__" + env,
182-
description="In {} environment, {}".format(env, action.description),
185+
name=action.name,
186+
description=action.description,
183187
parameters=p_schema,
184188
)

crab/agents/backend_models/openai_model.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
model: str,
3232
parameters: dict[str, Any] = dict(),
3333
history_messages_len: int = 0,
34+
tool_call_required: bool = False,
3435
) -> None:
3536
if not openai_model_enable:
3637
raise ImportError("Please install openai to use OpenAIModel")
@@ -40,6 +41,16 @@ def __init__(
4041
history_messages_len,
4142
)
4243
self.client = openai.OpenAI()
44+
self.tool_call_required = tool_call_required
45+
self.system_message = "You are a helpful assistant."
46+
self.openai_system_message = {
47+
"role": "system",
48+
"content": self.system_message,
49+
}
50+
self.action_space = None
51+
self.action_schema = None
52+
self.token_usage = 0
53+
self.chat_history: list[list[ChatCompletionMessage | dict]] = []
4354

4455
def reset(self, system_message: str, action_space: list[Action] | None) -> None:
4556
self.system_message = system_message
@@ -88,12 +99,12 @@ def call_api(self, request_messages: list) -> ChatCompletionMessage:
8899
messages=request_messages,
89100
model=self.model,
90101
tools=self.action_schema,
91-
tool_choice="required",
102+
tool_choice="required" if self.tool_call_required else "auto",
92103
**self.parameters,
93104
)
94105
else:
95106
response = self.client.chat.completions.create(
96-
messages=request_messages,
107+
messages=request_messages, # type: ignore
97108
model=self.model,
98109
**self.parameters,
99110
)

crab/agents/policies/multi_agent_by_env.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
14-
from copy import copy
15-
1614
from crab import Action, ActionOutput
15+
from crab.agents.backend_models import BackendModelConfig, create_backend_model
16+
from crab.agents.utils import generate_action_prompt
1717
from crab.core.agent_policy import AgentPolicy
1818
from crab.core.backend_model import (
1919
BackendModel,
@@ -57,12 +57,12 @@ class MultiAgentByEnvPolicy(AgentPolicy):
5757

5858
def __init__(
5959
self,
60-
main_agent_model_backend: BackendModel,
61-
env_agent_model_backend: BackendModel,
60+
main_agent_model_backend: BackendModelConfig,
61+
env_agent_model_backend: BackendModelConfig,
6262
):
63-
self.main_agent_model_backend = copy(main_agent_model_backend)
64-
self.env_agent_model_backend = env_agent_model_backend
65-
self.reset(task_description="", action_spaces=None, env_descriptions={})
63+
self.main_agent_model_backend = create_backend_model(main_agent_model_backend)
64+
self.env_agent_model_backend_config = env_agent_model_backend
65+
self.reset(task_description="", action_spaces={}, env_descriptions={})
6666

6767
def reset(
6868
self,
@@ -82,15 +82,16 @@ def reset(
8282
)
8383
self.env_agent_model_backends: dict[str, BackendModel] = {}
8484
for env in action_spaces:
85-
backend = copy(self.env_agent_model_backend)
85+
backend = create_backend_model(self.env_agent_model_backend_config)
8686
if env == "root":
8787
backend.reset(root_agent_system_message, action_spaces[env])
8888
else:
89+
backend.require_tool = True
8990
env_agent_system_message = self._env_agent_prompt.format(
9091
task_description=task_description,
9192
environment=env,
9293
env_description=env_descriptions[env],
93-
action_descriptions=self.generate_action_prompt(action_spaces[env]),
94+
action_descriptions=generate_action_prompt(action_spaces[env]),
9495
)
9596
backend.reset(env_agent_system_message, action_spaces[env])
9697
self.env_agent_model_backends[env] = backend
@@ -140,5 +141,7 @@ def chat(
140141
)
141142
else:
142143
output = backend.chat((main_agent_message, MessageType.TEXT))
144+
for action in output.action_list:
145+
action.env = env
143146
tool_calls.extend(output.action_list)
144-
return self.decode_combined_action(tool_calls)
147+
return tool_calls

crab/agents/policies/multi_agent_by_func.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
14-
from copy import copy
15-
16-
from crab import Action, ActionOutput
17-
from crab.core.agent_policy import AgentPolicy
18-
from crab.core.backend_model import (
19-
BackendModel,
20-
MessageType,
14+
from crab.agents.backend_models import BackendModelConfig, create_backend_model
15+
from crab.agents.utils import (
16+
combine_multi_env_action_space,
17+
decode_combined_action,
18+
generate_action_prompt,
2119
)
20+
from crab.core import Action, ActionOutput
21+
from crab.core.agent_policy import AgentPolicy
22+
from crab.core.backend_model import MessageType
2223

2324

2425
class MultiAgentByFuncPolicy(AgentPolicy):
@@ -40,11 +41,11 @@ class MultiAgentByFuncPolicy(AgentPolicy):
4041

4142
def __init__(
4243
self,
43-
main_agent_model_backend: BackendModel,
44-
tool_agent_model_backend: BackendModel,
44+
main_agent_model_backend: BackendModelConfig,
45+
tool_agent_model_backend: BackendModelConfig,
4546
):
46-
self.main_agent_model_backend = copy(main_agent_model_backend)
47-
self.tool_agent_model_backend = copy(tool_agent_model_backend)
47+
self.main_agent_model_backend = create_backend_model(main_agent_model_backend)
48+
self.tool_agent_model_backend = create_backend_model(tool_agent_model_backend)
4849
self.reset(task_description="", action_spaces=None, env_descriptions={})
4950

5051
def reset(
@@ -54,11 +55,11 @@ def reset(
5455
env_descriptions: dict[str, str],
5556
) -> list[ActionOutput]:
5657
self.task_description = task_description
57-
self.action_space = self.combine_multi_env_action_space(action_spaces)
58+
self.action_space = combine_multi_env_action_space(action_spaces)
5859

5960
main_agent_system_message = self._system_prompt.format(
6061
task_description=task_description,
61-
action_descriptions=self.generate_action_prompt(self.action_space),
62+
action_descriptions=generate_action_prompt(self.action_space),
6263
env_description=str(env_descriptions),
6364
)
6465
self.main_agent_model_backend.reset(main_agent_system_message, None)
@@ -95,4 +96,4 @@ def chat(
9596
tool_output = self.tool_agent_model_backend.chat(
9697
(output.message, MessageType.TEXT)
9798
)
98-
return self.decode_combined_action(tool_output.action_list)
99+
return decode_combined_action(tool_output.action_list)

crab/agents/policies/single_agent.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
14-
from copy import copy
15-
1614
from crab import Action, ActionOutput
15+
from crab.agents.backend_models import BackendModelConfig, create_backend_model
16+
from crab.agents.utils import (
17+
combine_multi_env_action_space,
18+
decode_combined_action,
19+
generate_action_prompt,
20+
)
1721
from crab.core.agent_policy import AgentPolicy
1822
from crab.core.backend_model import (
19-
BackendModel,
2023
MessageType,
2124
)
2225
from crab.utils.measure import timed
@@ -46,9 +49,9 @@ class SingleAgentPolicy(AgentPolicy):
4649

4750
def __init__(
4851
self,
49-
model_backend: BackendModel,
52+
model_backend: BackendModelConfig,
5053
):
51-
self.model_backend = copy(model_backend)
54+
self.model_backend = create_backend_model(model_backend)
5255
self.reset(task_description="", action_spaces=None, env_descriptions={})
5356

5457
def reset(
@@ -58,10 +61,10 @@ def reset(
5861
env_descriptions: dict[str, str],
5962
) -> list:
6063
self.task_description = task_description
61-
self.action_space = self.combine_multi_env_action_space(action_spaces)
64+
self.action_space = combine_multi_env_action_space(action_spaces)
6265
system_message = self._system_prompt.format(
6366
task_description=task_description,
64-
action_descriptions=self.generate_action_prompt(self.action_space),
67+
action_descriptions=generate_action_prompt(self.action_space),
6568
env_description=str(env_descriptions),
6669
)
6770
self.model_backend.reset(system_message, self.action_space)
@@ -87,4 +90,4 @@ def chat(
8790
)
8891
)
8992
output = self.model_backend.chat(prompt)
90-
return self.decode_combined_action(output.action_list)
93+
return decode_combined_action(output.action_list)

0 commit comments

Comments
 (0)