diff --git a/Networking/Sources/MsQuicSwift/QuicConnection.swift b/Networking/Sources/MsQuicSwift/QuicConnection.swift index 5218d34d..48b88ff8 100644 --- a/Networking/Sources/MsQuicSwift/QuicConnection.swift +++ b/Networking/Sources/MsQuicSwift/QuicConnection.swift @@ -124,7 +124,7 @@ public final class QuicConnection: Sendable { } } - fileprivate func close() { + func close() { storage.write { storage in storage = nil } @@ -272,8 +272,8 @@ private class ConnectionHandle { case QUIC_CONNECTION_EVENT_PEER_STREAM_STARTED: logger.debug("Peer stream started") let streamPtr = event.pointee.PEER_STREAM_STARTED.Stream - if let connection { - let stream = QuicStream(connection: connection, stream: streamPtr!, handler: connection.handler) + if let connection, let streamPtr, connection.api != nil { + let stream = QuicStream(connection: connection, stream: streamPtr, handler: connection.handler) connection.handler.streamStarted(connection, stream: stream) } else { logger.warning("Stream started but connection is gone?") diff --git a/Networking/Tests/MsQuicSwiftTests/QuicListenerTests.swift b/Networking/Tests/MsQuicSwiftTests/QuicListenerTests.swift index 34029236..146419a0 100644 --- a/Networking/Tests/MsQuicSwiftTests/QuicListenerTests.swift +++ b/Networking/Tests/MsQuicSwiftTests/QuicListenerTests.swift @@ -304,4 +304,85 @@ struct QuicListenerTests { #expect(receivedData2[0] == Data("other test data 2".utf8)) #expect(receivedData2[1] == Data("another replay to 2".utf8)) } + + @Test + func mockConnectionShutdown() async throws { + let serverHandler = MockQuicEventHandler() + let clientHandler = MockQuicEventHandler() + + // create listener + + let quicSettings = QuicSettings.defaultSettings + let serverConfiguration = try QuicConfiguration( + registration: registration, + pkcs12: pkcs12Data, + alpns: [Data("testalpn".utf8)], + client: false, + settings: quicSettings + ) + + let listener = try QuicListener( + handler: serverHandler, + registration: registration, + configuration: serverConfiguration, + listenAddress: NetAddr(ipAddress: "127.0.0.1", port: 0)!, + alpns: [Data("testalpn".utf8)] + ) + + let listenAddress = try listener.listenAddress() + let (ipAddress, port) = listenAddress.getAddressAndPort() + #expect(ipAddress == "127.0.0.1") + #expect(port != 0) + + // create connection to listener + + let clientConfiguration = try QuicConfiguration( + registration: registration, + pkcs12: pkcs12Data, + alpns: [Data("testalpn".utf8)], + client: true, + settings: quicSettings + ) + + let clientConnection = try QuicConnection( + handler: clientHandler, + registration: registration, + configuration: clientConfiguration + ) + + try clientConnection.connect(to: listenAddress) + + let stream1 = try clientConnection.createStream() + + try? await Task.sleep(for: .milliseconds(100)) + let (serverConnection, info) = serverHandler.events.value.compactMap { + switch $0 { + case let .newConnection(_, connection, info): + (connection, info) as (QuicConnection, ConnectionInfo)? + default: + nil + } + }.first! + + let (ipAddress2, _) = info.remoteAddress.getAddressAndPort() + + #expect(info.negotiatedAlpn == Data("testalpn".utf8)) + #expect(info.serverName == "127.0.0.1") + #expect(info.localAddress == listenAddress) + #expect(ipAddress2 == "127.0.0.1") + + try stream1.send(data: Data("test data 1".utf8)) + serverConnection.close() + + try? await Task.sleep(for: .milliseconds(1000)) + let receivedData = serverHandler.events.value.compactMap { + switch $0 { + case let .dataReceived(_, data): + data + default: + nil + } + } + #expect(receivedData == []) + } }