Skip to content

Commit

Permalink
fix issue 226 (#230)
Browse files Browse the repository at this point in the history
* fix issue 226

* update test

* update p2p
  • Loading branch information
MacOMNI authored Nov 24, 2024
1 parent 161eb10 commit eb7a122
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 24 deletions.
9 changes: 8 additions & 1 deletion Networking/Sources/MsQuicSwift/QuicConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ public final class QuicConnection: Sendable {
}
}

fileprivate func close() {
storage.write { storage in
storage = nil
}
}

public func shutdown(errorCode: QuicErrorCode = .success) throws {
logger.debug("closing connection")
try storage.write { storage in
Expand Down Expand Up @@ -250,12 +256,13 @@ private class ConnectionHandle {
}

case QUIC_CONNECTION_EVENT_SHUTDOWN_COMPLETE:
logger.trace("Shutdown complete")
logger.debug("Shutdown complete")
if let connection {
connection.handler.shutdownComplete(connection)
}
if event.pointee.SHUTDOWN_COMPLETE.AppCloseInProgress == 0 {
// avoid closing twice
connection?.close()
api.call { api in
api.pointee.ConnectionClose(ptr)
}
Expand Down
24 changes: 9 additions & 15 deletions Networking/Sources/Networking/Peer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -294,20 +294,17 @@ final class PeerImpl<Handler: StreamHandler>: Sendable {
}

func reconnect(to address: NetAddr, role: PeerRole) throws {
let state = reconnectStates.read { reconnectStates in
var state = reconnectStates.read { reconnectStates in
reconnectStates[address] ?? .init()
}

guard state.attempt < maxRetryAttempts else {
logger.warning("reconnecting to \(address) exceeded max attempts")
return
}

state.applyBackoff()
reconnectStates.write { reconnectStates in
if var state = reconnectStates[address] {
state.applyBackoff()
reconnectStates[address] = state
}
reconnectStates[address] = state
}
Task {
try await Task.sleep(for: .seconds(state.delay))
Expand Down Expand Up @@ -336,20 +333,17 @@ final class PeerImpl<Handler: StreamHandler>: Sendable {
}

func reopenUpStream(connection: Connection<Handler>, kind: Handler.PresistentHandler.StreamKind) {
let state = reopenStates.read { states in
var state = reopenStates.read { states in
states[connection.id] ?? .init()
}

guard state.attempt < maxRetryAttempts else {
logger.warning("Reopen attempt for stream \(kind) on connection \(connection.id) exceeded max attempts")
return
}

state.applyBackoff()
reopenStates.write { states in
if var state = states[connection.id] {
state.applyBackoff()
states[connection.id] = state
}
states[connection.id] = state
}

Task {
Expand Down Expand Up @@ -557,10 +551,10 @@ private struct PeerEventHandler<Handler: StreamHandler>: QuicEventHandler {
false
case let .transport(status, _):
switch QuicStatusCode(rawValue: status.rawValue) {
case .badCert:
false
case .aborted, .outOfMemory, .connectionTimeout, .unreachable, .bufferTooSmall, .connectionRefused:
true
default:
!status.isSucceeded
status.isSucceeded
}
case let .byPeer(code):
// Do not reconnect if the closure was initiated by the peer.
Expand Down
15 changes: 13 additions & 2 deletions Networking/Tests/NetworkingTests/MockPeerEventTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ import Utils

final class MockPeerEventTests {
final class MockPeerEventHandler: QuicEventHandler {
enum MockPeerAction {
case none
case mockHandshakeFailure
}

enum EventType {
case newConnection(listener: QuicListener, connection: QuicConnection, info: ConnectionInfo)
case shouldOpen(connection: QuicConnection, certificate: Data?)
Expand All @@ -18,8 +23,11 @@ final class MockPeerEventTests {
}

let events: ThreadSafeContainer<[EventType]> = .init([])
let mockAction: MockPeerAction

init() {}
init(_ action: MockPeerAction = .none) {
mockAction = action
}

func newConnection(
_ listener: QuicListener, connection: QuicConnection, info: ConnectionInfo
Expand All @@ -32,6 +40,9 @@ final class MockPeerEventTests {
}

func shouldOpen(_: QuicConnection, certificate: Data?) -> QuicStatus {
if mockAction == .mockHandshakeFailure {
return .code(.handshakeFailure)
}
guard let certificate else {
return .code(.requiredCert)
}
Expand Down Expand Up @@ -169,7 +180,7 @@ final class MockPeerEventTests {
func connected() async throws {
let serverHandler = MockPeerEventHandler()
let clientHandler = MockPeerEventHandler()
let privateKey1 = try Ed25519.SecretKey(from: Data32())
let privateKey1 = try Ed25519.SecretKey(from: Data32.random())
let cert = try generateSelfSignedCertificate(privateKey: privateKey1)
let serverConfiguration = try QuicConfiguration(
registration: registration,
Expand Down
104 changes: 98 additions & 6 deletions Networking/Tests/NetworkingTests/PeerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,98 @@ struct PeerTests {
typealias EphemeralHandler = MockEphemeralStreamHandler
}

@Test
func mockHandshakeFailure() async throws {
let mockPeerTest = try MockPeerEventTests()
let serverHandler = MockPeerEventTests.MockPeerEventHandler(
MockPeerEventTests.MockPeerEventHandler.MockPeerAction.mockHandshakeFailure
)
let alpns = [
PeerRole.validator: Alpn(genesisHeader: Data32(), builder: false).data,
PeerRole.builder: Alpn(genesisHeader: Data32(), builder: true).data,
]
let allAlpns = Array(alpns.values)
// Server setup with bad certificate
let serverConfiguration = try QuicConfiguration(
registration: mockPeerTest.registration,
pkcs12: mockPeerTest.certData,
alpns: allAlpns,
client: false,
settings: QuicSettings.defaultSettings
)

let listener = try QuicListener(
handler: serverHandler,
registration: mockPeerTest.registration,
configuration: serverConfiguration,
listenAddress: NetAddr(ipAddress: "127.0.0.1", port: 0)!,
alpns: allAlpns
)

let listenAddress = try listener.listenAddress()
let peer1 = try Peer(
options: PeerOptions<MockStreamHandler>(
role: .validator,
listenAddress: NetAddr(ipAddress: "127.0.0.1", port: 0)!,
genesisHeader: Data32(),
secretKey: Ed25519.SecretKey(from: Data32.random()),
presistentStreamHandler: MockPresentStreamHandler(),
ephemeralStreamHandler: MockEphemeralStreamHandler(),
serverSettings: .defaultSettings,
clientSettings: .defaultSettings
)
)

let connection1 = try peer1.connect(to: listenAddress, role: .validator)
try? await Task.sleep(for: .milliseconds(3000))
#expect(connection1.isClosed == true)
}

@Test
func mockShutdownBadCert() async throws {
let mockPeerTest = try MockPeerEventTests()
let serverHandler = MockPeerEventTests.MockPeerEventHandler()
let alpns = [
PeerRole.validator: Alpn(genesisHeader: Data32(), builder: false).data,
PeerRole.builder: Alpn(genesisHeader: Data32(), builder: true).data,
]
let allAlpns = Array(alpns.values)
// Server setup with bad certificate
let serverConfiguration = try QuicConfiguration(
registration: mockPeerTest.registration,
pkcs12: mockPeerTest.badCertData,
alpns: allAlpns,
client: false,
settings: QuicSettings.defaultSettings
)

let listener = try QuicListener(
handler: serverHandler,
registration: mockPeerTest.registration,
configuration: serverConfiguration,
listenAddress: NetAddr(ipAddress: "127.0.0.1", port: 0)!,
alpns: allAlpns
)

let listenAddress = try listener.listenAddress()
let peer1 = try Peer(
options: PeerOptions<MockStreamHandler>(
role: .validator,
listenAddress: NetAddr(ipAddress: "127.0.0.1", port: 0)!,
genesisHeader: Data32(),
secretKey: Ed25519.SecretKey(from: Data32.random()),
presistentStreamHandler: MockPresentStreamHandler(),
ephemeralStreamHandler: MockEphemeralStreamHandler(),
serverSettings: .defaultSettings,
clientSettings: .defaultSettings
)
)

let connection1 = try peer1.connect(to: listenAddress, role: .validator)
try? await Task.sleep(for: .milliseconds(1000))
#expect(connection1.isClosed == true)
}

@Test
func reopenUpStream() async throws {
let handler2 = MockPresentStreamHandler()
Expand Down Expand Up @@ -197,7 +289,7 @@ struct PeerTests {
peer1.broadcast(
kind: .uniqueA, message: .init(kind: .uniqueA, data: messageData)
)
try await Task.sleep(for: .milliseconds(1000))
try await Task.sleep(for: .milliseconds(2000))
let lastReceivedData2 = await handler2.lastReceivedData
#expect(lastReceivedData2 == messageData)
}
Expand Down Expand Up @@ -290,15 +382,15 @@ struct PeerTests {

let connection1 = try peer1.connect(to: peer2.listenAddress(), role: .validator)
let connection2 = try peer2.connect(to: peer1.listenAddress(), role: .validator)
try? await Task.sleep(for: .milliseconds(50))
try? await Task.sleep(for: .milliseconds(1000))
if !connection1.isClosed {
let data = try await connection1.request(MockRequest(kind: .typeA, data: Data("hello world".utf8)))
try? await Task.sleep(for: .milliseconds(50))
try? await Task.sleep(for: .milliseconds(500))
#expect(data == Data("hello world response".utf8))
}
if !connection2.isClosed {
let data = try await connection2.request(MockRequest(kind: .typeA, data: Data("hello world".utf8)))
try? await Task.sleep(for: .milliseconds(50))
try? await Task.sleep(for: .milliseconds(500))
#expect(data == Data("hello world response".utf8))
}
}
Expand Down Expand Up @@ -573,7 +665,7 @@ struct PeerTests {
to: peer2.listenAddress(), role: .validator
)

try? await Task.sleep(for: .milliseconds(50))
try? await Task.sleep(for: .milliseconds(500))

peer1.broadcast(
kind: .uniqueA, message: .init(kind: .uniqueA, data: Data("hello world".utf8))
Expand All @@ -583,7 +675,7 @@ struct PeerTests {
kind: .uniqueB, message: .init(kind: .uniqueB, data: Data("I am jam".utf8))
)
// Verify last received data
try? await Task.sleep(for: .milliseconds(200))
try? await Task.sleep(for: .milliseconds(500))
await #expect(handler2.lastReceivedData == Data("hello world".utf8))
await #expect(handler1.lastReceivedData == Data("I am jam".utf8))
}
Expand Down

0 comments on commit eb7a122

Please sign in to comment.