diff --git a/cmd/flags/flags.go b/cmd/flags/flags.go index b063804c..6d821b47 100644 --- a/cmd/flags/flags.go +++ b/cmd/flags/flags.go @@ -21,3 +21,7 @@ func StateSchema() string { func LockTimeout() int { return viper.GetInt("LOCK_TIMEOUT") } + +func Role() string { + return viper.GetString("ROLE") +} diff --git a/cmd/root.go b/cmd/root.go index 9f9c27b8..5936e728 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -23,11 +23,13 @@ func init() { rootCmd.PersistentFlags().String("schema", "public", "Postgres schema to use for the migration") rootCmd.PersistentFlags().String("pgroll-schema", "pgroll", "Postgres schema to use for pgroll internal state") rootCmd.PersistentFlags().Int("lock-timeout", 500, "Postgres lock timeout in milliseconds for pgroll DDL operations") + rootCmd.PersistentFlags().String("role", "", "Optional postgres role to set when executing migrations") viper.BindPFlag("PG_URL", rootCmd.PersistentFlags().Lookup("postgres-url")) viper.BindPFlag("SCHEMA", rootCmd.PersistentFlags().Lookup("schema")) viper.BindPFlag("STATE_SCHEMA", rootCmd.PersistentFlags().Lookup("pgroll-schema")) viper.BindPFlag("LOCK_TIMEOUT", rootCmd.PersistentFlags().Lookup("lock-timeout")) + viper.BindPFlag("ROLE", rootCmd.PersistentFlags().Lookup("role")) } var rootCmd = &cobra.Command{ @@ -41,13 +43,17 @@ func NewRoll(ctx context.Context) (*roll.Roll, error) { schema := flags.Schema() stateSchema := flags.StateSchema() lockTimeout := flags.LockTimeout() + role := flags.Role() state, err := state.New(ctx, pgURL, stateSchema) if err != nil { return nil, err } - return roll.New(ctx, pgURL, schema, lockTimeout, state) + return roll.New(ctx, pgURL, schema, state, + roll.WithLockTimeoutMs(lockTimeout), + roll.WithRole(role), + ) } // Execute executes the root command. diff --git a/docs/README.md b/docs/README.md index 891af506..032d45c3 100644 --- a/docs/README.md +++ b/docs/README.md @@ -537,12 +537,14 @@ The `pgroll` CLI has the following top-level flags: * `--schema`: The Postgres schema in which migrations will be run (default `"public"`). * `--pgroll-schema`: The Postgres schema in which `pgroll` will store its internal state (default: `"pgroll"`). * `--lock-timeout`: The Postgres `lock_timeout` value to use for all `pgroll` DDL operations, specified in milliseconds (default `500`). +* --role: The Postgres role to use for all `pgroll` DDL operations (default: `""`, which doesn't set any role). Each of these flags can also be set via an environment variable: * `PGROLL_PG_URL` * `PGROLL_SCHEMA` * `PGROLL_STATE_SCHEMA` * `PGROLL_LOCK_TIMEOUT` +* `PGROLL_ROLE` The CLI flag takes precedence if a flag is set via both an environment variable and a CLI flag. diff --git a/pkg/roll/options.go b/pkg/roll/options.go new file mode 100644 index 00000000..5c2b0028 --- /dev/null +++ b/pkg/roll/options.go @@ -0,0 +1,25 @@ +package roll + +type options struct { + // lock timeout in milliseconds for pgroll DDL operations + lockTimeoutMs int + + // optional role to set before executing migrations + role string +} + +type Option func(*options) + +// WithLockTimeoutMs sets the lock timeout in milliseconds for pgroll DDL operations +func WithLockTimeoutMs(lockTimeoutMs int) Option { + return func(o *options) { + o.lockTimeoutMs = lockTimeoutMs + } +} + +// WithRole sets the role to set before executing migrations +func WithRole(role string) Option { + return func(o *options) { + o.role = role + } +} diff --git a/pkg/roll/roll.go b/pkg/roll/roll.go index 30a200ef..e15d0ffe 100644 --- a/pkg/roll/roll.go +++ b/pkg/roll/roll.go @@ -26,7 +26,12 @@ type Roll struct { pgVersion PGVersion } -func New(ctx context.Context, pgURL, schema string, lockTimeoutMs int, state *state.State) (*Roll, error) { +func New(ctx context.Context, pgURL, schema string, state *state.State, opts ...Option) (*Roll, error) { + options := &options{} + for _, o := range opts { + o(options) + } + dsn, err := pq.ParseURL(pgURL) if err != nil { dsn = pgURL @@ -48,9 +53,18 @@ func New(ctx context.Context, pgURL, schema string, lockTimeoutMs int, state *st return nil, fmt.Errorf("unable to set pgroll.internal to true: %w", err) } - _, err = conn.ExecContext(ctx, fmt.Sprintf("SET lock_timeout to '%dms'", lockTimeoutMs)) - if err != nil { - return nil, fmt.Errorf("unable to set lock_timeout: %w", err) + if options.lockTimeoutMs > 0 { + _, err = conn.ExecContext(ctx, fmt.Sprintf("SET lock_timeout to '%dms'", options.lockTimeoutMs)) + if err != nil { + return nil, fmt.Errorf("unable to set lock_timeout: %w", err) + } + } + + if options.role != "" { + _, err = conn.ExecContext(ctx, fmt.Sprintf("SET ROLE %s", options.role)) + if err != nil { + return nil, fmt.Errorf("unable to set role to '%s': %w", options.role, err) + } } var pgMajorVersion PGVersion diff --git a/pkg/testutils/util.go b/pkg/testutils/util.go index 12836fe3..96f27fb8 100644 --- a/pkg/testutils/util.go +++ b/pkg/testutils/util.go @@ -153,7 +153,7 @@ func WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, s t.Fatal(err) } - mig, err := roll.New(ctx, connStr, schema, lockTimeoutMs, st) + mig, err := roll.New(ctx, connStr, schema, st, roll.WithLockTimeoutMs(lockTimeoutMs)) if err != nil { t.Fatal(err) }