diff --git a/gui_agents/common/__init__.py b/gui_agents/common/__init__.py new file mode 100644 index 00000000..6c7e54e7 --- /dev/null +++ b/gui_agents/common/__init__.py @@ -0,0 +1,28 @@ +from gui_agents.common.agent_action_schema import ( + AGENT_ACTION_JSON_SCHEMA, + AGENT_ACTION_RESPONSE_FORMAT, + SCHEMA_PROMPT_FRAGMENT, + ActionType, + AgentAction, + AgentActionParseError, + parse_agent_action, + agent_action_to_dict, +) + +from gui_agents.common.agent_action_dispatcher import ( + ACTION_METHOD_BY_TYPE, + execute_agent_action, +) + +__all__ = [ + "AGENT_ACTION_JSON_SCHEMA", + "AGENT_ACTION_RESPONSE_FORMAT", + "SCHEMA_PROMPT_FRAGMENT", + "ActionType", + "AgentAction", + "AgentActionParseError", + "parse_agent_action", + "agent_action_to_dict", + "ACTION_METHOD_BY_TYPE", + "execute_agent_action", +] diff --git a/gui_agents/common/agent_action_dispatcher.py b/gui_agents/common/agent_action_dispatcher.py new file mode 100644 index 00000000..69d8a419 --- /dev/null +++ b/gui_agents/common/agent_action_dispatcher.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict, Iterable, List, Optional, Tuple + +from gui_agents.common.agent_action_schema import ( + ActionType, + AgentAction, + ClickArgs, + DragAndDropArgs, + HighlightTextSpanArgs, + HoldAndPressArgs, + HotkeyArgs, + OpenAppArgs, + SaveToKnowledgeArgs, + ScrollArgs, + SetCellValuesArgs, + SwitchAppArgs, + TargetSelector, + TypeArgs, + WaitArgs, +) + +logger = logging.getLogger(__name__) + + +ACTION_METHOD_BY_TYPE: Dict[ActionType, str] = { + ActionType.CLICK: "click", + ActionType.DBLCLICK: "click", + ActionType.TYPE: "type", + ActionType.HOTKEY: "hotkey", + ActionType.WAIT: "wait", + ActionType.SCROLL: "scroll", + ActionType.SWITCH_APP: "switch_applications", + ActionType.FOCUS_APP: "switch_applications", + ActionType.SWITCH_APPLICATIONS: "switch_applications", + ActionType.OPEN_APP: "open", + ActionType.OPEN: "open", + ActionType.DRAG_AND_DROP: "drag_and_drop", + ActionType.SAVE_TO_KNOWLEDGE: "save_to_knowledge", + ActionType.HIGHLIGHT_TEXT_SPAN: "highlight_text_span", + ActionType.SET_CELL_VALUES: "set_cell_values", + ActionType.HOLD_AND_PRESS: "hold_and_press", + ActionType.DONE: "done", + ActionType.FAIL: "fail", +} + + +def execute_agent_action(agent: Any, action: AgentAction, obs: Dict[str, Any]) -> str: + method_name = ACTION_METHOD_BY_TYPE.get(action.type) + if method_name is None: + raise RuntimeError(f"Unsupported action type: {action.type}") + + _assign_coordinates_if_supported(agent, action, obs) + + method = getattr(agent, method_name, None) + if method is None: + raise RuntimeError(f"Grounding agent is missing method '{method_name}' for action {action.type.value}") + + positional_args, keyword_args = _build_call_arguments(action) + return method(*positional_args, **keyword_args) + + +def _assign_coordinates_if_supported(agent: Any, action: AgentAction, obs: Dict[str, Any]) -> None: + assign_fn = getattr(agent, "assign_coordinates", None) + if callable(assign_fn): + try: + assign_fn(action, obs) + except TypeError: + # Legacy signature (str, obs). Fall back to manual assignment below. + logger.debug("assign_coordinates did not accept AgentAction payload", exc_info=True) + except Exception: + logger.debug("assign_coordinates raised during structured dispatch", exc_info=True) + else: + if not _coordinates_available(agent, action): + logger.debug("assign_coordinates returned without coordinates; falling back") + else: + return + _fallback_coordinate_assignment(agent, action, obs) + + + +def _coordinates_available(agent: Any, action: AgentAction) -> bool: + if action.type in {ActionType.CLICK, ActionType.DBLCLICK, ActionType.SCROLL}: + return getattr(agent, "coords1", None) is not None + if action.type == ActionType.TYPE: + target = getattr(action.args, "target", None) + if target is None: + return True + return getattr(agent, "coords1", None) is not None + if action.type == ActionType.DRAG_AND_DROP: + return getattr(agent, "coords1", None) is not None and getattr(agent, "coords2", None) is not None + if action.type == ActionType.HIGHLIGHT_TEXT_SPAN: + return getattr(agent, "coords1", None) is not None and getattr(agent, "coords2", None) is not None + return True + + +def _fallback_coordinate_assignment(agent: Any, action: AgentAction, obs: Dict[str, Any]) -> None: + # Provide basic bbox-driven coordinate assignment as a safety net when the agent does not support structured payloads yet. + try: + if action.type in {ActionType.CLICK, ActionType.DBLCLICK, ActionType.SCROLL}: + target = getattr(action.args, "target", None) + point = _resolve_target_to_point(agent, target, obs) + if point is not None: + setattr(agent, "coords1", point) + elif action.type == ActionType.TYPE: + target = getattr(action.args, "target", None) + if target is not None: + point = _resolve_target_to_point(agent, target, obs) + if point is not None: + setattr(agent, "coords1", point) + elif action.type == ActionType.DRAG_AND_DROP: + start = getattr(action.args, "start", None) + end = getattr(action.args, "end", None) + point1 = _resolve_target_to_point(agent, start, obs) + point2 = _resolve_target_to_point(agent, end, obs) + if point1 is not None: + setattr(agent, "coords1", point1) + if point2 is not None: + setattr(agent, "coords2", point2) + except Exception: + logger.debug("Fallback coordinate assignment failed", exc_info=True) + + +def _resolve_target_to_point(agent: Any, target: Optional[TargetSelector], obs: Dict[str, Any]) -> Optional[List[int]]: + if target is None: + return None + if target.bbox is not None: + return _center_from_bounds(target.bbox) + if target.a11y_id is not None: + bounds = _find_bounds_by_id(obs, target.a11y_id) + if bounds is not None: + return _center_from_bounds(bounds) + if target.description is not None: + resolver = getattr(agent, "generate_coords", None) + if callable(resolver): + try: + coords = resolver(target.description, obs) + if isinstance(coords, Iterable) and len(coords) == 2: + return [int(coords[0]), int(coords[1])] + except Exception: + logger.debug("generate_coords failed while resolving description", exc_info=True) + return None + + +def _find_bounds_by_id(obs: Dict[str, Any], target_id: str) -> Optional[Tuple[int, int, int, int]]: + for tree in _candidate_trees(obs): + for node in _iter_nodes(tree): + if not isinstance(node, dict): + continue + node_id = node.get("id") or node.get("nodeId") or node.get("node_id") or node.get("a11y_id") + if node_id is None: + continue + if str(node_id) != str(target_id): + continue + bounds = ( + node.get("bounds") + or node.get("bounding_box") + or node.get("bbox") + or node.get("frame") + ) + if isinstance(bounds, (list, tuple)) and len(bounds) == 4: + try: + return tuple(int(v) for v in bounds) + except Exception: + continue + return None + + +def _candidate_trees(obs: Dict[str, Any]) -> Iterable[Any]: + for key in ("a11y_tree", "serialized_a11y_tree", "tree", "som", "nodes"): + tree = obs.get(key) + if tree: + yield tree + + +def _iter_nodes(node: Any) -> Iterable[Any]: + stack = [node] + while stack: + current = stack.pop() + yield current + if isinstance(current, dict): + children = current.get("children") or current.get("nodes") or current.get("children_list") + if isinstance(children, dict): + stack.extend(children.values()) + elif isinstance(children, list): + stack.extend(children) + elif isinstance(current, list): + stack.extend(current) + + +def _center_from_bounds(bounds: Tuple[int, int, int, int]) -> List[int]: + x0, y0, x1, y1 = bounds + if x1 >= x0 and y1 >= y0: + return [int((x0 + x1) / 2), int((y0 + y1) / 2)] + # Treat as (x, y, width, height) + return [int(x0 + x1 / 2), int(y0 + y1 / 2)] + + +def _target_description(target: Optional[TargetSelector]) -> str: + if target is None: + return "" + if target.description is not None: + return target.description + if target.a11y_id is not None: + return f"a11y_id:{target.a11y_id}" + if target.bbox is not None: + return f"bbox:{','.join(str(v) for v in target.bbox)}" + return "" + + +def _build_call_arguments(action: AgentAction) -> Tuple[List[Any], Dict[str, Any]]: + args = action.args + if action.type in {ActionType.CLICK, ActionType.DBLCLICK} and isinstance(args, ClickArgs): + desc = _target_description(args.target) + num_clicks = args.count if action.type == ActionType.CLICK else max(2, args.count) + return [desc], { + "num_clicks": num_clicks, + "button_type": args.button, + "hold_keys": list(args.hold_keys), + } + if action.type == ActionType.TYPE and isinstance(args, TypeArgs): + kwargs: Dict[str, Any] = { + "text": args.text, + "overwrite": args.overwrite, + "enter": args.enter, + } + if args.target is not None: + kwargs["element_description"] = _target_description(args.target) + return [], kwargs + if action.type == ActionType.HOTKEY and isinstance(args, HotkeyArgs): + return [list(args.keys)], {} + if action.type == ActionType.WAIT and isinstance(args, WaitArgs): + return [args.seconds], {} + if action.type == ActionType.SCROLL and isinstance(args, ScrollArgs): + return [ + _target_description(args.target) + ], { + "clicks": args.clicks, + "shift": args.shift, + } + if action.type in {ActionType.SWITCH_APP, ActionType.FOCUS_APP, ActionType.SWITCH_APPLICATIONS} and isinstance(args, SwitchAppArgs): + return [args.app_id], {} + if action.type in {ActionType.OPEN_APP, ActionType.OPEN} and isinstance(args, OpenAppArgs): + return [args.app_name], {} + if action.type == ActionType.DRAG_AND_DROP and isinstance(args, DragAndDropArgs): + return [ + _target_description(args.start), + _target_description(args.end), + ], {"hold_keys": list(args.hold_keys)} + if action.type == ActionType.SAVE_TO_KNOWLEDGE and isinstance(args, SaveToKnowledgeArgs): + return [list(args.text)], {} + if action.type == ActionType.HIGHLIGHT_TEXT_SPAN and isinstance(args, HighlightTextSpanArgs): + return [args.start_phrase, args.end_phrase], {"button": args.button} + if action.type == ActionType.SET_CELL_VALUES and isinstance(args, SetCellValuesArgs): + return [ + dict(args.cell_values), + args.app_name, + args.sheet_name, + ], {} + if action.type == ActionType.HOLD_AND_PRESS and isinstance(args, HoldAndPressArgs): + return [], { + "hold_keys": list(args.hold_keys), + "press_keys": list(args.press_keys), + } + if action.type == ActionType.DONE: + return [], {"return_value": getattr(args, "return_value", None)} + if action.type == ActionType.FAIL: + return [], {} + raise RuntimeError(f"Unsupported or mismatched arguments for action type: {action.type}") + diff --git a/gui_agents/common/agent_action_schema.py b/gui_agents/common/agent_action_schema.py new file mode 100644 index 00000000..198e953f --- /dev/null +++ b/gui_agents/common/agent_action_schema.py @@ -0,0 +1,716 @@ +from __future__ import annotations + +import json +import textwrap +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +from pydantic import BaseModel, Field, field_validator, model_validator + + +class AgentActionParseError(ValueError): + """Raised when an agent action payload cannot be parsed or validated.""" + + +class TargetSelector(BaseModel): + a11y_id: Optional[str] = None + bbox: Optional[Tuple[int, int, int, int]] = None + description: Optional[str] = None + + class Config: + extra = "forbid" + + @field_validator("a11y_id", "description") + @classmethod + def _strip_text(cls, value: Optional[str]) -> Optional[str]: + if value is None: + return None + cleaned = value.strip() + if not cleaned: + raise ValueError("Text selectors must be non-empty when provided") + return cleaned + + @field_validator("bbox") + @classmethod + def _validate_bbox(cls, value: Optional[Tuple[int, int, int, int]]) -> Optional[Tuple[int, int, int, int]]: + if value is None: + return None + if len(value) != 4: + raise ValueError("Bounding boxes must contain exactly four integers") + return tuple(int(v) for v in value) + + @model_validator(mode="after") + def _check_exactly_one_selector(self) -> "TargetSelector": + provided = [ + name + for name in ("a11y_id", "bbox", "description") + if getattr(self, name) not in (None, "") + ] + if len(provided) != 1: + raise ValueError("Provide exactly one of a11y_id, bbox, or description in target selector") + return self + + +class ArgsModel(BaseModel): + class Config: + extra = "forbid" + + +class ClickArgs(ArgsModel): + target: TargetSelector + button: str = Field(default="left") + count: int = Field(default=1, ge=1, le=4) + hold_keys: List[str] = Field(default_factory=list) + + @field_validator("button") + @classmethod + def _validate_button(cls, value: str) -> str: + allowed = {"left", "right", "middle"} + lowered = value.lower().strip() + if lowered not in allowed: + raise ValueError(f"Button must be one of {sorted(allowed)}") + return lowered + + @field_validator("hold_keys", mode="after") + @classmethod + def _clean_keys(cls, values: List[str]) -> List[str]: + cleaned = [item.strip() for item in values if item.strip()] + if len(cleaned) != len(values): + raise ValueError("Hold key entries must be non-empty strings") + if len(cleaned) > 5: + raise ValueError("Hold keys cannot exceed 5 entries") + return cleaned + + +class DragAndDropArgs(ArgsModel): + start: TargetSelector + end: TargetSelector + hold_keys: List[str] = Field(default_factory=list) + + @field_validator("hold_keys", mode="after") + @classmethod + def _clean_hold_keys(cls, values: List[str]) -> List[str]: + cleaned = [item.strip() for item in values if item.strip()] + if len(cleaned) != len(values): + raise ValueError("Hold key entries must be non-empty strings") + if len(cleaned) > 4: + raise ValueError("Hold keys cannot exceed 4 entries") + return cleaned + + +class TypeArgs(ArgsModel): + text: str = Field(..., max_length=512) + target: Optional[TargetSelector] = None + overwrite: bool = False + enter: bool = False + + @field_validator("text") + @classmethod + def _non_empty_text(cls, value: str) -> str: + if not value.strip("\n"): + raise ValueError("Typing text must be non-empty") + return value + + +class HotkeyArgs(ArgsModel): + keys: List[str] + + @field_validator("keys") + @classmethod + def _ensure_bounds(cls, values: List[str]) -> List[str]: + if not values: + raise ValueError("At least one key must be provided") + if len(values) > 5: + raise ValueError("Hotkey sequences cannot exceed 5 keys") + cleaned = [item.strip() for item in values if item.strip()] + if len(cleaned) != len(values): + raise ValueError("Hotkey values must be non-empty strings") + return cleaned + + +class HoldAndPressArgs(ArgsModel): + hold_keys: List[str] = Field(default_factory=list) + press_keys: List[str] + + @field_validator("hold_keys") + @classmethod + def _clean_hold(cls, values: List[str]) -> List[str]: + cleaned = [item.strip() for item in values if item.strip()] + if len(cleaned) != len(values): + raise ValueError("Hold key entries must be non-empty strings") + if len(cleaned) > 5: + raise ValueError("Hold keys cannot exceed 5 entries") + return cleaned + + @field_validator("press_keys") + @classmethod + def _clean_press(cls, values: List[str]) -> List[str]: + if not values: + raise ValueError("At least one key must be pressed") + cleaned = [item.strip() for item in values if item.strip()] + if len(cleaned) != len(values): + raise ValueError("Press key entries must be non-empty strings") + if len(cleaned) > 5: + raise ValueError("Press sequences cannot exceed 5 keys") + return cleaned + + +class ScrollArgs(ArgsModel): + target: TargetSelector + clicks: int + shift: bool = False + + +class SwitchAppArgs(ArgsModel): + app_id: str + + @field_validator("app_id") + @classmethod + def _clean_app(cls, value: str) -> str: + cleaned = value.strip() + if not cleaned: + raise ValueError("Application identifier must be non-empty") + return cleaned + + +class OpenAppArgs(ArgsModel): + app_name: str + + @field_validator("app_name") + @classmethod + def _clean_name(cls, value: str) -> str: + cleaned = value.strip() + if not cleaned: + raise ValueError("Application name must be non-empty") + return cleaned + + +class SaveToKnowledgeArgs(ArgsModel): + text: List[str] + + @field_validator("text") + @classmethod + def _validate_saved_text(cls, values: List[str]) -> List[str]: + if not values: + raise ValueError("At least one string must be provided") + if len(values) > 20: + raise ValueError("Saved knowledge entries cannot exceed 20 strings") + cleaned = [item.strip() for item in values if item.strip()] + if len(cleaned) != len(values): + raise ValueError("Saved knowledge entries must be non-empty strings") + return cleaned + + +class HighlightTextSpanArgs(ArgsModel): + start_phrase: str + end_phrase: str + button: str = Field(default="left") + + @field_validator("start_phrase", "end_phrase") + @classmethod + def _non_empty_phrase(cls, value: str) -> str: + cleaned = value.strip() + if not cleaned: + raise ValueError("Highlight phrases must be non-empty") + return cleaned + + @field_validator("button") + @classmethod + def _validate_button(cls, value: str) -> str: + allowed = {"left", "right", "middle"} + lowered = value.lower().strip() + if lowered not in allowed: + raise ValueError(f"Button must be one of {sorted(allowed)}") + return lowered + + +class SetCellValuesArgs(ArgsModel): + cell_values: Dict[str, Any] + app_name: str + sheet_name: str + + @field_validator("app_name", "sheet_name") + @classmethod + def _non_empty(cls, value: str) -> str: + cleaned = value.strip() + if not cleaned: + raise ValueError("Spreadsheet metadata must be non-empty strings") + return cleaned + + +class WaitArgs(ArgsModel): + seconds: float = Field(..., ge=0.0, le=30.0) + + +class DoneArgs(ArgsModel): + return_value: Optional[Any] = None + + +class FailArgs(ArgsModel): + reason: Optional[str] = None + + +class ActionType(str, Enum): + CLICK = "click" + DBLCLICK = "dblclick" + TYPE = "type" + HOTKEY = "hotkey" + WAIT = "wait" + SCROLL = "scroll" + SWITCH_APP = "switch_app" + FOCUS_APP = "focus_app" + SWITCH_APPLICATIONS = "switch_applications" + OPEN_APP = "open_app" + OPEN = "open" + DRAG_AND_DROP = "drag_and_drop" + SAVE_TO_KNOWLEDGE = "save_to_knowledge" + HIGHLIGHT_TEXT_SPAN = "highlight_text_span" + SET_CELL_VALUES = "set_cell_values" + HOLD_AND_PRESS = "hold_and_press" + DONE = "done" + FAIL = "fail" + + +ACTION_ARGS_BY_TYPE: Dict[ActionType, type[ArgsModel]] = { + ActionType.CLICK: ClickArgs, + ActionType.DBLCLICK: ClickArgs, + ActionType.TYPE: TypeArgs, + ActionType.HOTKEY: HotkeyArgs, + ActionType.WAIT: WaitArgs, + ActionType.SCROLL: ScrollArgs, + ActionType.SWITCH_APP: SwitchAppArgs, + ActionType.FOCUS_APP: SwitchAppArgs, + ActionType.SWITCH_APPLICATIONS: SwitchAppArgs, + ActionType.OPEN_APP: OpenAppArgs, + ActionType.OPEN: OpenAppArgs, + ActionType.DRAG_AND_DROP: DragAndDropArgs, + ActionType.SAVE_TO_KNOWLEDGE: SaveToKnowledgeArgs, + ActionType.HIGHLIGHT_TEXT_SPAN: HighlightTextSpanArgs, + ActionType.SET_CELL_VALUES: SetCellValuesArgs, + ActionType.HOLD_AND_PRESS: HoldAndPressArgs, + ActionType.DONE: DoneArgs, + ActionType.FAIL: FailArgs, +} + + +class AgentActionMeta(BaseModel): + idempotency_key: str + roi_hash: Optional[str] = None + explanation: Optional[str] = Field(default=None, max_length=280) + + class Config: + extra = "forbid" + + @field_validator("idempotency_key") + @classmethod + def _clean_idempotency(cls, value: str) -> str: + cleaned = value.strip() + if not cleaned: + raise ValueError("idempotency_key must be a non-empty string") + return cleaned + + @field_validator("roi_hash", "explanation") + @classmethod + def _trim_optional(cls, value: Optional[str]) -> Optional[str]: + if value is None: + return None + cleaned = value.strip() + return cleaned or None + + +class AgentAction(BaseModel): + type: ActionType + args: ArgsModel + meta: AgentActionMeta + + class Config: + extra = "forbid" + + @model_validator(mode="before") + @classmethod + def _apply_args_model(cls, data: Any) -> Any: + if not isinstance(data, dict): + raise AgentActionParseError("AgentAction payload must be a JSON object") + raw_type = data.get("type") + if raw_type is None: + raise AgentActionParseError("Missing action type") + try: + action_type = ActionType(raw_type) + except ValueError as exc: + raise AgentActionParseError(f"Unsupported action type: {raw_type}") from exc + + args_model = ACTION_ARGS_BY_TYPE.get(action_type) + if args_model is None: + raise AgentActionParseError(f"No argument model registered for action type: {action_type.value}") + + args_payload = data.get("args") or {} + try: + args_instance = args_model.model_validate(args_payload) + except Exception as exc: + raise AgentActionParseError(str(exc)) from exc + + data["type"] = action_type + data["args"] = args_instance + return data + + @model_validator(mode="after") + def _normalize_args(self) -> "AgentAction": + if self.type is ActionType.DBLCLICK and isinstance(self.args, ClickArgs): + if self.args.count < 2: + self.args.count = 2 + return self + + def to_json(self) -> str: + return json.dumps(self.model_dump(mode="json"), ensure_ascii=False) + + +_JSON_BLOCK_PREFIX = "```json" +_JSON_BLOCK_SUFFIX = "```" + + +def extract_agent_action_json(payload: str) -> str: + """Return the first JSON object inside a fenced block or raise.""" + if not payload: + raise AgentActionParseError("Empty response from model") + lowered = payload.lower() + start = lowered.find(_JSON_BLOCK_PREFIX) + if start != -1: + start = payload.find("{", start) + if start == -1: + raise AgentActionParseError("JSON block started but no object found") + end = payload.find(_JSON_BLOCK_SUFFIX, start) + if end == -1: + raise AgentActionParseError("JSON block not properly terminated") + return payload[start:end].strip() + stripped = payload.strip() + if stripped.startswith("{") and stripped.endswith("}"): + return stripped + raise AgentActionParseError("No JSON object found in model response") + + +def parse_agent_action(payload: str) -> AgentAction: + """Parse and validate an AgentAction from model output.""" + json_blob = extract_agent_action_json(payload) + try: + data = json.loads(json_blob) + except json.JSONDecodeError as exc: + raise AgentActionParseError(f"Invalid JSON action: {exc}") from exc + try: + return AgentAction.model_validate(data) + except AgentActionParseError: + raise + except Exception as exc: + raise AgentActionParseError(str(exc)) from exc + + +_TARGET_SELECTOR_SCHEMA: Dict[str, Any] = { + "type": "object", + "additionalProperties": False, + "properties": { + "a11y_id": {"type": "string"}, + "bbox": { + "type": "array", + "items": {"type": "integer"}, + "minItems": 4, + "maxItems": 4, + }, + "description": {"type": "string"}, + }, + "oneOf": [ + {"required": ["a11y_id"]}, + {"required": ["bbox"]}, + {"required": ["description"]}, + ], +} + + +ARGUMENT_SCHEMAS: Dict[str, Dict[str, Any]] = { + ActionType.CLICK.value: { + "type": "object", + "required": ["target"], + "additionalProperties": False, + "properties": { + "target": _TARGET_SELECTOR_SCHEMA, + "button": {"enum": ["left", "right", "middle"]}, + "count": {"type": "integer", "minimum": 1, "maximum": 4}, + "hold_keys": { + "type": "array", + "items": {"type": "string", "minLength": 1}, + "maxItems": 5, + }, + }, + }, + ActionType.DBLCLICK.value: { + "type": "object", + "required": ["target"], + "additionalProperties": False, + "properties": { + "target": _TARGET_SELECTOR_SCHEMA, + "button": {"enum": ["left", "right", "middle"]}, + "count": {"type": "integer", "minimum": 1, "maximum": 4}, + "hold_keys": { + "type": "array", + "items": {"type": "string", "minLength": 1}, + "maxItems": 5, + }, + }, + }, + ActionType.TYPE.value: { + "type": "object", + "required": ["text"], + "additionalProperties": False, + "properties": { + "text": {"type": "string", "maxLength": 512}, + "target": _TARGET_SELECTOR_SCHEMA, + "overwrite": {"type": "boolean"}, + "enter": {"type": "boolean"}, + }, + }, + ActionType.HOTKEY.value: { + "type": "object", + "required": ["keys"], + "additionalProperties": False, + "properties": { + "keys": { + "type": "array", + "items": {"type": "string", "minLength": 1}, + "minItems": 1, + "maxItems": 5, + } + }, + }, + ActionType.WAIT.value: { + "type": "object", + "required": ["seconds"], + "additionalProperties": False, + "properties": { + "seconds": {"type": "number", "minimum": 0.0, "maximum": 30.0} + }, + }, + ActionType.SCROLL.value: { + "type": "object", + "required": ["target", "clicks"], + "additionalProperties": False, + "properties": { + "target": _TARGET_SELECTOR_SCHEMA, + "clicks": {"type": "integer"}, + "shift": {"type": "boolean"}, + }, + }, + ActionType.SWITCH_APP.value: { + "type": "object", + "required": ["app_id"], + "additionalProperties": False, + "properties": {"app_id": {"type": "string", "minLength": 1}}, + }, + ActionType.FOCUS_APP.value: { + "type": "object", + "required": ["app_id"], + "additionalProperties": False, + "properties": {"app_id": {"type": "string", "minLength": 1}}, + }, + ActionType.SWITCH_APPLICATIONS.value: { + "type": "object", + "required": ["app_id"], + "additionalProperties": False, + "properties": {"app_id": {"type": "string", "minLength": 1}}, + }, + ActionType.OPEN_APP.value: { + "type": "object", + "required": ["app_name"], + "additionalProperties": False, + "properties": {"app_name": {"type": "string", "minLength": 1}}, + }, + ActionType.OPEN.value: { + "type": "object", + "required": ["app_name"], + "additionalProperties": False, + "properties": {"app_name": {"type": "string", "minLength": 1}}, + }, + ActionType.DRAG_AND_DROP.value: { + "type": "object", + "required": ["start", "end"], + "additionalProperties": False, + "properties": { + "start": _TARGET_SELECTOR_SCHEMA, + "end": _TARGET_SELECTOR_SCHEMA, + "hold_keys": { + "type": "array", + "items": {"type": "string", "minLength": 1}, + "maxItems": 4, + }, + }, + }, + ActionType.SAVE_TO_KNOWLEDGE.value: { + "type": "object", + "required": ["text"], + "additionalProperties": False, + "properties": { + "text": { + "type": "array", + "items": {"type": "string", "minLength": 1}, + "minItems": 1, + "maxItems": 20, + } + }, + }, + ActionType.HIGHLIGHT_TEXT_SPAN.value: { + "type": "object", + "required": ["start_phrase", "end_phrase"], + "additionalProperties": False, + "properties": { + "start_phrase": {"type": "string", "minLength": 1}, + "end_phrase": {"type": "string", "minLength": 1}, + "button": {"enum": ["left", "right", "middle"]}, + }, + }, + ActionType.SET_CELL_VALUES.value: { + "type": "object", + "required": ["cell_values", "app_name", "sheet_name"], + "additionalProperties": False, + "properties": { + "cell_values": { + "type": "object", + "additionalProperties": True, + "minProperties": 1, + }, + "app_name": {"type": "string", "minLength": 1}, + "sheet_name": {"type": "string", "minLength": 1}, + }, + }, + ActionType.HOLD_AND_PRESS.value: { + "type": "object", + "required": ["press_keys"], + "additionalProperties": False, + "properties": { + "hold_keys": { + "type": "array", + "items": {"type": "string", "minLength": 1}, + "maxItems": 5, + }, + "press_keys": { + "type": "array", + "items": {"type": "string", "minLength": 1}, + "minItems": 1, + "maxItems": 5, + }, + }, + }, + ActionType.DONE.value: { + "type": "object", + "additionalProperties": False, + "properties": { + "return_value": {} + }, + }, + ActionType.FAIL.value: { + "type": "object", + "additionalProperties": False, + "properties": { + "reason": {"type": "string"} + }, + }, +} + + +META_SCHEMA: Dict[str, Any] = { + "type": "object", + "required": ["idempotency_key"], + "additionalProperties": False, + "properties": { + "idempotency_key": {"type": "string", "minLength": 1}, + "roi_hash": {"type": "string", "minLength": 1}, + "explanation": {"type": "string", "maxLength": 280}, + }, +} + + +def _action_all_of() -> List[Dict[str, Any]]: + rules: List[Dict[str, Any]] = [] + for action, schema in ARGUMENT_SCHEMAS.items(): + rules.append( + { + "if": {"properties": {"type": {"const": action}}}, + "then": {"properties": {"args": schema}, "required": ["args"]}, + } + ) + return rules + + +AGENT_ACTION_JSON_SCHEMA: Dict[str, Any] = { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "AgentAction", + "type": "object", + "additionalProperties": False, + "required": ["type", "args", "meta"], + "properties": { + "type": {"enum": [member.value for member in ActionType]}, + "args": {"type": "object"}, + "meta": META_SCHEMA, + }, + "allOf": _action_all_of(), +} + + +AGENT_ACTION_RESPONSE_FORMAT: Dict[str, Any] = { + "type": "json_schema", + "json_schema": { + "name": "agent_action", + "schema": AGENT_ACTION_JSON_SCHEMA, + }, +} + + +SCHEMA_PROMPT_FRAGMENT = textwrap.dedent( + f""" + (Grounded Action) + Emit exactly one JSON object describing the next UI action. The JSON must appear inside a \"```json\" fenced block and conform to the AgentAction schema below. + ```json + {json.dumps(AGENT_ACTION_JSON_SCHEMA, indent=2)} + ``` + Do not add commentary before or after the JSON block. Never emit multiple JSON blocks. + """ +) + + +def agent_action_to_dict(action: AgentAction) -> Dict[str, Any]: + """Return a plain dict representation suitable for logging/telemetry.""" + return action.model_dump(mode="python") + +_JSON_BLOCK_PREFIX = "```json" +_JSON_BLOCK_SUFFIX = "```" + + +def extract_agent_action_json(payload: str) -> str: + """Return the first JSON object inside a fenced block or raise.""" + if not payload: + raise AgentActionParseError("Empty response from model") + lowered = payload.lower() + start = lowered.find(_JSON_BLOCK_PREFIX) + if start != -1: + start = payload.find("{", start) + if start == -1: + raise AgentActionParseError("JSON block started but no object found") + end = payload.find(_JSON_BLOCK_SUFFIX, start) + if end == -1: + raise AgentActionParseError("JSON block not properly terminated") + return payload[start:end].strip() + stripped = payload.strip() + if stripped.startswith("{") and stripped.endswith("}"): + return stripped + raise AgentActionParseError("No JSON object found in model response") + + +def parse_agent_action(payload: str) -> AgentAction: + """Parse and validate an AgentAction from model output.""" + json_blob = extract_agent_action_json(payload) + try: + data = json.loads(json_blob) + except json.JSONDecodeError as exc: + raise AgentActionParseError(f"Invalid JSON action: {exc}") from exc + try: + return AgentAction.model_validate(data) + except AgentActionParseError: + raise + except Exception as exc: + raise AgentActionParseError(str(exc)) from exc diff --git a/gui_agents/s1/core/AgentS.py b/gui_agents/s1/core/AgentS.py index 23e1569e..700c1d3a 100644 --- a/gui_agents/s1/core/AgentS.py +++ b/gui_agents/s1/core/AgentS.py @@ -251,7 +251,7 @@ def predict(self, instruction: str, observation: Dict) -> Tuple[Dict, List[str]] if "FAIL" in actions: self.requires_replan = True # set the failure feedback to the evaluator feedback - self.failure_feedback = f"Completed subtasks: {self.completed_tasks}. The subtask {self.current_subtask} cannot be completed. Please try another approach. {executor_info['plan_code']}. Please replan." + self.failure_feedback = f"Completed subtasks: {self.completed_tasks}. The subtask {self.current_subtask} cannot be completed. Please try another approach. {executor_info.get('executor_plan', '')}. Please replan." self.needs_next_subtask = True # reset the step count, executor, and evaluator diff --git a/gui_agents/s1/core/ProceduralMemory.py b/gui_agents/s1/core/ProceduralMemory.py index 72f2dc53..7144ac2e 100644 --- a/gui_agents/s1/core/ProceduralMemory.py +++ b/gui_agents/s1/core/ProceduralMemory.py @@ -1,6 +1,8 @@ import inspect import textwrap +from gui_agents.common import SCHEMA_PROMPT_FRAGMENT + class PROCEDURAL_MEMORY: @staticmethod @@ -30,7 +32,7 @@ def {attr_name}{signature}: """ procedural_memory += textwrap.dedent( - """ + f""" Your response should be formatted like this: (Previous action verification) Carefully analyze based on the screenshot and the accessibility tree if the previous action was successful. If the previous action was not successful, provide a reason for the failure. @@ -42,21 +44,21 @@ def {attr_name}{signature}: Based on the current screenshot, the accessibility tree and the history of your previous interaction with the UI, decide on the next action in natural language to accomplish the given task. (Grounded Action) - Translate the next action into code using the provided API methods. Format the code like this: - ```python - agent.click(123, 1, "left") - ``` - Note for the code: + Translate the next action into a JSON object that conforms to the shared AgentAction schema. + {SCHEMA_PROMPT_FRAGMENT} + Additional requirements: 1. Only perform one action at a time. - 2. Do not put anything other than python code in the block. You can only use one function call at a time. Do not put more than one function call in the block. - 3. You must use only the available methods provided above to interact with the UI, do not invent new methods. - 3. Only return one code block every time. There must be a single line of code in the code block. - 4. Please only use the available methods provided above to interact with the UI. - 5. If you think the task is already completed, you can return `agent.done()` in the code block. - 6. If you think the task cannot be completed, you can return `agent.fail()` in the code block. - 7. Do not do anything other than the exact specified task. Return with `agent.done()` immediately after the task is completed or `agent.fail()` if it cannot be completed. - 8. Whenever possible use hot-keys or typing rather than mouse clicks. - 9. My computer's password is 'password', feel free to use it when you need sudo rights + 2. Emit exactly one ```json fenced block without any commentary before or after it. + 3. Set `type` to the precise UI skill you intend to execute (e.g. "click", "drag_and_drop", "wait", "done", "fail"). + 4. Populate `args` with the minimal fields required by the schema. Prefer `target.a11y_id`, fall back to a tight `bbox` or a high-confidence description only when necessary. + 5. Populate `meta.idempotency_key` with a unique value for this turn so the executor can safely deduplicate retries. + 6. Summarize the intent in `meta.explanation` (<=280 characters) so humans can audit the trajectory quickly. + 7. Use `hotkey` or `hold_and_press` instead of mouse interactions whenever a reliable shortcut exists. + 8. Emit `done` when the task is fully complete and `fail` when it is impossible to continue. + 9. Use `wait` with a small positive `seconds` value whenever the UI needs time to respond. + 10. Use `save_to_knowledge` when you need to persist text for later turns. + 11. My computer's password is 'password'; feel free to use it when sudo rights are required. + 12. Do not use the "command" + "tab" hotkey on MacOS. """ ) return procedural_memory.strip() diff --git a/gui_agents/s1/core/Worker.py b/gui_agents/s1/core/Worker.py index 2f6a8879..c4d42bbc 100644 --- a/gui_agents/s1/core/Worker.py +++ b/gui_agents/s1/core/Worker.py @@ -1,3 +1,4 @@ +import json import logging import os import re @@ -8,7 +9,13 @@ from gui_agents.s1.core.BaseModule import BaseModule from gui_agents.s1.core.Knowledge import KnowledgeBase from gui_agents.s1.core.ProceduralMemory import PROCEDURAL_MEMORY -from gui_agents.s1.utils import common_utils +from gui_agents.common import ( + AGENT_ACTION_RESPONSE_FORMAT, + AgentActionParseError, + agent_action_to_dict, + execute_agent_action, + parse_agent_action, +) from gui_agents.s1.utils.common_utils import Node, calculate_tokens, call_llm_safe logger = logging.getLogger("desktopenv.agent") @@ -208,7 +215,11 @@ def generate_next_action( generator_message, image_content=obs["screenshot"] ) - plan = call_llm_safe(self.generator_agent) + response_kwargs = {} + engine_type = self.engine_params.get("engine_type") + if engine_type in {"openai", "azure"}: + response_kwargs["response_format"] = AGENT_ACTION_RESPONSE_FORMAT + plan = call_llm_safe(self.generator_agent, **response_kwargs) self.planner_history.append(plan) logger.info("PLAN: %s", plan) @@ -222,19 +233,25 @@ def generate_next_action( self.cost_this_turn += cost logger.info("EXECTUOR COST: %s", self.cost_this_turn) - # Extract code block from the plan - plan_code = common_utils.parse_single_code_from_string( - plan.split("Grounded Action")[-1] - ) - plan_code = common_utils.sanitize_code(plan_code) - plan_code = common_utils.extract_first_agent_function(plan_code) - exec_code = eval(plan_code) + action_payload = None + action_json = None + parse_error = None + try: + agent_action = parse_agent_action(plan) + action_payload = agent_action_to_dict(agent_action) + action_json = json.dumps(action_payload, ensure_ascii=False) + exec_code = execute_agent_action(agent, agent_action, obs) + except AgentActionParseError as err: + parse_error = str(err) + logger.error("Error parsing AgentAction: %s", err) + exec_code = agent.wait(1.0) + except Exception as err: + parse_error = str(err) + logger.error("Error executing AgentAction: %s", err) + exec_code = agent.wait(1.0) - # If agent selects an element that was out of range, it should not be executed just send a WAIT command. - # TODO: should provide this as code feedback to the agent? if agent.index_out_of_range_flag: - plan_code = "agent.wait(1.0)" - exec_code = eval(plan_code) + exec_code = agent.wait(1.0) agent.index_out_of_range_flag = False executor_info = { @@ -242,7 +259,9 @@ def generate_next_action( "current_subtask_info": subtask_info, "executor_plan": plan, "linearized_accessibility_tree": tree_input, - "plan_code": plan_code, + "action": action_payload, + "action_json": action_json, + "parse_error": parse_error, "reflection": reflection, "num_input_tokens_executor": input_tokens, "num_output_tokens_executor": output_tokens, diff --git a/gui_agents/s1/utils/common_utils.py b/gui_agents/s1/utils/common_utils.py index a0bc52fd..30059b5b 100644 --- a/gui_agents/s1/utils/common_utils.py +++ b/gui_agents/s1/utils/common_utils.py @@ -55,14 +55,14 @@ class Dag(BaseModel): NUM_IMAGE_TOKEN = 1105 # Value set of screen of size 1920x1080 for openai vision -def call_llm_safe(agent) -> Union[str, Dag]: +def call_llm_safe(agent, **generation_kwargs) -> Union[str, Dag]: # Retry if fails max_retries = 3 # Set the maximum number of retries attempt = 0 response = "" while attempt < max_retries: try: - response = agent.get_response() + response = agent.get_response(**generation_kwargs) break # If successful, break out of the loop except Exception as e: attempt += 1 diff --git a/gui_agents/s2/agents/agent_s.py b/gui_agents/s2/agents/agent_s.py index c8e7d18e..8ecae358 100644 --- a/gui_agents/s2/agents/agent_s.py +++ b/gui_agents/s2/agents/agent_s.py @@ -256,7 +256,8 @@ def predict(self, instruction: str, observation: Dict) -> Tuple[Dict, List[str]] self.subtask_status = "Done" executor_info = { "executor_plan": "agent.done()", - "plan_code": "agent.done()", + "action": {"type": "done"}, + "parse_error": None, "reflection": "agent.done()", } actions = ["DONE"] diff --git a/gui_agents/s2/agents/worker.py b/gui_agents/s2/agents/worker.py index 4f218cf9..1647b4c5 100644 --- a/gui_agents/s2/agents/worker.py +++ b/gui_agents/s2/agents/worker.py @@ -1,3 +1,4 @@ +import json import logging import re import textwrap @@ -8,13 +9,17 @@ from gui_agents.s2.core.module import BaseModule from gui_agents.s2.core.knowledge import KnowledgeBase from gui_agents.s2.memory.procedural_memory import PROCEDURAL_MEMORY +from gui_agents.common import ( + AGENT_ACTION_RESPONSE_FORMAT, + AgentActionParseError, + agent_action_to_dict, + execute_agent_action, + parse_agent_action, +) from gui_agents.s2.utils.common_utils import ( Node, calculate_tokens, call_llm_safe, - parse_single_code_from_string, - sanitize_code, - extract_first_agent_function, ) logger = logging.getLogger("desktopenv.agent") @@ -206,7 +211,11 @@ def generate_next_action( generator_message, image_content=obs["screenshot"], role="user" ) - plan = call_llm_safe(self.generator_agent) + response_kwargs = {} + engine_type = self.engine_params.get("engine_type") + if engine_type in {"openai", "azure"}: + response_kwargs["response_format"] = AGENT_ACTION_RESPONSE_FORMAT + plan = call_llm_safe(self.generator_agent, **response_kwargs) self.planner_history.append(plan) logger.info("PLAN: %s", plan) self.generator_agent.add_message(plan, role="assistant") @@ -217,23 +226,31 @@ def generate_next_action( self.cost_this_turn += cost logger.info("EXECTUOR COST: %s", self.cost_this_turn) - # Use the DescriptionBasedACI to convert agent_action("desc") into agent_action([x, y]) + # Parse the structured AgentAction payload and dispatch through the grounding interface + action_payload = None + action_json = None + parse_error = None try: - agent.assign_coordinates(plan, obs) - plan_code = parse_single_code_from_string(plan.split("Grounded Action")[-1]) - plan_code = sanitize_code(plan_code) - plan_code = extract_first_agent_function(plan_code) - exec_code = eval(plan_code) - except Exception as e: - logger.error("Error in parsing plan code: %s", e) - plan_code = "agent.wait(1.0)" - exec_code = eval(plan_code) + agent_action = parse_agent_action(plan) + action_payload = agent_action_to_dict(agent_action) + action_json = json.dumps(action_payload, ensure_ascii=False) + exec_code = execute_agent_action(agent, agent_action, obs) + except AgentActionParseError as err: + parse_error = str(err) + logger.error("Error parsing AgentAction: %s", err) + exec_code = agent.wait(1.0) + except Exception as err: + parse_error = str(err) + logger.error("Error executing AgentAction: %s", err) + exec_code = agent.wait(1.0) executor_info = { "current_subtask": subtask, "current_subtask_info": subtask_info, "executor_plan": plan, - "plan_code": plan_code, + "action": action_payload, + "action_json": action_json, + "parse_error": parse_error, "reflection": reflection, "num_input_tokens_executor": input_tokens, "num_output_tokens_executor": output_tokens, @@ -249,8 +266,13 @@ def generate_next_action( def clean_worker_generation_for_reflection(self, worker_generation: str) -> str: # Remove the previous action verification res = worker_generation[worker_generation.find("(Screenshot Analysis)") :] - action = extract_first_agent_function(worker_generation) - # Cut off extra grounded actions - res = res[: res.find("(Grounded Action)")] - res += f"(Grounded Action)\n```python\n{action}\n```\n" + try: + agent_action = parse_agent_action(worker_generation) + action_json = json.dumps(agent_action_to_dict(agent_action), ensure_ascii=False) + except AgentActionParseError: + action_json = "{}" + ground_idx = res.find("(Grounded Action)") + if ground_idx != -1: + res = res[:ground_idx] + res += f"(Grounded Action)\n```json\n{action_json}\n```\n" return res diff --git a/gui_agents/s2/memory/procedural_memory.py b/gui_agents/s2/memory/procedural_memory.py index 32739ca4..bfcc0b35 100644 --- a/gui_agents/s2/memory/procedural_memory.py +++ b/gui_agents/s2/memory/procedural_memory.py @@ -1,6 +1,8 @@ import inspect import textwrap +from gui_agents.common import SCHEMA_PROMPT_FRAGMENT + class PROCEDURAL_MEMORY: @@ -33,7 +35,7 @@ def {attr_name}{signature}: """ procedural_memory += textwrap.dedent( - """ + f""" Your response should be formatted like this: (Previous action verification) Carefully analyze based on the screenshot if the previous action was successful. If the previous action was not successful, provide a reason for the failure. @@ -45,21 +47,21 @@ def {attr_name}{signature}: Based on the current screenshot and the history of your previous interaction with the UI, decide on the next action in natural language to accomplish the given task. (Grounded Action) - Translate the next action into code using the provided API methods. Format the code like this: - ```python - agent.click("The menu button at the top right of the window", 1, "left") - ``` - Note for the code: + Translate the next action into a JSON object that conforms to the shared AgentAction schema. + {SCHEMA_PROMPT_FRAGMENT} + Additional requirements: 1. Only perform one action at a time. - 2. Do not put anything other than python code in the block. You can only use one function call at a time. Do not put more than one function call in the block. - 3. You must use only the available methods provided above to interact with the UI, do not invent new methods. - 4. Only return one code block every time. There must be a single line of code in the code block. - 5. If you think the task is already completed, return `agent.done()` in the code block. - 6. If you think the task cannot be completed, return `agent.fail()` in the code block. - 7. Do not do anything other than the exact specified task. Return with `agent.done()` immediately after the task is completed or `agent.fail()` if it cannot be completed. - 8. Whenever possible, your grounded action should use hot-keys with the agent.hotkey() action instead of clicking or dragging. - 9. My computer's password is 'password', feel free to use it when you need sudo rights. - 10. Do not use the "command" + "tab" hotkey on MacOS. + 2. Emit exactly one ```json fenced block without any commentary before or after it. + 3. Set `type` to the precise UI skill you intend to execute (e.g. "click", "drag_and_drop", "wait", "done", "fail"). + 4. Populate `args` with the minimal fields required by the schema. Prefer `target.a11y_id`, fall back to a tight `bbox` or a high-confidence description only when necessary. + 5. Populate `meta.idempotency_key` with a unique value for this turn so the executor can safely deduplicate retries. + 6. Summarize the intent in `meta.explanation` (<=280 characters) so humans can audit the trajectory quickly. + 7. Use `hotkey` or `hold_and_press` instead of mouse interactions whenever a reliable shortcut exists. + 8. Emit `done` when the task is fully complete and `fail` when it is impossible to continue. + 9. Use `wait` with a small positive `seconds` value whenever the UI needs time to respond. + 10. Use `save_to_knowledge` when you need to persist text for later turns. + 11. My computer's password is 'password'; feel free to use it when sudo rights are required. + 12. Do not use the "command" + "tab" hotkey on MacOS. """ ) diff --git a/gui_agents/s2/utils/common_utils.py b/gui_agents/s2/utils/common_utils.py index 1873aa1a..e3edefa3 100644 --- a/gui_agents/s2/utils/common_utils.py +++ b/gui_agents/s2/utils/common_utils.py @@ -24,14 +24,14 @@ class Dag(BaseModel): NUM_IMAGE_TOKEN = 1105 # Value set of screen of size 1920x1080 for openai vision -def call_llm_safe(agent) -> Union[str, Dag]: +def call_llm_safe(agent, **generation_kwargs) -> Union[str, Dag]: # Retry if fails max_retries = 3 # Set the maximum number of retries attempt = 0 response = "" while attempt < max_retries: try: - response = agent.get_response() + response = agent.get_response(**generation_kwargs) break # If successful, break out of the loop except Exception as e: attempt += 1 diff --git a/gui_agents/s2_5/agents/worker.py b/gui_agents/s2_5/agents/worker.py index 9681f714..7dcfa31c 100644 --- a/gui_agents/s2_5/agents/worker.py +++ b/gui_agents/s2_5/agents/worker.py @@ -1,3 +1,4 @@ +import json import logging import textwrap from typing import Dict, List, Tuple @@ -5,11 +6,15 @@ from gui_agents.s2_5.agents.grounding import ACI from gui_agents.s2_5.core.module import BaseModule from gui_agents.s2_5.memory.procedural_memory import PROCEDURAL_MEMORY +from gui_agents.common import ( + AGENT_ACTION_RESPONSE_FORMAT, + AgentActionParseError, + agent_action_to_dict, + execute_agent_action, + parse_agent_action, +) from gui_agents.s2_5.utils.common_utils import ( call_llm_safe, - extract_first_agent_function, - parse_single_code_from_string, - sanitize_code, split_thinking_response, ) @@ -167,10 +172,15 @@ def generate_next_action( generator_message, image_content=obs["screenshot"], role="user" ) + response_kwargs = {} + engine_type = self.engine_params.get("engine_type") + if engine_type in {"openai", "azure"}: + response_kwargs["response_format"] = AGENT_ACTION_RESPONSE_FORMAT full_plan = call_llm_safe( self.generator_agent, temperature=self.temperature, use_thinking=self.use_thinking, + **response_kwargs, ) plan, plan_thoughts = split_thinking_response(full_plan) # NOTE: currently dropping thinking tokens from context @@ -178,23 +188,30 @@ def generate_next_action( logger.info("FULL PLAN:\n %s", full_plan) self.generator_agent.add_message(plan, role="assistant") - # Use the grounding agent to convert agent_action("desc") into agent_action([x, y]) + action_payload = None + action_json = None + parse_error = None try: - agent.assign_coordinates(plan, obs) - plan_code = parse_single_code_from_string(plan.split("Grounded Action")[-1]) - plan_code = sanitize_code(plan_code) - plan_code = extract_first_agent_function(plan_code) - exec_code = eval(plan_code) - except Exception as e: - logger.error("Error in parsing plan code: %s", e) - plan_code = "agent.wait(1.0)" - exec_code = eval(plan_code) + agent_action = parse_agent_action(plan) + action_payload = agent_action_to_dict(agent_action) + action_json = json.dumps(action_payload, ensure_ascii=False) + exec_code = execute_agent_action(agent, agent_action, obs) + except AgentActionParseError as err: + parse_error = str(err) + logger.error("Error parsing AgentAction: %s", err) + exec_code = agent.wait(1.0) + except Exception as err: # Dispatch errors + parse_error = str(err) + logger.error("Error executing AgentAction: %s", err) + exec_code = agent.wait(1.0) executor_info = { "full_plan": full_plan, "executor_plan": plan, "plan_thoughts": plan_thoughts, - "plan_code": plan_code, + "action": action_payload, + "action_json": action_json, + "parse_error": parse_error, "reflection": reflection, "reflection_thoughts": reflection_thoughts, } diff --git a/gui_agents/s2_5/core/engine.py b/gui_agents/s2_5/core/engine.py index 3b17de60..ff06a533 100644 --- a/gui_agents/s2_5/core/engine.py +++ b/gui_agents/s2_5/core/engine.py @@ -11,6 +11,25 @@ RateLimitError, ) +def _message_to_text(message): + content = getattr(message, "content", message) + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for item in content: + if isinstance(item, dict): + text = item.get("text") + if text: + parts.append(text) + elif item.get("type") == "output_text" and item.get("text"): + parts.append(item["text"]) + elif hasattr(item, "text") and getattr(item, "text"): + parts.append(getattr(item, "text")) + return "".join(parts) + return str(content) + + class LMMEngine: pass @@ -53,21 +72,16 @@ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs): self.llm_client = OpenAI( base_url=self.base_url, api_key=api_key, organization=organization ) - return ( - self.llm_client.chat.completions.create( - model=self.model, - messages=messages, - max_completion_tokens=max_new_tokens if max_new_tokens else 4096, - temperature=( - temperature if self.temperature is None else self.temperature - ), - **kwargs, - ) - .choices[0] - .message.content + completion = self.llm_client.chat.completions.create( + model=self.model, + messages=messages, + max_completion_tokens=max_new_tokens if max_new_tokens else 4096, + temperature=( + temperature if self.temperature is None else self.temperature + ), + **kwargs, ) - - + return _message_to_text(completion.choices[0].message) class LMMEngineAnthropic(LMMEngine): def __init__( self, @@ -182,21 +196,15 @@ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs): ) if not self.llm_client: self.llm_client = OpenAI(base_url=base_url, api_key=api_key) - # Use the temperature passed to generate, otherwise use the instance's temperature, otherwise default to 0.0 temp = self.temperature if temperature is None else temperature - return ( - self.llm_client.chat.completions.create( - model=self.model, - messages=messages, - max_tokens=max_new_tokens if max_new_tokens else 4096, - temperature=temp, - **kwargs, - ) - .choices[0] - .message.content + completion = self.llm_client.chat.completions.create( + model=self.model, + messages=messages, + max_tokens=max_new_tokens if max_new_tokens else 4096, + temperature=temp, + **kwargs, ) - - + return _message_to_text(completion.choices[0].message) class LMMEngineOpenRouter(LMMEngine): def __init__( self, @@ -219,33 +227,22 @@ def __init__( backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60 ) def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs): - api_key = self.api_key or os.getenv("OPENROUTER_API_KEY") + api_key = self.api_key or os.getenv("OPEN_ROUTER_API_KEY") if api_key is None: raise ValueError( - "An API Key needs to be provided in either the api_key parameter or as an environment variable named OPENROUTER_API_KEY" - ) - base_url = self.base_url or os.getenv("OPEN_ROUTER_ENDPOINT_URL") - if base_url is None: - raise ValueError( - "An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named OPEN_ROUTER_ENDPOINT_URL" + "An API Key needs to be provided in either the api_key parameter or as an environment variable named OPEN_ROUTER_API_KEY" ) if not self.llm_client: - self.llm_client = OpenAI(base_url=base_url, api_key=api_key) - # Use self.temperature if set, otherwise use the temperature argument - temp = self.temperature if self.temperature is not None else temperature - return ( - self.llm_client.chat.completions.create( - model=self.model, - messages=messages, - max_tokens=max_new_tokens if max_new_tokens else 4096, - temperature=temp, - **kwargs, - ) - .choices[0] - .message.content + self.llm_client = OpenAI(base_url=self.base_url, api_key=api_key) + temp = self.temperature if temperature is None else temperature + completion = self.llm_client.chat.completions.create( + model=self.model, + messages=messages, + max_tokens=max_new_tokens if max_new_tokens else 4096, + temperature=temp, + **kwargs, ) - - + return _message_to_text(completion.choices[0].message) class LMMEngineAzureOpenAI(LMMEngine): def __init__( self, @@ -293,7 +290,6 @@ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs): api_key=api_key, api_version=api_version, ) - # Use self.temperature if set, otherwise use the temperature argument temp = self.temperature if self.temperature is not None else temperature completion = self.llm_client.chat.completions.create( model=self.model, @@ -304,9 +300,7 @@ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs): ) total_tokens = completion.usage.total_tokens self.cost += 0.02 * ((total_tokens + 500) / 1000) - return completion.choices[0].message.content - - + return _message_to_text(completion.choices[0].message) class LMMEnginevLLM(LMMEngine): def __init__( self, @@ -349,7 +343,6 @@ def generate( ) if not self.llm_client: self.llm_client = OpenAI(base_url=base_url, api_key=api_key) - # Use self.temperature if set, otherwise use the temperature argument temp = self.temperature if self.temperature is not None else temperature completion = self.llm_client.chat.completions.create( model=self.model, @@ -359,9 +352,7 @@ def generate( top_p=top_p, extra_body={"repetition_penalty": repetition_penalty}, ) - return completion.choices[0].message.content - - + return _message_to_text(completion.choices[0].message) class LMMEngineHuggingFace(LMMEngine): def __init__(self, base_url=None, api_key=None, rate_limit=-1, **kwargs): self.base_url = base_url @@ -385,19 +376,14 @@ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs): ) if not self.llm_client: self.llm_client = OpenAI(base_url=base_url, api_key=api_key) - return ( - self.llm_client.chat.completions.create( - model="tgi", - messages=messages, - max_tokens=max_new_tokens if max_new_tokens else 4096, - temperature=temperature, - **kwargs, - ) - .choices[0] - .message.content + completion = self.llm_client.chat.completions.create( + model="tgi", + messages=messages, + max_tokens=max_new_tokens if max_new_tokens else 4096, + temperature=temperature, + **kwargs, ) - - + return _message_to_text(completion.choices[0].message) class LMMEngineParasail(LMMEngine): def __init__( self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs @@ -428,14 +414,11 @@ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs): base_url=base_url if base_url else "https://api.parasail.io/v1", api_key=api_key, ) - return ( - self.llm_client.chat.completions.create( - model=self.model, - messages=messages, - max_tokens=max_new_tokens if max_new_tokens else 4096, - temperature=temperature, - **kwargs, - ) - .choices[0] - .message.content + completion = self.llm_client.chat.completions.create( + model=self.model, + messages=messages, + max_tokens=max_new_tokens if max_new_tokens else 4096, + temperature=temperature, + **kwargs, ) + return _message_to_text(completion.choices[0].message) diff --git a/gui_agents/s2_5/memory/procedural_memory.py b/gui_agents/s2_5/memory/procedural_memory.py index 273c5111..e40fa7ac 100644 --- a/gui_agents/s2_5/memory/procedural_memory.py +++ b/gui_agents/s2_5/memory/procedural_memory.py @@ -1,6 +1,8 @@ import inspect import textwrap +from gui_agents.common import SCHEMA_PROMPT_FRAGMENT + class PROCEDURAL_MEMORY: @staticmethod @@ -31,7 +33,7 @@ def {attr_name}{signature}: """ procedural_memory += textwrap.dedent( - """ + f""" Your response should be formatted like this: (Previous action verification) Carefully analyze based on the screenshot if the previous action was successful. If the previous action was not successful, provide a reason for the failure. @@ -43,21 +45,21 @@ def {attr_name}{signature}: Based on the current screenshot and the history of your previous interaction with the UI, decide on the next action in natural language to accomplish the given task. (Grounded Action) - Translate the next action into code using the provided API methods. Format the code like this: - ```python - agent.click("The menu button at the top right of the window", 1, "left") - ``` - Note for the code: + Translate the next action into a JSON object that conforms to the shared AgentAction schema. + {SCHEMA_PROMPT_FRAGMENT} + Additional requirements: 1. Only perform one action at a time. - 2. Do not put anything other than python code in the block. You can only use one function call at a time. Do not put more than one function call in the block. - 3. You must use only the available methods provided above to interact with the UI, do not invent new methods. - 4. Only return one code block every time. There must be a single line of code in the code block. - 5. Do not do anything other than the exact specified task. Return with `agent.done()` immediately after the subtask is completed or `agent.fail()` if it cannot be completed. - 6. Whenever possible, your grounded action should use hot-keys with the agent.hotkey() action instead of clicking or dragging. - 7. My computer's password is 'osworld-public-evaluation', feel free to use it when you need sudo rights. - 8. Generate agent.fail() as your grounded action if you get exhaustively stuck on the task and believe it is impossible. - 9. Generate agent.done() as your grounded action when your believe the task is fully complete. - 10. Do not use the "command" + "tab" hotkey on MacOS. + 2. Emit exactly one ```json fenced block without any commentary before or after it. + 3. Set `type` to the precise UI skill you intend to execute (e.g. "click", "drag_and_drop", "wait", "done", "fail"). + 4. Populate `args` with the minimal fields required by the schema. Prefer `target.a11y_id`, fall back to a tight `bbox` or a high-confidence description only when necessary. + 5. Populate `meta.idempotency_key` with a unique value for this turn so the executor can safely deduplicate retries. + 6. Summarize the intent in `meta.explanation` (<=280 characters) so humans can audit the trajectory quickly. + 7. Use `hotkey` or `hold_and_press` instead of mouse interactions whenever a reliable shortcut exists. + 8. Emit `done` when the task is fully complete and `fail` when it is impossible to continue. + 9. Use `wait` with a small positive `seconds` value whenever the UI needs time to respond. + 10. Use `save_to_knowledge` when you need to persist text for later turns. + 11. My computer's password is 'osworld-public-evaluation'; feel free to use it when sudo rights are required. + 12. Do not use the "command" + "tab" hotkey on MacOS. """ ) diff --git a/gui_agents/s2_5/utils/common_utils.py b/gui_agents/s2_5/utils/common_utils.py index 0b37835a..bd8e7fea 100644 --- a/gui_agents/s2_5/utils/common_utils.py +++ b/gui_agents/s2_5/utils/common_utils.py @@ -4,7 +4,7 @@ from typing import Tuple -def call_llm_safe(agent, temperature: float = 0.0, use_thinking: bool = False) -> str: +def call_llm_safe(agent, temperature: float = 0.0, use_thinking: bool = False, **generation_kwargs) -> str: # Retry if fails max_retries = 3 # Set the maximum number of retries attempt = 0 @@ -12,7 +12,8 @@ def call_llm_safe(agent, temperature: float = 0.0, use_thinking: bool = False) - while attempt < max_retries: try: response = agent.get_response( - temperature=temperature, use_thinking=use_thinking + temperature=temperature, use_thinking=use_thinking, + **generation_kwargs ) assert response is not None, "Response from agent should not be None" print("Response success!") diff --git a/requirements.txt b/requirements.txt index 8d1dc9f9..22634461 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,4 +20,5 @@ google-genai # Platform-specific dependencies pyobjc; platform_system == "Darwin" pywinauto; platform_system == "Windows" -pywin32; platform_system == "Windows" \ No newline at end of file +pywin32; platform_system == "Windows" +pydantic>=2.6,<3 diff --git a/setup.py b/setup.py index 28ba61ed..9d99cddc 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ "pyautogui", "toml", "pytesseract", + "pydantic>=2.6,<3", "google-genai", 'pywinauto; platform_system == "Windows"', # Only for Windows 'pywin32; platform_system == "Windows"', # Only for Windows @@ -54,3 +55,4 @@ }, python_requires=">=3.9, <=3.12", ) +