diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 62264d9a..747edb2a 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -314,14 +314,16 @@ private Mono handleMessage(ServerRequest request) { return request.bodyToMono(String.class).flatMap(body -> { try { McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); - return session.handle(message).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> { - logger.error("Error processing message: {}", error.getMessage()); - // TODO: instead of signalling the error, just respond with 200 OK - // - the error is signalled on the SSE connection - // return ServerResponse.ok().build(); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) - .bodyValue(new McpError(error.getMessage())); - }); + return session.handle(request, message) + .flatMap(response -> ServerResponse.ok().build()) + .onErrorResume(error -> { + logger.error("Error processing message: {}", error.getMessage()); + // TODO: instead of signalling the error, just respond with 200 OK + // - the error is signalled on the SSE connection + // return ServerResponse.ok().build(); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .bodyValue(new McpError(error.getMessage())); + }); } catch (IllegalArgumentException | IOException e) { logger.error("Failed to deserialize message: {}", e.getMessage()); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 08619bd3..8a5bb43e 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -802,4 +802,4 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { mcpServer.close(); } -} \ No newline at end of file +} diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index fc86cfaa..baca3c98 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -316,7 +316,7 @@ private ServerResponse handleMessage(ServerRequest request) { McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); // Process the message through the session's handle method - session.handle(message).block(); // Block for WebMVC compatibility + session.handle(request, message).block(); // Block for WebMVC compatibility return ServerResponse.ok().build(); } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index cdd43e7e..d480b0f2 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -110,7 +110,7 @@ void testAddTool() { .build(); StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) + (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); @@ -123,12 +123,12 @@ void testAddDuplicateTool() { var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) + (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -144,7 +144,7 @@ void testRemoveTool() { var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); @@ -173,7 +173,7 @@ void testNotifyToolsListChanged() { var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index c81e638c..9657f590 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -117,7 +117,7 @@ void testAddTool() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) + (exchange, request) -> new CallToolResult(List.of(), false)))) .doesNotThrowAnyException(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); @@ -130,11 +130,11 @@ void testAddDuplicateTool() { var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) + .tool(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) .build(); assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) + (exchange, request) -> new CallToolResult(List.of(), false)))) .isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -148,7 +148,7 @@ void testRemoveTool() { var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(tool, (exchange, args) -> new CallToolResult(List.of(), false)) + .tool(tool, (exchange, request) -> new CallToolResult(List.of(), false)) .build(); assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 906cb9a0..8798a36b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -18,6 +18,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpClientSession; import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpRequest; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; @@ -292,7 +293,7 @@ private static class AsyncServerImpl extends McpAsyncServer { // Initialize request handlers for standard MCP methods // Ping MUST respond with an empty data, but not NULL response. - requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just(Map.of())); + requestHandlers.put(McpSchema.METHOD_PING, (exchange, request) -> Mono.just(Map.of())); // Add tools API handlers if the tool capability is enabled if (this.serverCapabilities.tools() != null) { @@ -472,7 +473,7 @@ public Mono notifyToolsListChanged() { } private McpServerSession.RequestHandler toolsListRequestHandler() { - return (exchange, params) -> { + return (exchange, request) -> { List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); return Mono.just(new McpSchema.ListToolsResult(tools, null)); @@ -480,9 +481,9 @@ private McpServerSession.RequestHandler toolsListRequ } private McpServerSession.RequestHandler toolsCallRequestHandler() { - return (exchange, params) -> { - McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, - new TypeReference() { + return (exchange, request) -> { + McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(request.params(), + new TypeReference<>() { }); Optional toolSpecification = this.tools.stream() @@ -493,7 +494,9 @@ private McpServerSession.RequestHandler toolsCallRequestHandler( return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); } - return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments())) + return toolSpecification + .map(tool -> tool.call() + .apply(exchange, new McpRequest(callToolRequest.arguments(), request.context()))) .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); }; } @@ -553,7 +556,7 @@ public Mono notifyResourcesListChanged() { } private McpServerSession.RequestHandler resourcesListRequestHandler() { - return (exchange, params) -> { + return (exchange, request) -> { var resourceList = this.resources.values() .stream() .map(McpServerFeatures.AsyncResourceSpecification::resource) @@ -563,14 +566,14 @@ private McpServerSession.RequestHandler resources } private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { - return (exchange, params) -> Mono + return (exchange, request) -> Mono .just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); } private McpServerSession.RequestHandler resourcesReadRequestHandler() { - return (exchange, params) -> { - McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, + return (exchange, request) -> { + McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(request.params(), new TypeReference() { }); var resourceUri = resourceRequest.uri(); @@ -646,7 +649,7 @@ public Mono notifyPromptsListChanged() { } private McpServerSession.RequestHandler promptsListRequestHandler() { - return (exchange, params) -> { + return (exchange, request) -> { // TODO: Implement pagination // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, // new TypeReference() { @@ -662,8 +665,8 @@ private McpServerSession.RequestHandler promptsList } private McpServerSession.RequestHandler promptsGetRequestHandler() { - return (exchange, params) -> { - McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, + return (exchange, request) -> { + McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(request.params(), new TypeReference() { }); @@ -697,10 +700,10 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN } private McpServerSession.RequestHandler setLoggerRequestHandler() { - return (exchange, params) -> { + return (exchange, request) -> { return Mono.defer(() -> { - SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params, + SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(request.params(), new TypeReference() { }); @@ -716,8 +719,8 @@ private McpServerSession.RequestHandler setLoggerRequestHandler() { } private McpServerSession.RequestHandler completionCompleteRequestHandler() { - return (exchange, params) -> { - McpSchema.CompleteRequest request = parseCompletionParams(params); + return (exchange, req) -> { + McpSchema.CompleteRequest request = parseCompletionParams(req.params()); if (request.ref() == null) { return Mono.error(new McpError("ref must not be null")); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index 84089703..642206c7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -14,10 +14,12 @@ import java.util.function.BiFunction; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpRequest; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.RequestContext; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; @@ -306,7 +308,7 @@ public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabi * @throws IllegalArgumentException if tool or handler is null */ public AsyncSpecification tool(McpSchema.Tool tool, - BiFunction, Mono> handler) { + BiFunction> handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); @@ -751,7 +753,7 @@ public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabil * @throws IllegalArgumentException if tool or handler is null */ public SyncSpecification tool(McpSchema.Tool tool, - BiFunction, McpSchema.CallToolResult> handler) { + BiFunction handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index 8311f5d4..9e5b82ec 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -4,6 +4,14 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.spec.McpRequest; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.RequestContext; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -11,12 +19,6 @@ import java.util.function.BiConsumer; import java.util.function.BiFunction; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.Utils; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; - /** * MCP server features specification that a particular server can choose to support. * @@ -73,9 +75,9 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s : new McpSchema.ServerCapabilities(null, // completions null, // experimental new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable - // logging - // by - // default + // logging + // by + // default !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, !Utils.isEmpty(resources) ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, @@ -181,9 +183,9 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se : new McpSchema.ServerCapabilities(null, // completions null, // experimental new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable - // logging - // by - // default + // logging + // by + // default !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, !Utils.isEmpty(resources) ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, @@ -237,7 +239,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se * connected client. The second arguments is a map of tool arguments. */ public record AsyncToolSpecification(McpSchema.Tool tool, - BiFunction, Mono> call) { + BiFunction> call) { static AsyncToolSpecification fromSync(SyncToolSpecification tool) { // FIXME: This is temporary, proper validation should be implemented @@ -245,8 +247,8 @@ static AsyncToolSpecification fromSync(SyncToolSpecification tool) { return null; } return new AsyncToolSpecification(tool.tool(), - (exchange, map) -> Mono - .fromCallable(() -> tool.call().apply(new McpSyncServerExchange(exchange), map)) + (exchange, request) -> Mono + .fromCallable(() -> tool.call().apply(new McpSyncServerExchange(exchange), request)) .subscribeOn(Schedulers.boundedElastic())); } } @@ -413,7 +415,8 @@ static AsyncCompletionSpecification fromSync(SyncCompletionSpecification complet * client. The second arguments is a map of arguments passed to the tool. */ public record SyncToolSpecification(McpSchema.Tool tool, - BiFunction, McpSchema.CallToolResult> call) { + BiFunction call) { + } /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index afdbff47..4d942890 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -289,7 +289,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); // Process the message through the session's handle method - session.handle(message).block(); // Block for Servlet compatibility + session.handle(request, message).block(); // Block for Servlet compatibility response.setStatus(HttpServletResponse.SC_OK); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index 819da977..7d556b86 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -186,7 +186,7 @@ private void initProcessing() { } private void handleIncomingMessages() { - this.inboundSink.asFlux().flatMap(message -> session.handle(message)).doOnTerminate(() -> { + this.inboundSink.asFlux().flatMap(message -> session.handle(null, message)).doOnTerminate(() -> { // The outbound processing will dispose its scheduler upon completion this.outboundSink.tryEmitComplete(); this.inboundScheduler.dispose(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpRequest.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpRequest.java new file mode 100644 index 00000000..99a27914 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpRequest.java @@ -0,0 +1,9 @@ +package io.modelcontextprotocol.spec; + +/** + * @param params the parameters of the request. + * @param context the request context + * @author taobaorun + */ +public record McpRequest(Object params, RequestContext context) { +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 46c356cd..ddd79737 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -1,6 +1,7 @@ package io.modelcontextprotocol.spec; import java.time.Duration; +import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; @@ -150,9 +151,10 @@ public Mono sendNotification(String method, Object params) { * {@link io.modelcontextprotocol.server.McpSyncServer}) via * {@link McpServerSession.Factory} that the server creates. * @param message the incoming JSON-RPC message + * @param originalRequest the original request that triggered this message * @return a Mono that completes when the message is processed */ - public Mono handle(McpSchema.JSONRPCMessage message) { + public Mono handle(Object originalRequest, McpSchema.JSONRPCMessage message) { return Mono.defer(() -> { // TODO handle errors for communication to without initialization happening // first @@ -169,7 +171,7 @@ public Mono handle(McpSchema.JSONRPCMessage message) { } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); - return handleIncomingRequest(request).onErrorResume(error -> { + return handleIncomingRequest(originalRequest, request).onErrorResume(error -> { var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); @@ -195,9 +197,11 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) { /** * Handles an incoming JSON-RPC request by routing it to the appropriate handler. * @param request The incoming JSON-RPC request + * @param originalRequest the original request that triggered this message * @return A Mono containing the JSON-RPC response */ - private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request) { + private Mono handleIncomingRequest(Object originalRequest, + McpSchema.JSONRPCRequest request) { return Mono.defer(() -> { Mono resultMono; if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { @@ -220,8 +224,12 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, error.message(), error.data()))); } - - resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params())); + Map context = new HashMap<>(); + context.put("sessionId", this.id); + RequestContext requestContext = new RequestContext(context); + requestContext.setRequest(originalRequest); + resultMono = this.exchangeSink.asMono() + .flatMap(exchange -> handler.handle(exchange, new McpRequest(request.params(), requestContext))); } return resultMono .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) @@ -332,10 +340,10 @@ public interface RequestHandler { * Handles a request from the client. * @param exchange the exchange associated with the client that allows calling * back to the connected client or inspecting its capabilities. - * @param params the parameters of the request. + * @param request the parameters of the request with context * @return a Mono that will emit the response to the request. */ - Mono handle(McpAsyncServerExchange exchange, Object params); + Mono handle(McpAsyncServerExchange exchange, McpRequest request); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/RequestContext.java b/mcp/src/main/java/io/modelcontextprotocol/spec/RequestContext.java new file mode 100644 index 00000000..245b3841 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/RequestContext.java @@ -0,0 +1,41 @@ + +package io.modelcontextprotocol.spec; + +import java.util.Collections; +import java.util.Map; + +/** + * Represents the context for MCP request execution. + * + *

+ * The context map can contain any information that is relevant to the request. + *

+ * + * @author taobaorun + */ +public class RequestContext { + + /** + * the original request object + */ + private Object request; + + private final Map context; + + public RequestContext(Map context) { + this.context = Collections.unmodifiableMap(context); + } + + public Map getContext() { + return context; + } + + public Object getRequest() { + return request; + } + + public void setRequest(Object request) { + this.request = request; + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java index 20a8c0cf..74f26c01 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java @@ -57,7 +57,7 @@ public Mono closeGracefully() { } public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) { - session.handle(message).subscribe(); + session.handle(null, message).subscribe(); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index df0b0c72..66e7a79e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -109,7 +109,7 @@ void testAddTool() { .build(); StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) + (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); @@ -122,12 +122,12 @@ void testAddDuplicateTool() { var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) + (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -143,7 +143,7 @@ void testRemoveTool() { var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); @@ -172,7 +172,7 @@ void testNotifyToolsListChanged() { var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 0b38da85..968d6e04 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -116,7 +116,7 @@ void testAddTool() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) + (exchange, request) -> new CallToolResult(List.of(), false)))) .doesNotThrowAnyException(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); @@ -129,11 +129,11 @@ void testAddDuplicateTool() { var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) + .tool(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) .build(); assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) + (exchange, request) -> new CallToolResult(List.of(), false)))) .isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -147,7 +147,7 @@ void testRemoveTool() { var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(tool, (exchange, args) -> new CallToolResult(List.of(), false)) + .tool(tool, (exchange, request) -> new CallToolResult(List.of(), false)) .build(); assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java index 14987b5a..2f5523b8 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java @@ -112,7 +112,7 @@ void shouldHandleIncomingMessages() throws Exception { McpServerSession.Factory realSessionFactory = transport -> { McpServerSession session = mock(McpServerSession.class); - when(session.handle(any())).thenAnswer(invocation -> { + when(session.handle(any(), any())).thenAnswer(invocation -> { capturedMessage.set(invocation.getArgument(0)); messageLatch.countDown(); return Mono.empty();