From f4c37b82c42d8f60c37b5f5cd103f19492bfd01d Mon Sep 17 00:00:00 2001 From: Andrew Farries Date: Mon, 8 Jan 2024 10:00:01 +0000 Subject: [PATCH] test: Change test isolation model to 'database per test' (#220) Improve the reliability and performance of the test suite by moving the test isolation model from 'container per test' to 'database per test'. The current test suite works well locally but is [very flaky ](https://github.com/xataio/pgroll/actions) when run on Github Actions. The cause of the flakiness is the 'container per test' isolation model, under which each testcase in each test starts its own Postgres container. This model worked well initially but as the number of tests has increased the actions runner often fails to make available the required number of containers for such a large number of parallel tests. This PR changes the isolation model to 'database per test'. Each package creates one Postgres container and then each testcase in each test creates a database within that container. This greatly reduces the number of simultaneous containers required, making the test suite faster and more reliable. Each job in the test matrix sees a 60-70% reduction in duration and (anecdotally) far fewer failures with no failures observed in ~20 runs. --- pkg/migrations/op_common_test.go | 85 +------------- pkg/roll/execute_test.go | 109 +++--------------- pkg/state/state_test.go | 68 +---------- pkg/testutils/db.go | 17 +++ pkg/testutils/util.go | 192 +++++++++++++++++++++++++++++++ 5 files changed, 234 insertions(+), 237 deletions(-) create mode 100644 pkg/testutils/db.go create mode 100644 pkg/testutils/util.go diff --git a/pkg/migrations/op_common_test.go b/pkg/migrations/op_common_test.go index e0c17a8e..7e3d41d8 100644 --- a/pkg/migrations/op_common_test.go +++ b/pkg/migrations/op_common_test.go @@ -7,25 +7,16 @@ import ( "database/sql" "errors" "fmt" - "os" "testing" - "time" "github.com/lib/pq" - "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/modules/postgres" - "github.com/testcontainers/testcontainers-go/wait" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/roll" - "github.com/xataio/pgroll/pkg/state" + "github.com/xataio/pgroll/pkg/testutils" "golang.org/x/exp/maps" "golang.org/x/exp/slices" ) -// The version of postgres against which the tests are run -// if the POSTGRES_VERSION environment variable is not set. -const defaultPostgresVersion = "15.3" - type TestCase struct { name string migrations []migrations.Migration @@ -37,10 +28,14 @@ type TestCase struct { type TestCases []TestCase +func TestMain(m *testing.M) { + testutils.SharedTestMain(m) +} + func ExecuteTests(t *testing.T, tests TestCases) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() // run all migrations except the last one @@ -100,74 +95,6 @@ func ExecuteTests(t *testing.T, tests TestCases) { } } -func withMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) { - t.Helper() - ctx := context.Background() - - waitForLogs := wait. - ForLog("database system is ready to accept connections"). - WithOccurrence(2). - WithStartupTimeout(5 * time.Second) - - pgVersion := os.Getenv("POSTGRES_VERSION") - if pgVersion == "" { - pgVersion = defaultPostgresVersion - } - - ctr, err := postgres.RunContainer(ctx, - testcontainers.WithImage("postgres:"+pgVersion), - testcontainers.WithWaitStrategy(waitForLogs), - ) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := ctr.Terminate(ctx); err != nil { - t.Fatalf("Failed to terminate container: %v", err) - } - }) - - cStr, err := ctr.ConnectionString(ctx, "sslmode=disable") - if err != nil { - t.Fatal(err) - } - - st, err := state.New(ctx, cStr, "pgroll") - if err != nil { - t.Fatal(err) - } - err = st.Init(ctx) - if err != nil { - t.Fatal(err) - } - - const lockTimeoutMs = 500 - mig, err := roll.New(ctx, cStr, "public", lockTimeoutMs, st) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := mig.Close(); err != nil { - t.Fatalf("Failed to close migrator connection: %v", err) - } - }) - - db, err := sql.Open("postgres", cStr) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := db.Close(); err != nil { - t.Fatalf("Failed to close database connection: %v", err) - } - }) - - fn(mig, db) -} - // Common assertions func ViewMustExist(t *testing.T, db *sql.DB, schema, version, view string) { diff --git a/pkg/roll/execute_test.go b/pkg/roll/execute_test.go index 3967d0eb..93c740ca 100644 --- a/pkg/roll/execute_test.go +++ b/pkg/roll/execute_test.go @@ -7,31 +7,28 @@ import ( "database/sql" "errors" "fmt" - "os" "testing" - "time" "github.com/lib/pq" "github.com/stretchr/testify/assert" - "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/modules/postgres" - "github.com/testcontainers/testcontainers-go/wait" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/roll" "github.com/xataio/pgroll/pkg/state" + "github.com/xataio/pgroll/pkg/testutils" ) const ( schema = "public" - // The version of postgres against which the tests are run - // if the POSTGRES_VERSION environment variable is not set. - defaultPostgresVersion = "15.3" ) +func TestMain(m *testing.M) { + testutils.SharedTestMain(m) +} + func TestSchemaIsCreatedfterMigrationStart(t *testing.T) { t.Parallel() - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() version := "1_create_table" @@ -63,7 +60,7 @@ func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) { t.Parallel() t.Run("when the previous version is a pgroll migration", func(t *testing.T) { - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() const ( firstVersion = "1_create_table" @@ -104,7 +101,7 @@ func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) { }) t.Run("when the previous version is an inferred DDL migration", func(t *testing.T) { - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() const ( firstVersion = "1_create_table" @@ -157,7 +154,7 @@ func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) { func TestSchemaIsDroppedAfterMigrationRollback(t *testing.T) { t.Parallel() - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() version := "1_create_table" @@ -191,7 +188,7 @@ func TestSchemaIsDroppedAfterMigrationRollback(t *testing.T) { func TestSchemaOptionIsRespected(t *testing.T) { t.Parallel() - withMigratorInSchemaAndConnectionToContainer(t, "schema1", func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorInSchemaAndConnectionToContainer(t, "schema1", func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() const version1 = "1_create_table" const version2 = "2_create_another_table" @@ -256,7 +253,7 @@ func TestSchemaOptionIsRespected(t *testing.T) { func TestLockTimeoutIsEnforced(t *testing.T) { t.Parallel() - withMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 100, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 100, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() // Start a create table migration @@ -310,7 +307,7 @@ func TestLockTimeoutIsEnforced(t *testing.T) { func TestViewsAreCreatedWithSecurityInvokerTrue(t *testing.T) { t.Parallel() - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() version := "1_create_table" @@ -392,7 +389,7 @@ func TestViewsAreCreatedWithSecurityInvokerTrue(t *testing.T) { func TestStatusMethodReturnsCorrectStatus(t *testing.T) { t.Parallel() - withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { + testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() // Get the initial migration status before any migrations are run @@ -490,86 +487,6 @@ func addColumnOp(tableName string) *migrations.OpAddColumn { } } -func withMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, schema string, lockTimeoutMs int, fn func(mig *roll.Roll, db *sql.DB)) { - t.Helper() - ctx := context.Background() - - waitForLogs := wait. - ForLog("database system is ready to accept connections"). - WithOccurrence(2). - WithStartupTimeout(5 * time.Second) - - pgVersion := os.Getenv("POSTGRES_VERSION") - if pgVersion == "" { - pgVersion = defaultPostgresVersion - } - - ctr, err := postgres.RunContainer(ctx, - testcontainers.WithImage("postgres:"+pgVersion), - testcontainers.WithWaitStrategy(waitForLogs), - ) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := ctr.Terminate(ctx); err != nil { - t.Fatalf("Failed to terminate container: %v", err) - } - }) - - cStr, err := ctr.ConnectionString(ctx, "sslmode=disable") - if err != nil { - t.Fatal(err) - } - - st, err := state.New(ctx, cStr, "pgroll") - if err != nil { - t.Fatal(err) - } - err = st.Init(ctx) - if err != nil { - t.Fatal(err) - } - - mig, err := roll.New(ctx, cStr, schema, lockTimeoutMs, st) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := mig.Close(); err != nil { - t.Fatalf("Failed to close migrator connection: %v", err) - } - }) - - db, err := sql.Open("postgres", cStr) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := db.Close(); err != nil { - t.Fatalf("Failed to close database connection: %v", err) - } - }) - - _, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema)) - if err != nil { - t.Fatal(err) - } - - fn(mig, db) -} - -func withMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, fn func(mig *roll.Roll, db *sql.DB)) { - withMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, schema, 500, fn) -} - -func withMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) { - withMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 500, fn) -} - func MustSelect(t *testing.T, db *sql.DB, schema, version, table string) []map[string]any { t.Helper() versionSchema := roll.VersionedSchemaName(schema, version) diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go index fbb14f2a..a5decd43 100644 --- a/pkg/state/state_test.go +++ b/pkg/state/state_test.go @@ -5,29 +5,25 @@ package state_test import ( "context" "database/sql" - "os" "testing" - "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/stretchr/testify/assert" - "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/modules/postgres" - "github.com/testcontainers/testcontainers-go/wait" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/schema" "github.com/xataio/pgroll/pkg/state" + "github.com/xataio/pgroll/pkg/testutils" ) -// The version of postgres against which the tests are run -// if the POSTGRES_VERSION environment variable is not set. -const defaultPostgresVersion = "15.3" +func TestMain(m *testing.M) { + testutils.SharedTestMain(m) +} func TestSchemaOptionIsRespected(t *testing.T) { t.Parallel() - witStateAndConnectionToContainer(t, func(state *state.State, db *sql.DB) { + testutils.WithStateAndConnectionToContainer(t, func(state *state.State, db *sql.DB) { ctx := context.Background() // create a table in the public schema @@ -63,7 +59,7 @@ func TestSchemaOptionIsRespected(t *testing.T) { func TestReadSchema(t *testing.T) { t.Parallel() - witStateAndConnectionToContainer(t, func(state *state.State, db *sql.DB) { + testutils.WithStateAndConnectionToContainer(t, func(state *state.State, db *sql.DB) { ctx := context.Background() tests := []struct { @@ -187,55 +183,3 @@ func TestReadSchema(t *testing.T) { } }) } - -func witStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.DB)) { - t.Helper() - ctx := context.Background() - - waitForLogs := wait. - ForLog("database system is ready to accept connections"). - WithOccurrence(2). - WithStartupTimeout(5 * time.Second) - - pgVersion := os.Getenv("POSTGRES_VERSION") - if pgVersion == "" { - pgVersion = defaultPostgresVersion - } - - ctr, err := postgres.RunContainer(ctx, - testcontainers.WithImage("postgres:"+pgVersion), - testcontainers.WithWaitStrategy(waitForLogs), - ) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := ctr.Terminate(ctx); err != nil { - t.Fatalf("Failed to terminate container: %v", err) - } - }) - - cStr, err := ctr.ConnectionString(ctx, "sslmode=disable") - if err != nil { - t.Fatal(err) - } - - db, err := sql.Open("postgres", cStr) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := db.Close(); err != nil { - t.Fatalf("Failed to close database connection: %v", err) - } - }) - - st, err := state.New(ctx, cStr, "pgroll") - if err != nil { - t.Fatal(err) - } - - fn(st, db) -} diff --git a/pkg/testutils/db.go b/pkg/testutils/db.go new file mode 100644 index 00000000..53df8357 --- /dev/null +++ b/pkg/testutils/db.go @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: Apache-2.0 + +package testutils + +import "math/rand" + +func randomDBName() string { + const length = 15 + const charset = "abcdefghijklmnopqrstuvwxyz" + + b := make([]byte, length) + for i := range b { + b[i] = charset[rand.Intn(len(charset))] // #nosec G404 + } + + return "testdb_" + string(b) +} diff --git a/pkg/testutils/util.go b/pkg/testutils/util.go new file mode 100644 index 00000000..12836fe3 --- /dev/null +++ b/pkg/testutils/util.go @@ -0,0 +1,192 @@ +// SPDX-License-Identifier: Apache-2.0 + +package testutils + +import ( + "context" + "database/sql" + "fmt" + "log" + "net/url" + "os" + "testing" + "time" + + "github.com/lib/pq" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/postgres" + "github.com/testcontainers/testcontainers-go/wait" + "github.com/xataio/pgroll/pkg/roll" + "github.com/xataio/pgroll/pkg/state" +) + +// The version of postgres against which the tests are run +// if the POSTGRES_VERSION environment variable is not set. +const defaultPostgresVersion = "15.3" + +// tConnStr holds the connection string to the test container created in TestMain. +var tConnStr string + +// SharedTestMain starts a postgres container to be used by all tests in a package. +// Each test then connects to the container and creates a new database. +func SharedTestMain(m *testing.M) { + ctx := context.Background() + + waitForLogs := wait. + ForLog("database system is ready to accept connections"). + WithOccurrence(2). + WithStartupTimeout(5 * time.Second) + + pgVersion := os.Getenv("POSTGRES_VERSION") + if pgVersion == "" { + pgVersion = defaultPostgresVersion + } + + ctr, err := postgres.RunContainer(ctx, + testcontainers.WithImage("postgres:"+pgVersion), + testcontainers.WithWaitStrategy(waitForLogs), + ) + if err != nil { + os.Exit(1) + } + + tConnStr, err = ctr.ConnectionString(ctx, "sslmode=disable") + if err != nil { + os.Exit(1) + } + + exitCode := m.Run() + + if err := ctr.Terminate(ctx); err != nil { + log.Printf("Failed to terminate container: %v", err) + } + + os.Exit(exitCode) +} + +func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.DB)) { + t.Helper() + ctx := context.Background() + + tDB, err := sql.Open("postgres", tConnStr) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := tDB.Close(); err != nil { + t.Fatalf("Failed to close database connection: %v", err) + } + }) + + dbName := randomDBName() + + _, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName))) + if err != nil { + t.Fatal(err) + } + + u, err := url.Parse(tConnStr) + if err != nil { + t.Fatal(err) + } + + u.Path = "/" + dbName + connStr := u.String() + + st, err := state.New(ctx, connStr, "pgroll") + if err != nil { + t.Fatal(err) + } + + db, err := sql.Open("postgres", connStr) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := db.Close(); err != nil { + t.Fatalf("Failed to close database connection: %v", err) + } + }) + + fn(st, db) +} + +func WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, schema string, lockTimeoutMs int, fn func(mig *roll.Roll, db *sql.DB)) { + t.Helper() + ctx := context.Background() + + tDB, err := sql.Open("postgres", tConnStr) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := tDB.Close(); err != nil { + t.Fatalf("Failed to close database connection: %v", err) + } + }) + + dbName := randomDBName() + + _, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName))) + if err != nil { + t.Fatal(err) + } + + u, err := url.Parse(tConnStr) + if err != nil { + t.Fatal(err) + } + + u.Path = "/" + dbName + connStr := u.String() + + st, err := state.New(ctx, connStr, "pgroll") + if err != nil { + t.Fatal(err) + } + + err = st.Init(ctx) + if err != nil { + t.Fatal(err) + } + + mig, err := roll.New(ctx, connStr, schema, lockTimeoutMs, st) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := mig.Close(); err != nil { + t.Fatalf("Failed to close migrator connection: %v", err) + } + }) + + db, err := sql.Open("postgres", connStr) + if err != nil { + t.Fatal(err) + } + + _, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema)) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := db.Close(); err != nil { + t.Fatalf("Failed to close database connection: %v", err) + } + }) + + fn(mig, db) +} + +func WithMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, fn func(mig *roll.Roll, db *sql.DB)) { + WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, schema, 500, fn) +} + +func WithMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) { + WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 500, fn) +}