From a5c07ab67e1772b7197eaf1abf4ea911b163ef67 Mon Sep 17 00:00:00 2001 From: taobaorun Date: Fri, 18 Apr 2025 01:29:01 +0800 Subject: [PATCH 1/2] tool call support context --- .../WebFluxSseServerTransportProvider.java | 18 ++++---- .../WebFluxSseIntegrationTests.java | 20 ++++----- .../WebMvcSseServerTransportProvider.java | 2 +- .../server/WebMvcSseIntegrationTests.java | 16 ++++---- .../server/AbstractMcpAsyncServerTests.java | 13 +++--- .../server/AbstractMcpSyncServerTests.java | 8 ++-- .../server/McpAsyncServer.java | 22 +++++----- .../server/McpServer.java | 5 ++- .../server/McpServerFeatures.java | 34 +++++++-------- .../server/TriFunction.java | 19 +++++++++ ...HttpServletSseServerTransportProvider.java | 2 +- .../StdioServerTransportProvider.java | 2 +- .../spec/McpServerSession.java | 21 +++++++--- .../spec/RequestContext.java | 41 +++++++++++++++++++ .../MockMcpServerTransportProvider.java | 2 +- .../server/AbstractMcpAsyncServerTests.java | 13 +++--- .../server/AbstractMcpSyncServerTests.java | 8 ++-- ...rverTransportProviderIntegrationTests.java | 18 ++++---- .../StdioServerTransportProviderTests.java | 2 +- 19 files changed, 171 insertions(+), 95 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/TriFunction.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/RequestContext.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 62264d9a..747edb2a 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -314,14 +314,16 @@ private Mono handleMessage(ServerRequest request) { return request.bodyToMono(String.class).flatMap(body -> { try { McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); - return session.handle(message).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> { - logger.error("Error processing message: {}", error.getMessage()); - // TODO: instead of signalling the error, just respond with 200 OK - // - the error is signalled on the SSE connection - // return ServerResponse.ok().build(); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) - .bodyValue(new McpError(error.getMessage())); - }); + return session.handle(request, message) + .flatMap(response -> ServerResponse.ok().build()) + .onErrorResume(error -> { + logger.error("Error processing message: {}", error.getMessage()); + // TODO: instead of signalling the error, just respond with 200 OK + // - the error is signalled on the SSE connection + // return ServerResponse.ok().build(); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .bodyValue(new McpError(error.getMessage())); + }); } catch (IllegalArgumentException | IOException e) { logger.error("Failed to deserialize message: {}", e.getMessage()); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 08619bd3..9b9bce70 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -103,7 +103,7 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) { McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), - (exchange, request) -> exchange.createMessage(mock(CreateMessageRequest.class)) + (exchange, request, context) -> exchange.createMessage(mock(CreateMessageRequest.class)) .thenReturn(mock(CallToolResult.class))); var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); @@ -144,7 +144,7 @@ void testCreateMessageSuccess(String clientType) { AtomicReference samplingResult = new AtomicReference<>(); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { var createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -220,7 +220,7 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr AtomicReference samplingResult = new AtomicReference<>(); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -296,7 +296,7 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt null); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -386,7 +386,7 @@ void testRootsWithoutCapability(String clientType) { var clientBuilder = clientBuilders.get(clientType); McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { exchange.listRoots(); // try to list roots @@ -525,7 +525,7 @@ void testToolCallSuccess(String clientType) { var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -565,7 +565,7 @@ void testToolListChangeHandlingSuccess(String clientType) { var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -617,7 +617,7 @@ void testToolListChangeHandlingSuccess(String clientType) { // Add a new tool McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), - (exchange, request) -> callResponse); + (exchange, request, context) -> callResponse); mcpServer.addTool(tool2); @@ -660,7 +660,7 @@ void testLoggingNotification(String clientType) { // Create server with a tool that sends logging notifications McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("logging-test", "Test logging notifications", emptyJsonSchema), - (exchange, request) -> { + (exchange, request, context) -> { // Create and send notifications with different levels @@ -802,4 +802,4 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { mcpServer.close(); } -} \ No newline at end of file +} diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index fc86cfaa..baca3c98 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -316,7 +316,7 @@ private ServerResponse handleMessage(ServerRequest request) { McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); // Process the message through the session's handle method - session.handle(message).block(); // Block for WebMVC compatibility + session.handle(request, message).block(); // Block for WebMVC compatibility return ServerResponse.ok().build(); } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index b12d6843..10bb7f7b 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -120,7 +120,7 @@ public void after() { void testCreateMessageWithoutSamplingCapabilities() { McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); @@ -167,7 +167,7 @@ void testCreateMessageSuccess() { null); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { var createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -243,7 +243,7 @@ void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { null); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -315,7 +315,7 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { null); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -408,7 +408,7 @@ void testRootsSuccess() { void testRootsWithoutCapability() { McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { exchange.listRoots(); // try to list roots @@ -535,7 +535,7 @@ void testToolCallSuccess() { var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -571,7 +571,7 @@ void testToolListChangeHandlingSuccess() { var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -623,7 +623,7 @@ void testToolListChangeHandlingSuccess() { // Add a new tool McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), - (exchange, request) -> callResponse); + (exchange, request, context) -> callResponse); mcpServer.addTool(tool2); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index cdd43e7e..52a910e7 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -109,8 +109,9 @@ void testAddTool() { .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) + StepVerifier + .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, + (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); @@ -123,12 +124,12 @@ void testAddDuplicateTool() { var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(duplicateTool, (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) + (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -144,7 +145,7 @@ void testRemoveTool() { var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); @@ -173,7 +174,7 @@ void testNotifyToolsListChanged() { var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index c81e638c..aa9fd3db 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -117,7 +117,7 @@ void testAddTool() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) + (exchange, args, requestContext) -> new CallToolResult(List.of(), false)))) .doesNotThrowAnyException(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); @@ -130,11 +130,11 @@ void testAddDuplicateTool() { var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) + .tool(duplicateTool, (exchange, args, requestContext) -> new CallToolResult(List.of(), false)) .build(); assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) + (exchange, args, requestContext) -> new CallToolResult(List.of(), false)))) .isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -148,7 +148,7 @@ void testRemoveTool() { var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(tool, (exchange, args) -> new CallToolResult(List.of(), false)) + .tool(tool, (exchange, args, requestContext) -> new CallToolResult(List.of(), false)) .build(); assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 906cb9a0..05c5af3f 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -292,7 +292,7 @@ private static class AsyncServerImpl extends McpAsyncServer { // Initialize request handlers for standard MCP methods // Ping MUST respond with an empty data, but not NULL response. - requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just(Map.of())); + requestHandlers.put(McpSchema.METHOD_PING, (exchange, params, context) -> Mono.just(Map.of())); // Add tools API handlers if the tool capability is enabled if (this.serverCapabilities.tools() != null) { @@ -472,7 +472,7 @@ public Mono notifyToolsListChanged() { } private McpServerSession.RequestHandler toolsListRequestHandler() { - return (exchange, params) -> { + return (exchange, params, context) -> { List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); return Mono.just(new McpSchema.ListToolsResult(tools, null)); @@ -480,7 +480,7 @@ private McpServerSession.RequestHandler toolsListRequ } private McpServerSession.RequestHandler toolsCallRequestHandler() { - return (exchange, params) -> { + return (exchange, params, context) -> { McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, new TypeReference() { }); @@ -493,7 +493,7 @@ private McpServerSession.RequestHandler toolsCallRequestHandler( return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); } - return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments())) + return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments(), context)) .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); }; } @@ -553,7 +553,7 @@ public Mono notifyResourcesListChanged() { } private McpServerSession.RequestHandler resourcesListRequestHandler() { - return (exchange, params) -> { + return (exchange, params, context) -> { var resourceList = this.resources.values() .stream() .map(McpServerFeatures.AsyncResourceSpecification::resource) @@ -563,13 +563,13 @@ private McpServerSession.RequestHandler resources } private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { - return (exchange, params) -> Mono + return (exchange, params, context) -> Mono .just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); } private McpServerSession.RequestHandler resourcesReadRequestHandler() { - return (exchange, params) -> { + return (exchange, params, context) -> { McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, new TypeReference() { }); @@ -646,7 +646,7 @@ public Mono notifyPromptsListChanged() { } private McpServerSession.RequestHandler promptsListRequestHandler() { - return (exchange, params) -> { + return (exchange, params, context) -> { // TODO: Implement pagination // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, // new TypeReference() { @@ -662,7 +662,7 @@ private McpServerSession.RequestHandler promptsList } private McpServerSession.RequestHandler promptsGetRequestHandler() { - return (exchange, params) -> { + return (exchange, params, context) -> { McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, new TypeReference() { }); @@ -697,7 +697,7 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN } private McpServerSession.RequestHandler setLoggerRequestHandler() { - return (exchange, params) -> { + return (exchange, params, context) -> { return Mono.defer(() -> { SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params, @@ -716,7 +716,7 @@ private McpServerSession.RequestHandler setLoggerRequestHandler() { } private McpServerSession.RequestHandler completionCompleteRequestHandler() { - return (exchange, params) -> { + return (exchange, params, context) -> { McpSchema.CompleteRequest request = parseCompletionParams(params); if (request.ref() == null) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index 84089703..82b84712 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -18,6 +18,7 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.RequestContext; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; @@ -306,7 +307,7 @@ public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabi * @throws IllegalArgumentException if tool or handler is null */ public AsyncSpecification tool(McpSchema.Tool tool, - BiFunction, Mono> handler) { + TriFunction, RequestContext, Mono> handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); @@ -751,7 +752,7 @@ public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabil * @throws IllegalArgumentException if tool or handler is null */ public SyncSpecification tool(McpSchema.Tool tool, - BiFunction, McpSchema.CallToolResult> handler) { + TriFunction, RequestContext, McpSchema.CallToolResult> handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index 8311f5d4..4f245237 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -4,6 +4,13 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.RequestContext; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -11,12 +18,6 @@ import java.util.function.BiConsumer; import java.util.function.BiFunction; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.Utils; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; - /** * MCP server features specification that a particular server can choose to support. * @@ -73,9 +74,9 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s : new McpSchema.ServerCapabilities(null, // completions null, // experimental new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable - // logging - // by - // default + // logging + // by + // default !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, !Utils.isEmpty(resources) ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, @@ -181,9 +182,9 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se : new McpSchema.ServerCapabilities(null, // completions null, // experimental new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable - // logging - // by - // default + // logging + // by + // default !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, !Utils.isEmpty(resources) ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, @@ -237,7 +238,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se * connected client. The second arguments is a map of tool arguments. */ public record AsyncToolSpecification(McpSchema.Tool tool, - BiFunction, Mono> call) { + TriFunction, RequestContext, Mono> call) { static AsyncToolSpecification fromSync(SyncToolSpecification tool) { // FIXME: This is temporary, proper validation should be implemented @@ -245,8 +246,8 @@ static AsyncToolSpecification fromSync(SyncToolSpecification tool) { return null; } return new AsyncToolSpecification(tool.tool(), - (exchange, map) -> Mono - .fromCallable(() -> tool.call().apply(new McpSyncServerExchange(exchange), map)) + (exchange, map, context) -> Mono + .fromCallable(() -> tool.call().apply(new McpSyncServerExchange(exchange), map, context)) .subscribeOn(Schedulers.boundedElastic())); } } @@ -413,7 +414,8 @@ static AsyncCompletionSpecification fromSync(SyncCompletionSpecification complet * client. The second arguments is a map of arguments passed to the tool. */ public record SyncToolSpecification(McpSchema.Tool tool, - BiFunction, McpSchema.CallToolResult> call) { + TriFunction, RequestContext, McpSchema.CallToolResult> call) { + } /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/TriFunction.java b/mcp/src/main/java/io/modelcontextprotocol/server/TriFunction.java new file mode 100644 index 00000000..fb164ad1 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/TriFunction.java @@ -0,0 +1,19 @@ +package io.modelcontextprotocol.server; + +import java.util.Objects; +import java.util.function.Function; + +/** + * @author taobaorun + */ +@FunctionalInterface +public interface TriFunction { + + R apply(T t, U u, V v); + + default TriFunction andThen(Function after) { + Objects.requireNonNull(after); + return (T t, U u, V v) -> after.apply(apply(t, u, v)); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index afdbff47..4d942890 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -289,7 +289,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); // Process the message through the session's handle method - session.handle(message).block(); // Block for Servlet compatibility + session.handle(request, message).block(); // Block for Servlet compatibility response.setStatus(HttpServletResponse.SC_OK); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index 819da977..7d556b86 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -186,7 +186,7 @@ private void initProcessing() { } private void handleIncomingMessages() { - this.inboundSink.asFlux().flatMap(message -> session.handle(message)).doOnTerminate(() -> { + this.inboundSink.asFlux().flatMap(message -> session.handle(null, message)).doOnTerminate(() -> { // The outbound processing will dispose its scheduler upon completion this.outboundSink.tryEmitComplete(); this.inboundScheduler.dispose(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 46c356cd..63de7c5b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -1,6 +1,7 @@ package io.modelcontextprotocol.spec; import java.time.Duration; +import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; @@ -150,9 +151,10 @@ public Mono sendNotification(String method, Object params) { * {@link io.modelcontextprotocol.server.McpSyncServer}) via * {@link McpServerSession.Factory} that the server creates. * @param message the incoming JSON-RPC message + * @param originalRequest the original request that triggered this message * @return a Mono that completes when the message is processed */ - public Mono handle(McpSchema.JSONRPCMessage message) { + public Mono handle(Object originalRequest, McpSchema.JSONRPCMessage message) { return Mono.defer(() -> { // TODO handle errors for communication to without initialization happening // first @@ -169,7 +171,7 @@ public Mono handle(McpSchema.JSONRPCMessage message) { } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); - return handleIncomingRequest(request).onErrorResume(error -> { + return handleIncomingRequest(originalRequest, request).onErrorResume(error -> { var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); @@ -195,9 +197,11 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) { /** * Handles an incoming JSON-RPC request by routing it to the appropriate handler. * @param request The incoming JSON-RPC request + * @param originalRequest the original request that triggered this message * @return A Mono containing the JSON-RPC response */ - private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request) { + private Mono handleIncomingRequest(Object originalRequest, + McpSchema.JSONRPCRequest request) { return Mono.defer(() -> { Mono resultMono; if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { @@ -220,8 +224,12 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, error.message(), error.data()))); } - - resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params())); + Map context = new HashMap<>(); + context.put("sessionId", this.id); + RequestContext requestContext = new RequestContext(context); + requestContext.setRequest(originalRequest); + resultMono = this.exchangeSink.asMono() + .flatMap(exchange -> handler.handle(exchange, request.params(), requestContext)); } return resultMono .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) @@ -333,9 +341,10 @@ public interface RequestHandler { * @param exchange the exchange associated with the client that allows calling * back to the connected client or inspecting its capabilities. * @param params the parameters of the request. + * @param context the request context * @return a Mono that will emit the response to the request. */ - Mono handle(McpAsyncServerExchange exchange, Object params); + Mono handle(McpAsyncServerExchange exchange, Object params, RequestContext context); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/RequestContext.java b/mcp/src/main/java/io/modelcontextprotocol/spec/RequestContext.java new file mode 100644 index 00000000..245b3841 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/RequestContext.java @@ -0,0 +1,41 @@ + +package io.modelcontextprotocol.spec; + +import java.util.Collections; +import java.util.Map; + +/** + * Represents the context for MCP request execution. + * + *

+ * The context map can contain any information that is relevant to the request. + *

+ * + * @author taobaorun + */ +public class RequestContext { + + /** + * the original request object + */ + private Object request; + + private final Map context; + + public RequestContext(Map context) { + this.context = Collections.unmodifiableMap(context); + } + + public Map getContext() { + return context; + } + + public Object getRequest() { + return request; + } + + public void setRequest(Object request) { + this.request = request; + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java index 20a8c0cf..74f26c01 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java @@ -57,7 +57,7 @@ public Mono closeGracefully() { } public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) { - session.handle(message).subscribe(); + session.handle(null, message).subscribe(); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index df0b0c72..0f9dfd70 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -108,8 +108,9 @@ void testAddTool() { .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) + StepVerifier + .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, + (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); @@ -122,12 +123,12 @@ void testAddDuplicateTool() { var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(duplicateTool, (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) + (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -143,7 +144,7 @@ void testRemoveTool() { var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); @@ -172,7 +173,7 @@ void testNotifyToolsListChanged() { var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 0b38da85..0149aacd 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -116,7 +116,7 @@ void testAddTool() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) + (exchange, args, requestContext) -> new CallToolResult(List.of(), false)))) .doesNotThrowAnyException(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); @@ -129,11 +129,11 @@ void testAddDuplicateTool() { var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) + .tool(duplicateTool, (exchange, args, requestContext) -> new CallToolResult(List.of(), false)) .build(); assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) + (exchange, args, requestContext) -> new CallToolResult(List.of(), false)))) .isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -147,7 +147,7 @@ void testRemoveTool() { var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(tool, (exchange, args) -> new CallToolResult(List.of(), false)) + .tool(tool, (exchange, args, requestContext) -> new CallToolResult(List.of(), false)) .build(); assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index 2ff6325a..f0f16bd0 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -108,7 +108,7 @@ public void after() { void testCreateMessageWithoutSamplingCapabilities() { McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); @@ -150,7 +150,7 @@ void testCreateMessageSuccess() { null); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { var createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -225,7 +225,7 @@ void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { null); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -297,7 +297,7 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { null); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -390,7 +390,7 @@ void testRootsSuccess() { void testRootsWithoutCapability() { McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { exchange.listRoots(); // try to list roots @@ -514,7 +514,7 @@ void testToolCallSuccess() { var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -550,7 +550,7 @@ void testToolListChangeHandlingSuccess() { var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -602,7 +602,7 @@ void testToolListChangeHandlingSuccess() { // Add a new tool McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), - (exchange, request) -> callResponse); + (exchange, request, context) -> callResponse); mcpServer.addTool(tool2); @@ -638,7 +638,7 @@ void testLoggingNotification() { // Create server with a tool that sends logging notifications McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("logging-test", "Test logging notifications", emptyJsonSchema), - (exchange, request) -> { + (exchange, request, context) -> { // Create and send notifications with different levels diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java index 14987b5a..2f5523b8 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java @@ -112,7 +112,7 @@ void shouldHandleIncomingMessages() throws Exception { McpServerSession.Factory realSessionFactory = transport -> { McpServerSession session = mock(McpServerSession.class); - when(session.handle(any())).thenAnswer(invocation -> { + when(session.handle(any(), any())).thenAnswer(invocation -> { capturedMessage.set(invocation.getArgument(0)); messageLatch.countDown(); return Mono.empty(); From a8e69a03033c912a253a28336ca80f7e36b06918 Mon Sep 17 00:00:00 2001 From: taobaorun Date: Fri, 18 Apr 2025 09:06:45 +0800 Subject: [PATCH 2/2] remove TriFunction.Add McpRequest record encapsulate the params and context --- .../WebFluxSseIntegrationTests.java | 18 ++++----- .../server/WebMvcSseIntegrationTests.java | 16 ++++---- .../server/AbstractMcpAsyncServerTests.java | 13 +++---- .../server/AbstractMcpSyncServerTests.java | 8 ++-- .../server/McpAsyncServer.java | 37 ++++++++++--------- .../server/McpServer.java | 5 ++- .../server/McpServerFeatures.java | 9 +++-- .../server/TriFunction.java | 19 ---------- .../modelcontextprotocol/spec/McpRequest.java | 9 +++++ .../spec/McpServerSession.java | 7 ++-- .../server/AbstractMcpAsyncServerTests.java | 13 +++---- .../server/AbstractMcpSyncServerTests.java | 8 ++-- ...rverTransportProviderIntegrationTests.java | 18 ++++----- 13 files changed, 86 insertions(+), 94 deletions(-) delete mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/TriFunction.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpRequest.java diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 9b9bce70..8a5bb43e 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -103,7 +103,7 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) { McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), - (exchange, request, context) -> exchange.createMessage(mock(CreateMessageRequest.class)) + (exchange, request) -> exchange.createMessage(mock(CreateMessageRequest.class)) .thenReturn(mock(CallToolResult.class))); var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); @@ -144,7 +144,7 @@ void testCreateMessageSuccess(String clientType) { AtomicReference samplingResult = new AtomicReference<>(); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { var createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -220,7 +220,7 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr AtomicReference samplingResult = new AtomicReference<>(); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -296,7 +296,7 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt null); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -386,7 +386,7 @@ void testRootsWithoutCapability(String clientType) { var clientBuilder = clientBuilders.get(clientType); McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { exchange.listRoots(); // try to list roots @@ -525,7 +525,7 @@ void testToolCallSuccess(String clientType) { var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -565,7 +565,7 @@ void testToolListChangeHandlingSuccess(String clientType) { var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -617,7 +617,7 @@ void testToolListChangeHandlingSuccess(String clientType) { // Add a new tool McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), - (exchange, request, context) -> callResponse); + (exchange, request) -> callResponse); mcpServer.addTool(tool2); @@ -660,7 +660,7 @@ void testLoggingNotification(String clientType) { // Create server with a tool that sends logging notifications McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("logging-test", "Test logging notifications", emptyJsonSchema), - (exchange, request, context) -> { + (exchange, request) -> { // Create and send notifications with different levels diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 10bb7f7b..b12d6843 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -120,7 +120,7 @@ public void after() { void testCreateMessageWithoutSamplingCapabilities() { McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); @@ -167,7 +167,7 @@ void testCreateMessageSuccess() { null); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { var createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -243,7 +243,7 @@ void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { null); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -315,7 +315,7 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { null); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -408,7 +408,7 @@ void testRootsSuccess() { void testRootsWithoutCapability() { McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { exchange.listRoots(); // try to list roots @@ -535,7 +535,7 @@ void testToolCallSuccess() { var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -571,7 +571,7 @@ void testToolListChangeHandlingSuccess() { var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -623,7 +623,7 @@ void testToolListChangeHandlingSuccess() { // Add a new tool McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), - (exchange, request, context) -> callResponse); + (exchange, request) -> callResponse); mcpServer.addTool(tool2); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index 52a910e7..d480b0f2 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -109,9 +109,8 @@ void testAddTool() { .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - StepVerifier - .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, - (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))))) + StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, + (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); @@ -124,12 +123,12 @@ void testAddDuplicateTool() { var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, - (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))))) + (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -145,7 +144,7 @@ void testRemoveTool() { var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); @@ -174,7 +173,7 @@ void testNotifyToolsListChanged() { var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index aa9fd3db..9657f590 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -117,7 +117,7 @@ void testAddTool() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, - (exchange, args, requestContext) -> new CallToolResult(List.of(), false)))) + (exchange, request) -> new CallToolResult(List.of(), false)))) .doesNotThrowAnyException(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); @@ -130,11 +130,11 @@ void testAddDuplicateTool() { var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args, requestContext) -> new CallToolResult(List.of(), false)) + .tool(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) .build(); assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, - (exchange, args, requestContext) -> new CallToolResult(List.of(), false)))) + (exchange, request) -> new CallToolResult(List.of(), false)))) .isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -148,7 +148,7 @@ void testRemoveTool() { var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(tool, (exchange, args, requestContext) -> new CallToolResult(List.of(), false)) + .tool(tool, (exchange, request) -> new CallToolResult(List.of(), false)) .build(); assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 05c5af3f..8798a36b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -18,6 +18,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpClientSession; import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpRequest; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; @@ -292,7 +293,7 @@ private static class AsyncServerImpl extends McpAsyncServer { // Initialize request handlers for standard MCP methods // Ping MUST respond with an empty data, but not NULL response. - requestHandlers.put(McpSchema.METHOD_PING, (exchange, params, context) -> Mono.just(Map.of())); + requestHandlers.put(McpSchema.METHOD_PING, (exchange, request) -> Mono.just(Map.of())); // Add tools API handlers if the tool capability is enabled if (this.serverCapabilities.tools() != null) { @@ -472,7 +473,7 @@ public Mono notifyToolsListChanged() { } private McpServerSession.RequestHandler toolsListRequestHandler() { - return (exchange, params, context) -> { + return (exchange, request) -> { List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); return Mono.just(new McpSchema.ListToolsResult(tools, null)); @@ -480,9 +481,9 @@ private McpServerSession.RequestHandler toolsListRequ } private McpServerSession.RequestHandler toolsCallRequestHandler() { - return (exchange, params, context) -> { - McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, - new TypeReference() { + return (exchange, request) -> { + McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(request.params(), + new TypeReference<>() { }); Optional toolSpecification = this.tools.stream() @@ -493,7 +494,9 @@ private McpServerSession.RequestHandler toolsCallRequestHandler( return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); } - return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments(), context)) + return toolSpecification + .map(tool -> tool.call() + .apply(exchange, new McpRequest(callToolRequest.arguments(), request.context()))) .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); }; } @@ -553,7 +556,7 @@ public Mono notifyResourcesListChanged() { } private McpServerSession.RequestHandler resourcesListRequestHandler() { - return (exchange, params, context) -> { + return (exchange, request) -> { var resourceList = this.resources.values() .stream() .map(McpServerFeatures.AsyncResourceSpecification::resource) @@ -563,14 +566,14 @@ private McpServerSession.RequestHandler resources } private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { - return (exchange, params, context) -> Mono + return (exchange, request) -> Mono .just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); } private McpServerSession.RequestHandler resourcesReadRequestHandler() { - return (exchange, params, context) -> { - McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, + return (exchange, request) -> { + McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(request.params(), new TypeReference() { }); var resourceUri = resourceRequest.uri(); @@ -646,7 +649,7 @@ public Mono notifyPromptsListChanged() { } private McpServerSession.RequestHandler promptsListRequestHandler() { - return (exchange, params, context) -> { + return (exchange, request) -> { // TODO: Implement pagination // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, // new TypeReference() { @@ -662,8 +665,8 @@ private McpServerSession.RequestHandler promptsList } private McpServerSession.RequestHandler promptsGetRequestHandler() { - return (exchange, params, context) -> { - McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, + return (exchange, request) -> { + McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(request.params(), new TypeReference() { }); @@ -697,10 +700,10 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN } private McpServerSession.RequestHandler setLoggerRequestHandler() { - return (exchange, params, context) -> { + return (exchange, request) -> { return Mono.defer(() -> { - SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params, + SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(request.params(), new TypeReference() { }); @@ -716,8 +719,8 @@ private McpServerSession.RequestHandler setLoggerRequestHandler() { } private McpServerSession.RequestHandler completionCompleteRequestHandler() { - return (exchange, params, context) -> { - McpSchema.CompleteRequest request = parseCompletionParams(params); + return (exchange, req) -> { + McpSchema.CompleteRequest request = parseCompletionParams(req.params()); if (request.ref() == null) { return Mono.error(new McpError("ref must not be null")); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index 82b84712..642206c7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -14,6 +14,7 @@ import java.util.function.BiFunction; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpRequest; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; @@ -307,7 +308,7 @@ public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabi * @throws IllegalArgumentException if tool or handler is null */ public AsyncSpecification tool(McpSchema.Tool tool, - TriFunction, RequestContext, Mono> handler) { + BiFunction> handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); @@ -752,7 +753,7 @@ public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabil * @throws IllegalArgumentException if tool or handler is null */ public SyncSpecification tool(McpSchema.Tool tool, - TriFunction, RequestContext, McpSchema.CallToolResult> handler) { + BiFunction handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index 4f245237..9e5b82ec 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.spec.McpRequest; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.RequestContext; import io.modelcontextprotocol.util.Assert; @@ -238,7 +239,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se * connected client. The second arguments is a map of tool arguments. */ public record AsyncToolSpecification(McpSchema.Tool tool, - TriFunction, RequestContext, Mono> call) { + BiFunction> call) { static AsyncToolSpecification fromSync(SyncToolSpecification tool) { // FIXME: This is temporary, proper validation should be implemented @@ -246,8 +247,8 @@ static AsyncToolSpecification fromSync(SyncToolSpecification tool) { return null; } return new AsyncToolSpecification(tool.tool(), - (exchange, map, context) -> Mono - .fromCallable(() -> tool.call().apply(new McpSyncServerExchange(exchange), map, context)) + (exchange, request) -> Mono + .fromCallable(() -> tool.call().apply(new McpSyncServerExchange(exchange), request)) .subscribeOn(Schedulers.boundedElastic())); } } @@ -414,7 +415,7 @@ static AsyncCompletionSpecification fromSync(SyncCompletionSpecification complet * client. The second arguments is a map of arguments passed to the tool. */ public record SyncToolSpecification(McpSchema.Tool tool, - TriFunction, RequestContext, McpSchema.CallToolResult> call) { + BiFunction call) { } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/TriFunction.java b/mcp/src/main/java/io/modelcontextprotocol/server/TriFunction.java deleted file mode 100644 index fb164ad1..00000000 --- a/mcp/src/main/java/io/modelcontextprotocol/server/TriFunction.java +++ /dev/null @@ -1,19 +0,0 @@ -package io.modelcontextprotocol.server; - -import java.util.Objects; -import java.util.function.Function; - -/** - * @author taobaorun - */ -@FunctionalInterface -public interface TriFunction { - - R apply(T t, U u, V v); - - default TriFunction andThen(Function after) { - Objects.requireNonNull(after); - return (T t, U u, V v) -> after.apply(apply(t, u, v)); - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpRequest.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpRequest.java new file mode 100644 index 00000000..99a27914 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpRequest.java @@ -0,0 +1,9 @@ +package io.modelcontextprotocol.spec; + +/** + * @param params the parameters of the request. + * @param context the request context + * @author taobaorun + */ +public record McpRequest(Object params, RequestContext context) { +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 63de7c5b..ddd79737 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -229,7 +229,7 @@ private Mono handleIncomingRequest(Object originalReq RequestContext requestContext = new RequestContext(context); requestContext.setRequest(originalRequest); resultMono = this.exchangeSink.asMono() - .flatMap(exchange -> handler.handle(exchange, request.params(), requestContext)); + .flatMap(exchange -> handler.handle(exchange, new McpRequest(request.params(), requestContext))); } return resultMono .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) @@ -340,11 +340,10 @@ public interface RequestHandler { * Handles a request from the client. * @param exchange the exchange associated with the client that allows calling * back to the connected client or inspecting its capabilities. - * @param params the parameters of the request. - * @param context the request context + * @param request the parameters of the request with context * @return a Mono that will emit the response to the request. */ - Mono handle(McpAsyncServerExchange exchange, Object params, RequestContext context); + Mono handle(McpAsyncServerExchange exchange, McpRequest request); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index 0f9dfd70..66e7a79e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -108,9 +108,8 @@ void testAddTool() { .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - StepVerifier - .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, - (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))))) + StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, + (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); @@ -123,12 +122,12 @@ void testAddDuplicateTool() { var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, - (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))))) + (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -144,7 +143,7 @@ void testRemoveTool() { var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); @@ -173,7 +172,7 @@ void testNotifyToolsListChanged() { var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, (exchange, args, requestContext) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 0149aacd..968d6e04 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -116,7 +116,7 @@ void testAddTool() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, - (exchange, args, requestContext) -> new CallToolResult(List.of(), false)))) + (exchange, request) -> new CallToolResult(List.of(), false)))) .doesNotThrowAnyException(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); @@ -129,11 +129,11 @@ void testAddDuplicateTool() { var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args, requestContext) -> new CallToolResult(List.of(), false)) + .tool(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) .build(); assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, - (exchange, args, requestContext) -> new CallToolResult(List.of(), false)))) + (exchange, request) -> new CallToolResult(List.of(), false)))) .isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -147,7 +147,7 @@ void testRemoveTool() { var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(tool, (exchange, args, requestContext) -> new CallToolResult(List.of(), false)) + .tool(tool, (exchange, request) -> new CallToolResult(List.of(), false)) .build(); assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index f0f16bd0..2ff6325a 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -108,7 +108,7 @@ public void after() { void testCreateMessageWithoutSamplingCapabilities() { McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); @@ -150,7 +150,7 @@ void testCreateMessageSuccess() { null); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { var createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -225,7 +225,7 @@ void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { null); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -297,7 +297,7 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { null); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -390,7 +390,7 @@ void testRootsSuccess() { void testRootsWithoutCapability() { McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { exchange.listRoots(); // try to list roots @@ -514,7 +514,7 @@ void testToolCallSuccess() { var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -550,7 +550,7 @@ void testToolListChangeHandlingSuccess() { var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request, context) -> { + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -602,7 +602,7 @@ void testToolListChangeHandlingSuccess() { // Add a new tool McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), - (exchange, request, context) -> callResponse); + (exchange, request) -> callResponse); mcpServer.addTool(tool2); @@ -638,7 +638,7 @@ void testLoggingNotification() { // Create server with a tool that sends logging notifications McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("logging-test", "Test logging notifications", emptyJsonSchema), - (exchange, request, context) -> { + (exchange, request) -> { // Create and send notifications with different levels