diff --git a/Sources/MCP/Base/Transports/InMemoryTransport.swift b/Sources/MCP/Base/Transports/InMemoryTransport.swift new file mode 100644 index 0000000..9234dfc --- /dev/null +++ b/Sources/MCP/Base/Transports/InMemoryTransport.swift @@ -0,0 +1,197 @@ +import Foundation +import Logging + +/// An in-memory transport implementation for direct communication within the same process. +/// +/// - Example: +/// ```swift +/// // Create a connected pair of transports +/// let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() +/// +/// // Use with client and server +/// let client = Client(name: "MyApp", version: "1.0.0") +/// let server = Server(name: "MyServer", version: "1.0.0") +/// +/// try await client.connect(transport: clientTransport) +/// try await server.connect(transport: serverTransport) +/// ``` +public actor InMemoryTransport: Transport { + /// Logger instance for transport-related events + public nonisolated let logger: Logger + + private var isConnected = false + private var pairedTransport: InMemoryTransport? + + // Message queues + private var incomingMessages: [Data] = [] + private var messageContinuation: AsyncThrowingStream.Continuation? + + /// Creates a new in-memory transport + /// + /// - Parameter logger: Optional logger instance for transport events + public init(logger: Logger? = nil) { + self.logger = + logger + ?? Logger( + label: "mcp.transport.in-memory", + factory: { _ in SwiftLogNoOpLogHandler() } + ) + } + + /// Creates a connected pair of in-memory transports + /// + /// This is the recommended way to create transports for client-server communication + /// within the same process. The returned transports are already paired and ready + /// to be connected. + /// + /// - Parameter logger: Optional logger instance shared by both transports + /// - Returns: A tuple of (clientTransport, serverTransport) ready for use + public static func createConnectedPair( + logger: Logger? = nil + ) async -> (client: InMemoryTransport, server: InMemoryTransport) { + let clientLogger: Logger + let serverLogger: Logger + + if let providedLogger = logger { + // If a logger is provided, use it directly for both transports + clientLogger = providedLogger + serverLogger = providedLogger + } else { + // Create default loggers with appropriate labels + clientLogger = Logger( + label: "mcp.transport.in-memory.client", + factory: { _ in SwiftLogNoOpLogHandler() } + ) + serverLogger = Logger( + label: "mcp.transport.in-memory.server", + factory: { _ in SwiftLogNoOpLogHandler() } + ) + } + + let clientTransport = InMemoryTransport(logger: clientLogger) + let serverTransport = InMemoryTransport(logger: serverLogger) + + // Perform pairing + await clientTransport.pair(with: serverTransport) + await serverTransport.pair(with: clientTransport) + + return (clientTransport, serverTransport) + } + + /// Pairs this transport with another for bidirectional communication + /// + /// - Parameter other: The transport to pair with + /// - Important: This method should typically not be called directly. + /// Use `createConnectedPair()` instead. + private func pair(with other: InMemoryTransport) { + self.pairedTransport = other + } + + /// Establishes connection with the transport + /// + /// For in-memory transports, this validates that the transport is properly + /// paired and sets up the message stream. + /// + /// - Throws: MCPError.internalError if the transport is not paired + public func connect() async throws { + guard !isConnected else { + logger.debug("Transport already connected") + return + } + + guard pairedTransport != nil else { + throw MCPError.internalError( + "Transport not paired. Use createConnectedPair() to create paired transports.") + } + + isConnected = true + logger.info("Transport connected successfully") + } + + /// Disconnects from the transport + /// + /// This closes the message stream and marks the transport as disconnected. + public func disconnect() async { + guard isConnected else { return } + + isConnected = false + messageContinuation?.finish() + messageContinuation = nil + + // Notify paired transport of disconnection + if let paired = pairedTransport { + await paired.handlePeerDisconnection() + } + + logger.info("Transport disconnected") + } + + /// Handles disconnection from the paired transport + private func handlePeerDisconnection() { + if isConnected { + messageContinuation?.finish(throwing: MCPError.connectionClosed) + messageContinuation = nil + isConnected = false + logger.info("Peer transport disconnected") + } + } + + /// Sends a message to the paired transport + /// + /// Messages are delivered directly to the paired transport's receive queue + /// without any additional encoding or framing. + /// + /// - Parameter data: The message data to send + /// - Throws: MCPError.internalError if not connected or no paired transport + public func send(_ data: Data) async throws { + guard isConnected else { + throw MCPError.internalError("Transport not connected") + } + + guard let paired = pairedTransport else { + throw MCPError.internalError("No paired transport") + } + + logger.debug("Sending message", metadata: ["size": "\(data.count)"]) + + // Deliver message to paired transport + await paired.deliverMessage(data) + } + + /// Delivers a message from the paired transport + private func deliverMessage(_ data: Data) { + guard isConnected else { + logger.warning("Received message while disconnected") + return + } + + logger.debug("Message received", metadata: ["size": "\(data.count)"]) + + if let continuation = messageContinuation { + continuation.yield(data) + } else { + // Queue message if stream not yet created + incomingMessages.append(data) + } + } + + /// Receives messages from the paired transport + /// + /// - Returns: An AsyncThrowingStream of Data objects representing messages + public func receive() -> AsyncThrowingStream { + return AsyncThrowingStream { continuation in + self.messageContinuation = continuation + + // Deliver any queued messages + for message in self.incomingMessages { + continuation.yield(message) + } + self.incomingMessages.removeAll() + + // Check if already disconnected + if !self.isConnected { + continuation.finish() + } + } + } +} diff --git a/Tests/MCPTests/InMemoryTransportTests.swift b/Tests/MCPTests/InMemoryTransportTests.swift new file mode 100644 index 0000000..d8228e7 --- /dev/null +++ b/Tests/MCPTests/InMemoryTransportTests.swift @@ -0,0 +1,409 @@ +import Foundation +import Logging +import Testing + +@testable import MCP + +@Suite("InMemory Transport Tests") +struct InMemoryTransportTests { + + @Test("Create connected pair") + func testCreateConnectedPair() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + // Verify both transports can be connected + try await clientTransport.connect() + try await serverTransport.connect() + + // Clean up + await clientTransport.disconnect() + await serverTransport.disconnect() + } + + @Test("Connect without pairing throws error") + func testConnectWithoutPairing() async throws { + let transport = InMemoryTransport() + + // Attempt to connect without pairing should throw + do { + try await transport.connect() + #expect(Bool(false), "Expected connect to throw an error") + } catch let error as MCPError { + if case .internalError(let message) = error { + #expect( + message + == "Transport not paired. Use createConnectedPair() to create paired transports." + ) + } else { + #expect(Bool(false), "Expected MCPError.internalError") + } + } + } + + @Test("Multiple connect calls are idempotent") + func testMultipleConnectCalls() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + // Connect multiple times should not throw + try await clientTransport.connect() + try await clientTransport.connect() // Should be safe + try await clientTransport.connect() // Should be safe + + // Clean up + await clientTransport.disconnect() + await serverTransport.disconnect() + } + + @Test("Send and receive messages") + func testSendAndReceiveMessages() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + try await clientTransport.connect() + try await serverTransport.connect() + + // Start receiving on server + let serverReceiveTask = Task { + var messages: [Data] = [] + for try await message in await serverTransport.receive() { + messages.append(message) + if messages.count >= 3 { + break + } + } + return messages + } + + // Send messages from client + let message1 = "Hello".data(using: .utf8)! + let message2 = "World".data(using: .utf8)! + let message3 = "!".data(using: .utf8)! + + try await clientTransport.send(message1) + try await clientTransport.send(message2) + try await clientTransport.send(message3) + + // Wait for messages to be received + let receivedMessages = try await serverReceiveTask.value + + #expect(receivedMessages.count == 3) + #expect(receivedMessages[0] == message1) + #expect(receivedMessages[1] == message2) + #expect(receivedMessages[2] == message3) + + // Clean up + await clientTransport.disconnect() + await serverTransport.disconnect() + } + + @Test("Bidirectional communication") + func testBidirectionalCommunication() async throws { + let (transport1, transport2) = await InMemoryTransport.createConnectedPair() + + try await transport1.connect() + try await transport2.connect() + + // Set up receivers + let receive1Task = Task { + var messages: [String] = [] + for try await data in await transport1.receive() { + if let message = String(data: data, encoding: .utf8) { + messages.append(message) + if messages.count >= 2 { + break + } + } + } + return messages + } + + let receive2Task = Task { + var messages: [String] = [] + for try await data in await transport2.receive() { + if let message = String(data: data, encoding: .utf8) { + messages.append(message) + if messages.count >= 2 { + break + } + } + } + return messages + } + + // Send messages in both directions + try await transport1.send("From transport 1 - message 1".data(using: .utf8)!) + try await transport2.send("From transport 2 - message 1".data(using: .utf8)!) + try await transport1.send("From transport 1 - message 2".data(using: .utf8)!) + try await transport2.send("From transport 2 - message 2".data(using: .utf8)!) + + // Verify both sides received messages + let messages1 = try await receive1Task.value + let messages2 = try await receive2Task.value + + #expect(messages1.count == 2) + #expect(messages1[0] == "From transport 2 - message 1") + #expect(messages1[1] == "From transport 2 - message 2") + + #expect(messages2.count == 2) + #expect(messages2[0] == "From transport 1 - message 1") + #expect(messages2[1] == "From transport 1 - message 2") + + // Clean up + await transport1.disconnect() + await transport2.disconnect() + } + + @Test("Send without connection throws error") + func testSendWithoutConnection() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + // Try to send without connecting + do { + try await clientTransport.send("test".data(using: .utf8)!) + #expect(Bool(false), "Expected send to throw an error") + } catch let error as MCPError { + if case .internalError(let message) = error { + #expect(message == "Transport not connected") + } else { + #expect(Bool(false), "Expected MCPError.internalError") + } + } + + // Clean up (connect server to avoid dangling connections) + try await serverTransport.connect() + await serverTransport.disconnect() + } + + @Test("Disconnect stops message stream") + func testDisconnectStopsMessageStream() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + try await clientTransport.connect() + try await serverTransport.connect() + + // Start receiving + let receiveTask = Task { + var messageCount = 0 + do { + for try await _ in await serverTransport.receive() { + messageCount += 1 + } + } catch { + // Expected when disconnected + } + return messageCount + } + + // Send a message + try await clientTransport.send("message".data(using: .utf8)!) + + // Give some time for message to be received + try await Task.sleep(for: .milliseconds(100)) + + // Disconnect + await serverTransport.disconnect() + + // The receive stream should complete + let messageCount = await receiveTask.value + #expect(messageCount >= 1) + } + + @Test("Peer disconnection handling") + func testPeerDisconnectionHandling() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + try await clientTransport.connect() + try await serverTransport.connect() + + // Start receiving on server + let receiveTask = Task { () -> Swift.Error? in + do { + for try await _ in await serverTransport.receive() { + // Keep receiving + } + return nil + } catch { + return error + } + } + + // Give a moment for the receive stream to be set up + try await Task.sleep(for: .milliseconds(10)) + + // Disconnect client (peer) + await clientTransport.disconnect() + + receiveTask.cancel() + } + + @Test("Multiple disconnects are safe") + func testMultipleDisconnects() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + try await clientTransport.connect() + try await serverTransport.connect() + + // Multiple disconnects should be safe + await clientTransport.disconnect() + await clientTransport.disconnect() + await clientTransport.disconnect() + + // Clean up server + await serverTransport.disconnect() + } + + @Test("Message queueing before stream creation") + func testMessageQueueingBeforeStreamCreation() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + try await clientTransport.connect() + try await serverTransport.connect() + + // Send messages before receive stream is created + try await clientTransport.send("message1".data(using: .utf8)!) + try await clientTransport.send("message2".data(using: .utf8)!) + try await clientTransport.send("message3".data(using: .utf8)!) + + // Now create receive stream + let messages = await serverTransport.receive() + var receivedMessages: [String] = [] + + for try await data in messages { + if let message = String(data: data, encoding: .utf8) { + receivedMessages.append(message) + if receivedMessages.count >= 3 { + break + } + } + } + + #expect(receivedMessages.count == 3) + #expect(receivedMessages[0] == "message1") + #expect(receivedMessages[1] == "message2") + #expect(receivedMessages[2] == "message3") + + // Clean up + await clientTransport.disconnect() + await serverTransport.disconnect() + } + + @Test("Receive after disconnect returns completed stream") + func testReceiveAfterDisconnect() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + try await clientTransport.connect() + try await serverTransport.connect() + + // Disconnect before receiving + await serverTransport.disconnect() + + // Create receive stream after disconnect + let messages = await serverTransport.receive() + var messageCount = 0 + + for try await _ in messages { + messageCount += 1 + } + + // Stream should complete immediately + #expect(messageCount == 0) + + // Clean up + await clientTransport.disconnect() + } + + @Test("Large message handling") + func testLargeMessageHandling() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + try await clientTransport.connect() + try await serverTransport.connect() + + // Create a large message (1MB) + let largeData = Data(repeating: 0xFF, count: 1024 * 1024) + + // Start receiving + let receiveTask = Task { + for try await data in await serverTransport.receive() { + return data + } + return Data() + } + + // Send large message + try await clientTransport.send(largeData) + + // Verify it was received correctly + let receivedData = try await receiveTask.value + #expect(receivedData == largeData) + + // Clean up + await clientTransport.disconnect() + await serverTransport.disconnect() + } + + @Test("Concurrent send operations") + func testConcurrentSendOperations() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + try await clientTransport.connect() + try await serverTransport.connect() + + // Start receiving + let receiveTask = Task { + var messages: [String] = [] + for try await data in await serverTransport.receive() { + if let message = String(data: data, encoding: .utf8) { + messages.append(message) + if messages.count >= 10 { + break + } + } + } + return messages + } + + // Send messages concurrently + await withTaskGroup(of: Void.self) { group in + for i in 0..<10 { + group.addTask { + try? await clientTransport.send("Message \(i)".data(using: .utf8)!) + } + } + } + + // Verify all messages were received + let receivedMessages = try await receiveTask.value + #expect(receivedMessages.count == 10) + + // Check that all messages are present (order may vary due to concurrency) + let expectedMessages = Set((0..<10).map { "Message \($0)" }) + let actualMessages = Set(receivedMessages) + #expect(actualMessages == expectedMessages) + + // Clean up + await clientTransport.disconnect() + await serverTransport.disconnect() + } + + @Test("Custom logger usage") + func testCustomLoggerUsage() async throws { + // Create a custom logger (in real tests, you might use a test logger that captures output) + let logger = Logger(label: "test.in-memory.transport") + + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair( + logger: logger) + + // Verify loggers are set correctly - when a custom logger is provided, it's used for both + #expect(clientTransport.logger.label == "test.in-memory.transport") + #expect(serverTransport.logger.label == "test.in-memory.transport") + + // Test basic operations with custom logger + try await clientTransport.connect() + try await serverTransport.connect() + + try await clientTransport.send("test".data(using: .utf8)!) + + await clientTransport.disconnect() + await serverTransport.disconnect() + } +}