Skip to content

Commit

Permalink
Create testutils package
Browse files Browse the repository at this point in the history
  • Loading branch information
andrew-farries committed Jan 2, 2024
1 parent c33c8f5 commit ee2ed9b
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 195 deletions.
136 changes: 3 additions & 133 deletions pkg/migrations/op_common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,16 @@ import (
"database/sql"
"errors"
"fmt"
"log"
"math/rand"
"net/url"
"os"
"testing"
"time"

"github.com/lib/pq"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
"github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/roll"
"github.com/xataio/pgroll/pkg/state"
"github.com/xataio/pgroll/pkg/testutils"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)

// The version of postgres against which the tests are run
// if the POSTGRES_VERSION environment variable is not set.
const defaultPostgresVersion = "15.3"

type TestCase struct {
name string
migrations []migrations.Migration
Expand All @@ -40,49 +28,14 @@ type TestCase struct {

type TestCases []TestCase

var tConnStr string

// TestMain starts a postgres container to be used by all tests in the package.
// Each test then connects to the container and creates a new database.
func TestMain(m *testing.M) {
ctx := context.Background()

waitForLogs := wait.
ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(5 * time.Second)

pgVersion := os.Getenv("POSTGRES_VERSION")
if pgVersion == "" {
pgVersion = defaultPostgresVersion
}

ctr, err := postgres.RunContainer(ctx,
testcontainers.WithImage("postgres:"+pgVersion),
testcontainers.WithWaitStrategy(waitForLogs),
)
if err != nil {
os.Exit(1)
}

tConnStr, err = ctr.ConnectionString(ctx, "sslmode=disable")
if err != nil {
os.Exit(1)
}

exitCode := m.Run()

if err := ctr.Terminate(ctx); err != nil {
log.Printf("Failed to terminate container: %v", err)
}

os.Exit(exitCode)
testutils.SharedTestMain(m)
}

func ExecuteTests(t *testing.T, tests TestCases) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()

// run all migrations except the last one
Expand Down Expand Up @@ -142,89 +95,6 @@ func ExecuteTests(t *testing.T, tests TestCases) {
}
}

// withMigratorAndConnectionToContainer:
// * connects to the test container created in TestMain.
// * creates a new database.
// * initializes pgroll in the new database.
// * runs the supplied test function with a migrator and a connection to the new database.
func withMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) {
t.Helper()
ctx := context.Background()

tDB, err := sql.Open("postgres", tConnStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := tDB.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

dbName := randomDBName()

_, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName)))
if err != nil {
t.Fatal(err)
}

u, err := url.Parse(tConnStr)
if err != nil {
t.Fatal(err)
}

u.Path = "/" + dbName
connStr := u.String()

st, err := state.New(ctx, connStr, "pgroll")
if err != nil {
t.Fatal(err)
}

err = st.Init(ctx)
if err != nil {
t.Fatal(err)
}

const lockTimeoutMs = 500
mig, err := roll.New(ctx, connStr, "public", lockTimeoutMs, st)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := mig.Close(); err != nil {
t.Fatalf("Failed to close migrator connection: %v", err)
}
})

db, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := db.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

fn(mig, db)
}

func randomDBName() string {
const length = 15
const charset = "abcdefghijklmnopqrstuvwxyz"

b := make([]byte, length)
for i := range b {
b[i] = charset[rand.Intn(len(charset))] // #nosec G404
}

return "testdb_" + string(b)
}

// Common assertions

func ViewMustExist(t *testing.T, db *sql.DB, schema, version, view string) {
Expand Down
68 changes: 6 additions & 62 deletions pkg/state/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,25 @@ package state_test
import (
"context"
"database/sql"
"os"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/assert"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
"github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/schema"
"github.com/xataio/pgroll/pkg/state"
"github.com/xataio/pgroll/pkg/testutils"
)

// The version of postgres against which the tests are run
// if the POSTGRES_VERSION environment variable is not set.
const defaultPostgresVersion = "15.3"
func TestMain(m *testing.M) {
testutils.SharedTestMain(m)
}

func TestSchemaOptionIsRespected(t *testing.T) {
t.Parallel()

witStateAndConnectionToContainer(t, func(state *state.State, db *sql.DB) {
testutils.WithStateAndConnectionToContainer(t, func(state *state.State, db *sql.DB) {
ctx := context.Background()

// create a table in the public schema
Expand Down Expand Up @@ -63,7 +59,7 @@ func TestSchemaOptionIsRespected(t *testing.T) {
func TestReadSchema(t *testing.T) {
t.Parallel()

witStateAndConnectionToContainer(t, func(state *state.State, db *sql.DB) {
testutils.WithStateAndConnectionToContainer(t, func(state *state.State, db *sql.DB) {
ctx := context.Background()

tests := []struct {
Expand Down Expand Up @@ -187,55 +183,3 @@ func TestReadSchema(t *testing.T) {
}
})
}

func witStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.DB)) {
t.Helper()
ctx := context.Background()

waitForLogs := wait.
ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(5 * time.Second)

pgVersion := os.Getenv("POSTGRES_VERSION")
if pgVersion == "" {
pgVersion = defaultPostgresVersion
}

ctr, err := postgres.RunContainer(ctx,
testcontainers.WithImage("postgres:"+pgVersion),
testcontainers.WithWaitStrategy(waitForLogs),
)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := ctr.Terminate(ctx); err != nil {
t.Fatalf("Failed to terminate container: %v", err)
}
})

cStr, err := ctr.ConnectionString(ctx, "sslmode=disable")
if err != nil {
t.Fatal(err)
}

db, err := sql.Open("postgres", cStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := db.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

st, err := state.New(ctx, cStr, "pgroll")
if err != nil {
t.Fatal(err)
}

fn(st, db)
}
17 changes: 17 additions & 0 deletions pkg/testutils/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// SPDX-License-Identifier: Apache-2.0

package testutils

import "math/rand"

func randomDBName() string {
const length = 15
const charset = "abcdefghijklmnopqrstuvwxyz"

b := make([]byte, length)
for i := range b {
b[i] = charset[rand.Intn(len(charset))] // #nosec G404
}

return "testdb_" + string(b)
}
Loading

0 comments on commit ee2ed9b

Please sign in to comment.