Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve CHECK constraints on duplication #244

Merged
merged 7 commits into from
Jan 19, 2024
Merged
17 changes: 14 additions & 3 deletions pkg/migrations/duplicate.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ func (d *Duplicator) WithType(t string) *Duplicator {
// constraints as the original column.
func (d *Duplicator) Duplicate(ctx context.Context) error {
const (
cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s`
cSetDefaultSQL = `ALTER COLUMN %s SET DEFAULT %s`
cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)`
cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s`
cSetDefaultSQL = `ALTER COLUMN %s SET DEFAULT %s`
cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)`
cAddCheckConstraintSQL = `ADD CONSTRAINT %s %s NOT VALID`
)

// Generate SQL to duplicate the column's name and type
Expand All @@ -68,6 +69,16 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
}
}

// Generate SQL to duplicate any check constraints on the column
for _, cc := range d.table.CheckConstraints {
if slices.Contains(cc.Columns, d.column.Name) {
sql += fmt.Sprintf(", "+cAddCheckConstraintSQL,
pq.QuoteIdentifier(DuplicationName(cc.Name)),
rewriteCheckExpression(cc.Definition, d.column.Name, d.asName),
)
}
}

_, err := d.conn.ExecContext(ctx, sql)

return err
Expand Down
58 changes: 58 additions & 0 deletions pkg/migrations/op_change_type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/roll"
"github.com/xataio/pgroll/pkg/testutils"
)

func TestChangeColumnType(t *testing.T) {
Expand Down Expand Up @@ -287,6 +288,63 @@ func TestChangeColumnType(t *testing.T) {
}, rows)
},
},
{
name: "changing column type preserves any check constraints on the column",
migrations: []migrations.Migration{
{
Name: "01_add_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "integer",
Pk: true,
},
{
Name: "username",
Type: "text",
Nullable: true,
Check: &migrations.CheckConstraint{
Name: "username_length",
Constraint: "length(username) > 3",
},
},
},
},
},
},
{
Name: "02_change_type",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "users",
Column: "username",
Type: "varchar(255)",
Up: "username",
Down: "username",
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// Inserting a row that violates the check constraint should fail.
MustNotInsert(t, db, "public", "02_change_type", "users", map[string]string{
"id": "1",
"username": "a",
}, testutils.CheckViolationErrorCode)
},
afterRollback: func(t *testing.T, db *sql.DB) {
},
afterComplete: func(t *testing.T, db *sql.DB) {
// Inserting a row that violates the check constraint should fail.
MustNotInsert(t, db, "public", "02_change_type", "users", map[string]string{
"id": "2",
"username": "b",
}, testutils.CheckViolationErrorCode)
},
},
})
}

Expand Down
65 changes: 65 additions & 0 deletions pkg/migrations/op_set_check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,71 @@ func TestSetCheckConstraint(t *testing.T) {
ValidatedForeignKeyMustExist(t, db, "public", "employees", "fk_employee_department")
},
},
{
name: "existing check constraints are preserved when adding a check constraint",
migrations: []migrations.Migration{
{
Name: "01_add_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "posts",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: true,
},
{
Name: "title",
Type: "text",
Check: &migrations.CheckConstraint{
Name: "check_title_length",
Constraint: "length(title) > 3",
},
},
{
Name: "body",
Type: "text",
},
},
},
},
},
{
Name: "02_add_check_constraint",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "posts",
Column: "body",
Check: &migrations.CheckConstraint{
Name: "check_body_length",
Constraint: "length(body) > 3",
},
Up: "(SELECT CASE WHEN length(body) <= 3 THEN LPAD(body, 4, '-') ELSE body END)",
Down: "body",
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// The check constraint on the `title` column still exists.
MustNotInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{
"id": "1",
"title": "a",
"body": "this is the post body",
}, testutils.CheckViolationErrorCode)
},
afterRollback: func(t *testing.T, db *sql.DB) {
},
afterComplete: func(t *testing.T, db *sql.DB) {
// The check constraint on the `title` column still exists.
MustNotInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{
"id": "2",
"title": "b",
"body": "this is another post body",
}, testutils.CheckViolationErrorCode)
},
},
})
}

Expand Down
85 changes: 85 additions & 0 deletions pkg/migrations/op_set_fk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,91 @@ func TestSetForeignKey(t *testing.T) {
ValidatedForeignKeyMustExist(t, db, "public", "posts", "fk_users_id_1")
},
},
{
name: "check constraints on a column are preserved when adding a foreign key constraint",
migrations: []migrations.Migration{
{
Name: "01_add_tables",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: true,
},
{
Name: "name",
Type: "text",
},
},
},
&migrations.OpCreateTable{
Name: "posts",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: true,
},
{
Name: "title",
Type: "text",
Check: &migrations.CheckConstraint{
Name: "title_length",
Constraint: "length(title) > 3",
},
},
{
Name: "user_id",
Type: "integer",
},
},
},
},
},
{
Name: "02_add_fk_constraint",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "posts",
Column: "user_id",
References: &migrations.ForeignKeyReference{
Name: "fk_users_id",
Table: "users",
Column: "id",
},
Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)",
Down: "user_id",
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// Set up the users table with a reference row
MustInsert(t, db, "public", "02_add_fk_constraint", "users", map[string]string{
"name": "alice",
})

// Inserting a row that violates the check constraint should fail.
MustNotInsert(t, db, "public", "02_add_fk_constraint", "posts", map[string]string{
"id": "1",
"user_id": "1",
"title": "a",
}, testutils.CheckViolationErrorCode)
},
afterRollback: func(t *testing.T, db *sql.DB) {
},
afterComplete: func(t *testing.T, db *sql.DB) {
// Inserting a row that violates the check constraint should fail.
MustNotInsert(t, db, "public", "02_add_fk_constraint", "posts", map[string]string{
"id": "2",
"user_id": "1",
"title": "b",
}, testutils.CheckViolationErrorCode)
},
},
})
}

Expand Down
58 changes: 57 additions & 1 deletion pkg/migrations/op_set_notnull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ func TestSetNotNull(t *testing.T) {
},
},
{
name: "setting a nullable column to not null retains any default defined on the column",
name: "setting a column to not null retains any default defined on the column",
migrations: []migrations.Migration{
{
Name: "01_add_table",
Expand Down Expand Up @@ -364,6 +364,62 @@ func TestSetNotNull(t *testing.T) {
}, rows)
},
},
{
name: "setting a column to not null retains any check constraints defined on the column",
migrations: []migrations.Migration{
{
Name: "01_add_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "integer",
Pk: true,
},
{
Name: "name",
Type: "text",
Nullable: true,
Check: &migrations.CheckConstraint{
Name: "name_length",
Constraint: "length(name) > 3",
},
},
},
},
},
},
{
Name: "02_set_not_null",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "users",
Column: "name",
Nullable: ptr(false),
Up: "(SELECT CASE WHEN name IS NULL THEN 'anonymous' ELSE name END)",
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// Inserting a row that violates the check constraint should fail.
MustNotInsert(t, db, "public", "02_set_not_null", "users", map[string]string{
"id": "1",
"name": "a",
}, testutils.CheckViolationErrorCode)
},
afterRollback: func(t *testing.T, db *sql.DB) {
},
afterComplete: func(t *testing.T, db *sql.DB) {
// Inserting a row that violates the check constraint should fail.
MustNotInsert(t, db, "public", "02_set_not_null", "users", map[string]string{
"id": "2",
"name": "b",
}, testutils.CheckViolationErrorCode)
},
},
})
}

Expand Down
Loading
Loading