diff --git a/Networking/Sources/Networking/Connection.swift b/Networking/Sources/Networking/Connection.swift index 9a0e4a99..bc7b54cb 100644 --- a/Networking/Sources/Networking/Connection.swift +++ b/Networking/Sources/Networking/Connection.swift @@ -20,12 +20,14 @@ enum ConnectionError: Error { case invalidLength case unexpectedState case closed + case reconnect } enum ConnectionState { case connecting(continuations: [CheckedContinuation]) case connected(publicKey: Data) case closed + case reconnect(publicKey: Data) } public final class Connection: Sendable, ConnectionInfoProtocol { @@ -50,6 +52,8 @@ public final class Connection: Sendable, ConnectionInfoP publicKey case .closed: nil + case let .reconnect(publicKey): + publicKey } } } @@ -91,6 +95,18 @@ public final class Connection: Sendable, ConnectionInfoP } } + func reconnect(publicKey: Data) { + state.write { state in + if case let .connecting(continuations) = state { + for continuation in continuations { + continuation.resume(throwing: ConnectionError.reconnect) + } + state = .reconnect(publicKey: publicKey) + } + state = .reconnect(publicKey: publicKey) + } + } + public var isClosed: Bool { state.read { switch $0 { @@ -100,7 +116,18 @@ public final class Connection: Sendable, ConnectionInfoP false case .closed: true + case .reconnect: + false + } + } + } + + public var needReconnect: Bool { + state.read { + if case .reconnect = $0 { + return true } + return false } } @@ -113,6 +140,8 @@ public final class Connection: Sendable, ConnectionInfoP true case .closed: true + case .reconnect: + true } } diff --git a/Networking/Sources/Networking/Peer.swift b/Networking/Sources/Networking/Peer.swift index e1dcf626..fe5d5911 100644 --- a/Networking/Sources/Networking/Peer.swift +++ b/Networking/Sources/Networking/Peer.swift @@ -75,12 +75,14 @@ public final class Peer: Sendable { let registration = try QuicRegistration() let serverConfiguration = try QuicConfiguration( - registration: registration, pkcs12: pkcs12, alpns: allAlpns, client: false, settings: options.serverSettings + registration: registration, pkcs12: pkcs12, alpns: allAlpns, client: false, + settings: options.serverSettings ) let clientAlpn = alpns[options.role]! let clientConfiguration = try QuicConfiguration( - registration: registration, pkcs12: pkcs12, alpns: [clientAlpn], client: true, settings: options.clientSettings + registration: registration, pkcs12: pkcs12, alpns: [clientAlpn], client: true, + settings: options.clientSettings ) publicKey = options.secretKey.publicKey.data.data @@ -104,11 +106,14 @@ public final class Peer: Sendable { alpns: allAlpns ) - logger.debug("Peer initialized", metadata: [ - "listenAddress": "\(options.listenAddress)", - "role": "\(options.role)", - "publicKey": "\(options.secretKey.publicKey.data.toHexString())", - ]) + logger.debug( + "Peer initialized", + metadata: [ + "listenAddress": "\(options.listenAddress)", + "role": "\(options.role)", + "publicKey": "\(options.secretKey.publicKey.data.toHexString())", + ] + ) } public func listenAddress() throws -> NetAddr { @@ -120,30 +125,33 @@ public final class Peer: Sendable { let conn = impl.connections.read { connections in connections.byAddr[address] } - return try conn ?? impl.connections.write { connections in - if let curr = connections.byAddr[address] { - return curr - } - - logger.debug("connecting to peer", metadata: ["address": "\(address)", "role": "\(role)"]) + return try conn + ?? impl.connections.write { connections in + if let curr = connections.byAddr[address] { + return curr + } - let quicConn = try QuicConnection( - handler: PeerEventHandler(self.impl), - registration: self.impl.clientConfiguration.registration, - configuration: self.impl.clientConfiguration - ) - try quicConn.connect(to: address) - let conn = Connection( - quicConn, - impl: self.impl, - role: role, - remoteAddress: address, - initiatedByLocal: true - ) - connections.byAddr[address] = conn - connections.byId[conn.id] = conn - return conn - } + logger.debug( + "connecting to peer", metadata: ["address": "\(address)", "role": "\(role)"] + ) + + let quicConn = try QuicConnection( + handler: PeerEventHandler(self.impl), + registration: self.impl.clientConfiguration.registration, + configuration: self.impl.clientConfiguration + ) + try quicConn.connect(to: address) + let conn = Connection( + quicConn, + impl: self.impl, + role: role, + remoteAddress: address, + initiatedByLocal: true + ) + connections.byAddr[address] = conn + connections.byId[conn.id] = conn + return conn + } } public func getConnection(publicKey: Data) -> Connection? { @@ -152,11 +160,12 @@ public final class Peer: Sendable { } } - public func broadcast(kind: Handler.PresistentHandler.StreamKind, message: Handler.PresistentHandler.Message) { + public func broadcast( + kind: Handler.PresistentHandler.StreamKind, message: Handler.PresistentHandler.Message + ) { let connections = impl.connections.read { connections in connections.byId.values } - guard let messageData = try? message.encode() else { impl.logger.warning("Failed to encode message: \(message)") return @@ -168,11 +177,14 @@ public final class Peer: Sendable { case .success: break case let .failure(error): - impl.logger.warning("Failed to send message", metadata: [ - "connectionId": "\(connection.id)", - "kind": "\(kind)", - "error": "\(error)", - ]) + impl.logger.warning( + "Failed to send message", + metadata: [ + "connectionId": "\(connection.id)", + "kind": "\(kind)", + "error": "\(error)", + ] + ) } } } @@ -259,6 +271,32 @@ final class PeerImpl: Sendable { } } + // TODO: Add reconnection attempts & Apply exponential backoff delay + func reconnect(to address: NetAddr, role: PeerRole) throws { + logger.debug("reconnecting", metadata: ["to address": "\(address)", "role": "\(role)"]) + try connections.write { connections in + if connections.byAddr[address] != nil { + logger.warning("reconnecting to \(address) already connected") + return + } + let quicConn = try QuicConnection( + handler: PeerEventHandler(self), + registration: clientConfiguration.registration, + configuration: clientConfiguration + ) + try quicConn.connect(to: address) + let conn = Connection( + quicConn, + impl: self, + role: role, + remoteAddress: address, + initiatedByLocal: true + ) + connections.byAddr[address] = conn + connections.byId[conn.id] = conn + } + } + func addStream(_ stream: Stream) { streams.write { streams in if streams[stream.id] != nil { @@ -280,11 +318,15 @@ private struct PeerEventHandler: QuicEventHandler { self.impl = impl } - func newConnection(_: QuicListener, connection: QuicConnection, info: ConnectionInfo) -> QuicStatus { + func newConnection(_: QuicListener, connection: QuicConnection, info: ConnectionInfo) + -> QuicStatus + { let addr = info.remoteAddress let role = impl.alpnLookup[info.negotiatedAlpn] guard let role else { - logger.warning("unknown alpn: \(String(data: info.negotiatedAlpn, encoding: .utf8) ?? info.negotiatedAlpn.toDebugHexString())") + logger.warning( + "unknown alpn: \(String(data: info.negotiatedAlpn, encoding: .utf8) ?? info.negotiatedAlpn.toDebugHexString())" + ) return .code(.alpnNegFailure) } logger.debug("new connection: \(addr) role: \(role)") @@ -303,20 +345,28 @@ private struct PeerEventHandler: QuicEventHandler { connections.byId[connection.id] } guard let conn else { - logger.warning("Attempt to open but connection is absent", metadata: ["connectionId": "\(connection.id)"]) + logger.warning( + "Attempt to open but connection is absent", + metadata: ["connectionId": "\(connection.id)"] + ) return .code(.connectionRefused) } do { let (publicKey, alternativeName) = try parseCertificate(data: certificate, type: .x509) - logger.trace("Certificate parsed", metadata: [ - "connectionId": "\(connection.id)", - "publicKey": "\(publicKey.toHexString())", - "alternativeName": "\(alternativeName)", - ]) + logger.trace( + "Certificate parsed", + metadata: [ + "connectionId": "\(connection.id)", + "publicKey": "\(publicKey.toHexString())", + "alternativeName": "\(alternativeName)", + ] + ) if publicKey == impl.publicKey { // Self connection detected - logger.trace("Rejecting self-connection", metadata: ["connectionId": "\(connection.id)"]) + logger.trace( + "Rejecting self-connection", metadata: ["connectionId": "\(connection.id)"] + ) return .code(.connectionRefused) } if alternativeName != generateSubjectAlternativeName(pubkey: publicKey) { @@ -333,9 +383,13 @@ private struct PeerEventHandler: QuicEventHandler { try conn.opened(publicKey: publicKey) return .code(.success) } else { - logger.debug("Rejecting duplicate connection by rule", metadata: [ - "connectionId": "\(connection.id)", "publicKey": "\(publicKey.toHexString())", - ]) + logger.debug( + "Rejecting duplicate connection by rule", + metadata: [ + "connectionId": "\(connection.id)", + "publicKey": "\(publicKey.toHexString())", + ] + ) return .code(.connectionRefused) } } else { @@ -345,7 +399,10 @@ private struct PeerEventHandler: QuicEventHandler { } } } catch { - logger.warning("Certificate parsing failed", metadata: ["connectionId": "\(connection.id)", "error": "\(error)"]) + logger.warning( + "Certificate parsing failed", + metadata: ["connectionId": "\(connection.id)", "error": "\(error)"] + ) return .code(.badCert) } } @@ -355,7 +412,9 @@ private struct PeerEventHandler: QuicEventHandler { connections.byId[connection.id] } guard let conn else { - logger.warning("Connected but connection is gone?", metadata: ["connectionId": "\(connection.id)"]) + logger.warning( + "Connected but connection is gone?", metadata: ["connectionId": "\(connection.id)"] + ) return } @@ -377,19 +436,70 @@ private struct PeerEventHandler: QuicEventHandler { func shutdownComplete(_ connection: QuicConnection) { logger.debug("connection shutdown complete", metadata: ["connectionId": "\(connection.id)"]) - impl.connections.write { connections in + let conn = impl.connections.read { connections in + connections.byId[connection.id] + } + let needReconnect = impl.connections.write { connections in if let conn = connections.byId[connection.id] { - // remove publickey first,func closed will change state to closed + let needReconnect = conn.needReconnect if let publicKey = conn.publicKey { connections.byPublicKey.removeValue(forKey: publicKey) } - conn.closed() connections.byId.removeValue(forKey: connection.id) connections.byAddr.removeValue(forKey: conn.remoteAddress) + conn.closed() + return needReconnect + } + return false + } + if needReconnect, let address = conn?.remoteAddress, let role = conn?.role { + do { + try impl.reconnect(to: address, role: role) + } catch { + logger.error("reconnect failed", metadata: ["error": "\(error)"]) } } } + func shutdownInitiated(_ connection: QuicConnection, reason: ConnectionCloseReason) { + logger.debug( + "Shutdown initiated", + metadata: ["connectionId": "\(connection.id)", "reason": "\(reason)"] + ) + if shouldReconnect(basedOn: reason) { + impl.connections.write { connections in + if let conn = connections.byId[connection.id] { + if let publicKey = conn.publicKey { + connections.byPublicKey.removeValue(forKey: publicKey) + conn.reconnect(publicKey: publicKey) + } + } + } + } + } + + // TODO: Add all the cases about reconnects + private func shouldReconnect(basedOn reason: ConnectionCloseReason) -> Bool { + switch reason { + case .idle: + // Do not reconnect for idle closures. + false + case let .transport(status, _): + switch QuicStatusCode(rawValue: status.rawValue) { + case .badCert: + false + default: + !status.isSucceeded + } + case let .byPeer(code): + // Do not reconnect if the closure was initiated by the peer. + code != .success + case let .byLocal(code): + // Do not reconnect if the local side initiated the closure. + code != .success + } + } + func streamStarted(_ connection: QuicConnection, stream: QuicStream) { let conn = impl.connections.read { connections in connections.byId[connection.id] @@ -419,10 +529,14 @@ private struct PeerEventHandler: QuicEventHandler { if let connection { connection.streamClosed(stream: stream, abort: !status.isSucceeded) } else { - logger.warning("Stream closed but connection is gone?", metadata: ["streamId": "\(stream.id)"]) + logger.warning( + "Stream closed but connection is gone?", metadata: ["streamId": "\(stream.id)"] + ) } } else { - logger.warning("Stream closed but stream is gone?", metadata: ["streamId": "\(quicStream.id)"]) + logger.warning( + "Stream closed but stream is gone?", metadata: ["streamId": "\(quicStream.id)"] + ) } } } diff --git a/Networking/Tests/NetworkingTests/PeerTests.swift b/Networking/Tests/NetworkingTests/PeerTests.swift index f776ad9c..0ec28edf 100644 --- a/Networking/Tests/NetworkingTests/PeerTests.swift +++ b/Networking/Tests/NetworkingTests/PeerTests.swift @@ -87,7 +87,7 @@ struct PeerTests { struct MockEphemeralStreamHandler: EphemeralStreamHandler { typealias StreamKind = EphemeralStreamKind typealias Request = MockRequest - private let dataStorage = DataStorage() + private let dataStorage: PeerTests.DataStorage = DataStorage() var lastReceivedData: Data? { get async { await dataStorage.data.last } @@ -175,10 +175,12 @@ struct PeerTests { try? await Task.sleep(for: .milliseconds(50)) if !connection1.isClosed { let data = try await connection1.request(MockRequest(kind: .typeA, data: Data("hello world".utf8))) + try? await Task.sleep(for: .milliseconds(50)) #expect(data == Data("hello world response".utf8)) } if !connection2.isClosed { let data = try await connection2.request(MockRequest(kind: .typeA, data: Data("hello world".utf8))) + try? await Task.sleep(for: .milliseconds(50)) #expect(data == Data("hello world response".utf8)) } } @@ -232,6 +234,7 @@ struct PeerTests { let receivedData1 = try await connection1.request( MockRequest(kind: .typeA, data: largeData) ) + try? await Task.sleep(for: .milliseconds(100)) // Verify that the received data matches the original large data #expect(receivedData1 == largeData + Data(" response".utf8)) @@ -239,19 +242,73 @@ struct PeerTests { peer1.broadcast( kind: .uniqueA, message: .init(kind: .uniqueA, data: largeData) ) - try? await Task.sleep(for: .milliseconds(50)) + try? await Task.sleep(for: .milliseconds(100)) peer2.broadcast( kind: .uniqueB, message: .init(kind: .uniqueB, data: largeData) ) // Verify last received data - try? await Task.sleep(for: .milliseconds(1000)) + try? await Task.sleep(for: .milliseconds(2000)) await #expect(handler2.lastReceivedData == largeData) await #expect(handler1.lastReceivedData == largeData) } @Test - func peerFailureRecovery() async throws { + func connectionNeedToReconnect() async throws { + let handler2 = MockPresentStreamHandler() + let messageData = Data("Post-recovery message".utf8) + + let peer1 = try Peer( + options: PeerOptions( + role: .validator, + listenAddress: NetAddr(ipAddress: "127.0.0.1", port: 0)!, + genesisHeader: Data32(), + secretKey: Ed25519.SecretKey(from: Data32.random()), + presistentStreamHandler: MockPresentStreamHandler(), + ephemeralStreamHandler: MockEphemeralStreamHandler(), + serverSettings: .defaultSettings, + clientSettings: .defaultSettings + ) + ) + let peer2 = try Peer( + options: PeerOptions( + role: .validator, + listenAddress: NetAddr(ipAddress: "127.0.0.1", port: 0)!, + genesisHeader: Data32(), + secretKey: Ed25519.SecretKey(from: Data32.random()), + presistentStreamHandler: handler2, + ephemeralStreamHandler: MockEphemeralStreamHandler(), + serverSettings: .defaultSettings, + clientSettings: .defaultSettings + ) + ) + try? await Task.sleep(for: .milliseconds(100)) + + let connection = try peer1.connect( + to: peer2.listenAddress(), role: .validator + ) + try? await Task.sleep(for: .milliseconds(100)) + + let receivedData = try await connection.request( + MockRequest(kind: .typeA, data: messageData) + ) + + #expect(receivedData == messageData + Data(" response".utf8)) + try? await Task.sleep(for: .milliseconds(100)) + // Simulate abnormal shutdown of connections + connection.close(abort: true) + // Wait to simulate downtime + try? await Task.sleep(for: .milliseconds(200)) + peer1.broadcast( + kind: .uniqueC, message: .init(kind: .uniqueC, data: messageData) + ) + try? await Task.sleep(for: .milliseconds(1000)) + let lastReceivedData = await handler2.lastReceivedData + #expect(lastReceivedData == messageData) + } + + @Test + func connectionNoNeedToReconnect() async throws { let handler2 = MockPresentStreamHandler() let messageData = Data("Post-recovery message".utf8) @@ -281,7 +338,42 @@ struct PeerTests { ) ) - let peer3 = try Peer( + try? await Task.sleep(for: .milliseconds(100)) + + let connection = try peer1.connect( + to: peer2.listenAddress(), role: .validator + ) + try? await Task.sleep(for: .milliseconds(100)) + // Simulate regular shutdown of connections + connection.close(abort: false) + // Wait to simulate downtime + try? await Task.sleep(for: .milliseconds(200)) + peer1.broadcast( + kind: .uniqueC, message: .init(kind: .uniqueC, data: messageData) + ) + try? await Task.sleep(for: .milliseconds(1000)) + await #expect(handler2.lastReceivedData == nil) + } + + @Test + func connectionManualReconnect() async throws { + let handler2 = MockPresentStreamHandler() + let messageData = Data("Post-recovery message".utf8) + + let peer1 = try Peer( + options: PeerOptions( + role: .validator, + listenAddress: NetAddr(ipAddress: "127.0.0.1", port: 0)!, + genesisHeader: Data32(), + secretKey: Ed25519.SecretKey(from: Data32.random()), + presistentStreamHandler: MockPresentStreamHandler(), + ephemeralStreamHandler: MockEphemeralStreamHandler(), + serverSettings: .defaultSettings, + clientSettings: .defaultSettings + ) + ) + + let peer2 = try Peer( options: PeerOptions( role: .validator, listenAddress: NetAddr(ipAddress: "127.0.0.1", port: 0)!, @@ -294,50 +386,39 @@ struct PeerTests { ) ) - try? await Task.sleep(for: .milliseconds(50)) + try? await Task.sleep(for: .milliseconds(100)) let connection = try peer1.connect( to: peer2.listenAddress(), role: .validator ) - try? await Task.sleep(for: .milliseconds(50)) + try? await Task.sleep(for: .milliseconds(100)) let receivedData = try await connection.request( MockRequest(kind: .typeA, data: messageData) ) #expect(receivedData == messageData + Data(" response".utf8)) - try? await Task.sleep(for: .milliseconds(50)) + try? await Task.sleep(for: .milliseconds(100)) // Simulate a peer failure by disconnecting one peer - connection.close(abort: true) + connection.close(abort: false) // Wait to simulate downtime try? await Task.sleep(for: .milliseconds(200)) - // check the peer is usable & connect to another peer - let connection2 = try peer1.connect( - to: peer3.listenAddress(), - role: .validator - ) - try? await Task.sleep(for: .milliseconds(50)) - let receivedData2 = try await connection2.request( - MockRequest(kind: .typeA, data: messageData) - ) - try? await Task.sleep(for: .milliseconds(50)) - #expect(receivedData2 == messageData + Data(" response".utf8)) // Reconnect the failing peer let reconnection = try peer1.connect( to: peer2.listenAddress(), role: .validator ) - try? await Task.sleep(for: .milliseconds(50)) + try? await Task.sleep(for: .milliseconds(100)) let recoverData = try await reconnection.request( MockRequest(kind: .typeA, data: messageData) ) - try? await Task.sleep(for: .milliseconds(50)) + try? await Task.sleep(for: .milliseconds(100)) #expect(recoverData == messageData + Data(" response".utf8)) peer1.broadcast( - kind: .uniqueC, message: .init(kind: .uniqueC, data: messageData) + kind: .uniqueC, message: .init(kind: .uniqueC, data: recoverData) ) - try? await Task.sleep(for: .milliseconds(50)) - await #expect(handler2.lastReceivedData == messageData) + try? await Task.sleep(for: .milliseconds(1000)) + await #expect(handler2.lastReceivedData == recoverData) } @Test @@ -384,7 +465,7 @@ struct PeerTests { kind: .uniqueB, message: .init(kind: .uniqueB, data: Data("I am jam".utf8)) ) // Verify last received data - try? await Task.sleep(for: .milliseconds(100)) + try? await Task.sleep(for: .milliseconds(200)) await #expect(handler2.lastReceivedData == Data("hello world".utf8)) await #expect(handler1.lastReceivedData == Data("I am jam".utf8)) } @@ -487,7 +568,7 @@ struct PeerTests { var peers: [Peer] = [] var handlers: [MockPresentStreamHandler] = [] // Create 100 peer nodes - for i in 0 ..< 100 { + for _ in 0 ..< 100 { let handler = MockPresentStreamHandler() handlers.append(handler) let peer = try Peer( @@ -580,6 +661,7 @@ struct PeerTests { to: otherPeer.listenAddress(), role: .validator ).request(MockRequest(kind: type, data: messageData)) + try? await Task.sleep(for: .milliseconds(100)) #expect(response == messageData + Data(" response".utf8), "Peer \(i) should receive correct response") }) } @@ -642,6 +724,7 @@ struct PeerTests { ) .unwrap() .request(MockRequest(kind: type, data: messageData)) + try? await Task.sleep(for: .milliseconds(50)) #expect(response == messageData + Data(" response".utf8), "Peer should receive correct response") } }