Skip to content

Commit

Permalink
Use QueryContext instead of QueryRowContext and retry on lock timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanslade committed Jan 7, 2025
1 parent 30b57d1 commit 418ff81
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 11 deletions.
37 changes: 33 additions & 4 deletions pkg/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ const (

type DB interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
WithRetryableTransaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error
Close() error
}
Expand Down Expand Up @@ -53,9 +53,26 @@ func (db *RDB) ExecContext(ctx context.Context, query string, args ...interface{
}
}

// QueryRowContext wraps sql.DB.QueryRowContext.
func (db *RDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
return db.DB.QueryRowContext(ctx, query, args...)
// QueryContext wraps sql.DB.QueryContext, retrying queries on lock_timeout errors.
func (db *RDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
b := backoff.New(maxBackoffDuration, backoffInterval)

for {
rows, err := db.DB.QueryContext(ctx, query, args...)
if err == nil {
return rows, nil
}

pqErr := &pq.Error{}
if errors.As(err, &pqErr) && pqErr.Code == lockNotAvailableErrorCode {
if err := sleepCtx(ctx, b.Duration()); err != nil {
return nil, err
}
continue
}

return nil, err
}
}

// WithRetryableTransaction runs `f` in a transaction, retrying on lock_timeout errors.
Expand Down Expand Up @@ -101,3 +118,15 @@ func sleepCtx(ctx context.Context, d time.Duration) error {
return nil
}
}

// ScanFirstValue is a helper function to scan the first value with the assumption that Rows contains
// a single row with a single value.
func ScanFirstValue[T any](rows *sql.Rows, dest *T) error {
for rows.Next() {
if err := rows.Scan(dest); err != nil {
return err
}
break

Check failure on line 129 in pkg/db/db.go

View workflow job for this annotation

GitHub Actions / lint

SA4004: the surrounding loop is unconditionally terminated (staticcheck)

Check failure on line 129 in pkg/db/db.go

View workflow job for this annotation

GitHub Actions / lint

SA4004: the surrounding loop is unconditionally terminated (staticcheck)
}
return rows.Err()
}
25 changes: 18 additions & 7 deletions pkg/migrations/backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,29 +63,40 @@ func Backfill(ctx context.Context, conn db.DB, table *schema.Table, batchSize in
// getRowCount will attempt to get the row count for the given table. It first attempts to get an
// estimate and if that is zero, falls back to a full table scan.
func getRowCount(ctx context.Context, conn db.DB, tableName string) (int64, error) {
var total int64
// Try and get estimated row count
var currentSchema string
row := conn.QueryRowContext(ctx, "select current_schema()")
if err := row.Scan(&currentSchema); err != nil {
rows, err := conn.QueryContext(ctx, "select current_schema()")
if err != nil {
return 0, fmt.Errorf("getting current schema: %w", err)
}
row = conn.QueryRowContext(ctx, `
if err := db.ScanFirstValue(rows, &currentSchema); err != nil {
return 0, fmt.Errorf("scanning current schema: %w", err)
}

var total int64
rows, err = conn.QueryContext(ctx, `
SELECT n_live_tup AS estimate
FROM pg_stat_user_tables
WHERE schemaname = $1 AND relname = $2`, currentSchema, tableName)
if err := row.Scan(&total); err != nil {
if err != nil {
return 0, fmt.Errorf("getting row count estimate for %q: %w", tableName, err)
}
if err := db.ScanFirstValue(rows, &total); err != nil {
return 0, fmt.Errorf("scanning row count estimate for %q: %w", tableName, err)
}
if total > 0 {
return total, nil
}

// If the estimate is zero, fall back to full count
row = conn.QueryRowContext(ctx, fmt.Sprintf(`SELECT count(*) from %s`, tableName))
if err := row.Scan(&total); err != nil {
rows, err = conn.QueryContext(ctx, fmt.Sprintf(`SELECT count(*) from %s`, tableName))
if err != nil {
return 0, fmt.Errorf("getting row count for %q: %w", tableName, err)
}
if err := db.ScanFirstValue(rows, &total); err != nil {
return 0, fmt.Errorf("scanning row count for %q: %w", tableName, err)
}

return total, nil
}

Expand Down

0 comments on commit 418ff81

Please sign in to comment.