diff --git a/pkg/migrations/op_common_test.go b/pkg/migrations/op_common_test.go index e0c17a8e..7e3d41d8 100644 --- a/pkg/migrations/op_common_test.go +++ b/pkg/migrations/op_common_test.go @@ -7,25 +7,16 @@ import ( "database/sql" "errors" "fmt" - "os" "testing" - "time" "github.com/lib/pq" - "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/modules/postgres" - "github.com/testcontainers/testcontainers-go/wait" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/roll" - "github.com/xataio/pgroll/pkg/state" + "github.com/xataio/pgroll/pkg/testutils" "golang.org/x/exp/maps" "golang.org/x/exp/slices" ) -// The version of postgres against which the tests are run -// if the POSTGRES_VERSION environment variable is not set. -const defaultPostgresVersion = "15.3" - type TestCase struct { name string migrations []migrations.Migration @@ -37,10 +28,14 @@ type TestCase struct { type TestCases []TestCase +func TestMain(m *testing.M) { + testutils.SharedTestMain(m) +} + func ExecuteTests(t *testing.T, tests TestCases) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() // run all migrations except the last one @@ -100,74 +95,6 @@ func ExecuteTests(t *testing.T, tests TestCases) { } } -func withMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) { - t.Helper() - ctx := context.Background() - - waitForLogs := wait. - ForLog("database system is ready to accept connections"). - WithOccurrence(2). - WithStartupTimeout(5 * time.Second) - - pgVersion := os.Getenv("POSTGRES_VERSION") - if pgVersion == "" { - pgVersion = defaultPostgresVersion - } - - ctr, err := postgres.RunContainer(ctx, - testcontainers.WithImage("postgres:"+pgVersion), - testcontainers.WithWaitStrategy(waitForLogs), - ) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := ctr.Terminate(ctx); err != nil { - t.Fatalf("Failed to terminate container: %v", err) - } - }) - - cStr, err := ctr.ConnectionString(ctx, "sslmode=disable") - if err != nil { - t.Fatal(err) - } - - st, err := state.New(ctx, cStr, "pgroll") - if err != nil { - t.Fatal(err) - } - err = st.Init(ctx) - if err != nil { - t.Fatal(err) - } - - const lockTimeoutMs = 500 - mig, err := roll.New(ctx, cStr, "public", lockTimeoutMs, st) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := mig.Close(); err != nil { - t.Fatalf("Failed to close migrator connection: %v", err) - } - }) - - db, err := sql.Open("postgres", cStr) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := db.Close(); err != nil { - t.Fatalf("Failed to close database connection: %v", err) - } - }) - - fn(mig, db) -} - // Common assertions func ViewMustExist(t *testing.T, db *sql.DB, schema, version, view string) { diff --git a/pkg/roll/execute_test.go b/pkg/roll/execute_test.go index 3967d0eb..93c740ca 100644 --- a/pkg/roll/execute_test.go +++ b/pkg/roll/execute_test.go @@ -7,31 +7,28 @@ import ( "database/sql" "errors" "fmt" - "os" "testing" - "time" "github.com/lib/pq" "github.com/stretchr/testify/assert" - "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/modules/postgres" - "github.com/testcontainers/testcontainers-go/wait" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/roll" "github.com/xataio/pgroll/pkg/state" + "github.com/xataio/pgroll/pkg/testutils" ) const ( schema = "public" - // The version of postgres against which the tests are run - // if the POSTGRES_VERSION environment variable is not set. - defaultPostgresVersion = "15.3" ) +func TestMain(m *testing.M) { + testutils.SharedTestMain(m) +} + func TestSchemaIsCreatedfterMigrationStart(t *testing.T) { t.Parallel() - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() version := "1_create_table" @@ -63,7 +60,7 @@ func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) { t.Parallel() t.Run("when the previous version is a pgroll migration", func(t *testing.T) { - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() const ( firstVersion = "1_create_table" @@ -104,7 +101,7 @@ func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) { }) t.Run("when the previous version is an inferred DDL migration", func(t *testing.T) { - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() const ( firstVersion = "1_create_table" @@ -157,7 +154,7 @@ func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) { func TestSchemaIsDroppedAfterMigrationRollback(t *testing.T) { t.Parallel() - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() version := "1_create_table" @@ -191,7 +188,7 @@ func TestSchemaIsDroppedAfterMigrationRollback(t *testing.T) { func TestSchemaOptionIsRespected(t *testing.T) { t.Parallel() - withMigratorInSchemaAndConnectionToContainer(t, "schema1", func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorInSchemaAndConnectionToContainer(t, "schema1", func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() const version1 = "1_create_table" const version2 = "2_create_another_table" @@ -256,7 +253,7 @@ func TestSchemaOptionIsRespected(t *testing.T) { func TestLockTimeoutIsEnforced(t *testing.T) { t.Parallel() - withMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 100, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 100, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() // Start a create table migration @@ -310,7 +307,7 @@ func TestLockTimeoutIsEnforced(t *testing.T) { func TestViewsAreCreatedWithSecurityInvokerTrue(t *testing.T) { t.Parallel() - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() version := "1_create_table" @@ -392,7 +389,7 @@ func TestViewsAreCreatedWithSecurityInvokerTrue(t *testing.T) { func TestStatusMethodReturnsCorrectStatus(t *testing.T) { t.Parallel() - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() // Get the initial migration status before any migrations are run @@ -490,86 +487,6 @@ func addColumnOp(tableName string) *migrations.OpAddColumn { } } -func withMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, schema string, lockTimeoutMs int, fn func(mig *roll.Roll, db *sql.DB)) { - t.Helper() - ctx := context.Background() - - waitForLogs := wait. - ForLog("database system is ready to accept connections"). - WithOccurrence(2). - WithStartupTimeout(5 * time.Second) - - pgVersion := os.Getenv("POSTGRES_VERSION") - if pgVersion == "" { - pgVersion = defaultPostgresVersion - } - - ctr, err := postgres.RunContainer(ctx, - testcontainers.WithImage("postgres:"+pgVersion), - testcontainers.WithWaitStrategy(waitForLogs), - ) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := ctr.Terminate(ctx); err != nil { - t.Fatalf("Failed to terminate container: %v", err) - } - }) - - cStr, err := ctr.ConnectionString(ctx, "sslmode=disable") - if err != nil { - t.Fatal(err) - } - - st, err := state.New(ctx, cStr, "pgroll") - if err != nil { - t.Fatal(err) - } - err = st.Init(ctx) - if err != nil { - t.Fatal(err) - } - - mig, err := roll.New(ctx, cStr, schema, lockTimeoutMs, st) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := mig.Close(); err != nil { - t.Fatalf("Failed to close migrator connection: %v", err) - } - }) - - db, err := sql.Open("postgres", cStr) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := db.Close(); err != nil { - t.Fatalf("Failed to close database connection: %v", err) - } - }) - - _, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema)) - if err != nil { - t.Fatal(err) - } - - fn(mig, db) -} - -func withMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, fn func(mig *roll.Roll, db *sql.DB)) { - withMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, schema, 500, fn) -} - -func withMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) { - withMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 500, fn) -} - func MustSelect(t *testing.T, db *sql.DB, schema, version, table string) []map[string]any { t.Helper() versionSchema := roll.VersionedSchemaName(schema, version) diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go index fbb14f2a..a5decd43 100644 --- a/pkg/state/state_test.go +++ b/pkg/state/state_test.go @@ -5,29 +5,25 @@ package state_test import ( "context" "database/sql" - "os" "testing" - "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/stretchr/testify/assert" - "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/modules/postgres" - "github.com/testcontainers/testcontainers-go/wait" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/state" + "github.com/xataio/pgroll/pkg/testutils" ) -// The version of postgres against which the tests are run -// if the POSTGRES_VERSION environment variable is not set. -const defaultPostgresVersion = "15.3" +func TestMain(m *testing.M) { + testutils.SharedTestMain(m) +} func TestSchemaOptionIsRespected(t *testing.T) { t.Parallel() - witStateAndConnectionToContainer(t, func(state *state.State, db *sql.DB) { + testutils.WithStateAndConnectionToContainer(t, func(state *state.State, db *sql.DB) { ctx := context.Background() // create a table in the public schema @@ -63,7 +59,7 @@ func TestSchemaOptionIsRespected(t *testing.T) { func TestReadSchema(t *testing.T) { t.Parallel() - witStateAndConnectionToContainer(t, func(state *state.State, db *sql.DB) { + testutils.WithStateAndConnectionToContainer(t, func(state *state.State, db *sql.DB) { ctx := context.Background() tests := []struct { @@ -187,55 +183,3 @@ func TestReadSchema(t *testing.T) { } }) } - -func witStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.DB)) { - t.Helper() - ctx := context.Background() - - waitForLogs := wait. - ForLog("database system is ready to accept connections"). - WithOccurrence(2). - WithStartupTimeout(5 * time.Second) - - pgVersion := os.Getenv("POSTGRES_VERSION") - if pgVersion == "" { - pgVersion = defaultPostgresVersion - } - - ctr, err := postgres.RunContainer(ctx, - testcontainers.WithImage("postgres:"+pgVersion), - testcontainers.WithWaitStrategy(waitForLogs), - ) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := ctr.Terminate(ctx); err != nil { - t.Fatalf("Failed to terminate container: %v", err) - } - }) - - cStr, err := ctr.ConnectionString(ctx, "sslmode=disable") - if err != nil { - t.Fatal(err) - } - - db, err := sql.Open("postgres", cStr) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := db.Close(); err != nil { - t.Fatalf("Failed to close database connection: %v", err) - } - }) - - st, err := state.New(ctx, cStr, "pgroll") - if err != nil { - t.Fatal(err) - } - - fn(st, db) -}