diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index ee925d0e..a7082a90 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -44,7 +44,7 @@ final class PSQLRowStream: @unchecked Sendable { } internal let rowDescription: [RowDescription.Column] - private let lookupTable: [String: Int] + internal let lookupTable: [String: Int] private var downstreamState: DownstreamState init( @@ -114,7 +114,7 @@ final class PSQLRowStream: @unchecked Sendable { self.downstreamState = .consumed(.failure(error)) } - return PostgresRowSequence(producer.sequence, lookupTable: self.lookupTable, columns: self.rowDescription) + return PostgresRowSequence(producer.sequence, rowStream: self) } func demand() { diff --git a/Sources/PostgresNIO/New/PostgresRowSequence.swift b/Sources/PostgresNIO/New/PostgresRowSequence.swift index 3936b51e..9ee06358 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence.swift @@ -9,14 +9,18 @@ public struct PostgresRowSequence: AsyncSequence, Sendable { typealias BackingSequence = NIOThrowingAsyncSequenceProducer - let backing: BackingSequence - let lookupTable: [String: Int] - let columns: [RowDescription.Column] + private let backing: BackingSequence + private let rowStream: PSQLRowStream + var lookupTable: [String: Int] { + self.rowStream.lookupTable + } + var columns: [RowDescription.Column] { + self.rowStream.rowDescription + } - init(_ backing: BackingSequence, lookupTable: [String: Int], columns: [RowDescription.Column]) { + init(_ backing: BackingSequence, rowStream: PSQLRowStream) { self.backing = backing - self.lookupTable = lookupTable - self.columns = columns + self.rowStream = rowStream } public func makeAsyncIterator() -> AsyncIterator { @@ -60,6 +64,8 @@ extension PostgresRowSequence { extension PostgresRowSequence.AsyncIterator: Sendable {} extension PostgresRowSequence { + /// Collect and return all rows. + /// - Returns: The rows. public func collect() async throws -> [PostgresRow] { var result = [PostgresRow]() for try await row in self { @@ -67,6 +73,39 @@ extension PostgresRowSequence { } return result } + + /// Collect and return all rows, alongside the query metadata. + /// - Returns: The query metadata and the rows. + public func collectWithMetadata() async throws -> (metadata: PostgresQueryMetadata, rows: [PostgresRow]) { + let rows = try await self.collect() + guard let metadata = PostgresQueryMetadata(string: self.rowStream.commandTag) else { + throw PSQLError.invalidCommandTag(self.rowStream.commandTag) + } + return (metadata, rows) + } + + /// Consumes all rows and returns the query metadata. + /// + /// If you don't need the returned query metadata, just use the for-try-await-loop syntax: + /// ```swift + /// for try await row in myRowSequence { + /// /// Process each row + /// } + /// ``` + /// + /// - Parameter onRow: Processes each row. + /// - Returns: The query metadata. + public func consume( + onRow: @Sendable (PostgresRow) throws -> () + ) async throws -> PostgresQueryMetadata { + for try await row in self { + try onRow(row) + } + guard let metadata = PostgresQueryMetadata(string: self.rowStream.commandTag) else { + throw PSQLError.invalidCommandTag(self.rowStream.commandTag) + } + return metadata + } } struct AdaptiveRowBuffer: NIOAsyncSequenceProducerBackPressureStrategy { diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index b4c8e93f..6a5cd171 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -1,3 +1,4 @@ +import Atomics import Logging import XCTest import PostgresNIO @@ -46,6 +47,58 @@ final class AsyncPostgresConnectionTests: XCTestCase { } } + func testSelect10kRowsAndConsume() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let start = 1 + let end = 10000 + + try await withTestConnection(on: eventLoop) { connection in + let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest) + + let counter = ManagedAtomic(0) + let metadata = try await rows.consume { row in + let element = try row.decode(Int.self) + let newCounter = counter.wrappingIncrementThenLoad(ordering: .relaxed) + XCTAssertEqual(element, newCounter) + } + + XCTAssertEqual(metadata.command, "SELECT") + XCTAssertEqual(metadata.oid, nil) + XCTAssertEqual(metadata.rows, 10000) + + XCTAssertEqual(counter.load(ordering: .relaxed), end) + } + } + + func testSelect10kRowsAndCollect() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let start = 1 + let end = 10000 + + try await withTestConnection(on: eventLoop) { connection in + let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest) + let (metadata, elements) = try await rows.collectWithMetadata() + var counter = 0 + for row in elements { + let element = try row.decode(Int.self) + XCTAssertEqual(element, counter + 1) + counter += 1 + } + + XCTAssertEqual(metadata.command, "SELECT") + XCTAssertEqual(metadata.oid, nil) + XCTAssertEqual(metadata.rows, 10000) + + XCTAssertEqual(counter, end) + } + } + func testSelectActiveConnection() async throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } @@ -207,7 +260,7 @@ final class AsyncPostgresConnectionTests: XCTestCase { try await withTestConnection(on: eventLoop) { connection in // Max binds limit is UInt16.max which is 65535 which is 3 * 5 * 17 * 257 - // Max columns limit is 1664, so we will only make 5 * 257 columns which is less + // Max columns limit appears to be ~1600, so we will only make 5 * 257 columns which is less // Then we will insert 3 * 17 rows // In the insertion, there will be a total of 3 * 17 * 5 * 257 == UInt16.max bindings // If the test is successful, it means Postgres supports UInt16.max bindings @@ -241,13 +294,9 @@ final class AsyncPostgresConnectionTests: XCTestCase { unsafeSQL: "INSERT INTO table1 VALUES \(insertionValues)", binds: binds ) - try await connection.query(insertionQuery, logger: .psqlTest) - - let countQuery = PostgresQuery(unsafeSQL: "SELECT COUNT(*) FROM table1") - let countRows = try await connection.query(countQuery, logger: .psqlTest) - var countIterator = countRows.makeAsyncIterator() - let insertedRowsCount = try await countIterator.next()?.decode(Int.self, context: .default) - XCTAssertEqual(rowsCount, insertedRowsCount) + let result = try await connection.query(insertionQuery, logger: .psqlTest) + let metadata = try await result.collectWithMetadata().metadata + XCTAssertEqual(metadata.rows, rowsCount) let dropQuery = PostgresQuery(unsafeSQL: "DROP TABLE table1") try await connection.query(dropQuery, logger: .psqlTest)