diff --git a/pkg/migrations/op_common_test.go b/pkg/migrations/op_common_test.go index fd5a40528..7e3d41d86 100644 --- a/pkg/migrations/op_common_test.go +++ b/pkg/migrations/op_common_test.go @@ -7,28 +7,16 @@ import ( "database/sql" "errors" "fmt" - "log" - "math/rand" - "net/url" - "os" "testing" - "time" "github.com/lib/pq" - "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/modules/postgres" - "github.com/testcontainers/testcontainers-go/wait" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/roll" - "github.com/xataio/pgroll/pkg/state" + "github.com/xataio/pgroll/pkg/testutils" "golang.org/x/exp/maps" "golang.org/x/exp/slices" ) -// The version of postgres against which the tests are run -// if the POSTGRES_VERSION environment variable is not set. -const defaultPostgresVersion = "15.3" - type TestCase struct { name string migrations []migrations.Migration @@ -40,49 +28,14 @@ type TestCase struct { type TestCases []TestCase -var tConnStr string - -// TestMain starts a postgres container to be used by all tests in the package. -// Each test then connects to the container and creates a new database. func TestMain(m *testing.M) { - ctx := context.Background() - - waitForLogs := wait. - ForLog("database system is ready to accept connections"). - WithOccurrence(2). - WithStartupTimeout(5 * time.Second) - - pgVersion := os.Getenv("POSTGRES_VERSION") - if pgVersion == "" { - pgVersion = defaultPostgresVersion - } - - ctr, err := postgres.RunContainer(ctx, - testcontainers.WithImage("postgres:"+pgVersion), - testcontainers.WithWaitStrategy(waitForLogs), - ) - if err != nil { - os.Exit(1) - } - - tConnStr, err = ctr.ConnectionString(ctx, "sslmode=disable") - if err != nil { - os.Exit(1) - } - - exitCode := m.Run() - - if err := ctr.Terminate(ctx); err != nil { - log.Printf("Failed to terminate container: %v", err) - } - - os.Exit(exitCode) + testutils.SharedTestMain(m) } func ExecuteTests(t *testing.T, tests TestCases) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() // run all migrations except the last one @@ -142,89 +95,6 @@ func ExecuteTests(t *testing.T, tests TestCases) { } } -// withMigratorAndConnectionToContainer: -// * connects to the test container created in TestMain. -// * creates a new database. -// * initializes pgroll in the new database. -// * runs the supplied test function with a migrator and a connection to the new database. -func withMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) { - t.Helper() - ctx := context.Background() - - tDB, err := sql.Open("postgres", tConnStr) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := tDB.Close(); err != nil { - t.Fatalf("Failed to close database connection: %v", err) - } - }) - - dbName := randomDBName() - - _, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName))) - if err != nil { - t.Fatal(err) - } - - u, err := url.Parse(tConnStr) - if err != nil { - t.Fatal(err) - } - - u.Path = "/" + dbName - connStr := u.String() - - st, err := state.New(ctx, connStr, "pgroll") - if err != nil { - t.Fatal(err) - } - - err = st.Init(ctx) - if err != nil { - t.Fatal(err) - } - - const lockTimeoutMs = 500 - mig, err := roll.New(ctx, connStr, "public", lockTimeoutMs, st) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := mig.Close(); err != nil { - t.Fatalf("Failed to close migrator connection: %v", err) - } - }) - - db, err := sql.Open("postgres", connStr) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := db.Close(); err != nil { - t.Fatalf("Failed to close database connection: %v", err) - } - }) - - fn(mig, db) -} - -func randomDBName() string { - const length = 15 - const charset = "abcdefghijklmnopqrstuvwxyz" - - b := make([]byte, length) - for i := range b { - b[i] = charset[rand.Intn(len(charset))] // #nosec G404 - } - - return "testdb_" + string(b) -} - // Common assertions func ViewMustExist(t *testing.T, db *sql.DB, schema, version, view string) { diff --git a/pkg/roll/execute_test.go b/pkg/roll/execute_test.go index 3967d0eb6..93c740ca3 100644 --- a/pkg/roll/execute_test.go +++ b/pkg/roll/execute_test.go @@ -7,31 +7,28 @@ import ( "database/sql" "errors" "fmt" - "os" "testing" - "time" "github.com/lib/pq" "github.com/stretchr/testify/assert" - "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/modules/postgres" - "github.com/testcontainers/testcontainers-go/wait" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/roll" "github.com/xataio/pgroll/pkg/state" + "github.com/xataio/pgroll/pkg/testutils" ) const ( schema = "public" - // The version of postgres against which the tests are run - // if the POSTGRES_VERSION environment variable is not set. - defaultPostgresVersion = "15.3" ) +func TestMain(m *testing.M) { + testutils.SharedTestMain(m) +} + func TestSchemaIsCreatedfterMigrationStart(t *testing.T) { t.Parallel() - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() version := "1_create_table" @@ -63,7 +60,7 @@ func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) { t.Parallel() t.Run("when the previous version is a pgroll migration", func(t *testing.T) { - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() const ( firstVersion = "1_create_table" @@ -104,7 +101,7 @@ func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) { }) t.Run("when the previous version is an inferred DDL migration", func(t *testing.T) { - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() const ( firstVersion = "1_create_table" @@ -157,7 +154,7 @@ func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) { func TestSchemaIsDroppedAfterMigrationRollback(t *testing.T) { t.Parallel() - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() version := "1_create_table" @@ -191,7 +188,7 @@ func TestSchemaIsDroppedAfterMigrationRollback(t *testing.T) { func TestSchemaOptionIsRespected(t *testing.T) { t.Parallel() - withMigratorInSchemaAndConnectionToContainer(t, "schema1", func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorInSchemaAndConnectionToContainer(t, "schema1", func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() const version1 = "1_create_table" const version2 = "2_create_another_table" @@ -256,7 +253,7 @@ func TestSchemaOptionIsRespected(t *testing.T) { func TestLockTimeoutIsEnforced(t *testing.T) { t.Parallel() - withMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 100, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 100, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() // Start a create table migration @@ -310,7 +307,7 @@ func TestLockTimeoutIsEnforced(t *testing.T) { func TestViewsAreCreatedWithSecurityInvokerTrue(t *testing.T) { t.Parallel() - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() version := "1_create_table" @@ -392,7 +389,7 @@ func TestViewsAreCreatedWithSecurityInvokerTrue(t *testing.T) { func TestStatusMethodReturnsCorrectStatus(t *testing.T) { t.Parallel() - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() // Get the initial migration status before any migrations are run @@ -490,86 +487,6 @@ func addColumnOp(tableName string) *migrations.OpAddColumn { } } -func withMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, schema string, lockTimeoutMs int, fn func(mig *roll.Roll, db *sql.DB)) { - t.Helper() - ctx := context.Background() - - waitForLogs := wait. - ForLog("database system is ready to accept connections"). - WithOccurrence(2). - WithStartupTimeout(5 * time.Second) - - pgVersion := os.Getenv("POSTGRES_VERSION") - if pgVersion == "" { - pgVersion = defaultPostgresVersion - } - - ctr, err := postgres.RunContainer(ctx, - testcontainers.WithImage("postgres:"+pgVersion), - testcontainers.WithWaitStrategy(waitForLogs), - ) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := ctr.Terminate(ctx); err != nil { - t.Fatalf("Failed to terminate container: %v", err) - } - }) - - cStr, err := ctr.ConnectionString(ctx, "sslmode=disable") - if err != nil { - t.Fatal(err) - } - - st, err := state.New(ctx, cStr, "pgroll") - if err != nil { - t.Fatal(err) - } - err = st.Init(ctx) - if err != nil { - t.Fatal(err) - } - - mig, err := roll.New(ctx, cStr, schema, lockTimeoutMs, st) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := mig.Close(); err != nil { - t.Fatalf("Failed to close migrator connection: %v", err) - } - }) - - db, err := sql.Open("postgres", cStr) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := db.Close(); err != nil { - t.Fatalf("Failed to close database connection: %v", err) - } - }) - - _, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema)) - if err != nil { - t.Fatal(err) - } - - fn(mig, db) -} - -func withMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, fn func(mig *roll.Roll, db *sql.DB)) { - withMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, schema, 500, fn) -} - -func withMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) { - withMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 500, fn) -} - func MustSelect(t *testing.T, db *sql.DB, schema, version, table string) []map[string]any { t.Helper() versionSchema := roll.VersionedSchemaName(schema, version) diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go index fbb14f2ad..a5decd431 100644 --- a/pkg/state/state_test.go +++ b/pkg/state/state_test.go @@ -5,29 +5,25 @@ package state_test import ( "context" "database/sql" - "os" "testing" - "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/stretchr/testify/assert" - "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/modules/postgres" - "github.com/testcontainers/testcontainers-go/wait" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/state" + "github.com/xataio/pgroll/pkg/testutils" ) -// The version of postgres against which the tests are run -// if the POSTGRES_VERSION environment variable is not set. -const defaultPostgresVersion = "15.3" +func TestMain(m *testing.M) { + testutils.SharedTestMain(m) +} func TestSchemaOptionIsRespected(t *testing.T) { t.Parallel() - witStateAndConnectionToContainer(t, func(state *state.State, db *sql.DB) { + testutils.WithStateAndConnectionToContainer(t, func(state *state.State, db *sql.DB) { ctx := context.Background() // create a table in the public schema @@ -63,7 +59,7 @@ func TestSchemaOptionIsRespected(t *testing.T) { func TestReadSchema(t *testing.T) { t.Parallel() - witStateAndConnectionToContainer(t, func(state *state.State, db *sql.DB) { + testutils.WithStateAndConnectionToContainer(t, func(state *state.State, db *sql.DB) { ctx := context.Background() tests := []struct { @@ -187,55 +183,3 @@ func TestReadSchema(t *testing.T) { } }) } - -func witStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.DB)) { - t.Helper() - ctx := context.Background() - - waitForLogs := wait. - ForLog("database system is ready to accept connections"). - WithOccurrence(2). - WithStartupTimeout(5 * time.Second) - - pgVersion := os.Getenv("POSTGRES_VERSION") - if pgVersion == "" { - pgVersion = defaultPostgresVersion - } - - ctr, err := postgres.RunContainer(ctx, - testcontainers.WithImage("postgres:"+pgVersion), - testcontainers.WithWaitStrategy(waitForLogs), - ) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := ctr.Terminate(ctx); err != nil { - t.Fatalf("Failed to terminate container: %v", err) - } - }) - - cStr, err := ctr.ConnectionString(ctx, "sslmode=disable") - if err != nil { - t.Fatal(err) - } - - db, err := sql.Open("postgres", cStr) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := db.Close(); err != nil { - t.Fatalf("Failed to close database connection: %v", err) - } - }) - - st, err := state.New(ctx, cStr, "pgroll") - if err != nil { - t.Fatal(err) - } - - fn(st, db) -} diff --git a/pkg/testutils/db.go b/pkg/testutils/db.go new file mode 100644 index 000000000..53df83579 --- /dev/null +++ b/pkg/testutils/db.go @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: Apache-2.0 + +package testutils + +import "math/rand" + +func randomDBName() string { + const length = 15 + const charset = "abcdefghijklmnopqrstuvwxyz" + + b := make([]byte, length) + for i := range b { + b[i] = charset[rand.Intn(len(charset))] // #nosec G404 + } + + return "testdb_" + string(b) +} diff --git a/pkg/testutils/util.go b/pkg/testutils/util.go new file mode 100644 index 000000000..c359b7ea0 --- /dev/null +++ b/pkg/testutils/util.go @@ -0,0 +1,259 @@ +// SPDX-License-Identifier: Apache-2.0 + +package testutils + +import ( + "context" + "database/sql" + "fmt" + "log" + "net/url" + "os" + "testing" + "time" + + "github.com/lib/pq" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/postgres" + "github.com/testcontainers/testcontainers-go/wait" + "github.com/xataio/pgroll/pkg/roll" + "github.com/xataio/pgroll/pkg/state" +) + +// The version of postgres against which the tests are run +// if the POSTGRES_VERSION environment variable is not set. +const defaultPostgresVersion = "15.3" + +// tConnStr holds the connection string to the test container created in TestMain. +var tConnStr string + +// SharedTestMain starts a postgres container to be used by all tests in a package. +// Each test then connects to the container and creates a new database. +func SharedTestMain(m *testing.M) { + ctx := context.Background() + + waitForLogs := wait. + ForLog("database system is ready to accept connections"). + WithOccurrence(2). + WithStartupTimeout(5 * time.Second) + + pgVersion := os.Getenv("POSTGRES_VERSION") + if pgVersion == "" { + pgVersion = defaultPostgresVersion + } + + ctr, err := postgres.RunContainer(ctx, + testcontainers.WithImage("postgres:"+pgVersion), + testcontainers.WithWaitStrategy(waitForLogs), + ) + if err != nil { + os.Exit(1) + } + + tConnStr, err = ctr.ConnectionString(ctx, "sslmode=disable") + if err != nil { + os.Exit(1) + } + + exitCode := m.Run() + + if err := ctr.Terminate(ctx); err != nil { + log.Printf("Failed to terminate container: %v", err) + } + + os.Exit(exitCode) +} + +// WithMigratorAndConnectionToContainer: +// * connects to the test container created in TestMain. +// * creates a new database. +// * initializes pgroll in the new database. +// * runs the supplied test function with a migrator and a connection to the new database. +func WithMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) { + t.Helper() + ctx := context.Background() + + tDB, err := sql.Open("postgres", tConnStr) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := tDB.Close(); err != nil { + t.Fatalf("Failed to close database connection: %v", err) + } + }) + + dbName := randomDBName() + + _, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName))) + if err != nil { + t.Fatal(err) + } + + u, err := url.Parse(tConnStr) + if err != nil { + t.Fatal(err) + } + + u.Path = "/" + dbName + connStr := u.String() + + st, err := state.New(ctx, connStr, "pgroll") + if err != nil { + t.Fatal(err) + } + + err = st.Init(ctx) + if err != nil { + t.Fatal(err) + } + + const lockTimeoutMs = 500 + mig, err := roll.New(ctx, connStr, "public", lockTimeoutMs, st) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := mig.Close(); err != nil { + t.Fatalf("Failed to close migrator connection: %v", err) + } + }) + + db, err := sql.Open("postgres", connStr) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := db.Close(); err != nil { + t.Fatalf("Failed to close database connection: %v", err) + } + }) + + fn(mig, db) +} + +func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.DB)) { + t.Helper() + ctx := context.Background() + + tDB, err := sql.Open("postgres", tConnStr) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := tDB.Close(); err != nil { + t.Fatalf("Failed to close database connection: %v", err) + } + }) + + dbName := randomDBName() + + _, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName))) + if err != nil { + t.Fatal(err) + } + + u, err := url.Parse(tConnStr) + if err != nil { + t.Fatal(err) + } + + u.Path = "/" + dbName + connStr := u.String() + + st, err := state.New(ctx, connStr, "pgroll") + if err != nil { + t.Fatal(err) + } + + db, err := sql.Open("postgres", connStr) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := db.Close(); err != nil { + t.Fatalf("Failed to close database connection: %v", err) + } + }) + + fn(st, db) +} + +func WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, schema string, lockTimeoutMs int, fn func(mig *roll.Roll, db *sql.DB)) { + t.Helper() + ctx := context.Background() + + tDB, err := sql.Open("postgres", tConnStr) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := tDB.Close(); err != nil { + t.Fatalf("Failed to close database connection: %v", err) + } + }) + + dbName := randomDBName() + + _, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName))) + if err != nil { + t.Fatal(err) + } + + u, err := url.Parse(tConnStr) + if err != nil { + t.Fatal(err) + } + + u.Path = "/" + dbName + connStr := u.String() + + st, err := state.New(ctx, connStr, "pgroll") + if err != nil { + t.Fatal(err) + } + + err = st.Init(ctx) + if err != nil { + t.Fatal(err) + } + + mig, err := roll.New(ctx, connStr, schema, lockTimeoutMs, st) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := mig.Close(); err != nil { + t.Fatalf("Failed to close migrator connection: %v", err) + } + }) + + db, err := sql.Open("postgres", connStr) + if err != nil { + t.Fatal(err) + } + + _, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema)) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := db.Close(); err != nil { + t.Fatalf("Failed to close database connection: %v", err) + } + }) + + fn(mig, db) +} + +func WithMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, fn func(mig *roll.Roll, db *sql.DB)) { + WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, schema, 500, fn) +}