Skip to content

add hooks to agent #615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/cold-humans-work.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@browserbasehq/stagehand": minor
---

Added support for execute-level hooks on agents. This includes `onStep(action)`, `onSuccess(reuslt)`, and `onFailure(error)` inputs.
2 changes: 1 addition & 1 deletion examples/cua-example.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async function main() {

const agent = stagehand.agent({
provider: "openai",
model: "computer-use-preview-2025-02-04",
model: "computer-use-preview",
instructions: `You are a helpful assistant that can use a web browser.
You are currently on the following page: ${page.url()}.
Do not ask follow up questions, the user will trust your judgement.`,
Expand Down
9 changes: 4 additions & 5 deletions lib/agent/AgentProvider.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { LogLine } from "@/types/log";
import { AgentClient } from "./AgentClient";
import { AgentType } from "@/types/agent";
import { OpenAICUAClient } from "./OpenAICUAClient";
import { AnthropicCUAClient } from "./AnthropicCUAClient";
import { LogLine } from "@/types/log";
import {
UnsupportedModelError,
UnsupportedModelProviderError,
} from "@/types/stagehandErrors";
import { AgentClient } from "./AgentClient";
import { AnthropicCUAClient } from "./AnthropicCUAClient";
import { OpenAICUAClient } from "./OpenAICUAClient";

// Map model names to their provider types
const modelToAgentProviderMap: Record<string, AgentType> = {
Expand All @@ -22,7 +22,6 @@ const modelToAgentProviderMap: Record<string, AgentType> = {
*/
export class AgentProvider {
private logger: (message: LogLine) => void;

/**
* Create a new agent provider
*/
Expand Down
26 changes: 18 additions & 8 deletions lib/agent/AnthropicCUAClient.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import Anthropic from "@anthropic-ai/sdk";
import { LogLine } from "@/types/log";
import {
AgentAction,
AgentExecuteOptions,
AgentExecutionOptions,
AgentResult,
AgentType,
AgentExecutionOptions,
ToolUseItem,
AnthropicMessage,
AnthropicContentBlock,
AnthropicMessage,
AnthropicTextBlock,
AnthropicToolResult,
ToolUseItem,
} from "@/types/agent";
import { AgentClient } from "./AgentClient";
import { LogLine } from "@/types/log";
import { AgentScreenshotProviderError } from "@/types/stagehandErrors";
import Anthropic from "@anthropic-ai/sdk";
import { AgentClient } from "./AgentClient";

export type ResponseInputItem = AnthropicMessage | AnthropicToolResult;

Expand Down Expand Up @@ -116,7 +117,7 @@ export class AnthropicCUAClient extends AgentClient {
level: 2,
});

const result = await this.executeStep(inputItems, logger);
const result = await this.executeStep(inputItems, logger, options);

// Add actions to the list
if (result.actions.length > 0) {
Expand Down Expand Up @@ -153,12 +154,16 @@ export class AnthropicCUAClient extends AgentClient {
});

// Return the final result
return {
const result = {
success: completed,
actions,
message: finalMessage,
completed,
};

options.onSuccess?.(result);

return result;
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : String(error);
Expand All @@ -168,6 +173,8 @@ export class AnthropicCUAClient extends AgentClient {
level: 0,
});

options.onFailure?.(error);

return {
success: false,
actions,
Expand All @@ -180,6 +187,7 @@ export class AnthropicCUAClient extends AgentClient {
async executeStep(
inputItems: ResponseInputItem[],
logger: (message: LogLine) => void,
options: AgentExecuteOptions,
): Promise<{
actions: AgentAction[];
message: string;
Expand Down Expand Up @@ -270,6 +278,8 @@ export class AnthropicCUAClient extends AgentClient {
message: `Executing action: ${action.type}`,
level: 1,
});

await options.onStep?.(action);
await this.actionHandler(action);
} catch (error) {
const errorMessage =
Expand Down
30 changes: 21 additions & 9 deletions lib/agent/OpenAICUAClient.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import OpenAI from "openai";
import { LogLine } from "../../types/log";
import {
AgentAction,
AgentExecuteOptions,
AgentExecutionOptions,
AgentResult,
AgentType,
AgentExecutionOptions,
ResponseInputItem,
ResponseItem,
ComputerCallItem,
FunctionCallItem,
ResponseInputItem,
ResponseItem,
} from "@/types/agent";
import { AgentClient } from "./AgentClient";
import { AgentScreenshotProviderError } from "@/types/stagehandErrors";
import OpenAI from "openai";
import { LogLine } from "../../types/log";
import { AgentClient } from "./AgentClient";

/**
* Client for OpenAI's Computer Use Assistant API
Expand Down Expand Up @@ -111,6 +112,7 @@ export class OpenAICUAClient extends AgentClient {
inputItems,
previousResponseId,
logger,
options,
);

// Add actions to the list
Expand All @@ -137,13 +139,17 @@ export class OpenAICUAClient extends AgentClient {
currentStep++;
}

// Return the final result
return {
const result = {
success: completed,
actions,
message: finalMessage,
completed,
};

options.onSuccess?.(result);

// Return the final result
return result;
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : String(error);
Expand All @@ -153,6 +159,8 @@ export class OpenAICUAClient extends AgentClient {
level: 0,
});

options.onFailure?.(error as Error);

return {
success: false,
actions,
Expand All @@ -170,6 +178,7 @@ export class OpenAICUAClient extends AgentClient {
inputItems: ResponseInputItem[],
previousResponseId: string | undefined,
logger: (message: LogLine) => void,
options: AgentExecuteOptions,
): Promise<{
actions: AgentAction[];
message: string;
Expand Down Expand Up @@ -224,7 +233,7 @@ export class OpenAICUAClient extends AgentClient {
}

// Take actions and get results
const nextInputItems = await this.takeAction(output, logger);
const nextInputItems = await this.takeAction(output, logger, options);

// Check if completed
const completed =
Expand Down Expand Up @@ -334,11 +343,14 @@ export class OpenAICUAClient extends AgentClient {
async takeAction(
output: ResponseItem[],
logger: (message: LogLine) => void,
options: AgentExecuteOptions,
): Promise<ResponseInputItem[]> {
const nextInputItems: ResponseInputItem[] = [];

// Add any computer calls to process
for (const item of output) {
await options.onStep?.(item);

if (item.type === "computer_call" && this.isComputerCallItem(item)) {
// Execute the action
try {
Expand Down
17 changes: 8 additions & 9 deletions lib/handlers/agentHandler.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import { StagehandPage } from "../StagehandPage";
import { AgentProvider } from "../agent/AgentProvider";
import { StagehandAgent } from "../agent/StagehandAgent";
import { AgentClient } from "../agent/AgentClient";
import { LogLine } from "../../types/log";
import {
AgentExecuteOptions,
ActionExecutionResult,
AgentAction,
AgentResult,
AgentExecuteOptions,
AgentHandlerOptions,
ActionExecutionResult,
AgentResult,
} from "@/types/agent";
import { LogLine } from "../../types/log";
import { StagehandPage } from "../StagehandPage";
import { AgentClient } from "../agent/AgentClient";
import { AgentProvider } from "../agent/AgentProvider";
import { StagehandAgent } from "../agent/StagehandAgent";

export class StagehandAgentHandler {
private stagehandPage: StagehandPage;
Expand All @@ -27,7 +27,6 @@ export class StagehandAgentHandler {
this.stagehandPage = stagehandPage;
this.logger = logger;
this.options = options;

// Initialize the provider
this.provider = new AgentProvider(logger);

Expand Down
Loading