Skip to content

Commit

Permalink
Add testutils package
Browse files Browse the repository at this point in the history
Add a `testutils` package containing a shared `TestMain` function and
the `with*` functions used by tests in other packages.

The `SharedTestMain` function creates a new postgres container.

Each of the `with*` functions then creates a new database in that
container for the purposes of the test.
  • Loading branch information
andrew-farries committed Jan 2, 2024
1 parent 66ccf91 commit 76695e2
Show file tree
Hide file tree
Showing 2 changed files with 271 additions and 0 deletions.
17 changes: 17 additions & 0 deletions pkg/testutils/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// SPDX-License-Identifier: Apache-2.0

package testutils

import "math/rand"

func randomDBName() string {
const length = 15
const charset = "abcdefghijklmnopqrstuvwxyz"

b := make([]byte, length)
for i := range b {
b[i] = charset[rand.Intn(len(charset))] // #nosec G404
}

return "testdb_" + string(b)
}
254 changes: 254 additions & 0 deletions pkg/testutils/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
// SPDX-License-Identifier: Apache-2.0

package testutils

import (
"context"
"database/sql"
"fmt"
"log"
"net/url"
"os"
"testing"
"time"

"github.com/lib/pq"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
"github.com/xataio/pgroll/pkg/roll"
"github.com/xataio/pgroll/pkg/state"
)

// The version of postgres against which the tests are run
// if the POSTGRES_VERSION environment variable is not set.
const defaultPostgresVersion = "15.3"

// tConnStr holds the connection string to the test container created in TestMain.
var tConnStr string

// SharedTestMain starts a postgres container to be used by all tests in a package.
// Each test then connects to the container and creates a new database.
func SharedTestMain(m *testing.M) {
ctx := context.Background()

waitForLogs := wait.
ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(5 * time.Second)

pgVersion := os.Getenv("POSTGRES_VERSION")
if pgVersion == "" {
pgVersion = defaultPostgresVersion
}

ctr, err := postgres.RunContainer(ctx,
testcontainers.WithImage("postgres:"+pgVersion),
testcontainers.WithWaitStrategy(waitForLogs),
)
if err != nil {
os.Exit(1)
}

tConnStr, err = ctr.ConnectionString(ctx, "sslmode=disable")
if err != nil {
os.Exit(1)
}

exitCode := m.Run()

if err := ctr.Terminate(ctx); err != nil {
log.Printf("Failed to terminate container: %v", err)
}

os.Exit(exitCode)
}

func WithMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) {
t.Helper()
ctx := context.Background()

tDB, err := sql.Open("postgres", tConnStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := tDB.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

dbName := randomDBName()

_, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName)))
if err != nil {
t.Fatal(err)
}

u, err := url.Parse(tConnStr)
if err != nil {
t.Fatal(err)
}

u.Path = "/" + dbName
connStr := u.String()

st, err := state.New(ctx, connStr, "pgroll")
if err != nil {
t.Fatal(err)
}

err = st.Init(ctx)
if err != nil {
t.Fatal(err)
}

const lockTimeoutMs = 500
mig, err := roll.New(ctx, connStr, "public", lockTimeoutMs, st)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := mig.Close(); err != nil {
t.Fatalf("Failed to close migrator connection: %v", err)
}
})

db, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := db.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

fn(mig, db)
}

func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.DB)) {
t.Helper()
ctx := context.Background()

tDB, err := sql.Open("postgres", tConnStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := tDB.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

dbName := randomDBName()

_, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName)))
if err != nil {
t.Fatal(err)
}

u, err := url.Parse(tConnStr)
if err != nil {
t.Fatal(err)
}

u.Path = "/" + dbName
connStr := u.String()

st, err := state.New(ctx, connStr, "pgroll")
if err != nil {
t.Fatal(err)
}

db, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := db.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

fn(st, db)
}

func WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, schema string, lockTimeoutMs int, fn func(mig *roll.Roll, db *sql.DB)) {
t.Helper()
ctx := context.Background()

tDB, err := sql.Open("postgres", tConnStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := tDB.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

dbName := randomDBName()

_, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName)))
if err != nil {
t.Fatal(err)
}

u, err := url.Parse(tConnStr)
if err != nil {
t.Fatal(err)
}

u.Path = "/" + dbName
connStr := u.String()

st, err := state.New(ctx, connStr, "pgroll")
if err != nil {
t.Fatal(err)
}

err = st.Init(ctx)
if err != nil {
t.Fatal(err)
}

mig, err := roll.New(ctx, connStr, schema, lockTimeoutMs, st)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := mig.Close(); err != nil {
t.Fatalf("Failed to close migrator connection: %v", err)
}
})

db, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatal(err)
}

_, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema))
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := db.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

fn(mig, db)
}

func WithMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, fn func(mig *roll.Roll, db *sql.DB)) {
WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, schema, 500, fn)
}

0 comments on commit 76695e2

Please sign in to comment.