diff --git a/internal/db/diff/diff.go b/internal/db/diff/diff.go index 646ee005d..51c38b65d 100644 --- a/internal/db/diff/diff.go +++ b/internal/db/diff/diff.go @@ -33,16 +33,18 @@ type DiffFunc func(context.Context, string, string, []string) (string, error) func Run(ctx context.Context, schema []string, file string, config pgconn.Config, differ DiffFunc, fsys afero.Fs, options ...func(*pgx.ConnConfig)) (err error) { // Sanity checks. if utils.IsLocalDatabase(config) { - if declared, err := loadDeclaredSchemas(fsys); err != nil { - return err - } else if container, err := createShadowIfNotExists(ctx, declared); err != nil { - return err - } else if len(container) > 0 { - defer utils.DockerRemove(container) - if err := start.WaitForHealthyService(ctx, start.HealthTimeout, container); err != nil { + if err := utils.AssertSupabaseDbIsRunning(); errors.Is(err, utils.ErrNotRunning) { + if err := start.StartDatabase(ctx, "", fsys, os.Stderr); err != nil { return err } - if err := migrateBaseDatabase(ctx, container, declared, fsys, options...); err != nil { + } else if err != nil { + return err + } + if declared, err := loadDeclaredSchemas(fsys); err != nil { + return err + } else if len(declared) > 0 { + config.Database = "_declared" + if err := migrateBaseDatabase(ctx, config.Database, declared, fsys); err != nil { return err } } @@ -72,22 +74,6 @@ func Run(ctx context.Context, schema []string, file string, config pgconn.Config return nil } -func createShadowIfNotExists(ctx context.Context, migrations []string) (string, error) { - if len(migrations) == 0 { - return "", nil - } - if err := utils.AssertSupabaseDbIsRunning(); !errors.Is(err, utils.ErrNotRunning) { - return "", err - } - fmt.Fprintln(os.Stderr, "Creating local database from declarative schemas:") - msg := make([]string, len(migrations)) - for i, m := range migrations { - msg[i] = fmt.Sprintf(" • %s", utils.Bold(m)) - } - fmt.Fprintln(os.Stderr, strings.Join(msg, "\n")) - return CreateShadowDatabase(ctx, utils.Config.Db.Port) -} - func loadDeclaredSchemas(fsys afero.Fs) ([]string, error) { if schemas := utils.Config.Db.Migrations.SchemaPaths; len(schemas) > 0 { return schemas.Files(afero.NewIOFS(fsys)) @@ -169,26 +155,51 @@ func MigrateShadowDatabase(ctx context.Context, container string, fsys afero.Fs, if err != nil { return err } + if err := start.InitDatabase(ctx, container[:12], os.Stderr); err != nil { + return err + } conn, err := ConnectShadowDatabase(ctx, 10*time.Second, options...) if err != nil { return err } defer conn.Close(context.Background()) - if err := start.SetupDatabase(ctx, conn, container[:12], os.Stderr, fsys); err != nil { + if err := start.SetupDatabase(ctx, conn, os.Stderr, fsys); err != nil { return err } return migration.ApplyMigrations(ctx, migrations, conn, afero.NewIOFS(fsys)) } -func migrateBaseDatabase(ctx context.Context, container string, migrations []string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectLocalPostgres(ctx, pgconn.Config{}, options...) +func createBaseDatabase(ctx context.Context, database string, options ...func(*pgx.ConnConfig)) error { + pgc := pgconn.Config{Database: "template1"} + conn, err := utils.ConnectLocalPostgres(ctx, pgc, options...) if err != nil { return err } defer conn.Close(context.Background()) - if err := start.SetupDatabase(ctx, conn, container[:12], os.Stderr, fsys); err != nil { + if _, err := conn.Exec(ctx, "DROP DATABASE IF EXISTS "+database); err != nil { + return errors.Errorf("failed to drop database: %w", err) + } + if _, err := conn.Exec(ctx, fmt.Sprintf("CREATE DATABASE %s TEMPLATE _shadow", database)); err != nil { + return errors.Errorf("failed to create database: %w", err) + } + return nil +} + +func migrateBaseDatabase(ctx context.Context, database string, migrations []string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { + fmt.Fprintln(os.Stderr, "Creating local database from declarative schemas:") + msg := make([]string, len(migrations)) + for i, m := range migrations { + msg[i] = fmt.Sprintf(" • %s", utils.Bold(m)) + } + fmt.Fprintln(os.Stderr, strings.Join(msg, "\n")) + if err := createBaseDatabase(ctx, database, options...); err != nil { return err } + conn, err := utils.ConnectLocalPostgres(ctx, pgconn.Config{Database: database}, options...) + if err != nil { + return err + } + defer conn.Close(context.Background()) return migration.SeedGlobals(ctx, migrations, conn, afero.NewIOFS(fsys)) } diff --git a/internal/db/start/start.go b/internal/db/start/start.go index 608e5ca61..44c94419e 100644 --- a/internal/db/start/start.go +++ b/internal/db/start/start.go @@ -247,9 +247,18 @@ func initCurrentBranch(fsys afero.Fs) error { return utils.WriteFile(utils.CurrBranchPath, []byte("main"), fsys) } -func initSchema(ctx context.Context, conn *pgx.Conn, host string, w io.Writer) error { +func InitDatabase(ctx context.Context, host string, w io.Writer) error { fmt.Fprintln(w, "Initialising schema...") if utils.Config.Db.MajorVersion <= 14 { + pgc := pgconn.Config{} + if host != utils.DbId { + pgc.Port = utils.Config.Db.ShadowPort + } + conn, err := utils.ConnectLocalPostgres(ctx, pgc) + if err != nil { + return err + } + defer conn.Close(context.Background()) if file, err := migration.NewMigrationFromReader(strings.NewReader(utils.GlobalsSql)); err != nil { return err } else if err := file.ExecBatch(ctx, conn); err != nil { @@ -336,7 +345,7 @@ func initAuthJob(host string) utils.DockerJob { } } -func initSchema15(ctx context.Context, host string) error { +func initSchema15(ctx context.Context, host string, options ...func(*pgx.ConnConfig)) error { // Apply service migrations var initJobs []utils.DockerJob if utils.Config.Realtime.Enabled { @@ -354,25 +363,44 @@ func initSchema15(ctx context.Context, host string) error { return err } } + // Only create template on main database + if host != utils.DbId { + return nil + } + pgc := pgconn.Config{User: "supabase_admin", Database: "template1"} + conn, err := utils.ConnectLocalPostgres(ctx, pgc, options...) + if err != nil { + return err + } + defer conn.Close(context.Background()) + if _, err := conn.Exec(ctx, "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = 'postgres'"); err != nil { + return errors.Errorf("failed to disconnect clients: %w", err) + } + if _, err := conn.Exec(ctx, "SET ROLE postgres"); err != nil { + return errors.Errorf("failed to switch role: %w", err) + } + if _, err := conn.Exec(ctx, "CREATE DATABASE _shadow TEMPLATE postgres"); err != nil { + return errors.Errorf("failed to create template: %w", err) + } return nil } func SetupLocalDatabase(ctx context.Context, version string, fsys afero.Fs, w io.Writer, options ...func(*pgx.ConnConfig)) error { + if err := InitDatabase(ctx, utils.DbId, w); err != nil { + return err + } conn, err := utils.ConnectLocalPostgres(ctx, pgconn.Config{}, options...) if err != nil { return err } defer conn.Close(context.Background()) - if err := SetupDatabase(ctx, conn, utils.DbId, w, fsys); err != nil { + if err := SetupDatabase(ctx, conn, w, fsys); err != nil { return err } return apply.MigrateAndSeed(ctx, version, conn, fsys) } -func SetupDatabase(ctx context.Context, conn *pgx.Conn, host string, w io.Writer, fsys afero.Fs) error { - if err := initSchema(ctx, conn, host, w); err != nil { - return err - } +func SetupDatabase(ctx context.Context, conn *pgx.Conn, w io.Writer, fsys afero.Fs) error { // Create vault secrets first so roles.sql can reference them if err := vault.UpsertVaultSecrets(ctx, utils.Config.Db.Vault, conn); err != nil { return err diff --git a/internal/migration/squash/squash.go b/internal/migration/squash/squash.go index 22f3278de..d2db9a669 100644 --- a/internal/migration/squash/squash.go +++ b/internal/migration/squash/squash.go @@ -87,12 +87,15 @@ func squashMigrations(ctx context.Context, migrations []string, fsys afero.Fs, o if err := start.WaitForHealthyService(ctx, start.HealthTimeout, shadow); err != nil { return err } + if err := start.InitDatabase(ctx, shadow[:12], os.Stderr); err != nil { + return err + } conn, err := diff.ConnectShadowDatabase(ctx, 10*time.Second, options...) if err != nil { return err } defer conn.Close(context.Background()) - if err := start.SetupDatabase(ctx, conn, shadow[:12], os.Stderr, fsys); err != nil { + if err := start.SetupDatabase(ctx, conn, os.Stderr, fsys); err != nil { return err } // Assuming entities in managed schemas are not altered, we can simply diff the dumps before and after migrations.