Skip to content

Support binary data transfer in COPY FROM #573

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
404 changes: 404 additions & 0 deletions Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,33 @@ struct ConnectionStateMachine {
case sendParseDescribeBindExecuteSync(PostgresQuery)
case sendBindExecuteSync(PSQLExecuteStatement)
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError, cleanupContext: CleanUpContext?)
/// Fail a query's execution by resuming the continuation with the given error. When `sync` is `true`, send a
/// `Sync` message to the backend.
case failQueryContinuation(AnyErrorContinuation, with: PSQLError, sync: Bool, cleanupContext: CleanUpContext?)
/// Fail a query's execution by resuming the continuation with the given error and send a `Sync` message to the
/// backend.
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)
/// Succeed the continuation with a void result. When `sync` is `true`, send a `Sync` message to the backend.
case succeedQueryContinuation(CheckedContinuation<Void, any Error>, sync: Bool)

/// Trigger a data transfer returning a `PostgresCopyFromWriter` to the given continuation.
///
/// Once the data transfer is triggered, it will send `CopyData` messages to the backend. After that the state
/// machine needs to be prodded again to send a `CopyDone` or `CopyFail` by calling
/// `PostgresChannelHandler.sendCopyDone` or `PostgresChannelHandler.sendCopyFail`.
case triggerCopyData(CheckedContinuation<PostgresCopyFromWriter, any Error>)

/// Send a `CopyDone` and `Sync` message to the backend.
case sendCopyDoneAndSync

/// Send a `CopyFail` message to the backend with the given error message.
case sendCopyFail(message: String)

/// Fail the promise with the given error and close the connection.
///
/// This is used when we want to cancel a COPY operation while waiting for backpressure relieve. In that case we
/// can't recover the connection because we can't send any messages to the backend, so we need to close it.
case failPromiseAndCloseConnection(EventLoopPromise<Void>, error: PSQLError, cleanupContext: CleanUpContext)

// --- streaming actions
// actions if query has requested next row but we are waiting for backend
Expand All @@ -107,6 +133,25 @@ struct ConnectionStateMachine {
case failClose(CloseCommandContext, with: PSQLError, cleanupContext: CleanUpContext?)
}

enum ChannelWritabilityChangedAction {
/// No action needs to be taken based on the writability change.
case none

/// Resume the given continuation successfully.
case succeedPromise(EventLoopPromise<Void>)
}

enum CheckBackendCanReceiveCopyDataAction {
/// Don't perform any action.
case none

/// Succeed the promise with a Void result.
case succeedPromise(EventLoopPromise<Void>)

/// Fail the promise with the given error.
case failPromise(EventLoopPromise<Void>, error: any Error)
}

private var state: State
private let requireBackendKeyData: Bool
private var taskQueue = CircularBuffer<PSQLTask>()
Expand Down Expand Up @@ -587,6 +632,8 @@ struct ConnectionStateMachine {
switch queryContext.query {
case .executeStatement(_, let promise), .unnamed(_, let promise):
return .failQuery(promise, with: psqlErrror, cleanupContext: nil)
case .copyFrom(_, let triggerCopy):
return .failQueryContinuation(.copyFromWriter(triggerCopy), with: psqlErrror, sync: false, cleanupContext: nil)
case .prepareStatement(_, _, _, let promise):
return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil)
}
Expand Down Expand Up @@ -660,6 +707,16 @@ struct ConnectionStateMachine {
preconditionFailure("Invalid state: \(self.state)")
}
}

mutating func channelWritabilityChanged(isWritable: Bool) -> ChannelWritabilityChangedAction {
guard case .extendedQuery(var queryState, let connectionContext) = state else {
return .none
}
self.state = .modifying // avoid CoW
let action = queryState.channelWritabilityChanged(isWritable: isWritable)
self.state = .extendedQuery(queryState, connectionContext)
return action
}

// MARK: - Running Queries -

Expand Down Expand Up @@ -752,10 +809,56 @@ struct ConnectionStateMachine {
return self.modify(with: action)
}

mutating func copyInResponseReceived(
_ copyInResponse: PostgresBackendMessage.CopyInResponse
) -> ConnectionAction {
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse)))
mutating func copyInResponseReceived(_ copyInResponse: PostgresBackendMessage.CopyInResponse) -> ConnectionAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse)))
}

self.state = .modifying // avoid CoW
let action = queryState.copyInResponseReceived(copyInResponse)
self.state = .extendedQuery(queryState, connectionContext)
return self.modify(with: action)
}


/// Succeed the promise when the channel to the backend is writable and the backend is ready to receive more data.
///
/// The promise may be failed if the backend indicated that it can't handle any more data by sending an
/// `ErrorResponse`. This is mostly the case when malformed data is sent to it. In that case, the data transfer
/// should be aborted to avoid unnecessary work.
mutating func checkBackendCanReceiveCopyData(channelIsWritable: Bool, promise: EventLoopPromise<Void>) -> CheckBackendCanReceiveCopyDataAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
preconditionFailure("Copy mode is only supported for extended queries")
}

self.state = .modifying // avoid CoW
let action = queryState.checkBackendCanReceiveCopyData(channelIsWritable: channelIsWritable, promise: promise)
self.state = .extendedQuery(queryState, connectionContext)
return action
}

/// Put the state machine out of the copying mode and send a `CopyDone` message to the backend.
mutating func sendCopyDone(continuation: CheckedContinuation<Void, any Error>) -> ConnectionAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
preconditionFailure("Copy mode is only supported for extended queries")
}

self.state = .modifying // avoid CoW
let action = queryState.sendCopyDone(continuation: continuation)
self.state = .extendedQuery(queryState, connectionContext)
return self.modify(with: action)
}

/// Put the state machine out of the copying mode and send a `CopyFail` message to the backend.
mutating func sendCopyFail(message: String, continuation: CheckedContinuation<Void, any Error>) -> ConnectionAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
preconditionFailure("Copy mode is only supported for extended queries")
}

self.state = .modifying // avoid CoW
let action = queryState.sendCopyFail(message: message, continuation: continuation)
self.state = .extendedQuery(queryState, connectionContext)
return self.modify(with: action)
}

mutating func emptyQueryResponseReceived() -> ConnectionAction {
Expand All @@ -782,9 +885,10 @@ struct ConnectionStateMachine {

// MARK: Consumer

mutating func cancelQueryStream() -> ConnectionAction {
mutating func cancel() -> ConnectionAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
preconditionFailure("Tried to cancel stream without active query")
// We are not in a state in which we can cancel. Do nothing.
return .wait
}

self.state = .modifying // avoid CoW
Expand Down Expand Up @@ -866,14 +970,22 @@ struct ConnectionStateMachine {
.forwardRows,
.forwardStreamComplete,
.wait,
.read:
.read,
.triggerCopyData,
.sendCopyDoneAndSync,
.sendCopyFail,
.succeedQueryContinuation,
.failPromiseAndCloseConnection:
preconditionFailure("Invalid query state machine action in state: \(self.state), action: \(action)")

case .evaluateErrorAtConnectionLevel:
return .closeConnectionAndCleanup(cleanupContext)

case .failQuery(let queryContext, with: let error):
return .failQuery(queryContext, with: error, cleanupContext: cleanupContext)
case .failQuery(let promise, with: let error):
return .failQuery(promise, with: error, cleanupContext: cleanupContext)

case .failQueryContinuation(let continuation, with: let error, let sync):
return .failQueryContinuation(continuation, with: error, sync: sync, cleanupContext: cleanupContext)

case .forwardStreamError(let error, let read):
return .forwardStreamError(error, read: read, cleanupContext: cleanupContext)
Expand Down Expand Up @@ -1044,8 +1156,22 @@ extension ConnectionStateMachine {
case .failQuery(let requestContext, with: let error):
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
return .failQuery(requestContext, with: error, cleanupContext: cleanupContext)
case .failQueryContinuation(let continuation, with: let error, let sync):
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
return .failQueryContinuation(continuation, with: error, sync: sync, cleanupContext: cleanupContext)
case .succeedQuery(let requestContext, with: let result):
return .succeedQuery(requestContext, with: result)
case .succeedQueryContinuation(let continuation, let sync):
return .succeedQueryContinuation(continuation, sync: sync)
case .triggerCopyData(let triggerCopy):
return .triggerCopyData(triggerCopy)
case .sendCopyDoneAndSync:
return .sendCopyDoneAndSync
case .sendCopyFail(message: let message):
return .sendCopyFail(message: message)
case .failPromiseAndCloseConnection(let promise, error: let error):
let cleanupContext = self.setErrorAndCreateCleanupContext(error)
return .failPromiseAndCloseConnection(promise, error: error, cleanupContext: cleanupContext)
case .forwardRows(let buffer):
return .forwardRows(buffer)
case .forwardStreamComplete(let buffer, let commandTag):
Expand Down
Loading
Loading