33import json
44import base64
55import re
6+ import io
7+ import asyncio
8+ import aiohttp
9+
610from PIL import Image
11+ from ultralytics import YOLO
712import google .generativeai as genai
8- from operate .config .settings import Config
9- from operate .exceptions .exceptions import ModelNotRecognizedException
10- from operate .utils .screenshot_util import capture_screen_with_cursor , add_grid_to_image , capture_mini_screenshot_with_cursor
11- from operate .utils .action_util import get_last_assistant_message
12- from operate .utils .prompt_util import format_vision_prompt , format_accurate_mode_vision_prompt ,format_summary_prompt
13+ from operate .settings import Config
14+ from operate .exceptions import ModelNotRecognizedException
15+ from operate .utils .screenshot import (
16+ capture_screen_with_cursor ,
17+ add_grid_to_image ,
18+ capture_mini_screenshot_with_cursor ,
19+ )
20+ from operate .utils .os import get_last_assistant_message
21+ from operate .prompts import (
22+ format_vision_prompt ,
23+ format_accurate_mode_vision_prompt ,
24+ format_summary_prompt ,
25+ format_decision_prompt ,
26+ format_label_prompt ,
27+ )
28+
29+
30+ from operate .utils .label import (
31+ add_labels ,
32+ parse_click_content ,
33+ get_click_position_in_percent ,
34+ get_label_coordinates ,
35+ )
36+ from operate .utils .style import (
37+ ANSI_GREEN ,
38+ ANSI_RED ,
39+ ANSI_RESET ,
40+ )
41+
1342
1443# Load configuration
1544config = Config ()
45+
1646client = config .initialize_openai_client ()
1747
48+ yolo_model = YOLO ("./operate/model/weights/best.pt" ) # Load your trained model
1849
19- def get_next_action (model , messages , objective , accurate_mode ):
20- if model == "gpt-4-vision-preview" :
21- content = get_next_action_from_openai (
22- messages , objective , accurate_mode )
23- return content
50+
51+ async def get_next_action (model , messages , objective ):
52+ if model == "gpt-4" :
53+ return call_gpt_4_v (messages , objective )
54+ if model == "gpt-4-with-som" :
55+ return await call_gpt_4_v_labeled (messages , objective )
2456 elif model == "agent-1" :
2557 return "coming soon"
2658 elif model == "gemini-pro-vision" :
27- content = get_next_action_from_gemini_pro_vision (
28- messages , objective
29- )
30- return content
59+ return call_gemini_pro_vision (messages , objective )
3160
3261 raise ModelNotRecognizedException (model )
3362
3463
35- def get_next_action_from_openai (messages , objective , accurate_mode ):
64+ def call_gpt_4_v (messages , objective ):
3665 """
3766 Get the next action for Self-Operating Computer
3867 """
@@ -95,32 +124,14 @@ def get_next_action_from_openai(messages, objective, accurate_mode):
95124
96125 content = response .choices [0 ].message .content
97126
98- if accurate_mode :
99- if content .startswith ("CLICK" ):
100- # Adjust pseudo_messages to include the accurate_mode_message
101-
102- click_data = re .search (r"CLICK \{ (.+) \}" , content ).group (1 )
103- click_data_json = json .loads (f"{{{ click_data } }}" )
104- prev_x = click_data_json ["x" ]
105- prev_y = click_data_json ["y" ]
106-
107- if config .debug :
108- print (
109- f"Previous coords before accurate tuning: prev_x { prev_x } prev_y { prev_y } "
110- )
111- content = accurate_mode_double_check (
112- "gpt-4-vision-preview" , pseudo_messages , prev_x , prev_y
113- )
114- assert content != "ERROR" , "ERROR: accurate_mode_double_check failed"
115-
116127 return content
117128
118129 except Exception as e :
119130 print (f"Error parsing JSON: { e } " )
120131 return "Failed take action after looking at the screenshot"
121132
122133
123- def get_next_action_from_gemini_pro_vision (messages , objective ):
134+ def call_gemini_pro_vision (messages , objective ):
124135 """
125136 Get the next action for Self-Operating Computer using Gemini Pro Vision
126137 """
@@ -172,14 +183,13 @@ def get_next_action_from_gemini_pro_vision(messages, objective):
172183 return "Failed take action after looking at the screenshot"
173184
174185
186+ # This function is not used. `-accurate` mode was removed for now until a new PR fixes it.
175187def accurate_mode_double_check (model , pseudo_messages , prev_x , prev_y ):
176188 """
177189 Reprompt OAI with additional screenshot of a mini screenshot centered around the cursor for further finetuning of clicked location
178190 """
179- print ("[get_next_action_from_gemini_pro_vision] accurate_mode_double_check" )
180191 try :
181- screenshot_filename = os .path .join (
182- "screenshots" , "screenshot_mini.png" )
192+ screenshot_filename = os .path .join ("screenshots" , "screenshot_mini.png" )
183193 capture_mini_screenshot_with_cursor (
184194 file_path = screenshot_filename , x = prev_x , y = prev_y
185195 )
@@ -191,8 +201,7 @@ def accurate_mode_double_check(model, pseudo_messages, prev_x, prev_y):
191201 with open (new_screenshot_filename , "rb" ) as img_file :
192202 img_base64 = base64 .b64encode (img_file .read ()).decode ("utf-8" )
193203
194- accurate_vision_prompt = format_accurate_mode_vision_prompt (
195- prev_x , prev_y )
204+ accurate_vision_prompt = format_accurate_mode_vision_prompt (prev_x , prev_y )
196205
197206 accurate_mode_message = {
198207 "role" : "user" ,
@@ -234,7 +243,7 @@ def summarize(model, messages, objective):
234243 capture_screen_with_cursor (screenshot_filename )
235244
236245 summary_prompt = format_summary_prompt (objective )
237-
246+
238247 if model == "gpt-4-vision-preview" :
239248 with open (screenshot_filename , "rb" ) as img_file :
240249 img_base64 = base64 .b64encode (img_file .read ()).decode ("utf-8" )
@@ -266,7 +275,135 @@ def summarize(model, messages, objective):
266275 )
267276 content = summary_message .text
268277 return content
269-
278+
270279 except Exception as e :
271280 print (f"Error in summarize: { e } " )
272- return "Failed to summarize the workflow"
281+ return "Failed to summarize the workflow"
282+
283+
284+ async def call_gpt_4_v_labeled (messages , objective ):
285+ time .sleep (1 )
286+ try :
287+ screenshots_dir = "screenshots"
288+ if not os .path .exists (screenshots_dir ):
289+ os .makedirs (screenshots_dir )
290+
291+ screenshot_filename = os .path .join (screenshots_dir , "screenshot.png" )
292+ # Call the function to capture the screen with the cursor
293+ capture_screen_with_cursor (screenshot_filename )
294+
295+ with open (screenshot_filename , "rb" ) as img_file :
296+ img_base64 = base64 .b64encode (img_file .read ()).decode ("utf-8" )
297+
298+ previous_action = get_last_assistant_message (messages )
299+
300+ img_base64_labeled , img_base64_original , label_coordinates = add_labels (
301+ img_base64 , yolo_model
302+ )
303+
304+ decision_prompt = format_decision_prompt (objective , previous_action )
305+ labeled_click_prompt = format_label_prompt (objective )
306+
307+ click_message = {
308+ "role" : "user" ,
309+ "content" : [
310+ {"type" : "text" , "text" : labeled_click_prompt },
311+ {
312+ "type" : "image_url" ,
313+ "image_url" : {
314+ "url" : f"data:image/jpeg;base64,{ img_base64_labeled } "
315+ },
316+ },
317+ ],
318+ }
319+ decision_message = {
320+ "role" : "user" ,
321+ "content" : [
322+ {"type" : "text" , "text" : decision_prompt },
323+ {
324+ "type" : "image_url" ,
325+ "image_url" : {
326+ "url" : f"data:image/jpeg;base64,{ img_base64_original } "
327+ },
328+ },
329+ ],
330+ }
331+
332+ click_messages = messages .copy ()
333+ click_messages .append (click_message )
334+ decision_messages = messages .copy ()
335+ decision_messages .append (decision_message )
336+
337+ click_future = fetch_openai_response_async (click_messages )
338+ decision_future = fetch_openai_response_async (decision_messages )
339+
340+ click_response , decision_response = await asyncio .gather (
341+ click_future , decision_future
342+ )
343+
344+ # Extracting the message content from the ChatCompletionMessage object
345+ click_content = click_response .get ("choices" )[0 ].get ("message" ).get ("content" )
346+
347+ decision_content = (
348+ decision_response .get ("choices" )[0 ].get ("message" ).get ("content" )
349+ )
350+
351+ if not decision_content .startswith ("CLICK" ):
352+ return decision_content
353+
354+ label_data = parse_click_content (click_content )
355+
356+ if label_data and "label" in label_data :
357+ coordinates = get_label_coordinates (label_data ["label" ], label_coordinates )
358+ image = Image .open (
359+ io .BytesIO (base64 .b64decode (img_base64 ))
360+ ) # Load the image to get its size
361+ image_size = image .size # Get the size of the image (width, height)
362+ click_position_percent = get_click_position_in_percent (
363+ coordinates , image_size
364+ )
365+ if not click_position_percent :
366+ print (
367+ f"{ ANSI_GREEN } [Self-Operating Computer]{ ANSI_RED } [Error] Failed to get click position in percent. Trying another method { ANSI_RESET } "
368+ )
369+ return call_gpt_4_v (messages , objective )
370+
371+ x_percent = f"{ click_position_percent [0 ]:.2f} %"
372+ y_percent = f"{ click_position_percent [1 ]:.2f} %"
373+ click_action = f'CLICK {{ "x": "{ x_percent } ", "y": "{ y_percent } ", "description": "{ label_data ["decision" ]} ", "reason": "{ label_data ["reason" ]} " }}'
374+
375+ else :
376+ print (
377+ f"{ ANSI_GREEN } [Self-Operating Computer]{ ANSI_RED } [Error] No label found. Trying another method { ANSI_RESET } "
378+ )
379+ return call_gpt_4_v (messages , objective )
380+
381+ return click_action
382+
383+ except Exception as e :
384+ print (
385+ f"{ ANSI_GREEN } [Self-Operating Computer]{ ANSI_RED } [Error] Something went wrong. Trying another method { ANSI_RESET } "
386+ )
387+ return call_gpt_4_v (messages , objective )
388+
389+
390+ async def fetch_openai_response_async (messages ):
391+ url = "https://api.openai.com/v1/chat/completions"
392+ headers = {
393+ "Content-Type" : "application/json" ,
394+ "Authorization" : f"Bearer { config .openai_api_key } " ,
395+ }
396+ data = {
397+ "model" : "gpt-4-vision-preview" ,
398+ "messages" : messages ,
399+ "frequency_penalty" : 1 ,
400+ "presence_penalty" : 1 ,
401+ "temperature" : 0.7 ,
402+ "max_tokens" : 300 ,
403+ }
404+
405+ async with aiohttp .ClientSession () as session :
406+ async with session .post (
407+ url , headers = headers , data = json .dumps (data )
408+ ) as response :
409+ return await response .json ()
0 commit comments