Skip to content

Commit

Permalink
Preserve CHECK constraints on duplication (#244)
Browse files Browse the repository at this point in the history
When duplicating a column for backfilling ensure that any `CHECK`
constraints on the original column are re-created on the duplicated
column. The `CHECK` constraint is initially created as `NOT VALID` then
validated after migration completion.

This is part of #227

As of this PR, column properties that are preserved when duplicating a
column for backfilling are:

* `DEFAULT` values
* `FOREIGN KEY` constraints
* `CHECK` constraints
  • Loading branch information
andrew-farries authored Jan 19, 2024
1 parent 7a38e1f commit 68c5fbf
Show file tree
Hide file tree
Showing 7 changed files with 381 additions and 11 deletions.
17 changes: 14 additions & 3 deletions pkg/migrations/duplicate.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ func (d *Duplicator) WithType(t string) *Duplicator {
// constraints as the original column.
func (d *Duplicator) Duplicate(ctx context.Context) error {
const (
cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s`
cSetDefaultSQL = `ALTER COLUMN %s SET DEFAULT %s`
cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)`
cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s`
cSetDefaultSQL = `ALTER COLUMN %s SET DEFAULT %s`
cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)`
cAddCheckConstraintSQL = `ADD CONSTRAINT %s %s NOT VALID`
)

// Generate SQL to duplicate the column's name and type
Expand All @@ -68,6 +69,16 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
}
}

// Generate SQL to duplicate any check constraints on the column
for _, cc := range d.table.CheckConstraints {
if slices.Contains(cc.Columns, d.column.Name) {
sql += fmt.Sprintf(", "+cAddCheckConstraintSQL,
pq.QuoteIdentifier(DuplicationName(cc.Name)),
rewriteCheckExpression(cc.Definition, d.column.Name, d.asName),
)
}
}

_, err := d.conn.ExecContext(ctx, sql)

return err
Expand Down
58 changes: 58 additions & 0 deletions pkg/migrations/op_change_type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/roll"
"github.com/xataio/pgroll/pkg/testutils"
)

func TestChangeColumnType(t *testing.T) {
Expand Down Expand Up @@ -287,6 +288,63 @@ func TestChangeColumnType(t *testing.T) {
}, rows)
},
},
{
name: "changing column type preserves any check constraints on the column",
migrations: []migrations.Migration{
{
Name: "01_add_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "integer",
Pk: true,
},
{
Name: "username",
Type: "text",
Nullable: true,
Check: &migrations.CheckConstraint{
Name: "username_length",
Constraint: "length(username) > 3",
},
},
},
},
},
},
{
Name: "02_change_type",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "users",
Column: "username",
Type: "varchar(255)",
Up: "username",
Down: "username",
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// Inserting a row that violates the check constraint should fail.
MustNotInsert(t, db, "public", "02_change_type", "users", map[string]string{
"id": "1",
"username": "a",
}, testutils.CheckViolationErrorCode)
},
afterRollback: func(t *testing.T, db *sql.DB) {
},
afterComplete: func(t *testing.T, db *sql.DB) {
// Inserting a row that violates the check constraint should fail.
MustNotInsert(t, db, "public", "02_change_type", "users", map[string]string{
"id": "2",
"username": "b",
}, testutils.CheckViolationErrorCode)
},
},
})
}

Expand Down
65 changes: 65 additions & 0 deletions pkg/migrations/op_set_check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,71 @@ func TestSetCheckConstraint(t *testing.T) {
ValidatedForeignKeyMustExist(t, db, "public", "employees", "fk_employee_department")
},
},
{
name: "existing check constraints are preserved when adding a check constraint",
migrations: []migrations.Migration{
{
Name: "01_add_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "posts",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: true,
},
{
Name: "title",
Type: "text",
Check: &migrations.CheckConstraint{
Name: "check_title_length",
Constraint: "length(title) > 3",
},
},
{
Name: "body",
Type: "text",
},
},
},
},
},
{
Name: "02_add_check_constraint",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "posts",
Column: "body",
Check: &migrations.CheckConstraint{
Name: "check_body_length",
Constraint: "length(body) > 3",
},
Up: "(SELECT CASE WHEN length(body) <= 3 THEN LPAD(body, 4, '-') ELSE body END)",
Down: "body",
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// The check constraint on the `title` column still exists.
MustNotInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{
"id": "1",
"title": "a",
"body": "this is the post body",
}, testutils.CheckViolationErrorCode)
},
afterRollback: func(t *testing.T, db *sql.DB) {
},
afterComplete: func(t *testing.T, db *sql.DB) {
// The check constraint on the `title` column still exists.
MustNotInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{
"id": "2",
"title": "b",
"body": "this is another post body",
}, testutils.CheckViolationErrorCode)
},
},
})
}

Expand Down
85 changes: 85 additions & 0 deletions pkg/migrations/op_set_fk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,91 @@ func TestSetForeignKey(t *testing.T) {
ValidatedForeignKeyMustExist(t, db, "public", "posts", "fk_users_id_1")
},
},
{
name: "check constraints on a column are preserved when adding a foreign key constraint",
migrations: []migrations.Migration{
{
Name: "01_add_tables",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: true,
},
{
Name: "name",
Type: "text",
},
},
},
&migrations.OpCreateTable{
Name: "posts",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: true,
},
{
Name: "title",
Type: "text",
Check: &migrations.CheckConstraint{
Name: "title_length",
Constraint: "length(title) > 3",
},
},
{
Name: "user_id",
Type: "integer",
},
},
},
},
},
{
Name: "02_add_fk_constraint",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "posts",
Column: "user_id",
References: &migrations.ForeignKeyReference{
Name: "fk_users_id",
Table: "users",
Column: "id",
},
Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)",
Down: "user_id",
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// Set up the users table with a reference row
MustInsert(t, db, "public", "02_add_fk_constraint", "users", map[string]string{
"name": "alice",
})

// Inserting a row that violates the check constraint should fail.
MustNotInsert(t, db, "public", "02_add_fk_constraint", "posts", map[string]string{
"id": "1",
"user_id": "1",
"title": "a",
}, testutils.CheckViolationErrorCode)
},
afterRollback: func(t *testing.T, db *sql.DB) {
},
afterComplete: func(t *testing.T, db *sql.DB) {
// Inserting a row that violates the check constraint should fail.
MustNotInsert(t, db, "public", "02_add_fk_constraint", "posts", map[string]string{
"id": "2",
"user_id": "1",
"title": "b",
}, testutils.CheckViolationErrorCode)
},
},
})
}

Expand Down
58 changes: 57 additions & 1 deletion pkg/migrations/op_set_notnull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ func TestSetNotNull(t *testing.T) {
},
},
{
name: "setting a nullable column to not null retains any default defined on the column",
name: "setting a column to not null retains any default defined on the column",
migrations: []migrations.Migration{
{
Name: "01_add_table",
Expand Down Expand Up @@ -364,6 +364,62 @@ func TestSetNotNull(t *testing.T) {
}, rows)
},
},
{
name: "setting a column to not null retains any check constraints defined on the column",
migrations: []migrations.Migration{
{
Name: "01_add_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "integer",
Pk: true,
},
{
Name: "name",
Type: "text",
Nullable: true,
Check: &migrations.CheckConstraint{
Name: "name_length",
Constraint: "length(name) > 3",
},
},
},
},
},
},
{
Name: "02_set_not_null",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "users",
Column: "name",
Nullable: ptr(false),
Up: "(SELECT CASE WHEN name IS NULL THEN 'anonymous' ELSE name END)",
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// Inserting a row that violates the check constraint should fail.
MustNotInsert(t, db, "public", "02_set_not_null", "users", map[string]string{
"id": "1",
"name": "a",
}, testutils.CheckViolationErrorCode)
},
afterRollback: func(t *testing.T, db *sql.DB) {
},
afterComplete: func(t *testing.T, db *sql.DB) {
// Inserting a row that violates the check constraint should fail.
MustNotInsert(t, db, "public", "02_set_not_null", "users", map[string]string{
"id": "2",
"name": "b",
}, testutils.CheckViolationErrorCode)
},
},
})
}

Expand Down
Loading

0 comments on commit 68c5fbf

Please sign in to comment.