Skip to content

Commit

Permalink
Correctly place the SSL channel handler in front of the PostgresChann…
Browse files Browse the repository at this point in the history
…elHandler (#527)
  • Loading branch information
tkrajacic authored Dec 8, 2024
1 parent f2a6394 commit 96ed89f
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
6 changes: 3 additions & 3 deletions Sources/PostgresNIO/Connection/PostgresConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,18 @@ public final class PostgresConnection: @unchecked Sendable {
func start(configuration: InternalConfiguration) -> EventLoopFuture<Void> {
// 1. configure handlers

let configureSSLCallback: ((Channel) throws -> ())?
let configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> ())?

switch configuration.tls.base {
case .prefer(let context), .require(let context):
configureSSLCallback = { channel in
configureSSLCallback = { channel, postgresChannelHandler in
channel.eventLoop.assertInEventLoop()

let sslHandler = try NIOSSLClientHandler(
context: context,
serverHostname: configuration.serverNameForTLS
)
try channel.pipeline.syncOperations.addHandler(sslHandler, position: .first)
try channel.pipeline.syncOperations.addHandler(sslHandler, position: .before(postgresChannelHandler))
}
case .disable:
configureSSLCallback = nil
Expand Down
8 changes: 4 additions & 4 deletions Sources/PostgresNIO/New/PostgresChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
private var decoder: NIOSingleStepByteToMessageProcessor<PostgresBackendMessageDecoder>
private var encoder: PostgresFrontendMessageEncoder!
private let configuration: PostgresConnection.InternalConfiguration
private let configureSSLCallback: ((Channel) throws -> Void)?
private let configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)?

private var listenState = ListenStateMachine()
private var preparedStatementState = PreparedStatementStateMachine()
Expand All @@ -29,7 +29,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
configuration: PostgresConnection.InternalConfiguration,
eventLoop: EventLoop,
logger: Logger,
configureSSLCallback: ((Channel) throws -> Void)?
configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)?
) {
self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData)
self.eventLoop = eventLoop
Expand All @@ -46,7 +46,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
eventLoop: EventLoop,
state: ConnectionStateMachine = .init(.initialized),
logger: Logger = .psqlNoOpLogger,
configureSSLCallback: ((Channel) throws -> Void)?
configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)?
) {
self.state = state
self.eventLoop = eventLoop
Expand Down Expand Up @@ -439,7 +439,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
// This method must only be called, if we signalized the StateMachine before that we are
// able to setup a SSL connection.
do {
try self.configureSSLCallback!(context.channel)
try self.configureSSLCallback!(context.channel, self)
let action = self.state.sslHandlerAdded()
self.run(action, with: context)
} catch {
Expand Down
6 changes: 3 additions & 3 deletions Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class PostgresChannelHandlerTests: XCTestCase {
var config = self.testConnectionConfiguration()
XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration())))
var addSSLCallbackIsHit = false
let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel in
let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in
addSSLCallbackIsHit = true
}
let embedded = EmbeddedChannel(handlers: [
Expand Down Expand Up @@ -84,7 +84,7 @@ class PostgresChannelHandlerTests: XCTestCase {
var config = self.testConnectionConfiguration()
XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration())))
var addSSLCallbackIsHit = false
let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel in
let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in
addSSLCallbackIsHit = true
}
let eventHandler = TestEventHandler()
Expand Down Expand Up @@ -114,7 +114,7 @@ class PostgresChannelHandlerTests: XCTestCase {
func testSSLUnsupportedClosesConnection() throws {
let config = self.testConnectionConfiguration(tls: .require(try NIOSSLContext(configuration: .makeClientConfiguration())))

let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel in
let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in
XCTFail("This callback should never be exectuded")
throw PSQLError.sslUnsupported
}
Expand Down

0 comments on commit 96ed89f

Please sign in to comment.