@@ -26,6 +26,10 @@ def __init__(
2626 log_dir : Path | None = None ,
2727 ) -> None :
2828 super ().__init__ (benchmark , task_id , agent_policy , log_dir )
29+ self .display_callback = None
30+
31+ def set_display_callback (self , callback ):
32+ self .display_callback = callback
2933
3034 def get_prompt (self ):
3135 observation , ob_prompt = self .benchmark .observe_with_prompt ()
@@ -47,3 +51,64 @@ def get_prompt(self):
4751 (marked_screenshot , MessageType .IMAGE_JPG_BASE64 ),
4852 ]
4953 return result_prompt
54+
55+ def step (self , it ) -> bool :
56+ if self .display_callback :
57+ self .display_callback (f"Step { self .step_cnt } :" , "ai" )
58+
59+ prompt = self .get_prompt ()
60+ self .log_prompt (prompt )
61+
62+ try :
63+ response = self .agent_policy .chat (prompt )
64+ if self .display_callback :
65+ self .display_callback (f"Planning next action..." , "ai" )
66+ except Exception as e :
67+ if self .display_callback :
68+ self .display_callback (f"Error: { str (e )} " , "ai" )
69+ self .write_main_csv_row ("agent_exception" )
70+ return True
71+
72+ if self .display_callback :
73+ self .display_callback (f"Executing: { response } " , "ai" )
74+ return self .execute_action (response )
75+
76+ def execute_action (self , response : list [ActionOutput ]) -> bool :
77+ for action in response :
78+ benchmark_result = self .benchmark .step (
79+ action = action .name ,
80+ parameters = action .arguments ,
81+ env_name = action .env ,
82+ )
83+ self .metrics = benchmark_result .evaluation_results
84+
85+ if benchmark_result .terminated :
86+ if self .display_callback :
87+ self .display_callback (
88+ f"✓ Task completed! Results: { self .metrics } " , "ai"
89+ )
90+ self .write_current_log_row (action )
91+ self .write_current_log_row (benchmark_result .info ["terminate_reason" ])
92+ return True
93+
94+ if self .display_callback :
95+ self .display_callback (
96+ f'Action "{ action .name } " completed in { action .env } . '
97+ f"Progress: { self .metrics } " , "ai"
98+ )
99+ self .write_current_log_row (action )
100+ self .step_cnt += 1
101+ return False
102+
103+ def start_benchmark (self ):
104+ if self .display_callback :
105+ self .display_callback ("Starting benchmark..." , "ai" )
106+ try :
107+ super ().start_benchmark ()
108+ except KeyboardInterrupt :
109+ if self .display_callback :
110+ self .display_callback ("Experiment interrupted." , "ai" )
111+ self .write_main_csv_row ("experiment_interrupted" )
112+ finally :
113+ if self .display_callback :
114+ self .display_callback ("Experiment finished." , "ai" )
0 commit comments