Skip to content

Commit 8503080

Browse files
Enhancement: CAMEL model update (#28)
Co-authored-by: Tianqi Xu <[email protected]>
1 parent b9ddf11 commit 8503080

File tree

13 files changed

+1442
-1296
lines changed

13 files changed

+1442
-1296
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
33
# Ruff version.
4-
rev: v0.4.2
4+
rev: v0.6.5
55
hooks:
66
# Run the linter.
77
- id: ruff
@@ -13,4 +13,4 @@ repos:
1313
name: Check License
1414
entry: python licenses/update_license.py . licenses/license_template.txt
1515
language: system
16-
types: [python]
16+
types: [python]

crab-benchmark-v0/dataset/ubuntu_subtasks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -540,9 +540,9 @@ def get_rgb_values_outside_bbox(
540540

541541
# Create a mask for the bounding box area with margin
542542
mask = np.ones(img.shape[:2], dtype=bool)
543-
mask[
544-
y_min_with_margin:y_max_with_margin, x_min_with_margin:x_max_with_margin
545-
] = False
543+
mask[y_min_with_margin:y_max_with_margin, x_min_with_margin:x_max_with_margin] = (
544+
False
545+
)
546546

547547
# Extract the RGB values outside the bounding box with margin
548548
rgb_values = img[mask]

crab-benchmark-v0/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,9 @@ def get_benchmark(env: str, ubuntu_url: str):
180180
loglevel = args.loglevel
181181
numeric_level = getattr(logging, loglevel.upper(), None)
182182
if not isinstance(numeric_level, int):
183-
raise ValueError('Invalid log level: %s' % loglevel)
183+
raise ValueError("Invalid log level: %s" % loglevel)
184184
logging.basicConfig(level=numeric_level)
185185

186-
187186
benchmark = get_benchmark(args.env, args.remote_url)
188187

189188
if args.model == "gpt4o":

crab/actions/android_actions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def swipe(element: int, direction: SwipeDirection, dist: SwipeDist, env) -> None
148148
offset = unit_dist, 0
149149
else:
150150
return "ERROR"
151-
adb_command = f"shell input swipe {x} {y} {x+offset[0]} {y+offset[1]} 200"
151+
adb_command = f"shell input swipe {x} {y} {x + offset[0]} {y + offset[1]} 200"
152152
execute_adb(adb_command, env)
153153
sleep(_DURATION)
154154

@@ -213,7 +213,9 @@ def stop_all_apps(env) -> None:
213213
execute_adb("shell input keyevent KEYCODE_HOME", env)
214214
execute_adb("shell input keyevent KEYCODE_APP_SWITCH", env)
215215
sleep(0.5)
216-
command = f"shell input swipe 100 {env.height/2} {env.width-100} {env.height/2} 200"
216+
command = (
217+
f"shell input swipe 100 {env.height / 2} {env.width - 100} {env.height / 2} 200"
218+
)
217219
execute_adb(command, env)
218220
sleep(0.5)
219221
execute_adb("shell input tap 300 1400", env)

crab/agents/backend_models/camel_model.py

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
14-
import base64
15-
import io
1614
import json
17-
from typing import Any, Dict, List, Optional, Tuple, Union
15+
from typing import Any
1816

17+
from openai.types.chat import ChatCompletionMessageToolCall
1918
from PIL import Image
2019

2120
from crab import Action, ActionOutput, BackendModel, BackendOutput, MessageType
@@ -34,7 +33,7 @@
3433
CAMEL_ENABLED = False
3534

3635

37-
def find_model_platform_type(model_platform_name: str) -> ModelPlatformType:
36+
def _find_model_platform_type(model_platform_name: str) -> "ModelPlatformType":
3837
for platform in ModelPlatformType:
3938
if platform.value.lower() == model_platform_name.lower():
4039
return platform
@@ -44,33 +43,51 @@ def find_model_platform_type(model_platform_name: str) -> ModelPlatformType:
4443
)
4544

4645

47-
def find_model_type(model_name: str) -> Union[ModelType, str]:
46+
def _find_model_type(model_name: str) -> "str | ModelType":
4847
for model in ModelType:
4948
if model.value.lower() == model_name.lower():
5049
return model
5150
return model_name
5251

5352

54-
def decode_image(encoded_image: str) -> Image:
55-
data = base64.b64decode(encoded_image)
56-
return Image.open(io.BytesIO(data))
53+
def _convert_action_to_schema(
54+
action_space: list[Action] | None,
55+
) -> "list[OpenAIFunction] | None":
56+
if action_space is None:
57+
return None
58+
return [OpenAIFunction(action.entry) for action in action_space]
59+
60+
61+
def _convert_tool_calls_to_action_list(
62+
tool_calls: list[ChatCompletionMessageToolCall] | None,
63+
) -> list[ActionOutput] | None:
64+
if tool_calls is None:
65+
return None
66+
67+
return [
68+
ActionOutput(
69+
name=call.function.name,
70+
arguments=json.loads(call.function.arguments),
71+
)
72+
for call in tool_calls
73+
]
5774

5875

5976
class CamelModel(BackendModel):
6077
def __init__(
6178
self,
6279
model: str,
6380
model_platform: str,
64-
parameters: Optional[Dict[str, Any]] = None,
81+
parameters: dict[str, Any] | None = None,
6582
history_messages_len: int = 0,
6683
) -> None:
6784
if not CAMEL_ENABLED:
6885
raise ImportError("Please install camel-ai to use CamelModel")
6986
self.parameters = parameters or {}
7087
# TODO: a better way?
71-
self.model_type = find_model_type(model)
72-
self.model_platform_type = find_model_platform_type(model_platform)
73-
self.client: Optional[ChatAgent] = None
88+
self.model_type = _find_model_type(model)
89+
self.model_platform_type = _find_model_platform_type(model_platform)
90+
self.client: ChatAgent | None = None
7491
self.token_usage = 0
7592

7693
super().__init__(
@@ -79,11 +96,11 @@ def __init__(
7996
history_messages_len,
8097
)
8198

82-
def get_token_usage(self):
99+
def get_token_usage(self) -> int:
83100
return self.token_usage
84101

85-
def reset(self, system_message: str, action_space: Optional[List[Action]]) -> None:
86-
action_schema = self._convert_action_to_schema(action_space)
102+
def reset(self, system_message: str, action_space: list[Action] | None) -> None:
103+
action_schema = _convert_action_to_schema(action_space)
87104
config = self.parameters.copy()
88105
if action_schema is not None:
89106
config["tool_choice"] = "required"
@@ -109,30 +126,9 @@ def reset(self, system_message: str, action_space: Optional[List[Action]]) -> No
109126
)
110127
self.token_usage = 0
111128

112-
@staticmethod
113-
def _convert_action_to_schema(
114-
action_space: Optional[List[Action]],
115-
) -> Optional[List[OpenAIFunction]]:
116-
if action_space is None:
117-
return None
118-
return [OpenAIFunction(action.entry) for action in action_space]
119-
120-
@staticmethod
121-
def _convert_tool_calls_to_action_list(tool_calls) -> List[ActionOutput]:
122-
if tool_calls is None:
123-
return tool_calls
124-
125-
return [
126-
ActionOutput(
127-
name=call.function.name,
128-
arguments=json.loads(call.function.arguments),
129-
)
130-
for call in tool_calls
131-
]
132-
133-
def chat(self, messages: List[Tuple[str, MessageType]]):
129+
def chat(self, messages: list[tuple[str, MessageType]]) -> BackendOutput:
134130
# TODO: handle multiple text messages after message refactoring
135-
image_list: List[Image.Image] = []
131+
image_list: list[Image.Image] = []
136132
content = ""
137133
for message in messages:
138134
if message[1] == MessageType.IMAGE_JPG_BASE64:
@@ -147,12 +143,9 @@ def chat(self, messages: List[Tuple[str, MessageType]]):
147143
)
148144
response = self.client.step(usermsg)
149145
self.token_usage += response.info["usage"]["total_tokens"]
150-
tool_call_request = response.info.get("tool_call_request")
151-
152-
# TODO: delete this after record_message is refactored
153-
self.client.record_message(response.msg)
146+
tool_call_request = response.info.get("external_tool_request")
154147

155148
return BackendOutput(
156149
message=response.msg.content,
157-
action_list=self._convert_tool_calls_to_action_list([tool_call_request]),
150+
action_list=_convert_tool_calls_to_action_list([tool_call_request]),
158151
)

crab/core/agent_policy.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,21 @@ class AgentPolicy(ABC):
2121
def chat(
2222
self,
2323
observation: dict[str, list[tuple[str, MessageType]]],
24-
) -> list[ActionOutput]:
25-
...
24+
) -> list[ActionOutput]: ...
2625

2726
@abstractmethod
2827
def reset(
2928
self,
3029
task_description: str,
3130
action_spaces: dict[str, list[Action]],
3231
env_descriptions: dict[str, str],
33-
) -> None:
34-
...
32+
) -> None: ...
3533

3634
@abstractmethod
37-
def get_token_usage(self):
38-
...
35+
def get_token_usage(self): ...
3936

4037
@abstractmethod
41-
def get_backend_model_name(self) -> str:
42-
...
38+
def get_backend_model_name(self) -> str: ...
4339

4440
@staticmethod
4541
def combine_multi_env_action_space(

crab/core/backend_model.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,14 @@ def __init__(
3333
self.reset("You are a helpful assistant.", None)
3434

3535
@abstractmethod
36-
def chat(self, contents: list[tuple[str, MessageType]]) -> BackendOutput:
37-
...
36+
def chat(self, contents: list[tuple[str, MessageType]]) -> BackendOutput: ...
3837

3938
@abstractmethod
4039
def reset(
4140
self,
4241
system_message: str,
4342
action_space: list[Action] | None,
44-
):
45-
...
43+
): ...
4644

4745
@abstractmethod
48-
def get_token_usage(self):
49-
...
46+
def get_token_usage(self): ...

crab/core/models/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class Evaluator(Action):
2222
@field_validator("returns", mode="after")
2323
@classmethod
2424
def must_return_bool(cls, v: type[BaseModel]) -> type[BaseModel]:
25-
if v.model_fields["returns"].annotation != bool:
25+
if v.model_fields["returns"].annotation is not bool:
2626
raise ValueError("Evaluator must return bool.")
2727
return v
2828

crab/utils/measure.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
2+
# Licensed under the Apache License, Version 2.0 (the “License”);
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an “AS IS” BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
114
import logging
215
import time
316
from functools import wraps

0 commit comments

Comments
 (0)