Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor server side to handle multiple clients #31

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
package io.modelcontextprotocol.server.transport;

import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.ServerMcpSession;
import io.modelcontextprotocol.spec.ServerMcpTransport;
import io.modelcontextprotocol.util.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.Exceptions;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;

Expand Down Expand Up @@ -88,18 +88,30 @@ public class WebFluxSseServerTransport implements ServerMcpTransport {

private final RouterFunction<?> routerFunction;

private ServerMcpSession.Factory sessionFactory;

/**
* Map of active client sessions, keyed by session ID.
*/
private final ConcurrentHashMap<String, ClientSession> sessions = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, ServerMcpSession> sessions = new ConcurrentHashMap<>();

// FIXME: This is a bit clumsy. The McpAsyncServer handles global notifications
// using the transport and we need access to child transports for each session to
// use the sendMessage method. Ideally, the particular transport would be an
// abstraction of a specialized session that can handle only notifications and we
// could delegate to all child sessions without directly going through the transport.
// The conversion from a notification to message happens both in McpAsyncServer
// and in ServerMcpSession and it would be beneficial to have a unified interface
// for both. An MCP server implementation can use both McpServerExchange and
// Mcp(Sync|Async)Server to send notifications so the capability needs to lie in
// both places.
private final ConcurrentHashMap<String, WebFluxMcpSessionTransport> sessionTransports = new ConcurrentHashMap<>();

/**
* Flag indicating if the transport is shutting down.
*/
private volatile boolean isClosing = false;

private Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> connectHandler;

/**
* Constructs a new WebFlux SSE server transport instance.
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
Expand Down Expand Up @@ -137,21 +149,9 @@ public WebFluxSseServerTransport(ObjectMapper objectMapper, String messageEndpoi
this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT);
}

/**
* Configures the message handler for this transport. In the WebFlux SSE
* implementation, this method stores the handler for processing incoming messages but
* doesn't establish any connections since the server accepts connections rather than
* initiating them.
* @param handler A function that processes incoming JSON-RPC messages and returns
* responses. This handler will be called for each message received through the
* message endpoint.
* @return An empty Mono since the server doesn't initiate connections
*/
@Override
public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
this.connectHandler = handler;
// Server-side transport doesn't initiate connections
return Mono.empty().then();
public void setSessionFactory(ServerMcpSession.Factory sessionFactory) {
this.sessionFactory = sessionFactory;
}

/**
Expand All @@ -178,36 +178,14 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
return Mono.empty();
}

return Mono.<Void>create(sink -> {
try {// @formatter:off
String jsonText = objectMapper.writeValueAsString(message);
ServerSentEvent<Object> event = ServerSentEvent.builder()
.event(MESSAGE_EVENT_TYPE)
.data(jsonText)
.build();

logger.debug("Attempting to broadcast message to {} active sessions", sessions.size());
logger.debug("Attempting to broadcast message to {} active sessions", sessions.size());

List<String> failedSessions = sessions.values().stream()
.filter(session -> session.messageSink.tryEmitNext(event).isFailure())
.map(session -> session.id)
.toList();

if (failedSessions.isEmpty()) {
logger.debug("Successfully broadcast message to all sessions");
sink.success();
}
else {
String error = "Failed to broadcast message to sessions: " + String.join(", ", failedSessions);
logger.error(error);
sink.error(new RuntimeException(error));
} // @formatter:on
}
catch (IOException e) {
logger.error("Failed to serialize message: {}", e.getMessage());
sink.error(e);
}
});
return Flux.fromStream(sessionTransports.values().stream())
.flatMap(session -> session.sendMessage(message)
.doOnError(e -> logger.error("Failed to " + "send message to session {}: {}", session.sessionId,
e.getMessage()))
.onErrorComplete())
.then();
}

/**
Expand Down Expand Up @@ -241,18 +219,10 @@ public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
*/
@Override
public Mono<Void> closeGracefully() {
return Mono.fromRunnable(() -> {
isClosing = true;
logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size());
}).then(Mono.when(sessions.values().stream().map(session -> {
String sessionId = session.id;
return Mono.fromRunnable(() -> session.close())
.then(Mono.delay(Duration.ofMillis(100)))
.then(Mono.fromRunnable(() -> sessions.remove(sessionId)));
}).toList()))
.timeout(Duration.ofSeconds(5))
.doOnSuccess(v -> logger.debug("Graceful shutdown completed"))
.doOnError(e -> logger.error("Error during graceful shutdown: {}", e.getMessage()));
return Flux.fromIterable(sessions.values())
.doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()))
.flatMap(ServerMcpSession::closeGracefully)
.then();
}

/**
Expand Down Expand Up @@ -291,38 +261,24 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
if (isClosing) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
}
String sessionId = UUID.randomUUID().toString();
logger.debug("Creating new SSE connection for session: {}", sessionId);
ClientSession session = new ClientSession(sessionId);
this.sessions.put(sessionId, session);

return ServerResponse.ok()
.contentType(MediaType.TEXT_EVENT_STREAM)
.body(Flux.<ServerSentEvent<?>>create(sink -> {
String sessionId = UUID.randomUUID().toString();
logger.debug("Creating new SSE connection for session: {}", sessionId);
WebFluxMcpSessionTransport
sessionTransport = new WebFluxMcpSessionTransport(sessionId, sink);

sessions.put(sessionId, sessionFactory.create(sessionTransport));
sessionTransports.put(sessionId, sessionTransport);

// Send initial endpoint event
logger.debug("Sending initial endpoint event to session: {}", sessionId);
sink.next(ServerSentEvent.builder().event(ENDPOINT_EVENT_TYPE).data(messageEndpoint).build());

// Subscribe to session messages
session.messageSink.asFlux()
.doOnSubscribe(s -> logger.debug("Session {} subscribed to message sink", sessionId))
.doOnComplete(() -> {
logger.debug("Session {} completed", sessionId);
sessions.remove(sessionId);
})
.doOnError(error -> {
logger.error("Error in session {}: {}", sessionId, error.getMessage());
sessions.remove(sessionId);
})
.doOnCancel(() -> {
logger.debug("Session {} cancelled", sessionId);
sessions.remove(sessionId);
})
.subscribe(event -> {
logger.debug("Forwarding event to session {}: {}", sessionId, event);
sink.next(event);
}, sink::error, sink::complete);

sink.next(ServerSentEvent.builder()
.event(ENDPOINT_EVENT_TYPE)
.data(messageEndpoint + "?sessionId=" + sessionId)
.build());
sink.onCancel(() -> {
logger.debug("Session {} cancelled", sessionId);
sessions.remove(sessionId);
Expand Down Expand Up @@ -350,17 +306,20 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
}

if (request.queryParam("sessionId").isEmpty()) {
return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing in message endpoint"));
}

ServerMcpSession session = sessions.get(request.queryParam("sessionId").get());

return request.bodyToMono(String.class).flatMap(body -> {
try {
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
return Mono.just(message)
.transform(this.connectHandler)
.flatMap(response -> ServerResponse.ok().build())
.onErrorResume(error -> {
logger.error("Error processing message: {}", error.getMessage());
return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR)
.bodyValue(new McpError(error.getMessage()));
});
return session.handle(message).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> {
logger.error("Error processing message: {}", error.getMessage());
return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR)
.bodyValue(new McpError(error.getMessage()));
});
}
catch (IllegalArgumentException | IOException e) {
logger.error("Failed to deserialize message: {}", e.getMessage());
Expand All @@ -369,40 +328,52 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
});
}

/**
* Represents an active client SSE connection session. Manages the message sink for
* sending events to the client and handles session lifecycle.
*
* <p>
* Each session:
* <ul>
* <li>Has a unique identifier</li>
* <li>Maintains its own message sink for event broadcasting</li>
* <li>Supports clean shutdown through the close method</li>
* </ul>
*/
private static class ClientSession {
private class WebFluxMcpSessionTransport implements ServerMcpTransport.Child {

private final String id;
final String sessionId;

private final Sinks.Many<ServerSentEvent<?>> messageSink;
private final FluxSink<ServerSentEvent<?>> sink;

ClientSession(String id) {
this.id = id;
logger.debug("Creating new session: {}", id);
this.messageSink = Sinks.many().replay().latest();
logger.debug("Session {} initialized with replay sink", id);
public WebFluxMcpSessionTransport(String sessionId, FluxSink<ServerSentEvent<?>> sink) {
this.sessionId = sessionId;
this.sink = sink;
}

void close() {
logger.debug("Closing session: {}", id);
Sinks.EmitResult result = messageSink.tryEmitComplete();
if (result.isFailure()) {
logger.warn("Failed to complete message sink for session {}: {}", id, result);
}
else {
logger.debug("Successfully completed message sink for session {}", id);
}
@Override
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
return Mono.fromSupplier(() -> {
try {
return objectMapper.writeValueAsString(message);
}
catch (IOException e) {
throw Exceptions.propagate(e);
}
}).doOnNext(jsonText -> {
ServerSentEvent<Object> event = ServerSentEvent.builder()
.event(MESSAGE_EVENT_TYPE)
.data(jsonText)
.build();
sink.next(event);
}).doOnError(e -> {
// TODO log with sessionid
Throwable exception = Exceptions.unwrap(e);
sink.error(exception);
}).then();
}

@Override
public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
return WebFluxSseServerTransport.this.unmarshalFrom(data, typeRef);
}

@Override
public Mono<Void> closeGracefully() {
return Mono.fromRunnable(sink::complete);
}

@Override
public void close() {
sink.complete();
}

}
Expand Down
Loading