From 89d10ca1a93eee8b052fc6385c6cfca695545e01 Mon Sep 17 00:00:00 2001 From: Sydney Runkle Date: Sat, 11 Oct 2025 07:34:42 -0400 Subject: [PATCH] new typing --- .../langchain/agents/middleware/__init__.py | 10 ++ .../langchain/agents/middleware/types.py | 114 +++++++++++++++--- libs/langchain_v1/langchain/tools/__init__.py | 16 ++- .../langchain_v1/langchain/tools/tool_node.py | 113 +++++++++++------ 4 files changed, 197 insertions(+), 56 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/middleware/__init__.py b/libs/langchain_v1/langchain/agents/middleware/__init__.py index de89191aa04ec..2e93d6ce4f897 100644 --- a/libs/langchain_v1/langchain/agents/middleware/__init__.py +++ b/libs/langchain_v1/langchain/agents/middleware/__init__.py @@ -20,7 +20,11 @@ from .types import ( AgentMiddleware, AgentState, + ModelCallHandler, + ModelCallResult, + ModelCallWrapper, ModelRequest, + ModelResponse, after_agent, after_model, before_agent, @@ -28,6 +32,7 @@ dynamic_prompt, hook_config, wrap_model_call, + wrap_tool_call, ) __all__ = [ @@ -41,9 +46,13 @@ "InterruptOnConfig", "LLMToolEmulator", "LLMToolSelectorMiddleware", + "ModelCallHandler", "ModelCallLimitMiddleware", + "ModelCallResult", + "ModelCallWrapper", "ModelFallbackMiddleware", "ModelRequest", + "ModelResponse", "PIIDetectionError", "PIIMiddleware", "PlanningMiddleware", @@ -56,4 +65,5 @@ "dynamic_prompt", "hook_config", "wrap_model_call", + "wrap_tool_call", ] diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index 801b213112829..6400fe158861e 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -17,9 +17,11 @@ ) if TYPE_CHECKING: - from collections.abc import Awaitable - - from langchain.tools.tool_node import ToolCallRequest + from langchain.tools.tool_node import ( + AsyncToolCallHandler, + ToolCallHandler, + ToolCallRequest, + ) # Needed as top level import for Pydantic schema generation on AgentState from typing import TypeAlias @@ -43,6 +45,9 @@ "AgentMiddleware", "AgentState", "ContextT", + "ModelCallHandler", + "ModelCallResult", + "ModelCallWrapper", "ModelRequest", "ModelResponse", "OmitFromSchema", @@ -53,6 +58,7 @@ "before_model", "dynamic_prompt", "hook_config", + "wrap_model_call", "wrap_tool_call", ] @@ -102,6 +108,82 @@ class ModelResponse: """ +ModelCallHandler = Callable[[ModelRequest], ModelResponse] +"""Type alias for the handler callback passed to wrap_model_call hooks. + +The handler executes the model request and returns a ModelResponse. It can be called +multiple times for retry logic or skipped entirely to short-circuit execution. + +Examples: + Simple passthrough: + ```python + def my_wrapper(request: ModelRequest, handler: ModelCallHandler) -> ModelCallResult: + return handler(request) + ``` + + Retry logic: + ```python + def retry_wrapper(request: ModelRequest, handler: ModelCallHandler) -> ModelCallResult: + for attempt in range(3): + try: + return handler(request) + except Exception: + if attempt == 2: + raise + ``` +""" + +AsyncModelCallHandler = Callable[[ModelRequest], Awaitable[ModelResponse]] +"""Type alias for the async handler callback passed to wrap_model_call hooks. + +The async handler executes the model request and returns a ModelResponse. It can be +called multiple times for retry logic or skipped entirely to short-circuit execution. +""" + +ModelCallWrapper = Callable[[ModelRequest, ModelCallHandler], ModelCallResult] +"""Type alias for synchronous model call wrapper functions. + +A wrapper receives a ModelRequest and a handler callback. It can modify the request, +call the handler (potentially multiple times), modify the response, or short-circuit +entirely. + +Args: + request: Model request containing state, runtime, messages, tools, etc. + handler: Callback to execute the model. Can be called multiple times. + +Returns: + ModelCallResult (either ModelResponse or AIMessage) + +Examples: + Basic retry pattern: + ```python + def retry_on_error(request: ModelRequest, handler: ModelCallHandler) -> ModelCallResult: + for attempt in range(3): + try: + return handler(request) + except Exception: + if attempt == 2: + raise + ``` + + Access runtime context: + ```python + def use_runtime(request: ModelRequest, handler: ModelCallHandler) -> ModelCallResult: + user_id = request.runtime.context.get("user_id") + # Modify request based on context + return handler(request) + ``` +""" + +AsyncModelCallWrapper = Callable[[ModelRequest, AsyncModelCallHandler], Awaitable[ModelCallResult]] +"""Type alias for asynchronous model call wrapper functions. + +A wrapper receives a ModelRequest and an async handler callback. It can modify the +request, call the handler (potentially multiple times), modify the response, or +short-circuit entirely. +""" + + @dataclass class OmitFromSchema: """Annotation used to mark state attributes as omitted from input or output schemas.""" @@ -195,7 +277,7 @@ async def aafter_model( def wrap_model_call( self, request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], + handler: ModelCallHandler, ) -> ModelCallResult: """Intercept and control model execution via handler callback. @@ -278,7 +360,7 @@ def wrap_model_call(self, request, handler): async def awrap_model_call( self, request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + handler: AsyncModelCallHandler, ) -> ModelCallResult: """Intercept and control async model execution via handler callback. @@ -331,7 +413,7 @@ async def aafter_agent( def wrap_tool_call( self, request: ToolCallRequest, - handler: Callable[[ToolCallRequest], ToolMessage | Command], + handler: ToolCallHandler, ) -> ToolMessage | Command: """Intercept tool execution for retries, monitoring, or modification. @@ -395,7 +477,7 @@ def wrap_tool_call(self, request, handler): async def awrap_tool_call( self, request: ToolCallRequest, - handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], + handler: AsyncToolCallHandler, ) -> ToolMessage | Command: """Intercept and control async tool execution via handler callback. @@ -480,7 +562,7 @@ class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # typ def __call__( self, request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], + handler: ModelCallHandler, ) -> ModelCallResult: """Intercept model execution via handler callback.""" ... @@ -495,7 +577,7 @@ class _CallableReturningToolResponse(Protocol): def __call__( self, request: ToolCallRequest, - handler: Callable[[ToolCallRequest], ToolMessage | Command], + handler: ToolCallHandler, ) -> ToolMessage | Command: """Intercept tool execution via handler callback.""" ... @@ -1174,7 +1256,7 @@ def decorator( async def async_wrapped( self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + handler: AsyncModelCallHandler, ) -> ModelCallResult: prompt = await func(request) # type: ignore[misc] request.system_prompt = prompt @@ -1195,7 +1277,7 @@ async def async_wrapped( def wrapped( self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], + handler: ModelCallHandler, ) -> ModelCallResult: prompt = cast("str", func(request)) request.system_prompt = prompt @@ -1204,7 +1286,7 @@ def wrapped( async def async_wrapped_from_sync( self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + handler: AsyncModelCallHandler, ) -> ModelCallResult: # Delegate to sync function prompt = cast("str", func(request)) @@ -1337,7 +1419,7 @@ def decorator( async def async_wrapped( self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + handler: AsyncModelCallHandler, ) -> ModelCallResult: return await func(request, handler) # type: ignore[misc, arg-type] @@ -1358,7 +1440,7 @@ async def async_wrapped( def wrapped( self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], + handler: ModelCallHandler, ) -> ModelCallResult: return func(request, handler) @@ -1480,7 +1562,7 @@ def decorator( async def async_wrapped( self: AgentMiddleware, # noqa: ARG001 request: ToolCallRequest, - handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], + handler: AsyncToolCallHandler, ) -> ToolMessage | Command: return await func(request, handler) # type: ignore[arg-type,misc] @@ -1501,7 +1583,7 @@ async def async_wrapped( def wrapped( self: AgentMiddleware, # noqa: ARG001 request: ToolCallRequest, - handler: Callable[[ToolCallRequest], ToolMessage | Command], + handler: ToolCallHandler, ) -> ToolMessage | Command: return func(request, handler) diff --git a/libs/langchain_v1/langchain/tools/__init__.py b/libs/langchain_v1/langchain/tools/__init__.py index 6d74894c319e3..6361f81a8badf 100644 --- a/libs/langchain_v1/langchain/tools/__init__.py +++ b/libs/langchain_v1/langchain/tools/__init__.py @@ -8,14 +8,28 @@ tool, ) -from langchain.tools.tool_node import InjectedState, InjectedStore, ToolInvocationError +from langchain.tools.tool_node import ( + AsyncToolCallHandler, + AsyncToolCallWrapper, + InjectedState, + InjectedStore, + ToolCallHandler, + ToolCallRequest, + ToolCallWrapper, + ToolInvocationError, +) __all__ = [ + "AsyncToolCallHandler", + "AsyncToolCallWrapper", "BaseTool", "InjectedState", "InjectedStore", "InjectedToolArg", "InjectedToolCallId", + "ToolCallHandler", + "ToolCallRequest", + "ToolCallWrapper", "ToolException", "ToolInvocationError", "tool", diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index e5d06f99f1ab6..ffe3e581ec1e4 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -118,26 +118,55 @@ class ToolCallRequest: tool_call: ToolCall tool: BaseTool state: Any - runtime: Any + runtime: Any # Runtime[Any] | None, but using Any for simplicity and to avoid circular imports -ToolCallWrapper = Callable[ - [ToolCallRequest, Callable[[ToolCallRequest], ToolMessage | Command]], - ToolMessage | Command, -] -"""Wrapper for tool call execution with multi-call support. +ToolCallHandler = Callable[[ToolCallRequest], ToolMessage | Command] +"""Type alias for the handler callback passed to wrap_tool_call hooks. + +The handler executes the tool call and returns a ToolMessage or Command. It can be +called multiple times for retry logic or skipped entirely to short-circuit execution. + +Examples: + Simple passthrough: + ```python + def my_wrapper(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command: + return handler(request) + ``` + + Retry logic: + ```python + def retry_wrapper(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command: + for attempt in range(3): + try: + return handler(request) + except Exception: + if attempt == 2: + raise + ``` +""" + +AsyncToolCallHandler = Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]] +"""Type alias for the async handler callback passed to wrap_tool_call hooks. + +The async handler executes the tool call and returns a ToolMessage or Command. It can +be called multiple times for retry logic or skipped entirely to short-circuit execution. +""" + +ToolCallWrapper = Callable[[ToolCallRequest, ToolCallHandler], ToolMessage | Command] +"""Type alias for synchronous tool call wrapper functions. + +A wrapper receives a ToolCallRequest and a handler callback. It can modify the request, +call the handler (potentially multiple times), modify the response, or short-circuit +entirely. -Wrapper receives: - request: ToolCallRequest with tool_call, tool, state, and runtime. - execute: Callable to execute the tool (CAN BE CALLED MULTIPLE TIMES). +Args: + request: Tool call request with tool_call, tool, state, and runtime. + handler: Callback to execute the tool. Can be called multiple times. Returns: ToolMessage or Command (the final result). -The execute callable can be invoked multiple times for retry logic, -with potentially modified requests each time. Each call to execute -is independent and stateless. - Note: When implementing middleware for `create_agent`, use `AgentMiddleware.wrap_tool_call` which provides properly typed @@ -145,55 +174,61 @@ class ToolCallRequest: Examples: Passthrough (execute once): - - def handler(request, execute): - return execute(request) + ```python + def passthrough(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command: + return handler(request) + ``` Modify request before execution: - - def handler(request, execute): + ```python + def modify_args(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command: request.tool_call["args"]["value"] *= 2 - return execute(request) + return handler(request) + ``` Retry on error (execute multiple times): - - def handler(request, execute): + ```python + def retry_on_error(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command: for attempt in range(3): try: - result = execute(request) + result = handler(request) if is_valid(result): return result except Exception: if attempt == 2: raise return result + ``` - Conditional retry based on response: - - def handler(request, execute): - for attempt in range(3): - result = execute(request) - if isinstance(result, ToolMessage) and result.status != "error": - return result - if attempt < 2: - continue - return result - - Cache/short-circuit without calling execute: + Access runtime context: + ```python + def use_runtime(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command: + if request.runtime is not None: + thread_id = request.runtime.context.get("thread_id") + # Use runtime context + return handler(request) + ``` - def handler(request, execute): + Cache/short-circuit without calling handler: + ```python + def with_cache(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command: if cached := get_cache(request): return ToolMessage(content=cached, tool_call_id=request.tool_call["id"]) - result = execute(request) + result = handler(request) save_cache(request, result) return result + ``` """ AsyncToolCallWrapper = Callable[ - [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]], - Awaitable[ToolMessage | Command], + [ToolCallRequest, AsyncToolCallHandler], Awaitable[ToolMessage | Command] ] -"""Async wrapper for tool call execution with multi-call support.""" +"""Type alias for asynchronous tool call wrapper functions. + +A wrapper receives a ToolCallRequest and an async handler callback. It can modify the +request, call the handler (potentially multiple times), modify the response, or +short-circuit entirely. +""" class ToolCallWithContext(TypedDict):