Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
309 changes: 121 additions & 188 deletions core/src/agents/functions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/

import {Content, createUserContent, FunctionCall, Part} from '@google/genai';
import {SpanStatusCode} from '@opentelemetry/api';
import {isEmpty} from 'lodash-es';

import {InvocationContext} from '../agents/invocation_context.js';
Expand Down Expand Up @@ -211,69 +212,6 @@ export function generateRequestConfirmationEvent({
});
}

async function callToolAsync(
tool: BaseTool,
args: Record<string, any>, // eslint-disable-line @typescript-eslint/no-explicit-any
toolContext: Context,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
): Promise<any> {
return tracer.startActiveSpan(`execute_tool ${tool.name}`, async (span) => {
try {
logger.debug(`callToolAsync ${tool.name}`);
const result = await tool.runAsync({args, toolContext});
traceToolCall({
tool,
args,
functionResponseEvent: buildResponseEvent(
tool,
result,
toolContext,
toolContext.invocationContext,
),
});
return result;
} finally {
span.end();
}
});
}

function buildResponseEvent(
tool: BaseTool,
functionResult: unknown,
toolContext: Context,
invocationContext: InvocationContext,
): Event {
let responseResult: Record<string, unknown>;
if (typeof functionResult !== 'object' || functionResult == null) {
responseResult = {result: functionResult};
} else if (Array.isArray(functionResult)) {
responseResult = {results: functionResult};
} else {
responseResult = functionResult as Record<string, unknown>;
}

const partFunctionResponse: Part = {
functionResponse: {
name: tool.name,
response: responseResult,
id: toolContext.functionCallId,
},
};

const content: Content = {
role: 'user',
parts: [partFunctionResponse],
};

return createEvent({
invocationId: invocationContext.invocationId,
author: invocationContext.agent.name,
content: content,
actions: toolContext.actions,
branch: invocationContext.branch,
});
}
/**
* Handles function calls.
* Runtime behavior to pay attention to:
Expand Down Expand Up @@ -357,142 +295,142 @@ export async function handleFunctionCallList({
toolConfirmation,
});

// TODO - b/436079721: implement [tracer.start_as_current_span]
logger.debug(`execute_tool ${tool.name}`);
const functionArgs = functionCall.args ?? {};

// Step 1: Check if plugin before_tool_callback overrides the function
// response.
let functionResponse = null;
let functionResponseError: string | unknown | undefined;
functionResponse =
await invocationContext.pluginManager.runBeforeToolCallback({
tool: tool,
toolArgs: functionArgs,
toolContext: toolContext,
});

// Step 2: If no overrides are provided from the plugins, further run the
// canonical callback.
// TODO - b/425992518: validate the callback response type matches.
if (functionResponse == null) {
// Cover both null and undefined
for (const callback of beforeToolCallbacks) {
functionResponse = await callback({
tool: tool,
args: functionArgs,
context: toolContext,
});
if (functionResponse) {
break;
}
}
}

// Step 3: Otherwise, proceed calling the tool normally.
if (functionResponse == null) {
// Cover both null and undefined
await tracer.startActiveSpan(`execute_tool ${tool.name}`, async (span) => {
try {
functionResponse = await callToolAsync(tool, functionArgs, toolContext);
} catch (e: unknown) {
if (e instanceof Error) {
const onToolErrorResponse =
await invocationContext.pluginManager.runOnToolErrorCallback({
// Step 1: Check if plugin before_tool_callback overrides the function
// response.
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let functionResponse: any =
await invocationContext.pluginManager.runBeforeToolCallback({
tool,
toolArgs: functionArgs,
toolContext,
});
let functionResponseError: string | unknown | undefined;

// Step 2: If no overrides are provided from the plugins, further run the
// canonical callback.
// TODO - b/425992518: validate the callback response type matches.
if (functionResponse == null) {
// Cover both null and undefined
for (const callback of beforeToolCallbacks) {
functionResponse = await callback({
tool: tool,
toolArgs: functionArgs,
toolContext: toolContext,
error: e,
args: functionArgs,
context: toolContext,
});
if (functionResponse) {
break;
}
}
}

// Set function response to the result of the error callback and
// continue execution, do not shortcut
if (onToolErrorResponse) {
functionResponse = onToolErrorResponse;
} else {
// If the error callback returns undefined, use the error message
// as the function response error.
functionResponseError = e.message;
// Step 3: Otherwise, proceed calling the tool normally.
if (functionResponse == null) {
// Cover both null and undefined
try {
logger.debug(`callToolAsync ${tool.name}`);
functionResponse = await tool.runAsync({
args: functionArgs,
toolContext,
});
} catch (e: unknown) {
const err = e instanceof Error ? e : new Error(String(e));
span.recordException(err);
if (e instanceof Error) {
functionResponse =
await invocationContext.pluginManager.runOnToolErrorCallback({
tool,
toolArgs: functionArgs,
toolContext,
error: e,
});
}
if (!functionResponse) {
functionResponseError = e instanceof Error ? e.message : e;
span.setStatus({
code: SpanStatusCode.ERROR,
message: err.message,
});
}
}
} else {
// If the error is not an Error, use the error object as the function
// response error.
functionResponseError = e;
}
}
}

// Step 4: Check if plugin after_tool_callback overrides the function
// response.
let alteredFunctionResponse =
await invocationContext.pluginManager.runAfterToolCallback({
tool: tool,
toolArgs: functionArgs,
toolContext: toolContext,
result: functionResponse,
});

// Step 5: If no overrides are provided from the plugins, further run the
// canonical after_tool_callbacks.
if (alteredFunctionResponse == null) {
// Cover both null and undefined
for (const callback of afterToolCallbacks) {
alteredFunctionResponse = await callback({
tool: tool,
args: functionArgs,
context: toolContext,
response: functionResponse,
});
if (alteredFunctionResponse) {
break;
// Step 4: Check if plugin after_tool_callback overrides the function
// response.
let alteredFunctionResponse =
await invocationContext.pluginManager.runAfterToolCallback({
tool: tool,
toolArgs: functionArgs,
toolContext: toolContext,
result: functionResponse,
});

// Step 5: If no overrides are provided from the plugins, further run the
// canonical after_tool_callbacks.
if (alteredFunctionResponse == null) {
// Cover both null and undefined
for (const callback of afterToolCallbacks) {
alteredFunctionResponse = await callback({
tool: tool,
args: functionArgs,
context: toolContext,
response: functionResponse,
});
if (alteredFunctionResponse) {
break;
}
}
}
}
}

// Step 6: If alternative response exists from after_tool_callback, use it
// instead of the original function response.
if (alteredFunctionResponse != null) {
functionResponse = alteredFunctionResponse;
}
// Step 6: If alternative response exists from after_tool_callback, use it
// instead of the original function response.
functionResponse = alteredFunctionResponse ?? functionResponse;

// TODO - b/425992518: state event polluting runtime, consider fix.
// Allow long running function to return None as response.
if (tool.isLongRunning && !functionResponse) {
continue;
}
// TODO - b/425992518: state event polluting runtime, consider fix.
// Allow long running function to return None as response.
if (tool.isLongRunning && !functionResponse) {
return;
}

if (functionResponseError) {
functionResponse = {error: functionResponseError};
} else if (
typeof functionResponse !== 'object' ||
functionResponse == null
) {
functionResponse = {result: functionResponse};
} else if (Array.isArray(functionResponse)) {
functionResponse = {results: functionResponse};
}
if (functionResponseError) {
functionResponse = {error: functionResponseError};
} else if (
typeof functionResponse !== 'object' ||
functionResponse == null
) {
functionResponse = {result: functionResponse};
} else if (Array.isArray(functionResponse)) {
functionResponse = {results: functionResponse};
}

// Builds the function response event.
const functionResponseEvent = createEvent({
invocationId: invocationContext.invocationId,
author: invocationContext.agent.name,
content: createUserContent({
functionResponse: {
id: toolContext.functionCallId,
name: tool.name,
response: functionResponse,
},
}),
actions: toolContext.actions,
branch: invocationContext.branch,
});
// Builds the function response event.
const functionResponseEvent = createEvent({
invocationId: invocationContext.invocationId,
author: invocationContext.agent.name,
content: createUserContent({
functionResponse: {
id: toolContext.functionCallId,
name: tool.name,
response: functionResponse,
},
}),
actions: toolContext.actions,
branch: invocationContext.branch,
});

// TODO - b/436079721: implement [traceToolCall]
logger.debug('traceToolCall', {
tool: tool.name,
args: functionArgs,
functionResponseEvent: functionResponseEvent.id,
traceToolCall({
tool,
args: functionArgs,
functionResponseEvent,
});
functionResponseEvents.push(functionResponseEvent);
} finally {
span.end();
}
});
functionResponseEvents.push(functionResponseEvent);
}

if (!functionResponseEvents.length) {
Expand All @@ -506,11 +444,6 @@ export async function handleFunctionCallList({
tracer.startActiveSpan('execute_tool (merged)', (span) => {
try {
logger.debug('execute_tool (merged)');
// TODO - b/436079721: implement [traceMergedToolCalls]
logger.debug('traceMergedToolCalls', {
responseEventId: mergedEvent.id,
functionResponseEvent: mergedEvent.id,
});
traceMergedToolCalls({
responseEventId: mergedEvent.id,
functionResponseEvent: mergedEvent,
Expand Down
Loading
Loading