From 76695e29d348cf768310d656c7fbdf53ba30e209 Mon Sep 17 00:00:00 2001 From: Andrew Farries Date: Tue, 2 Jan 2024 15:35:04 +0000 Subject: [PATCH] Add `testutils` package Add a `testutils` package containing a shared `TestMain` function and the `with*` functions used by tests in other packages. The `SharedTestMain` function creates a new postgres container. Each of the `with*` functions then creates a new database in that container for the purposes of the test. --- pkg/testutils/db.go | 17 +++ pkg/testutils/util.go | 254 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 271 insertions(+) create mode 100644 pkg/testutils/db.go create mode 100644 pkg/testutils/util.go diff --git a/pkg/testutils/db.go b/pkg/testutils/db.go new file mode 100644 index 000000000..53df83579 --- /dev/null +++ b/pkg/testutils/db.go @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: Apache-2.0 + +package testutils + +import "math/rand" + +func randomDBName() string { + const length = 15 + const charset = "abcdefghijklmnopqrstuvwxyz" + + b := make([]byte, length) + for i := range b { + b[i] = charset[rand.Intn(len(charset))] // #nosec G404 + } + + return "testdb_" + string(b) +} diff --git a/pkg/testutils/util.go b/pkg/testutils/util.go new file mode 100644 index 000000000..15107fa48 --- /dev/null +++ b/pkg/testutils/util.go @@ -0,0 +1,254 @@ +// SPDX-License-Identifier: Apache-2.0 + +package testutils + +import ( + "context" + "database/sql" + "fmt" + "log" + "net/url" + "os" + "testing" + "time" + + "github.com/lib/pq" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/postgres" + "github.com/testcontainers/testcontainers-go/wait" + "github.com/xataio/pgroll/pkg/roll" + "github.com/xataio/pgroll/pkg/state" +) + +// The version of postgres against which the tests are run +// if the POSTGRES_VERSION environment variable is not set. +const defaultPostgresVersion = "15.3" + +// tConnStr holds the connection string to the test container created in TestMain. +var tConnStr string + +// SharedTestMain starts a postgres container to be used by all tests in a package. +// Each test then connects to the container and creates a new database. +func SharedTestMain(m *testing.M) { + ctx := context.Background() + + waitForLogs := wait. + ForLog("database system is ready to accept connections"). + WithOccurrence(2). + WithStartupTimeout(5 * time.Second) + + pgVersion := os.Getenv("POSTGRES_VERSION") + if pgVersion == "" { + pgVersion = defaultPostgresVersion + } + + ctr, err := postgres.RunContainer(ctx, + testcontainers.WithImage("postgres:"+pgVersion), + testcontainers.WithWaitStrategy(waitForLogs), + ) + if err != nil { + os.Exit(1) + } + + tConnStr, err = ctr.ConnectionString(ctx, "sslmode=disable") + if err != nil { + os.Exit(1) + } + + exitCode := m.Run() + + if err := ctr.Terminate(ctx); err != nil { + log.Printf("Failed to terminate container: %v", err) + } + + os.Exit(exitCode) +} + +func WithMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) { + t.Helper() + ctx := context.Background() + + tDB, err := sql.Open("postgres", tConnStr) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := tDB.Close(); err != nil { + t.Fatalf("Failed to close database connection: %v", err) + } + }) + + dbName := randomDBName() + + _, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName))) + if err != nil { + t.Fatal(err) + } + + u, err := url.Parse(tConnStr) + if err != nil { + t.Fatal(err) + } + + u.Path = "/" + dbName + connStr := u.String() + + st, err := state.New(ctx, connStr, "pgroll") + if err != nil { + t.Fatal(err) + } + + err = st.Init(ctx) + if err != nil { + t.Fatal(err) + } + + const lockTimeoutMs = 500 + mig, err := roll.New(ctx, connStr, "public", lockTimeoutMs, st) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := mig.Close(); err != nil { + t.Fatalf("Failed to close migrator connection: %v", err) + } + }) + + db, err := sql.Open("postgres", connStr) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := db.Close(); err != nil { + t.Fatalf("Failed to close database connection: %v", err) + } + }) + + fn(mig, db) +} + +func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.DB)) { + t.Helper() + ctx := context.Background() + + tDB, err := sql.Open("postgres", tConnStr) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := tDB.Close(); err != nil { + t.Fatalf("Failed to close database connection: %v", err) + } + }) + + dbName := randomDBName() + + _, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName))) + if err != nil { + t.Fatal(err) + } + + u, err := url.Parse(tConnStr) + if err != nil { + t.Fatal(err) + } + + u.Path = "/" + dbName + connStr := u.String() + + st, err := state.New(ctx, connStr, "pgroll") + if err != nil { + t.Fatal(err) + } + + db, err := sql.Open("postgres", connStr) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := db.Close(); err != nil { + t.Fatalf("Failed to close database connection: %v", err) + } + }) + + fn(st, db) +} + +func WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, schema string, lockTimeoutMs int, fn func(mig *roll.Roll, db *sql.DB)) { + t.Helper() + ctx := context.Background() + + tDB, err := sql.Open("postgres", tConnStr) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := tDB.Close(); err != nil { + t.Fatalf("Failed to close database connection: %v", err) + } + }) + + dbName := randomDBName() + + _, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName))) + if err != nil { + t.Fatal(err) + } + + u, err := url.Parse(tConnStr) + if err != nil { + t.Fatal(err) + } + + u.Path = "/" + dbName + connStr := u.String() + + st, err := state.New(ctx, connStr, "pgroll") + if err != nil { + t.Fatal(err) + } + + err = st.Init(ctx) + if err != nil { + t.Fatal(err) + } + + mig, err := roll.New(ctx, connStr, schema, lockTimeoutMs, st) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := mig.Close(); err != nil { + t.Fatalf("Failed to close migrator connection: %v", err) + } + }) + + db, err := sql.Open("postgres", connStr) + if err != nil { + t.Fatal(err) + } + + _, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema)) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := db.Close(); err != nil { + t.Fatalf("Failed to close database connection: %v", err) + } + }) + + fn(mig, db) +} + +func WithMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, fn func(mig *roll.Roll, db *sql.DB)) { + WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, schema, 500, fn) +}