diff --git a/cmd/perform.go b/cmd/perform.go index 05d660b..42c8883 100644 --- a/cmd/perform.go +++ b/cmd/perform.go @@ -9,6 +9,7 @@ import ( "github.com/andrewpillar/mgrt/config" "github.com/andrewpillar/mgrt/database" + "github.com/andrewpillar/mgrt/migration" "github.com/andrewpillar/mgrt/revision" "github.com/andrewpillar/mgrt/util" ) @@ -34,15 +35,15 @@ func loadRevisions(c cli.Command, d revision.Direction) ([]*revision.Revision, e var err error switch d { - case revision.Up: - revisions, err = revision.Oldest() - break - case revision.Down: - revisions, err = revision.Latest() - break - default: - err = errors.New("unknown direction") - break + case revision.Up: + revisions, err = revision.Oldest() + break + case revision.Down: + revisions, err = revision.Latest() + break + default: + err = errors.New("unknown direction") + break } return revisions, err @@ -83,38 +84,8 @@ func perform(c cli.Command, d revision.Direction) { force := c.Flags.IsSet("force") - for _, r := range revisions { - r.Direction = d - - if err := r.GenHash(); err != nil { - util.ExitError("failed to perform revision", err) - } - - if err := db.Perform(r, force); err != nil { - if err != database.ErrAlreadyPerformed { - util.ExitError("failed to perform revision", fmt.Errorf("%s: %d", err, r.ID)) - } - - fmt.Printf("%s - %s: %d", d, err, r.ID) - - if r.Message != "" { - fmt.Printf(": %s", r.Message) - } - - fmt.Printf("\n") - continue - } - - if err := db.Log(r, force); err != nil { - util.ExitError("failed to log revision", err) - } - - fmt.Printf("%s - performed revision: %d", d, r.ID) - - if r.Message != "" { - fmt.Printf(": %s", r.Message) - } - - fmt.Printf("\n") + err = migration.Perform(db, revisions, d, force) + if err != nil { + util.ExitError("failed to perform revision ", err) } } diff --git a/database/mysql.go b/database/mysql.go index 26d7cf7..41b9454 100644 --- a/database/mysql.go +++ b/database/mysql.go @@ -55,6 +55,12 @@ func init() { databases["mysql"] = &MySQL{} } +func (p *MySQL) FromConn(db *sql.DB) { + p.database = &database{ + DB: db, + } +} + func (m *MySQL) Open(cfg *config.Config) error { dsn := fmt.Sprintf(mysqlDsn, cfg.Username, cfg.Password, cfg.Address, cfg.Database) diff --git a/database/postgres.go b/database/postgres.go index 918e414..9816bd7 100644 --- a/database/postgres.go +++ b/database/postgres.go @@ -52,6 +52,12 @@ func init() { databases["postgres"] = &Postgres{} } +func (p *Postgres) FromConn(db *sql.DB) { + p.database = &database{ + DB: db, + } +} + func (p *Postgres) Open(cfg *config.Config) error { host, port, err := net.SplitHostPort(cfg.Address) diff --git a/database/sqlite3.go b/database/sqlite3.go index 90e3c2e..54b1b70 100644 --- a/database/sqlite3.go +++ b/database/sqlite3.go @@ -32,6 +32,12 @@ func init() { databases["sqlite3"] = &SQLite3{} } +func (p *SQLite3) FromConn(db *sql.DB) { + p.database = &database{ + DB: db, + } +} + func (s *SQLite3) Open(cfg *config.Config) error { db, err := sql.Open("sqlite3", cfg.Address) diff --git a/migration/migrate.go b/migration/migrate.go new file mode 100644 index 0000000..95b38a0 --- /dev/null +++ b/migration/migrate.go @@ -0,0 +1,47 @@ +package migration + +import ( + "fmt" + + "github.com/andrewpillar/mgrt/database" + "github.com/andrewpillar/mgrt/revision" +) + +func Perform(db database.DB, revisions []*revision.Revision, d revision.Direction, force bool) error { + for _, r := range revisions { + + r.Direction = d + + if err := r.GenHash(); err != nil { + return err + } + + if err := db.Perform(r, force); err != nil { + if err != database.ErrAlreadyPerformed { + return fmt.Errorf("%s: %d", err, r.ID) + } + + fmt.Printf("%s - %s: %d", d, err, r.ID) + + if r.Message != "" { + fmt.Printf(": %s", r.Message) + } + + fmt.Printf("\n") + continue + } + + if err := db.Log(r, force); err != nil { + return err + } + + fmt.Printf("%s - performed revision: %d", d, r.ID) + + if r.Message != "" { + fmt.Printf(": %s", r.Message) + } + + fmt.Printf("\n") + } + return nil +} diff --git a/revision/revision.go b/revision/revision.go index 0c3830e..5befe43 100644 --- a/revision/revision.go +++ b/revision/revision.go @@ -17,8 +17,8 @@ import ( ) var ( - upFile = "up.sql" - downFile = "down.sql" + upFile = "up.sql" + downFile = "down.sql" reslug = regexp.MustCompile("[^a-zA-Z0-9]") redup = regexp.MustCompile("-{2,}") @@ -196,6 +196,10 @@ func Find(id string) (*Revision, error) { } } + if base == "" { + return nil, errors.New("no revision found with ID: " + id) + } + return resolveFromPath(filepath.Join(config.RevisionsDir(), base)) }