diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxStreamableClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxStreamableClientTransport.java new file mode 100644 index 00000000..20f59bc7 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxStreamableClientTransport.java @@ -0,0 +1,402 @@ +package io.modelcontextprotocol.client.transport; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.client.WebClientResponseException; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.publisher.SynchronousSink; +import reactor.core.scheduler.Schedulers; +import reactor.util.retry.Retry; + +import java.io.IOException; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; +import java.util.function.Function; + + +/** + *

+ * The connection process follows these steps: + *

    + *
  1. The client sends an initialization POST request to the message endpoint, + * and the server may optionally generate and return a session ID to the client. + * This process is not required, only for stateful connections
  2. + * + *
  3. All messages are sent via HTTP POST requests to the message endpoint. + * If a session ID exists, it MUST be included in the request
  4. + * + *
  5. The client handles responses from the server accordingly based on the protocol specifications.
  6. + *
+ * + * This implementation uses {@link WebClient} for HTTP communications + * @author Yeaury + */ +public class WebFluxStreamableClientTransport implements McpClientTransport { + + private static final Logger logger = LoggerFactory.getLogger(WebFluxStreamableClientTransport.class); + + private static final String MESSAGE_EVENT_TYPE = "message"; + + private static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + private static final String DEFAULT_MCP_ENDPOINT = "/message"; + + private static final String MCP_SESSION_ID = "Mcp-Session-Id"; + + private static final String LAST_EVENT_ID = "Last-Event-ID"; + + private static final ParameterizedTypeReference> SSE_TYPE = new ParameterizedTypeReference<>() { + }; + + private final WebFluxSseClientTransport sseClientTransport; + + private final WebClient webClient; + + protected ObjectMapper objectMapper; + + private Disposable inboundSubscription; + + private volatile boolean isClosing = false; + + protected final Sinks.One messageEndpointSink = Sinks.one(); + + private final Sinks.Many> sseSink = Sinks.many().multicast().onBackpressureBuffer(); + + private final AtomicReference sessionId = new AtomicReference<>(null); + + private final AtomicReference lastEventId = new AtomicReference<>(null); + + private final AtomicBoolean fallbackToSse = new AtomicBoolean(false); + + private Function, Mono> handler; + + /** + * The endpoint URI provided by the server. + * Used to send outbound messages via HTTP POST requests, and is also used when initializing a connection. + */ + private final String endpoint; + + public WebFluxStreamableClientTransport(WebClient.Builder webClientBuilder) { + this(webClientBuilder, new ObjectMapper()); + } + + public WebFluxStreamableClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) { + this(webClientBuilder, objectMapper, DEFAULT_MCP_ENDPOINT); + } + + public WebFluxStreamableClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper, String defaultMcpEndpoint) { + this(webClientBuilder, objectMapper, defaultMcpEndpoint, Function.identity(), WebFluxSseClientTransport.builder(webClientBuilder).objectMapper(objectMapper).build()); + } + + public WebFluxStreamableClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper, + String sseEndpoint, Function, Mono> handler, + WebFluxSseClientTransport sseClientTransport) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); + Assert.hasText(sseEndpoint, "SSE endpoint must not be null or empty"); + Assert.notNull(handler, "Handler must not be null"); + + this.objectMapper = objectMapper; + this.webClient = webClientBuilder.build(); + this.endpoint = sseEndpoint; + this.handler = handler; + this.sseClientTransport = sseClientTransport; + } + + /** + * Connect to the MCP server using the WebClient. + * + *

+ * The connection is only used as an initialization operation, and the Mcp-Session-ID and Last-Event-ID are set. + *

+ * The initialization process is not required in Streamable HTTP, only when stateful management is implemented, + * and after Initializated, all subsequent requests need to be accompanied by the mcp-session-id. + */ + @Override + public Mono connect(Function, Mono> handler) { + if (fallbackToSse.get()) { + return sseClientTransport.connect(handler); + } + this.handler = handler; + + WebClient.RequestBodySpec request = webClient.post() + .uri(endpoint) + .contentType(MediaType.APPLICATION_JSON) + .header("Last-Event-ID", lastEventId.get() != null ? lastEventId.get() : null) + .header("MCP-SESSION-ID", sessionId.get() != null ? sessionId.get() : null); + + return request + .exchangeToMono(response -> { + if (response.statusCode().value() == 405 || response.statusCode().value() == 404) { + fallbackToSse.set(true); + return Mono.error(new IllegalStateException("Operation not allowed, falling back to SSE")); + } + if (!response.statusCode().is2xxSuccessful()) { + logger.error("Session initialization failed with status: {}", response.statusCode()); + return response.createException().flatMap(Mono::error); + } + setSessionId(response); + + return response.bodyToMono(Void.class) + .onErrorResume(e -> Mono.empty()); + }) + .retryWhen(Retry.backoff(3, Duration.ofSeconds(3)) + .filter(err -> err instanceof IOException || isServerError(err)) + .doAfterRetry(signal -> logger.debug("Retrying session initialization: {}", signal.totalRetries()))) + .doOnError(error -> logger.error("Error initializing session: {}", error.getMessage())) + .onErrorResume(e -> { + logger.error("WebClient transport connection error", e); + if (fallbackToSse.get()) { + return sseClientTransport.connect(handler); + } + return Mono.empty(); + }) + .then() + .subscribeOn(Schedulers.boundedElastic()); + } + + @Override + public Mono sendMessage(JSONRPCMessage message) { + if (isClosing) { + logger.debug("Client is closing, ignoring sendMessage request"); + return Mono.empty(); + } + try { + String jsonText = objectMapper.writeValueAsString(message); + WebClient.RequestBodySpec request = webClient.post() + .uri(endpoint) + .contentType(MediaType.APPLICATION_JSON) + .accept(MediaType.APPLICATION_JSON, MediaType.TEXT_EVENT_STREAM, new MediaType("application", "json-seq")); + + if (lastEventId.get() != null) { + request.header(LAST_EVENT_ID, lastEventId.get()); + } + if (sessionId.get() != null) { + request.header(MCP_SESSION_ID, sessionId.get()); + } + return request + .bodyValue(jsonText) + .exchangeToMono(response -> { + if (!response.statusCode().is2xxSuccessful()) { + logger.error("Connect request failed with status: {}", response.statusCode()); + return response.createException().flatMap(Mono::error); + } + setSessionId(response); + + var contentType = response.headers().contentType().orElse(MediaType.APPLICATION_JSON); + if (contentType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { + Flux> sseFlux = response.bodyToFlux(SSE_TYPE); + sseFlux.subscribe(sseSink::tryEmitNext); + return processSseEvents(handler); + } else if (contentType.isCompatibleWith(new MediaType("application", "json-seq"))) { + return handleJsonSeqResponse(response, handler); + } else { + return handleHttpResponse(response, handler); + } + }) + .retryWhen(Retry.backoff(3, Duration.ofSeconds(3)) + .filter(err -> err instanceof IOException || isServerError(err)) + .doAfterRetry(signal -> logger.debug("Retrying message send: {}", signal.totalRetries()))) + .doOnError(error -> { + if (!isClosing) { + logger.error("Error sending message: {}", error.getMessage()); + } + }); + } catch (IOException e) { + if (!isClosing) { + logger.error("Failed to serialize message", e); + return Mono.error(new RuntimeException("Failed to serialize message", e)); + } + return Mono.empty(); + } + } + + private void setSessionId(ClientResponse response) { + String sessionIdHeader = response.headers().header(MCP_SESSION_ID).stream().findFirst().orElse(null); + if (sessionIdHeader != null && !sessionIdHeader.trim().isEmpty()) { + sessionId.set(sessionIdHeader); + logger.debug("Received and stored sessionId: {}", sessionIdHeader); + } + } + + private Mono processSseEvents(Function, Mono> handler) { + this.inboundSubscription = sseSink.asFlux().concatMap(event -> Mono.just(event).handle((e, s) -> { + if (ENDPOINT_EVENT_TYPE.equals(event.event())) { + String messageEndpointUri = event.data(); + if (messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) { + logger.debug("Received endpoint URI: {}", messageEndpointUri); + s.complete(); + } else { + s.error(new McpError("Failed to handle SSE endpoint event")); + } + } else if (MESSAGE_EVENT_TYPE.equals(event.event())) { + try { + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data()); + s.next(message); + } catch (IOException ioException) { + s.error(ioException); + } + } else { + logger.warn("Received unrecognized SSE event type: {}", event.event()); + s.complete(); + } + }).transform(handler).flatMap(this::sendMessage)).subscribe(); + + return messageEndpointSink.asMono().then(); + } + + private Mono handleJsonSeqResponse(ClientResponse response, Function, Mono> handler) { + return response.bodyToFlux(String.class) + .flatMap(jsonLine -> { + try { + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, jsonLine); + return handler.apply(Mono.just(message)) + .onErrorResume(e -> { + logger.error("Error processing message", e); + return Mono.error(e); + }); + } catch (IOException e) { + logger.error("Error processing JSON-seq line: {}", jsonLine, e); + return Mono.empty(); + } + }) + .then(); + } + + private Mono handleHttpResponse(ClientResponse response, Function, Mono> handler) { + return response.bodyToMono(String.class) + .flatMap(body -> { + try { + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + return handler.apply(Mono.just(message)) + .then(Mono.empty()); + } catch (IOException e) { + logger.error("Error processing HTTP response body: {}", body, e); + return Mono.error(e); + } + }); + } + + /** + * It is used for the client to actively establish a SSE connection with the server + */ + public Mono connectWithGet(Function, Mono> handler) { + sessionId.set(null); + Flux> getSseFlux = this.webClient + .get() + .uri(endpoint) + .accept(MediaType.TEXT_EVENT_STREAM) + .retrieve() + .bodyToFlux(SSE_TYPE) + .retryWhen(Retry.from(retrySignal -> retrySignal.handle(inboundRetryHandler))); + + getSseFlux.subscribe(sseSink::tryEmitNext); + return processSseEvents(handler); + } + + private boolean isServerError(Throwable err) { + return err instanceof WebClientResponseException + && ((WebClientResponseException) err).getStatusCode().is5xxServerError(); + } + + private BiConsumer> inboundRetryHandler = (retrySpec, sink) -> { + if (isClosing) { + logger.debug("SSE connection closed during shutdown"); + sink.error(retrySpec.failure()); + return; + } + if (retrySpec.failure() instanceof IOException || isServerError(retrySpec.failure())) { + logger.debug("Retrying SSE connection after error"); + sink.next(retrySpec); + return; + } + logger.error("Fatal SSE error, not retrying: {}", retrySpec.failure().getMessage()); + sink.error(retrySpec.failure()); + }; + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + this.isClosing = true; + this.sessionId.set(null); + this.lastEventId.set(null); + if (this.inboundSubscription != null) { + this.inboundSubscription.dispose(); + } + this.sseSink.tryEmitComplete(); + }).then().subscribeOn(Schedulers.boundedElastic()); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return this.objectMapper.convertValue(data, typeRef); + } + + public static Builder builder(WebClient.Builder webClientBuilder) { + return new Builder(webClientBuilder); + } + + public static class Builder { + + private final WebClient.Builder webClientBuilder; + + private String endpoint = DEFAULT_MCP_ENDPOINT; + + private ObjectMapper objectMapper = new ObjectMapper(); + + private Function, Mono> handler = Function.identity(); + + public Builder(WebClient.Builder webClientBuilder) { + Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); + this.webClientBuilder = webClientBuilder; + } + + public Builder sseEndpoint(String sseEndpoint) { + Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); + this.endpoint = sseEndpoint; + return this; + } + + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "objectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + public Builder handler(Function, Mono> handler) { + Assert.notNull(handler, "handler must not be null"); + this.handler = handler; + return this; + } + + public WebFluxStreamableClientTransport build() { + final WebFluxSseClientTransport.Builder builder = WebFluxSseClientTransport.builder(webClientBuilder) + .objectMapper(objectMapper); + + if (!endpoint.equals(DEFAULT_MCP_ENDPOINT)) { + builder.sseEndpoint(endpoint); + } + + return new WebFluxStreamableClientTransport(webClientBuilder, objectMapper, endpoint, + handler, builder.build()); + } + + } +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxStreamableClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxStreamableClientTransportTests.java new file mode 100644 index 00000000..3fa0cf4d --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxStreamableClientTransportTests.java @@ -0,0 +1,394 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; + +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; + +/** + * Tests for the {@link WebFluxStreamableClientTransport} class. + */ +@Timeout(15) +class WebFluxStreamableClientTransportTests { + + // Set up a gateway or server that adapts to Streamable HTTP Transport + private static final String HOST = "http://yourServer:port"; + + private TestStreamableClientTransport transport; + + private WebClient.Builder webClientBuilder; + + private ObjectMapper objectMapper; + + // Test class to access protected methods and simulate events + static class TestStreamableClientTransport extends WebFluxStreamableClientTransport { + + private final AtomicInteger inboundMessageCount = new AtomicInteger(0); + + private Sinks.Many> events = Sinks.many().unicast().onBackpressureBuffer(); + + public TestStreamableClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) { + super(webClientBuilder, objectMapper); + } + + public String getLastEndpoint() { + return messageEndpointSink.asMono().block(); + } + + public int getInboundMessageCount() { + return inboundMessageCount.get(); + } + + public void simulateEndpointEvent(String jsonMessage) { + events.tryEmitNext(ServerSentEvent.builder().event("endpoint").data(jsonMessage).build()); + inboundMessageCount.incrementAndGet(); + } + + public void simulateMessageEvent(String jsonMessage) { + events.tryEmitNext(ServerSentEvent.builder().event("message").data(jsonMessage).build()); + inboundMessageCount.incrementAndGet(); + } + } + + @BeforeEach + void setUp() { + webClientBuilder = WebClient.builder().baseUrl(HOST); + objectMapper = new ObjectMapper(); + transport = new TestStreamableClientTransport(webClientBuilder, objectMapper); + } + + @AfterEach + void afterEach() { + if (transport != null) { + assertThatCode(() -> transport.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + } + + @Test + void testBuilderPattern() { + // Test default builder + WebFluxStreamableClientTransport transport1 = WebFluxStreamableClientTransport.builder(webClientBuilder).build(); + assertThatCode(() -> transport1.closeGracefully().block()).doesNotThrowAnyException(); + + // Test builder with custom ObjectMapper + ObjectMapper customMapper = new ObjectMapper(); + WebFluxStreamableClientTransport transport2 = WebFluxStreamableClientTransport.builder(webClientBuilder) + .objectMapper(customMapper) + .build(); + assertThatCode(() -> transport2.closeGracefully().block()).doesNotThrowAnyException(); + + // Test builder with custom SSE endpoint + WebFluxStreamableClientTransport transport3 = WebFluxStreamableClientTransport.builder(webClientBuilder) + .sseEndpoint("/custom-sse") + .build(); + assertThatCode(() -> transport3.closeGracefully().block()).doesNotThrowAnyException(); + + // Test builder with all custom parameters + WebFluxStreamableClientTransport transport4 = WebFluxStreamableClientTransport.builder(webClientBuilder) + .objectMapper(customMapper) + .sseEndpoint("/custom-sse") + .build(); + assertThatCode(() -> transport4.closeGracefully().block()).doesNotThrowAnyException(); + } + + @Test + void testMessageProcessing() { + // Create a test message + McpSchema.JSONRPCRequest testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Map.of("key", "value")); + + // Simulate receiving the message + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "method": "test-method", + "id": "test-id", + "params": {"key": "value"} + } + """); + + // Subscribe to messages and verify + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + assertThat(transport.getInboundMessageCount()).isEqualTo(1); + } + + @Test + void testResponseMessageProcessing() { + // Simulate receiving a response message + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "id": "test-id", + "result": {"status": "success"} + } + """); + + // Create and send a request message + McpSchema.JSONRPCRequest testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Map.of("key", "value")); + + // Verify message handling + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + assertThat(transport.getInboundMessageCount()).isEqualTo(1); + } + + @Test + void testErrorMessageProcessing() { + // Simulate receiving an error message + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "id": "test-id", + "error": { + "code": -32600, + "message": "Invalid Request" + } + } + """); + + // Create and send a request message + McpSchema.JSONRPCRequest testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Map.of("key", "value")); + + // Verify message handling + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + assertThat(transport.getInboundMessageCount()).isEqualTo(1); + } + + @Test + void testNotificationMessageProcessing() { + // Simulate receiving a notification message (no id) + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "method": "update", + "params": {"status": "processing"} + } + """); + + // Verify the notification was processed + assertThat(transport.getInboundMessageCount()).isEqualTo(1); + } + + @Test + void testGracefulShutdown() { + // Test graceful shutdown + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + // Create a test message + McpSchema.JSONRPCRequest testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Map.of("key", "value")); + + // Verify message is not processed after shutdown + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Message count should remain 0 after shutdown + assertThat(transport.getInboundMessageCount()).isEqualTo(0); + } + + @Test + void testRetryBehavior() { + // Create a WebClient that simulates connection failures + WebClient.Builder failingWebClientBuilder = WebClient.builder().baseUrl("http://non-existent-host"); + + WebFluxSseClientTransport failingTransport = WebFluxSseClientTransport.builder(failingWebClientBuilder).build(); + + // Verify that the transport attempts to reconnect + StepVerifier.create(Mono.delay(Duration.ofSeconds(2))).expectNextCount(1).verifyComplete(); + + // Clean up + failingTransport.closeGracefully().block(); + } + + @Test + void testMultipleMessageProcessing() { + // Simulate receiving multiple messages in sequence + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "method": "method1", + "id": "id1", + "params": {"key": "value1"} + } + """); + + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "method": "method2", + "id": "id2", + "params": {"key": "value2"} + } + """); + + // Create and send corresponding messages + McpSchema.JSONRPCRequest message1 = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", "id1", + Map.of("key", "value1")); + + McpSchema.JSONRPCRequest message2 = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method2", "id2", + Map.of("key", "value2")); + + // Verify both messages are processed + StepVerifier.create(transport.sendMessage(message1).then(transport.sendMessage(message2))).verifyComplete(); + + // Verify message count + assertThat(transport.getInboundMessageCount()).isEqualTo(2); + } + + @Test + void testMessageOrderPreservation() { + // Simulate receiving messages in a specific order + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "method": "first", + "id": "1", + "params": {"sequence": 1} + } + """); + + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "method": "second", + "id": "2", + "params": {"sequence": 2} + } + """); + + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "method": "third", + "id": "3", + "params": {"sequence": 3} + } + """); + + // Verify message count and order + assertThat(transport.getInboundMessageCount()).isEqualTo(3); + } + + @Test + void testSendMessageWithSseResponse() { + // the mock server returns an sse flow response + McpSchema.JSONRPCRequest testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + "test-method", "test-id", Map.of("key", "value")); + + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "method": "test-method", + "id": "test-id", + "params": {"key": "value"} + } + """); + + // send a message and verify processing + StepVerifier.create(transport.sendMessage(testMessage)) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + assertThat(transport.getInboundMessageCount()).isEqualTo(1); + } + + @Test + void testSendMessageWithJsonSeqResponse() { + // Simulate a JSON-seq response with multiple rows of data + McpSchema.JSONRPCRequest testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + "test-method", "test-id", Map.of("key", "value")); + + // Here you need to simulate a ClientResponse to return type application/json-seq and contain multiple lines of JSON + // Let's say that when JSON-seq is processed in TestStreamableClientTransport, + // the handler is called to process each JSON line + // Here, the processing of multiple rows of data is simulated by calling handler.apply directly + StepVerifier.create(transport.sendMessage(testMessage)) + .then(() -> { + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "result": "ok", + "id": "1" + } + """); + + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "result": "done", + "id": "2" + } + """); + + }) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + assertThat(transport.getInboundMessageCount()).isEqualTo(2); + } + + @Test + void testSendMessageWithHttpResponse() { + // simulate a normal http json response + McpSchema.JSONRPCRequest testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + "test-method", "test-id", Map.of("key", "value")); + + // Let's say the server returns a single JSON response + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "result": "success", + "id": "test-id" + } + """); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + assertThat(transport.getInboundMessageCount()).isEqualTo(1); + } + + + + @Test + void testConnectWithGetSuccessfully() { + // Test connectWithGet to successfully establish an SSE connection and process events + transport.simulateMessageEvent(""" + { + "jsonrpc": "2.0", + "result": "update" + } + """); + + StepVerifier.create(transport.connectWithGet(msg -> Mono.empty())) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + assertThat(transport.getLastEndpoint()).isEqualTo("http://new-endpoint"); + assertThat(transport.getInboundMessageCount()).isEqualTo(2); + } + + +}