diff --git a/src/agentlab/agents/agent_args.py b/src/agentlab/agents/agent_args.py index b2cd0eb6..78036c31 100644 --- a/src/agentlab/agents/agent_args.py +++ b/src/agentlab/agents/agent_args.py @@ -1,5 +1,5 @@ import bgym -from bgym import AbstractAgentArgs +from bgym import AbstractAgentArgs, Benchmark class AgentArgs(AbstractAgentArgs): @@ -14,7 +14,7 @@ class MyAgentArgs(AgentArgs): Note: for working properly with AgentXRay, the arguments need to be serializable and hasable. """ - def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode: bool): + def set_benchmark(self, benchmark: Benchmark, demo_mode: bool): """Optional method to set benchmark specific flags. This allows the agent to have minor adjustments based on the benchmark. diff --git a/src/agentlab/agents/agent_utils.py b/src/agentlab/agents/agent_utils.py new file mode 100644 index 00000000..b91aab26 --- /dev/null +++ b/src/agentlab/agents/agent_utils.py @@ -0,0 +1,44 @@ +from PIL import Image, ImageDraw +from logging import warning + + +""" +This module contains utility functions for handling observations and actions in the context of agent interactions. +""" + + +def tag_screenshot_with_action(screenshot: Image, action: str) -> Image: + """ + If action is a coordinate action, try to render it on the screenshot. + + e.g. mouse_click(120, 130) -> draw a dot at (120, 130) on the screenshot + + Args: + screenshot: The screenshot to tag. + action: The action to tag the screenshot with. + + Returns: + The tagged screenshot. + + Raises: + ValueError: If the action parsing fails. + """ + if action.startswith("mouse_click"): + try: + coords = action[action.index("(") + 1 : action.index(")")].split(",") + coords = [c.strip() for c in coords] + if len(coords) not in [2, 3]: + raise ValueError(f"Invalid coordinate format: {coords}") + if coords[0].startswith("x="): + coords[0] = coords[0][2:] + if coords[1].startswith("y="): + coords[1] = coords[1][2:] + x, y = float(coords[0].strip()), float(coords[1].strip()) + draw = ImageDraw.Draw(screenshot) + radius = 5 + draw.ellipse( + (x - radius, y - radius, x + radius, y + radius), fill="blue", outline="blue" + ) + except (ValueError, IndexError) as e: + warning(f"Failed to parse action '{action}': {e}") + return screenshot diff --git a/src/agentlab/agents/dynamic_prompting.py b/src/agentlab/agents/dynamic_prompting.py index 2cb474e9..92ad25b9 100644 --- a/src/agentlab/agents/dynamic_prompting.py +++ b/src/agentlab/agents/dynamic_prompting.py @@ -9,13 +9,9 @@ from warnings import warn import bgym +from bgym import HighLevelActionSetArgs from browsergym.core.action.base import AbstractActionSet -from browsergym.utils.obs import ( - flatten_axtree_to_str, - flatten_dom_to_str, - overlay_som, - prune_html, -) +from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, overlay_som, prune_html from agentlab.llm.llm_utils import ( BaseMessage, @@ -99,7 +95,7 @@ class ObsFlags(Flags): @dataclass class ActionFlags(Flags): - action_set: bgym.HighLevelActionSetArgs = None # should be set by the set_benchmark method + action_set: HighLevelActionSetArgs = None # should be set by the set_benchmark method long_description: bool = True individual_examples: bool = False diff --git a/src/agentlab/agents/generic_agent/agent_configs.py b/src/agentlab/agents/generic_agent/agent_configs.py index 914e3249..f50367d8 100644 --- a/src/agentlab/agents/generic_agent/agent_configs.py +++ b/src/agentlab/agents/generic_agent/agent_configs.py @@ -3,6 +3,7 @@ """ import bgym +from bgym import HighLevelActionSetArgs from agentlab.agents import dynamic_prompting as dp from agentlab.experiments import args @@ -32,7 +33,7 @@ filter_visible_elements_only=False, ), action=dp.ActionFlags( - action_set=bgym.HighLevelActionSetArgs( + action_set=HighLevelActionSetArgs( subsets=["bid"], multiaction=False, ), @@ -80,7 +81,7 @@ filter_visible_elements_only=False, ), action=dp.ActionFlags( - action_set=bgym.HighLevelActionSetArgs( + action_set=HighLevelActionSetArgs( subsets=["bid"], multiaction=False, ), @@ -127,7 +128,7 @@ filter_visible_elements_only=False, ), action=dp.ActionFlags( - action_set=bgym.HighLevelActionSetArgs( + action_set=HighLevelActionSetArgs( subsets=["bid"], multiaction=False, ), @@ -177,7 +178,7 @@ filter_visible_elements_only=False, ), action=dp.ActionFlags( - action_set=bgym.HighLevelActionSetArgs( + action_set=HighLevelActionSetArgs( subsets=["bid"], multiaction=True, ), @@ -232,7 +233,7 @@ filter_visible_elements_only=False, ), action=dp.ActionFlags( - action_set=bgym.HighLevelActionSetArgs( + action_set=HighLevelActionSetArgs( subsets=["bid"], multiaction=False, ), @@ -323,7 +324,7 @@ filter_visible_elements_only=args.Choice([True, False], p=[0.3, 0.7]), ), action=dp.ActionFlags( - action_set=bgym.HighLevelActionSetArgs( + action_set=HighLevelActionSetArgs( subsets=args.Choice([["bid"], ["bid", "coord"]]), multiaction=args.Choice([True, False], p=[0.7, 0.3]), ), diff --git a/src/agentlab/agents/generic_agent/generic_agent.py b/src/agentlab/agents/generic_agent/generic_agent.py index a65b3eb3..d1f48f76 100644 --- a/src/agentlab/agents/generic_agent/generic_agent.py +++ b/src/agentlab/agents/generic_agent/generic_agent.py @@ -10,9 +10,11 @@ from copy import deepcopy from dataclasses import asdict, dataclass +from functools import partial from warnings import warn import bgym +from bgym import Benchmark from browsergym.experiments.agent import Agent, AgentInfo from agentlab.agents import dynamic_prompting as dp @@ -22,7 +24,6 @@ from agentlab.llm.tracking import cost_tracker_decorator from .generic_agent_prompt import GenericPromptFlags, MainPrompt -from functools import partial @dataclass @@ -37,7 +38,7 @@ def __post_init__(self): except AttributeError: pass - def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode): + def set_benchmark(self, benchmark: Benchmark, demo_mode): """Override Some flags based on the benchmark.""" if benchmark.name.startswith("miniwob"): self.flags.obs.use_html = True diff --git a/src/agentlab/agents/generic_agent/reproducibility_agent.py b/src/agentlab/agents/generic_agent/reproducibility_agent.py index bf1f01c9..154aeae5 100644 --- a/src/agentlab/agents/generic_agent/reproducibility_agent.py +++ b/src/agentlab/agents/generic_agent/reproducibility_agent.py @@ -19,6 +19,7 @@ from pathlib import Path import bgym +from bgym import HighLevelActionSetArgs from browsergym.experiments.agent import AgentInfo from bs4 import BeautifulSoup @@ -144,7 +145,7 @@ def _make_backward_compatible(agent_args: GenericAgentArgs): if isinstance(action_set, str): action_set = action_set.split("+") - agent_args.flags.action.action_set = bgym.HighLevelActionSetArgs( + agent_args.flags.action.action_set = HighLevelActionSetArgs( subsets=action_set, multiaction=agent_args.flags.action.multi_actions, ) diff --git a/src/agentlab/agents/tool_use_agent/__init__.py b/src/agentlab/agents/tool_use_agent/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/agentlab/agents/tool_use_agent/agent.py b/src/agentlab/agents/tool_use_agent/agent.py new file mode 100644 index 00000000..26949462 --- /dev/null +++ b/src/agentlab/agents/tool_use_agent/agent.py @@ -0,0 +1,305 @@ +import json +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import bgym +import numpy as np +from PIL import Image, ImageDraw + +from agentlab.agents import agent_utils +from agentlab.agents.agent_args import AgentArgs +from agentlab.llm.llm_utils import image_to_png_base64_url +from agentlab.llm.response_api import ( + BaseModelArgs, + ClaudeResponseModelArgs, + MessageBuilder, + OpenAIChatModelArgs, + OpenAIResponseModelArgs, + OpenRouterModelArgs, + ResponseLLMOutput, + VLLMModelArgs, +) +from agentlab.llm.tracking import cost_tracker_decorator +from browsergym.core.observation import extract_screenshot + +if TYPE_CHECKING: + from openai.types.responses import Response + + +@dataclass +class ToolUseAgentArgs(AgentArgs): + model_args: OpenAIResponseModelArgs = None + use_first_obs: bool = True + tag_screenshot: bool = True + use_raw_page_output: bool = True + + def __post_init__(self): + try: + self.agent_name = f"ToolUse-{self.model_args.model_name}".replace("/", "_") + except AttributeError: + pass + + def make_agent(self) -> bgym.Agent: + return ToolUseAgent( + model_args=self.model_args, + use_first_obs=self.use_first_obs, + tag_screenshot=self.tag_screenshot, + ) + + def prepare(self): + return self.model_args.prepare_server() + + def close(self): + return self.model_args.close_server() + + +class ToolUseAgent(bgym.Agent): + def __init__( + self, + model_args: OpenAIResponseModelArgs, + use_first_obs: bool = True, + tag_screenshot: bool = True, + ): + self.chat = model_args.make_model() + self.model_args = model_args + self.use_first_obs = use_first_obs + self.tag_screenshot = tag_screenshot + self.action_set = bgym.HighLevelActionSet(["coord"], multiaction=False) + self.tools = self.action_set.to_tool_description(api=model_args.api) + + self.call_ids = [] + + # self.tools.append( + # { + # "type": "function", + # "name": "chain_of_thought", + # "description": "A tool that allows the agent to think step by step. Every other action must ALWAYS be preceeded by a call to this tool.", + # "parameters": { + # "type": "object", + # "properties": { + # "thoughts": { + # "type": "string", + # "description": "The agent's reasoning process.", + # }, + # }, + # "required": ["thoughts"], + # }, + # } + # ) + + self.llm = model_args.make_model(extra_kwargs={"tools": self.tools}) + self.msg_builder = model_args.get_message_builder() + self.messages: list[MessageBuilder] = [] + + def obs_preprocessor(self, obs): + page = obs.pop("page", None) + if page is not None: + obs["screenshot"] = extract_screenshot(page) + if self.tag_screenshot: + screenshot = Image.fromarray(obs["screenshot"]) + screenshot = agent_utils.tag_screenshot_with_action(screenshot, obs["last_action"]) + obs["screenshot_tag"] = np.array(screenshot) + else: + raise ValueError("No page found in the observation.") + + return obs + + @cost_tracker_decorator + def get_action(self, obs: Any) -> float: + if len(self.messages) == 0: + self.initalize_messages(obs) + else: + if obs["last_action_error"] == "": # Check No error in the last action + screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot" + tool_message = self.msg_builder.tool().add_image( + image_to_png_base64_url(obs[screenshot_key]) + ) + tool_message.update_last_raw_response(self.last_response) + tool_message.add_tool_id(self.previous_call_id) + self.messages.append(tool_message) + else: + tool_message = self.msg_builder.tool().add_text( + f"Function call failed: {obs['last_action_error']}" + ) + tool_message.add_tool_id(self.previous_call_id) + tool_message.update_last_raw_response(self.last_response) + self.messages.append(tool_message) + + response: ResponseLLMOutput = self.llm(messages=self.messages) + + action = response.action + think = response.think + self.last_response = response + self.previous_call_id = response.last_computer_call_id + self.messages.append(response.assistant_message) # this is tool call + + return ( + action, + bgym.AgentInfo( + think=think, + chat_messages=self.messages, + stats={}, + ), + ) + + def initalize_messages(self, obs: Any) -> None: + system_message = self.msg_builder.system().add_text( + "You are an agent. Based on the observation, you will decide which action to take to accomplish your goal." + ) + self.messages.append(system_message) + + goal_message = self.msg_builder.user() + for content in obs["goal_object"]: + if content["type"] == "text": + goal_message.add_text(content["text"]) + elif content["type"] == "image_url": + goal_message.add_image(content["image_url"]) + self.messages.append(goal_message) + + extra_info = [] + + extra_info.append( + """Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n""" + ) + + self.messages.append(self.msg_builder.user().add_text("\n".join(extra_info))) + + if self.use_first_obs: + msg = "Here is the first observation." + screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot" + if self.tag_screenshot: + msg += " A red dot on screenshots indicate the previous click action." + message = self.msg_builder.user().add_text(msg) + message.add_image(image_to_png_base64_url(obs[screenshot_key])) + self.messages.append(message) + + +OPENAI_MODEL_CONFIG = OpenAIResponseModelArgs( + model_name="gpt-4.1", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=0.1, + vision_support=True, +) + +OPENAI_CHATAPI_MODEL_CONFIG = OpenAIChatModelArgs( + model_name="gpt-4o-2024-08-06", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=0.1, + vision_support=True, +) + +CLAUDE_MODEL_CONFIG = ClaudeResponseModelArgs( + model_name="claude-3-7-sonnet-20250219", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=0.1, + vision_support=True, +) + + + + +# def get_openrouter_model(model_name: str, **open_router_args) -> OpenRouterModelArgs: +# default_model_args = { +# "max_total_tokens": 200_000, +# "max_input_tokens": 180_000, +# "max_new_tokens": 2_000, +# "temperature": 0.1, +# "vision_support": True, +# } +# merged_args = {**default_model_args, **open_router_args} + +# return OpenRouterModelArgs(model_name=model_name, **merged_args) + + +# def get_openrouter_tool_use_agent( +# model_name: str, +# model_args: dict = {}, +# use_first_obs=True, +# tag_screenshot=True, +# use_raw_page_output=True, +# ) -> ToolUseAgentArgs: +# # To Do : Check if OpenRouter endpoint specific args are working +# if not supports_tool_calling(model_name): +# raise ValueError(f"Model {model_name} does not support tool calling.") + +# model_args = get_openrouter_model(model_name, **model_args) + +# return ToolUseAgentArgs( +# model_args=model_args, +# use_first_obs=use_first_obs, +# tag_screenshot=tag_screenshot, +# use_raw_page_output=use_raw_page_output, +# ) + + +# OPENROUTER_MODEL = get_openrouter_tool_use_agent("google/gemini-2.5-pro-preview") + + +AGENT_CONFIG = ToolUseAgentArgs( + model_args=CLAUDE_MODEL_CONFIG, +) + +# MT_TOOL_USE_AGENT = ToolUseAgentArgs( +# model_args=OPENROUTER_MODEL, +# ) +CHATAPI_AGENT_CONFIG = ToolUseAgentArgs( + model_args=OpenAIChatModelArgs( + model_name="gpt-4o-2024-11-20", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=0.7, + vision_support=True, + ), +) + + +OAI_CHAT_TOOl_AGENT = ToolUseAgentArgs( + model_args=OpenAIChatModelArgs(model_name="gpt-4o-2024-08-06") +) + + +PROVIDER_FACTORY_MAP = { + "openai": {"chatcompletion": OpenAIChatModelArgs, "response": OpenAIResponseModelArgs}, + "openrouter": OpenRouterModelArgs, + "vllm": VLLMModelArgs, + "antrophic": ClaudeResponseModelArgs, +} + + +def get_tool_use_agent( + api_provider: str, + model_args: "BaseModelArgs", + tool_use_agent_args: dict = None, + api_provider_spec=None, +) -> ToolUseAgentArgs: + + if api_provider == "openai": + assert ( + api_provider_spec is not None + ), "Endpoint specification is required for OpenAI provider. Choose between 'chatcompletion' and 'response'." + + model_args_factory = ( + PROVIDER_FACTORY_MAP[api_provider] + if api_provider_spec is None + else PROVIDER_FACTORY_MAP[api_provider][api_provider_spec] + ) + + # Create the agent with model arguments from the factory + agent = ToolUseAgentArgs( + model_args=model_args_factory(**model_args), **(tool_use_agent_args or {}) + ) + return agent + + +## We have three providers that we want to support. +# Anthropic +# OpenAI +# vllm (uses OpenAI API) diff --git a/src/agentlab/agents/visual_agent/agent_configs.py b/src/agentlab/agents/visual_agent/agent_configs.py index 404afaec..df8d819b 100644 --- a/src/agentlab/agents/visual_agent/agent_configs.py +++ b/src/agentlab/agents/visual_agent/agent_configs.py @@ -1,9 +1,11 @@ +import bgym +from bgym import HighLevelActionSetArgs + +import agentlab.agents.dynamic_prompting as dp from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT from .visual_agent import VisualAgentArgs from .visual_agent_prompts import PromptFlags -import agentlab.agents.dynamic_prompting as dp -import bgym # the other flags are ignored for this agent. DEFAULT_OBS_FLAGS = dp.ObsFlags( @@ -16,7 +18,7 @@ ) DEFAULT_ACTION_FLAGS = dp.ActionFlags( - action_set=bgym.HighLevelActionSetArgs(subsets=["coord"]), + action_set=HighLevelActionSetArgs(subsets=["coord"]), long_description=True, individual_examples=False, ) diff --git a/src/agentlab/agents/visual_agent/visual_agent.py b/src/agentlab/agents/visual_agent/visual_agent.py index 8efee11d..d76cedf3 100644 --- a/src/agentlab/agents/visual_agent/visual_agent.py +++ b/src/agentlab/agents/visual_agent/visual_agent.py @@ -11,6 +11,7 @@ from dataclasses import asdict, dataclass import bgym +from bgym import Benchmark from browsergym.experiments.agent import Agent, AgentInfo from agentlab.agents import dynamic_prompting as dp @@ -19,7 +20,7 @@ from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry from agentlab.llm.tracking import cost_tracker_decorator -from .visual_agent_prompts import PromptFlags, MainPrompt +from .visual_agent_prompts import MainPrompt, PromptFlags @dataclass @@ -34,7 +35,7 @@ def __post_init__(self): except AttributeError: pass - def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode): + def set_benchmark(self, benchmark: Benchmark, demo_mode): """Override Some flags based on the benchmark.""" self.flags.obs.use_tabs = benchmark.is_multi_tab diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index 6154007e..45957601 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -14,7 +14,7 @@ from attr import dataclass from langchain.schema import BaseMessage, HumanMessage from openai import OpenAI -from PIL import Image, ImageDraw +from PIL import Image from agentlab.analyze import inspect_results from agentlab.experiments.exp_utils import RESULTS_DIR @@ -23,6 +23,8 @@ from agentlab.llm.chat_api import make_system_message, make_user_message from agentlab.llm.llm_utils import BaseMessage as AgentLabBaseMessage from agentlab.llm.llm_utils import Discussion +from agentlab.llm.response_api import MessageBuilder +from agentlab.agents import agent_utils select_dir_instructions = "Select Experiment Directory" AGENT_NAME_KEY = "agent.agent_name" @@ -530,47 +532,12 @@ def wrapper(*args, **kwargs): return decorator -def tag_screenshot_with_action(screenshot: Image, action: str) -> Image: - """ - If action is a coordinate action, try to render it on the screenshot. - - e.g. mouse_click(120, 130) -> draw a dot at (120, 130) on the screenshot - - Args: - screenshot: The screenshot to tag. - action: The action to tag the screenshot with. - - Returns: - The tagged screenshot. - - Raises: - ValueError: If the action parsing fails. - """ - if action.startswith("mouse_click"): - try: - coords = action[action.index("(") + 1 : action.index(")")].split(",") - coords = [c.strip() for c in coords] - if len(coords) not in [2, 3]: - raise ValueError(f"Invalid coordinate format: {coords}") - if coords[0].startswith("x="): - coords[0] = coords[0][2:] - if coords[1].startswith("y="): - coords[1] = coords[1][2:] - x, y = float(coords[0].strip()), float(coords[1].strip()) - draw = ImageDraw.Draw(screenshot) - radius = 5 - draw.ellipse( - (x - radius, y - radius, x + radius, y + radius), fill="red", outline="red" - ) - except (ValueError, IndexError) as e: - warning(f"Failed to parse action '{action}': {e}") - return screenshot - - def update_screenshot(som_or_not: str): global info action = info.exp_result.steps_info[info.step].action - return tag_screenshot_with_action(get_screenshot(info, som_or_not=som_or_not), action) + return agent_utils.tag_screenshot_with_action( + get_screenshot(info, som_or_not=som_or_not), action + ) def get_screenshot(info: Info, step: int = None, som_or_not: str = "Raw Screenshots"): @@ -589,7 +556,9 @@ def update_screenshot_pair(som_or_not: str): s2 = get_screenshot(info, info.step + 1, som_or_not) if s1 is not None: - s1 = tag_screenshot_with_action(s1, info.exp_result.steps_info[info.step].action) + s1 = agent_utils.tag_screenshot_with_action( + s1, info.exp_result.steps_info[info.step].action + ) return s1, s2 @@ -627,12 +596,51 @@ def update_axtree(): return get_obs(key="axtree_txt", default="No AXTree") +def dict_to_markdown(d: dict): + """ + Convert a dictionary to a clean markdown representation, recursively. + + Args: + d (dict): A dictionary where keys are strings and values can be strings, + lists of dictionaries, or nested dictionaries. + + Returns: + str: A markdown-formatted string representation of the dictionary. + """ + if not isinstance(d, dict): + warning(f"Expected dict, got {type(d)}") + return repr(d) + if not d: + return "No Data" + res = "" + for k, v in d.items(): + if isinstance(v, dict): + res += f"## {k}\n{dict_to_markdown(v)}\n" + elif isinstance(v, list): + res += f"## {k}\n" + for i, item in enumerate(v): + if isinstance(item, dict): + res += f"### Item {i}\n{dict_to_markdown(item)}\n" + else: + res += f"- {item}\n" + else: + res += f"- **{k}**: {v}\n" + return res + + def update_chat_messages(): global info agent_info = info.exp_result.steps_info[info.step].agent_info chat_messages = agent_info.get("chat_messages", ["No Chat Messages"]) if isinstance(chat_messages, Discussion): return chat_messages.to_markdown() + + if isinstance(chat_messages, list) and isinstance(chat_messages[0], MessageBuilder): + chat_messages = [ + m.to_markdown() if isinstance(m, MessageBuilder) else dict_to_markdown(m) + for m in chat_messages + ] + return "\n\n".join(chat_messages) messages = [] # TODO(ThibaultLSDC) remove this at some point for i, m in enumerate(chat_messages): if isinstance(m, BaseMessage): # TODO remove once langchain is deprecated diff --git a/src/agentlab/experiments/loop.py b/src/agentlab/experiments/loop.py index 5a9580ca..528b684c 100644 --- a/src/agentlab/experiments/loop.py +++ b/src/agentlab/experiments/loop.py @@ -25,7 +25,7 @@ from PIL import Image from tqdm import tqdm -from agentlab.agents.tapeagent import TapeAgent, save_tape +# from agentlab.agents.tapeagent import TapeAgent, save_tape logger = logging.getLogger(__name__) @@ -45,7 +45,9 @@ class EnvArgs(DataClassJsonMixin): storage_state: Optional[str | Path | dict] = None task_kwargs: Optional[dict] = None # use default value from BrowserGym - def make_env(self, action_mapping, exp_dir, exp_task_kwargs: dict = {}): + def make_env( + self, action_mapping, exp_dir, exp_task_kwargs: dict = {}, use_raw_page_output=True + ): """ Instantiates the BrowserGym environment corresponding to the arguments (with some tweaks). @@ -85,6 +87,7 @@ def make_env(self, action_mapping, exp_dir, exp_task_kwargs: dict = {}): headless=self.headless, wait_for_user_message=self.wait_for_user_message, action_mapping=action_mapping, # action mapping is provided by the agent + use_raw_page_output=use_raw_page_output, **extra_kwargs, ) @@ -233,9 +236,9 @@ def make_stats(self): stats = {} stats.update(self.agent_info.pop("stats", {})) - messages = self.agent_info.get("chat_messages", None) - if messages is not None: - stats["n_token_agent_messages"] = count_messages_token(messages) + # messages = self.agent_info.get("chat_messages", None) + # if messages is not None: + # stats["n_token_agent_messages"] = count_messages_token(messages) t = self.profiling stats["step_elapsed"] = t.env_stop - t.env_start @@ -398,9 +401,15 @@ def run(self): agent = self.agent_args.make_agent() logger.debug("Agent created.") + if hasattr(self.agent_args, "use_raw_page_output"): + use_raw_page_output = self.agent_args.use_raw_page_output + else: + use_raw_page_output = False + env = self.env_args.make_env( action_mapping=agent.action_set.to_python_code, exp_dir=self.exp_dir, + use_raw_page_output=use_raw_page_output, ) logger.debug("Environment created.") @@ -470,9 +479,9 @@ def run(self): err_msg = f"Exception uncaught by agent or environment in task {self.env_args.task_name}.\n{type(e).__name__}:\n{e}" logger.info("Saving experiment info.") self.save_summary_info(episode_info, Path(self.exp_dir), err_msg, stack_trace) - if isinstance(agent, TapeAgent): - task = getattr(env, "task", {}) - save_tape(self.exp_dir, episode_info, task, agent.final_tape) + # if isinstance(agent, TapeAgent): + # task = getattr(env, "task", {}) + # save_tape(self.exp_dir, episode_info, task, agent.final_tape) except Exception as e: logger.exception(f"Error while saving experiment info: {e}") try: @@ -875,7 +884,7 @@ def _move_old_exp(exp_dir): def _get_env_name(task_name: str): """Register tasks if needed (lazy import) and return environment name.""" - # lazy benchmark import + # lazy import if task_name.startswith("miniwob"): import browsergym.miniwob elif task_name.startswith("workarena"): diff --git a/src/agentlab/experiments/reproducibility_util.py b/src/agentlab/experiments/reproducibility_util.py index 01f3fdc9..0b0f91b4 100644 --- a/src/agentlab/experiments/reproducibility_util.py +++ b/src/agentlab/experiments/reproducibility_util.py @@ -8,6 +8,7 @@ import bgym import pandas as pd +from bgym import Benchmark from git import InvalidGitRepositoryError, Repo from git.config import GitConfigParser @@ -20,7 +21,7 @@ def _get_repo(module): def _get_benchmark_version( - benchmark: bgym.Benchmark, allow_bypass_benchmark_version: bool = False + benchmark: Benchmark, allow_bypass_benchmark_version: bool = False ) -> str: benchmark_name = benchmark.name @@ -178,7 +179,7 @@ def _get_git_info(module, changes_white_list=()) -> tuple[str, list[tuple[str, P def get_reproducibility_info( agent_names: str | list[str], - benchmark: bgym.Benchmark, + benchmark: Benchmark, study_id: str = "", comment=None, changes_white_list=( # Files that are often modified during experiments but do not affect reproducibility diff --git a/src/agentlab/experiments/study.py b/src/agentlab/experiments/study.py index 7de3db98..cc1b1a1c 100644 --- a/src/agentlab/experiments/study.py +++ b/src/agentlab/experiments/study.py @@ -6,13 +6,13 @@ import uuid from abc import ABC, abstractmethod from concurrent.futures import ProcessPoolExecutor -from dataclasses import dataclass +from dataclasses import asdict, dataclass from datetime import datetime from multiprocessing import Manager, Pool, Queue from pathlib import Path import bgym -from bgym import Benchmark +from bgym import DEFAULT_BENCHMARKS, Benchmark from slugify import slugify from agentlab.agents.agent_args import AgentArgs @@ -32,7 +32,7 @@ def make_study( agent_args: list[AgentArgs] | AgentArgs, - benchmark: bgym.Benchmark | str, + benchmark: Benchmark | str, logging_level=logging.WARNING, logging_level_stdout=logging.WARNING, suffix="", @@ -47,8 +47,8 @@ def make_study( The agent configuration(s) to run. *IMPORTANT*: these objects will be pickled and unpickled. Make sure they are imported from a package that is accessible from PYTHONPATH. Otherwise, it won't load in agentlab-xray. - benchmark: bgym.Benchmark | str - The benchmark to run the agents on. See bgym.DEFAULT_BENCHMARKS for the main ones. You + benchmark: Benchmark | str + The benchmark to run the agents on. See DEFAULT_BENCHMARKS for the main ones. You can also make your own by modifying an existing one. logging_level: int The logging level for file log. @@ -89,7 +89,7 @@ def make_study( agent_args = [agent_args] if isinstance(benchmark, str): - benchmark = bgym.DEFAULT_BENCHMARKS[benchmark.lower()]() + benchmark = DEFAULT_BENCHMARKS[benchmark.lower()]() if len(agent_args) > 1 and ("webarena" in benchmark.name or parallel_servers is not None): logger.warning( @@ -184,8 +184,8 @@ class Study(AbstractStudy): The agent configuration(s) to run. *IMPORTANT*: these objects will be pickled and unpickled. Make sure they are imported from a package that is accessible from PYTHONPATH. Otherwise, it won't load in agentlab-xray. - benchmark: bgym.Benchmark | str - The benchmark to run the agents on. See bgym.DEFAULT_BENCHMARKS for the main ones. You + benchmark: Benchmark | str + The benchmark to run the agents on. See DEFAULT_BENCHMARKS for the main ones. You can also make your own by modifying an existing one. dir: Path The directory where the study will be saved. If None, a directory will be created in @@ -241,7 +241,10 @@ def __post_init__(self): """Initialize the study. Set the uuid, and generate the exp_args_list.""" self.uuid = uuid.uuid4() if isinstance(self.benchmark, str): - self.benchmark = bgym.DEFAULT_BENCHMARKS[self.benchmark.lower()]() + self.benchmark = DEFAULT_BENCHMARKS[self.benchmark.lower()]() + + self.benchmark.env_args_list = _convert_env_args(self.benchmark.env_args_list) + if isinstance(self.dir, str): self.dir = Path(self.dir) self.make_exp_args_list() @@ -328,28 +331,31 @@ def run( self._run(n_jobs, parallel_backend, strict_reproducibility) suffix = f"trial_{i + 1}_of_{n_relaunch}" - _, summary_df, _ = self.get_results(suffix=suffix) + _, summary_df, error_report = self.get_results(suffix=suffix) logger.info("\n" + str(summary_df)) n_incomplete, n_error = self.find_incomplete(include_errors=relaunch_errors) if n_error / n_exp > 0.3: - logger.warning("More than 30% of the experiments errored. Stopping the study.") - return + logger.warning("More than 30% of the experiments errored. Stopping the retries.") + break if last_error_count is not None and n_error >= last_error_count: logger.warning( - "Last trial did not reduce the number of errors. Stopping the study." + "Last trial did not reduce the number of errors. Stopping the retries." ) - return + break if n_incomplete == 0: logger.info(f"Study {self.name} finished.") - return + break - logger.warning( - f"Study {self.name} did not finish after {n_relaunch} trials. There are {n_incomplete} incomplete experiments." - ) + logger.info("# Error Report:\n-------------\n\n" + error_report) + + if n_incomplete != 0: + logger.warning( + f"Study {self.name} did not finish after {n_relaunch} trials. There are {n_incomplete} incomplete experiments." + ) def _run(self, n_jobs=1, parallel_backend="joblib", strict_reproducibility=False): """Run all experiments in the study in parallel when possible. @@ -436,7 +442,7 @@ def load_most_recent(root_dir: Path = None, contains=None) -> "Study": def agents_on_benchmark( self, agents: list[AgentArgs] | AgentArgs, - benchmark: bgym.Benchmark, + benchmark: Benchmark, demo_mode=False, logging_level: int = logging.INFO, logging_level_stdout: int = logging.INFO, @@ -447,7 +453,7 @@ def agents_on_benchmark( Args: agents: list[AgentArgs] | AgentArgs The agent configuration(s) to run. - benchmark: bgym.Benchmark + benchmark: Benchmark The benchmark to run the agents on. demo_mode: bool If True, the experiments will be run in demo mode. @@ -719,6 +725,26 @@ def set_demo_mode(env_args_list: list[EnvArgs]): env_args.slow_mo = 1000 +def _convert_env_args(env_args_list): + """Return a list where every element is the *new* EnvArgs. + + For backward compatibility, we need to convert the old EnvArgs to the new one. + """ + from bgym import EnvArgs as BGymEnvArgs + + new_list = [] + for ea in env_args_list: + # already new → keep as‑is + if isinstance(ea, EnvArgs): + new_list.append(ea) + # old → convert + elif isinstance(ea, BGymEnvArgs): + new_list.append(EnvArgs(**asdict(ea))) + else: + raise TypeError(f"Unexpected type: {type(ea)}") + return new_list + + # def _flag_sequential_exp(exp_args_list: list[ExpArgs], benchmark: Benchmark): # if benchmark.name.startswith("visualwebarena"): # sequential_subset = benchmark.subset_from_glob("requires_reset", "True") diff --git a/src/agentlab/experiments/view_dep_graph.py b/src/agentlab/experiments/view_dep_graph.py index abbf7f87..2f4058ae 100644 --- a/src/agentlab/experiments/view_dep_graph.py +++ b/src/agentlab/experiments/view_dep_graph.py @@ -2,11 +2,12 @@ etc. You may have to detust it to make it work for you.""" import math + import bgym import matplotlib.pyplot as plt - import networkx as nx import numpy as np +from bgym import DEFAULT_BENCHMARKS def clean_dict(dependency_dict: dict[str, list[str]]) -> dict[str, list[str]]: @@ -308,8 +309,8 @@ def compress_chains(G): return G_compressed -# benchmark = bgym.DEFAULT_BENCHMARKS["webarena"]() -benchmark = bgym.DEFAULT_BENCHMARKS["visualwebarena"]() +# benchmark = DEFAULT_BENCHMARKS["webarena"]() +benchmark = DEFAULT_BENCHMARKS["visualwebarena"]() dep_graph = benchmark.dependency_graph_over_tasks() dep_graph = clean_dict(dep_graph) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index 2536200e..9f70ca42 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -12,6 +12,7 @@ from warnings import warn import numpy as np +import openai import tiktoken import yaml from langchain.schema import BaseMessage @@ -90,6 +91,102 @@ def retry( raise ParseError(f"Could not parse a valid value after {n_retry} retries.") +def call_with_retries(client_function, api_params, max_retries=5): + """ + Makes a API call with retries for transient failures, + rate limiting, and invalid or error-containing responses. + + Args: + client_function (Callable): Function to call the API (e.g., openai.ChatCompletion.create). + api_params (dict): Parameters to pass to the API function. + max_retries (int): Maximum number of retry attempts. + + Returns: + response: Valid API response object. + """ + for attempt in range(1, max_retries + 1): + try: + response = client_function(**api_params) + + # Check for explicit error field in response object + if getattr(response, "error", None): + logging.warning( + f"[Attempt {attempt}] API returned error: {response.error}. Retrying..." + ) + continue + + # Check for valid response with choices + if hasattr(response, "choices") and response.choices: + logging.info(f"[Attempt {attempt}] API call succeeded.") + return response + + logging.warning( + f"[Attempt {attempt}] API returned empty or malformed response. Retrying..." + ) + + except openai.APIError as e: + logging.error(f"[Attempt {attempt}] APIError: {e}") + if e.http_status == 429: + logging.warning("Rate limit exceeded. Retrying...") + elif e.http_status >= 500: + logging.warning("Server error encountered. Retrying...") + else: + logging.error("Non-retriable API error occurred.") + raise + + except Exception as e: + logging.exception(f"[Attempt {attempt}] Unexpected exception occurred: {e}") + raise + + logging.error("Exceeded maximum retry attempts. API call failed.") + raise RuntimeError("API call failed after maximum retries.") + + +def supports_tool_calling_for_openrouter( + model_name: str, +) -> bool: + """ + Check if the openrouter model supports tool calling. + + Args: + model_name (str): The name of the model. + + Returns: + bool: True if the model supports tool calling, False otherwise. + """ + import os + + import openai + + client = openai.Client( + api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1" + ) + try: + response = client.chat.completions.create( + model=model_name, + messages=[{"role": "user", "content": "Call the test tool"}], + tools=[ + { + "type": "function", + "function": { + "name": "dummy_tool", + "description": "Just a test tool", + "parameters": { + "type": "object", + "properties": {}, + }, + }, + } + ], + tool_choice="required", + ) + response = response.to_dict() + return "tool_calls" in response["choices"][0]["message"] + except Exception as e: + print(f"Model '{model_name}' error: {e}") + return False + + def retry_multiple( chat: "ChatModel", messages: "Discussion", @@ -381,6 +478,17 @@ def image_to_jpg_base64_url(image: np.ndarray | Image.Image): return f"data:image/jpeg;base64,{image_base64}" +def image_to_png_base64_url(image: np.ndarray | Image.Image): + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + if image.mode in ("RGBA", "LA"): + image = image.convert("RGB") + buffered = io.BytesIO() + image.save(buffered, "PNG") + image_base64 = base64.b64encode(buffered.getvalue()).decode() + return f"data:image/png;base64,{image_base64}" + + class BaseMessage(dict): def __init__(self, role: str, content: Union[str, list[dict]], **kwargs): allowed_attrs = {"log_probs"} @@ -401,7 +509,13 @@ def __str__(self, warn_if_image=False) -> str: else: logging.info(msg) - return "\n".join([elem["text"] for elem in self["content"] if elem["type"] == "text"]) + return "\n".join( + [ + elem["text"] + for elem in self["content"] + if elem["type"] == "text" or elem["type"] == "input_text" + ] + ) def add_content(self, type: str, content: Any): if isinstance(self["content"], str): diff --git a/src/agentlab/llm/response_api.py b/src/agentlab/llm/response_api.py new file mode 100644 index 00000000..49a8a775 --- /dev/null +++ b/src/agentlab/llm/response_api.py @@ -0,0 +1,781 @@ +import json +import logging +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Type, Union + +import openai +from anthropic import Anthropic +from openai import OpenAI +from .llm_utils import call_with_retries, supports_tool_calling_for_openrouter +from agentlab.llm import tracking + +from .base_api import BaseModelArgs + +"""This module contains utlity classes for building input messages and interacting with LLM APIs. +It includes: + 1. Message Builder for building input messages + 2. Base Reponse class for different LLM APIs (OpenAI, Anthropic, etc.) + 3. Factory classes (inherits from BaseModelArgs) for creating instances of LLM Response models. +""" + + +type ContentItem = Dict[str, Any] +type Message = Dict[str, Union[str, List[ContentItem]]] + + +@dataclass +class ResponseLLMOutput: + """Serializable object for the output of a response LLM.""" + + raw_response: Any + think: str + action: str + last_computer_call_id: str + assistant_message: Any + + +class MessageBuilder: + def __init__(self, role: str): + self.role = role + self.content: List[ContentItem] = [] + self.last_response: ResponseLLMOutput = None + self.tool_call_id: Optional[str] = None + + @classmethod + def system(cls) -> "MessageBuilder": + return cls("system") + + @classmethod + def user(cls) -> "MessageBuilder": + return cls("user") + + @classmethod + def assistant(cls) -> "MessageBuilder": + return cls("assistant") + + @classmethod + def tool(cls) -> "MessageBuilder": + return cls("tool") + + def update_last_raw_response(self, raw_response: Any) -> "MessageBuilder": + self.last_response = raw_response + return self + + def add_tool_id(self, id: str) -> "MessageBuilder": + self.tool_call_id = id + return self + + def add_text(self, text: str) -> "MessageBuilder": + self.content.append({"text": text}) + return self + + def add_image(self, image: str) -> "MessageBuilder": + self.content.append({"image": image}) + return self + + def to_markdown(self) -> str: + parts = [] + for item in self.content: + if "text" in item: + parts.append(item["text"]) + elif "image" in item: + parts.append(f"![Image]({item['image']})") + + markdown = f"## {self.role.capitalize()} Message\n\n" + markdown += "\n\n---\n\n".join(parts) + + if self.role == "tool": + assert self.tool_call_id is not None, "Tool call ID is required for tool messages" + markdown += f"\n\n---\n\n**Tool Call ID:** `{self.tool_call_id}`" + + return markdown + + +class OpenAIResponseAPIMessageBuilder(MessageBuilder): + + def __init__(self, role: str): + super().__init__(role) + self.tool_call_id = None + + def add_tool_id(self, id: str) -> "MessageBuilder": + self.tool_call_id = id + return self + + def prepare_message(self) -> List[Message]: + content = [] + for item in self.content: + if "text" in item: + content.append({"type": "input_text", "text": item["text"]}) + elif "image" in item: + content.append({"type": "input_image", "image_url": item["image"]}) + res = [{"role": self.role, "content": content}] + + if self.role == "tool": + assert self.tool_call_id is not None, "Tool call ID is required for tool messages" + # tool messages can only take text with openai + # we need to split the first content element if it's text and use it + # then open a new (user) message with the rest + # a function_call_output dict has keys "call_id", "type" and "output" + res[0]["call_id"] = self.tool_call_id + res[0]["type"] = "function_call_output" + res[0].pop("role", None) # make sure to remove role + text_content = ( + content.pop(0)["text"] + if "text" in content[0] + else "Tool call answer in next message" + ) + res[0]["output"] = text_content + res[0].pop("content", None) # make sure to remove content + res.append({"role": "user", "content": content}) + + return res + + +class AnthropicAPIMessageBuilder(MessageBuilder): + + def __init__(self, role: str): + super().__init__(role) + self.tool_call_id = None + + def add_tool_id(self, id: str) -> "MessageBuilder": + self.tool_call_id = id + return self + + def prepare_message(self) -> List[Message]: + content = [] + + if self.role == "system": + logging.info( + "Treating system message as 'user'. In the Anthropic API, system messages should be passed as a direct input to the client." + ) + return [{"role": "user", "content": content}] + + for item in self.content: + if "text" in item: + content.append({"type": "text", "text": item["text"]}) + elif "image" in item: + img_str: str = item["image"] + # make sure to get rid of the image type for anthropic + # e.g. "data:image/png;base64" + if img_str.startswith("data:image/png;base64,"): + img_str = img_str[len("data:image/png;base64,") :] + content.append( + { + "type": "image", + "source": { + "type": "base64", # currently only base64 is supported + "media_type": "image/png", # currently only png is supported + "data": img_str, + }, + } + ) + res = [{"role": self.role, "content": content}] + + if self.role == "tool": + assert self.tool_call_id is not None, "Tool call ID is required for tool messages" + res[0]["role"] = "user" + res[0]["content"] = [ + { + "type": "tool_result", + "tool_use_id": self.tool_call_id, + "content": res[0]["content"], + } + ] + return res + + +class OpenAIChatCompletionAPIMessageBuilder(MessageBuilder): + + def __init__(self, role: str): + super().__init__(role) + self.tool_call_id = None + self.tool_name = None + self.last_response = None + + def update_tool_info(self, id: str) -> "MessageBuilder": + self.tool_call_id = id + return self + + def prepare_message(self) -> List[Message]: + """Prepare the message for the OpenAI API.""" + content = [] + for item in self.content: + if "text" in item: + content.append({"type": "text", "text": item["text"]}) + elif "image" in item: + content.append({"type": "image_url", "image_url": {"url": item["image"]}}) + res = [{"role": self.role, "content": content}] + + if self.role == "tool": + assert self.tool_call_id is not None, "Tool call ID is required for tool messages" + # tool messages can only take text with openai + # we need to split the first content element if it's text and use it + # then open a new (user) message with the rest + # a function_call_output dict has keys "call_id", "type" and "output" + res[0]["tool_call_id"] = self.tool_call_id + res[0]["type"] = "function_call_output" + message = self.last_response.raw_response.choices[0].message.to_dict() + res[0]["tool_name"] = message["tool_calls"][0]["function"]["name"] + text_content = ( + content.pop(0)["text"] + if "text" in content[0] + else "Tool call answer in next message" + ) + res[0]["content"] = text_content + res.append({"role": "user", "content": content}) + return res + + +class OpenRouterAPIMessageBuilder(MessageBuilder): + + def __init__(self, role: str): + super().__init__(role) + self.tool_call_id = None + self.tool_name = None + self.last_response = None + + def update_tool_info(self, id: str) -> "MessageBuilder": + self.tool_call_id = id + return self + + def prepare_message(self) -> List[Message]: + """Prepare the message for the OpenAI API.""" + content = [] + for item in self.content: + if "text" in item: + content.append({"type": "text", "text": item["text"]}) + elif "image" in item: + content.append({"type": "image_url", "image_url": {"url": item["image"]}}) + res = [{"role": self.role, "content": content}] + + if self.role == "tool": + assert self.tool_call_id is not None, "Tool call ID is required for tool messages" + # tool messages can only take text with openai + # we need to split the first content element if it's text and use it + # then open a new (user) message with the rest + # a function_call_output dict has keys "call_id", "type" and "output" + res[0]["tool_call_id"] = self.tool_call_id + res[0]["type"] = "function_call_output" + message = self.last_response.raw_response.choices[0].message.to_dict() + res[0]["tool_name"] = message["tool_calls"][0]["function"]["name"] + text_content = ( + content.pop(0)["text"] + if "text" in content[0] + else "Tool call answer in next message" + ) + res[0]["content"] = text_content + res.append({"role": "user", "content": content}) + return res + + +# # Base class for all API Endpoints +class BaseResponseModel(ABC): + def __init__( + self, + model_name: str, + api_key: Optional[str] = None, + temperature: float = 0.5, + max_tokens: int = 100, + extra_kwargs: Optional[Dict[str, Any]] = None, + ): + self.model_name = model_name + self.api_key = api_key + self.temperature = temperature + self.max_tokens = max_tokens + self.extra_kwargs = extra_kwargs or {} + + def __call__(self, messages: list[dict | MessageBuilder]) -> dict: + """Make a call to the model and return the parsed response.""" + response = self._call_api(messages) + return self._parse_response(response) + + @abstractmethod + def _call_api(self, messages: list[dict | MessageBuilder]) -> Any: + """Make a call to the model API and return the raw response.""" + pass + + @abstractmethod + def _parse_response(self, response: Any) -> ResponseLLMOutput: + """Parse the raw response from the model API and return a structured response.""" + pass + + +class OpenAIResponseModel(BaseResponseModel): + def __init__( + self, + model_name: str, + api_key: Optional[str] = None, + temperature: float = 0.5, + max_tokens: int = 100, + extra_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + model_name=model_name, + api_key=api_key, + temperature=temperature, + max_tokens=max_tokens, + extra_kwargs=extra_kwargs, + ) + self.client = OpenAI(api_key=api_key) + + def _call_api(self, messages: list[Any | MessageBuilder]) -> dict: + input = [] + for msg in messages: + if isinstance(msg, MessageBuilder): + input += msg.prepare_message() + else: + input.append(msg) + try: + response = self.client.responses.create( + model=self.model_name, + input=input, + temperature=self.temperature, + # previous_response_id=content.get("previous_response_id", None), + max_output_tokens=self.max_tokens, + **self.extra_kwargs, + tool_choice="required", + # reasoning={ + # "effort": "low", + # "summary": "detailed", + # }, + ) + + return response + except openai.OpenAIError as e: + logging.error(f"Failed to get a response from the API: {e}") + raise e + + def _parse_response(self, response: dict) -> dict: + result = ResponseLLMOutput( + raw_response=response, + think="", + action="noop()", + last_computer_call_id=None, + assistant_message=None, + ) + for output in response.output: + if output.type == "function_call": + arguments = json.loads(output.arguments) + result.action = ( + f"{output.name}({", ".join([f"{k}={v}" for k, v in arguments.items()])})" + ) + result.last_computer_call_id = output.call_id + result.assistant_message = output + break + elif output.type == "reasoning": + if len(output.summary) > 0: + result.think += output.summary[0].text + "\n" + return result + + +class OpenAIChatCompletionModel(BaseResponseModel): + def __init__( + self, + model_name: str, + client_args: Optional[Dict[str, Any]] = {}, + temperature: float = 0.5, + max_tokens: int = 100, + extra_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + extra_kwargs=extra_kwargs, + ) + self.extra_kwargs["tools"] = self.format_tools_for_chat_completion( + self.extra_kwargs.get("tools", []) + ) + self.client = OpenAI( + **client_args + ) # Ensures client_args is a dict or defaults to an empty dict + + def _call_api(self, messages: list[dict | MessageBuilder]) -> openai.types.chat.ChatCompletion: + chat_messages: List[Message] = [] + for msg in messages: + if isinstance(msg, MessageBuilder): + chat_messages.extend(msg.prepare_message()) + else: + # Assuming msg is already in OpenAI Chat Completion message format + chat_messages.append(msg) # type: ignore + + api_params: Dict[str, Any] = { + "model": self.model_name, + "messages": chat_messages, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "tool_choice": "auto", + **self.extra_kwargs, # Pass tools, tool_choice, etc. here + } + + response = call_with_retries(self.client.chat.completions.create, api_params) + # Basic token tracking (if usage information is available) + if response.usage: + input_tokens = response.usage.prompt_tokens + output_tokens = response.usage.completion_tokens + # Cost calculation would require pricing data + # cost = ... + # if hasattr(tracking.TRACKER, "instance") and isinstance( + # tracking.TRACKER.instance, tracking.LLMTracker + # ): + # tracking.TRACKER.instance(input_tokens, output_tokens, cost) # Placeholder for cost + + return response + + def _parse_response(self, response: openai.types.chat.ChatCompletion) -> ResponseLLMOutput: + + output = ResponseLLMOutput( + raw_response=response, + think="", + action="noop()", # Default if no tool call + last_computer_call_id=None, + assistant_message={ + "role": "assistant", + "content": response.choices[0].message.content, + }, + ) + message = response.choices[0].message.to_dict() + + if tool_calls := message.get("tool_calls", None): + for tool_call in tool_calls: + function = tool_call["function"] + arguments = json.loads(function["arguments"]) + output.action = ( + f"{function['name']}({', '.join([f'{k}={v}' for k, v in arguments.items()])})" + ) + output.last_computer_call_id = tool_call["id"] + output.assistant_message = { + "role": "assistant", + "tool_calls": message["tool_calls"], + } + break # only first tool call is used + + elif "content" in message and message["content"]: + output.think = message["content"] + + return output + + @staticmethod + def format_tools_for_chat_completion(tools_flat): + """Formats response tools format for OpenAI Chat Completion API. + Why we need this? + Ans: actionset.to_tool_description() in bgym only returns description + format valid for OpenAI Response API. + """ + return [ + { + "type": tool["type"], + "function": {k: tool[k] for k in ("name", "description", "parameters")}, + } + for tool in tools_flat + ] + + + +class ClaudeResponseModel(BaseResponseModel): + def __init__( + self, + model_name: str, + api_key: Optional[str] = None, + temperature: float = 0.5, + max_tokens: int = 100, + extra_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + model_name=model_name, + api_key=api_key, + temperature=temperature, + max_tokens=max_tokens, + extra_kwargs=extra_kwargs, + ) + + # Get pricing information + + try: + pricing = tracking.get_pricing_anthropic() + self.input_cost = float(pricing[model_name]["prompt"]) + self.output_cost = float(pricing[model_name]["completion"]) + except KeyError: + logging.warning( + f"Model {model_name} not found in the pricing information, prices are set to 0. Maybe try upgrading langchain_community." + ) + self.input_cost = 0.0 + self.output_cost = 0.0 + + self.client = Anthropic(api_key=api_key) + + def _call_api(self, messages: list[dict | MessageBuilder]) -> dict: + input = [] + for msg in messages: + if isinstance(msg, MessageBuilder): + input += msg.prepare_message() + else: + input.append(msg) + try: + response = self.client.messages.create( + model=self.model_name, + messages=input, + temperature=self.temperature, + max_tokens=self.max_tokens, + **self.extra_kwargs, + ) + input_tokens = response.usage.input_tokens + output_tokens = response.usage.output_tokens + cost = input_tokens * self.input_cost + output_tokens * self.output_cost + + print(f"response.usage: {response.usage}") + + if hasattr(tracking.TRACKER, "instance") and isinstance( + tracking.TRACKER.instance, tracking.LLMTracker + ): + tracking.TRACKER.instance(input_tokens, output_tokens, cost) + + return response + except Exception as e: + logging.error(f"Failed to get a response from the API: {e}") + raise e + + def _parse_response(self, response: dict) -> dict: + result = ResponseLLMOutput( + raw_response=response, + think="", + action=None, + last_computer_call_id=None, + assistant_message={ + "role": "assistant", + "content": response.content, + }, + ) + for output in response.content: + if output.type == "tool_use": + result.action = f"{output.name}({', '.join([f'{k}=\"{v}\"' if isinstance(v, str) else f'{k}={v}' for k, v in output.input.items()])})" + result.last_computer_call_id = output.id + elif output.type == "text": + result.think += output.text + return result + + +def cua_response_to_text(action): + """ + Given a computer action (e.g., click, double_click, scroll, etc.), + convert it to a text description. + """ + action_type = action.type + + try: + match action_type: + + case "click": + x, y = action.x, action.y + button = action.button + print(f"Action: click at ({x}, {y}) with button '{button}'") + # Not handling things like middle click, etc. + if button != "left" and button != "right": + button = "left" + return f"mouse_click({x}, {y}, button='{button}')" + + case "scroll": + x, y = action.x, action.y + scroll_x, scroll_y = action.scroll_x, action.scroll_y + print( + f"Action: scroll at ({x}, {y}) with offsets (scroll_x={scroll_x}, scroll_y={scroll_y})" + ) + return f"mouse_move({x}, {y})\nscroll({scroll_x}, {scroll_y})" + + case "keypress": + keys = action.keys + for k in keys: + print(f"Action: keypress '{k}'") + # A simple mapping for common keys; expand as needed. + if k.lower() == "enter": + return "keyboard_press('Enter')" + elif k.lower() == "space": + return "keyboard_press(' ')" + else: + return f"keyboard_press('{k}')" + + case "type": + text = action.text + print(f"Action: type text: {text}") + return f"keyboard_type('{text}')" + + case "wait": + print(f"Action: wait") + return "noop()" + + case "screenshot": + # Nothing to do as screenshot is taken at each turn + print(f"Action: screenshot") + + # Handle other actions here + + case "drag": + x1, y1 = action.path[0].x, action.path[0].y + x2, y2 = action.path[1].x, action.path[1].y + print(f"Action: drag from ({x1}, {y1}) to ({x2}, {y2})") + return f"mouse_drag_and_drop({x1}, {y1}, {x2}, {y2})" + + case _: + print(f"Unrecognized action: {action}") + + except Exception as e: + print(f"Error handling action {action}: {e}") + + +# Factory classes to create the appropriate model based on the API endpoint. +@dataclass +class OpenAIResponseModelArgs(BaseModelArgs): + """Serializable object for instantiating a generic chat model with an OpenAI + model.""" + + api = "openai" + + def make_model(self, extra_kwargs=None): + return OpenAIResponseModel( + model_name=self.model_name, + temperature=self.temperature, + max_tokens=self.max_new_tokens, + extra_kwargs=extra_kwargs, + ) + + def get_message_builder(self) -> MessageBuilder: + return OpenAIResponseAPIMessageBuilder + + +@dataclass +class ClaudeResponseModelArgs(BaseModelArgs): + """Serializable object for instantiating a generic chat model with an OpenAI + model.""" + + api = "anthropic" + + def make_model(self, extra_kwargs=None): + return ClaudeResponseModel( + model_name=self.model_name, + temperature=self.temperature, + max_tokens=self.max_new_tokens, + extra_kwargs=extra_kwargs, + ) + + def get_message_builder(self) -> MessageBuilder: + return AnthropicAPIMessageBuilder + + +@dataclass +class OpenAIChatModelArgs(BaseModelArgs): + """Serializable object for instantiating a generic chat model with an OpenAI + model.""" + + api = "openai" + + def make_model(self, extra_kwargs=None): + return OpenAIChatCompletionModel( + model_name=self.model_name, + temperature=self.temperature, + max_tokens=self.max_new_tokens, + extra_kwargs=extra_kwargs, + ) + + def get_message_builder(self) -> MessageBuilder: + return OpenAIChatCompletionAPIMessageBuilder + + +@dataclass +class OpenRouterModelArgs(BaseModelArgs): + """Serializable object for instantiating a generic chat model with an OpenRouter + model.""" + + api: str = "openai" # tool description format used by actionset.to_tool_description() in bgym + + def make_model(self, extra_kwargs=None): + return OpenAIChatCompletionModel( + client_args={ + "base_url": "https://openrouter.ai/api/v1", + "api_key": os.getenv("OPENROUTER_API_KEY"), + }, + model_name=self.model_name, + temperature=self.temperature, + max_tokens=self.max_new_tokens, + extra_kwargs=extra_kwargs, + ) + + def get_message_builder(self) -> MessageBuilder: + return OpenRouterAPIMessageBuilder + + def __post_init__(self): + # Some runtime checks + assert supports_tool_calling_for_openrouter( + self.model_name + ), f"Model {self.model_name} does not support tool calling." + +class VLLMModelArgs(BaseModelArgs): + """Serializable object for instantiating a generic chat model with a VLLM + model.""" + + api = "openai" # tool description format used by actionset.to_tool_description() in bgym + + def __post_init__(self): + # tests + assert self.is_model_available( + self.model_name + ), f"Model {self.model_name} is not available on the VLLM server. \ + Please check the model name or server configuration." + + def make_model(self, extra_kwargs=None): + return OpenAIChatCompletionModel( + client_args={ + "base_url": "http://localhost:8000/v1", + "api_key": os.getenv("VLLM_API_KEY", "EMPTY"), + }, + model_name=self.model_name, # this needs to be set + temperature=self.temperature, + max_tokens=self.max_new_tokens, + extra_kwargs=extra_kwargs, + ) + + def get_message_builder(self) -> MessageBuilder: + return OpenAIChatCompletionAPIMessageBuilder + + ## Some Tests for VLLM server in the works! + def test_vllm_server_reachability(self): + import requests + + try: + response = requests.get( + f"{self.client_args['base_url']}/v1/models", + headers={"Authorization": f"Bearer {self.client_args['api_key']}"}, + ) + if response.status_code == 200: + return True + else: + return False + except requests.RequestException as e: + logging.error(f"Error checking VLLM server reachability: {e}") + return False + + def is_model_available(self, model_name: str) -> bool: + # import requests + + # """Check if the model is available on the VLLM server.""" + # if not self.test_vllm_server_reachability(): + # logging.error("VLLM server is not reachable.") + # return False + # try: + # response = requests.get( + # f"{self.client_args['base_url']}/v1/models", + # headers={"Authorization": f"Bearer {self.client_args['api_key']}"}, + # ) + # if response.status_code == 200: + # models = response.json().get("data", []) + # return any(model.get("id") == model_name for model in models) + # else: + # logging.error( + # f"Failed to fetch vllm hosted models: {response.status_code} - {response.text}" + # ) + # return False + # except requests.RequestException as e: + # logging.error(f"Error checking model availability: {e}") + # return False + return True diff --git a/src/agentlab/llm/tracking.py b/src/agentlab/llm/tracking.py index 6a08839b..6c2b0c24 100644 --- a/src/agentlab/llm/tracking.py +++ b/src/agentlab/llm/tracking.py @@ -1,10 +1,11 @@ import os +import re import threading from contextlib import contextmanager from functools import cache import requests -from langchain_community.callbacks.openai_info import MODEL_COST_PER_1K_TOKENS +from langchain_community.callbacks import bedrock_anthropic_callback, openai_info TRACKER = threading.local() @@ -85,7 +86,7 @@ def get_pricing_openrouter(): def get_pricing_openai(): - cost_dict = MODEL_COST_PER_1K_TOKENS + cost_dict = openai_info.MODEL_COST_PER_1K_TOKENS cost_dict = {k: v / 1000 for k, v in cost_dict.items()} res = {} for k in cost_dict: @@ -99,3 +100,25 @@ def get_pricing_openai(): "completion": cost_dict[completion_key], } return res + + +def _remove_version_suffix(model_name): + no_version = re.sub(r"-v\d+(?:[.:]\d+)?$", "", model_name) + return re.sub(r"anthropic.", "", no_version) + + +def get_pricing_anthropic(): + input_cost_dict = bedrock_anthropic_callback.MODEL_COST_PER_1K_INPUT_TOKENS + output_cost_dict = bedrock_anthropic_callback.MODEL_COST_PER_1K_OUTPUT_TOKENS + + res = {} + for k, v in input_cost_dict.items(): + k = _remove_version_suffix(k) + res[k] = {"prompt": v / 1000} + + for k, v in output_cost_dict.items(): + k = _remove_version_suffix(k) + if k not in res: + res[k] = {} + res[k]["completion"] = v / 1000 + return res diff --git a/tests/agents/test_generic_prompt.py b/tests/agents/test_generic_prompt.py index cc1f9036..5e89799a 100644 --- a/tests/agents/test_generic_prompt.py +++ b/tests/agents/test_generic_prompt.py @@ -2,13 +2,11 @@ import bgym import pytest +from bgym import HighLevelActionSet, HighLevelActionSetArgs from agentlab.agents import dynamic_prompting as dp from agentlab.agents.generic_agent.agent_configs import FLAGS_GPT_3_5 -from agentlab.agents.generic_agent.generic_agent_prompt import ( - GenericPromptFlags, - MainPrompt, -) +from agentlab.agents.generic_agent.generic_agent_prompt import GenericPromptFlags, MainPrompt from agentlab.llm.llm_utils import count_tokens html_template = """ @@ -76,7 +74,7 @@ filter_visible_elements_only=True, ), action=dp.ActionFlags( - action_set=bgym.HighLevelActionSetArgs( + action_set=HighLevelActionSetArgs( subsets=["bid"], multiaction=True, ), @@ -171,7 +169,7 @@ def test_shrinking_observation(): flags.obs.use_html = True prompt_maker = MainPrompt( - action_set=bgym.HighLevelActionSet(), + action_set=HighLevelActionSet(), obs_history=OBS_HISTORY, actions=ACTIONS, memories=MEMORIES, @@ -237,7 +235,7 @@ def test_main_prompt_elements_present(): # Initialize MainPrompt prompt = str( MainPrompt( - action_set=bgym.HighLevelActionSet(), + action_set=HighLevelActionSet(), obs_history=OBS_HISTORY, actions=ACTIONS, memories=MEMORIES, diff --git a/tests/experiments/test_reproducibility_util.py b/tests/experiments/test_reproducibility_util.py index aa10ff47..e5e771b3 100644 --- a/tests/experiments/test_reproducibility_util.py +++ b/tests/experiments/test_reproducibility_util.py @@ -5,6 +5,7 @@ import bgym import pytest +from bgym import DEFAULT_BENCHMARKS from agentlab.agents.generic_agent import AGENT_4o_MINI from agentlab.analyze import inspect_results @@ -17,7 +18,7 @@ ) def test_get_reproducibility_info(benchmark_name): - benchmark = bgym.DEFAULT_BENCHMARKS[benchmark_name]() + benchmark = DEFAULT_BENCHMARKS[benchmark_name]() info = reproducibility_util.get_reproducibility_info( "test_agent", benchmark, "test_id", ignore_changes=True