Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show completion estimate during backfill #567

Merged
merged 7 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,20 @@ func runMigrationFromFile(ctx context.Context, m *roll.Roll, fileName string, co

func runMigration(ctx context.Context, m *roll.Roll, migration *migrations.Migration, complete bool) error {
sp, _ := pterm.DefaultSpinner.WithText("Starting migration...").Start()
cb := func(n int64) {
sp.UpdateText(fmt.Sprintf("%d records complete...", n))
cb := func(n int64, total int64) {
var percent float64
if total > 0 {
percent = float64(n) / float64(total) * 100
}
if percent > 100 {
// This can happen if we're on the last batch
percent = 100
}
if total > 0 {
sp.UpdateText(fmt.Sprintf("%d records complete... (%.2f%%)", n, percent))
} else {
sp.UpdateText(fmt.Sprintf("%d records complete...", n))
}
}

err := m.Start(ctx, migration, cb)
Expand Down
34 changes: 34 additions & 0 deletions pkg/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const (

type DB interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
WithRetryableTransaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error
Close() error
}
Expand Down Expand Up @@ -52,6 +53,28 @@ func (db *RDB) ExecContext(ctx context.Context, query string, args ...interface{
}
}

// QueryContext wraps sql.DB.QueryContext, retrying queries on lock_timeout errors.
func (db *RDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
b := backoff.New(maxBackoffDuration, backoffInterval)

for {
rows, err := db.DB.QueryContext(ctx, query, args...)
if err == nil {
return rows, nil
}

pqErr := &pq.Error{}
if errors.As(err, &pqErr) && pqErr.Code == lockNotAvailableErrorCode {
if err := sleepCtx(ctx, b.Duration()); err != nil {
return nil, err
}
continue
}

return nil, err
}
}

// WithRetryableTransaction runs `f` in a transaction, retrying on lock_timeout errors.
func (db *RDB) WithRetryableTransaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error {
b := backoff.New(maxBackoffDuration, backoffInterval)
Expand Down Expand Up @@ -95,3 +118,14 @@ func sleepCtx(ctx context.Context, d time.Duration) error {
return nil
}
}

// ScanFirstValue is a helper function to scan the first value with the assumption that Rows contains
// a single row with a single value.
func ScanFirstValue[T any](rows *sql.Rows, dest *T) error {
if rows.Next() {
if err := rows.Scan(dest); err != nil {
return err
}
}
return rows.Err()
}
48 changes: 48 additions & 0 deletions pkg/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/xataio/pgroll/internal/testutils"
Expand Down Expand Up @@ -61,6 +62,53 @@ func TestExecContextWhenContextCancelled(t *testing.T) {
})
}

func TestQueryContext(t *testing.T) {
t.Parallel()

testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) {
ctx := context.Background()
// create a table on which an exclusive lock is held for 2 seconds
setupTableLock(t, connStr, 2*time.Second)

// set the lock timeout to 100ms
ensureLockTimeout(t, conn, 100)

// execute a query that should retry until the lock is released
rdb := &db.RDB{DB: conn}
rows, err := rdb.QueryContext(ctx, "SELECT COUNT(*) FROM test")
require.NoError(t, err)

var count int
err = db.ScanFirstValue(rows, &count)
assert.NoError(t, err)
assert.Equal(t, 0, count)
})
}

func TestQueryContextWhenContextCancelled(t *testing.T) {
t.Parallel()

testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)

// create a table on which an exclusive lock is held for 2 seconds
setupTableLock(t, connStr, 2*time.Second)

// set the lock timeout to 100ms
ensureLockTimeout(t, conn, 100)

// execute a query that should retry until the lock is released
rdb := &db.RDB{DB: conn}

// Cancel the context before the lock times out
go time.AfterFunc(500*time.Millisecond, cancel)

_, err := rdb.QueryContext(ctx, "SELECT COUNT(*) FROM test")
require.Errorf(t, err, "context canceled")
})
}

func TestWithRetryableTransaction(t *testing.T) {
t.Parallel()

Expand Down
47 changes: 46 additions & 1 deletion pkg/migrations/backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,18 @@ func Backfill(ctx context.Context, conn db.DB, table *schema.Table, batchSize in
return BackfillNotPossibleError{Table: table.Name}
}

total, err := getRowCount(ctx, conn, table.Name)
if err != nil {
return fmt.Errorf("get row count for %q: %w", table.Name, err)
}

// Create a batcher for the table.
b := newBatcher(table, batchSize)

// Update each batch of rows, invoking callbacks for each one.
for batch := 0; ; batch++ {
for _, cb := range cbs {
cb(int64(batch * batchSize))
cb(int64(batch*batchSize), total)
}

if err := b.updateBatch(ctx, conn); err != nil {
Expand All @@ -55,6 +60,46 @@ func Backfill(ctx context.Context, conn db.DB, table *schema.Table, batchSize in
return nil
}

// getRowCount will attempt to get the row count for the given table. It first attempts to get an
// estimate and if that is zero, falls back to a full table scan.
func getRowCount(ctx context.Context, conn db.DB, tableName string) (int64, error) {
// Try and get estimated row count
var currentSchema string
rows, err := conn.QueryContext(ctx, "select current_schema()")
if err != nil {
return 0, fmt.Errorf("getting current schema: %w", err)
}
if err := db.ScanFirstValue(rows, &currentSchema); err != nil {
return 0, fmt.Errorf("scanning current schema: %w", err)
}

var total int64
rows, err = conn.QueryContext(ctx, `
SELECT n_live_tup AS estimate
FROM pg_stat_user_tables
WHERE schemaname = $1 AND relname = $2`, currentSchema, tableName)
if err != nil {
return 0, fmt.Errorf("getting row count estimate for %q: %w", tableName, err)
}
if err := db.ScanFirstValue(rows, &total); err != nil {
return 0, fmt.Errorf("scanning row count estimate for %q: %w", tableName, err)
}
if total > 0 {
return total, nil
}

// If the estimate is zero, fall back to full count
rows, err = conn.QueryContext(ctx, fmt.Sprintf(`SELECT count(*) from %s`, tableName))
if err != nil {
return 0, fmt.Errorf("getting row count for %q: %w", tableName, err)
}
if err := db.ScanFirstValue(rows, &total); err != nil {
return 0, fmt.Errorf("scanning row count for %q: %w", tableName, err)
}

return total, nil
}

// checkBackfill will return an error if the backfill operation is not supported.
func checkBackfill(table *schema.Table) error {
cols := getIdentityColumns(table)
Expand Down
2 changes: 1 addition & 1 deletion pkg/migrations/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/xataio/pgroll/pkg/schema"
)

type CallbackFn func(int64)
type CallbackFn func(done int64, total int64)

// Operation is an operation that can be applied to a schema
type Operation interface {
Expand Down
2 changes: 1 addition & 1 deletion pkg/roll/execute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ func TestCallbacksAreInvokedOnMigrationStart(t *testing.T) {

// Define a mock callback
invoked := false
cb := func(n int64) { invoked = true }
cb := func(n, total int64) { invoked = true }

// Start a migration that requires a backfill
err = mig.Start(ctx, &migrations.Migration{
Expand Down