From fc7370f140b9b27cda0a236caf20094231e0f611 Mon Sep 17 00:00:00 2001 From: Andrew Farries Date: Mon, 15 Jan 2024 16:01:52 +0000 Subject: [PATCH] Preserve foreign key constraints when changing a column's type (#232) Preserve any foreign key constraints defined on columns when changing a column's type. Uses the `Duplicator` type added in #230 in the 'change type' operation to ensure that FK constraints are preserved when the column is duplicated for backfilling. https://github.com/xataio/pgroll/pull/230 did the same thing for preserving FK constraints when a column has a `NOT NULL` constraint added. Part of https://github.com/xataio/pgroll/issues/227 --- pkg/migrations/duplicate.go | 25 +- pkg/migrations/op_change_type.go | 26 +-- pkg/migrations/op_change_type_test.go | 320 ++++++++++++++++---------- 3 files changed, 223 insertions(+), 148 deletions(-) diff --git a/pkg/migrations/duplicate.go b/pkg/migrations/duplicate.go index 23b2a1e1..d35b3f80 100644 --- a/pkg/migrations/duplicate.go +++ b/pkg/migrations/duplicate.go @@ -14,22 +14,29 @@ import ( ) type Duplicator struct { - conn *sql.DB - table *schema.Table - column *schema.Column - asName string + conn *sql.DB + table *schema.Table + column *schema.Column + asName string + withType string } // NewColumnDuplicator creates a new Duplicator for a column. func NewColumnDuplicator(conn *sql.DB, table *schema.Table, column *schema.Column) *Duplicator { return &Duplicator{ - conn: conn, - table: table, - column: column, - asName: TemporaryName(column.Name), + conn: conn, + table: table, + column: column, + asName: TemporaryName(column.Name), + withType: column.Type, } } +func (d *Duplicator) WithType(t string) *Duplicator { + d.withType = t + return d +} + // Duplicate creates a new column with the same type and foreign key // constraints as the original column. func (d *Duplicator) Duplicate(ctx context.Context) error { @@ -41,7 +48,7 @@ func (d *Duplicator) Duplicate(ctx context.Context) error { sql := fmt.Sprintf(cAlterTableSQL, pq.QuoteIdentifier(d.table.Name), pq.QuoteIdentifier(d.asName), - d.column.Type) + d.withType) for _, fk := range d.table.ForeignKeys { if slices.Contains(fk.Columns, d.column.Name) { diff --git a/pkg/migrations/op_change_type.go b/pkg/migrations/op_change_type.go index 5eba2fbd..9a6cd2e2 100644 --- a/pkg/migrations/op_change_type.go +++ b/pkg/migrations/op_change_type.go @@ -26,7 +26,8 @@ func (o *OpChangeType) Start(ctx context.Context, conn *sql.DB, stateSchema stri column := table.GetColumn(o.Column) // Create a copy of the column on the underlying table. - if err := duplicateColumnForTypeChange(ctx, conn, table, *column, o.Type); err != nil { + d := NewColumnDuplicator(conn, table, column).WithType(o.Type) + if err := d.Duplicate(ctx); err != nil { return fmt.Errorf("failed to duplicate column: %w", err) } @@ -99,12 +100,13 @@ func (o *OpChangeType) Complete(ctx context.Context, conn *sql.DB, s *schema.Sch } // Rename the new column to the old column name - _, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s", - pq.QuoteIdentifier(o.Table), - pq.QuoteIdentifier(TemporaryName(o.Column)), - pq.QuoteIdentifier(o.Column))) + table := s.GetTable(o.Table) + column := table.GetColumn(o.Column) + if err := RenameDuplicatedColumn(ctx, conn, table, column); err != nil { + return err + } - return err + return nil } func (o *OpChangeType) Rollback(ctx context.Context, conn *sql.DB) error { @@ -143,15 +145,3 @@ func (o *OpChangeType) Validate(ctx context.Context, s *schema.Schema) error { } return nil } - -func duplicateColumnForTypeChange(ctx context.Context, conn *sql.DB, table *schema.Table, column schema.Column, newType string) error { - column.Name = TemporaryName(column.Name) - column.Type = newType - - _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", - pq.QuoteIdentifier(table.Name), - schemaColumnToSQL(column), - )) - - return err -} diff --git a/pkg/migrations/op_change_type_test.go b/pkg/migrations/op_change_type_test.go index 16df8886..b8208048 100644 --- a/pkg/migrations/op_change_type_test.go +++ b/pkg/migrations/op_change_type_test.go @@ -14,137 +14,215 @@ import ( func TestChangeColumnType(t *testing.T) { t.Parallel() - ExecuteTests(t, TestCases{{ - name: "change column type", - migrations: []migrations.Migration{ - { - Name: "01_add_table", - Operations: migrations.Operations{ - &migrations.OpCreateTable{ - Name: "reviews", - Columns: []migrations.Column{ - { - Name: "id", - Type: "serial", - Pk: true, - }, - { - Name: "username", - Type: "text", + ExecuteTests(t, TestCases{ + { + name: "change column type", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "reviews", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "username", + Type: "text", + }, + { + Name: "product", + Type: "text", + }, + { + Name: "rating", + Type: "text", + Default: ptr("0"), + }, }, - { - Name: "product", - Type: "text", + }, + }, + }, + { + Name: "02_change_type", + Operations: migrations.Operations{ + &migrations.OpAlterColumn{ + Table: "reviews", + Column: "rating", + Type: "integer", + Up: "CAST (rating AS integer)", + Down: "CAST (rating AS text)", + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + newVersionSchema := roll.VersionedSchemaName("public", "02_change_type") + + // The new (temporary) `rating` column should exist on the underlying table. + ColumnMustExist(t, db, "public", "reviews", migrations.TemporaryName("rating")) + + // The `rating` column in the new view must have the correct type. + ColumnMustHaveType(t, db, newVersionSchema, "reviews", "rating", "integer") + + // Inserting into the new `rating` column should work. + MustInsert(t, db, "public", "02_change_type", "reviews", map[string]string{ + "username": "alice", + "product": "apple", + "rating": "5", + }) + + // The value inserted into the new `rating` column has been backfilled into + // the old `rating` column. + rows := MustSelect(t, db, "public", "01_add_table", "reviews") + assert.Equal(t, []map[string]any{ + {"id": 1, "username": "alice", "product": "apple", "rating": "5"}, + }, rows) + + // Inserting into the old `rating` column should work. + MustInsert(t, db, "public", "01_add_table", "reviews", map[string]string{ + "username": "bob", + "product": "banana", + "rating": "8", + }) + + // The value inserted into the old `rating` column has been backfilled into + // the new `rating` column. + rows = MustSelect(t, db, "public", "02_change_type", "reviews") + assert.Equal(t, []map[string]any{ + {"id": 1, "username": "alice", "product": "apple", "rating": 5}, + {"id": 2, "username": "bob", "product": "banana", "rating": 8}, + }, rows) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + // The new (temporary) `rating` column should not exist on the underlying table. + ColumnMustNotExist(t, db, "public", "reviews", migrations.TemporaryName("rating")) + + // The up function no longer exists. + FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", "rating")) + // The down function no longer exists. + FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", migrations.TemporaryName("rating"))) + + // The up trigger no longer exists. + TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", "rating")) + // The down trigger no longer exists. + TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", migrations.TemporaryName("rating"))) + }, + afterComplete: func(t *testing.T, db *sql.DB) { + newVersionSchema := roll.VersionedSchemaName("public", "02_change_type") + + // The new (temporary) `rating` column should not exist on the underlying table. + ColumnMustNotExist(t, db, "public", "reviews", migrations.TemporaryName("rating")) + + // The `rating` column in the new view must have the correct type. + ColumnMustHaveType(t, db, newVersionSchema, "reviews", "rating", "integer") + + // Inserting into the new view should work. + MustInsert(t, db, "public", "02_change_type", "reviews", map[string]string{ + "username": "carl", + "product": "carrot", + "rating": "3", + }) + + // Selecting from the new view should succeed. + rows := MustSelect(t, db, "public", "02_change_type", "reviews") + assert.Equal(t, []map[string]any{ + {"id": 1, "username": "alice", "product": "apple", "rating": 5}, + {"id": 2, "username": "bob", "product": "banana", "rating": 8}, + {"id": 3, "username": "carl", "product": "carrot", "rating": 3}, + }, rows) + + // The up function no longer exists. + FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", "rating")) + // The down function no longer exists. + FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", migrations.TemporaryName("rating"))) + + // The up trigger no longer exists. + TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", "rating")) + // The down trigger no longer exists. + TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", migrations.TemporaryName("rating"))) + }, + }, + { + name: "changing column type preserves any foreign key constraints on the column", + migrations: []migrations.Migration{ + { + Name: "01_add_departments_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "departments", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "name", + Type: "text", + Nullable: false, + }, }, - { - Name: "rating", - Type: "text", - Default: ptr("0"), + }, + }, + }, + { + Name: "02_add_employees_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "employees", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "name", + Type: "text", + Nullable: false, + }, + { + Name: "department_id", + Type: "integer", + References: &migrations.ForeignKeyReference{ + Name: "fk_employee_department", + Table: "departments", + Column: "id", + }, + }, }, }, }, }, - }, - { - Name: "02_change_type", - Operations: migrations.Operations{ - &migrations.OpAlterColumn{ - Table: "reviews", - Column: "rating", - Type: "integer", - Up: "CAST (rating AS integer)", - Down: "CAST (rating AS text)", + { + Name: "03_change_type", + Operations: migrations.Operations{ + &migrations.OpAlterColumn{ + Table: "employees", + Column: "department_id", + Type: "bigint", + Up: "department_id", + Down: "department_id", + }, }, }, }, + afterStart: func(t *testing.T, db *sql.DB) { + // A temporary FK constraint has been created on the temporary column + ConstraintMustExist(t, db, "public", "employees", migrations.TemporaryName("fk_employee_department")) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // The foreign key constraint still exists on the column + ConstraintMustExist(t, db, "public", "employees", "fk_employee_department") + }, }, - afterStart: func(t *testing.T, db *sql.DB) { - newVersionSchema := roll.VersionedSchemaName("public", "02_change_type") - - // The new (temporary) `rating` column should exist on the underlying table. - ColumnMustExist(t, db, "public", "reviews", migrations.TemporaryName("rating")) - - // The `rating` column in the new view must have the correct type. - ColumnMustHaveType(t, db, newVersionSchema, "reviews", "rating", "integer") - - // Inserting into the new `rating` column should work. - MustInsert(t, db, "public", "02_change_type", "reviews", map[string]string{ - "username": "alice", - "product": "apple", - "rating": "5", - }) - - // The value inserted into the new `rating` column has been backfilled into - // the old `rating` column. - rows := MustSelect(t, db, "public", "01_add_table", "reviews") - assert.Equal(t, []map[string]any{ - {"id": 1, "username": "alice", "product": "apple", "rating": "5"}, - }, rows) - - // Inserting into the old `rating` column should work. - MustInsert(t, db, "public", "01_add_table", "reviews", map[string]string{ - "username": "bob", - "product": "banana", - "rating": "8", - }) - - // The value inserted into the old `rating` column has been backfilled into - // the new `rating` column. - rows = MustSelect(t, db, "public", "02_change_type", "reviews") - assert.Equal(t, []map[string]any{ - {"id": 1, "username": "alice", "product": "apple", "rating": 5}, - {"id": 2, "username": "bob", "product": "banana", "rating": 8}, - }, rows) - }, - afterRollback: func(t *testing.T, db *sql.DB) { - // The new (temporary) `rating` column should not exist on the underlying table. - ColumnMustNotExist(t, db, "public", "reviews", migrations.TemporaryName("rating")) - - // The up function no longer exists. - FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", "rating")) - // The down function no longer exists. - FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", migrations.TemporaryName("rating"))) - - // The up trigger no longer exists. - TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", "rating")) - // The down trigger no longer exists. - TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", migrations.TemporaryName("rating"))) - }, - afterComplete: func(t *testing.T, db *sql.DB) { - newVersionSchema := roll.VersionedSchemaName("public", "02_change_type") - - // The new (temporary) `rating` column should not exist on the underlying table. - ColumnMustNotExist(t, db, "public", "reviews", migrations.TemporaryName("rating")) - - // The `rating` column in the new view must have the correct type. - ColumnMustHaveType(t, db, newVersionSchema, "reviews", "rating", "integer") - - // Inserting into the new view should work. - MustInsert(t, db, "public", "02_change_type", "reviews", map[string]string{ - "username": "carl", - "product": "carrot", - "rating": "3", - }) - - // Selecting from the new view should succeed. - rows := MustSelect(t, db, "public", "02_change_type", "reviews") - assert.Equal(t, []map[string]any{ - {"id": 1, "username": "alice", "product": "apple", "rating": 5}, - {"id": 2, "username": "bob", "product": "banana", "rating": 8}, - {"id": 3, "username": "carl", "product": "carrot", "rating": 3}, - }, rows) - - // The up function no longer exists. - FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", "rating")) - // The down function no longer exists. - FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", migrations.TemporaryName("rating"))) - - // The up trigger no longer exists. - TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", "rating")) - // The down trigger no longer exists. - TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", migrations.TemporaryName("rating"))) - }, - }}) + }) } func TestChangeColumnTypeValidation(t *testing.T) {