Skip to content

Commit 3d58477

Browse files
committed
update the code
1 parent bfd778d commit 3d58477

File tree

2 files changed

+92
-13
lines changed

2 files changed

+92
-13
lines changed

gui/gui_experiment.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ def __init__(
2626
log_dir: Path | None = None,
2727
) -> None:
2828
super().__init__(benchmark, task_id, agent_policy, log_dir)
29+
self.display_callback = None
30+
31+
def set_display_callback(self, callback):
32+
self.display_callback = callback
2933

3034
def get_prompt(self):
3135
observation, ob_prompt = self.benchmark.observe_with_prompt()
@@ -47,3 +51,64 @@ def get_prompt(self):
4751
(marked_screenshot, MessageType.IMAGE_JPG_BASE64),
4852
]
4953
return result_prompt
54+
55+
def step(self, it) -> bool:
56+
if self.display_callback:
57+
self.display_callback(f"Step {self.step_cnt}:", "ai")
58+
59+
prompt = self.get_prompt()
60+
self.log_prompt(prompt)
61+
62+
try:
63+
response = self.agent_policy.chat(prompt)
64+
if self.display_callback:
65+
self.display_callback(f"Planning next action...", "ai")
66+
except Exception as e:
67+
if self.display_callback:
68+
self.display_callback(f"Error: {str(e)}", "ai")
69+
self.write_main_csv_row("agent_exception")
70+
return True
71+
72+
if self.display_callback:
73+
self.display_callback(f"Executing: {response}", "ai")
74+
return self.execute_action(response)
75+
76+
def execute_action(self, response: list[ActionOutput]) -> bool:
77+
for action in response:
78+
benchmark_result = self.benchmark.step(
79+
action=action.name,
80+
parameters=action.arguments,
81+
env_name=action.env,
82+
)
83+
self.metrics = benchmark_result.evaluation_results
84+
85+
if benchmark_result.terminated:
86+
if self.display_callback:
87+
self.display_callback(
88+
f"✓ Task completed! Results: {self.metrics}", "ai"
89+
)
90+
self.write_current_log_row(action)
91+
self.write_current_log_row(benchmark_result.info["terminate_reason"])
92+
return True
93+
94+
if self.display_callback:
95+
self.display_callback(
96+
f'Action "{action.name}" completed in {action.env}. '
97+
f"Progress: {self.metrics}", "ai"
98+
)
99+
self.write_current_log_row(action)
100+
self.step_cnt += 1
101+
return False
102+
103+
def start_benchmark(self):
104+
if self.display_callback:
105+
self.display_callback("Starting benchmark...", "ai")
106+
try:
107+
super().start_benchmark()
108+
except KeyboardInterrupt:
109+
if self.display_callback:
110+
self.display_callback("Experiment interrupted.", "ai")
111+
self.write_main_csv_row("experiment_interrupted")
112+
finally:
113+
if self.display_callback:
114+
self.display_callback("Experiment finished.", "ai")

gui/main.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,33 @@ def assign_task():
5353
input_entry.delete(0, "end")
5454
display_message(task_description)
5555

56-
model = get_model_instance(model_dropdown.get())
57-
agent_policy = SingleAgentPolicy(model_backend=model)
58-
59-
task_id = str(uuid4())
60-
benchmark = get_benchmark(task_id, task_description)
61-
experiment = GuiExperiment(
62-
benchmark=benchmark,
63-
task_id=task_id,
64-
agent_policy=agent_policy,
65-
log_dir=log_dir,
66-
)
67-
# TODO: redirect the output to the GUI
68-
experiment.start_benchmark()
56+
try:
57+
model = get_model_instance(model_dropdown.get())
58+
agent_policy = SingleAgentPolicy(model_backend=model)
59+
60+
task_id = str(uuid4())
61+
benchmark = get_benchmark(task_id, task_description)
62+
experiment = GuiExperiment(
63+
benchmark=benchmark,
64+
task_id=task_id,
65+
agent_policy=agent_policy,
66+
log_dir=log_dir,
67+
)
68+
69+
experiment.set_display_callback(display_message)
70+
71+
def run_experiment():
72+
try:
73+
experiment.start_benchmark()
74+
except Exception as e:
75+
display_message(f"Error: {str(e)}", "ai")
76+
77+
import threading
78+
thread = threading.Thread(target=run_experiment, daemon=True)
79+
thread.start()
80+
81+
except Exception as e:
82+
display_message(f"Error: {str(e)}", "ai")
6983

7084

7185
def display_message(message, sender="user"):

0 commit comments

Comments
 (0)