diff --git a/libs/langchain-aws/src/chat_models.ts b/libs/langchain-aws/src/chat_models/converse.ts similarity index 99% rename from libs/langchain-aws/src/chat_models.ts rename to libs/langchain-aws/src/chat_models/converse.ts index 1496d9115951..758d6f2d96ff 100644 --- a/libs/langchain-aws/src/chat_models.ts +++ b/libs/langchain-aws/src/chat_models/converse.ts @@ -46,12 +46,12 @@ import { handleConverseStreamMetadata, handleConverseStreamContentBlockStart, BedrockConverseToolChoice, -} from "./common.js"; +} from "../common.js"; import { ChatBedrockConverseToolType, ConverseCommandParams, CredentialType, -} from "./types.js"; +} from "../types.js"; /** * Inputs for ChatBedrockConverse. diff --git a/libs/langchain-aws/src/chat_models/index.ts b/libs/langchain-aws/src/chat_models/index.ts new file mode 100644 index 000000000000..c04f8c3cb2c5 --- /dev/null +++ b/libs/langchain-aws/src/chat_models/index.ts @@ -0,0 +1,2 @@ +export * from "./converse.js"; +export * from "./invoke_model.js"; diff --git a/libs/langchain-aws/src/chat_models/invoke_model.ts b/libs/langchain-aws/src/chat_models/invoke_model.ts new file mode 100644 index 000000000000..9e3c2a3e63eb --- /dev/null +++ b/libs/langchain-aws/src/chat_models/invoke_model.ts @@ -0,0 +1,294 @@ +import type { BaseMessage } from "@langchain/core/messages"; +import { + DefaultProviderInit, + defaultProvider, +} from "@aws-sdk/credential-provider-node"; +import type { BaseLanguageModelInput } from "@langchain/core/language_models/base"; +import { + BaseChatModel, + BaseChatModelCallOptions, + BaseChatModelParams, +} from "@langchain/core/language_models/chat_models"; +import { Runnable } from "@langchain/core/runnables"; +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { AIMessageChunk } from "@langchain/core/messages"; +import { ChatResult } from "@langchain/core/outputs"; +import type { DocumentType as __DocumentType } from "@smithy/types"; +import { + BedrockRuntimeClient, + InvokeModelCommand, +} from "@aws-sdk/client-bedrock-runtime"; +import { ChatBedrockConverseInput } from "./converse.js"; +import { + convertToInvokeModelTools, + BedrockConverseToolChoice, + convertToBedrockInvokeModelToolChoice, + convertToInvokeModelMessages, + convertInvokeModelMessageToLangChainMessage, +} from "../common.js"; +import { ChatBedrockInvokeModelToolType } from "../types.js"; + +export interface ChatBedrockInvokeModelInput + extends BaseChatModelParams, + Pick< + ChatBedrockConverseInput, + | "model" + | "credentials" + | "region" + | "client" + | "supportsToolChoiceValues" + | "streaming" + >, + Partial { + contentType?: string; + + trace?: "DISABLED" | "ENABLED"; + + guardrailIdentifier?: string; + + guardrailVersion?: string; + + performanceConfigLatency?: "standard" | "optimized"; + + anthropicVersion?: string; +} + +export interface ChatBedrockInvokeModelCallOptions + extends BaseChatModelCallOptions, + ChatBedrockInvokeModelInput { + body?: Record; + + /** + * A list of stop sequences. A stop sequence is a sequence of characters that causes + * the model to stop generating the response. + */ + stop?: string[]; + + tools?: ChatBedrockInvokeModelToolType[]; + + /** + * Tool choice for the model. If passing a string, it must be "any", "auto" or the + * name of the tool to use. Or, pass a BedrockToolChoice object. + * + * If "any" is passed, the model must request at least one tool. + * If "auto" is passed, the model automatically decides if a tool should be called + * or whether to generate text instead. + * If a tool name is passed, it will force the model to call that specific tool. + */ + tool_choice?: BedrockConverseToolChoice; +} + +export class ChatBedrockInvokeModel + extends BaseChatModel + implements ChatBedrockInvokeModelInput +{ + streaming = false; + + model = "anthropic.claude-3-haiku-20240307-v1:0"; + + region: string; + + client: BedrockRuntimeClient; + + contentType: string; + + trace?: "DISABLED" | "ENABLED"; + + guardrailIdentifier?: string; + + guardrailVersion?: string; + + body?: Record; + + performanceConfigLatency?: "standard" | "optimized"; + + /** + * Which types of `tool_choice` values the model supports. + * + * Inferred if not specified. Inferred as ['auto', 'any', 'tool'] if a 'claude-3' + * model is used, ['auto', 'any'] if a 'mistral-large' model is used, empty otherwise. + */ + supportsToolChoiceValues?: Array<"auto" | "any" | "tool">; + + constructor(fields?: ChatBedrockInvokeModelInput) { + super(fields ?? {}); + + const { + profile, + filepath, + configFilepath, + ignoreCache, + mfaCodeProvider, + roleAssumer, + roleArn, + webIdentityTokenFile, + roleAssumerWithWebIdentity, + ...rest + } = fields ?? {}; + + const credentials = + rest?.credentials ?? + defaultProvider({ + profile, + filepath, + configFilepath, + ignoreCache, + mfaCodeProvider, + roleAssumer, + roleArn, + webIdentityTokenFile, + roleAssumerWithWebIdentity, + }); + + const region = rest?.region ?? getEnvironmentVariable("AWS_DEFAULT_REGION"); + if (!region) { + throw new Error( + "Please set the AWS_DEFAULT_REGION environment variable or pass it to the constructor as the region field." + ); + } + + this.client = + fields?.client ?? + new BedrockRuntimeClient({ + region, + credentials, + }); + this.region = region; + this.model = rest?.model ?? this.model; + this.trace = rest?.trace ?? this.trace; + this.contentType = rest?.contentType ?? this.contentType; + this.streaming = rest?.streaming ?? this.streaming; + this.guardrailVersion = rest?.guardrailVersion ?? this.guardrailVersion; + this.guardrailIdentifier = + rest?.guardrailIdentifier ?? this.guardrailIdentifier; + this.performanceConfigLatency = + rest?.performanceConfigLatency ?? this.performanceConfigLatency; + if (rest?.supportsToolChoiceValues === undefined) { + if (this.model.includes("claude-3")) { + this.supportsToolChoiceValues = ["auto", "any", "tool"]; + } else if (this.model.includes("mistral-large")) { + this.supportsToolChoiceValues = ["auto", "any"]; + } else { + this.supportsToolChoiceValues = undefined; + } + } else { + this.supportsToolChoiceValues = rest.supportsToolChoiceValues; + } + } + + // Used for tracing, replace with the same name as your class + static lc_name() { + return "ChatBedrockInvokeModel"; + } + + /** + * Replace with any secrets this class passes to `super`. + * See {@link ../../langchain-cohere/src/chat_model.ts} for + * an example. + */ + get lc_secrets(): { [key: string]: string } | undefined { + return { + apiKey: "API_KEY_NAME", + }; + } + + get lc_aliases(): { [key: string]: string } | undefined { + return { + apiKey: "API_KEY_NAME", + }; + } + + _llmType() { + return "chat_bedrock_invoke_model"; + } + + async _generate( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): Promise { + if (this.streaming) { + throw new Error("streaming not supported yet."); + } + console.log(messages[0]) + return this._generateNonStreaming(messages, options, runManager); + } + + async _generateNonStreaming( + messages: BaseMessage[], + options: Partial, + _runManager?: CallbackManagerForLLMRun + ): Promise { + const { invokeModelMessages, invokeModelSystem } = + convertToInvokeModelMessages(messages); + const { toolConfig, ...rest } = this.invocationParams(options); + const body = options?.body ?? {}; + + const command = new InvokeModelCommand({ + ...rest, + body: JSON.stringify({ + ...body, + ...toolConfig, + messages: invokeModelMessages, + system: invokeModelSystem, + }), + }); + const response = await this.client.send(command, { + abortSignal: options.signal, + }); + const { body: output, ...responseMetadata } = response; + if (!output) { + throw new Error("No message found in Bedrock response."); + } + const message = convertInvokeModelMessageToLangChainMessage( + output, + responseMetadata + ); + return { + generations: [ + { + text: typeof message.content === "string" ? message.content : "", + message, + }, + ], + }; + } + + invocationParams(options?: this["ParsedCallOptions"]) { + let toolConfig; + if (options?.tools && options.tools.length) { + const tools = convertToInvokeModelTools(options.tools); + toolConfig = { + tools, + toolChoice: options.tool_choice + ? convertToBedrockInvokeModelToolChoice(options.tool_choice, tools, { + model: this.model, + supportsToolChoiceValues: this.supportsToolChoiceValues, + }) + : undefined, + }; + } + + return { + modelId: this.model, + contentType: this.contentType, + trace: this.trace, + guardrailVersion: this.guardrailVersion, + guardrailIdentifier: this.guardrailIdentifier, + performanceConfigLatency: this.performanceConfigLatency, + toolConfig, + }; + } + + override bindTools( + tools: ChatBedrockInvokeModelToolType[], + kwargs?: Partial + ): Runnable< + BaseLanguageModelInput, + AIMessageChunk, + this["ParsedCallOptions"] + > { + return this.bind({ tools: convertToInvokeModelTools(tools), ...kwargs }); + } +} diff --git a/libs/langchain-aws/src/common.ts b/libs/langchain-aws/src/common.ts index 1810b4f907cc..bcc5c28d742e 100644 --- a/libs/langchain-aws/src/common.ts +++ b/libs/langchain-aws/src/common.ts @@ -20,9 +20,11 @@ import type { ContentBlockDeltaEvent, ConverseStreamMetadataEvent, ContentBlockStartEvent, + InvokeModelResponse, ReasoningContentBlock, ReasoningContentBlockDelta, ReasoningTextBlock, + ConversationRole, } from "@aws-sdk/client-bedrock-runtime"; import type { DocumentType as __DocumentType } from "@smithy/types"; import { isLangChainTool } from "@langchain/core/utils/function_calling"; @@ -30,11 +32,14 @@ import { zodToJsonSchema } from "zod-to-json-schema"; import { ChatGenerationChunk } from "@langchain/core/outputs"; import { ChatBedrockConverseToolType, + ChatBedrockInvokeModelToolType, BedrockToolChoice, + BedrockInvokeModelTool, MessageContentReasoningBlock, MessageContentReasoningBlockReasoningText, MessageContentReasoningBlockReasoningTextPartial, MessageContentReasoningBlockRedacted, + InvokeModelBodyResponse, } from "./types.js"; export function extractImageInfo(base64: string): ContentBlock.ImageMember { @@ -68,6 +73,212 @@ export function extractImageInfo(base64: string): ContentBlock.ImageMember { }; } +interface ExtendTextMember extends ContentBlock.TextMember { + type: string; +} + +interface BedrockInvokeModelMessage { + role: ConversationRole | undefined; + content: Array | undefined; +} + +export function convertToInvokeModelMessages(messages: BaseMessage[]): { + invokeModelMessages: BedrockInvokeModelMessage[]; + invokeModelSystem: BedrockSystemContentBlock[]; +} { + const invokeModelSystem: BedrockSystemContentBlock[] = messages + .filter((msg) => msg._getType() === "system") + .map((msg) => { + if (typeof msg.content === "string") { + return { text: msg.content, type: "text" }; + } else if (msg.content.length === 1 && msg.content[0].type === "text") { + return { text: msg.content[0].text, type: "text" }; + } + throw new Error( + "System message content must be either a string, or a content array containing a single text object." + ); + }); + const invokeModelMessages: BedrockInvokeModelMessage[] = messages + .filter((msg) => msg._getType() !== "system") + .map((msg) => { + if (msg._getType() === "ai") { + const castMsg = msg as AIMessage; + const assistantMsg: BedrockInvokeModelMessage = { + role: "assistant", + content: [], + }; + if (typeof castMsg.content === "string" && castMsg.content !== "") { + assistantMsg.content?.push({ + text: castMsg.content, + type: "text", + }); + } else if (Array.isArray(castMsg.content)) { + const concatenatedBlocks = concatenateLangchainReasoningBlocks( + castMsg.content + ); + const contentBlocks: ContentBlock[] = concatenatedBlocks.map( + (block) => { + if (block.type === "text" && block.text !== "") { + return { + text: block.text, + type: "text", + }; + } else if (block.type === "reasoning_content") { + return { + reasoningContent: + langchainReasoningBlockToBedrockReasoningBlock( + block as MessageContentReasoningBlock + ), + }; + } else { + const blockValues = Object.fromEntries( + Object.entries(block).filter(([key]) => key !== "type") + ); + throw new Error( + `Unsupported content block type: ${ + block.type + } with content of ${JSON.stringify(blockValues, null, 2)}` + ); + } + } + ); + + assistantMsg.content = [ + ...(assistantMsg.content ? assistantMsg.content : []), + ...contentBlocks, + ]; + } + + // Important: this must be placed after any reasoning content blocks, else claude models will return an error. + if (castMsg.tool_calls && castMsg.tool_calls.length) { + assistantMsg.content = [ + ...(assistantMsg.content ? assistantMsg.content : []), + ...castMsg.tool_calls.map((tc) => ({ + toolUse: { + toolUseId: tc.id, + name: tc.name, + input: tc.args, + }, + })), + ]; + } + + return assistantMsg; + } else if (msg._getType() === "human" || msg._getType() === "generic") { + if (typeof msg.content === "string" && msg.content !== "") { + return { + role: "user" as const, + content: [ + { + text: msg.content, + type: "text", + }, + ], + }; + } else if (Array.isArray(msg.content)) { + const contentBlocks: ContentBlock[] = msg.content.flatMap((block) => { + if (block.type === "image_url") { + const base64: string = + typeof block.image_url === "string" + ? block.image_url + : block.image_url.url; + return extractImageInfo(base64); + } else if (block.type === "text") { + return { + text: block.text, + type: "text", + }; + } else if ( + block.type === "document" && + block.document !== undefined + ) { + return { + document: block.document, + }; + } else if (block.type === "image" && block.image !== undefined) { + return { + image: block.image, + }; + } else { + throw new Error(`Unsupported content block type: ${block.type}`); + } + }); + return { + role: "user" as const, + content: contentBlocks, + }; + } else { + throw new Error( + `Invalid message content: empty string. '${msg._getType()}' must contain non-empty content.` + ); + } + } else if (msg._getType() === "tool") { + const castMsg = msg as ToolMessage; + if (typeof castMsg.content === "string") { + return { + // Tool use messages are always from the user + role: "user" as const, + content: [ + { + toolResult: { + toolUseId: castMsg.tool_call_id, + content: [ + { + text: castMsg.content, + }, + ], + }, + }, + ], + }; + } else { + return { + // Tool use messages are always from the user + role: "user" as const, + content: [ + { + toolResult: { + toolUseId: castMsg.tool_call_id, + content: [ + { + json: castMsg.content, + }, + ], + }, + }, + ], + }; + } + } else { + throw new Error(`Unsupported message type: ${msg._getType()}`); + } + }); + + // Combine consecutive user tool result messages into a single message + const combinedInvokeModelMessages = invokeModelMessages.reduce( + (acc, curr) => { + const lastMessage = acc[acc.length - 1]; + + if ( + lastMessage && + lastMessage.role === "user" && + lastMessage.content?.some((c) => "toolResult" in c) && + curr.role === "user" && + curr.content?.some((c) => "toolResult" in c) + ) { + lastMessage.content = lastMessage.content.concat(curr.content); + } else { + acc.push(curr); + } + + return acc; + }, + [] + ); + + return { invokeModelMessages: combinedInvokeModelMessages, invokeModelSystem }; +} + export function convertToConverseMessages(messages: BaseMessage[]): { converseMessages: BedrockMessage[]; converseSystem: BedrockSystemContentBlock[]; @@ -269,6 +480,43 @@ export function isBedrockTool(tool: unknown): tool is BedrockTool { return false; } +export function isInvokeModelTool( + tool: unknown +): tool is BedrockInvokeModelTool { + if (typeof tool === "object" && tool && "input_schema" in tool) { + return true; + } + return false; +} + +export function convertToInvokeModelTools( + tools: ChatBedrockInvokeModelToolType[] +): BedrockInvokeModelTool[] { + if (tools.every(isOpenAITool)) { + return tools.map((tool) => ({ + name: tool.function.name, + description: tool.function.description, + input_schema: { + ...tool.function.parameters, + }, + })); + } else if (tools.every(isLangChainTool)) { + return tools.map((tool) => ({ + name: tool.name, + description: tool.description, + input_schema: { + ...zodToJsonSchema(tool.schema), + }, + })); + } else if (tools.every(isInvokeModelTool)) { + return tools; + } + + throw new Error( + "Invalid tools passed. Must be an array of StructuredToolInterface, ToolDefinition, or BedrockInvokeModelTool." + ); +} + export function convertToConverseTools( tools: ChatBedrockConverseToolType[] ): BedrockTool[] { @@ -307,6 +555,72 @@ export type BedrockConverseToolChoice = | string | BedrockToolChoice; +export function convertToBedrockInvokeModelToolChoice( + toolChoice: BedrockConverseToolChoice, + tools: BedrockInvokeModelTool[], + fields: { + model: string; + supportsToolChoiceValues?: Array<"auto" | "any" | "tool">; + } +) { + const supportsToolChoiceValues = fields.supportsToolChoiceValues ?? []; + + let bedrockToolChoice: BedrockToolChoice; + if (typeof toolChoice === "string") { + switch (toolChoice) { + case "any": + bedrockToolChoice = { + any: {}, + }; + break; + case "auto": + bedrockToolChoice = { + auto: {}, + }; + break; + default: { + const foundTool = tools.find((tool) => tool?.name === toolChoice); + if (!foundTool) { + throw new Error( + `Tool with name ${toolChoice} not found in tools list.` + ); + } + bedrockToolChoice = { + tool: { + name: toolChoice, + }, + }; + } + } + } else { + bedrockToolChoice = toolChoice; + } + + const toolChoiceType = Object.keys(bedrockToolChoice)[0] as + | "auto" + | "any" + | "tool"; + if (!supportsToolChoiceValues.includes(toolChoiceType)) { + let supportedTxt = ""; + if (supportsToolChoiceValues.length) { + supportedTxt = + `Model ${fields.model} does not currently support 'tool_choice' ` + + `of type ${toolChoiceType}. The following 'tool_choice' types ` + + `are supported: ${supportsToolChoiceValues.join(", ")}.`; + } else { + supportedTxt = `Model ${fields.model} does not currently support 'tool_choice'.`; + } + + throw new Error( + `${supportedTxt} Please see` + + "https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html" + + "for the latest documentation on models that support tool choice." + ); + } + + return bedrockToolChoice; +} + export function convertToBedrockToolChoice( toolChoice: BedrockConverseToolChoice, tools: BedrockTool[], @@ -375,6 +689,85 @@ export function convertToBedrockToolChoice( return bedrockToolChoice; } +export function convertInvokeModelMessageToLangChainMessage( + body: Uint8Array, + responseMetadata: Omit +) { + const bodyResponse = ( + JSON.parse(new TextDecoder().decode(body)) + ); + + const { id, content, role, usage, ...rest } = bodyResponse; + if (!content) { + throw new Error("No message content found in response."); + } + if (role !== "assistant") { + throw new Error( + `Unsupported message role received in InvokeModel response: ${bodyResponse.role}` + ); + } + + let tokenUsage: UsageMetadata | undefined; + if (usage) { + const input_tokens = usage.input_tokens ?? 0; + const output_tokens = usage.output_tokens ?? 0; + tokenUsage = { + input_tokens, + output_tokens, + total_tokens: input_tokens + output_tokens, + }; + } + + if ( + content?.length === 1 && + "text" in content[0] && + typeof content[0].text === "string" + ) { + return new AIMessage({ + content: content[0].text, + response_metadata: { + ...responseMetadata, + ...rest, + }, + usage_metadata: tokenUsage, + id, + }); + } else { + const toolCalls: ToolCall[] = []; + const complexContent: MessageContentComplex[] = []; + content.forEach((c) => { + if (c.type === "tool_use" && c.input && typeof c.input === "object") { + toolCalls.push({ + id: c.id, + name: c.name, + args: c.input, + type: "tool_call", + }); + } else if ("text" in c && typeof c.text === "string") { + complexContent.push({ type: "text", text: c.text }); + } else if ("reasoningContent" in c) { + complexContent.push( + bedrockReasoningBlockToLangchainReasoningBlock( + c.reasoningContent as ReasoningContentBlock + ) + ); + } else { + complexContent.push(c); + } + }); + return new AIMessage({ + content: complexContent, + tool_calls: toolCalls, + response_metadata: { + ...responseMetadata, + ...rest, + }, + usage_metadata: tokenUsage, + id, + }); + } +} + export function convertConverseMessageToLangChainMessage( message: BedrockMessage, responseMetadata: Omit diff --git a/libs/langchain-aws/src/index.ts b/libs/langchain-aws/src/index.ts index 6100947a9e49..a9f6622b66bc 100644 --- a/libs/langchain-aws/src/index.ts +++ b/libs/langchain-aws/src/index.ts @@ -1,4 +1,4 @@ -export * from "./chat_models.js"; +export * from "./chat_models/index.js"; export * from "./types.js"; export * from "./retrievers/index.js"; export * from "./embeddings.js"; diff --git a/libs/langchain-aws/src/tests/chat_models.int.test.ts b/libs/langchain-aws/src/tests/chat_models.int.test.ts index 954e919c36c0..72b0d9c94062 100644 --- a/libs/langchain-aws/src/tests/chat_models.int.test.ts +++ b/libs/langchain-aws/src/tests/chat_models.int.test.ts @@ -13,7 +13,7 @@ import { import { tool } from "@langchain/core/tools"; import { z } from "zod"; import { concat } from "@langchain/core/utils/stream"; -import { ChatBedrockConverse } from "../chat_models.js"; +import { ChatBedrockConverse } from "../chat_models/converse.js"; import { concatenateLangchainReasoningBlocks } from "../common.js"; import { MessageContentReasoningBlockReasoningText } from "../types.js"; diff --git a/libs/langchain-aws/src/tests/chat_models.standard.int.test.ts b/libs/langchain-aws/src/tests/chat_models.standard.int.test.ts index bfc099864956..c2bfefd86b88 100644 --- a/libs/langchain-aws/src/tests/chat_models.standard.int.test.ts +++ b/libs/langchain-aws/src/tests/chat_models.standard.int.test.ts @@ -5,7 +5,7 @@ import { AIMessageChunk } from "@langchain/core/messages"; import { ChatBedrockConverse, ChatBedrockConverseCallOptions, -} from "../chat_models.js"; +} from "../chat_models/converse.js"; class ChatBedrockConverseStandardIntegrationTests extends ChatModelIntegrationTests< ChatBedrockConverseCallOptions, diff --git a/libs/langchain-aws/src/tests/chat_models.standard.test.ts b/libs/langchain-aws/src/tests/chat_models.standard.test.ts index b22b59e932c2..afacae649a32 100644 --- a/libs/langchain-aws/src/tests/chat_models.standard.test.ts +++ b/libs/langchain-aws/src/tests/chat_models.standard.test.ts @@ -5,7 +5,7 @@ import { AIMessageChunk } from "@langchain/core/messages"; import { ChatBedrockConverse, ChatBedrockConverseCallOptions, -} from "../chat_models.js"; +} from "../chat_models/converse.js"; class ChatBedrockConverseStandardUnitTests extends ChatModelUnitTests< ChatBedrockConverseCallOptions, diff --git a/libs/langchain-aws/src/tests/chat_models.test.ts b/libs/langchain-aws/src/tests/chat_models.test.ts index e82faf216627..701cf50f8411 100644 --- a/libs/langchain-aws/src/tests/chat_models.test.ts +++ b/libs/langchain-aws/src/tests/chat_models.test.ts @@ -17,7 +17,7 @@ import { convertToConverseMessages, handleConverseStreamContentBlockDelta, } from "../common.js"; -import { ChatBedrockConverse } from "../chat_models.js"; +import { ChatBedrockConverse } from "../chat_models/converse.js"; describe("convertToConverseMessages", () => { const testCases: { diff --git a/libs/langchain-aws/src/types.ts b/libs/langchain-aws/src/types.ts index 1758f2dea935..a6d7f80f311b 100644 --- a/libs/langchain-aws/src/types.ts +++ b/libs/langchain-aws/src/types.ts @@ -1,9 +1,13 @@ import type { ToolChoice, Tool as BedrockTool, + StopReason, } from "@aws-sdk/client-bedrock-runtime"; import type { AwsCredentialIdentity, Provider } from "@aws-sdk/types"; -import { ConverseCommand } from "@aws-sdk/client-bedrock-runtime"; +import { + ConverseCommand, + InvokeModelCommand, +} from "@aws-sdk/client-bedrock-runtime"; import { BindToolsInput } from "@langchain/core/language_models/chat_models"; export type CredentialType = @@ -14,11 +18,37 @@ export type ConverseCommandParams = ConstructorParameters< typeof ConverseCommand >[0]; +export type InvokeModelCommandParams = ConstructorParameters< + typeof InvokeModelCommand +>[0]; + +export type InvokeModelBodyResponse = { + id: string; + type: string; + role: string; + model: string; + content: Record[]; + stop_reason: StopReason; + stop_sequence: string[] | null; + usage: { + input_tokens: number; + output_tokens: number; + }; +}; +export type BedrockInvokeModelTool = { + name: string | undefined; + description?: string | undefined; + input_schema: Record; +}; + export type BedrockToolChoice = | ToolChoice.AnyMember | ToolChoice.AutoMember | ToolChoice.ToolMember; export type ChatBedrockConverseToolType = BindToolsInput | BedrockTool; +export type ChatBedrockInvokeModelToolType = + | BindToolsInput + | BedrockInvokeModelTool; export type MessageContentReasoningBlockReasoningText = { type: "reasoning_content";