Skip to content

Commit

Permalink
Add timeout for close
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-fowler committed Dec 17, 2024
1 parent 677fecb commit bf85ba3
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 61 deletions.
1 change: 1 addition & 0 deletions Sources/WSClient/WebSocketClientChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ struct WebSocketClientChannel: ClientConnectionChannel {
configuration: .init(
extensions: extensions,
autoPing: self.configuration.autoPing,
closeTimeout: self.configuration.closeTimeout,
validateUTF8: self.configuration.validateUTF8
),
asyncChannel: webSocketChannel,
Expand Down
4 changes: 4 additions & 0 deletions Sources/WSClient/WebSocketClientConfiguration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ public struct WebSocketClientConfiguration: Sendable {
public var additionalHeaders: HTTPFields
/// WebSocket extensions
public var extensions: [any WebSocketExtensionBuilder]
/// Close timeout
public var closeTimeout: Duration
/// Automatic ping setup
public var autoPing: AutoPingSetup
/// Should text be validated to be UTF8
Expand All @@ -39,12 +41,14 @@ public struct WebSocketClientConfiguration: Sendable {
maxFrameSize: Int = (1 << 14),
additionalHeaders: HTTPFields = .init(),
extensions: [WebSocketExtensionFactory] = [],
closeTimeout: Duration = .seconds(15),
autoPing: AutoPingSetup = .disabled,
validateUTF8: Bool = false
) {
self.maxFrameSize = maxFrameSize
self.additionalHeaders = additionalHeaders
self.extensions = extensions.map { $0.build() }
self.closeTimeout = closeTimeout
self.autoPing = autoPing
self.validateUTF8 = validateUTF8
}
Expand Down
128 changes: 69 additions & 59 deletions Sources/WSCore/WebSocketHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,17 @@ public struct WebSocketCloseFrame: Sendable {
let autoPing: AutoPingSetup
let validateUTF8: Bool
let reservedBits: WebSocketFrame.ReservedBits
let closeTimeout: Duration

@_spi(WSInternal) public init(extensions: [any WebSocketExtension], autoPing: AutoPingSetup, validateUTF8: Bool) {
@_spi(WSInternal) public init(
extensions: [any WebSocketExtension],
autoPing: AutoPingSetup,
closeTimeout: Duration = .seconds(15),
validateUTF8: Bool
) {
self.extensions = extensions
self.autoPing = autoPing
self.closeTimeout = closeTimeout
self.validateUTF8 = validateUTF8
// store reserved bits used by this handler
self.reservedBits = extensions.reduce(.init()) { partialResult, `extension` in
Expand Down Expand Up @@ -122,31 +129,20 @@ public struct WebSocketCloseFrame: Sendable {
context.logger.debug("Closing WebSocket")
}
return try await withTaskCancellationHandler {
try await withThrowingTaskGroup(of: WebSocketCloseFrame.self) { group in
let webSocketHandler = Self(
channel: asyncChannel.channel,
outbound: outbound,
type: type,
configuration: configuration,
context: context
)
if case .enabled = configuration.autoPing.value {
/// Add task sending ping frames every so often and verifying a pong frame was sent back
group.addTask {
try await webSocketHandler.runAutoPingLoop()
return .init(closeCode: .goingAway, reason: "Ping timeout")
}
}
let rt = try await webSocketHandler.handle(
type: type,
inbound: inbound,
outbound: outbound,
handler: handler,
context: context
)
group.cancelAll()
return rt
}
let webSocketHandler = Self(
channel: asyncChannel.channel,
outbound: outbound,
type: type,
configuration: configuration,
context: context
)
return try await webSocketHandler.handle(
type: type,
inbound: inbound,
outbound: outbound,
handler: handler,
context: context
)
} onCancel: {
Task {
try await asyncChannel.channel.close(mode: .input)
Expand All @@ -169,43 +165,57 @@ public struct WebSocketCloseFrame: Sendable {
context: Context
) async throws -> WebSocketCloseFrame? {
try await withGracefulShutdownHandler {
let webSocketOutbound = WebSocketOutboundWriter(handler: self)
var inboundIterator = inbound.makeAsyncIterator()
let webSocketInbound = WebSocketInboundStream(
iterator: inboundIterator,
handler: self
)
let closeCode: WebSocketErrorCode
var clientError: Error?
do {
// handle websocket data and text
try await handler(webSocketInbound, webSocketOutbound, context)
closeCode = .normalClosure
} catch InternalError.close(let code) {
closeCode = code
} catch {
clientError = error
closeCode = .unexpectedServerError
}
do {
try await self.close(code: closeCode)
if case .closing = self.stateMachine.state {
// Close handshake. Wait for responding close or until inbound ends
while let frame = try await inboundIterator.next() {
if case .connectionClose = frame.opcode {
try await self.receivedClose(frame)
// only the server can close the connection, so clients
// should continue reading from inbound until it is closed
if type == .server {
break
try await withThrowingTaskGroup(of: Void.self) { group in
if case .enabled = configuration.autoPing.value {
/// Add task sending ping frames every so often and verifying a pong frame was sent back
group.addTask {
try await self.runAutoPingLoop()
}
}
let webSocketOutbound = WebSocketOutboundWriter(handler: self)
var inboundIterator = inbound.makeAsyncIterator()
let webSocketInbound = WebSocketInboundStream(
iterator: inboundIterator,
handler: self
)
let closeCode: WebSocketErrorCode
var clientError: Error?
do {
// handle websocket data and text
try await handler(webSocketInbound, webSocketOutbound, context)
closeCode = .normalClosure
} catch InternalError.close(let code) {
closeCode = code
} catch {
clientError = error
closeCode = .unexpectedServerError
}
do {
try await self.close(code: closeCode)
if case .closing = self.stateMachine.state {
group.addTask {
try await Task.sleep(for: self.configuration.closeTimeout)
try await self.channel.close(mode: .input)
}
// Close handshake. Wait for responding close or until inbound ends
while let frame = try await inboundIterator.next() {
if case .connectionClose = frame.opcode {
try await self.receivedClose(frame)
// only the server can close the connection, so clients
// should continue reading from inbound until it is closed
if type == .server {
break
}
}
}
}
// don't propagate error if channel is already closed
} catch ChannelError.ioOnClosedChannel {}
if type == .client, let clientError {
throw clientError
}
// don't propagate error if channel is already closed
} catch ChannelError.ioOnClosedChannel {}
if type == .client, let clientError {
throw clientError

group.cancelAll()
}
} onGracefulShutdown: {
Task {
Expand Down
4 changes: 2 additions & 2 deletions Tests/WebSocketTests/AutobahnTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import XCTest
/// The Autobahn|Testsuite provides a fully automated test suite to verify client and server
/// implementations of The WebSocket Protocol for specification conformance and implementation robustness.
/// You can find out more at https://github.com/crossbario/autobahn-testsuite
///
/// Before running these tests run `./scripts/autobahn-server.sh` to running the test server.
final class AutobahnTests: XCTestCase {
/// To run all the autobahn compression tests takes a long time. By default we only run a selection.
/// The `AUTOBAHN_ALL_TESTS` environment flag triggers running all of them.
Expand Down Expand Up @@ -121,8 +123,6 @@ final class AutobahnTests: XCTestCase {
}

func test_3_ReservedBits() async throws {
// Reserved bits tests fail
try XCTSkipIf(true)
try await self.autobahnTests(cases: .init(28..<35))
}

Expand Down

0 comments on commit bf85ba3

Please sign in to comment.