From eb7a1228f9e2a220ec99c530a7846b344b021750 Mon Sep 17 00:00:00 2001 From: MacOMNI <414294494@qq.com> Date: Sun, 24 Nov 2024 12:59:28 +0800 Subject: [PATCH] fix issue 226 (#230) * fix issue 226 * update test * update p2p --- .../Sources/MsQuicSwift/QuicConnection.swift | 9 +- Networking/Sources/Networking/Peer.swift | 24 ++-- .../NetworkingTests/MockPeerEventTests.swift | 15 ++- .../Tests/NetworkingTests/PeerTests.swift | 104 +++++++++++++++++- 4 files changed, 128 insertions(+), 24 deletions(-) diff --git a/Networking/Sources/MsQuicSwift/QuicConnection.swift b/Networking/Sources/MsQuicSwift/QuicConnection.swift index dddea21f..5218d34d 100644 --- a/Networking/Sources/MsQuicSwift/QuicConnection.swift +++ b/Networking/Sources/MsQuicSwift/QuicConnection.swift @@ -124,6 +124,12 @@ public final class QuicConnection: Sendable { } } + fileprivate func close() { + storage.write { storage in + storage = nil + } + } + public func shutdown(errorCode: QuicErrorCode = .success) throws { logger.debug("closing connection") try storage.write { storage in @@ -250,12 +256,13 @@ private class ConnectionHandle { } case QUIC_CONNECTION_EVENT_SHUTDOWN_COMPLETE: - logger.trace("Shutdown complete") + logger.debug("Shutdown complete") if let connection { connection.handler.shutdownComplete(connection) } if event.pointee.SHUTDOWN_COMPLETE.AppCloseInProgress == 0 { // avoid closing twice + connection?.close() api.call { api in api.pointee.ConnectionClose(ptr) } diff --git a/Networking/Sources/Networking/Peer.swift b/Networking/Sources/Networking/Peer.swift index 32829072..342cd881 100644 --- a/Networking/Sources/Networking/Peer.swift +++ b/Networking/Sources/Networking/Peer.swift @@ -294,7 +294,7 @@ final class PeerImpl: Sendable { } func reconnect(to address: NetAddr, role: PeerRole) throws { - let state = reconnectStates.read { reconnectStates in + var state = reconnectStates.read { reconnectStates in reconnectStates[address] ?? .init() } @@ -302,12 +302,9 @@ final class PeerImpl: Sendable { logger.warning("reconnecting to \(address) exceeded max attempts") return } - + state.applyBackoff() reconnectStates.write { reconnectStates in - if var state = reconnectStates[address] { - state.applyBackoff() - reconnectStates[address] = state - } + reconnectStates[address] = state } Task { try await Task.sleep(for: .seconds(state.delay)) @@ -336,7 +333,7 @@ final class PeerImpl: Sendable { } func reopenUpStream(connection: Connection, kind: Handler.PresistentHandler.StreamKind) { - let state = reopenStates.read { states in + var state = reopenStates.read { states in states[connection.id] ?? .init() } @@ -344,12 +341,9 @@ final class PeerImpl: Sendable { logger.warning("Reopen attempt for stream \(kind) on connection \(connection.id) exceeded max attempts") return } - + state.applyBackoff() reopenStates.write { states in - if var state = states[connection.id] { - state.applyBackoff() - states[connection.id] = state - } + states[connection.id] = state } Task { @@ -557,10 +551,10 @@ private struct PeerEventHandler: QuicEventHandler { false case let .transport(status, _): switch QuicStatusCode(rawValue: status.rawValue) { - case .badCert: - false + case .aborted, .outOfMemory, .connectionTimeout, .unreachable, .bufferTooSmall, .connectionRefused: + true default: - !status.isSucceeded + status.isSucceeded } case let .byPeer(code): // Do not reconnect if the closure was initiated by the peer. diff --git a/Networking/Tests/NetworkingTests/MockPeerEventTests.swift b/Networking/Tests/NetworkingTests/MockPeerEventTests.swift index 71b4793a..f837c576 100644 --- a/Networking/Tests/NetworkingTests/MockPeerEventTests.swift +++ b/Networking/Tests/NetworkingTests/MockPeerEventTests.swift @@ -7,6 +7,11 @@ import Utils final class MockPeerEventTests { final class MockPeerEventHandler: QuicEventHandler { + enum MockPeerAction { + case none + case mockHandshakeFailure + } + enum EventType { case newConnection(listener: QuicListener, connection: QuicConnection, info: ConnectionInfo) case shouldOpen(connection: QuicConnection, certificate: Data?) @@ -18,8 +23,11 @@ final class MockPeerEventTests { } let events: ThreadSafeContainer<[EventType]> = .init([]) + let mockAction: MockPeerAction - init() {} + init(_ action: MockPeerAction = .none) { + mockAction = action + } func newConnection( _ listener: QuicListener, connection: QuicConnection, info: ConnectionInfo @@ -32,6 +40,9 @@ final class MockPeerEventTests { } func shouldOpen(_: QuicConnection, certificate: Data?) -> QuicStatus { + if mockAction == .mockHandshakeFailure { + return .code(.handshakeFailure) + } guard let certificate else { return .code(.requiredCert) } @@ -169,7 +180,7 @@ final class MockPeerEventTests { func connected() async throws { let serverHandler = MockPeerEventHandler() let clientHandler = MockPeerEventHandler() - let privateKey1 = try Ed25519.SecretKey(from: Data32()) + let privateKey1 = try Ed25519.SecretKey(from: Data32.random()) let cert = try generateSelfSignedCertificate(privateKey: privateKey1) let serverConfiguration = try QuicConfiguration( registration: registration, diff --git a/Networking/Tests/NetworkingTests/PeerTests.swift b/Networking/Tests/NetworkingTests/PeerTests.swift index ca189d51..29b2f3dd 100644 --- a/Networking/Tests/NetworkingTests/PeerTests.swift +++ b/Networking/Tests/NetworkingTests/PeerTests.swift @@ -143,6 +143,98 @@ struct PeerTests { typealias EphemeralHandler = MockEphemeralStreamHandler } + @Test + func mockHandshakeFailure() async throws { + let mockPeerTest = try MockPeerEventTests() + let serverHandler = MockPeerEventTests.MockPeerEventHandler( + MockPeerEventTests.MockPeerEventHandler.MockPeerAction.mockHandshakeFailure + ) + let alpns = [ + PeerRole.validator: Alpn(genesisHeader: Data32(), builder: false).data, + PeerRole.builder: Alpn(genesisHeader: Data32(), builder: true).data, + ] + let allAlpns = Array(alpns.values) + // Server setup with bad certificate + let serverConfiguration = try QuicConfiguration( + registration: mockPeerTest.registration, + pkcs12: mockPeerTest.certData, + alpns: allAlpns, + client: false, + settings: QuicSettings.defaultSettings + ) + + let listener = try QuicListener( + handler: serverHandler, + registration: mockPeerTest.registration, + configuration: serverConfiguration, + listenAddress: NetAddr(ipAddress: "127.0.0.1", port: 0)!, + alpns: allAlpns + ) + + let listenAddress = try listener.listenAddress() + let peer1 = try Peer( + options: PeerOptions( + role: .validator, + listenAddress: NetAddr(ipAddress: "127.0.0.1", port: 0)!, + genesisHeader: Data32(), + secretKey: Ed25519.SecretKey(from: Data32.random()), + presistentStreamHandler: MockPresentStreamHandler(), + ephemeralStreamHandler: MockEphemeralStreamHandler(), + serverSettings: .defaultSettings, + clientSettings: .defaultSettings + ) + ) + + let connection1 = try peer1.connect(to: listenAddress, role: .validator) + try? await Task.sleep(for: .milliseconds(3000)) + #expect(connection1.isClosed == true) + } + + @Test + func mockShutdownBadCert() async throws { + let mockPeerTest = try MockPeerEventTests() + let serverHandler = MockPeerEventTests.MockPeerEventHandler() + let alpns = [ + PeerRole.validator: Alpn(genesisHeader: Data32(), builder: false).data, + PeerRole.builder: Alpn(genesisHeader: Data32(), builder: true).data, + ] + let allAlpns = Array(alpns.values) + // Server setup with bad certificate + let serverConfiguration = try QuicConfiguration( + registration: mockPeerTest.registration, + pkcs12: mockPeerTest.badCertData, + alpns: allAlpns, + client: false, + settings: QuicSettings.defaultSettings + ) + + let listener = try QuicListener( + handler: serverHandler, + registration: mockPeerTest.registration, + configuration: serverConfiguration, + listenAddress: NetAddr(ipAddress: "127.0.0.1", port: 0)!, + alpns: allAlpns + ) + + let listenAddress = try listener.listenAddress() + let peer1 = try Peer( + options: PeerOptions( + role: .validator, + listenAddress: NetAddr(ipAddress: "127.0.0.1", port: 0)!, + genesisHeader: Data32(), + secretKey: Ed25519.SecretKey(from: Data32.random()), + presistentStreamHandler: MockPresentStreamHandler(), + ephemeralStreamHandler: MockEphemeralStreamHandler(), + serverSettings: .defaultSettings, + clientSettings: .defaultSettings + ) + ) + + let connection1 = try peer1.connect(to: listenAddress, role: .validator) + try? await Task.sleep(for: .milliseconds(1000)) + #expect(connection1.isClosed == true) + } + @Test func reopenUpStream() async throws { let handler2 = MockPresentStreamHandler() @@ -197,7 +289,7 @@ struct PeerTests { peer1.broadcast( kind: .uniqueA, message: .init(kind: .uniqueA, data: messageData) ) - try await Task.sleep(for: .milliseconds(1000)) + try await Task.sleep(for: .milliseconds(2000)) let lastReceivedData2 = await handler2.lastReceivedData #expect(lastReceivedData2 == messageData) } @@ -290,15 +382,15 @@ struct PeerTests { let connection1 = try peer1.connect(to: peer2.listenAddress(), role: .validator) let connection2 = try peer2.connect(to: peer1.listenAddress(), role: .validator) - try? await Task.sleep(for: .milliseconds(50)) + try? await Task.sleep(for: .milliseconds(1000)) if !connection1.isClosed { let data = try await connection1.request(MockRequest(kind: .typeA, data: Data("hello world".utf8))) - try? await Task.sleep(for: .milliseconds(50)) + try? await Task.sleep(for: .milliseconds(500)) #expect(data == Data("hello world response".utf8)) } if !connection2.isClosed { let data = try await connection2.request(MockRequest(kind: .typeA, data: Data("hello world".utf8))) - try? await Task.sleep(for: .milliseconds(50)) + try? await Task.sleep(for: .milliseconds(500)) #expect(data == Data("hello world response".utf8)) } } @@ -573,7 +665,7 @@ struct PeerTests { to: peer2.listenAddress(), role: .validator ) - try? await Task.sleep(for: .milliseconds(50)) + try? await Task.sleep(for: .milliseconds(500)) peer1.broadcast( kind: .uniqueA, message: .init(kind: .uniqueA, data: Data("hello world".utf8)) @@ -583,7 +675,7 @@ struct PeerTests { kind: .uniqueB, message: .init(kind: .uniqueB, data: Data("I am jam".utf8)) ) // Verify last received data - try? await Task.sleep(for: .milliseconds(200)) + try? await Task.sleep(for: .milliseconds(500)) await #expect(handler2.lastReceivedData == Data("hello world".utf8)) await #expect(handler1.lastReceivedData == Data("I am jam".utf8)) }