Skip to content

Commit 48f2452

Browse files
refactor(agent): clean agent part code (#40)
Co-authored-by: Isaac Jin <[email protected]>
1 parent 71e95fb commit 48f2452

38 files changed

+1997
-1157
lines changed

crab-benchmark-v0/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,7 @@ After setting up the environment, you can start the experiment. A brief overview
2929
2. Start the CRAB server in the Ubuntu environment and get its IP address and port. Let's say they are `192.168.122.72` and `8000`.
3030
3. Choose a task. As an example, we take the task with ID `a3476778-e512-40ca-b1c0-d7aab0c7f18b` from [handmade_tasks](./dataset/handmade_tasks.py). The task is: "Open the 'Tasks' app on Android, check the first incomplete task, then perform the task according to its description."
3131
4. Run [main.py](./main.py) with the command `poetry run python -m crab-benchmark-v0.main --model gpt4o --policy single --remote-url http://192.168.122.72:8000 --task-id a3476778-e512-40ca-b1c0-d7aab0c7f18b`. In this command, `--model gpt4o` and `--policy single` determine the agent system, `--remote-url` specifies the Ubuntu environment interface, and `--task-id` indicates the task to be performed.
32+
33+
#### Model
34+
35+
For open source models, we use [VLLM](https://github.com/vllm-project/vllm) to host Pixtral model, check [here](https://docs.vllm.ai/en/latest/models/vlm.html#online-inference) for the setup commands; [SGLang](https://github.com/sgl-project/sglang) to host LLaVa-OneVision model, check [here](https://github.com/sgl-project/sglang?tab=readme-ov-file#supported-models) for the setup commands.

crab-benchmark-v0/android_env.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from crab import EnvironmentConfig
1515
from crab.actions.android_actions import (
1616
key_press,
17+
long_tap,
1718
open_app_drawer,
1819
screenshot,
1920
setup,
@@ -24,7 +25,7 @@
2425

2526
ANDROID_ENV = EnvironmentConfig(
2627
name="android",
27-
action_space=[tap, key_press, write_text, swipe, open_app_drawer],
28+
action_space=[tap, key_press, long_tap, write_text, swipe, open_app_drawer],
2829
observation_space=[screenshot],
2930
description="""A Google Pixel smartphone runs on the Android operating system. \
3031
The interface displays a current screenshot at each step and primarily \
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"description": "In Android, Using Google Map app, Find the city name of corresponding post code \"1010021\" in the country \"Japan\".",
3+
"tasks": [
4+
{
5+
"task": "51b2463c-9904-4a32-81ba-507bfb89d61f",
6+
"attribute": {
7+
"country": "Japan",
8+
"number": "101-0021"
9+
},
10+
"output": "Tokyo"
11+
}
12+
],
13+
"adjlist": "0",
14+
"id": "4190c90c-b28c-4bb3-ab5c-af3c4fde0a3d"
15+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"description": "In the Android system, use the calendar app to find the title of an event on the date \"16 July 2024,\".",
3+
"tasks": [
4+
{
5+
"task": "2394b768-2ca7-45e9-b41e-2aa4e9573192",
6+
"attribute": {
7+
"date": "16 July 2024"
8+
},
9+
"output": "Japan"
10+
}
11+
],
12+
"adjlist": "0",
13+
"id": "4893a9b0-6477-495d-a73c-32503326e24a"
14+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"description": "In Android, use the \"Google Map\" app to find the city name corresponding to the postcode \"110151\" in Colombia.",
3+
"tasks": [
4+
{
5+
"task": "51b2463c-9904-4a32-81ba-507bfb89d61f",
6+
"attribute": {
7+
"number": "110151",
8+
"country": "Columbia"
9+
},
10+
"output": "Bogota"
11+
}
12+
],
13+
"adjlist": "0",
14+
"id": "e55d7a39-7b6b-4852-8711-844cebc88cb8"
15+
}

crab-benchmark-v0/dataset/android_subtasks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,8 @@ def check_event(date: str, env) -> bool:
361361
event_nodes = root.xpath('//node[@class="android.support.v7.widget.RecyclerView"]')
362362
if event_nodes is None:
363363
return False
364+
if not event_nodes:
365+
return False
364366
for node in event_nodes[0]:
365367
text = node.get("content-desc")
366368
if date in text:

crab-benchmark-v0/dataset/cross/05a7633d-b966-471c-8848-e18e69ad265f.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"description": "In Android, use the \"Google Map\" app to find the city name corresponding to the postal code \"1010021\" in Japan, then paste the name into LibreOffice Writer on an Ubuntu system and save it as an ODT file at \"/home/crab/Desktop\".",
2+
"description": "In Android, use the \"Google Map\" app to find the city name corresponding to the postal code \"1010021\" in Japan, then paste the name into LibreOffice Writer on an Ubuntu system and save it as an ODT file at \"/home/crab/Desktop/target.opt\".",
33
"tasks": [
44
{
55
"task": "51b2463c-9904-4a32-81ba-507bfb89d61f",

crab-benchmark-v0/dataset/handmade_tasks.py

Lines changed: 200 additions & 24 deletions
Large diffs are not rendered by default.

crab-benchmark-v0/main.py

Lines changed: 81 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
TaskGenerator,
2525
create_benchmark,
2626
)
27-
from crab.actions.crab_actions import complete
27+
from crab.actions.crab_actions import complete, wait
2828
from crab.actions.visual_prompt_actions import (
2929
get_elements_prompt,
3030
groundingdino_easyocr,
3131
)
32-
from crab.agents.backend_models import ClaudeModel, GeminiModel, OpenAIModel
32+
from crab.agents.backend_models import BackendModelConfig
3333
from crab.agents.policies import (
3434
MultiAgentByEnvPolicy,
3535
MultiAgentByFuncPolicy,
@@ -96,7 +96,7 @@ def get_benchmark(env: str, ubuntu_url: str):
9696
tasks=[],
9797
environments=[ubuntu_env],
9898
prompting_tools=prompting_tools,
99-
root_action_space=[complete],
99+
root_action_space=[complete, wait],
100100
multienv=True,
101101
)
102102
elif env == "android":
@@ -106,7 +106,7 @@ def get_benchmark(env: str, ubuntu_url: str):
106106
tasks=[],
107107
environments=[ANDROID_ENV],
108108
prompting_tools=prompting_tools,
109-
root_action_space=[complete],
109+
root_action_space=[complete, wait],
110110
multienv=True,
111111
)
112112
elif env == "cross":
@@ -119,7 +119,7 @@ def get_benchmark(env: str, ubuntu_url: str):
119119
tasks=[],
120120
environments=[ubuntu_env, ANDROID_ENV],
121121
prompting_tools=prompting_tools,
122-
root_action_space=[complete],
122+
root_action_space=[complete, wait],
123123
multienv=True,
124124
)
125125
else:
@@ -137,7 +137,7 @@ def get_benchmark(env: str, ubuntu_url: str):
137137
# Load from handmade tasks
138138
benchmark_config.tasks.extend(handmade_tasks)
139139

140-
benchmark_config.step_limit = 15
140+
benchmark_config.step_limit = 20
141141
return create_benchmark(benchmark_config)
142142

143143

@@ -158,7 +158,7 @@ def get_benchmark(env: str, ubuntu_url: str):
158158
default="single",
159159
)
160160
parser.add_argument(
161-
"--remote-url",
161+
"--ubuntu-url",
162162
type=str,
163163
help="remote url of Ubunutu environment",
164164
default="http://127.0.0.1:8000",
@@ -170,29 +170,97 @@ def get_benchmark(env: str, ubuntu_url: str):
170170
default="cross",
171171
)
172172
parser.add_argument("--task-id", type=str, help="task id")
173+
parser.add_argument(
174+
"--model-base-url",
175+
type=str,
176+
help="URL of the model API",
177+
default="http://127.0.0.1:8000/v1",
178+
)
179+
parser.add_argument(
180+
"--model-api-key",
181+
type=str,
182+
help="API key of the model API",
183+
default="EMPTY",
184+
)
173185
parser.add_argument(
174186
"--loglevel",
175187
type=str,
176188
help="logger level, debug, info, warning, or error",
177189
default="warning",
178190
)
191+
parser.add_argument(
192+
"--history-messages-len",
193+
type=int,
194+
help="The number of rounds of chat history to provide to the model",
195+
default=2,
196+
)
179197
args = parser.parse_args()
180198
loglevel = args.loglevel
181199
numeric_level = getattr(logging, loglevel.upper(), None)
182200
if not isinstance(numeric_level, int):
183201
raise ValueError("Invalid log level: %s" % loglevel)
184202
logging.basicConfig(level=numeric_level)
185203

186-
benchmark = get_benchmark(args.env, args.remote_url)
204+
benchmark = get_benchmark(args.env, args.ubuntu_url)
205+
206+
if args.model == "human":
207+
expeirment = CrabBenchmarkV0(
208+
benchmark=benchmark,
209+
task_id=args.task_id,
210+
agent_policy="human",
211+
)
212+
expeirment.start_benchmark()
213+
exit()
187214

188215
if args.model == "gpt4o":
189-
model = OpenAIModel(model="gpt-4o", history_messages_len=2)
216+
model = BackendModelConfig(
217+
model_class="openai",
218+
model_name="gpt-4o",
219+
history_messages_len=args.history_messages_len,
220+
)
190221
elif args.model == "gpt4turbo":
191-
model = OpenAIModel(model="gpt-4-turbo", history_messages_len=2)
222+
model = BackendModelConfig(
223+
model_class="openai",
224+
model_name="gpt-4-turbo",
225+
history_messages_len=args.history_messages_len,
226+
)
192227
elif args.model == "gemini":
193-
model = GeminiModel(model="gemini-1.5-pro-latest", history_messages_len=2)
228+
model = BackendModelConfig(
229+
model_class="gemini",
230+
model_name="gemini-1.5-pro-latest",
231+
history_messages_len=args.history_messages_len,
232+
)
194233
elif args.model == "claude":
195-
model = ClaudeModel(model="claude-3-opus-20240229", history_messages_len=2)
234+
model = BackendModelConfig(
235+
model_class="claude",
236+
model_name="claude-3-opus-20240229",
237+
history_messages_len=args.history_messages_len,
238+
)
239+
elif args.model == "pixtral":
240+
model = BackendModelConfig(
241+
model_class="openai",
242+
model_name="mistralai/Pixtral-12B-2409",
243+
json_structre_output=True,
244+
history_messages_len=args.history_messages_len,
245+
base_url=args.model_base_url,
246+
api_key=args.model_api_key,
247+
)
248+
elif args.model == "gpt4o-wofc":
249+
model = BackendModelConfig(
250+
model_class="openai",
251+
model_name="gpt-4o",
252+
json_structre_output=True,
253+
history_messages_len=args.history_messages_len,
254+
)
255+
elif args.model == "llava-ov72b":
256+
model = BackendModelConfig(
257+
model_class="sglang",
258+
model_name="lmms-lab/llava-onevision-qwen2-72b-ov-chat",
259+
json_structre_output=True,
260+
history_messages_len=args.history_messages_len,
261+
base_url=args.model_base_url,
262+
api_key=args.model_api_key,
263+
)
196264
else:
197265
print("Unsupported model: ", args.model)
198266
exit()
@@ -211,7 +279,7 @@ def get_benchmark(env: str, ubuntu_url: str):
211279
print("Unsupported policy: ", args.policy)
212280
exit()
213281

214-
log_dir = (Path(__file__).parent / "logs").resolve()
282+
log_dir = (Path(__file__).parent / "tianqi_logs").resolve()
215283
expeirment = CrabBenchmarkV0(
216284
benchmark=benchmark,
217285
task_id=args.task_id,

crab-benchmark-v0/ubuntu_env.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
1414
from crab.actions.desktop_actions import (
1515
click,
16+
double_click,
1617
key_press,
1718
press_hotkey,
1819
right_click,
@@ -31,6 +32,7 @@
3132
press_hotkey,
3233
search_application,
3334
right_click,
35+
double_click,
3436
],
3537
observation_space=[screenshot],
3638
description="""An Ubuntu 22.04 Linux desktop operating system. The interface \

0 commit comments

Comments
 (0)