Skip to content

Commit

Permalink
Decouple SimpleQuery from ExtendedQuery
Browse files Browse the repository at this point in the history
  • Loading branch information
MahdiBM committed Aug 26, 2024
1 parent 4d01e30 commit ffa0dc6
Show file tree
Hide file tree
Showing 8 changed files with 978 additions and 251 deletions.
8 changes: 4 additions & 4 deletions Sources/PostgresNIO/Connection/PostgresConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ extension PostgresConnection {
}

/// Run a simple text-only query on the Postgres server the connection is connected to.
/// WARNING: This functions is not yet API and is incomplete.
/// WARNING: This function is not yet API and is incomplete.
/// The return type will change to another stream.
///
/// - Parameters:
Expand All @@ -460,13 +460,13 @@ extension PostgresConnection {
logger[postgresMetadataKey: .connectionID] = "\(self.id)"

let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
let context = ExtendedQueryContext(
simpleQuery: query,
let context = SimpleQueryContext(
query: query,
logger: logger,
promise: promise
)

self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
self.channel.write(HandlerTask.simpleQuery(context), promise: nil)

do {
return try await promise.futureResult.map({ $0.asyncSequence() }).get()
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ struct ExtendedQueryStateMachine {
case sendParseDescribeBindExecuteSync(PostgresQuery)
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType])
case sendBindExecuteSync(PSQLExecuteStatement)
case sendQuery(String)


// --- general actions
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError)
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)
Expand Down Expand Up @@ -86,12 +85,6 @@ struct ExtendedQueryStateMachine {
state = .messagesSent(queryContext)
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes)
}

case .simpleQuery(let query, _):
return self.avoidingStateMachineCoW { state -> Action in
state = .messagesSent(queryContext)
return .sendQuery(query)
}
}
}

Expand All @@ -112,7 +105,7 @@ struct ExtendedQueryStateMachine {

self.isCancelled = true
switch queryContext.query {
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise), .simpleQuery(_, let eventLoopPromise):
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise):
return .failQuery(eventLoopPromise, with: .queryCancelled)

case .prepareStatement(_, _, _, let eventLoopPromise):
Expand Down Expand Up @@ -178,19 +171,11 @@ struct ExtendedQueryStateMachine {
state = .noDataMessageReceived(queryContext)
return .succeedPreparedStatementCreation(promise, with: nil)
}

case .simpleQuery:
return self.setAndFireError(.unexpectedBackendMessage(.noData))
}
}

mutating func rowDescriptionReceived(_ rowDescription: RowDescription) -> Action {
let queryContext: ExtendedQueryContext
switch self.state {
case .messagesSent(let extendedQueryContext),
.parameterDescriptionReceived(let extendedQueryContext):
queryContext = extendedQueryContext
default:
guard case .parameterDescriptionReceived(let queryContext) = self.state else {
return self.setAndFireError(.unexpectedBackendMessage(.rowDescription(rowDescription)))
}

Expand All @@ -213,7 +198,7 @@ struct ExtendedQueryStateMachine {
}

switch queryContext.query {
case .unnamed, .executeStatement, .simpleQuery:
case .unnamed, .executeStatement:
return .wait

case .prepareStatement(_, _, _, let eventLoopPromise):
Expand All @@ -234,9 +219,6 @@ struct ExtendedQueryStateMachine {

case .prepareStatement:
return .evaluateErrorAtConnectionLevel(.unexpectedBackendMessage(.bindComplete))

case .simpleQuery:
return self.setAndFireError(.unexpectedBackendMessage(.bindComplete))
}

case .noDataMessageReceived(let queryContext):
Expand Down Expand Up @@ -276,40 +258,20 @@ struct ExtendedQueryStateMachine {
return .wait
}

case .rowDescriptionReceived(let queryContext, let columns):
switch queryContext.query {
case .simpleQuery(_, let eventLoopPromise):
// When receiving a data row, we must ensure that the data row column count
// matches the previously received row description column count.
guard dataRow.columnCount == columns.count else {
return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow)))
}

return self.avoidingStateMachineCoW { state -> Action in
var demandStateMachine = RowStreamStateMachine()
demandStateMachine.receivedRow(dataRow)
state = .streaming(columns, demandStateMachine)
let result = QueryResult(value: .rowDescription(columns), logger: queryContext.logger)
return .succeedQuery(eventLoopPromise, with: result)
}

case .unnamed, .executeStatement, .prepareStatement:
return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow)))
}

case .drain(let columns):
guard dataRow.columnCount == columns.count else {
return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow)))
}
// we ignore all rows and wait for readyForQuery
return .wait

case .initialized,
.messagesSent,
.parseCompleteReceived,
.parameterDescriptionReceived,
.noDataMessageReceived,
.emptyQueryResponseReceived,
.rowDescriptionReceived,
.bindCompleteReceived,
.commandComplete,
.error:
Expand All @@ -330,36 +292,10 @@ struct ExtendedQueryStateMachine {
return .succeedQuery(eventLoopPromise, with: result)
}

case .prepareStatement, .simpleQuery:
case .prepareStatement:
preconditionFailure("Invalid state: \(self.state)")
}

case .messagesSent(let context):
switch context.query {
case .simpleQuery(_, let eventLoopGroup):
return self.avoidingStateMachineCoW { state -> Action in
state = .commandComplete(commandTag: commandTag)
let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger)
return .succeedQuery(eventLoopGroup, with: result)
}

case .unnamed, .executeStatement, .prepareStatement:
return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag)))
}

case .rowDescriptionReceived(let context, _):
switch context.query {
case .simpleQuery(_, let eventLoopPromise):
return self.avoidingStateMachineCoW { state -> Action in
state = .commandComplete(commandTag: commandTag)
let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger)
return .succeedQuery(eventLoopPromise, with: result)
}

case .unnamed, .executeStatement, .prepareStatement:
return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag)))
}


case .streaming(_, var demandStateMachine):
return self.avoidingStateMachineCoW { state -> Action in
state = .commandComplete(commandTag: commandTag)
Expand All @@ -370,12 +306,14 @@ struct ExtendedQueryStateMachine {
precondition(self.isCancelled)
self.state = .commandComplete(commandTag: commandTag)
return .wait

case .initialized,
.messagesSent,
.parseCompleteReceived,
.parameterDescriptionReceived,
.noDataMessageReceived,
.emptyQueryResponseReceived,
.rowDescriptionReceived,
.commandComplete,
.error:
return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag)))
Expand All @@ -385,32 +323,20 @@ struct ExtendedQueryStateMachine {
}

mutating func emptyQueryResponseReceived() -> Action {
switch self.state {
case .bindCompleteReceived(let queryContext):
switch queryContext.query {
case .unnamed(_, let eventLoopPromise),
.executeStatement(_, let eventLoopPromise):
return self.avoidingStateMachineCoW { state -> Action in
state = .emptyQueryResponseReceived
let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger)
return .succeedQuery(eventLoopPromise, with: result)
}
guard case .bindCompleteReceived(let queryContext) = self.state else {
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
}

case .prepareStatement, .simpleQuery:
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
}
case .messagesSent(let queryContext):
switch queryContext.query {
case .simpleQuery(_, let eventLoopPromise):
return self.avoidingStateMachineCoW { state -> Action in
state = .emptyQueryResponseReceived
let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger)
return .succeedQuery(eventLoopPromise, with: result)
}
case .unnamed, .executeStatement, .prepareStatement:
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
switch queryContext.query {
case .unnamed(_, let eventLoopPromise),
.executeStatement(_, let eventLoopPromise):
return self.avoidingStateMachineCoW { state -> Action in
state = .emptyQueryResponseReceived
let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger)
return .succeedQuery(eventLoopPromise, with: result)
}
default:

case .prepareStatement(_, _, _, _):
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
}
}
Expand Down Expand Up @@ -571,7 +497,7 @@ struct ExtendedQueryStateMachine {
return .evaluateErrorAtConnectionLevel(error)
} else {
switch context.query {
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise), .simpleQuery(_, let eventLoopPromise):
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise):
return .failQuery(eventLoopPromise, with: error)
case .prepareStatement(_, _, _, let eventLoopPromise):
return .failPreparedStatementCreation(eventLoopPromise, with: error)
Expand Down Expand Up @@ -610,7 +536,7 @@ struct ExtendedQueryStateMachine {
switch context.query {
case .prepareStatement:
return true
case .unnamed, .executeStatement, .simpleQuery:
case .unnamed, .executeStatement:
return false
}

Expand Down
Loading

0 comments on commit ffa0dc6

Please sign in to comment.