From 952fb237c33306136feca137d0085381dfe81da1 Mon Sep 17 00:00:00 2001 From: Ryan Slade Date: Wed, 8 Jan 2025 09:39:19 +0100 Subject: [PATCH] Show completion estimate during backfill (#567) Instead of only showing the number of rows backfills, show an estimate of the total number of tows completed as a percentage. For example: ``` 1500 records complete... (12.34%) ``` It attempts to estimate the total number of rows but will fall back to a full scan if the number of rows estimated is zero. Closes #492 --- cmd/start.go | 16 ++++++++++-- pkg/db/db.go | 34 +++++++++++++++++++++++++ pkg/db/db_test.go | 48 ++++++++++++++++++++++++++++++++++++ pkg/migrations/backfill.go | 47 ++++++++++++++++++++++++++++++++++- pkg/migrations/migrations.go | 2 +- pkg/roll/execute_test.go | 2 +- 6 files changed, 144 insertions(+), 5 deletions(-) diff --git a/cmd/start.go b/cmd/start.go index 39d8fde4..21e64b20 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -55,8 +55,20 @@ func runMigrationFromFile(ctx context.Context, m *roll.Roll, fileName string, co func runMigration(ctx context.Context, m *roll.Roll, migration *migrations.Migration, complete bool) error { sp, _ := pterm.DefaultSpinner.WithText("Starting migration...").Start() - cb := func(n int64) { - sp.UpdateText(fmt.Sprintf("%d records complete...", n)) + cb := func(n int64, total int64) { + var percent float64 + if total > 0 { + percent = float64(n) / float64(total) * 100 + } + if percent > 100 { + // This can happen if we're on the last batch + percent = 100 + } + if total > 0 { + sp.UpdateText(fmt.Sprintf("%d records complete... (%.2f%%)", n, percent)) + } else { + sp.UpdateText(fmt.Sprintf("%d records complete...", n)) + } } err := m.Start(ctx, migration, cb) diff --git a/pkg/db/db.go b/pkg/db/db.go index 2c39ee85..be9a6b64 100644 --- a/pkg/db/db.go +++ b/pkg/db/db.go @@ -20,6 +20,7 @@ const ( type DB interface { ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) WithRetryableTransaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error Close() error } @@ -52,6 +53,28 @@ func (db *RDB) ExecContext(ctx context.Context, query string, args ...interface{ } } +// QueryContext wraps sql.DB.QueryContext, retrying queries on lock_timeout errors. +func (db *RDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + b := backoff.New(maxBackoffDuration, backoffInterval) + + for { + rows, err := db.DB.QueryContext(ctx, query, args...) + if err == nil { + return rows, nil + } + + pqErr := &pq.Error{} + if errors.As(err, &pqErr) && pqErr.Code == lockNotAvailableErrorCode { + if err := sleepCtx(ctx, b.Duration()); err != nil { + return nil, err + } + continue + } + + return nil, err + } +} + // WithRetryableTransaction runs `f` in a transaction, retrying on lock_timeout errors. func (db *RDB) WithRetryableTransaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error { b := backoff.New(maxBackoffDuration, backoffInterval) @@ -95,3 +118,14 @@ func sleepCtx(ctx context.Context, d time.Duration) error { return nil } } + +// ScanFirstValue is a helper function to scan the first value with the assumption that Rows contains +// a single row with a single value. +func ScanFirstValue[T any](rows *sql.Rows, dest *T) error { + if rows.Next() { + if err := rows.Scan(dest); err != nil { + return err + } + } + return rows.Err() +} diff --git a/pkg/db/db_test.go b/pkg/db/db_test.go index 94f47761..975a71b5 100644 --- a/pkg/db/db_test.go +++ b/pkg/db/db_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/xataio/pgroll/internal/testutils" @@ -61,6 +62,53 @@ func TestExecContextWhenContextCancelled(t *testing.T) { }) } +func TestQueryContext(t *testing.T) { + t.Parallel() + + testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) { + ctx := context.Background() + // create a table on which an exclusive lock is held for 2 seconds + setupTableLock(t, connStr, 2*time.Second) + + // set the lock timeout to 100ms + ensureLockTimeout(t, conn, 100) + + // execute a query that should retry until the lock is released + rdb := &db.RDB{DB: conn} + rows, err := rdb.QueryContext(ctx, "SELECT COUNT(*) FROM test") + require.NoError(t, err) + + var count int + err = db.ScanFirstValue(rows, &count) + assert.NoError(t, err) + assert.Equal(t, 0, count) + }) +} + +func TestQueryContextWhenContextCancelled(t *testing.T) { + t.Parallel() + + testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + + // create a table on which an exclusive lock is held for 2 seconds + setupTableLock(t, connStr, 2*time.Second) + + // set the lock timeout to 100ms + ensureLockTimeout(t, conn, 100) + + // execute a query that should retry until the lock is released + rdb := &db.RDB{DB: conn} + + // Cancel the context before the lock times out + go time.AfterFunc(500*time.Millisecond, cancel) + + _, err := rdb.QueryContext(ctx, "SELECT COUNT(*) FROM test") + require.Errorf(t, err, "context canceled") + }) +} + func TestWithRetryableTransaction(t *testing.T) { t.Parallel() diff --git a/pkg/migrations/backfill.go b/pkg/migrations/backfill.go index e9829a20..7a591d6f 100644 --- a/pkg/migrations/backfill.go +++ b/pkg/migrations/backfill.go @@ -29,13 +29,18 @@ func Backfill(ctx context.Context, conn db.DB, table *schema.Table, batchSize in return BackfillNotPossibleError{Table: table.Name} } + total, err := getRowCount(ctx, conn, table.Name) + if err != nil { + return fmt.Errorf("get row count for %q: %w", table.Name, err) + } + // Create a batcher for the table. b := newBatcher(table, batchSize) // Update each batch of rows, invoking callbacks for each one. for batch := 0; ; batch++ { for _, cb := range cbs { - cb(int64(batch * batchSize)) + cb(int64(batch*batchSize), total) } if err := b.updateBatch(ctx, conn); err != nil { @@ -55,6 +60,46 @@ func Backfill(ctx context.Context, conn db.DB, table *schema.Table, batchSize in return nil } +// getRowCount will attempt to get the row count for the given table. It first attempts to get an +// estimate and if that is zero, falls back to a full table scan. +func getRowCount(ctx context.Context, conn db.DB, tableName string) (int64, error) { + // Try and get estimated row count + var currentSchema string + rows, err := conn.QueryContext(ctx, "select current_schema()") + if err != nil { + return 0, fmt.Errorf("getting current schema: %w", err) + } + if err := db.ScanFirstValue(rows, ¤tSchema); err != nil { + return 0, fmt.Errorf("scanning current schema: %w", err) + } + + var total int64 + rows, err = conn.QueryContext(ctx, ` + SELECT n_live_tup AS estimate + FROM pg_stat_user_tables + WHERE schemaname = $1 AND relname = $2`, currentSchema, tableName) + if err != nil { + return 0, fmt.Errorf("getting row count estimate for %q: %w", tableName, err) + } + if err := db.ScanFirstValue(rows, &total); err != nil { + return 0, fmt.Errorf("scanning row count estimate for %q: %w", tableName, err) + } + if total > 0 { + return total, nil + } + + // If the estimate is zero, fall back to full count + rows, err = conn.QueryContext(ctx, fmt.Sprintf(`SELECT count(*) from %s`, tableName)) + if err != nil { + return 0, fmt.Errorf("getting row count for %q: %w", tableName, err) + } + if err := db.ScanFirstValue(rows, &total); err != nil { + return 0, fmt.Errorf("scanning row count for %q: %w", tableName, err) + } + + return total, nil +} + // checkBackfill will return an error if the backfill operation is not supported. func checkBackfill(table *schema.Table) error { cols := getIdentityColumns(table) diff --git a/pkg/migrations/migrations.go b/pkg/migrations/migrations.go index 96191500..c2683753 100644 --- a/pkg/migrations/migrations.go +++ b/pkg/migrations/migrations.go @@ -12,7 +12,7 @@ import ( "github.com/xataio/pgroll/pkg/schema" ) -type CallbackFn func(int64) +type CallbackFn func(done int64, total int64) // Operation is an operation that can be applied to a schema type Operation interface { diff --git a/pkg/roll/execute_test.go b/pkg/roll/execute_test.go index 6e6706ee..384d4b49 100644 --- a/pkg/roll/execute_test.go +++ b/pkg/roll/execute_test.go @@ -655,7 +655,7 @@ func TestCallbacksAreInvokedOnMigrationStart(t *testing.T) { // Define a mock callback invoked := false - cb := func(n int64) { invoked = true } + cb := func(n, total int64) { invoked = true } // Start a migration that requires a backfill err = mig.Start(ctx, &migrations.Migration{