diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index e267d8f9..4f884af4 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -694,6 +694,168 @@ extension PostgresConnection { } } +// MARK: Copy from + +/// A handle to send +public struct PostgresCopyFromWriter: Sendable { + /// The backend failed the copy data transfer, which means that no more data sent by the frontend would be processed. + /// + /// The `PostgresCopyFromWriter` should cancel the data transfer. + public struct CopyCancellationError: Error { + /// The error that the backend sent us which cancelled the data transfer. + /// + /// Note that this error is related to previous `write` calls since a `CopyCancellationError` is thrown before + /// new data is written by `write`. + let underlyingError: PSQLError + } + + private let channelHandler: NIOLoopBound + private let eventLoop: any EventLoop + + init(handler: PostgresChannelHandler, eventLoop: any EventLoop) { + self.channelHandler = NIOLoopBound(handler, eventLoop: eventLoop) + self.eventLoop = eventLoop + } + + /// Send data for a `COPY ... FROM STDIN` operation to the backend. + /// + /// If the backend encountered an error during the data transfer and thus cannot process any more data, this throws + /// a `CopyFailedError`. + public func write(_ byteBuffer: ByteBuffer) async throws { + // First, wait that we have a writable buffer. This also throws a `CopyFailedError` in case the backend sent an + // error during the data transfer and thus cannot process any more data. + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + if eventLoop.inEventLoop { + self.channelHandler.value.waitForWritableBuffer(continuation) + } else { + eventLoop.execute { + self.channelHandler.value.waitForWritableBuffer(continuation) + } + } + } + + // Run the actual data transfer + if eventLoop.inEventLoop { + self.channelHandler.value.copyData(byteBuffer) + } else { + eventLoop.execute { + self.channelHandler.value.copyData(byteBuffer) + } + } + } + + /// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyDone` message to + /// the backend. + func done() async throws { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + if eventLoop.inEventLoop { + self.channelHandler.value.sendCopyDone(continuation: continuation) + } else { + eventLoop.execute { + self.channelHandler.value.sendCopyDone(continuation: continuation) + } + } + } + } + + /// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyFail` message to + /// the backend. + func failed(error: any Error) async throws { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + if eventLoop.inEventLoop { + self.channelHandler.value.sendCopyFailed(message: "\(error)", continuation: continuation) + } else { + eventLoop.execute { + self.channelHandler.value.sendCopyFailed(message: "\(error)", continuation: continuation) + } + } + } + } + + /// Send a `Sync` message to the backend. + func sync() { + if eventLoop.inEventLoop { + self.channelHandler.value.sendSync() + } else { + eventLoop.execute { + self.channelHandler.value.sendSync() + } + } + } +} + +public struct CopyFromOptions { + let delimiter: StaticString? + + public init(delimiter: StaticString? = nil) { + self.delimiter = delimiter + } +} + +private func buildCopyFromQuery( + table: String, + columns: [StaticString]?, + options: CopyFromOptions +) -> PostgresQuery { + var query = "COPY \(table)" + if let columns { + // TODO: Is using `StaticString` sufficient here to prevent against SQL injection attacks or should we try to + // escape the identifiers, essentially re-implementing `PQescapeIdentifier`? + query += "(" + columns.map(\.description).joined(separator: ",") + ")" + } + query += " FROM STDIN" + var queryOptions: [String] = [] + if let delimiter = options.delimiter { + queryOptions.append("DELIMITER '\(delimiter)'") + } + if !queryOptions.isEmpty { + query += " WITH " + query += queryOptions.map { "(\($0))" }.joined(separator: ",") + } + return "\(unescaped: query)" +} + +extension PostgresConnection { + // TODO: Instead of an arbitrary query, make this a structured data structure. + // TODO: Write doc comment + public func copyFrom( + table: String, + columns: [StaticString]? = nil, + options: CopyFromOptions = CopyFromOptions(), + logger: Logger, + writeData: @escaping @Sendable (PostgresCopyFromWriter) async throws -> Void, + file: String = #fileID, + line: Int = #line + ) async throws { + var logger = logger + logger[postgresMetadataKey: .connectionID] = "\(self.id)" + let writer = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let context = ExtendedQueryContext( + copyFromQuery: buildCopyFromQuery(table: table, columns: columns, options: options), + triggerCopy: continuation, + logger: logger + ) + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + } + + do { + try await writeData(writer) + } catch let error as PostgresCopyFromWriter.CopyCancellationError { + // If the copy was cancelled because the backend sent us an error, we need to send a `Sync` message to put + // the backend out of the copy mode. + writer.sync() + throw error.underlyingError + } catch { + // Throw the error from the `writeData` closure instead of the one that Postgres gives us upon receiving the + // `CopyFail` message. + try? await writer.failed(error: error) + throw error + } + try await writer.done() + } + +} + // MARK: PostgresDatabase conformance extension PostgresConnection: PostgresDatabase { diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 9d264bcc..bd25d6a6 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -88,7 +88,24 @@ struct ConnectionStateMachine { case sendParseDescribeBindExecuteSync(PostgresQuery) case sendBindExecuteSync(PSQLExecuteStatement) case failQuery(EventLoopPromise, with: PSQLError, cleanupContext: CleanUpContext?) + /// Fail a query's execution by throwing an error on the given continuation. When `sync` is `true`, send a + /// `sync` message to the backend. + case failQueryContinuation(any AnyErrorContinuation, with: PSQLError, cleanupContext: CleanUpContext?, sync: Bool) case succeedQuery(EventLoopPromise, with: QueryResult) + case succeedQueryContinuation(CheckedContinuation) + + /// Trigger a data transfer returning a `PostgresCopyFromWriter` to the given continuation. + /// + /// Once the data transfer is triggered, it will send `CopyData` messages to the backend. After that the state + /// machine needs to be prodded again to send a `CopyDone` or `CopyFail` by calling + /// `PostgresChannelHandler.copyDone` or ``PostgresChannelHandler.copyFailed``. + case triggerCopyData(CheckedContinuation) + + /// Send a `CopyDone` message to the backend, followed by a `Sync`. + case sendCopyDone + + /// Send a `CopyFail` message to the backend with the given error message. + case sendCopyFailed(message: String) // --- streaming actions // actions if query has requested next row but we are waiting for backend @@ -106,6 +123,27 @@ struct ConnectionStateMachine { case succeedClose(CloseCommandContext) case failClose(CloseCommandContext, with: PSQLError, cleanupContext: CleanUpContext?) } + + enum ChannelWritabilityChangedAction { + /// No action needs to be taken based on the writability change. + case none + + /// Resume the given continuation successfully. + case resumeContinuation(CheckedContinuation) + } + + enum WaitForWritableBufferAction { + /// The channel has backpressure and cannot handle any data right now. We should flush the channel to help + /// relieve backpressure. Once the channel is writable again, this will be communicated via + /// `channelWritabilityChanged` + case waitForBackpressureRelieve + + /// Resume the given continuation successfully. + case resumeContinuation(CheckedContinuation) + + /// Fail the continuation with the given error. + case failContinuation(CheckedContinuation, error: any Error) + } private var state: State private let requireBackendKeyData: Bool @@ -587,6 +625,8 @@ struct ConnectionStateMachine { switch queryContext.query { case .executeStatement(_, let promise), .unnamed(_, let promise): return .failQuery(promise, with: psqlErrror, cleanupContext: nil) + case .copyFrom(_, let triggerCopy): + return .failQueryContinuation(triggerCopy, with: psqlErrror, cleanupContext: nil, sync: false) case .prepareStatement(_, _, _, let promise): return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil) } @@ -660,6 +700,16 @@ struct ConnectionStateMachine { preconditionFailure("Invalid state: \(self.state)") } } + + mutating func channelWritabilityChanged(isWritable: Bool) -> ChannelWritabilityChangedAction { + guard case .extendedQuery(var queryState, let connectionContext) = state else { + return .none + } + self.state = .modifying // avoid CoW + let action = queryState.channelWritabilityChanged(isWritable: isWritable) + self.state = .extendedQuery(queryState, connectionContext) + return action + } // MARK: - Running Queries - @@ -751,6 +801,59 @@ struct ConnectionStateMachine { self.state = .extendedQuery(queryState, connectionContext) return self.modify(with: action) } + + mutating func copyInResponseReceived( + _ copyInResponse: PostgresBackendMessage.CopyInResponseMessage + ) -> ConnectionAction { + guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.emptyQueryResponse)) + } + + self.state = .modifying // avoid CoW + let action = queryState.copyInResponseReceived(copyInResponse) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + } + + /// Wait fo `channel` to be writable and be able to handle more `CopyData` messages. Resume the given continuation + /// when the channel is able handle more data. + /// + /// This fails the continuation with a `PostgresCopyFromWriter.CopyCancellationError` when the server has cancelled + /// the data transfer to indicate that the frontend should not send any more data. + mutating func waitForWritableBuffer(channel: any Channel, continuation: CheckedContinuation) -> WaitForWritableBufferAction { + guard case .extendedQuery(var queryState, let connectionContext) = self.state else { + preconditionFailure("Copy mode is only supported for extended queries") + } + + self.state = .modifying // avoid CoW + let action = queryState.waitForWritableBuffer(channel: channel, continuation: continuation) + self.state = .extendedQuery(queryState, connectionContext) + return action + } + + /// Put the state machine out of the copying mode and send a `CopyDone` message to the backend. + mutating func sendCopyDone(continuation: CheckedContinuation) -> ConnectionAction { + guard case .extendedQuery(var queryState, let connectionContext) = self.state else { + preconditionFailure("Copy mode is only supported for extended queries") + } + + self.state = .modifying // avoid CoW + let action = queryState.sendCopyDone(continuation: continuation) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + } + + /// Put the state machine out of the copying mode and send a `CopyFail` message to the backend. + mutating func sendCopyFail(message: String, continuation: CheckedContinuation) -> ConnectionAction { + guard case .extendedQuery(var queryState, let connectionContext) = self.state else { + preconditionFailure("Copy mode is only supported for extended queries") + } + + self.state = .modifying // avoid CoW + let action = queryState.sendCopyFailed(message: message, continuation: continuation) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + } mutating func emptyQueryResponseReceived() -> ConnectionAction { guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { @@ -860,14 +963,21 @@ struct ConnectionStateMachine { .forwardRows, .forwardStreamComplete, .wait, - .read: + .read, + .triggerCopyData, + .sendCopyDone, + .sendCopyFailed, + .succeedQueryContinuation: preconditionFailure("Invalid query state machine action in state: \(self.state), action: \(action)") case .evaluateErrorAtConnectionLevel: return .closeConnectionAndCleanup(cleanupContext) - case .failQuery(let queryContext, with: let error): - return .failQuery(queryContext, with: error, cleanupContext: cleanupContext) + case .failQuery(let promise, with: let error): + return .failQuery(promise, with: error, cleanupContext: cleanupContext) + + case .failQueryContinuation(let continuation, with: let error, let sync): + return .failQueryContinuation(continuation, with: error, cleanupContext: cleanupContext, sync: sync) case .forwardStreamError(let error, let read): return .forwardStreamError(error, read: read, cleanupContext: cleanupContext) @@ -1038,8 +1148,19 @@ extension ConnectionStateMachine { case .failQuery(let requestContext, with: let error): let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) return .failQuery(requestContext, with: error, cleanupContext: cleanupContext) + case .failQueryContinuation(let continuation, with: let error, let sync): + let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) + return .failQueryContinuation(continuation, with: error, cleanupContext: cleanupContext, sync: sync) case .succeedQuery(let requestContext, with: let result): return .succeedQuery(requestContext, with: result) + case .succeedQueryContinuation(let continuation): + return .succeedQueryContinuation(continuation) + case .triggerCopyData(let triggerCopy): + return .triggerCopyData(triggerCopy) + case .sendCopyDone: + return .sendCopyDone + case .sendCopyFailed(message: let message): + return .sendCopyFailed(message: message) case .forwardRows(let buffer): return .forwardRows(buffer) case .forwardStreamComplete(let buffer, let commandTag): diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 087a6c24..3d5bd5c1 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -1,7 +1,16 @@ import NIOCore struct ExtendedQueryStateMachine { - + + private enum CopyingDataState { + /// The write channel is ready to handle more data + case readyToSend + + /// The write channel has backpressure. Once that is relieved, we should resume the given continuation to allow more + /// data to be sent by the client. + case pendingBackpressureRelieve(CheckedContinuation) + } + private enum State { case initialized(ExtendedQueryContext) case messagesSent(ExtendedQueryContext) @@ -12,6 +21,15 @@ struct ExtendedQueryStateMachine { case noDataMessageReceived(ExtendedQueryContext) case emptyQueryResponseReceived + /// We are currently copying data to the backend using `CopyData` messages. + case copyingData(CopyingDataState) + + /// We copied data to the backend and are done with that, either by sending a `CopyDone` or `CopyFail` message. + /// We are now expecting a `CommandComplete` or `ErrorResponse`. + /// + /// Once that is received the continuation is resumed. + case copyingFinished(CheckedContinuation) + /// A state that is used if a noData message was received before. If a row description was received `bufferingRows` is /// used after receiving a `bindComplete` message case bindCompleteReceived(ExtendedQueryContext) @@ -32,13 +50,30 @@ struct ExtendedQueryStateMachine { // --- general actions case failQuery(EventLoopPromise, with: PSQLError) + /// Fail a query's execution by throwing an error on the given continuation. If `sync` is `true`, send a `sync` + /// message to the backend to put it out of the copy mode. + case failQueryContinuation(AnyErrorContinuation, with: PSQLError, sync: Bool) case succeedQuery(EventLoopPromise, with: QueryResult) + case succeedQueryContinuation(CheckedContinuation) case evaluateErrorAtConnectionLevel(PSQLError) case succeedPreparedStatementCreation(EventLoopPromise, with: RowDescription?) case failPreparedStatementCreation(EventLoopPromise, with: PSQLError) + /// Trigger a data transfer returning a `PostgresCopyFromWriter` to the given continuation. + /// + /// Once the data transfer is triggered, it will send `CopyData` messages to the backend. After that the state + /// machine needs to be prodded again to send a `CopyDone` or `CopyFail` by calling + /// `PostgresChannelHandler.copyDone` or ``PostgresChannelHandler.copyFailed``. + case triggerCopyData(CheckedContinuation) + + /// Send a `CopyDone` message to the backend, followed by a `Sync`. + case sendCopyDone + + /// Send a `CopyFail` message to the backend with the given error message. + case sendCopyFailed(message: String) + // --- streaming actions // actions if query has requested next row but we are waiting for backend case forwardRows([DataRow]) @@ -63,7 +98,7 @@ struct ExtendedQueryStateMachine { } switch queryContext.query { - case .unnamed(let query, _): + case .unnamed(let query, _), .copyFrom(let query, _): return self.avoidingStateMachineCoW { state -> Action in state = .messagesSent(queryContext) return .sendParseDescribeBindExecuteSync(query) @@ -91,7 +126,7 @@ struct ExtendedQueryStateMachine { mutating func cancel() -> Action { switch self.state { case .initialized: - preconditionFailure("Start must be called immediatly after the query was created") + preconditionFailure("Start must be called immediately after the query was created") case .messagesSent(let queryContext), .parseCompleteReceived(let queryContext), @@ -107,10 +142,15 @@ struct ExtendedQueryStateMachine { switch queryContext.query { case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): return .failQuery(eventLoopPromise, with: .queryCancelled) - + case .copyFrom(_, let triggerCopy): + return .failQueryContinuation(triggerCopy, with: .queryCancelled, sync: false) case .prepareStatement(_, _, _, let eventLoopPromise): return .failPreparedStatementCreation(eventLoopPromise, with: .queryCancelled) } + case .copyingData: + return .sendCopyFailed(message: "Query cancelled") + case .copyingFinished(let continuation): + return .failQueryContinuation(continuation, with: .queryCancelled, sync: true) case .streaming(let columns, var streamStateMachine): precondition(!self.isCancelled) @@ -160,7 +200,7 @@ struct ExtendedQueryStateMachine { } switch queryContext.query { - case .unnamed, .executeStatement: + case .unnamed, .copyFrom, .executeStatement: return self.avoidingStateMachineCoW { state -> Action in state = .noDataMessageReceived(queryContext) return .wait @@ -198,7 +238,7 @@ struct ExtendedQueryStateMachine { } switch queryContext.query { - case .unnamed, .executeStatement: + case .unnamed, .copyFrom, .executeStatement: return .wait case .prepareStatement(_, _, _, let eventLoopPromise): @@ -217,7 +257,7 @@ struct ExtendedQueryStateMachine { return .succeedQuery(eventLoopPromise, with: result) } - case .prepareStatement: + case .prepareStatement, .copyFrom: return .evaluateErrorAtConnectionLevel(.unexpectedBackendMessage(.bindComplete)) } @@ -235,7 +275,9 @@ struct ExtendedQueryStateMachine { .streaming, .drain, .commandComplete, - .error: + .error, + .copyingData, + .copyingFinished: return self.setAndFireError(.unexpectedBackendMessage(.bindComplete)) case .modifying: @@ -274,7 +316,9 @@ struct ExtendedQueryStateMachine { .rowDescriptionReceived, .bindCompleteReceived, .commandComplete, - .error: + .error, + .copyingData, + .copyingFinished: return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) case .modifying: preconditionFailure("Invalid state") @@ -291,10 +335,19 @@ struct ExtendedQueryStateMachine { let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger) return .succeedQuery(eventLoopPromise, with: result) } - + case .copyFrom: + // We expect to transition through `copyingData` to `copyingFinished` before receiving a + // `CommandCompleted` message for copy queries. + preconditionFailure("Invalid state: \(self.state)") case .prepareStatement: preconditionFailure("Invalid state: \(self.state)") } + + case .copyingFinished(let continuation): + return self.avoidingStateMachineCoW { state -> Action in + state = .commandComplete(commandTag: commandTag) + return .succeedQueryContinuation(continuation) + } case .streaming(_, var demandStateMachine): return self.avoidingStateMachineCoW { state -> Action in @@ -315,13 +368,89 @@ struct ExtendedQueryStateMachine { .emptyQueryResponseReceived, .rowDescriptionReceived, .commandComplete, - .error: + .error, + .copyingData: return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag))) case .modifying: preconditionFailure("Invalid state") } } + /// When a `CopyInResponse` message is received from the backend, return a continuation to which a + /// `PostgresCopyFromWriter` can be yielded to trigger a data transfer to the backend. + /// + /// If we are not currently in a state to handle the `CopyInResponse`, an error is thrown. + mutating func copyInResponseReceived( + _ copyInResponse: PostgresBackendMessage.CopyInResponseMessage + ) -> Action { + guard case .bindCompleteReceived(let queryContext) = self.state, + case .copyFrom(_, let triggerCopy) = queryContext.query else { + return self.setAndFireError(.unexpectedBackendMessage(.copyInResponse(copyInResponse))) + } + return avoidingStateMachineCoW { state in + state = .copyingData(.readyToSend) + return .triggerCopyData(triggerCopy) + } + } + + /// Wait fo `channel` to be writable and be able to handle more `CopyData` messages. Resume the given continuation + /// when the channel is able handle more data. + /// + /// This fails the continuation with a `PostgresCopyFromWriter.CopyCancellationError` when the server has cancelled + /// the data transfer to indicate that the frontend should not send any more data. + mutating func waitForWritableBuffer( + channel: any Channel, + continuation: CheckedContinuation + ) -> ConnectionStateMachine.WaitForWritableBufferAction { + if case .error(let error) = self.state { + return .failContinuation(continuation, error: PostgresCopyFromWriter.CopyCancellationError(underlyingError: error)) + } + guard case .copyingData(let copyingSubstate) = self.state else { + preconditionFailure("Must be in copy mode to copy data") + } + guard case .readyToSend = copyingSubstate else { + preconditionFailure("Not ready to send data") + } + if channel.isWritable { + return .resumeContinuation(continuation) + } + return avoidingStateMachineCoW { state in + // Even if the buffer isn't writable, we write the current chunk of data to it. We just don't resume + // the continuation. This will prevent more writes from happening to build up more write backpressure. + state = .copyingData(.pendingBackpressureRelieve(continuation)) + return .waitForBackpressureRelieve + } + } + + /// Put the state machine out of the copying mode and send a `CopyDone` message to the backend. + mutating func sendCopyDone(continuation: CheckedContinuation) -> Action { + if case .error(let error) = self.state { + return .failQueryContinuation(continuation, with: error, sync: true) + } + guard case .copyingData = self.state else { + preconditionFailure("Must be in copy mode to send CopyDone") + } + return avoidingStateMachineCoW { state in + state = .copyingFinished(continuation) + return .sendCopyDone + } + } + + /// Put the state machine out of the copying mode and send a `CopyFail` message to the backend. + mutating func sendCopyFailed(message: String, continuation: CheckedContinuation) -> Action { + if case .error(let error) = self.state { + return .failQueryContinuation(continuation, with: error, sync: true) + } + guard case .copyingData = self.state else { + preconditionFailure("Must be in copy mode to send CopyFail") + } + return avoidingStateMachineCoW { state in + state = .copyingFinished(continuation) + return .sendCopyFailed(message: message) + } + + } + mutating func emptyQueryResponseReceived() -> Action { guard case .bindCompleteReceived(let queryContext) = self.state else { return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) @@ -336,7 +465,7 @@ struct ExtendedQueryStateMachine { return .succeedQuery(eventLoopPromise, with: result) } - case .prepareStatement(_, _, _, _): + case .prepareStatement, .copyFrom: return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) } } @@ -353,6 +482,8 @@ struct ExtendedQueryStateMachine { return self.setAndFireError(error) case .rowDescriptionReceived, .noDataMessageReceived: return self.setAndFireError(error) + case .copyingData, .copyingFinished: + return self.setAndFireError(error) case .streaming, .drain: return self.setAndFireError(error) case .commandComplete, .emptyQueryResponseReceived: @@ -403,7 +534,9 @@ struct ExtendedQueryStateMachine { .noDataMessageReceived, .emptyQueryResponseReceived, .rowDescriptionReceived, - .bindCompleteReceived: + .bindCompleteReceived, + .copyingData, + .copyingFinished: preconditionFailure("Requested to consume next row without anything going on.") case .commandComplete, .error: @@ -427,7 +560,9 @@ struct ExtendedQueryStateMachine { .noDataMessageReceived, .emptyQueryResponseReceived, .rowDescriptionReceived, - .bindCompleteReceived: + .bindCompleteReceived, + .copyingData, + .copyingFinished: return .wait case .streaming(let columns, var demandStateMachine): @@ -454,7 +589,9 @@ struct ExtendedQueryStateMachine { .parameterDescriptionReceived, .noDataMessageReceived, .rowDescriptionReceived, - .bindCompleteReceived: + .bindCompleteReceived, + .copyingData, + .copyingFinished: return .read case .streaming(let columns, var demandStateMachine): precondition(!self.isCancelled) @@ -480,6 +617,16 @@ struct ExtendedQueryStateMachine { preconditionFailure("Invalid state") } } + + mutating func channelWritabilityChanged(isWritable: Bool) -> ConnectionStateMachine.ChannelWritabilityChangedAction { + guard case .copyingData(.pendingBackpressureRelieve(let continuation)) = state else { + return .none + } + return self.avoidingStateMachineCoW { state in + state = .copyingData(.readyToSend) + return .resumeContinuation(continuation) + } + } // MARK: Private Methods @@ -499,11 +646,18 @@ struct ExtendedQueryStateMachine { switch context.query { case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): return .failQuery(eventLoopPromise, with: error) + case .copyFrom(_, let triggerCopy): + return .failQueryContinuation(triggerCopy, with: error, sync: false) case .prepareStatement(_, _, _, let eventLoopPromise): return .failPreparedStatementCreation(eventLoopPromise, with: error) } } - + case .copyingData: + self.state = .error(error) + return .evaluateErrorAtConnectionLevel(error) + case .copyingFinished(let continuation): + self.state = .error(error) + return .failQueryContinuation(continuation, with: error, sync: true) case .drain: self.state = .error(error) return .evaluateErrorAtConnectionLevel(error) @@ -536,11 +690,19 @@ struct ExtendedQueryStateMachine { switch context.query { case .prepareStatement: return true - case .unnamed, .executeStatement: + case .unnamed, .copyFrom, .executeStatement: return false } - case .initialized, .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .bindCompleteReceived, .streaming, .drain: + case .initialized, + .messagesSent, + .parseCompleteReceived, + .parameterDescriptionReceived, + .bindCompleteReceived, + .streaming, + .drain, + .copyingData, + .copyingFinished: return false case .modifying: diff --git a/Sources/PostgresNIO/New/Extensions/AnyErrorContinuation.swift b/Sources/PostgresNIO/New/Extensions/AnyErrorContinuation.swift new file mode 100644 index 00000000..a678e44f --- /dev/null +++ b/Sources/PostgresNIO/New/Extensions/AnyErrorContinuation.swift @@ -0,0 +1,6 @@ +/// Any `CheckedContinuation` that has an error type of `any Error`. +protocol AnyErrorContinuation { + func resume(throwing error: any Error) +} + +extension CheckedContinuation: AnyErrorContinuation where E == any Error {} diff --git a/Sources/PostgresNIO/New/Messages/CopyInMessage.swift b/Sources/PostgresNIO/New/Messages/CopyInMessage.swift new file mode 100644 index 00000000..641121fb --- /dev/null +++ b/Sources/PostgresNIO/New/Messages/CopyInMessage.swift @@ -0,0 +1,44 @@ +extension PostgresBackendMessage { + struct CopyInResponseMessage: Hashable { + enum Format: Int { + case textual = 0 + case binary = 1 + } + + let format: Format + let columnFormats: [Format] + + static func decode(from buffer: inout ByteBuffer) throws -> Self { + guard let rawFormat = buffer.readInteger(endianness: .big, as: Int8.self) else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(1, actual: buffer.readableBytes) + } + guard let format = Format(rawValue: Int(rawFormat)) else { + throw PSQLPartialDecodingError.unexpectedValue(value: rawFormat) + } + + guard let numColumns = buffer.readInteger(endianness: .big, as: Int16.self) else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(2, actual: buffer.readableBytes) + } + var columnFormatCodes: [Format] = [] + columnFormatCodes.reserveCapacity(Int(numColumns)) + + for _ in 0..) + /// A `COPY ... FROM STDIN` query that copies data from the frontend into a table. + /// + /// When `triggerCopy` is yielded a `PostgresCopyFromWriter`, the frontend will start sending data to the + /// backend via `CopyData` messages and finalize the data transfer using a `CopyDone` message to the backend or + /// using a `CopyFail` message. + /// + /// `queryCompleted` is a promise that is finished when the entire query is done. + case copyFrom(PostgresQuery, triggerCopy: CheckedContinuation) case executeStatement(PSQLExecuteStatement, EventLoopPromise) case prepareStatement(name: String, query: String, bindingDataTypes: [PostgresDataType], EventLoopPromise) } @@ -50,6 +60,15 @@ final class ExtendedQueryContext: Sendable { self.logger = logger } + init( + copyFromQuery query: PostgresQuery, + triggerCopy: CheckedContinuation, + logger: Logger + ) { + self.query = .copyFrom(query, triggerCopy: triggerCopy) + self.logger = logger + } + init( executeStatement: PSQLExecuteStatement, logger: Logger, diff --git a/Sources/PostgresNIO/New/PostgresBackendMessage.swift b/Sources/PostgresNIO/New/PostgresBackendMessage.swift index 792beec3..d8e86b05 100644 --- a/Sources/PostgresNIO/New/PostgresBackendMessage.swift +++ b/Sources/PostgresNIO/New/PostgresBackendMessage.swift @@ -29,6 +29,7 @@ enum PostgresBackendMessage: Hashable { case bindComplete case closeComplete case commandComplete(String) + case copyInResponse(CopyInResponseMessage) case dataRow(DataRow) case emptyQueryResponse case error(ErrorResponse) @@ -96,6 +97,9 @@ extension PostgresBackendMessage { } return .commandComplete(commandTag) + case .copyInResponse: + return try .copyInResponse(.decode(from: &buffer)) + case .dataRow: return try .dataRow(.decode(from: &buffer)) @@ -131,9 +135,9 @@ extension PostgresBackendMessage { case .rowDescription: return try .rowDescription(.decode(from: &buffer)) - - case .copyData, .copyDone, .copyInResponse, .copyOutResponse, .copyBothResponse, .functionCallResponse, .negotiateProtocolVersion: - preconditionFailure() + + case .copyData, .copyDone, .copyOutResponse, .copyBothResponse, .functionCallResponse, .negotiateProtocolVersion: + preconditionFailure("Unknown message kind: \(messageID)") } } } @@ -151,6 +155,8 @@ extension PostgresBackendMessage: CustomDebugStringConvertible { return ".closeComplete" case .commandComplete(let commandTag): return ".commandComplete(\(String(reflecting: commandTag)))" + case .copyInResponse(let copyInResponse): + return ".copyInResponse(\(String(reflecting: copyInResponse)))" case .dataRow(let dataRow): return ".dataRow(\(String(reflecting: dataRow)))" case .emptyQueryResponse: diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 0a14849a..5bd47925 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -136,6 +136,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { action = self.state.closeCompletedReceived() case .commandComplete(let commandTag): action = self.state.commandCompletedReceived(commandTag) + case .copyInResponse(let copyInResponse): + action = self.state.copyInResponseReceived(copyInResponse) case .dataRow(let dataRow): action = self.state.dataRowReceived(dataRow) case .emptyQueryResponse: @@ -169,10 +171,61 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.run(action, with: context) } + /// Wait for the channel to be writable and be able to handle more `CopyData` messages. Resume the given + /// continuation when the channel is able handle more data. + /// + /// This fails the continuation with a `PostgresCopyFromWriter.CopyCancellationError` when the server has cancelled + /// the data transfer to indicate that the frontend should not send any more data. + func waitForWritableBuffer(_ continuation: CheckedContinuation) { + let action = self.state.waitForWritableBuffer(channel: handlerContext!.channel, continuation: continuation) + switch action { + case .waitForBackpressureRelieve: + self.handlerContext!.channel.flush() + case .resumeContinuation(let continuation): + continuation.resume() + case .failContinuation(_, error: let error): + continuation.resume(throwing: error) + } + } + + /// Send a `CopyData` message to the backend using the given data. + func copyData(_ data: ByteBuffer) { + self.encoder.copyData(data: data) + self.handlerContext!.write(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + } + + /// Put the state machine out of the copying mode and send a `CopyDone` message to the backend. + func sendCopyDone(continuation: CheckedContinuation) { + let action = self.state.sendCopyDone(continuation: continuation) + self.run(action, with: self.handlerContext!) + } + + /// Put the state machine out of the copying mode and send a `CopyFail` message to the backend. + func sendCopyFailed(message: String, continuation: CheckedContinuation) { + let action = self.state.sendCopyFail(message: message, continuation: continuation) + self.run(action, with: self.handlerContext!) + } + + /// Send a `Sync` message to the backend. + func sendSync() { + self.encoder.sync() + self.handlerContext!.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + } + func channelReadComplete(context: ChannelHandlerContext) { let action = self.state.channelReadComplete() self.run(action, with: context) } + + func channelWritabilityChanged(context: ChannelHandlerContext) { + let action = self.state.channelWritabilityChanged(isWritable: context.channel.isWritable) + switch action { + case .none: + break + case .resumeContinuation(let continuation): + continuation.resume() + } + } func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { self.logger.trace("User inbound event received", metadata: [ @@ -353,12 +406,32 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.sendParseDescribeBindExecuteAndSyncMessage(query: query, context: context) case .succeedQuery(let promise, with: let result): self.succeedQuery(promise, result: result, context: context) + case .succeedQueryContinuation(let continuation): + continuation.resume() case .failQuery(let promise, with: let error, let cleanupContext): promise.fail(error) if let cleanupContext = cleanupContext { self.closeConnectionAndCleanup(cleanupContext, context: context) } - + case .failQueryContinuation(let continuation, with: let error, let cleanupContext, let sync): + if sync { + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + } + if let cleanupContext = cleanupContext { + self.closeConnectionAndCleanup(cleanupContext, context: context) + } + continuation.resume(throwing: error) + case .triggerCopyData(let triggerCopy): + let writer = PostgresCopyFromWriter(handler: self, eventLoop: eventLoop) + triggerCopy.resume(returning: writer) + case .sendCopyDone: + self.encoder.copyDone() + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + case .sendCopyFailed(message: let message): + self.encoder.copyFail(message: message) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .forwardRows(let rows): self.rowStream!.receive(rows) diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift index 97805418..c9b9f99f 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift @@ -167,6 +167,25 @@ struct PostgresFrontendMessageEncoder { self.buffer.writeMultipleIntegers(UInt32(8), Self.sslRequestCode) } + mutating func copyData(data: ByteBuffer) { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .copyData, length: UInt32(data.readableBytes)) + self.buffer.writeImmutableBuffer(data) + } + + mutating func copyDone() { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .copyDone, length: 0) + } + + mutating func copyFail(message: String) { + self.clearIfNeeded() + var messageBuffer = ByteBuffer() + messageBuffer.writeNullTerminatedString(message) + self.buffer.psqlWriteMultipleIntegers(id: .copyFail, length: UInt32(messageBuffer.readableBytes)) + self.buffer.writeImmutableBuffer(messageBuffer) + } + mutating func sync() { self.clearIfNeeded() self.buffer.psqlWriteMultipleIntegers(id: .sync, length: 0) @@ -197,6 +216,9 @@ struct PostgresFrontendMessageEncoder { private enum FrontendMessageID: UInt8, Hashable, Sendable { case bind = 66 // B case close = 67 // C + case copyData = 100 // d + case copyDone = 99 // c + case copyFail = 102 // f case describe = 68 // D case execute = 69 // E case flush = 72 // H diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index d541899b..b3821cac 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -379,4 +379,36 @@ final class IntegrationTests: XCTestCase { } } + func testCopyIntoFrom() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let conn = try await PostgresConnection.test(on: eventLoop).get() + defer { XCTAssertNoThrow(try conn.close().wait()) } + + _ = try? await conn.query("DROP TABLE copy_table", logger: .psqlTest).get() + _ = try await conn.query("CREATE TABLE copy_table (id INT, name VARCHAR(100))", logger: .psqlTest).get() + try await conn.copyFrom(table: "copy_table", columns: ["id", "name"], options: CopyFromOptions(delimiter: ","), logger: .psqlTest) { writer in + let records: [(id: Int, name: String)] = [ + (1, "Alice"), + (42, "Bob") + ] + for record in records { + var buffer = ByteBuffer() + buffer.writeString("\(record.id),\(record.name)\n") + try await writer.write(buffer) + } + } + let rows = try await conn.query("SELECT id, name FROM copy_table").get().rows.map { try $0.decode((Int, String).self) } + guard rows.count == 2 else { + XCTFail("Expected 2 columns, received \(rows.count)") + return + } + XCTAssertEqual(rows[0].0, 1) + XCTAssertEqual(rows[0].1, "Alice") + XCTAssertEqual(rows[1].0, 42) + XCTAssertEqual(rows[1].1, "Bob") + } + } diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index ae484acc..872664af 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -114,7 +114,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { .failQuery(promise, with: psqlError, cleanupContext: .init(action: .close, tasks: [], error: psqlError, closePromise: nil))) } - func testExtendedQueryIsCancelledImmediatly() { + func testExtendedQueryIsCancelledImmediately() { var state = ConnectionStateMachine.readyForQuery() let logger = Logger.psqlTest diff --git a/Tests/PostgresNIOTests/New/Extensions/AssertThrowsError.swift b/Tests/PostgresNIOTests/New/Extensions/AssertThrowsError.swift new file mode 100644 index 00000000..078a4b34 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Extensions/AssertThrowsError.swift @@ -0,0 +1,17 @@ +import XCTest + +/// Same as `XCTAssertThrows` but allows the expression to be async +func assertThrowsError( + _ expression: @autoclosure () async throws -> T, + _ message: @autoclosure () -> String = "", + file: StaticString = #filePath, + line: UInt = #line, + errorHandler: (_ error: Error) -> Void = { _ in } +) async { + do { + _ = try await expression() + XCTFail("Expression was expected to throw but did not throw", file: file, line: line) + } catch { + errorHandler(error) + } +} diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift index 9614bf1e..4a153856 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift @@ -28,6 +28,8 @@ struct PSQLBackendMessageEncoder: MessageToByteEncoder { case .commandComplete(let string): self.encode(messageID: message.id, payload: StringPayload(string), into: &buffer) + case .copyInResponse(let copyInResponse): + self.encode(messageID: message.id, payload: copyInResponse, into: &buffer) case .dataRow(let row): self.encode(messageID: message.id, payload: row, into: &buffer) @@ -99,6 +101,8 @@ extension PostgresBackendMessage { return .closeComplete case .commandComplete: return .commandComplete + case .copyInResponse: + return .copyInResponse case .dataRow: return .dataRow case .emptyQueryResponse: @@ -184,6 +188,16 @@ extension PostgresBackendMessage.BackendKeyData: PSQLMessagePayloadEncodable { } } +extension PostgresBackendMessage.CopyInResponseMessage: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeInteger(Int8(self.format.rawValue)) + buffer.writeInteger(Int16(self.columnFormats.count)) + for columnFormat in columnFormats { + buffer.writeInteger(Int16(columnFormat.rawValue)) + } + } +} + extension DataRow: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { buffer.writeInteger(self.columnCount, as: Int16.self) diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index 55ccd0a9..d913da22 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -168,6 +168,18 @@ extension PostgresFrontendMessage { ) ) + case .copyData: + return .copyData(CopyData(data: buffer)) + + case .copyDone: + return .copyDone + + case .copyFail: + guard let message = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + return .copyFail(CopyFail(message: message)) + case .close: preconditionFailure("TODO: Unimplemented") diff --git a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift index 2532959a..7f939151 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift @@ -36,6 +36,14 @@ enum PostgresFrontendMessage: Equatable { let secretKey: Int32 } + struct CopyData: Equatable { + let data: ByteBuffer + } + + struct CopyFail: Equatable { + let message: String + } + enum Close: Hashable { case preparedStatement(String) case portal(String) @@ -170,6 +178,9 @@ enum PostgresFrontendMessage: Equatable { case bind(Bind) case cancel(Cancel) + case copyData(CopyData) + case copyDone + case copyFail(CopyFail) case close(Close) case describe(Describe) case execute(Execute) @@ -186,6 +197,9 @@ enum PostgresFrontendMessage: Equatable { enum ID: UInt8, Equatable { case bind + case copyData + case copyDone + case copyFail case close case describe case execute @@ -201,12 +215,18 @@ enum PostgresFrontendMessage: Equatable { switch rawValue { case UInt8(ascii: "B"): self = .bind + case UInt8(ascii: "c"): + self = .copyDone case UInt8(ascii: "C"): self = .close + case UInt8(ascii: "d"): + self = .copyData case UInt8(ascii: "D"): self = .describe case UInt8(ascii: "E"): self = .execute + case UInt8(ascii: "f"): + self = .copyFail case UInt8(ascii: "H"): self = .flush case UInt8(ascii: "P"): @@ -230,6 +250,12 @@ enum PostgresFrontendMessage: Equatable { switch self { case .bind: return UInt8(ascii: "B") + case .copyData: + return UInt8(ascii: "d") + case .copyDone: + return UInt8(ascii: "c") + case .copyFail: + return UInt8(ascii: "f") case .close: return UInt8(ascii: "C") case .describe: @@ -263,6 +289,12 @@ extension PostgresFrontendMessage { return .bind case .cancel: preconditionFailure("Cancel messages don't have an identifier") + case .copyData: + return .copyData + case .copyDone: + return .copyDone + case .copyFail: + return .copyFail case .close: return .close case .describe: diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index d0f8e2b0..f3c3c7b5 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -95,10 +95,7 @@ class PostgresConnectionTests: XCTestCase { let listenMessage = try await channel.waitForUnpreparedRequest() XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -107,10 +104,7 @@ class PostgresConnectionTests: XCTestCase { let unlistenMessage = try await channel.waitForUnpreparedRequest() XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -155,10 +149,7 @@ class PostgresConnectionTests: XCTestCase { let listenMessage = try await channel.waitForUnpreparedRequest() XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -168,10 +159,7 @@ class PostgresConnectionTests: XCTestCase { let unlistenMessage = try await channel.waitForUnpreparedRequest() XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -204,10 +192,7 @@ class PostgresConnectionTests: XCTestCase { let listenMessage = try await channel.waitForUnpreparedRequest() XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -283,7 +268,7 @@ class PostgresConnectionTests: XCTestCase { } } - func testCloseClosesImmediatly() async throws { + func testCloseClosesImmediately() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in @@ -638,6 +623,209 @@ class PostgresConnectionTests: XCTestCase { } } + func testCopyDataSucceeds() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + try await connection.copyFrom(table: "copy_table", logger: .psqlTest) { writer in + try await writer.write(ByteBuffer(staticString: "1\tAlice\n")) + } + } + + let copyMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(copyMessage.parse.query, "COPY copy_table FROM STDIN") + XCTAssertEqual(copyMessage.bind.parameters, []) + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() + try await channel.sendCopyInResponseForTwoTextualColumns() + let data = try await channel.waitForCopyData() + XCTAssertEqual(String(buffer: data.data), "1\tAlice\n") + XCTAssertEqual(data.result, .done) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("COPY 1")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + } + } + + func testCopyDataWriterFails() async throws { + struct MyError: Error, CustomStringConvertible { + var description: String { "My error" } + } + + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + await assertThrowsError( + try await connection.copyFrom(table: "copy_table", logger: .psqlTest) { writer in + throw MyError() + } + ) { error in + XCTAssert(error is MyError, "Expected error of type MyError, got \(error)") + } + } + + _ = try await channel.waitForUnpreparedRequest() + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() + try await channel.sendCopyInResponseForTwoTextualColumns() + let data = try await channel.waitForCopyData() + XCTAssertEqual(data.result, .failed(message: "My error")) + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: "COPY from stdin failed: My error", + .sqlState : "57014" // query_canceled + ]))) + + // Ensure we get a `sync` message so that the backend transitions out of copy mode. + let syncMessage = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) + XCTAssertEqual(syncMessage, .sync) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + } + } + + func testBackendSendsErrorDuringCopy() async throws { + struct MyError: Error, CustomStringConvertible { + var description: String { "My error" } + } + + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + await assertThrowsError( + try await connection.copyFrom(table: "copy_table", logger: .psqlTest) { writer in + try await writer.write(ByteBuffer(staticString: "1Alice\n")) + } + ) { error in + XCTAssertEqual((error as? PSQLError)?.serverInfo?.underlying.fields[.sqlState], "22P02") + } + } + + _ = try await channel.waitForUnpreparedRequest() + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() + try await channel.sendCopyInResponseForTwoTextualColumns() + _ = try await channel.waitForCopyData() + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: #"invalid input syntax for type integer: "1Alice""#, + .sqlState : "22P02" // invalid_text_representation + ]))) + + // Ensure we get a `sync` message so that the backend transitions out of copy mode. + let syncMessage = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) + XCTAssertEqual(syncMessage, .sync) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + } + } + + + func testBackendSendsErrorDuringCopyBeforeCopyDone() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + let backendDidSendErrorExpectation = self.expectation(description: "Backend did send error") + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + await assertThrowsError( + try await connection.copyFrom(table: "copy_table", logger: .psqlTest) { writer in + try await writer.write(ByteBuffer(staticString: "1Alice\n")) + channel.flush() + _ = await XCTWaiter.fulfillment(of: [backendDidSendErrorExpectation]) + } + ) { error in + XCTAssertEqual((error as? PSQLError)?.serverInfo?.underlying.fields[.sqlState], "22P02") + } + } + + _ = try await channel.waitForUnpreparedRequest() + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() + try await channel.sendCopyInResponseForTwoTextualColumns() + + let copyDataMessage = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) + XCTAssertEqual(copyDataMessage, .copyData(PostgresFrontendMessage.CopyData(data: ByteBuffer(staticString: "1Alice\n")))) + + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: #"invalid input syntax for type integer: "1Alice""#, + .sqlState : "22P02" // invalid_text_representation + ]))) + backendDidSendErrorExpectation.fulfill() + + // We don't expect to receive a CopyDone or CopyFail message if the server sent us an error. + + // Ensure we get a `sync` message so that the backend transitions out of copy mode. + let syncMessage = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) + XCTAssertEqual(syncMessage, .sync) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + } + } + + func testWriteDataClosureTerminatesWhenServerThrowsError() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + let expectation = self.expectation(description: "Backend sent error") + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + await assertThrowsError( + try await connection.copyFrom(table: "copy_table", logger: .psqlTest) { writer in + try await writer.write(ByteBuffer(staticString: "1Alice\n")) + channel.flush() + _ = await XCTWaiter.fulfillment(of: [expectation]) + do { + try await writer.write(ByteBuffer(staticString: "2\tBob\n")) + XCTFail("Expected error to be thrown") + } catch { + XCTAssert(error is PostgresCopyFromWriter.CopyCancellationError, "Received unexpected error: \(error)") + throw error + } + } + ) { error in + XCTAssertEqual((error as? PSQLError)?.serverInfo?.underlying.fields[.sqlState], "22P02") + } + } + + _ = try await channel.waitForUnpreparedRequest() + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() + try await channel.sendCopyInResponseForTwoTextualColumns() + + let dataMessage = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) + XCTAssertEqual(dataMessage, .copyData(PostgresFrontendMessage.CopyData(data: ByteBuffer(staticString: "1Alice\n")))) + + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: #"invalid input syntax for type integer: "1Alice""#, + .sqlState : "22P02" // invalid_text_representation + ]))) + expectation.fulfill() + + // We don't expect to receive a CopyDone or CopyFail message if the server sent us an error. + + // Ensure we get a `sync` message so that the backend transitions out of copy mode. + let syncMessage = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) + XCTAssertEqual(syncMessage, .sync) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + } + } + + func testCopyDataWithOptions() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + try await connection.copyFrom(table: "copy_table", columns: ["id", "name"], options: CopyFromOptions(delimiter: ","), logger: .psqlTest) { writer in + try await writer.write(ByteBuffer(staticString: "1,Alice\n")) + } + } + + let copyMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(copyMessage.parse.query, "COPY copy_table(id,name) FROM STDIN WITH (DELIMITER ',')") + XCTAssertEqual(copyMessage.bind.parameters, []) + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() + try await channel.sendCopyInResponseForTwoTextualColumns() + let data = try await channel.waitForCopyData() + XCTAssertEqual(String(buffer: data.data), "1,Alice\n") + XCTAssertEqual(data.result, .done) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("COPY 1")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + } + } + func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in @@ -692,6 +880,38 @@ extension NIOAsyncTestingChannel { return UnpreparedRequest(parse: parse, describe: describe, bind: bind, execute: execute) } + struct CopyDataRequest { + enum Result: Equatable { + /// The data copy finished successfully with a `CopyDone` message. + case done + /// The data copy finished with a `CopyFail` message containing the following error message. + case failed(message: String) + } + + /// The data that was transferred. + var data: ByteBuffer + + /// The `CopyDone` or `CopyFail` message that finalized the data transfer. + var result: Result + } + + func waitForCopyData() async throws -> CopyDataRequest { + var copiedData = ByteBuffer() + while true { + let message = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + switch message { + case .copyData(let data): + copiedData.writeImmutableBuffer(data.data) + case .copyDone: + return CopyDataRequest(data: copiedData, result: .done) + case .copyFail(let message): + return CopyDataRequest(data: copiedData, result: .failed(message: message.message)) + default: + fatalError("Unexpected message") + } + } + } + func waitForPrepareRequest() async throws -> PrepareRequest { let parse = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) let describe = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) @@ -751,6 +971,18 @@ extension NIOAsyncTestingChannel { try await self.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) try await self.testingEventLoop.executeInContext { self.read() } } + + /// Send the messages up to `BindComplete` for an unnamed query that does not bind any parameters. + func sendUnpreparedRequestWithNoParametersBindResponse() async throws { + try await writeInbound(PostgresBackendMessage.parseComplete) + try await writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await writeInbound(PostgresBackendMessage.noData) + try await writeInbound(PostgresBackendMessage.bindComplete) + } + + func sendCopyInResponseForTwoTextualColumns() async throws { + try await writeInbound(PostgresBackendMessage.copyInResponse(.init(format: .textual, columnFormats: [.textual, .textual]))) + } } struct UnpreparedRequest {