Skip to content

Commit

Permalink
Allow configuration of backfill batch size (#406)
Browse files Browse the repository at this point in the history
It can be controlled via the `--backfill-batch-size` command line
parameter or by setting the `PGROLL_BACKFILL_BATCH_SIZE` environment
variable.

It can also be set programatically via the `roll.WithBackfillBatchSize`
function.

If unset, it will default to 1000.

Part of #168
  • Loading branch information
ryanslade authored Oct 16, 2024
1 parent 681a3eb commit ced761b
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 8 deletions.
2 changes: 2 additions & 0 deletions cmd/flags/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ func LockTimeout() int {
return viper.GetInt("LOCK_TIMEOUT")
}

func BackfillBatchSize() int { return viper.GetInt("BACKFILL_BATCH_SIZE") }

func Role() string {
return viper.GetString("ROLE")
}
5 changes: 5 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/spf13/cobra"
"github.com/spf13/viper"

"github.com/xataio/pgroll/cmd/flags"
"github.com/xataio/pgroll/pkg/roll"
"github.com/xataio/pgroll/pkg/state"
Expand All @@ -23,12 +24,14 @@ func init() {
rootCmd.PersistentFlags().String("schema", "public", "Postgres schema to use for the migration")
rootCmd.PersistentFlags().String("pgroll-schema", "pgroll", "Postgres schema to use for pgroll internal state")
rootCmd.PersistentFlags().Int("lock-timeout", 500, "Postgres lock timeout in milliseconds for pgroll DDL operations")
rootCmd.PersistentFlags().Int("backfill-batch-size", roll.DefaultBackfillBatchSize, "Number of rows backfilled in each batch")
rootCmd.PersistentFlags().String("role", "", "Optional postgres role to set when executing migrations")

viper.BindPFlag("PG_URL", rootCmd.PersistentFlags().Lookup("postgres-url"))
viper.BindPFlag("SCHEMA", rootCmd.PersistentFlags().Lookup("schema"))
viper.BindPFlag("STATE_SCHEMA", rootCmd.PersistentFlags().Lookup("pgroll-schema"))
viper.BindPFlag("LOCK_TIMEOUT", rootCmd.PersistentFlags().Lookup("lock-timeout"))
viper.BindPFlag("BACKFILL_BATCH_SIZE", rootCmd.PersistentFlags().Lookup("backfill-batch-size"))
viper.BindPFlag("ROLE", rootCmd.PersistentFlags().Lookup("role"))
}

Expand All @@ -44,6 +47,7 @@ func NewRoll(ctx context.Context) (*roll.Roll, error) {
stateSchema := flags.StateSchema()
lockTimeout := flags.LockTimeout()
role := flags.Role()
backfillBatchSize := flags.BackfillBatchSize()

state, err := state.New(ctx, pgURL, stateSchema)
if err != nil {
Expand All @@ -53,6 +57,7 @@ func NewRoll(ctx context.Context) (*roll.Roll, error) {
return roll.New(ctx, pgURL, schema, state,
roll.WithLockTimeoutMs(lockTimeout),
roll.WithRole(role),
roll.WithBackfillBatchSize(backfillBatchSize),
)
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/migrations/backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
// 2. Get the first batch of rows from the table, ordered by the primary key.
// 3. Update each row in the batch, setting the value of the primary key column to itself.
// 4. Repeat steps 2 and 3 until no more rows are returned.
func Backfill(ctx context.Context, conn db.DB, table *schema.Table, cbs ...CallbackFn) error {
func Backfill(ctx context.Context, conn db.DB, table *schema.Table, batchSize int, cbs ...CallbackFn) error {
// get the backfill column
identityColumn := getIdentityColumn(table)
if identityColumn == nil {
Expand All @@ -31,7 +31,7 @@ func Backfill(ctx context.Context, conn db.DB, table *schema.Table, cbs ...Callb
table: table,
identityColumn: identityColumn,
lastValue: nil,
batchSize: 1000,
batchSize: batchSize,
}

// Update each batch of rows, invoking callbacks for each one.
Expand Down
2 changes: 1 addition & 1 deletion pkg/roll/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ func (m *Roll) ensureView(ctx context.Context, version, name string, table schem

func (m *Roll) performBackfills(ctx context.Context, tables []*schema.Table, cbs ...migrations.CallbackFn) error {
for _, table := range tables {
if err := migrations.Backfill(ctx, m.pgConn, table, cbs...); err != nil {
if err := migrations.Backfill(ctx, m.pgConn, table, m.backfillBatchSize, cbs...); err != nil {
errRollback := m.Rollback(ctx)

return errors.Join(
Expand Down
10 changes: 10 additions & 0 deletions pkg/roll/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ type options struct {
// additional entries to add to the search_path during migration execution
searchPath []string

// the number of rows to backfill in each batch
backfillBatchSize int

migrationHooks MigrationHooks
}

Expand Down Expand Up @@ -99,3 +102,10 @@ func WithSearchPath(schemas ...string) Option {
o.searchPath = schemas
}
}

// WithBackfillBatchSize sets the number of rows backfilled in each batch.
func WithBackfillBatchSize(batchSize int) Option {
return func(o *options) {
o.backfillBatchSize = batchSize
}
}
18 changes: 13 additions & 5 deletions pkg/roll/roll.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ import (

type PGVersion int

const PGVersion15 PGVersion = 15
const (
PGVersion15 PGVersion = 15
DefaultBackfillBatchSize int = 1000
)

type Roll struct {
pgConn db.DB
Expand All @@ -31,10 +34,11 @@ type Roll struct {
// disable creation of version schema for raw SQL migrations
noVersionSchemaForRawSQL bool

migrationHooks MigrationHooks
state *state.State
pgVersion PGVersion
sqlTransformer migrations.SQLTransformer
migrationHooks MigrationHooks
state *state.State
pgVersion PGVersion
sqlTransformer migrations.SQLTransformer
backfillBatchSize int
}

// New creates a new Roll instance
Expand All @@ -43,6 +47,9 @@ func New(ctx context.Context, pgURL, schema string, state *state.State, opts ...
for _, o := range opts {
o(rollOpts)
}
if rollOpts.backfillBatchSize <= 0 {
rollOpts.backfillBatchSize = DefaultBackfillBatchSize
}

conn, err := setupConn(ctx, pgURL, schema, *rollOpts)
if err != nil {
Expand Down Expand Up @@ -71,6 +78,7 @@ func New(ctx context.Context, pgURL, schema string, state *state.State, opts ...
noVersionSchemaForRawSQL: rollOpts.noVersionSchemaForRawSQL,
migrationHooks: rollOpts.migrationHooks,
sqlTransformer: sqlTransformer,
backfillBatchSize: rollOpts.backfillBatchSize,
}, nil
}

Expand Down

0 comments on commit ced761b

Please sign in to comment.