Skip to content

Refactor server-side pt 2 - documentation #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@
* @author Alexandros Pappas
* @see ServerMcpTransport
* @see ServerSentEvent
* @deprecated This class will be removed in 0.9.0. Use
* {@link WebFluxSseServerTransportProvider}.
*/
@Deprecated
public class WebFluxSseServerTransport implements ServerMcpTransport {

private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransport.class);
Expand Down Expand Up @@ -182,16 +185,16 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
try {// @formatter:off
String jsonText = objectMapper.writeValueAsString(message);
ServerSentEvent<Object> event = ServerSentEvent.builder()
.event(MESSAGE_EVENT_TYPE)
.data(jsonText)
.build();
.event(MESSAGE_EVENT_TYPE)
.data(jsonText)
.build();

logger.debug("Attempting to broadcast message to {} active sessions", sessions.size());

List<String> failedSessions = sessions.values().stream()
.filter(session -> session.messageSink.tryEmitNext(event).isFailure())
.map(session -> session.id)
.toList();
.filter(session -> session.messageSink.tryEmitNext(event).isFailure())
.map(session -> session.id)
.toList();

if (failedSessions.isEmpty()) {
logger.debug("Successfully broadcast message to all sessions");
Expand Down Expand Up @@ -407,4 +410,4 @@ void close() {

}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
package io.modelcontextprotocol.server.transport;

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.ServerMcpTransport;
import io.modelcontextprotocol.util.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.Exceptions;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;

import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.RouterFunctions;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.reactive.function.server.ServerResponse;

/**
* Server-side implementation of the MCP (Model Context Protocol) HTTP transport using
* Server-Sent Events (SSE). This implementation provides a bidirectional communication
* channel between MCP clients and servers using HTTP POST for client-to-server messages
* and SSE for server-to-client messages.
*
* <p>
* Key features:
* <ul>
* <li>Implements the {@link McpServerTransportProvider} interface that allows managing
* {@link McpServerSession} instances and enabling their communication with the
* {@link McpServerTransport} abstraction.</li>
* <li>Uses WebFlux for non-blocking request handling and SSE support</li>
* <li>Maintains client sessions for reliable message delivery</li>
* <li>Supports graceful shutdown with session cleanup</li>
* <li>Thread-safe message broadcasting to multiple clients</li>
* </ul>
*
* <p>
* The transport sets up two main endpoints:
* <ul>
* <li>SSE endpoint (/sse) - For establishing SSE connections with clients</li>
* <li>Message endpoint (configurable) - For receiving JSON-RPC messages from clients</li>
* </ul>
*
* <p>
* This implementation is thread-safe and can handle multiple concurrent client
* connections. It uses {@link ConcurrentHashMap} for session management and Project
* Reactor's non-blocking APIs for message processing and delivery.
*
* @author Christian Tzolov
* @author Alexandros Pappas
* @author Dariusz Jędrzejczyk
* @see McpServerTransport
* @see ServerSentEvent
*/
public class WebFluxSseServerTransportProvider implements McpServerTransportProvider {

private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransportProvider.class);

/**
* Event type for JSON-RPC messages sent through the SSE connection.
*/
public static final String MESSAGE_EVENT_TYPE = "message";

/**
* Event type for sending the message endpoint URI to clients.
*/
public static final String ENDPOINT_EVENT_TYPE = "endpoint";

/**
* Default SSE endpoint path as specified by the MCP transport specification.
*/
public static final String DEFAULT_SSE_ENDPOINT = "/sse";

private final ObjectMapper objectMapper;

private final String messageEndpoint;

private final String sseEndpoint;

private final RouterFunction<?> routerFunction;

private McpServerSession.Factory sessionFactory;

/**
* Map of active client sessions, keyed by session ID.
*/
private final ConcurrentHashMap<String, McpServerSession> sessions = new ConcurrentHashMap<>();

/**
* Flag indicating if the transport is shutting down.
*/
private volatile boolean isClosing = false;

/**
* Constructs a new WebFlux SSE server transport provider instance.
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
* of MCP messages. Must not be null.
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
* messages. This endpoint will be communicated to clients during SSE connection
* setup. Must not be null.
* @throws IllegalArgumentException if either parameter is null
*/
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");

this.objectMapper = objectMapper;
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
this.routerFunction = RouterFunctions.route()
.GET(this.sseEndpoint, this::handleSseConnection)
.POST(this.messageEndpoint, this::handleMessage)
.build();
}

/**
* Constructs a new WebFlux SSE server transport provider instance with the default
* SSE endpoint.
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
* of MCP messages. Must not be null.
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
* messages. This endpoint will be communicated to clients during SSE connection
* setup. Must not be null.
* @throws IllegalArgumentException if either parameter is null
*/
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) {
this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT);
}

@Override
public void setSessionFactory(McpServerSession.Factory sessionFactory) {
this.sessionFactory = sessionFactory;
}

/**
* Broadcasts a JSON-RPC message to all connected clients through their SSE
* connections. The message is serialized to JSON and sent as a server-sent event to
* each active session.
*
* <p>
* The method:
* <ul>
* <li>Serializes the message to JSON</li>
* <li>Creates a server-sent event with the message data</li>
* <li>Attempts to send the event to all active sessions</li>
* <li>Tracks and reports any delivery failures</li>
* </ul>
* @param method The JSON-RPC method to send to clients
* @param params The method parameters to send to clients
* @return A Mono that completes when the message has been sent to all sessions, or
* errors if any session fails to receive the message
*/
@Override
public Mono<Void> notifyClients(String method, Map<String, Object> params) {
if (sessions.isEmpty()) {
logger.debug("No active sessions to broadcast message to");
return Mono.empty();
}

logger.debug("Attempting to broadcast message to {} active sessions", sessions.size());

return Flux.fromStream(sessions.values().stream())
.flatMap(session -> session.sendNotification(method, params)
.doOnError(e -> logger.error("Failed to " + "send message to session " + "{}: {}", session.getId(),
e.getMessage()))
.onErrorComplete())
.then();
}

// FIXME: This javadoc makes claims about using isClosing flag but it's not actually
// doing that.
/**
* Initiates a graceful shutdown of all the sessions. This method ensures all active
* sessions are properly closed and cleaned up.
*
* <p>
* The shutdown process:
* <ul>
* <li>Marks the transport as closing to prevent new connections</li>
* <li>Closes each active session</li>
* <li>Removes closed sessions from the sessions map</li>
* <li>Times out after 5 seconds if shutdown takes too long</li>
* </ul>
* @return A Mono that completes when all sessions have been closed
*/
@Override
public Mono<Void> closeGracefully() {
return Flux.fromIterable(sessions.values())
.doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()))
.flatMap(McpServerSession::closeGracefully)
.then();
}

/**
* Returns the WebFlux router function that defines the transport's HTTP endpoints.
* This router function should be integrated into the application's web configuration.
*
* <p>
* The router function defines two endpoints:
* <ul>
* <li>GET {sseEndpoint} - For establishing SSE connections</li>
* <li>POST {messageEndpoint} - For receiving client messages</li>
* </ul>
* @return The configured {@link RouterFunction} for handling HTTP requests
*/
public RouterFunction<?> getRouterFunction() {
return this.routerFunction;
}

/**
* Handles new SSE connection requests from clients. Creates a new session for each
* connection and sets up the SSE event stream.
* @param request The incoming server request
* @return A Mono which emits a response with the SSE event stream
*/
private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
if (isClosing) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
}

return ServerResponse.ok()
.contentType(MediaType.TEXT_EVENT_STREAM)
.body(Flux.<ServerSentEvent<?>>create(sink -> {
WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport(sink);

McpServerSession session = sessionFactory.create(sessionTransport);
String sessionId = session.getId();

logger.debug("Created new SSE connection for session: {}", sessionId);
sessions.put(sessionId, session);

// Send initial endpoint event
logger.debug("Sending initial endpoint event to session: {}", sessionId);
sink.next(ServerSentEvent.builder()
.event(ENDPOINT_EVENT_TYPE)
.data(messageEndpoint + "?sessionId=" + sessionId)
.build());
sink.onCancel(() -> {
logger.debug("Session {} cancelled", sessionId);
sessions.remove(sessionId);
});
}), ServerSentEvent.class);
}

/**
* Handles incoming JSON-RPC messages from clients. Deserializes the message and
* processes it through the configured message handler.
*
* <p>
* The handler:
* <ul>
* <li>Deserializes the incoming JSON-RPC message</li>
* <li>Passes it through the message handler chain</li>
* <li>Returns appropriate HTTP responses based on processing results</li>
* <li>Handles various error conditions with appropriate error responses</li>
* </ul>
* @param request The incoming server request containing the JSON-RPC message
* @return A Mono emitting the response indicating the message processing result
*/
private Mono<ServerResponse> handleMessage(ServerRequest request) {
if (isClosing) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
}

if (request.queryParam("sessionId").isEmpty()) {
return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing in message endpoint"));
}

McpServerSession session = sessions.get(request.queryParam("sessionId").get());

return request.bodyToMono(String.class).flatMap(body -> {
try {
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
return session.handle(message).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> {
logger.error("Error processing message: {}", error.getMessage());
// TODO: instead of signalling the error, just respond with 200 OK
// - the error is signalled on the SSE connection
// return ServerResponse.ok().build();
return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR)
.bodyValue(new McpError(error.getMessage()));
});
}
catch (IllegalArgumentException | IOException e) {
logger.error("Failed to deserialize message: {}", e.getMessage());
return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format"));
}
});
}

private class WebFluxMcpSessionTransport implements McpServerTransport {

private final FluxSink<ServerSentEvent<?>> sink;

public WebFluxMcpSessionTransport(FluxSink<ServerSentEvent<?>> sink) {
this.sink = sink;
}

@Override
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
return Mono.fromSupplier(() -> {
try {
return objectMapper.writeValueAsString(message);
}
catch (IOException e) {
throw Exceptions.propagate(e);
}
}).doOnNext(jsonText -> {
ServerSentEvent<Object> event = ServerSentEvent.builder()
.event(MESSAGE_EVENT_TYPE)
.data(jsonText)
.build();
sink.next(event);
}).doOnError(e -> {
// TODO log with sessionid
Throwable exception = Exceptions.unwrap(e);
sink.error(exception);
}).then();
}

@Override
public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
return objectMapper.convertValue(data, typeRef);
}

@Override
public Mono<Void> closeGracefully() {
return Mono.fromRunnable(sink::complete);
}

@Override
public void close() {
sink.complete();
}

}

}
Loading