diff --git a/packages/dartantic_ai/lib/dartantic_ai.dart b/packages/dartantic_ai/lib/dartantic_ai.dart index 1c547cda..e60b4ff8 100644 --- a/packages/dartantic_ai/lib/dartantic_ai.dart +++ b/packages/dartantic_ai/lib/dartantic_ai.dart @@ -8,6 +8,7 @@ export 'package:dartantic_interface/dartantic_interface.dart'; export 'src/agent/agent.dart'; export 'src/agent/model_string_parser.dart'; export 'src/agent/orchestrators/orchestrators.dart'; +export 'src/agent/tool_middleware.dart'; export 'src/chat_models/chat_models.dart'; export 'src/embeddings_models/embeddings_models.dart'; export 'src/logging_options.dart'; diff --git a/packages/dartantic_ai/lib/src/agent/agent.dart b/packages/dartantic_ai/lib/src/agent/agent.dart index 560ac9c2..bd133405 100644 --- a/packages/dartantic_ai/lib/src/agent/agent.dart +++ b/packages/dartantic_ai/lib/src/agent/agent.dart @@ -20,6 +20,7 @@ import 'media_response_accumulator.dart'; import 'model_string_parser.dart'; import 'orchestrators/default_streaming_orchestrator.dart'; import 'streaming_state.dart'; +import 'tool_middleware.dart'; /// An agent that manages chat models and provides tool execution and message /// collection capabilities. @@ -42,6 +43,7 @@ class Agent { /// - [tools]: List of tools the agent can use /// - [temperature]: Model temperature (0.0 to 1.0) /// - [enableThinking]: Enable extended thinking/reasoning (default: false) + /// - [middleware]: List of middleware to intercept tool calls /// - [chatModelOptions]: Provider-specific chat model configuration /// - [embeddingsModelOptions]: Provider-specific embeddings configuration /// - [mediaModelOptions]: Provider-specific media generation configuration @@ -51,6 +53,7 @@ class Agent { double? temperature, bool enableThinking = false, String? displayName, + List? middleware, this.chatModelOptions, this.embeddingsModelOptions, this.mediaModelOptions, @@ -86,10 +89,12 @@ class Agent { _tools = tools; _temperature = temperature; _enableThinking = enableThinking; + _middleware = middleware; _logger.fine( 'Agent created successfully with ${tools?.length ?? 0} tools, ' - 'temperature: $temperature, enableThinking: $enableThinking', + 'temperature: $temperature, enableThinking: $enableThinking, ' + 'middleware: ${middleware?.length ?? 0}', ); } @@ -103,6 +108,7 @@ class Agent { double? temperature, bool enableThinking = false, String? displayName, + List? middleware, this.chatModelOptions, this.embeddingsModelOptions, this.mediaModelOptions, @@ -129,10 +135,12 @@ class Agent { _tools = tools; _temperature = temperature; _enableThinking = enableThinking; + _middleware = middleware; _logger.fine( 'Agent created from provider with ${tools?.length ?? 0} tools, ' - 'temperature: $temperature, enableThinking: $enableThinking', + 'temperature: $temperature, enableThinking: $enableThinking, ' + 'middleware: ${middleware?.length ?? 0}', ); } @@ -189,6 +197,7 @@ class Agent { late final double? _temperature; late final bool _enableThinking; late final String? _displayName; + late final List? _middleware; static final Logger _logger = Logger('dartantic.chat_agent'); @@ -320,6 +329,7 @@ class Agent { final state = StreamingState( conversationHistory: conversationHistory, toolMap: {for (final tool in toolsToUse ?? []) tool.name: tool}, + middleware: _middleware, ); orchestrator.initialize(state); diff --git a/packages/dartantic_ai/lib/src/agent/streaming_state.dart b/packages/dartantic_ai/lib/src/agent/streaming_state.dart index 975fde42..8f364822 100644 --- a/packages/dartantic_ai/lib/src/agent/streaming_state.dart +++ b/packages/dartantic_ai/lib/src/agent/streaming_state.dart @@ -4,6 +4,7 @@ import 'package:logging/logging.dart'; import '../chat_models/helpers/tool_id_helpers.dart'; import 'message_accumulator.dart'; import 'tool_executor.dart'; +import 'tool_middleware.dart'; /// Encapsulates all mutable state required during streaming operations class StreamingState { @@ -11,8 +12,10 @@ class StreamingState { StreamingState({ required List conversationHistory, required Map toolMap, + List? middleware, }) : _conversationHistory = conversationHistory, - _toolMap = toolMap; + _toolMap = toolMap, + executor = ToolExecutor(middleware: middleware); /// Logger for state.streaming operations. static final Logger _logger = Logger('dartantic.state.streaming'); @@ -34,7 +37,7 @@ class StreamingState { final MessageAccumulator accumulator = const MessageAccumulator(); /// Tool executor for provider-specific tool execution - final ToolExecutor executor = const ToolExecutor(); + final ToolExecutor executor; /// Coordinator for managing tool IDs across the conversation final ToolIdCoordinator toolIdCoordinator = ToolIdCoordinator(); diff --git a/packages/dartantic_ai/lib/src/agent/tool_executor.dart b/packages/dartantic_ai/lib/src/agent/tool_executor.dart index 45f101df..2dfe94bf 100644 --- a/packages/dartantic_ai/lib/src/agent/tool_executor.dart +++ b/packages/dartantic_ai/lib/src/agent/tool_executor.dart @@ -4,6 +4,8 @@ import 'dart:convert'; import 'package:dartantic_interface/dartantic_interface.dart'; import 'package:logging/logging.dart'; +import 'tool_middleware.dart'; + /// Result of executing a single tool class ToolExecutionResult { /// Creates a new ToolExecutionResult @@ -36,9 +38,15 @@ class ToolExecutionResult { /// - Executes tools sequentially /// - Formats results as JSON strings /// - Includes error details in results for LLM consumption +/// - Supports middleware for intercepting tool calls class ToolExecutor { /// Creates a new ToolExecutor - const ToolExecutor(); + /// + /// [middleware] - Optional list of middleware to intercept tool calls + const ToolExecutor({this.middleware}); + + /// Optional middleware to intercept tool calls + final List? middleware; static final _logger = Logger('dartantic.executor.tool'); @@ -74,13 +82,46 @@ class ToolExecutor { ToolPart toolCall, Map toolMap, ) async { + // Look up the tool first final tool = toolMap[toolCall.name]; + // If middleware exists, chain through it + if (middleware != null && middleware!.isNotEmpty) { + return _executeWithMiddleware(toolCall, tool, toolMap); + } + + // Otherwise, execute directly (existing behavior) + return _executeDirectly(toolCall, tool); + } + + /// Executes a tool call through the middleware chain. + Future _executeWithMiddleware( + ToolPart toolCall, + Tool? tool, + Map toolMap, + ) async { + int index = 0; + + Future next() { + if (index < middleware!.length) { + final current = middleware![index++]; + return current.intercept(toolCall, tool, next); + } else { + // Last middleware - execute actual tool (if found) + return _executeDirectly(toolCall, tool); + } + } + + return next(); + } + + /// Executes a tool call directly without middleware. + Future _executeDirectly( + ToolPart toolCall, + Tool? tool, + ) async { if (tool == null) { - _logger.warning( - 'Tool ${toolCall.name} not found in available tools: ' - '${toolMap.keys.join(', ')}', - ); + _logger.warning('Tool ${toolCall.name} not found in available tools'); final error = Exception('Tool ${toolCall.name} not found'); return ToolExecutionResult( diff --git a/packages/dartantic_ai/lib/src/agent/tool_middleware.dart b/packages/dartantic_ai/lib/src/agent/tool_middleware.dart new file mode 100644 index 00000000..2a10d20c --- /dev/null +++ b/packages/dartantic_ai/lib/src/agent/tool_middleware.dart @@ -0,0 +1,82 @@ +import 'dart:async'; + +import 'package:dartantic_interface/dartantic_interface.dart'; + +import 'tool_executor.dart'; + +/// Middleware that can intercept tool calls before and after execution. +/// +/// Middleware allows you to: +/// - Log tool calls and results +/// - Modify tool arguments before execution +/// - Skip tool execution and return custom results +/// - Modify or replace tool results after execution +/// - Implement custom error handling +/// +/// Example: +/// ```dart +/// class LoggingMiddleware implements ToolMiddleware { +/// @override +/// Future intercept( +/// ToolPart toolCall, +/// Tool? tool, +/// Future Function() next, +/// ) async { +/// print('Before: ${toolCall.name}'); +/// final result = await next(); +/// print('After: ${toolCall.name} -> ${result.isSuccess}'); +/// return result; +/// } +/// } +/// ``` +abstract class ToolMiddleware { + /// Creates a ToolMiddleware + const ToolMiddleware(); + + /// Intercepts a tool call before and/or after execution. + /// + /// [toolCall] - The tool call to intercept + /// [tool] - The matched tool instance, or null if the tool was not found + /// [next] - Callback to continue to the next middleware or actual execution + /// + /// Returns the ToolExecutionResult (can be modified or replaced) + Future intercept( + ToolPart toolCall, + Tool? tool, + Future Function() next, + ); +} + +/// Adapter that wraps a function to implement ToolMiddleware. +/// +/// This allows you to use a simple function as middleware instead of +/// creating a class. +/// +/// Example: +/// ```dart +/// final middleware = FunctionToolMiddleware( +/// (toolCall, tool, next) async { +/// print('Executing ${toolCall.name}'); +/// return next(); +/// }, +/// ); +/// ``` +class FunctionToolMiddleware implements ToolMiddleware { + /// Creates a FunctionToolMiddleware that wraps the given handler function. + const FunctionToolMiddleware(this.handler); + + /// The function that handles the middleware logic + final Future Function( + ToolPart toolCall, + Tool? tool, + Future Function() next, + ) + handler; + + @override + Future intercept( + ToolPart toolCall, + Tool? tool, + Future Function() next, + ) => handler(toolCall, tool, next); +} diff --git a/packages/dartantic_ai/test/tool_middleware_test.dart b/packages/dartantic_ai/test/tool_middleware_test.dart new file mode 100644 index 00000000..f58d29d0 --- /dev/null +++ b/packages/dartantic_ai/test/tool_middleware_test.dart @@ -0,0 +1,458 @@ +// ignore_for_file: avoid_dynamic_calls + +import 'package:dartantic_ai/dartantic_ai.dart'; +import 'package:dartantic_ai/src/agent/tool_executor.dart'; +import 'package:test/test.dart'; + +import 'test_tools.dart'; + +void main() { + group('Tool Middleware', () { + group('ToolMiddleware interface', () { + test('FunctionToolMiddleware wraps a function correctly', () async { + var called = false; + final middleware = FunctionToolMiddleware((toolCall, tool, next) async { + called = true; + expect(toolCall.name, equals('string_tool')); + expect(tool, isNotNull); + return next(); + }); + + final executor = ToolExecutor(middleware: [middleware]); + final toolMap = {'string_tool': stringTool}; + const toolCall = ToolPart.call( + id: 'test-id', + name: 'string_tool', + arguments: {'input': 'test'}, + ); + + final result = await executor.executeSingle(toolCall, toolMap); + + expect(called, isTrue); + expect(result.isSuccess, isTrue); + expect(result.resultPart.result, contains('test')); + }); + + test('class-based middleware works', () async { + var beforeCalled = false; + var afterCalled = false; + + final testMiddleware = _TestMiddleware( + onIntercept: (toolCall, tool, next) async { + beforeCalled = true; + expect(toolCall.name, equals('string_tool')); + expect(tool, isNotNull); + final result = await next(); + afterCalled = true; + return result; + }, + ); + + final executor = ToolExecutor(middleware: [testMiddleware]); + final toolMap = {'string_tool': stringTool}; + const toolCall = ToolPart.call( + id: 'test-id', + name: 'string_tool', + arguments: {'input': 'test'}, + ); + + final result = await executor.executeSingle(toolCall, toolMap); + + expect(beforeCalled, isTrue); + expect(afterCalled, isTrue); + expect(result.isSuccess, isTrue); + }); + }); + + group('Middleware chain execution', () { + test('middleware executes in order', () async { + final executionOrder = []; + + final middleware1 = FunctionToolMiddleware(( + toolCall, + tool, + next, + ) async { + executionOrder.add('middleware1-before'); + final result = await next(); + executionOrder.add('middleware1-after'); + return result; + }); + + final middleware2 = FunctionToolMiddleware(( + toolCall, + tool, + next, + ) async { + executionOrder.add('middleware2-before'); + final result = await next(); + executionOrder.add('middleware2-after'); + return result; + }); + + final executor = ToolExecutor(middleware: [middleware1, middleware2]); + final toolMap = {'string_tool': stringTool}; + const toolCall = ToolPart.call( + id: 'test-id', + name: 'string_tool', + arguments: {'input': 'test'}, + ); + + await executor.executeSingle(toolCall, toolMap); + + expect( + executionOrder, + equals([ + 'middleware1-before', + 'middleware2-before', + 'middleware2-after', + 'middleware1-after', + ]), + ); + }); + + test('multiple middleware can modify results', () async { + final middleware1 = FunctionToolMiddleware(( + toolCall, + tool, + next, + ) async { + final result = await next(); + // Modify result by wrapping it + final modifiedResult = ToolExecutionResult( + toolPart: result.toolPart, + resultPart: ToolPart.result( + id: result.resultPart.id, + name: result.resultPart.name, + result: '[Middleware1] ${result.resultPart.result}', + ), + ); + return modifiedResult; + }); + + final middleware2 = FunctionToolMiddleware(( + toolCall, + tool, + next, + ) async { + final result = await next(); + // Modify result by wrapping it + final modifiedResult = ToolExecutionResult( + toolPart: result.toolPart, + resultPart: ToolPart.result( + id: result.resultPart.id, + name: result.resultPart.name, + result: '[Middleware2] ${result.resultPart.result}', + ), + ); + return modifiedResult; + }); + + final executor = ToolExecutor(middleware: [middleware1, middleware2]); + final toolMap = {'string_tool': stringTool}; + const toolCall = ToolPart.call( + id: 'test-id', + name: 'string_tool', + arguments: {'input': 'test'}, + ); + + final result = await executor.executeSingle(toolCall, toolMap); + + // Middleware2 wraps first, then middleware1 wraps that + expect( + result.resultPart.result, + equals('[Middleware1] [Middleware2] String result: test'), + ); + }); + }); + + group('Middleware skipping execution', () { + test( + 'middleware can skip tool execution and return custom result', + () async { + var toolExecuted = false; + + final mockTool = Tool>( + name: 'mock_tool', + description: 'Mock tool', + onCall: (_) { + toolExecuted = true; + return 'should not execute'; + }, + ); + + final middleware = FunctionToolMiddleware(( + toolCall, + tool, + next, + ) async { + // Skip execution, return custom result + return ToolExecutionResult( + toolPart: toolCall, + resultPart: ToolPart.result( + id: toolCall.id, + name: toolCall.name, + result: '{"skipped": true, "reason": "middleware override"}', + ), + ); + }); + + final executor = ToolExecutor(middleware: [middleware]); + final toolMap = {'mock_tool': mockTool}; + const toolCall = ToolPart.call( + id: 'test-id', + name: 'mock_tool', + arguments: {}, + ); + + final result = await executor.executeSingle(toolCall, toolMap); + + expect(toolExecuted, isFalse); + expect(result.isSuccess, isTrue); + expect(result.resultPart.result, contains('skipped')); + }, + ); + }); + + group('Middleware error handling', () { + test('middleware can catch and modify errors', () async { + final failingTool = Tool>( + name: 'failing_tool', + description: 'Tool that throws', + onCall: (_) { + throw Exception('Tool execution failed'); + }, + ); + + final middleware = FunctionToolMiddleware((toolCall, tool, next) async { + try { + return await next(); + } on Exception catch (e) { + // Catch and return custom error result + return ToolExecutionResult( + toolPart: toolCall, + resultPart: ToolPart.result( + id: toolCall.id, + name: toolCall.name, + result: '{"error": "Caught by middleware: $e"}', + ), + error: Exception('Middleware handled error'), + ); + } + }); + + final executor = ToolExecutor(middleware: [middleware]); + final toolMap = {'failing_tool': failingTool}; + const toolCall = ToolPart.call( + id: 'test-id', + name: 'failing_tool', + arguments: {}, + ); + + final result = await executor.executeSingle(toolCall, toolMap); + + // The tool executor should still catch the error, but middleware + // could theoretically intercept it + expect(result.error, isNotNull); + }); + + test('middleware handles missing tools', () async { + var middlewareCalled = false; + + final middleware = FunctionToolMiddleware((toolCall, tool, next) async { + middlewareCalled = true; + expect(tool, isNull); + // Middleware can handle missing tool case + return ToolExecutionResult( + toolPart: toolCall, + resultPart: ToolPart.result( + id: toolCall.id, + name: toolCall.name, + result: '{"error": "Tool not found, handled by middleware"}', + ), + error: Exception('Tool not found'), + ); + }); + + final executor = ToolExecutor(middleware: [middleware]); + final toolMap = {}; + const toolCall = ToolPart.call( + id: 'test-id', + name: 'nonexistent_tool', + arguments: {}, + ); + + final result = await executor.executeSingle(toolCall, toolMap); + + expect(middlewareCalled, isTrue); + expect(result.error, isNotNull); + }); + }); + + group('Backward compatibility', () { + test('ToolExecutor works without middleware', () async { + const executor = ToolExecutor(); + final toolMap = {'string_tool': stringTool}; + const toolCall = ToolPart.call( + id: 'test-id', + name: 'string_tool', + arguments: {'input': 'test'}, + ); + + final result = await executor.executeSingle(toolCall, toolMap); + + expect(result.isSuccess, isTrue); + expect(result.resultPart.result, contains('test')); + }); + + test('ToolExecutor works with empty middleware list', () async { + const executor = ToolExecutor(middleware: []); + final toolMap = {'string_tool': stringTool}; + const toolCall = ToolPart.call( + id: 'test-id', + name: 'string_tool', + arguments: {'input': 'test'}, + ); + + final result = await executor.executeSingle(toolCall, toolMap); + + expect(result.isSuccess, isTrue); + expect(result.resultPart.result, contains('test')); + }); + }); + + group('Integration with Agent', () { + test('Agent accepts middleware parameter', () { + final middleware = FunctionToolMiddleware( + (toolCall, tool, next) => next(), + ); + + final agent = Agent( + 'openai:gpt-4o-mini', + tools: [stringTool], + middleware: [middleware], + ); + + expect(agent, isNotNull); + }); + + test('Agent.forProvider accepts middleware parameter', () { + final middleware = FunctionToolMiddleware( + (toolCall, tool, next) => next(), + ); + + final provider = Agent.getProvider('openai'); + final agent = Agent.forProvider( + provider, + tools: [stringTool], + middleware: [middleware], + ); + + expect(agent, isNotNull); + }); + }); + + group('Logging middleware example', () { + test('logging middleware logs before and after execution', () async { + final logs = []; + + final loggingMiddleware = _LoggingMiddleware( + onIntercept: (toolCall, tool, next) async { + logs.add('Before: ${toolCall.name}(${toolCall.arguments})'); + final result = await next(); + logs.add('After: ${toolCall.name} -> ${result.isSuccess}'); + return result; + }, + ); + + final executor = ToolExecutor(middleware: [loggingMiddleware]); + final toolMap = {'string_tool': stringTool}; + const toolCall = ToolPart.call( + id: 'test-id', + name: 'string_tool', + arguments: {'input': 'test'}, + ); + + await executor.executeSingle(toolCall, toolMap); + + expect(logs, hasLength(2)); + expect(logs[0], contains('Before:')); + expect(logs[1], contains('After:')); + expect(logs[1], contains('true')); // isSuccess + }); + }); + + group('Batch execution with middleware', () { + test('middleware applies to each tool in batch', () async { + final callCount = {}; + + final middleware = FunctionToolMiddleware((toolCall, tool, next) async { + callCount[toolCall.name] = (callCount[toolCall.name] ?? 0) + 1; + return next(); + }); + + final executor = ToolExecutor(middleware: [middleware]); + final toolMap = {'string_tool': stringTool, 'int_tool': intTool}; + final toolCalls = [ + const ToolPart.call( + id: 'id1', + name: 'string_tool', + arguments: {'input': 'test1'}, + ), + const ToolPart.call( + id: 'id2', + name: 'int_tool', + arguments: {'value': 42}, + ), + const ToolPart.call( + id: 'id3', + name: 'string_tool', + arguments: {'input': 'test2'}, + ), + ]; + + await executor.executeBatch(toolCalls, toolMap); + + expect(callCount['string_tool'], equals(2)); + expect(callCount['int_tool'], equals(1)); + }); + }); + }); +} + +// Helper classes for testing +class _TestMiddleware implements ToolMiddleware { + _TestMiddleware({required this.onIntercept}); + + final Future Function( + ToolPart toolCall, + Tool? tool, + Future Function() next, + ) + onIntercept; + + @override + Future intercept( + ToolPart toolCall, + Tool? tool, + Future Function() next, + ) => onIntercept(toolCall, tool, next); +} + +class _LoggingMiddleware implements ToolMiddleware { + _LoggingMiddleware({required this.onIntercept}); + + final Future Function( + ToolPart toolCall, + Tool? tool, + Future Function() next, + ) + onIntercept; + + @override + Future intercept( + ToolPart toolCall, + Tool? tool, + Future Function() next, + ) => onIntercept(toolCall, tool, next); +}