Skip to content

Commit

Permalink
Create testutils package
Browse files Browse the repository at this point in the history
  • Loading branch information
andrew-farries committed Jan 2, 2024
1 parent c33c8f5 commit 20afaba
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 133 deletions.
136 changes: 3 additions & 133 deletions pkg/migrations/op_common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,16 @@ import (
"database/sql"
"errors"
"fmt"
"log"
"math/rand"
"net/url"
"os"
"testing"
"time"

"github.com/lib/pq"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
"github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/roll"
"github.com/xataio/pgroll/pkg/state"
"github.com/xataio/pgroll/pkg/testutils"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)

// The version of postgres against which the tests are run
// if the POSTGRES_VERSION environment variable is not set.
const defaultPostgresVersion = "15.3"

type TestCase struct {
name string
migrations []migrations.Migration
Expand All @@ -40,49 +28,14 @@ type TestCase struct {

type TestCases []TestCase

var tConnStr string

// TestMain starts a postgres container to be used by all tests in the package.
// Each test then connects to the container and creates a new database.
func TestMain(m *testing.M) {
ctx := context.Background()

waitForLogs := wait.
ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(5 * time.Second)

pgVersion := os.Getenv("POSTGRES_VERSION")
if pgVersion == "" {
pgVersion = defaultPostgresVersion
}

ctr, err := postgres.RunContainer(ctx,
testcontainers.WithImage("postgres:"+pgVersion),
testcontainers.WithWaitStrategy(waitForLogs),
)
if err != nil {
os.Exit(1)
}

tConnStr, err = ctr.ConnectionString(ctx, "sslmode=disable")
if err != nil {
os.Exit(1)
}

exitCode := m.Run()

if err := ctr.Terminate(ctx); err != nil {
log.Printf("Failed to terminate container: %v", err)
}

os.Exit(exitCode)
testutils.SharedTestMain(m)
}

func ExecuteTests(t *testing.T, tests TestCases) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()

// run all migrations except the last one
Expand Down Expand Up @@ -142,89 +95,6 @@ func ExecuteTests(t *testing.T, tests TestCases) {
}
}

// withMigratorAndConnectionToContainer:
// * connects to the test container created in TestMain.
// * creates a new database.
// * initializes pgroll in the new database.
// * runs the supplied test function with a migrator and a connection to the new database.
func withMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) {
t.Helper()
ctx := context.Background()

tDB, err := sql.Open("postgres", tConnStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := tDB.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

dbName := randomDBName()

_, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName)))
if err != nil {
t.Fatal(err)
}

u, err := url.Parse(tConnStr)
if err != nil {
t.Fatal(err)
}

u.Path = "/" + dbName
connStr := u.String()

st, err := state.New(ctx, connStr, "pgroll")
if err != nil {
t.Fatal(err)
}

err = st.Init(ctx)
if err != nil {
t.Fatal(err)
}

const lockTimeoutMs = 500
mig, err := roll.New(ctx, connStr, "public", lockTimeoutMs, st)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := mig.Close(); err != nil {
t.Fatalf("Failed to close migrator connection: %v", err)
}
})

db, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := db.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

fn(mig, db)
}

func randomDBName() string {
const length = 15
const charset = "abcdefghijklmnopqrstuvwxyz"

b := make([]byte, length)
for i := range b {
b[i] = charset[rand.Intn(len(charset))] // #nosec G404
}

return "testdb_" + string(b)
}

// Common assertions

func ViewMustExist(t *testing.T, db *sql.DB, schema, version, view string) {
Expand Down
17 changes: 17 additions & 0 deletions pkg/testutils/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// SPDX-License-Identifier: Apache-2.0

package testutils

import "math/rand"

func randomDBName() string {
const length = 15
const charset = "abcdefghijklmnopqrstuvwxyz"

b := make([]byte, length)
for i := range b {
b[i] = charset[rand.Intn(len(charset))] // #nosec G404
}

return "testdb_" + string(b)
}
136 changes: 136 additions & 0 deletions pkg/testutils/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// SPDX-License-Identifier: Apache-2.0

package testutils

import (
"context"
"database/sql"
"fmt"
"log"
"net/url"
"os"
"testing"
"time"

"github.com/lib/pq"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
"github.com/xataio/pgroll/pkg/roll"
"github.com/xataio/pgroll/pkg/state"
)

// The version of postgres against which the tests are run
// if the POSTGRES_VERSION environment variable is not set.
const defaultPostgresVersion = "15.3"

// tConnStr holds the connection string to the test container created in TestMain.
var tConnStr string

// SharedTestMain starts a postgres container to be used by all tests in a package.
// Each test then connects to the container and creates a new database.
func SharedTestMain(m *testing.M) {
ctx := context.Background()

waitForLogs := wait.
ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(5 * time.Second)

pgVersion := os.Getenv("POSTGRES_VERSION")
if pgVersion == "" {
pgVersion = defaultPostgresVersion
}

ctr, err := postgres.RunContainer(ctx,
testcontainers.WithImage("postgres:"+pgVersion),
testcontainers.WithWaitStrategy(waitForLogs),
)
if err != nil {
os.Exit(1)
}

tConnStr, err = ctr.ConnectionString(ctx, "sslmode=disable")
if err != nil {
os.Exit(1)
}

exitCode := m.Run()

if err := ctr.Terminate(ctx); err != nil {
log.Printf("Failed to terminate container: %v", err)
}

os.Exit(exitCode)
}

// WithMigratorAndConnectionToContainer:
// * connects to the test container created in TestMain.
// * creates a new database.
// * initializes pgroll in the new database.
// * runs the supplied test function with a migrator and a connection to the new database.
func WithMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) {
t.Helper()
ctx := context.Background()

tDB, err := sql.Open("postgres", tConnStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := tDB.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

dbName := randomDBName()

_, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName)))
if err != nil {
t.Fatal(err)
}

u, err := url.Parse(tConnStr)
if err != nil {
t.Fatal(err)
}

u.Path = "/" + dbName
connStr := u.String()

st, err := state.New(ctx, connStr, "pgroll")
if err != nil {
t.Fatal(err)
}

err = st.Init(ctx)
if err != nil {
t.Fatal(err)
}

const lockTimeoutMs = 500
mig, err := roll.New(ctx, connStr, "public", lockTimeoutMs, st)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := mig.Close(); err != nil {
t.Fatalf("Failed to close migrator connection: %v", err)
}
})

db, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := db.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

fn(mig, db)
}

0 comments on commit 20afaba

Please sign in to comment.