Skip to content

feat(mcp): WebFlux server support filter with CallToolRequest #206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package io.modelcontextprotocol.server.filter;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.reactive.function.server.HandlerFilterFunction;
import org.springframework.web.reactive.function.server.HandlerFunction;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.reactive.function.server.ServerResponse;
import reactor.core.publisher.Mono;

import java.util.Objects;
import java.util.function.Function;

public abstract class CallToolHandlerFilter implements HandlerFilterFunction<ServerResponse, ServerResponse> {

private static final Logger logger = LoggerFactory.getLogger(CallToolHandlerFilter.class);

private final ObjectMapper objectMapper = new ObjectMapper();

private static final String DEFAULT_MESSAGE_PATH = "/mcp/message";

public final static McpSchema.CallToolResult PASS = null;

private Function<String, McpServerSession> sessionFunction;

/**
* Filter incoming requests to handle tool calls. Processes
* {@linkplain io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest JSONRPCRequest}
* that match the configured path and method.
* @param request The incoming server request
* @param next The next handler in the chain
* @return The filtered response
*/
@Override
public Mono<ServerResponse> filter(ServerRequest request, HandlerFunction<ServerResponse> next) {
if (!Objects.equals(request.path(), matchPath())) {
return next.handle(request);
}

return request.bodyToMono(McpSchema.JSONRPCRequest.class)
.flatMap(jsonrpcRequest -> handleJsonRpcRequest(request, jsonrpcRequest, next));
}

private Mono<ServerResponse> handleJsonRpcRequest(ServerRequest request, McpSchema.JSONRPCRequest jsonrpcRequest,
HandlerFunction<ServerResponse> next) {
ServerRequest newRequest;
try {
newRequest = ServerRequest.from(request).body(objectMapper.writeValueAsString(jsonrpcRequest)).build();
}
catch (JsonProcessingException e) {
return Mono.error(e);
}

if (skipFilter(jsonrpcRequest)) {
return next.handle(newRequest);
}

return handleToolCallRequest(newRequest, jsonrpcRequest, next);
}

private Mono<ServerResponse> handleToolCallRequest(ServerRequest newRequest,
McpSchema.JSONRPCRequest jsonrpcRequest, HandlerFunction<ServerResponse> next) {
McpServerSession session = newRequest.queryParam("sessionId")
.map(sessionId -> sessionFunction.apply(sessionId))
.orElse(null);

if (Objects.isNull(session)) {
return next.handle(newRequest);
}

McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(jsonrpcRequest.params(),
McpSchema.CallToolRequest.class);
McpSchema.CallToolResult callToolResult = doFilter(newRequest, callToolRequest);
if (Objects.equals(PASS, callToolResult)) {
return next.handle(newRequest);
}
else {
return session.sendResponse(jsonrpcRequest.id(), callToolResult, null).then(ServerResponse.ok().build());
}
}

private boolean skipFilter(McpSchema.JSONRPCRequest jsonrpcRequest) {
if (!Objects.equals(jsonrpcRequest.method(), matchMethod())) {
return true;
}

if (Objects.isNull(sessionFunction)) {
logger.error("No session function provided, skip CallToolRequest filter");
return true;
}

return false;
}

/**
* Abstract method to be implemented by subclasses to handle tool call requests.
* @param request The incoming server request. Contains HTTP information such as:
* request path, request headers, request parameters. Note that the request body has
* already been extracted and deserialized into the callToolRequest parameter, so
* there's no need to extract the body from the ServerRequest again.
* @param callToolRequest The deserialized call tool request object
* @return A CallToolResult object if not pass current filter (subsequent filters will
* not be executed), or {@linkplain CallToolHandlerFilter#PASS PASS} pass current
* filter(subsequent filters will be executed)
*/
public abstract McpSchema.CallToolResult doFilter(ServerRequest request, McpSchema.CallToolRequest callToolRequest);

/**
* Returns the method name to match for handling tool calls.
* @return The method name to match
*/
public String matchMethod() {
return McpSchema.METHOD_TOOLS_CALL;
}

/**
* Returns the path to match for handling tool calls.
* @return The path to match
*/
public String matchPath() {
return DEFAULT_MESSAGE_PATH;
}

/**
* Set the session provider function.
* @param transportProvider The SSE server transport provider used to obtain sessions
*/
public void applySession(WebFluxSseServerTransportProvider transportProvider) {
this.sessionFunction = transportProvider::getSession;
}

}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package io.modelcontextprotocol.server.transport;

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import com.fasterxml.jackson.core.type.TypeReference;
Expand All @@ -14,6 +13,7 @@
import io.modelcontextprotocol.util.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.reactive.function.server.*;
import reactor.core.Exceptions;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
Expand All @@ -22,10 +22,6 @@
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.RouterFunctions;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.reactive.function.server.ServerResponse;

/**
* Server-side implementation of the MCP (Model Context Protocol) HTTP transport using
Expand Down Expand Up @@ -84,6 +80,12 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv

public static final String DEFAULT_BASE_URL = "";

/**
* Default filter function for handling requests, do nothing
*/
public static final HandlerFilterFunction<ServerResponse, ServerResponse> DEFAULT_REQUEST_FILTER = ((request,
next) -> next.handle(request));

private final ObjectMapper objectMapper;

/**
Expand Down Expand Up @@ -149,10 +151,28 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa
*/
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint) {
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, DEFAULT_REQUEST_FILTER);
}

/**
* Constructs a new WebFlux SSE server transport provider instance.
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
* of MCP messages. Must not be null.
* @param baseUrl webflux message base path
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
* messages. This endpoint will be communicated to clients during SSE connection
* setup. Must not be null.
* @param requestFilter The filter function to apply to incoming requests, which may
* be sse or message request.
* @throws IllegalArgumentException if either parameter is null
*/
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, HandlerFilterFunction<ServerResponse, ServerResponse> requestFilter) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.notNull(baseUrl, "Message base path must not be null");
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
Assert.notNull(requestFilter, "Request filter must not be null");

this.objectMapper = objectMapper;
this.baseUrl = baseUrl;
Expand All @@ -161,6 +181,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU
this.routerFunction = RouterFunctions.route()
.GET(this.sseEndpoint, this::handleSseConnection)
.POST(this.messageEndpoint, this::handleMessage)
.filter(requestFilter)
.build();
}

Expand Down Expand Up @@ -245,6 +266,14 @@ public RouterFunction<?> getRouterFunction() {
return this.routerFunction;
}

/**
* Returns the McpServerSession associated with the given session ID.
* @return session The McpServerSession associated with the given session ID, or null
*/
public McpServerSession getSession(String sessionId) {
return sessions.get(sessionId);
}

/**
* Handles new SSE connection requests from clients. Creates a new session for each
* connection and sets up the SSE event stream.
Expand Down Expand Up @@ -397,6 +426,8 @@ public static class Builder {

private String sseEndpoint = DEFAULT_SSE_ENDPOINT;

private HandlerFilterFunction<ServerResponse, ServerResponse> requestFilter = DEFAULT_REQUEST_FILTER;

/**
* Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
* messages.
Expand Down Expand Up @@ -447,6 +478,12 @@ public Builder sseEndpoint(String sseEndpoint) {
return this;
}

public Builder requestFilter(HandlerFilterFunction<ServerResponse, ServerResponse> requestFilter) {
Assert.notNull(requestFilter, "requestFilter must not be null");
this.requestFilter = requestFilter;
return this;
}

/**
* Builds a new instance of {@link WebFluxSseServerTransportProvider} with the
* configured settings.
Expand All @@ -457,7 +494,8 @@ public WebFluxSseServerTransportProvider build() {
Assert.notNull(objectMapper, "ObjectMapper must be set");
Assert.notNull(messageEndpoint, "Message endpoint must be set");

return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint);
return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint,
requestFilter);
}

}
Expand Down
Loading