diff --git a/cmd/start.go b/cmd/start.go index 39d8fde4..49754f2b 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -55,8 +55,12 @@ func runMigrationFromFile(ctx context.Context, m *roll.Roll, fileName string, co func runMigration(ctx context.Context, m *roll.Roll, migration *migrations.Migration, complete bool) error { sp, _ := pterm.DefaultSpinner.WithText("Starting migration...").Start() - cb := func(n int64) { - sp.UpdateText(fmt.Sprintf("%d records complete...", n)) + cb := func(n int64, total int64) { + var percent float64 + if total > 0 { + percent = float64(n) / float64(total) * 100 + } + sp.UpdateText(fmt.Sprintf("%d records complete... (%.2f%%)", n, percent)) } err := m.Start(ctx, migration, cb) diff --git a/pkg/db/db.go b/pkg/db/db.go index 2c39ee85..03653219 100644 --- a/pkg/db/db.go +++ b/pkg/db/db.go @@ -20,6 +20,7 @@ const ( type DB interface { ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row WithRetryableTransaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error Close() error } @@ -52,6 +53,11 @@ func (db *RDB) ExecContext(ctx context.Context, query string, args ...interface{ } } +// QueryRowContext wraps sql.DB.QueryRowContext. +func (db *RDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + return db.DB.QueryRowContext(ctx, query, args...) +} + // WithRetryableTransaction runs `f` in a transaction, retrying on lock_timeout errors. func (db *RDB) WithRetryableTransaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error { b := backoff.New(maxBackoffDuration, backoffInterval) diff --git a/pkg/migrations/backfill.go b/pkg/migrations/backfill.go index e9829a20..2b4df61c 100644 --- a/pkg/migrations/backfill.go +++ b/pkg/migrations/backfill.go @@ -29,13 +29,30 @@ func Backfill(ctx context.Context, conn db.DB, table *schema.Table, batchSize in return BackfillNotPossibleError{Table: table.Name} } + var total int64 + + // Try and get estimated row count + row := conn.QueryRowContext(ctx, `SELECT n_live_tup AS estimate + FROM pg_stat_user_tables + WHERE relname = $1`, table.Name) + if err := row.Scan(&total); err != nil { + return fmt.Errorf("scanning row count estimate for %q: %w", table.Name, err) + } + // If the estimate is zero, fall back to full count + if total == 0 { + row = conn.QueryRowContext(ctx, fmt.Sprintf(`SELECT count(*) from %s`, table.Name)) + if err := row.Scan(&total); err != nil { + return fmt.Errorf("scanning row count for %q: %w", table.Name, err) + } + } + // Create a batcher for the table. b := newBatcher(table, batchSize) // Update each batch of rows, invoking callbacks for each one. for batch := 0; ; batch++ { for _, cb := range cbs { - cb(int64(batch * batchSize)) + cb(int64(batch*batchSize), total) } if err := b.updateBatch(ctx, conn); err != nil { diff --git a/pkg/migrations/migrations.go b/pkg/migrations/migrations.go index 96191500..c2683753 100644 --- a/pkg/migrations/migrations.go +++ b/pkg/migrations/migrations.go @@ -12,7 +12,7 @@ import ( "github.com/xataio/pgroll/pkg/schema" ) -type CallbackFn func(int64) +type CallbackFn func(done int64, total int64) // Operation is an operation that can be applied to a schema type Operation interface {