diff --git a/Sources/WSClient/WebSocketClientChannel.swift b/Sources/WSClient/WebSocketClientChannel.swift index 0d5de1c..a7f6c4f 100644 --- a/Sources/WSClient/WebSocketClientChannel.swift +++ b/Sources/WSClient/WebSocketClientChannel.swift @@ -110,6 +110,7 @@ struct WebSocketClientChannel: ClientConnectionChannel { configuration: .init( extensions: extensions, autoPing: self.configuration.autoPing, + closeTimeout: self.configuration.closeTimeout, validateUTF8: self.configuration.validateUTF8 ), asyncChannel: webSocketChannel, diff --git a/Sources/WSClient/WebSocketClientConfiguration.swift b/Sources/WSClient/WebSocketClientConfiguration.swift index c8f647f..dcbb6ff 100644 --- a/Sources/WSClient/WebSocketClientConfiguration.swift +++ b/Sources/WSClient/WebSocketClientConfiguration.swift @@ -23,6 +23,8 @@ public struct WebSocketClientConfiguration: Sendable { public var additionalHeaders: HTTPFields /// WebSocket extensions public var extensions: [any WebSocketExtensionBuilder] + /// Close timeout + public var closeTimeout: Duration /// Automatic ping setup public var autoPing: AutoPingSetup /// Should text be validated to be UTF8 @@ -39,12 +41,14 @@ public struct WebSocketClientConfiguration: Sendable { maxFrameSize: Int = (1 << 14), additionalHeaders: HTTPFields = .init(), extensions: [WebSocketExtensionFactory] = [], + closeTimeout: Duration = .seconds(15), autoPing: AutoPingSetup = .disabled, validateUTF8: Bool = false ) { self.maxFrameSize = maxFrameSize self.additionalHeaders = additionalHeaders self.extensions = extensions.map { $0.build() } + self.closeTimeout = closeTimeout self.autoPing = autoPing self.validateUTF8 = validateUTF8 } diff --git a/Sources/WSCore/WebSocketHandler.swift b/Sources/WSCore/WebSocketHandler.swift index 9668e11..293fa31 100644 --- a/Sources/WSCore/WebSocketHandler.swift +++ b/Sources/WSCore/WebSocketHandler.swift @@ -72,10 +72,17 @@ public struct WebSocketCloseFrame: Sendable { let autoPing: AutoPingSetup let validateUTF8: Bool let reservedBits: WebSocketFrame.ReservedBits + let closeTimeout: Duration - @_spi(WSInternal) public init(extensions: [any WebSocketExtension], autoPing: AutoPingSetup, validateUTF8: Bool) { + @_spi(WSInternal) public init( + extensions: [any WebSocketExtension], + autoPing: AutoPingSetup, + closeTimeout: Duration = .seconds(15), + validateUTF8: Bool + ) { self.extensions = extensions self.autoPing = autoPing + self.closeTimeout = closeTimeout self.validateUTF8 = validateUTF8 // store reserved bits used by this handler self.reservedBits = extensions.reduce(.init()) { partialResult, `extension` in @@ -118,32 +125,24 @@ public struct WebSocketCloseFrame: Sendable { } do { let rt = try await asyncChannel.executeThenClose { inbound, outbound in - try await withTaskCancellationHandler { - try await withThrowingTaskGroup(of: WebSocketCloseFrame.self) { group in - let webSocketHandler = Self( - channel: asyncChannel.channel, - outbound: outbound, - type: type, - configuration: configuration, - context: context - ) - if case .enabled = configuration.autoPing.value { - /// Add task sending ping frames every so often and verifying a pong frame was sent back - group.addTask { - try await webSocketHandler.runAutoPingLoop() - return .init(closeCode: .goingAway, reason: "Ping timeout") - } - } - let rt = try await webSocketHandler.handle( - type: type, - inbound: inbound, - outbound: outbound, - handler: handler, - context: context - ) - group.cancelAll() - return rt - } + defer { + context.logger.trace("Closing WebSocket") + } + return try await withTaskCancellationHandler { + let webSocketHandler = Self( + channel: asyncChannel.channel, + outbound: outbound, + type: type, + configuration: configuration, + context: context + ) + return try await webSocketHandler.handle( + type: type, + inbound: inbound, + outbound: outbound, + handler: handler, + context: context + ) } onCancel: { Task { try await asyncChannel.channel.close(mode: .input) @@ -166,39 +165,57 @@ public struct WebSocketCloseFrame: Sendable { context: Context ) async throws -> WebSocketCloseFrame? { try await withGracefulShutdownHandler { - let webSocketOutbound = WebSocketOutboundWriter(handler: self) - var inboundIterator = inbound.makeAsyncIterator() - let webSocketInbound = WebSocketInboundStream( - iterator: inboundIterator, - handler: self - ) - let closeCode: WebSocketErrorCode - var clientError: Error? - do { - // handle websocket data and text - try await handler(webSocketInbound, webSocketOutbound, context) - closeCode = .normalClosure - } catch InternalError.close(let code) { - closeCode = code - } catch { - clientError = error - closeCode = .unexpectedServerError - } - do { - try await self.close(code: closeCode) - if case .closing = self.stateMachine.state { - // Close handshake. Wait for responding close or until inbound ends - while let frame = try await inboundIterator.next() { - if case .connectionClose = frame.opcode { - try await self.receivedClose(frame) - break + try await withThrowingTaskGroup(of: Void.self) { group in + if case .enabled = configuration.autoPing.value { + /// Add task sending ping frames every so often and verifying a pong frame was sent back + group.addTask { + try await self.runAutoPingLoop() + } + } + let webSocketOutbound = WebSocketOutboundWriter(handler: self) + var inboundIterator = inbound.makeAsyncIterator() + let webSocketInbound = WebSocketInboundStream( + iterator: inboundIterator, + handler: self + ) + let closeCode: WebSocketErrorCode + var clientError: Error? + do { + // handle websocket data and text + try await handler(webSocketInbound, webSocketOutbound, context) + closeCode = .normalClosure + } catch InternalError.close(let code) { + closeCode = code + } catch { + clientError = error + closeCode = .unexpectedServerError + } + do { + try await self.close(code: closeCode) + if case .closing = self.stateMachine.state { + group.addTask { + try await Task.sleep(for: self.configuration.closeTimeout) + try await self.channel.close(mode: .input) + } + // Close handshake. Wait for responding close or until inbound ends + while let frame = try await inboundIterator.next() { + if case .connectionClose = frame.opcode { + try await self.receivedClose(frame) + // only the server can close the connection, so clients + // should continue reading from inbound until it is closed + if type == .server { + break + } + } } } + // don't propagate error if channel is already closed + } catch ChannelError.ioOnClosedChannel {} + if type == .client, let clientError { + throw clientError } - // don't propagate error if channel is already closed - } catch ChannelError.ioOnClosedChannel {} - if type == .client, let clientError { - throw clientError + + group.cancelAll() } } onGracefulShutdown: { Task { diff --git a/Tests/WebSocketTests/AutobahnTests.swift b/Tests/WebSocketTests/AutobahnTests.swift index f4998ad..f23e6b8 100644 --- a/Tests/WebSocketTests/AutobahnTests.swift +++ b/Tests/WebSocketTests/AutobahnTests.swift @@ -23,6 +23,8 @@ import XCTest /// The Autobahn|Testsuite provides a fully automated test suite to verify client and server /// implementations of The WebSocket Protocol for specification conformance and implementation robustness. /// You can find out more at https://github.com/crossbario/autobahn-testsuite +/// +/// Before running these tests run `./scripts/autobahn-server.sh` to running the test server. final class AutobahnTests: XCTestCase { /// To run all the autobahn compression tests takes a long time. By default we only run a selection. /// The `AUTOBAHN_ALL_TESTS` environment flag triggers running all of them. @@ -58,6 +60,9 @@ final class AutobahnTests: XCTestCase { cases: Set, extensions: [WebSocketExtensionFactory] = [.perMessageDeflate(maxDecompressedFrameSize: 16_777_216)] ) async throws { + // These are broken in CI currently + try XCTSkipIf(ProcessInfo.processInfo.environment["CI"] != nil) + struct CaseInfo: Decodable { let id: String let description: String @@ -121,8 +126,6 @@ final class AutobahnTests: XCTestCase { } func test_3_ReservedBits() async throws { - // Reserved bits tests fail - try XCTSkipIf(true) try await self.autobahnTests(cases: .init(28..<35)) } diff --git a/Tests/WebSocketTests/ClientTests.swift b/Tests/WebSocketTests/ClientTests.swift new file mode 100644 index 0000000..87180d4 --- /dev/null +++ b/Tests/WebSocketTests/ClientTests.swift @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIOCore +import NIOSSL +import NIOWebSocket +import WSClient +import XCTest + +final class WebSocketClientTests: XCTestCase { + + func testEchoServer() async throws { + let clientLogger = { + var logger = Logger(label: "client") + logger.logLevel = .trace + return logger + }() + try await WebSocketClient.connect( + url: "wss://echo.websocket.org/", + tlsConfiguration: TLSConfiguration.makeClientConfiguration(), + logger: clientLogger + ) { inbound, outbound, _ in + var inboundIterator = inbound.messages(maxSize: .max).makeAsyncIterator() + try await outbound.write(.text("hello")) + if let msg = try await inboundIterator.next() { + print(msg) + } + } + } +}