diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 62264d9a..c5cf2a8d 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -2,6 +2,7 @@ import java.io.IOException; import java.util.Map; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import com.fasterxml.jackson.core.type.TypeReference; @@ -261,8 +262,8 @@ private Mono handleSseConnection(ServerRequest request) { .body(Flux.>create(sink -> { WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport(sink); - McpServerSession session = sessionFactory.create(sessionTransport); - String sessionId = session.getId(); + String sessionId = UUID.randomUUID().toString(); + McpServerSession session = sessionFactory.create(sessionId, sessionTransport); logger.debug("Created new SSE connection for session: {}", sessionId); sessions.put(sessionId, session); diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index fc86cfaa..79973a98 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -263,7 +263,7 @@ private ServerResponse handleSseConnection(ServerRequest request) { }); WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sessionId, sseBuilder); - McpServerSession session = sessionFactory.create(sessionTransport); + McpServerSession session = sessionFactory.create(sessionId, sessionTransport); this.sessions.put(sessionId, session); try { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 906cb9a0..b89e844c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -340,9 +340,9 @@ private static class AsyncServerImpl extends McpAsyncServer { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - mcpTransportProvider.setSessionFactory( - transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + mcpTransportProvider + .setSessionFactory((id, transport) -> new McpServerSession(id, requestTimeout, transport, + this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); } // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index afdbff47..40215089 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -219,7 +219,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) writer); // Create a new session using the session factory - McpServerSession session = sessionFactory.create(sessionTransport); + McpServerSession session = sessionFactory.create(sessionId, sessionTransport); this.sessions.put(sessionId, session); // Send initial endpoint event diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index 819da977..297c1b2f 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -12,6 +12,7 @@ import java.io.Reader; import java.nio.charset.StandardCharsets; import java.util.Map; +import java.util.UUID; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; @@ -94,7 +95,8 @@ public StdioServerTransportProvider(ObjectMapper objectMapper, InputStream input public void setSessionFactory(McpServerSession.Factory sessionFactory) { // Create a single session for the stdio connection var transport = new StdioMcpSessionTransport(); - this.session = sessionFactory.create(transport); + String sessionId = UUID.randomUUID().toString(); + this.session = sessionFactory.create(sessionId, transport); transport.initProcessing(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 64315095..47607165 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -342,10 +342,11 @@ public interface Factory { /** * Creates a new 1:1 representation of the client-server interaction. + * @param sessionId the id of the session. * @param sessionTransport the transport to use for communication with the client. * @return a new server session. */ - McpServerSession create(McpServerTransport sessionTransport); + McpServerSession create(String sessionId, McpServerTransport sessionTransport); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java index 20a8c0cf..d60d7120 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java @@ -16,6 +16,7 @@ package io.modelcontextprotocol; import java.util.Map; +import java.util.UUID; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; @@ -43,7 +44,8 @@ public MockMcpServerTransport getTransport() { @Override public void setSessionFactory(Factory sessionFactory) { - session = sessionFactory.create(transport); + String sessionId = UUID.randomUUID().toString(); + session = sessionFactory.create(sessionId, transport); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java index 14987b5a..f447a1bd 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java @@ -71,7 +71,7 @@ void setUp() { sessionFactory = mock(McpServerSession.Factory.class); // Configure mock behavior - when(sessionFactory.create(any(McpServerTransport.class))).thenReturn(mockSession); + when(sessionFactory.create(any(), any(McpServerTransport.class))).thenReturn(mockSession); when(mockSession.closeGracefully()).thenReturn(Mono.empty()); when(mockSession.sendNotification(any(), any())).thenReturn(Mono.empty()); @@ -110,7 +110,7 @@ void shouldHandleIncomingMessages() throws Exception { AtomicReference capturedMessage = new AtomicReference<>(); CountDownLatch messageLatch = new CountDownLatch(1); - McpServerSession.Factory realSessionFactory = transport -> { + McpServerSession.Factory realSessionFactory = (id, transport) -> { McpServerSession session = mock(McpServerSession.class); when(session.handle(any())).thenAnswer(invocation -> { capturedMessage.set(invocation.getArgument(0));