Skip to content

Commit

Permalink
Don't let client close the connection wait for server to initiate clo…
Browse files Browse the repository at this point in the history
…se (#10)

* Don't let client close the connection wait for server

* Add timeout for close

* Change Closing websocket debug to trace

* Disable autobahn tests in CI
  • Loading branch information
adam-fowler authored Dec 19, 2024
1 parent e9763cf commit 06e5e25
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 59 deletions.
1 change: 1 addition & 0 deletions Sources/WSClient/WebSocketClientChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ struct WebSocketClientChannel: ClientConnectionChannel {
configuration: .init(
extensions: extensions,
autoPing: self.configuration.autoPing,
closeTimeout: self.configuration.closeTimeout,
validateUTF8: self.configuration.validateUTF8
),
asyncChannel: webSocketChannel,
Expand Down
4 changes: 4 additions & 0 deletions Sources/WSClient/WebSocketClientConfiguration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ public struct WebSocketClientConfiguration: Sendable {
public var additionalHeaders: HTTPFields
/// WebSocket extensions
public var extensions: [any WebSocketExtensionBuilder]
/// Close timeout
public var closeTimeout: Duration
/// Automatic ping setup
public var autoPing: AutoPingSetup
/// Should text be validated to be UTF8
Expand All @@ -39,12 +41,14 @@ public struct WebSocketClientConfiguration: Sendable {
maxFrameSize: Int = (1 << 14),
additionalHeaders: HTTPFields = .init(),
extensions: [WebSocketExtensionFactory] = [],
closeTimeout: Duration = .seconds(15),
autoPing: AutoPingSetup = .disabled,
validateUTF8: Bool = false
) {
self.maxFrameSize = maxFrameSize
self.additionalHeaders = additionalHeaders
self.extensions = extensions.map { $0.build() }
self.closeTimeout = closeTimeout
self.autoPing = autoPing
self.validateUTF8 = validateUTF8
}
Expand Down
131 changes: 74 additions & 57 deletions Sources/WSCore/WebSocketHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,17 @@ public struct WebSocketCloseFrame: Sendable {
let autoPing: AutoPingSetup
let validateUTF8: Bool
let reservedBits: WebSocketFrame.ReservedBits
let closeTimeout: Duration

@_spi(WSInternal) public init(extensions: [any WebSocketExtension], autoPing: AutoPingSetup, validateUTF8: Bool) {
@_spi(WSInternal) public init(
extensions: [any WebSocketExtension],
autoPing: AutoPingSetup,
closeTimeout: Duration = .seconds(15),
validateUTF8: Bool
) {
self.extensions = extensions
self.autoPing = autoPing
self.closeTimeout = closeTimeout
self.validateUTF8 = validateUTF8
// store reserved bits used by this handler
self.reservedBits = extensions.reduce(.init()) { partialResult, `extension` in
Expand Down Expand Up @@ -118,32 +125,24 @@ public struct WebSocketCloseFrame: Sendable {
}
do {
let rt = try await asyncChannel.executeThenClose { inbound, outbound in
try await withTaskCancellationHandler {
try await withThrowingTaskGroup(of: WebSocketCloseFrame.self) { group in
let webSocketHandler = Self(
channel: asyncChannel.channel,
outbound: outbound,
type: type,
configuration: configuration,
context: context
)
if case .enabled = configuration.autoPing.value {
/// Add task sending ping frames every so often and verifying a pong frame was sent back
group.addTask {
try await webSocketHandler.runAutoPingLoop()
return .init(closeCode: .goingAway, reason: "Ping timeout")
}
}
let rt = try await webSocketHandler.handle(
type: type,
inbound: inbound,
outbound: outbound,
handler: handler,
context: context
)
group.cancelAll()
return rt
}
defer {
context.logger.trace("Closing WebSocket")
}
return try await withTaskCancellationHandler {
let webSocketHandler = Self(
channel: asyncChannel.channel,
outbound: outbound,
type: type,
configuration: configuration,
context: context
)
return try await webSocketHandler.handle(
type: type,
inbound: inbound,
outbound: outbound,
handler: handler,
context: context
)
} onCancel: {
Task {
try await asyncChannel.channel.close(mode: .input)
Expand All @@ -166,39 +165,57 @@ public struct WebSocketCloseFrame: Sendable {
context: Context
) async throws -> WebSocketCloseFrame? {
try await withGracefulShutdownHandler {
let webSocketOutbound = WebSocketOutboundWriter(handler: self)
var inboundIterator = inbound.makeAsyncIterator()
let webSocketInbound = WebSocketInboundStream(
iterator: inboundIterator,
handler: self
)
let closeCode: WebSocketErrorCode
var clientError: Error?
do {
// handle websocket data and text
try await handler(webSocketInbound, webSocketOutbound, context)
closeCode = .normalClosure
} catch InternalError.close(let code) {
closeCode = code
} catch {
clientError = error
closeCode = .unexpectedServerError
}
do {
try await self.close(code: closeCode)
if case .closing = self.stateMachine.state {
// Close handshake. Wait for responding close or until inbound ends
while let frame = try await inboundIterator.next() {
if case .connectionClose = frame.opcode {
try await self.receivedClose(frame)
break
try await withThrowingTaskGroup(of: Void.self) { group in
if case .enabled = configuration.autoPing.value {
/// Add task sending ping frames every so often and verifying a pong frame was sent back
group.addTask {
try await self.runAutoPingLoop()
}
}
let webSocketOutbound = WebSocketOutboundWriter(handler: self)
var inboundIterator = inbound.makeAsyncIterator()
let webSocketInbound = WebSocketInboundStream(
iterator: inboundIterator,
handler: self
)
let closeCode: WebSocketErrorCode
var clientError: Error?
do {
// handle websocket data and text
try await handler(webSocketInbound, webSocketOutbound, context)
closeCode = .normalClosure
} catch InternalError.close(let code) {
closeCode = code
} catch {
clientError = error
closeCode = .unexpectedServerError
}
do {
try await self.close(code: closeCode)
if case .closing = self.stateMachine.state {
group.addTask {
try await Task.sleep(for: self.configuration.closeTimeout)
try await self.channel.close(mode: .input)
}
// Close handshake. Wait for responding close or until inbound ends
while let frame = try await inboundIterator.next() {
if case .connectionClose = frame.opcode {
try await self.receivedClose(frame)
// only the server can close the connection, so clients
// should continue reading from inbound until it is closed
if type == .server {
break
}
}
}
}
// don't propagate error if channel is already closed
} catch ChannelError.ioOnClosedChannel {}
if type == .client, let clientError {
throw clientError
}
// don't propagate error if channel is already closed
} catch ChannelError.ioOnClosedChannel {}
if type == .client, let clientError {
throw clientError

group.cancelAll()
}
} onGracefulShutdown: {
Task {
Expand Down
7 changes: 5 additions & 2 deletions Tests/WebSocketTests/AutobahnTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import XCTest
/// The Autobahn|Testsuite provides a fully automated test suite to verify client and server
/// implementations of The WebSocket Protocol for specification conformance and implementation robustness.
/// You can find out more at https://github.com/crossbario/autobahn-testsuite
///
/// Before running these tests run `./scripts/autobahn-server.sh` to running the test server.
final class AutobahnTests: XCTestCase {
/// To run all the autobahn compression tests takes a long time. By default we only run a selection.
/// The `AUTOBAHN_ALL_TESTS` environment flag triggers running all of them.
Expand Down Expand Up @@ -58,6 +60,9 @@ final class AutobahnTests: XCTestCase {
cases: Set<Int>,
extensions: [WebSocketExtensionFactory] = [.perMessageDeflate(maxDecompressedFrameSize: 16_777_216)]
) async throws {
// These are broken in CI currently
try XCTSkipIf(ProcessInfo.processInfo.environment["CI"] != nil)

struct CaseInfo: Decodable {
let id: String
let description: String
Expand Down Expand Up @@ -121,8 +126,6 @@ final class AutobahnTests: XCTestCase {
}

func test_3_ReservedBits() async throws {
// Reserved bits tests fail
try XCTSkipIf(true)
try await self.autobahnTests(cases: .init(28..<35))
}

Expand Down
42 changes: 42 additions & 0 deletions Tests/WebSocketTests/ClientTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2024 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import Logging
import NIOCore
import NIOSSL
import NIOWebSocket
import WSClient
import XCTest

final class WebSocketClientTests: XCTestCase {

func testEchoServer() async throws {
let clientLogger = {
var logger = Logger(label: "client")
logger.logLevel = .trace
return logger
}()
try await WebSocketClient.connect(
url: "wss://echo.websocket.org/",
tlsConfiguration: TLSConfiguration.makeClientConfiguration(),
logger: clientLogger
) { inbound, outbound, _ in
var inboundIterator = inbound.messages(maxSize: .max).makeAsyncIterator()
try await outbound.write(.text("hello"))
if let msg = try await inboundIterator.next() {
print(msg)
}
}
}
}

0 comments on commit 06e5e25

Please sign in to comment.