diff --git a/cmd/migration.go b/cmd/migration.go index 30b7716ac..226c461d0 100644 --- a/cmd/migration.go +++ b/cmd/migration.go @@ -39,12 +39,14 @@ var ( }, } + repeatable bool + migrationNewCmd = &cobra.Command{ Use: "new ", Short: "Create an empty migration script", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - return new.Run(args[0], os.Stdin, afero.NewOsFs()) + return new.Run(repeatable, args[0], os.Stdin, afero.NewOsFs()) }, } @@ -149,6 +151,8 @@ func init() { migrationFetchCmd.MarkFlagsMutuallyExclusive("db-url", "linked", "local") migrationCmd.AddCommand(migrationFetchCmd) // Build new command + newFlags := migrationNewCmd.Flags() + newFlags.BoolVarP(&repeatable, "repeatable", "r", false, "Creates a repeatable migration instead of a versioned migration.") migrationCmd.AddCommand(migrationNewCmd) rootCmd.AddCommand(migrationCmd) } diff --git a/internal/migration/list/list.go b/internal/migration/list/list.go index 3107d4ec6..5d5ddd56f 100644 --- a/internal/migration/list/list.go +++ b/internal/migration/list/list.go @@ -5,6 +5,7 @@ import ( "fmt" "math" "strconv" + "strings" "github.com/charmbracelet/glamour" "github.com/go-errors/errors" @@ -68,6 +69,40 @@ func makeTable(remoteMigrations, localMigrations []string) string { j++ } } + + for i, j := 0, 0; i < len(remoteMigrations) || j < len(localMigrations); { + if i < len(remoteMigrations) && !strings.HasPrefix(remoteMigrations[i], "r_") { + i++ + continue + } + + if j < len(localMigrations) && !strings.HasPrefix(localMigrations[j], "r_") { + j++ + continue + } + + // Append repeatable migrations to table + if i >= len(remoteMigrations) { + table += fmt.Sprintf("|`%s`|` `|` `|\n", localMigrations[j]) + j++ + } else if j >= len(localMigrations) { + table += fmt.Sprintf("|` `|`%s`|` `|\n", remoteMigrations[i]) + i++ + } else { + if localMigrations[j] < remoteMigrations[i] { + table += fmt.Sprintf("|`%s`|` `|` `|\n", localMigrations[j]) + j++ + } else if remoteMigrations[i] < localMigrations[j] { + table += fmt.Sprintf("|` `|`%s`|` `|\n", remoteMigrations[i]) + i++ + } else { + table += fmt.Sprintf("|`%s`|`%s`|` `|\n", localMigrations[j], remoteMigrations[i]) + i++ + j++ + } + } + } + return table } @@ -99,7 +134,7 @@ func LoadLocalVersions(fsys afero.Fs) ([]string, error) { func LoadPartialMigrations(version string, fsys afero.Fs) ([]string, error) { filter := func(v string) bool { - return version == "" || v <= version + return version == "" || strings.HasPrefix(version, "r_") || v <= version } return migration.ListLocalMigrations(utils.MigrationsDir, afero.NewIOFS(fsys), filter) } diff --git a/internal/migration/new/new.go b/internal/migration/new/new.go index be232b591..86f36f03d 100644 --- a/internal/migration/new/new.go +++ b/internal/migration/new/new.go @@ -11,8 +11,16 @@ import ( "github.com/supabase/cli/internal/utils" ) -func Run(migrationName string, stdin afero.File, fsys afero.Fs) error { - path := GetMigrationPath(utils.GetCurrentTimestamp(), migrationName) +func Run(repeatable bool, migrationName string, stdin afero.File, fsys afero.Fs) error { + var path string + + if repeatable { + // if migration name already exists, repeatable migration will be overwritten + path = GetRepeatableMigrationPath(migrationName) + } else { + path = GetMigrationPath(utils.GetCurrentTimestamp(), migrationName) + } + if err := utils.MkdirIfNotExistFS(fsys, filepath.Dir(path)); err != nil { return err } @@ -33,6 +41,11 @@ func GetMigrationPath(timestamp, name string) string { return filepath.Join(utils.MigrationsDir, fullName) } +func GetRepeatableMigrationPath(name string) string { + fullName := fmt.Sprintf("r_%s.sql", name) + return filepath.Join(utils.MigrationsDir, fullName) +} + func CopyStdinIfExists(stdin afero.File, dst io.Writer) error { if fi, err := stdin.Stat(); err != nil { return errors.Errorf("failed to initialise stdin: %w", err) diff --git a/internal/migration/new/new_test.go b/internal/migration/new/new_test.go index 39e2fe115..0291c422c 100644 --- a/internal/migration/new/new_test.go +++ b/internal/migration/new/new_test.go @@ -12,14 +12,14 @@ import ( ) func TestNewCommand(t *testing.T) { - t.Run("creates new migration file", func(t *testing.T) { + t.Run("creates new common migration file", func(t *testing.T) { // Setup in-memory fs fsys := afero.NewMemMapFs() // Setup empty stdin stdin, err := fsys.Create("/dev/stdin") require.NoError(t, err) // Run test - assert.NoError(t, Run("test_migrate", stdin, fsys)) + assert.NoError(t, Run(false, "test_migrate", stdin, fsys)) // Validate output files, err := afero.ReadDir(fsys, utils.MigrationsDir) assert.NoError(t, err) @@ -27,7 +27,44 @@ func TestNewCommand(t *testing.T) { assert.Regexp(t, `([0-9]{14})_test_migrate\.sql`, files[0].Name()) }) - t.Run("streams content from pipe", func(t *testing.T) { + t.Run("creates new repeatable migration file", func(t *testing.T) { + // Setup in-memory fs + fsys := afero.NewMemMapFs() + // Setup empty stdin + stdin, err := fsys.Create("/dev/stdin") + require.NoError(t, err) + // Run test + assert.NoError(t, Run(true, "repeatable_test_migrate", stdin, fsys)) + // Validate output + files, err := afero.ReadDir(fsys, utils.MigrationsDir) + assert.NoError(t, err) + assert.Equal(t, 1, len(files)) + assert.Regexp(t, `r_repeatable_test_migrate\.sql`, files[0].Name()) + }) + + t.Run("streams content from pipe to common migration", func(t *testing.T) { + // Setup in-memory fs + fsys := afero.NewMemMapFs() + // Setup stdin + r, w, err := os.Pipe() + require.NoError(t, err) + script := "create table pet;\ndrop table pet;\n" + _, err = w.WriteString(script) + require.NoError(t, err) + require.NoError(t, w.Close()) + // Run test + assert.NoError(t, Run(false, "test_migrate", r, fsys)) + // Validate output + files, err := afero.ReadDir(fsys, utils.MigrationsDir) + assert.NoError(t, err) + assert.Equal(t, 1, len(files)) + path := filepath.Join(utils.MigrationsDir, files[0].Name()) + contents, err := afero.ReadFile(fsys, path) + assert.NoError(t, err) + assert.Equal(t, []byte(script), contents) + }) + + t.Run("streams content from pipe to repeatable migration", func(t *testing.T) { // Setup in-memory fs fsys := afero.NewMemMapFs() // Setup stdin @@ -38,7 +75,7 @@ func TestNewCommand(t *testing.T) { require.NoError(t, err) require.NoError(t, w.Close()) // Run test - assert.NoError(t, Run("test_migrate", r, fsys)) + assert.NoError(t, Run(true, "repeatable_test_migrate", r, fsys)) // Validate output files, err := afero.ReadDir(fsys, utils.MigrationsDir) assert.NoError(t, err) @@ -56,7 +93,8 @@ func TestNewCommand(t *testing.T) { stdin, err := fsys.Create("/dev/stdin") require.NoError(t, err) // Run test - assert.Error(t, Run("test_migrate", stdin, afero.NewReadOnlyFs(fsys))) + assert.Error(t, Run(false, "test_migrate", stdin, afero.NewReadOnlyFs(fsys))) + assert.Error(t, Run(true, "repeatable_test_migrate", stdin, afero.NewReadOnlyFs(fsys))) }) t.Run("throws error on closed pipe", func(t *testing.T) { @@ -67,6 +105,7 @@ func TestNewCommand(t *testing.T) { require.NoError(t, err) require.NoError(t, r.Close()) // Run test - assert.Error(t, Run("test_migrate", r, fsys)) + assert.Error(t, Run(false, "test_migrate", r, fsys)) + assert.Error(t, Run(true, "repeatable_test_migrate", r, fsys)) }) } diff --git a/internal/migration/up/up.go b/internal/migration/up/up.go index d33117f33..3f2531d99 100644 --- a/internal/migration/up/up.go +++ b/internal/migration/up/up.go @@ -23,6 +23,7 @@ func Run(ctx context.Context, includeAll bool, config pgconn.Config, fsys afero. if err != nil { return err } + return migration.ApplyMigrations(ctx, pending, conn, afero.NewIOFS(fsys)) } diff --git a/pkg/migration/file.go b/pkg/migration/file.go index 79da1c86a..cb45a5430 100644 --- a/pkg/migration/file.go +++ b/pkg/migration/file.go @@ -23,7 +23,7 @@ type MigrationFile struct { Statements []string } -var migrateFilePattern = regexp.MustCompile(`^([0-9]+)_(.*)\.sql$`) +var migrateFilePattern = regexp.MustCompile(`^([0-9]+|r)_(.*)\.sql$`) func NewMigrationFromFile(path string, fsys fs.FS) (*MigrationFile, error) { lines, err := parseFile(path, fsys) @@ -38,6 +38,10 @@ func NewMigrationFromFile(path string, fsys fs.FS) (*MigrationFile, error) { file.Version = matches[1] file.Name = matches[2] } + // Repeatable migration version => r_name + if file.Version == "r" { + file.Version += "_" + file.Name + } return &file, nil } diff --git a/pkg/migration/list.go b/pkg/migration/list.go index ec8bf2ac8..b4161538d 100644 --- a/pkg/migration/list.go +++ b/pkg/migration/list.go @@ -47,12 +47,16 @@ func ListLocalMigrations(migrationsDir string, fsys fs.FS, filter ...func(string } matches := migrateFilePattern.FindStringSubmatch(filename) if len(matches) == 0 { - fmt.Fprintf(os.Stderr, "Skipping migration %s... (file name must match pattern \"_name.sql\")\n", filename) + fmt.Fprintf(os.Stderr, "Skipping migration %s... (file name must match pattern \"_name.sql\" or \"r_name.sql\")\n", filename) continue } path := filepath.Join(migrationsDir, filename) for _, keep := range filter { - if version := matches[1]; keep(version) { + version := matches[1] + if version == "r" && len(matches) > 2 { + version += "_" + matches[2] + } + if keep(version) { clean = append(clean, path) } } diff --git a/pkg/migration/list_test.go b/pkg/migration/list_test.go index 80d09fd40..a9babf61c 100644 --- a/pkg/migration/list_test.go +++ b/pkg/migration/list_test.go @@ -73,6 +73,7 @@ func TestLocalMigrations(t *testing.T) { fsys := fs.MapFS{ "20211208000000_init.sql": &fs.MapFile{}, "20211208000001_invalid.ts": &fs.MapFile{}, + "r_invalid.ts": &fs.MapFile{}, } // Run test versions, err := ListLocalMigrations(".", fsys)