From 150e1c95ddf0818bde00a83daa02d79351701617 Mon Sep 17 00:00:00 2001 From: Ryan Slade Date: Fri, 3 Jan 2025 13:41:41 +0100 Subject: [PATCH] Move estimate into function and fix a test --- pkg/migrations/backfill.go | 42 ++++++++++++++++++++++++-------------- pkg/roll/execute_test.go | 2 +- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/pkg/migrations/backfill.go b/pkg/migrations/backfill.go index 2b4df61c..b094f022 100644 --- a/pkg/migrations/backfill.go +++ b/pkg/migrations/backfill.go @@ -29,21 +29,9 @@ func Backfill(ctx context.Context, conn db.DB, table *schema.Table, batchSize in return BackfillNotPossibleError{Table: table.Name} } - var total int64 - - // Try and get estimated row count - row := conn.QueryRowContext(ctx, `SELECT n_live_tup AS estimate - FROM pg_stat_user_tables - WHERE relname = $1`, table.Name) - if err := row.Scan(&total); err != nil { - return fmt.Errorf("scanning row count estimate for %q: %w", table.Name, err) - } - // If the estimate is zero, fall back to full count - if total == 0 { - row = conn.QueryRowContext(ctx, fmt.Sprintf(`SELECT count(*) from %s`, table.Name)) - if err := row.Scan(&total); err != nil { - return fmt.Errorf("scanning row count for %q: %w", table.Name, err) - } + total, err := getRowCount(ctx, conn, table.Name) + if err != nil { + return fmt.Errorf("get row count for %q: %w", table.Name, err) } // Create a batcher for the table. @@ -72,6 +60,30 @@ func Backfill(ctx context.Context, conn db.DB, table *schema.Table, batchSize in return nil } +// getRowCount will attempt to get the row count for the given table. It first attempts to get an +// estimate and if that is zero, falls back to a full table scan. +func getRowCount(ctx context.Context, conn db.DB, tableName string) (int64, error) { + var total int64 + // Try and get estimated row count + row := conn.QueryRowContext(ctx, ` + SELECT n_live_tup AS estimate + FROM pg_stat_user_tables + WHERE relname = $1`, tableName) + if err := row.Scan(&total); err != nil { + return 0, fmt.Errorf("scanning row count estimate for %q: %w", tableName, err) + } + if total > 0 { + return total, nil + } + + // If the estimate is zero, fall back to full count + row = conn.QueryRowContext(ctx, fmt.Sprintf(`SELECT count(*) from %s`, tableName)) + if err := row.Scan(&total); err != nil { + return 0, fmt.Errorf("scanning row count for %q: %w", tableName, err) + } + return total, nil +} + // checkBackfill will return an error if the backfill operation is not supported. func checkBackfill(table *schema.Table) error { cols := getIdentityColumns(table) diff --git a/pkg/roll/execute_test.go b/pkg/roll/execute_test.go index 6e6706ee..384d4b49 100644 --- a/pkg/roll/execute_test.go +++ b/pkg/roll/execute_test.go @@ -655,7 +655,7 @@ func TestCallbacksAreInvokedOnMigrationStart(t *testing.T) { // Define a mock callback invoked := false - cb := func(n int64) { invoked = true } + cb := func(n, total int64) { invoked = true } // Start a migration that requires a backfill err = mig.Start(ctx, &migrations.Migration{