diff --git a/.changeset/cold-humans-work.md b/.changeset/cold-humans-work.md new file mode 100644 index 000000000..5088845ac --- /dev/null +++ b/.changeset/cold-humans-work.md @@ -0,0 +1,5 @@ +--- +"@browserbasehq/stagehand": minor +--- + +Added support for execute-level hooks on agents. This includes `onStep(action)`, `onSuccess(reuslt)`, and `onFailure(error)` inputs. diff --git a/examples/cua-example.ts b/examples/cua-example.ts index c66a28019..637ba471b 100644 --- a/examples/cua-example.ts +++ b/examples/cua-example.ts @@ -27,7 +27,7 @@ async function main() { const agent = stagehand.agent({ provider: "openai", - model: "computer-use-preview-2025-02-04", + model: "computer-use-preview", instructions: `You are a helpful assistant that can use a web browser. You are currently on the following page: ${page.url()}. Do not ask follow up questions, the user will trust your judgement.`, diff --git a/lib/agent/AgentProvider.ts b/lib/agent/AgentProvider.ts index cc8110fa4..5d09096b5 100644 --- a/lib/agent/AgentProvider.ts +++ b/lib/agent/AgentProvider.ts @@ -1,12 +1,12 @@ -import { LogLine } from "@/types/log"; -import { AgentClient } from "./AgentClient"; import { AgentType } from "@/types/agent"; -import { OpenAICUAClient } from "./OpenAICUAClient"; -import { AnthropicCUAClient } from "./AnthropicCUAClient"; +import { LogLine } from "@/types/log"; import { UnsupportedModelError, UnsupportedModelProviderError, } from "@/types/stagehandErrors"; +import { AgentClient } from "./AgentClient"; +import { AnthropicCUAClient } from "./AnthropicCUAClient"; +import { OpenAICUAClient } from "./OpenAICUAClient"; // Map model names to their provider types const modelToAgentProviderMap: Record = { @@ -22,7 +22,6 @@ const modelToAgentProviderMap: Record = { */ export class AgentProvider { private logger: (message: LogLine) => void; - /** * Create a new agent provider */ diff --git a/lib/agent/AnthropicCUAClient.ts b/lib/agent/AnthropicCUAClient.ts index fbbc97d82..a747564e0 100644 --- a/lib/agent/AnthropicCUAClient.ts +++ b/lib/agent/AnthropicCUAClient.ts @@ -1,18 +1,19 @@ -import Anthropic from "@anthropic-ai/sdk"; -import { LogLine } from "@/types/log"; import { AgentAction, + AgentExecuteOptions, + AgentExecutionOptions, AgentResult, AgentType, - AgentExecutionOptions, - ToolUseItem, - AnthropicMessage, AnthropicContentBlock, + AnthropicMessage, AnthropicTextBlock, AnthropicToolResult, + ToolUseItem, } from "@/types/agent"; -import { AgentClient } from "./AgentClient"; +import { LogLine } from "@/types/log"; import { AgentScreenshotProviderError } from "@/types/stagehandErrors"; +import Anthropic from "@anthropic-ai/sdk"; +import { AgentClient } from "./AgentClient"; export type ResponseInputItem = AnthropicMessage | AnthropicToolResult; @@ -116,7 +117,7 @@ export class AnthropicCUAClient extends AgentClient { level: 2, }); - const result = await this.executeStep(inputItems, logger); + const result = await this.executeStep(inputItems, logger, options); // Add actions to the list if (result.actions.length > 0) { @@ -153,12 +154,16 @@ export class AnthropicCUAClient extends AgentClient { }); // Return the final result - return { + const result = { success: completed, actions, message: finalMessage, completed, }; + + options.onSuccess?.(result); + + return result; } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error); @@ -168,6 +173,8 @@ export class AnthropicCUAClient extends AgentClient { level: 0, }); + options.onFailure?.(error); + return { success: false, actions, @@ -180,6 +187,7 @@ export class AnthropicCUAClient extends AgentClient { async executeStep( inputItems: ResponseInputItem[], logger: (message: LogLine) => void, + options: AgentExecuteOptions, ): Promise<{ actions: AgentAction[]; message: string; @@ -270,6 +278,8 @@ export class AnthropicCUAClient extends AgentClient { message: `Executing action: ${action.type}`, level: 1, }); + + await options.onStep?.(action); await this.actionHandler(action); } catch (error) { const errorMessage = diff --git a/lib/agent/OpenAICUAClient.ts b/lib/agent/OpenAICUAClient.ts index 6a494300b..c3d21330e 100644 --- a/lib/agent/OpenAICUAClient.ts +++ b/lib/agent/OpenAICUAClient.ts @@ -1,17 +1,18 @@ -import OpenAI from "openai"; -import { LogLine } from "../../types/log"; import { AgentAction, + AgentExecuteOptions, + AgentExecutionOptions, AgentResult, AgentType, - AgentExecutionOptions, - ResponseInputItem, - ResponseItem, ComputerCallItem, FunctionCallItem, + ResponseInputItem, + ResponseItem, } from "@/types/agent"; -import { AgentClient } from "./AgentClient"; import { AgentScreenshotProviderError } from "@/types/stagehandErrors"; +import OpenAI from "openai"; +import { LogLine } from "../../types/log"; +import { AgentClient } from "./AgentClient"; /** * Client for OpenAI's Computer Use Assistant API @@ -111,6 +112,7 @@ export class OpenAICUAClient extends AgentClient { inputItems, previousResponseId, logger, + options, ); // Add actions to the list @@ -137,13 +139,17 @@ export class OpenAICUAClient extends AgentClient { currentStep++; } - // Return the final result - return { + const result = { success: completed, actions, message: finalMessage, completed, }; + + options.onSuccess?.(result); + + // Return the final result + return result; } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error); @@ -153,6 +159,8 @@ export class OpenAICUAClient extends AgentClient { level: 0, }); + options.onFailure?.(error as Error); + return { success: false, actions, @@ -170,6 +178,7 @@ export class OpenAICUAClient extends AgentClient { inputItems: ResponseInputItem[], previousResponseId: string | undefined, logger: (message: LogLine) => void, + options: AgentExecuteOptions, ): Promise<{ actions: AgentAction[]; message: string; @@ -224,7 +233,7 @@ export class OpenAICUAClient extends AgentClient { } // Take actions and get results - const nextInputItems = await this.takeAction(output, logger); + const nextInputItems = await this.takeAction(output, logger, options); // Check if completed const completed = @@ -334,11 +343,14 @@ export class OpenAICUAClient extends AgentClient { async takeAction( output: ResponseItem[], logger: (message: LogLine) => void, + options: AgentExecuteOptions, ): Promise { const nextInputItems: ResponseInputItem[] = []; // Add any computer calls to process for (const item of output) { + await options.onStep?.(item); + if (item.type === "computer_call" && this.isComputerCallItem(item)) { // Execute the action try { diff --git a/lib/handlers/agentHandler.ts b/lib/handlers/agentHandler.ts index c15ab7052..3704ff9c6 100644 --- a/lib/handlers/agentHandler.ts +++ b/lib/handlers/agentHandler.ts @@ -1,15 +1,15 @@ -import { StagehandPage } from "../StagehandPage"; -import { AgentProvider } from "../agent/AgentProvider"; -import { StagehandAgent } from "../agent/StagehandAgent"; -import { AgentClient } from "../agent/AgentClient"; -import { LogLine } from "../../types/log"; import { - AgentExecuteOptions, + ActionExecutionResult, AgentAction, - AgentResult, + AgentExecuteOptions, AgentHandlerOptions, - ActionExecutionResult, + AgentResult, } from "@/types/agent"; +import { LogLine } from "../../types/log"; +import { StagehandPage } from "../StagehandPage"; +import { AgentClient } from "../agent/AgentClient"; +import { AgentProvider } from "../agent/AgentProvider"; +import { StagehandAgent } from "../agent/StagehandAgent"; export class StagehandAgentHandler { private stagehandPage: StagehandPage; @@ -27,7 +27,6 @@ export class StagehandAgentHandler { this.stagehandPage = stagehandPage; this.logger = logger; this.options = options; - // Initialize the provider this.provider = new AgentProvider(logger); diff --git a/lib/handlers/operatorHandler.ts b/lib/handlers/operatorHandler.ts index ec5a7d5e5..2faaa9698 100644 --- a/lib/handlers/operatorHandler.ts +++ b/lib/handlers/operatorHandler.ts @@ -6,30 +6,33 @@ import { OperatorSummary, operatorSummarySchema, } from "@/types/operator"; -import { LLMParsedResponse } from "../inference"; -import { ChatMessage, LLMClient } from "../llm/LLMClient"; -import { buildOperatorSystemPrompt } from "../prompt"; -import { StagehandPage } from "../StagehandPage"; -import { ObserveResult } from "@/types/stagehand"; +import { AgentConfig, ObserveResult } from "@/types/stagehand"; import { StagehandError, StagehandMissingArgumentError, } from "@/types/stagehandErrors"; +import { LLMParsedResponse } from "../inference"; +import { ChatMessage, LLMClient } from "../llm/LLMClient"; +import { buildOperatorSystemPrompt } from "../prompt"; +import { StagehandPage } from "../StagehandPage"; export class StagehandOperatorHandler { private stagehandPage: StagehandPage; private logger: (message: LogLine) => void; private llmClient: LLMClient; private messages: ChatMessage[]; + private options: AgentConfig; constructor( stagehandPage: StagehandPage, logger: (message: LogLine) => void, llmClient: LLMClient, + options: AgentConfig, ) { this.stagehandPage = stagehandPage; this.logger = logger; this.llmClient = llmClient; + this.options = options; } public async execute( @@ -46,106 +49,121 @@ export class StagehandOperatorHandler { const maxSteps = options.maxSteps || 10; const actions: AgentAction[] = []; - while (!completed && currentStep < maxSteps) { - const url = this.stagehandPage.page.url(); + try { + while (!completed && currentStep < maxSteps) { + const url = this.stagehandPage.page.url(); - if (!url || url === "about:blank") { - this.messages.push({ - role: "user", - content: [ - { - type: "text", - text: "No page is currently loaded. The first step should be a 'goto' action to navigate to a URL.", - }, - ], - }); - } else { - const screenshot = await this.stagehandPage.page.screenshot({ - type: "png", - fullPage: false, - }); - - const base64Image = screenshot.toString("base64"); - - let messageText = `Here is a screenshot of the current page (URL: ${url}):`; - - messageText = `Previous actions were: ${actions - .map((action) => { - let result: string = ""; - if (action.type === "act") { - const args = action.playwrightArguments as ObserveResult; - result = `Performed a "${args.method}" action ${args.arguments.length > 0 ? `with arguments: ${args.arguments.map((arg) => `"${arg}"`).join(", ")}` : ""} on "${args.description}"`; - } else if (action.type === "extract") { - result = `Extracted data: ${action.extractionResult}`; - } - return `[${action.type}] ${action.reasoning}. Result: ${result}`; - }) - .join("\n")}\n\n${messageText}`; - - this.messages.push({ - role: "user", - content: [ - { - type: "text", - text: messageText, - }, - this.llmClient.type === "anthropic" - ? { - type: "image", - source: { - type: "base64", - media_type: "image/png", - data: base64Image, + if (!url || url === "about:blank") { + this.messages.push({ + role: "user", + content: [ + { + type: "text", + text: "No page is currently loaded. The first step should be a 'goto' action to navigate to a URL.", + }, + ], + }); + } else { + const screenshot = await this.stagehandPage.page.screenshot({ + type: "png", + fullPage: false, + }); + + const base64Image = screenshot.toString("base64"); + + let messageText = `Here is a screenshot of the current page (URL: ${url}):`; + + messageText = `Previous actions were: ${actions + .map((action) => { + let result: string = ""; + if (action.type === "act") { + const args = action.playwrightArguments as ObserveResult; + result = `Performed a "${args.method}" action ${args.arguments.length > 0 ? `with arguments: ${args.arguments.map((arg) => `"${arg}"`).join(", ")}` : ""} on "${args.description}"`; + } else if (action.type === "extract") { + result = `Extracted data: ${action.extractionResult}`; + } + return `[${action.type}] ${action.reasoning}. Result: ${result}`; + }) + .join("\n")}\n\n${messageText}`; + + this.messages.push({ + role: "user", + content: [ + { + type: "text", + text: messageText, + }, + this.llmClient.type === "anthropic" + ? { + type: "image", + source: { + type: "base64", + media_type: "image/png", + data: base64Image, + }, + text: "the screenshot of the current page", + } + : { + type: "image_url", + image_url: { url: `data:image/png;base64,${base64Image}` }, }, - text: "the screenshot of the current page", - } - : { - type: "image_url", - image_url: { url: `data:image/png;base64,${base64Image}` }, - }, - ], - }); - } + ], + }); + } - const result = await this.getNextStep(currentStep); + const result = await this.getNextStep(currentStep); - if (result.method === "close") { - completed = true; - } + if (result.method === "close") { + completed = true; + } - let playwrightArguments: ObserveResult | undefined; - if (result.method === "act") { - [playwrightArguments] = await this.stagehandPage.page.observe( - result.parameters, - ); - } - let extractionResult: unknown | undefined; - if (result.method === "extract") { - extractionResult = await this.stagehandPage.page.extract( - result.parameters, - ); + let playwrightArguments: ObserveResult | undefined; + if (result.method === "act") { + [playwrightArguments] = await this.stagehandPage.page.observe( + result.parameters, + ); + } + let extractionResult: unknown | undefined; + if (result.method === "extract") { + extractionResult = await this.stagehandPage.page.extract( + result.parameters, + ); + } + + const action: AgentAction = { + type: result.method, + reasoning: result.reasoning, + taskCompleted: result.taskComplete, + parameters: result.parameters, + playwrightArguments, + extractionResult, + }; + + await options.onStep?.(action); + + await this.executeAction(result, playwrightArguments, extractionResult); + + actions.push(action); + currentStep++; } - await this.executeAction(result, playwrightArguments, extractionResult); + const finalResult: AgentResult = { + success: true, + message: await this.getSummary(options.instruction), + actions, + completed: + actions.length > 0 + ? (actions[actions.length - 1].taskCompleted as boolean) + : false, + }; - actions.push({ - type: result.method, - reasoning: result.reasoning, - taskCompleted: result.taskComplete, - parameters: result.parameters, - playwrightArguments, - extractionResult, - }); + await options.onSuccess?.(finalResult); - currentStep++; + return finalResult; + } catch (error) { + await options.onFailure?.(error as Error); + throw error; } - - return { - success: true, - message: await this.getSummary(options.instruction), - actions, - completed: actions[actions.length - 1].taskCompleted as boolean, - }; } private async getNextStep(currentStep: number): Promise { diff --git a/lib/index.ts b/lib/index.ts index 49754f701..df79f4599 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -824,6 +824,7 @@ export class Stagehand { this.stagehandPage, this.logger, this.llmClient, + options, ).execute(instructionOrOptions); }, }; diff --git a/types/agent.ts b/types/agent.ts index 8cc062012..e8fb6dbe3 100644 --- a/types/agent.ts +++ b/types/agent.ts @@ -22,6 +22,9 @@ export interface AgentOptions { export interface AgentExecuteOptions extends AgentOptions { instruction: string; + onStep?: (action: AgentAction) => void | Promise; + onSuccess?: (result: AgentResult) => void | Promise; + onFailure?: (error: Error) => void | Promise; } export type AgentProviderType = "openai" | "anthropic"; diff --git a/types/stagehand.ts b/types/stagehand.ts index 0682569bb..285a13dc9 100644 --- a/types/stagehand.ts +++ b/types/stagehand.ts @@ -1,11 +1,10 @@ import Browserbase from "@browserbasehq/sdk"; +import { Cookie } from "@playwright/test"; import { z } from "zod"; +import { LLMClient } from "../lib/llm/LLMClient"; import { LLMProvider } from "../lib/llm/LLMProvider"; import { LogLine } from "./log"; import { AvailableModel, ClientOptions } from "./model"; -import { LLMClient } from "../lib/llm/LLMClient"; -import { Cookie } from "@playwright/test"; -import { AgentProviderType } from "./agent"; export interface ConstructorParams { /** @@ -231,18 +230,24 @@ export interface AgentExecuteParams { context?: string; } -/** - * Configuration for agent functionality - */ -export interface AgentConfig { +interface OpenAIAgentConfig { + provider: "openai"; /** - * The provider to use for agent functionality + * The model to use for agent functionality */ - provider?: AgentProviderType; + + model: "computer-use-preview"; +} + +type AnthropicAgentConfig = { + provider: "anthropic"; /** * The model to use for agent functionality */ - model?: string; + model: "claude-3-5-sonnet-20240620" | "claude-3-7-sonnet-20250219"; +}; + +type GenericAgentConfig = { /** * Custom instructions to provide to the agent */ @@ -251,7 +256,13 @@ export interface AgentConfig { * Additional options to pass to the agent client */ options?: Record; -} +}; + +/** + * Configuration for agent functionality + */ +export type AgentConfig = (OpenAIAgentConfig | AnthropicAgentConfig) & + GenericAgentConfig; export enum StagehandFunctionName { ACT = "ACT",