Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose query metadata in PostgresRowSequence #504

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Sources/PostgresNIO/New/PSQLRowStream.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ final class PSQLRowStream: @unchecked Sendable {
}

internal let rowDescription: [RowDescription.Column]
private let lookupTable: [String: Int]
internal let lookupTable: [String: Int]
private var downstreamState: DownstreamState

init(
Expand Down Expand Up @@ -114,7 +114,7 @@ final class PSQLRowStream: @unchecked Sendable {
self.downstreamState = .consumed(.failure(error))
}

return PostgresRowSequence(producer.sequence, lookupTable: self.lookupTable, columns: self.rowDescription)
return PostgresRowSequence(producer.sequence, rowStream: self)
}

func demand() {
Expand Down
51 changes: 45 additions & 6 deletions Sources/PostgresNIO/New/PostgresRowSequence.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,18 @@ public struct PostgresRowSequence: AsyncSequence, Sendable {

typealias BackingSequence = NIOThrowingAsyncSequenceProducer<DataRow, Error, AdaptiveRowBuffer, PSQLRowStream>

let backing: BackingSequence
let lookupTable: [String: Int]
let columns: [RowDescription.Column]
private let backing: BackingSequence
private let rowStream: PSQLRowStream
var lookupTable: [String: Int] {
self.rowStream.lookupTable
}
var columns: [RowDescription.Column] {
self.rowStream.rowDescription
}

init(_ backing: BackingSequence, lookupTable: [String: Int], columns: [RowDescription.Column]) {
init(_ backing: BackingSequence, rowStream: PSQLRowStream) {
self.backing = backing
self.lookupTable = lookupTable
self.columns = columns
self.rowStream = rowStream
}

public func makeAsyncIterator() -> AsyncIterator {
Expand Down Expand Up @@ -60,13 +64,48 @@ extension PostgresRowSequence {
extension PostgresRowSequence.AsyncIterator: Sendable {}

extension PostgresRowSequence {
/// Collect and return all rows.
/// - Returns: The rows.
public func collect() async throws -> [PostgresRow] {
var result = [PostgresRow]()
for try await row in self {
result.append(row)
}
return result
}

/// Collect and return all rows, alongside the query metadata.
/// - Returns: The query metadata and the rows.
public func collectWithMetadata() async throws -> (metadata: PostgresQueryMetadata, rows: [PostgresRow]) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm open to better names.

let rows = try await self.collect()
guard let metadata = PostgresQueryMetadata(string: self.rowStream.commandTag) else {
throw PSQLError.invalidCommandTag(self.rowStream.commandTag)
}
return (metadata, rows)
}

/// Consumes all rows and returns the query metadata.
///
/// If you don't need the returned query metadata, just use the for-try-await-loop syntax:
/// ```swift
/// for try await row in myRowSequence {
/// /// Process each row
/// }
/// ```
///
/// - Parameter onRow: Processes each row.
/// - Returns: The query metadata.
public func consume(
onRow: @Sendable (PostgresRow) throws -> ()
) async throws -> PostgresQueryMetadata {
Comment on lines +98 to +100
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm open to better names.

Copy link
Contributor Author

@MahdiBM MahdiBM Aug 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the closure is intentionally not async since it's generally not a good idea to be doing async things on each row.
It's still possible to accumulate the rows and do whatever you want after.

I could change that if you think with async it would be better.

for try await row in self {
try onRow(row)
}
guard let metadata = PostgresQueryMetadata(string: self.rowStream.commandTag) else {
throw PSQLError.invalidCommandTag(self.rowStream.commandTag)
}
return metadata
}
}

struct AdaptiveRowBuffer: NIOAsyncSequenceProducerBackPressureStrategy {
Expand Down
65 changes: 57 additions & 8 deletions Tests/IntegrationTests/AsyncTests.swift
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import Atomics
import Logging
import XCTest
import PostgresNIO
Expand Down Expand Up @@ -46,6 +47,58 @@ final class AsyncPostgresConnectionTests: XCTestCase {
}
}

func testSelect10kRowsAndConsume() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

let start = 1
let end = 10000

try await withTestConnection(on: eventLoop) { connection in
let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest)

let counter = ManagedAtomic(0)
let metadata = try await rows.consume { row in
let element = try row.decode(Int.self)
let newCounter = counter.wrappingIncrementThenLoad(ordering: .relaxed)
XCTAssertEqual(element, newCounter)
}

XCTAssertEqual(metadata.command, "SELECT")
XCTAssertEqual(metadata.oid, nil)
XCTAssertEqual(metadata.rows, 10000)

XCTAssertEqual(counter.load(ordering: .relaxed), end)
}
}

func testSelect10kRowsAndCollect() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

let start = 1
let end = 10000

try await withTestConnection(on: eventLoop) { connection in
let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest)
let (metadata, elements) = try await rows.collectWithMetadata()
var counter = 0
for row in elements {
let element = try row.decode(Int.self)
XCTAssertEqual(element, counter + 1)
counter += 1
}

XCTAssertEqual(metadata.command, "SELECT")
XCTAssertEqual(metadata.oid, nil)
XCTAssertEqual(metadata.rows, 10000)

XCTAssertEqual(counter, end)
}
}

func testSelectActiveConnection() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
Expand Down Expand Up @@ -207,7 +260,7 @@ final class AsyncPostgresConnectionTests: XCTestCase {

try await withTestConnection(on: eventLoop) { connection in
// Max binds limit is UInt16.max which is 65535 which is 3 * 5 * 17 * 257
// Max columns limit is 1664, so we will only make 5 * 257 columns which is less
// Max columns limit appears to be ~1600, so we will only make 5 * 257 columns which is less
// Then we will insert 3 * 17 rows
// In the insertion, there will be a total of 3 * 17 * 5 * 257 == UInt16.max bindings
// If the test is successful, it means Postgres supports UInt16.max bindings
Expand Down Expand Up @@ -241,13 +294,9 @@ final class AsyncPostgresConnectionTests: XCTestCase {
unsafeSQL: "INSERT INTO table1 VALUES \(insertionValues)",
binds: binds
)
try await connection.query(insertionQuery, logger: .psqlTest)

let countQuery = PostgresQuery(unsafeSQL: "SELECT COUNT(*) FROM table1")
let countRows = try await connection.query(countQuery, logger: .psqlTest)
var countIterator = countRows.makeAsyncIterator()
let insertedRowsCount = try await countIterator.next()?.decode(Int.self, context: .default)
XCTAssertEqual(rowsCount, insertedRowsCount)
let result = try await connection.query(insertionQuery, logger: .psqlTest)
let metadata = try await result.collectWithMetadata().metadata
XCTAssertEqual(metadata.rows, rowsCount)
Comment on lines -246 to +299
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The metadata contains the count of the inserted rows so no need to do another query.


let dropQuery = PostgresQuery(unsafeSQL: "DROP TABLE table1")
try await connection.query(dropQuery, logger: .psqlTest)
Expand Down
Loading