diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 62264d9aa..00cca331d 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -1,6 +1,7 @@ package io.modelcontextprotocol.server.transport; import java.io.IOException; +import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -110,6 +111,11 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ private volatile boolean isClosing = false; + /** + * DNS rebinding protection configuration. + */ + private final DnsRebindingProtection dnsRebindingProtection; + /** * Constructs a new WebFlux SSE server transport provider instance with the default * SSE endpoint. @@ -118,8 +124,10 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC * messages. This endpoint will be communicated to clients during SSE connection * setup. Must not be null. + * @deprecated Use {@link #builder()} instead. * @throws IllegalArgumentException if either parameter is null */ + @Deprecated public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); } @@ -131,10 +139,12 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC * messages. This endpoint will be communicated to clients during SSE connection * setup. Must not be null. + * @deprecated Use {@link #builder()} instead. * @throws IllegalArgumentException if either parameter is null */ + @Deprecated public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); + this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint, null); } /** @@ -145,10 +155,31 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC * messages. This endpoint will be communicated to clients during SSE connection * setup. Must not be null. + * @deprecated Use {@link #builder()} instead. * @throws IllegalArgumentException if either parameter is null */ + @Deprecated public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); + } + + /** + * Constructs a new WebFlux SSE server transport provider instance with optional DNS + * rebinding protection. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param baseUrl webflux message base path + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @param dnsRebindingProtection The DNS rebinding protection configuration (may be + * null). + * @throws IllegalArgumentException if required parameters are null + */ + private WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, DnsRebindingProtection dnsRebindingProtection) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base path must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); @@ -158,6 +189,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.dnsRebindingProtection = dnsRebindingProtection; this.routerFunction = RouterFunctions.route() .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) @@ -256,6 +288,12 @@ private Mono handleSseConnection(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } + // Validate headers + Mono validationError = validateDnsRebindingProtection(request); + if (validationError != null) { + return validationError; + } + return ServerResponse.ok() .contentType(MediaType.TEXT_EVENT_STREAM) .body(Flux.>create(sink -> { @@ -300,6 +338,19 @@ private Mono handleMessage(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } + // Always validate Content-Type for POST requests + String contentType = request.headers().contentType().map(MediaType::toString).orElse(null); + if (contentType == null || !contentType.toLowerCase().startsWith("application/json")) { + logger.warn("Invalid Content-Type header: '{}'", contentType); + return ServerResponse.badRequest().bodyValue(new McpError("Content-Type must be application/json")); + } + + // Validate headers for POST requests if DNS rebinding protection is configured + Mono validationError = validateDnsRebindingProtection(request); + if (validationError != null) { + return validationError; + } + if (request.queryParam("sessionId").isEmpty()) { return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing in message endpoint")); } @@ -397,6 +448,8 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + private DnsRebindingProtection dnsRebindingProtection; + /** * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP * messages. @@ -447,6 +500,22 @@ public Builder sseEndpoint(String sseEndpoint) { return this; } + /** + * Sets the DNS rebinding protection configuration. + *

+ * When set, this configuration will be used to create a header validator that + * enforces DNS rebinding protection rules. This will override any previously set + * header validator. + * @param config The DNS rebinding protection configuration + * @return this builder instance + * @throws IllegalArgumentException if config is null + */ + public Builder dnsRebindingProtection(DnsRebindingProtection config) { + Assert.notNull(config, "DNS rebinding protection config must not be null"); + this.dnsRebindingProtection = config; + return this; + } + /** * Builds a new instance of {@link WebFluxSseServerTransportProvider} with the * configured settings. @@ -457,9 +526,30 @@ public WebFluxSseServerTransportProvider build() { Assert.notNull(objectMapper, "ObjectMapper must be set"); Assert.notNull(messageEndpoint, "Message endpoint must be set"); - return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); + return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, + dnsRebindingProtection); } } + /** + * Validates DNS rebinding protection for the given request. + * @param request The incoming server request + * @return A ServerResponse with forbidden status if validation fails, or null if + * validation passes + */ + private Mono validateDnsRebindingProtection(ServerRequest request) { + if (dnsRebindingProtection != null) { + String hostHeader = request.headers().asHttpHeaders().getFirst("Host"); + String originHeader = request.headers().asHttpHeaders().getFirst("Origin"); + if (!dnsRebindingProtection.isValid(hostHeader, originHeader)) { + logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, + originHeader); + return ServerResponse.status(HttpStatus.FORBIDDEN) + .bodyValue("DNS rebinding protection validation failed"); + } + } + return null; + } + } diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index fc86cfaa0..ced54f62c 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -6,6 +6,7 @@ import java.io.IOException; import java.time.Duration; +import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; @@ -107,6 +108,11 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi */ private volatile boolean isClosing = false; + /** + * DNS rebinding protection configuration. + */ + private final DnsRebindingProtection dnsRebindingProtection; + /** * Constructs a new WebMvcSseServerTransportProvider instance with the default SSE * endpoint. @@ -115,8 +121,10 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC * messages via HTTP POST. This endpoint will be communicated to clients through the * SSE connection's initial endpoint event. + * @deprecated Use {@link #builder()} instead. * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null */ + @Deprecated public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); } @@ -129,10 +137,12 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag * messages via HTTP POST. This endpoint will be communicated to clients through the * SSE connection's initial endpoint event. * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @deprecated Use {@link #builder()} instead. * @throws IllegalArgumentException if any parameter is null */ + @Deprecated public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, "", messageEndpoint, sseEndpoint); + this(objectMapper, "", messageEndpoint, sseEndpoint, null); } /** @@ -145,10 +155,32 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag * messages via HTTP POST. This endpoint will be communicated to clients through the * SSE connection's initial endpoint event. * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @deprecated Use {@link #builder()} instead. * @throws IllegalArgumentException if any parameter is null */ + @Deprecated public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); + } + + /** + * Constructs a new WebMvcSseServerTransportProvider instance with DNS rebinding + * protection. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param baseUrl The base URL for the message endpoint, used to construct the full + * endpoint URL for clients. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP POST. This endpoint will be communicated to clients through the + * SSE connection's initial endpoint event. + * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @param dnsRebindingProtection The DNS rebinding protection configuration (may be + * null). + * @throws IllegalArgumentException if any required parameter is null + */ + private WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, DnsRebindingProtection dnsRebindingProtection) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base URL must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); @@ -158,6 +190,7 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.dnsRebindingProtection = dnsRebindingProtection; this.routerFunction = RouterFunctions.route() .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) @@ -247,6 +280,12 @@ private ServerResponse handleSseConnection(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } + // Validate headers + ServerResponse validationError = validateDnsRebindingProtection(request); + if (validationError != null) { + return validationError; + } + String sessionId = UUID.randomUUID().toString(); logger.debug("Creating new SSE connection for session: {}", sessionId); @@ -300,6 +339,19 @@ private ServerResponse handleMessage(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } + // Always validate Content-Type for POST requests + String contentType = request.headers().asHttpHeaders().getFirst("Content-Type"); + if (contentType == null || !contentType.toLowerCase().startsWith("application/json")) { + logger.warn("Invalid Content-Type header: '{}'", contentType); + return ServerResponse.badRequest().body(new McpError("Content-Type must be application/json")); + } + + // Validate headers for POST requests if DNS rebinding protection is configured + ServerResponse validationError = validateDnsRebindingProtection(request); + if (validationError != null) { + return validationError; + } + if (request.param("sessionId").isEmpty()) { return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint")); } @@ -417,4 +469,23 @@ public void close() { } + /** + * Validates DNS rebinding protection for the given request. + * @param request The incoming server request + * @return A ServerResponse with forbidden status if validation fails, or null if + * validation passes + */ + private ServerResponse validateDnsRebindingProtection(ServerRequest request) { + if (dnsRebindingProtection != null) { + String hostHeader = request.headers().asHttpHeaders().getFirst("Host"); + String originHeader = request.headers().asHttpHeaders().getFirst("Origin"); + if (!dnsRebindingProtection.isValid(hostHeader, originHeader)) { + logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, + originHeader); + return ServerResponse.status(HttpStatus.FORBIDDEN).body("DNS rebinding protection validation failed"); + } + } + return null; + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtection.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtection.java new file mode 100644 index 000000000..2d1fd22d3 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtection.java @@ -0,0 +1,267 @@ +package io.modelcontextprotocol.server.transport; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +/** + * Configuration for DNS rebinding protection in SSE server transports. + *

+ * DNS Rebinding Attacks and Protection + *

+ * DNS rebinding attacks allow malicious websites to bypass same-origin policy and + * interact with local services by rebinding a domain name to localhost (127.0.0.1) or + * other local IP addresses. This protection is critical when running applications + * locally that are not meant to be accessible to the broader network. + * + *

+ * When to Use DNS Rebinding Protection + *

    + *
  • Local development servers: MCP servers running on localhost during + * development
  • + *
  • Enterprise environments: Internal services that should not be + * accessible from external networks
  • + *
+ * + *

+ * Allowlist Configuration + *

+ * This protection requires configuring an allowlist of accepted Host and Origin values: + *

    + *
  • Host values: The expected domain/IP and port where your server is + * hosted.
    + * These values are constructed from {@code HttpServletRequest.getServerName()} and + * {@code getServerPort()}.
    + * Examples: "localhost:8080", "127.0.0.1:3000", "app.mycompany.com" (standard ports + * 80/443 omitted)
  • + *
  • Origin headers: The domains allowed to make requests to your + * server.
    + * These are taken directly from the Origin HTTP header.
    + * Examples: "http://localhost:3000", "https://app.mycompany.com"
  • + *
+ * + *

+ * Validation Behavior + *

+ * When protection is enabled, validation succeeds when the Host and Origin are either + * null, or match values in the respective allowlists (case-insensitive). + * + *

+ * Example Usage

{@code
+ * // For a local development server:
+ * DnsRebindingProtection config = DnsRebindingProtection.builder()
+ *     .allowedHost("localhost:8080")
+ *     .allowedHost("127.0.0.1:8080")
+ *     .allowedOrigin("http://localhost:3000")
+ *     .build();
+ *
+ * // To disable protection entirely:
+ * DnsRebindingProtection config = DnsRebindingProtection.builder()
+ *     .enableDnsRebindingProtection(false)
+ *     .build();
+ * }
+ * + * @see DNS Rebinding Attack + */ +public class DnsRebindingProtection { + + private final Set allowedHosts; + + private final Set allowedOrigins; + + private final boolean enable; + + /** + * Constructs a new DNS rebinding protection configuration. + * @param allowedHosts The set of allowed host header values (case-insensitive) + * @param allowedOrigins The set of allowed origin header values (case-insensitive) + * @param enable Whether DNS rebinding protection is enabled + */ + private DnsRebindingProtection(Set allowedHosts, Set allowedOrigins, boolean enable) { + this.allowedHosts = Collections.unmodifiableSet(new HashSet<>(allowedHosts)); + this.allowedOrigins = Collections.unmodifiableSet(new HashSet<>(allowedOrigins)); + this.enable = enable; + } + + /** + * Validates Host and Origin headers for DNS rebinding protection. + *

+ * Validation Logic: + *

    + *
  • If protection is disabled ({@code enable=false}): always returns + * true
  • + *
  • If both headers are null: returns true (no validation + * needed)
  • + *
  • If allowlists are empty and headers are provided: returns + * false (reject all)
  • + *
  • If headers are provided and match allowlist values: returns + * true (case-insensitive)
  • + *
  • If any provided header doesn't match its allowlist: returns + * false
  • + *
+ * + *

+ * Important Behavior: An empty allowlist will reject + * any non-null header value. This is intentional - you must + * explicitly configure allowed values for protection to be meaningful. + * @param hostHeader The value of the Host header from the HTTP request (may be null) + * @param originHeader The value of the Origin header from the HTTP request (may be + * null) + * @return {@code true} if the headers are valid according to the protection rules, + * {@code false} if validation failed + */ + public boolean isValid(String hostHeader, String originHeader) { + // Skip validation if protection is not enabled + if (!enable) { + return true; + } + + // Validate Host header + if (hostHeader != null) { + String lowerHost = hostHeader.toLowerCase(); + if (!allowedHosts.contains(lowerHost)) { + return false; + } + } + + // Validate Origin header + if (originHeader != null) { + String lowerOrigin = originHeader.toLowerCase(); + if (!allowedOrigins.contains(lowerOrigin)) { + return false; + } + } + + return true; + } + + /** + * Creates a new builder for constructing DNS rebinding protection configurations. + * @return A new builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for constructing {@link DnsRebindingProtection} instances with a fluent + * API. + *

+ * This builder provides a convenient way to configure DNS rebinding protection with + * explicit allowlists for host and origin values. The builder enforces safe + * construction patterns and provides sensible defaults. + * + *

+ * Default Behavior: + *

    + *
  • Protection is enabled by default ({@code enable = true})
  • + *
  • Both allowlists start empty, which means all non-null + * headers will be rejected
  • + *
  • This secure default forces explicit allowlist configuration
  • + *
+ * + *

+ * Usage Pattern: + *

    + *
  1. Create builder via {@link DnsRebindingProtection#builder()}
  2. + *
  3. Configure allowlists using {@link #allowedHost(String)} and + * {@link #allowedOrigin(String)}
  4. + *
  5. Optionally enable/disable protection via + * {@link #enableDnsRebindingProtection(boolean)}. Protection is default-enabled when + * creating a builder.
  6. + *
  7. Build final instance via {@link #build()}
  8. + *
+ * + *

+ * Thread Safety: This builder is not thread-safe. + * Each thread should use its own builder instance. + * + * @see DnsRebindingProtection#builder() + */ + public static class Builder { + + private final Set allowedHosts = new HashSet<>(); + + private final Set allowedOrigins = new HashSet<>(); + + private boolean enable = true; + + /** + * Private constructor to restrict instantiation to builder() method. + */ + private Builder() { + } + + /** + * Adds an allowed host header value. + * @param host The host header value to allow (case-insensitive, may be null) + * @return This builder instance for method chaining + */ + public Builder allowedHost(String host) { + if (host != null) { + this.allowedHosts.add(host.toLowerCase()); + } + return this; + } + + /** + * Adds multiple allowed host header values. + * @param hosts The set of host header values to allow (case-insensitive, may be + * null) + * @return This builder instance for method chaining + */ + public Builder allowedHosts(Set hosts) { + if (hosts != null) { + hosts.forEach(this::allowedHost); + } + return this; + } + + /** + * Adds an allowed origin header value. + * @param origin The origin header value to allow (case-insensitive, may be null) + * @return This builder instance for method chaining + */ + public Builder allowedOrigin(String origin) { + if (origin != null) { + this.allowedOrigins.add(origin.toLowerCase()); + } + return this; + } + + /** + * Adds multiple allowed origin header values. + * @param origins The set of origin header values to allow (case-insensitive, may + * be null) + * @return This builder instance for method chaining + */ + public Builder allowedOrigins(Set origins) { + if (origins != null) { + origins.forEach(this::allowedOrigin); + } + return this; + } + + /** + * Sets whether DNS rebinding protection is enabled. + * @param enable True to enable protection, false to + * disable it + * @return This builder instance for method chaining + */ + public Builder enableDnsRebindingProtection(boolean enable) { + this.enable = enable; + return this; + } + + /** + * Builds a new DNS rebinding protection configuration with the configured + * settings. + * @return A new DnsRebindingProtection instance + */ + public DnsRebindingProtection build() { + return new DnsRebindingProtection(allowedHosts, allowedOrigins, enable); + } + + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index afdbff472..12fd48c03 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -103,6 +103,9 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement /** Session factory for creating new sessions */ private McpServerSession.Factory sessionFactory; + /** DNS rebinding protection configuration */ + private final DnsRebindingProtection dnsRebindingProtection; + /** * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE * endpoint. @@ -110,10 +113,12 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement * serialization/deserialization * @param messageEndpoint The endpoint path where clients will send their messages * @param sseEndpoint The endpoint path where clients will establish SSE connections + * @deprecated Use {@link #builder()} instead. */ + @Deprecated public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); + this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint, null); } /** @@ -124,13 +129,12 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String m * @param baseUrl The base URL for the server transport * @param messageEndpoint The endpoint path where clients will send their messages * @param sseEndpoint The endpoint path where clients will establish SSE connections + * @deprecated Use {@link #builder()} instead. */ + @Deprecated public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { - this.objectMapper = objectMapper; - this.baseUrl = baseUrl; - this.messageEndpoint = messageEndpoint; - this.sseEndpoint = sseEndpoint; + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); } /** @@ -139,11 +143,33 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String b * @param objectMapper The JSON object mapper to use for message * serialization/deserialization * @param messageEndpoint The endpoint path where clients will send their messages + * @deprecated Use {@link #builder()} instead. */ + @Deprecated public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); } + /** + * Creates a new HttpServletSseServerTransportProvider instance with optional DNS + * rebinding protection. + * @param objectMapper The JSON object mapper to use for message + * serialization/deserialization + * @param baseUrl The base URL for the server transport + * @param messageEndpoint The endpoint path where clients will send their messages + * @param sseEndpoint The endpoint path where clients will establish SSE connections + * @param dnsRebindingProtection The DNS rebinding protection configuration (may be + * null) + */ + private HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, DnsRebindingProtection dnsRebindingProtection) { + this.objectMapper = objectMapper; + this.baseUrl = baseUrl; + this.messageEndpoint = messageEndpoint; + this.sseEndpoint = sseEndpoint; + this.dnsRebindingProtection = dnsRebindingProtection; + } + /** * Sets the session factory for creating new sessions. * @param sessionFactory The session factory to use @@ -202,6 +228,11 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) return; } + // Validate headers if DNS rebinding protection is configured + if (!validateDnsRebindingProtection(request, response)) { + return; + } + response.setContentType("text/event-stream"); response.setCharacterEncoding(UTF_8); response.setHeader("Cache-Control", "no-cache"); @@ -252,6 +283,19 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) return; } + // Always validate Content-Type for POST requests + String contentType = request.getContentType(); + if (contentType == null || !contentType.toLowerCase().startsWith("application/json")) { + logger.warn("Invalid Content-Type header: '{}'", contentType); + response.sendError(HttpServletResponse.SC_BAD_REQUEST, "Content-Type must be application/json"); + return; + } + + // Validate headers for POST requests if DNS rebinding protection is configured + if (!validateDnsRebindingProtection(request, response)) { + return; + } + // Get the session ID from the request parameter String sessionId = request.getParameter("sessionId"); if (sessionId == null) { @@ -327,6 +371,37 @@ public Mono closeGracefully() { return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then(); } + /** + * Validates DNS rebinding protection for the given request and sends an error + * response if validation fails. + *

+ * Uses {@link HttpServletRequest#getServerName()} and + * {@link HttpServletRequest#getServerPort()} instead of the Host header for HTTP/2 + * compatibility. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @return true if validation passed (or no validation needed), false if validation + * failed and error was sent + * @throws IOException If an I/O error occurs while sending error response + */ + private boolean validateDnsRebindingProtection(HttpServletRequest request, HttpServletResponse response) + throws IOException { + if (dnsRebindingProtection != null) { + String serverName = request.getServerName(); + int serverPort = request.getServerPort(); + String hostValue = serverPort == 80 || serverPort == 443 ? serverName : serverName + ":" + serverPort; + + String originHeader = request.getHeader("Origin"); + if (!dnsRebindingProtection.isValid(hostValue, originHeader)) { + logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostValue, + originHeader); + response.sendError(HttpServletResponse.SC_FORBIDDEN, "DNS rebinding protection validation failed"); + return false; + } + } + return true; + } + /** * Sends an SSE event to a client. * @param writer The writer to send the event through @@ -475,6 +550,8 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + private DnsRebindingProtection dnsRebindingProtection; + /** * Sets the JSON object mapper to use for message serialization/deserialization. * @param objectMapper The object mapper to use @@ -522,6 +599,17 @@ public Builder sseEndpoint(String sseEndpoint) { return this; } + /** + * Sets the DNS rebinding protection configuration. + * @param config The DNS rebinding protection configuration + * @return This builder instance for method chaining + */ + public Builder dnsRebindingProtection(DnsRebindingProtection config) { + Assert.notNull(config, "DNS rebinding protection config must not be null"); + this.dnsRebindingProtection = config; + return this; + } + /** * Builds a new instance of HttpServletSseServerTransportProvider with the * configured settings. @@ -535,7 +623,8 @@ public HttpServletSseServerTransportProvider build() { if (messageEndpoint == null) { throw new IllegalStateException("MessageEndpoint must be set"); } - return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); + return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, + dnsRebindingProtection); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionTests.java new file mode 100644 index 000000000..87ae4b404 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionTests.java @@ -0,0 +1,157 @@ +package io.modelcontextprotocol.server.transport; + +import org.junit.jupiter.api.Test; + +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for DNS rebinding protection configuration. + */ +public class DnsRebindingProtectionTests { + + @Test + void testDefaultConfiguration() { + DnsRebindingProtection config = DnsRebindingProtection.builder().build(); + + // Test default behavior - when allowed lists are empty and headers are provided, + // validation fails because the headers are not in the (empty) allowed lists + assertThat(config.isValid("any.host.com", "http://any.origin.com")).isFalse(); + assertThat(config.isValid("localhost", null)).isFalse(); + assertThat(config.isValid(null, "http://example.com")).isFalse(); + // Null values are allowed when lists are empty + assertThat(config.isValid(null, null)).isTrue(); + } + + @Test + void testDisableDnsRebindingProtection() { + DnsRebindingProtection config = DnsRebindingProtection.builder() + .enableDnsRebindingProtection(false) + .allowedHost("localhost") // Should be ignored when protection is disabled + .allowedOrigin("http://localhost") // Should be ignored when protection is + // disabled + .build(); + + // When protection is disabled, all hosts and origins should be allowed + assertThat(config.isValid("evil.com", "http://evil.com")).isTrue(); + assertThat(config.isValid("any.host", "http://any.origin")).isTrue(); + assertThat(config.isValid(null, null)).isTrue(); + } + + @Test + void testHostValidation() { + DnsRebindingProtection config = DnsRebindingProtection.builder() + .allowedHost("localhost") + .allowedHost("127.0.0.1") + .build(); + + // Valid hosts + assertThat(config.isValid("localhost", null)).isTrue(); + assertThat(config.isValid("127.0.0.1", null)).isTrue(); + + // Invalid hosts + assertThat(config.isValid("evil.com", null)).isFalse(); + + // Null host is allowed when no specific hosts are being checked + assertThat(config.isValid(null, null)).isTrue(); + } + + @Test + void testOriginValidation() { + DnsRebindingProtection config = DnsRebindingProtection.builder() + .allowedOrigin("http://localhost:8080") + .allowedOrigin("https://app.example.com") + .build(); + + // Valid origins + assertThat(config.isValid(null, "http://localhost:8080")).isTrue(); + assertThat(config.isValid(null, "https://app.example.com")).isTrue(); + + // Invalid origins + assertThat(config.isValid(null, "http://evil.com")).isFalse(); + + // Null origin is allowed when no specific origins are being checked + assertThat(config.isValid(null, null)).isTrue(); + } + + @Test + void testCombinedHostAndOriginValidation() { + DnsRebindingProtection config = DnsRebindingProtection.builder() + .allowedHost("localhost") + .allowedOrigin("http://localhost:8080") + .build(); + + // Both valid + assertThat(config.isValid("localhost", "http://localhost:8080")).isTrue(); + + // Host valid, origin invalid + assertThat(config.isValid("localhost", "http://evil.com")).isFalse(); + + // Host invalid, origin valid + assertThat(config.isValid("evil.com", "http://localhost:8080")).isFalse(); + + // Both invalid + assertThat(config.isValid("evil.com", "http://evil.com")).isFalse(); + } + + @Test + void testCaseInsensitiveHostAndOrigin() { + DnsRebindingProtection config = DnsRebindingProtection.builder() + .allowedHost("LOCALHOST") + .allowedOrigin("HTTP://LOCALHOST:8080") + .build(); + + // Case insensitive matching + assertThat(config.isValid("localhost", null)).isTrue(); + assertThat(config.isValid("LOCALHOST", null)).isTrue(); + assertThat(config.isValid("LoCaLhOsT", null)).isTrue(); + + assertThat(config.isValid(null, "http://localhost:8080")).isTrue(); + assertThat(config.isValid(null, "HTTP://LOCALHOST:8080")).isTrue(); + } + + @Test + void testEmptyAllowedListsDenyNonNull() { + DnsRebindingProtection config = DnsRebindingProtection.builder().build(); + + // When allowed lists are empty and headers are provided, validation fails + assertThat(config.isValid("any.host.com", "http://any.origin.com")).isFalse(); + assertThat(config.isValid("random.host", "http://random.origin")).isFalse(); + // But null values are allowed + assertThat(config.isValid(null, null)).isTrue(); + } + + @Test + void testBuilderWithSets() { + Set hosts = Set.of("host1.com", "host2.com"); + Set origins = Set.of("http://origin1.com", "http://origin2.com"); + + DnsRebindingProtection config = DnsRebindingProtection.builder() + .allowedHosts(hosts) + .allowedOrigins(origins) + .build(); + + assertThat(config.isValid("host1.com", null)).isTrue(); + assertThat(config.isValid("host2.com", null)).isTrue(); + assertThat(config.isValid("host3.com", null)).isFalse(); + + assertThat(config.isValid(null, "http://origin1.com")).isTrue(); + assertThat(config.isValid(null, "http://origin2.com")).isTrue(); + assertThat(config.isValid(null, "http://origin3.com")).isFalse(); + } + + @Test + void testNullValuesWithConfiguredLists() { + DnsRebindingProtection config = DnsRebindingProtection.builder() + .allowedHost("localhost") + .allowedOrigin("http://localhost") + .build(); + + // Null values should be allowed when no check is needed for that header + assertThat(config.isValid(null, "http://localhost")).isTrue(); + assertThat(config.isValid("localhost", null)).isTrue(); + assertThat(config.isValid(null, null)).isTrue(); + } + +} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseHeaderValidationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseHeaderValidationTests.java new file mode 100644 index 000000000..c10d83b0f --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseHeaderValidationTests.java @@ -0,0 +1,240 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server.transport; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Map; +import java.util.concurrent.CompletionException; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Integration tests for header validation in + * {@link HttpServletSseServerTransportProvider}. + */ +class HttpServletSseHeaderValidationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private static final String SSE_ENDPOINT = "/sse"; + + private Tomcat tomcat; + + private HttpServletSseServerTransportProvider transportProvider; + + private McpSyncServer server; + + @AfterEach + void tearDown() { + if (server != null) { + server.close(); + } + if (transportProvider != null) { + transportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Test + void testConnectionSucceedsWithValidHeaders() { + // Create DNS rebinding protection config that validates API key + DnsRebindingProtection dnsRebindingProtection = DnsRebindingProtection.builder() + .enableDnsRebindingProtection(false) // Disable Host/Origin validation for + // this test + .build(); + + // For this test, we'll need to use a custom transport provider implementation + // since DnsRebindingProtection doesn't support custom header validation + transportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(SSE_ENDPOINT) + .dnsRebindingProtection(dnsRebindingProtection) + .build(); + + startServer(); + + // Create client - should succeed since DNS rebinding protection is disabled + try (var client = McpClient + .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).sseEndpoint(SSE_ENDPOINT).build()) + .build()) { + + // Connection should succeed + McpSchema.InitializeResult result = client.initialize(); + assertThat(result).isNotNull(); + assertThat(result.serverInfo().name()).isEqualTo("test-server"); + } + } + + @Test + void testConnectionFailsWithInvalidHeaders() { + // Create DNS rebinding protection config with restricted hosts + DnsRebindingProtection dnsRebindingProtection = DnsRebindingProtection.builder() + .allowedHost("valid-host.com") + .build(); + + // Create server with header validation + transportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(SSE_ENDPOINT) + .dnsRebindingProtection(dnsRebindingProtection) + .build(); + + startServer(); + + // Create client with localhost which won't match the allowed host + // The Host header will be "localhost:PORT" which won't match "valid-host.com" + var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(SSE_ENDPOINT) + .build(); + + // Connection should fail during initialization + assertThatThrownBy(() -> { + try (var client = McpClient.sync(clientTransport).build()) { + client.initialize(); + } + }).isInstanceOf(RuntimeException.class); + } + + @Test + void testConnectionFailsWithEmptyAllowedHostsButProvidedHost() { + // Create DNS rebinding protection config with specific allowed origin but no + // allowed hosts + // This means any non-null host will be rejected + DnsRebindingProtection dnsRebindingProtection = DnsRebindingProtection.builder() + .allowedOrigin("http://allowed-origin.com") + .build(); + + // Create server with header validation + transportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(SSE_ENDPOINT) + .dnsRebindingProtection(dnsRebindingProtection) + .build(); + + startServer(); + + // Create client - the client will send a Host header like "localhost:PORT" + // Since allowedHosts is empty, any non-null host will be rejected + var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(SSE_ENDPOINT) + .build(); + + // With the new behavior, a non-null Host header is rejected when allowedHosts is + // empty + assertThatThrownBy(() -> { + try (var client = McpClient.sync(clientTransport).build()) { + client.initialize(); + } + }).isInstanceOf(RuntimeException.class); + } + + @Test + void testComplexHeaderValidation() { + // Create DNS rebinding protection config with specific allowed hosts and origins + // Note: The Host header will include the port, so we need to allow + // "localhost:PORT" + DnsRebindingProtection dnsRebindingProtection = DnsRebindingProtection.builder() + .allowedHost("localhost:" + PORT) + .allowedOrigin("http://localhost:" + PORT) + .build(); + + // Create server with DNS rebinding protection + transportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(SSE_ENDPOINT) + .dnsRebindingProtection(dnsRebindingProtection) + .build(); + + startServer(); + + // Test with valid headers (localhost is allowed) + try (var client = McpClient + .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).sseEndpoint(SSE_ENDPOINT).build()) + .build()) { + + McpSchema.InitializeResult result = client.initialize(); + assertThat(result).isNotNull(); + } + + // Test with different host (should fail) + var invalidHostTransport = HttpClientSseClientTransport.builder("http://127.0.0.1:" + PORT) + .sseEndpoint(SSE_ENDPOINT) + .build(); + + assertThatThrownBy(() -> { + try (var client = McpClient.sync(invalidHostTransport).build()) { + client.initialize(); + } + }).isInstanceOf(RuntimeException.class); + } + + @Test + void testDefaultValidatorAllowsAllHeaders() { + // Create server without specifying a DNS rebinding protection config (no + // validation) + transportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(SSE_ENDPOINT) + .build(); + + startServer(); + + // Create client with arbitrary headers + try (var client = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(SSE_ENDPOINT) + .customizeRequest(requestBuilder -> { + requestBuilder.header("X-Random-Header", "random-value"); + requestBuilder.header("X-Another-Header", "another-value"); + }) + .build()).build()) { + + // Connection should succeed with any headers + McpSchema.InitializeResult result = client.initialize(); + assertThat(result).isNotNull(); + } + } + + private void startServer() { + tomcat = TomcatTestUtil.createTomcatServer("", PORT, transportProvider); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + server = McpServer.sync(transportProvider).serverInfo("test-server", "1.0.0").build(); + } + +}