Skip to content

Commit

Permalink
Create testutils package
Browse files Browse the repository at this point in the history
  • Loading branch information
andrew-farries committed Jan 2, 2024
1 parent c33c8f5 commit 52b3d13
Show file tree
Hide file tree
Showing 5 changed files with 298 additions and 291 deletions.
136 changes: 3 additions & 133 deletions pkg/migrations/op_common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,16 @@ import (
"database/sql"
"errors"
"fmt"
"log"
"math/rand"
"net/url"
"os"
"testing"
"time"

"github.com/lib/pq"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
"github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/roll"
"github.com/xataio/pgroll/pkg/state"
"github.com/xataio/pgroll/pkg/testutils"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)

// The version of postgres against which the tests are run
// if the POSTGRES_VERSION environment variable is not set.
const defaultPostgresVersion = "15.3"

type TestCase struct {
name string
migrations []migrations.Migration
Expand All @@ -40,49 +28,14 @@ type TestCase struct {

type TestCases []TestCase

var tConnStr string

// TestMain starts a postgres container to be used by all tests in the package.
// Each test then connects to the container and creates a new database.
func TestMain(m *testing.M) {
ctx := context.Background()

waitForLogs := wait.
ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(5 * time.Second)

pgVersion := os.Getenv("POSTGRES_VERSION")
if pgVersion == "" {
pgVersion = defaultPostgresVersion
}

ctr, err := postgres.RunContainer(ctx,
testcontainers.WithImage("postgres:"+pgVersion),
testcontainers.WithWaitStrategy(waitForLogs),
)
if err != nil {
os.Exit(1)
}

tConnStr, err = ctr.ConnectionString(ctx, "sslmode=disable")
if err != nil {
os.Exit(1)
}

exitCode := m.Run()

if err := ctr.Terminate(ctx); err != nil {
log.Printf("Failed to terminate container: %v", err)
}

os.Exit(exitCode)
testutils.SharedTestMain(m)
}

func ExecuteTests(t *testing.T, tests TestCases) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()

// run all migrations except the last one
Expand Down Expand Up @@ -142,89 +95,6 @@ func ExecuteTests(t *testing.T, tests TestCases) {
}
}

// withMigratorAndConnectionToContainer:
// * connects to the test container created in TestMain.
// * creates a new database.
// * initializes pgroll in the new database.
// * runs the supplied test function with a migrator and a connection to the new database.
func withMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) {
t.Helper()
ctx := context.Background()

tDB, err := sql.Open("postgres", tConnStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := tDB.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

dbName := randomDBName()

_, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName)))
if err != nil {
t.Fatal(err)
}

u, err := url.Parse(tConnStr)
if err != nil {
t.Fatal(err)
}

u.Path = "/" + dbName
connStr := u.String()

st, err := state.New(ctx, connStr, "pgroll")
if err != nil {
t.Fatal(err)
}

err = st.Init(ctx)
if err != nil {
t.Fatal(err)
}

const lockTimeoutMs = 500
mig, err := roll.New(ctx, connStr, "public", lockTimeoutMs, st)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := mig.Close(); err != nil {
t.Fatalf("Failed to close migrator connection: %v", err)
}
})

db, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := db.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

fn(mig, db)
}

func randomDBName() string {
const length = 15
const charset = "abcdefghijklmnopqrstuvwxyz"

b := make([]byte, length)
for i := range b {
b[i] = charset[rand.Intn(len(charset))] // #nosec G404
}

return "testdb_" + string(b)
}

// Common assertions

func ViewMustExist(t *testing.T, db *sql.DB, schema, version, view string) {
Expand Down
109 changes: 13 additions & 96 deletions pkg/roll/execute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,28 @@ import (
"database/sql"
"errors"
"fmt"
"os"
"testing"
"time"

"github.com/lib/pq"
"github.com/stretchr/testify/assert"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
"github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/roll"
"github.com/xataio/pgroll/pkg/state"
"github.com/xataio/pgroll/pkg/testutils"
)

const (
schema = "public"
// The version of postgres against which the tests are run
// if the POSTGRES_VERSION environment variable is not set.
defaultPostgresVersion = "15.3"
)

func TestMain(m *testing.M) {
testutils.SharedTestMain(m)
}

func TestSchemaIsCreatedfterMigrationStart(t *testing.T) {
t.Parallel()

withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()
version := "1_create_table"

Expand Down Expand Up @@ -63,7 +60,7 @@ func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) {
t.Parallel()

t.Run("when the previous version is a pgroll migration", func(t *testing.T) {
withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()
const (
firstVersion = "1_create_table"
Expand Down Expand Up @@ -104,7 +101,7 @@ func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) {
})

t.Run("when the previous version is an inferred DDL migration", func(t *testing.T) {
withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()
const (
firstVersion = "1_create_table"
Expand Down Expand Up @@ -157,7 +154,7 @@ func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) {
func TestSchemaIsDroppedAfterMigrationRollback(t *testing.T) {
t.Parallel()

withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()
version := "1_create_table"

Expand Down Expand Up @@ -191,7 +188,7 @@ func TestSchemaIsDroppedAfterMigrationRollback(t *testing.T) {
func TestSchemaOptionIsRespected(t *testing.T) {
t.Parallel()

withMigratorInSchemaAndConnectionToContainer(t, "schema1", func(mig *roll.Roll, db *sql.DB) {
testutils.WithMigratorInSchemaAndConnectionToContainer(t, "schema1", func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()
const version1 = "1_create_table"
const version2 = "2_create_another_table"
Expand Down Expand Up @@ -256,7 +253,7 @@ func TestSchemaOptionIsRespected(t *testing.T) {
func TestLockTimeoutIsEnforced(t *testing.T) {
t.Parallel()

withMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 100, func(mig *roll.Roll, db *sql.DB) {
testutils.WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 100, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()

// Start a create table migration
Expand Down Expand Up @@ -310,7 +307,7 @@ func TestLockTimeoutIsEnforced(t *testing.T) {
func TestViewsAreCreatedWithSecurityInvokerTrue(t *testing.T) {
t.Parallel()

withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()
version := "1_create_table"

Expand Down Expand Up @@ -392,7 +389,7 @@ func TestViewsAreCreatedWithSecurityInvokerTrue(t *testing.T) {
func TestStatusMethodReturnsCorrectStatus(t *testing.T) {
t.Parallel()

withMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()

// Get the initial migration status before any migrations are run
Expand Down Expand Up @@ -490,86 +487,6 @@ func addColumnOp(tableName string) *migrations.OpAddColumn {
}
}

func withMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, schema string, lockTimeoutMs int, fn func(mig *roll.Roll, db *sql.DB)) {
t.Helper()
ctx := context.Background()

waitForLogs := wait.
ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(5 * time.Second)

pgVersion := os.Getenv("POSTGRES_VERSION")
if pgVersion == "" {
pgVersion = defaultPostgresVersion
}

ctr, err := postgres.RunContainer(ctx,
testcontainers.WithImage("postgres:"+pgVersion),
testcontainers.WithWaitStrategy(waitForLogs),
)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := ctr.Terminate(ctx); err != nil {
t.Fatalf("Failed to terminate container: %v", err)
}
})

cStr, err := ctr.ConnectionString(ctx, "sslmode=disable")
if err != nil {
t.Fatal(err)
}

st, err := state.New(ctx, cStr, "pgroll")
if err != nil {
t.Fatal(err)
}
err = st.Init(ctx)
if err != nil {
t.Fatal(err)
}

mig, err := roll.New(ctx, cStr, schema, lockTimeoutMs, st)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := mig.Close(); err != nil {
t.Fatalf("Failed to close migrator connection: %v", err)
}
})

db, err := sql.Open("postgres", cStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := db.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

_, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema))
if err != nil {
t.Fatal(err)
}

fn(mig, db)
}

func withMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, fn func(mig *roll.Roll, db *sql.DB)) {
withMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, schema, 500, fn)
}

func withMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) {
withMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 500, fn)
}

func MustSelect(t *testing.T, db *sql.DB, schema, version, table string) []map[string]any {
t.Helper()
versionSchema := roll.VersionedSchemaName(schema, version)
Expand Down
Loading

0 comments on commit 52b3d13

Please sign in to comment.