Skip to content

WIP: Implement COPY … FROM STDIN #566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions Sources/PostgresNIO/Connection/PostgresConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,168 @@ extension PostgresConnection {
}
}

// MARK: Copy from

/// A handle to send
public struct PostgresCopyFromWriter: Sendable {
/// The backend failed the copy data transfer, which means that no more data sent by the frontend would be processed.
///
/// The `PostgresCopyFromWriter` should cancel the data transfer.
public struct CopyCancellationError: Error {
/// The error that the backend sent us which cancelled the data transfer.
///
/// Note that this error is related to previous `write` calls since a `CopyCancellationError` is thrown before
/// new data is written by `write`.
let underlyingError: PSQLError
}

private let channelHandler: NIOLoopBound<PostgresChannelHandler>
private let eventLoop: any EventLoop

init(handler: PostgresChannelHandler, eventLoop: any EventLoop) {
self.channelHandler = NIOLoopBound(handler, eventLoop: eventLoop)
self.eventLoop = eventLoop
}

/// Send data for a `COPY ... FROM STDIN` operation to the backend.
///
/// If the backend encountered an error during the data transfer and thus cannot process any more data, this throws
/// a `CopyFailedError`.
public func write(_ byteBuffer: ByteBuffer) async throws {
// First, wait that we have a writable buffer. This also throws a `CopyFailedError` in case the backend sent an
// error during the data transfer and thus cannot process any more data.
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
if eventLoop.inEventLoop {
self.channelHandler.value.waitForWritableBuffer(continuation)
} else {
eventLoop.execute {
self.channelHandler.value.waitForWritableBuffer(continuation)
}
}
}

// Run the actual data transfer
if eventLoop.inEventLoop {
self.channelHandler.value.copyData(byteBuffer)
} else {
eventLoop.execute {
self.channelHandler.value.copyData(byteBuffer)
}
}
}

/// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyDone` message to
/// the backend.
func done() async throws {
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
if eventLoop.inEventLoop {
self.channelHandler.value.sendCopyDone(continuation: continuation)
} else {
eventLoop.execute {
self.channelHandler.value.sendCopyDone(continuation: continuation)
}
}
}
}

/// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyFail` message to
/// the backend.
func failed(error: any Error) async throws {
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
if eventLoop.inEventLoop {
self.channelHandler.value.sendCopyFailed(message: "\(error)", continuation: continuation)
} else {
eventLoop.execute {
self.channelHandler.value.sendCopyFailed(message: "\(error)", continuation: continuation)
}
}
}
}

/// Send a `Sync` message to the backend.
func sync() {
if eventLoop.inEventLoop {
self.channelHandler.value.sendSync()
} else {
eventLoop.execute {
self.channelHandler.value.sendSync()
}
}
}
}

public struct CopyFromOptions {
let delimiter: StaticString?

public init(delimiter: StaticString? = nil) {
self.delimiter = delimiter
}
}

private func buildCopyFromQuery(
table: String,
columns: [StaticString]?,
options: CopyFromOptions
) -> PostgresQuery {
var query = "COPY \(table)"
if let columns {
// TODO: Is using `StaticString` sufficient here to prevent against SQL injection attacks or should we try to
// escape the identifiers, essentially re-implementing `PQescapeIdentifier`?
query += "(" + columns.map(\.description).joined(separator: ",") + ")"
}
query += " FROM STDIN"
var queryOptions: [String] = []
if let delimiter = options.delimiter {
queryOptions.append("DELIMITER '\(delimiter)'")
}
if !queryOptions.isEmpty {
query += " WITH "
query += queryOptions.map { "(\($0))" }.joined(separator: ",")
}
return "\(unescaped: query)"
}

extension PostgresConnection {
// TODO: Instead of an arbitrary query, make this a structured data structure.
// TODO: Write doc comment
public func copyFrom(
table: String,
columns: [StaticString]? = nil,
options: CopyFromOptions = CopyFromOptions(),
logger: Logger,
writeData: @escaping @Sendable (PostgresCopyFromWriter) async throws -> Void,
file: String = #fileID,
line: Int = #line
) async throws {
var logger = logger
logger[postgresMetadataKey: .connectionID] = "\(self.id)"
let writer = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<PostgresCopyFromWriter, any Error>) in
let context = ExtendedQueryContext(
copyFromQuery: buildCopyFromQuery(table: table, columns: columns, options: options),
triggerCopy: continuation,
logger: logger
)
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
}

do {
try await writeData(writer)
} catch let error as PostgresCopyFromWriter.CopyCancellationError {
// If the copy was cancelled because the backend sent us an error, we need to send a `Sync` message to put
// the backend out of the copy mode.
writer.sync()
throw error.underlyingError
} catch {
// Throw the error from the `writeData` closure instead of the one that Postgres gives us upon receiving the
// `CopyFail` message.
try? await writer.failed(error: error)
throw error
}
try await writer.done()
}

}

// MARK: PostgresDatabase conformance

extension PostgresConnection: PostgresDatabase {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,24 @@ struct ConnectionStateMachine {
case sendParseDescribeBindExecuteSync(PostgresQuery)
case sendBindExecuteSync(PSQLExecuteStatement)
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError, cleanupContext: CleanUpContext?)
/// Fail a query's execution by throwing an error on the given continuation. When `sync` is `true`, send a
/// `sync` message to the backend.
case failQueryContinuation(any AnyErrorContinuation, with: PSQLError, cleanupContext: CleanUpContext?, sync: Bool)
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)
case succeedQueryContinuation(CheckedContinuation<Void, any Error>)

/// Trigger a data transfer returning a `PostgresCopyFromWriter` to the given continuation.
///
/// Once the data transfer is triggered, it will send `CopyData` messages to the backend. After that the state
/// machine needs to be prodded again to send a `CopyDone` or `CopyFail` by calling
/// `PostgresChannelHandler.copyDone` or ``PostgresChannelHandler.copyFailed``.
case triggerCopyData(CheckedContinuation<PostgresCopyFromWriter, any Error>)

/// Send a `CopyDone` message to the backend, followed by a `Sync`.
case sendCopyDone

/// Send a `CopyFail` message to the backend with the given error message.
case sendCopyFailed(message: String)

// --- streaming actions
// actions if query has requested next row but we are waiting for backend
Expand All @@ -106,6 +123,27 @@ struct ConnectionStateMachine {
case succeedClose(CloseCommandContext)
case failClose(CloseCommandContext, with: PSQLError, cleanupContext: CleanUpContext?)
}

enum ChannelWritabilityChangedAction {
/// No action needs to be taken based on the writability change.
case none

/// Resume the given continuation successfully.
case resumeContinuation(CheckedContinuation<Void, any Error>)
}

enum WaitForWritableBufferAction {
/// The channel has backpressure and cannot handle any data right now. We should flush the channel to help
/// relieve backpressure. Once the channel is writable again, this will be communicated via
/// `channelWritabilityChanged`
case waitForBackpressureRelieve

/// Resume the given continuation successfully.
case resumeContinuation(CheckedContinuation<Void, any Error>)

/// Fail the continuation with the given error.
case failContinuation(CheckedContinuation<Void, any Error>, error: any Error)
}

private var state: State
private let requireBackendKeyData: Bool
Expand Down Expand Up @@ -587,6 +625,8 @@ struct ConnectionStateMachine {
switch queryContext.query {
case .executeStatement(_, let promise), .unnamed(_, let promise):
return .failQuery(promise, with: psqlErrror, cleanupContext: nil)
case .copyFrom(_, let triggerCopy):
return .failQueryContinuation(triggerCopy, with: psqlErrror, cleanupContext: nil, sync: false)
case .prepareStatement(_, _, _, let promise):
return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil)
}
Expand Down Expand Up @@ -660,6 +700,16 @@ struct ConnectionStateMachine {
preconditionFailure("Invalid state: \(self.state)")
}
}

mutating func channelWritabilityChanged(isWritable: Bool) -> ChannelWritabilityChangedAction {
guard case .extendedQuery(var queryState, let connectionContext) = state else {
return .none
}
self.state = .modifying // avoid CoW
let action = queryState.channelWritabilityChanged(isWritable: isWritable)
self.state = .extendedQuery(queryState, connectionContext)
return action
}

// MARK: - Running Queries -

Expand Down Expand Up @@ -751,6 +801,59 @@ struct ConnectionStateMachine {
self.state = .extendedQuery(queryState, connectionContext)
return self.modify(with: action)
}

mutating func copyInResponseReceived(
_ copyInResponse: PostgresBackendMessage.CopyInResponseMessage
) -> ConnectionAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.emptyQueryResponse))
}

self.state = .modifying // avoid CoW
let action = queryState.copyInResponseReceived(copyInResponse)
self.state = .extendedQuery(queryState, connectionContext)
return self.modify(with: action)
}

/// Wait fo `channel` to be writable and be able to handle more `CopyData` messages. Resume the given continuation
/// when the channel is able handle more data.
///
/// This fails the continuation with a `PostgresCopyFromWriter.CopyCancellationError` when the server has cancelled
/// the data transfer to indicate that the frontend should not send any more data.
mutating func waitForWritableBuffer(channel: any Channel, continuation: CheckedContinuation<Void, any Error>) -> WaitForWritableBufferAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
preconditionFailure("Copy mode is only supported for extended queries")
}

self.state = .modifying // avoid CoW
let action = queryState.waitForWritableBuffer(channel: channel, continuation: continuation)
self.state = .extendedQuery(queryState, connectionContext)
return action
}

/// Put the state machine out of the copying mode and send a `CopyDone` message to the backend.
mutating func sendCopyDone(continuation: CheckedContinuation<Void, any Error>) -> ConnectionAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
preconditionFailure("Copy mode is only supported for extended queries")
}

self.state = .modifying // avoid CoW
let action = queryState.sendCopyDone(continuation: continuation)
self.state = .extendedQuery(queryState, connectionContext)
return self.modify(with: action)
}

/// Put the state machine out of the copying mode and send a `CopyFail` message to the backend.
mutating func sendCopyFail(message: String, continuation: CheckedContinuation<Void, any Error>) -> ConnectionAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
preconditionFailure("Copy mode is only supported for extended queries")
}

self.state = .modifying // avoid CoW
let action = queryState.sendCopyFailed(message: message, continuation: continuation)
self.state = .extendedQuery(queryState, connectionContext)
return self.modify(with: action)
}

mutating func emptyQueryResponseReceived() -> ConnectionAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
Expand Down Expand Up @@ -860,14 +963,21 @@ struct ConnectionStateMachine {
.forwardRows,
.forwardStreamComplete,
.wait,
.read:
.read,
.triggerCopyData,
.sendCopyDone,
.sendCopyFailed,
.succeedQueryContinuation:
preconditionFailure("Invalid query state machine action in state: \(self.state), action: \(action)")

case .evaluateErrorAtConnectionLevel:
return .closeConnectionAndCleanup(cleanupContext)

case .failQuery(let queryContext, with: let error):
return .failQuery(queryContext, with: error, cleanupContext: cleanupContext)
case .failQuery(let promise, with: let error):
return .failQuery(promise, with: error, cleanupContext: cleanupContext)

case .failQueryContinuation(let continuation, with: let error, let sync):
return .failQueryContinuation(continuation, with: error, cleanupContext: cleanupContext, sync: sync)

case .forwardStreamError(let error, let read):
return .forwardStreamError(error, read: read, cleanupContext: cleanupContext)
Expand Down Expand Up @@ -1038,8 +1148,19 @@ extension ConnectionStateMachine {
case .failQuery(let requestContext, with: let error):
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
return .failQuery(requestContext, with: error, cleanupContext: cleanupContext)
case .failQueryContinuation(let continuation, with: let error, let sync):
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
return .failQueryContinuation(continuation, with: error, cleanupContext: cleanupContext, sync: sync)
case .succeedQuery(let requestContext, with: let result):
return .succeedQuery(requestContext, with: result)
case .succeedQueryContinuation(let continuation):
return .succeedQueryContinuation(continuation)
case .triggerCopyData(let triggerCopy):
return .triggerCopyData(triggerCopy)
case .sendCopyDone:
return .sendCopyDone
case .sendCopyFailed(message: let message):
return .sendCopyFailed(message: message)
case .forwardRows(let buffer):
return .forwardRows(buffer)
case .forwardStreamComplete(let buffer, let commandTag):
Expand Down
Loading