diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 62264d9a..9dad8cd4 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -1,13 +1,14 @@ package io.modelcontextprotocol.server.transport; import java.io.IOException; -import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.server.McpServerAuthParam; +import io.modelcontextprotocol.server.McpServerAuthProvider; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; @@ -60,7 +61,9 @@ * @author Christian Tzolov * @author Alexandros Pappas * @author Dariusz Jędrzejczyk + * @author lambochen * @see McpServerTransport + * @see McpServerAuthProvider * @see ServerSentEvent */ public class WebFluxSseServerTransportProvider implements McpServerTransportProvider { @@ -100,6 +103,11 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv private McpServerSession.Factory sessionFactory; + /** + * auth provider + */ + private final McpServerAuthProvider authProvider; + /** * Map of active client sessions, keyed by session ID. */ @@ -149,6 +157,22 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa */ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); + } + + /** + * Constructs a new WebFlux SSE server transport provider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param baseUrl webflux message base path + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @param authProvider auth provider + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, McpServerAuthProvider authProvider) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base path must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); @@ -158,6 +182,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.authProvider = authProvider; this.routerFunction = RouterFunctions.route() .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) @@ -256,6 +281,10 @@ private Mono handleSseConnection(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } + if (null != authProvider && !authProvider.authenticate(assemblyAuthParam(request))) { + return ServerResponse.status(HttpStatus.UNAUTHORIZED).bodyValue("Unauthorized"); + } + return ServerResponse.ok() .contentType(MediaType.TEXT_EVENT_STREAM) .body(Flux.>create(sink -> { @@ -280,6 +309,14 @@ private Mono handleSseConnection(ServerRequest request) { }), ServerSentEvent.class); } + private McpServerAuthParam assemblyAuthParam(ServerRequest request) { + return McpServerAuthParam.builder() + .sseEndpoint(this.sseEndpoint) + .uri(request.uri().toString()) + .params(request.queryParams().toSingleValueMap()) + .build(); + } + /** * Handles incoming JSON-RPC messages from clients. Deserializes the message and * processes it through the configured message handler. @@ -397,6 +434,18 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + private McpServerAuthProvider authProvider = null; + + /** + * Sets the authentication provider. + * @param authProvider the authentication provider + * @return this builder instance + */ + public Builder authProvider(McpServerAuthProvider authProvider) { + this.authProvider = authProvider; + return this; + } + /** * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP * messages. @@ -457,7 +506,8 @@ public WebFluxSseServerTransportProvider build() { Assert.notNull(objectMapper, "ObjectMapper must be set"); Assert.notNull(messageEndpoint, "Message endpoint must be set"); - return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); + return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, + authProvider); } } diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index fc86cfaa..453ea372 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -6,7 +6,6 @@ import java.io.IOException; import java.time.Duration; -import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; @@ -14,6 +13,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.server.McpServerAuthParam; +import io.modelcontextprotocol.server.McpServerAuthProvider; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.spec.McpServerSession; @@ -63,7 +64,9 @@ * * @author Christian Tzolov * @author Alexandros Pappas + * @author lambochen * @see McpServerTransportProvider + * @see McpServerAuthProvider * @see RouterFunction */ public class WebMvcSseServerTransportProvider implements McpServerTransportProvider { @@ -107,6 +110,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi */ private volatile boolean isClosing = false; + private final McpServerAuthProvider authProvider; + /** * Constructs a new WebMvcSseServerTransportProvider instance with the default SSE * endpoint. @@ -149,6 +154,24 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag */ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); + } + + /** + * Constructs a new WebMvcSseServerTransportProvider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param baseUrl The base URL for the message endpoint, used to construct the full + * endpoint URL for clients. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP POST. This endpoint will be communicated to clients through the + * SSE connection's initial endpoint event. + * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @param authProvider The authentication provider to use for authentication. + * @throws IllegalArgumentException if any parameter is null + */ + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, McpServerAuthProvider authProvider) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base URL must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); @@ -158,6 +181,7 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.authProvider = authProvider; this.routerFunction = RouterFunctions.route() .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) @@ -247,6 +271,10 @@ private ServerResponse handleSseConnection(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } + if (null != authProvider && !authProvider.authenticate(assemblyAuthParam(request))) { + return ServerResponse.status(HttpStatus.UNAUTHORIZED).body("Unauthorized"); + } + String sessionId = UUID.randomUUID().toString(); logger.debug("Creating new SSE connection for session: {}", sessionId); @@ -284,6 +312,14 @@ private ServerResponse handleSseConnection(ServerRequest request) { } } + private McpServerAuthParam assemblyAuthParam(ServerRequest request) { + return McpServerAuthParam.builder() + .sseEndpoint(this.sseEndpoint) + .uri(request.uri().toString()) + .params(request.params().toSingleValueMap()) + .build(); + } + /** * Handles incoming JSON-RPC messages from clients. This method: *
    diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 99cf2a62..63e2941b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -9,6 +9,7 @@ import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.time.Duration; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -75,6 +76,11 @@ public class HttpClientSseClientTransport implements McpClientTransport { /** SSE endpoint path */ private final String sseEndpoint; + /** + * Additional parameters for the connect( client to server). + */ + private final Map params; + /** SSE client for handling server-sent events. Uses the /sse endpoint */ private final FlowSseClient sseClient; @@ -174,6 +180,11 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques */ HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, String sseEndpoint, ObjectMapper objectMapper) { + this(httpClient, requestBuilder, baseUri, sseEndpoint, objectMapper, null); + } + + HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, + String sseEndpoint, ObjectMapper objectMapper, Map params) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.hasText(baseUri, "baseUri must not be empty"); Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); @@ -181,6 +192,7 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques Assert.notNull(requestBuilder, "requestBuilder must not be null"); this.baseUri = URI.create(baseUri); this.sseEndpoint = sseEndpoint; + this.params = params; this.objectMapper = objectMapper; this.httpClient = httpClient; this.requestBuilder = requestBuilder; @@ -206,6 +218,8 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + private Map params; + private HttpClient.Builder clientBuilder = HttpClient.newBuilder() .version(HttpClient.Version.HTTP_1_1) .connectTimeout(Duration.ofSeconds(10)); @@ -257,6 +271,16 @@ public Builder sseEndpoint(String sseEndpoint) { return this; } + /** + * Sets the request params. + * @param params the request params + * @return this builder + */ + public Builder params(Map params) { + this.params = params; + return this; + } + /** * Sets the HTTP client builder. * @param clientBuilder the HTTP client builder @@ -318,7 +342,7 @@ public Builder objectMapper(ObjectMapper objectMapper) { */ public HttpClientSseClientTransport build() { return new HttpClientSseClientTransport(clientBuilder.build(), requestBuilder, baseUri, sseEndpoint, - objectMapper); + objectMapper, params); } } @@ -342,7 +366,7 @@ public Mono connect(Function, Mono> h connectionFuture.set(future); URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint); - sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() { + sseClient.subscribe(assemblyUri(clientUri.toString(), this.params), new FlowSseClient.SseEventHandler() { @Override public void onEvent(SseEvent event) { if (isClosing) { @@ -382,6 +406,23 @@ public void onError(Throwable error) { return Mono.fromFuture(future); } + /** + * Assembles the full URI with the base URI and the params. + * @param baseUri baseUri + sseEndpoint + * @param params additional params + * @return full uri: baseUri + sseEndpoint + params + */ + private String assemblyUri(String baseUri, Map params) { + if (null == params || params.isEmpty()) { + return baseUri; + } + StringBuilder uri = new StringBuilder(baseUri); + uri.append("?"); + params.forEach((k, v) -> uri.append(k).append("=").append(v).append("&")); + uri.deleteCharAt(uri.length() - 1); + return uri.toString(); + } + /** * Sends a JSON-RPC message to the server. * diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerAuthParam.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerAuthParam.java new file mode 100644 index 00000000..e4d16b5e --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerAuthParam.java @@ -0,0 +1,78 @@ +package io.modelcontextprotocol.server; + +import java.util.Map; + +/** + * MCP server authentication context. + * + * @author lambochen + */ +public class McpServerAuthParam { + + /** + * mcp server sse endpoint + */ + private final String sseEndpoint; + + /** + * mcp server url + */ + private final String uri; + + /** + * mcp server params + */ + private final Map params; + + private McpServerAuthParam(String sseEndpoint, String uri, Map params) { + this.sseEndpoint = sseEndpoint; + this.uri = uri; + this.params = params; + } + + public String getSseEndpoint() { + return sseEndpoint; + } + + public String getUri() { + return uri; + } + + public Map getParams() { + return params; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String uri; + + private String sseEndpoint; + + private Map params; + + public Builder uri(String uri) { + this.uri = uri; + return this; + } + + public Builder sseEndpoint(String sseEndpoint) { + this.sseEndpoint = sseEndpoint; + return this; + } + + public Builder params(Map params) { + this.params = params; + return this; + } + + public McpServerAuthParam build() { + return new McpServerAuthParam(sseEndpoint, uri, params); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerAuthProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerAuthProvider.java new file mode 100644 index 00000000..572bdde8 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerAuthProvider.java @@ -0,0 +1,17 @@ +package io.modelcontextprotocol.server; + +/** + * MCP server authentication provider by request params + * + * @author lambochen + */ +public interface McpServerAuthProvider { + + /** + * Authenticate by request params + * @param param request params and other information + * @return true if authenticate success, otherwise false + */ + boolean authenticate(McpServerAuthParam param); + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/auth/CustomMcpServerAuthProvider.java b/mcp/src/test/java/io/modelcontextprotocol/auth/CustomMcpServerAuthProvider.java new file mode 100644 index 00000000..dad5a3e5 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/auth/CustomMcpServerAuthProvider.java @@ -0,0 +1,26 @@ +package io.modelcontextprotocol.auth; + +import io.modelcontextprotocol.server.McpServerAuthParam; +import io.modelcontextprotocol.server.McpServerAuthProvider; +// import org.springframework.stereotype.Component; + +/** + * MCP Client(eg: Claude) config: ``` { "mcpServers": { "mcp-remote-server-example": { + * "url": "http://{your mcp server}/sse?token=xxx" } } } ``` + * + * {@link McpServerAuthProvider} example + * + * @author lambochen + */ +// @Component +public class CustomMcpServerAuthProvider implements McpServerAuthProvider { + + @Override + public boolean authenticate(McpServerAuthParam param) { + // example: get param + String token = param.getParams().get("token"); + // TODO do something + return true; + } + +}