Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve foreign key constraints on columns duplicated for backfilling #230

Merged
merged 5 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions pkg/migrations/duplicate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// SPDX-License-Identifier: Apache-2.0

package migrations

import (
"context"
"database/sql"
"fmt"
"slices"
"strings"

"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/schema"
)

type Duplicator struct {
conn *sql.DB
table *schema.Table
column *schema.Column
asName string
}

// NewColumnDuplicator creates a new Duplicator for a column.
func NewColumnDuplicator(conn *sql.DB, table *schema.Table, column *schema.Column) *Duplicator {
return &Duplicator{
conn: conn,
table: table,
column: column,
asName: TemporaryName(column.Name),
}
}

// Duplicate creates a new column with the same type and foreign key
// constraints as the original column.
func (d *Duplicator) Duplicate(ctx context.Context) error {
const (
cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s`
cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)`
)

sql := fmt.Sprintf(cAlterTableSQL,
pq.QuoteIdentifier(d.table.Name),
pq.QuoteIdentifier(d.asName),
d.column.Type)

for _, fk := range d.table.ForeignKeys {
if slices.Contains(fk.Columns, d.column.Name) {
sql += fmt.Sprintf(", "+cAddForeignKeySQL,
pq.QuoteIdentifier(TemporaryName(fk.Name)),
strings.Join(quoteColumnNames(copyAndReplace(fk.Columns, d.column.Name, d.asName)), ", "),
pq.QuoteIdentifier(fk.ReferencedTable),
strings.Join(quoteColumnNames(fk.ReferencedColumns), ", "))
}
}

_, err := d.conn.ExecContext(ctx, sql)

return err
}

func copyAndReplace(xs []string, oldValue, newValue string) []string {
ys := slices.Clone(xs)

for i, c := range ys {
if c == oldValue {
ys[i] = newValue
}
}
return ys
}
9 changes: 8 additions & 1 deletion pkg/migrations/op_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"io"
"strings"
)

type OpName string
Expand All @@ -33,8 +34,14 @@ const (
OpNameChangeType OpName = "change_type"
)

const temporaryPrefix = "_pgroll_new_"

func TemporaryName(name string) string {
return "_pgroll_new_" + name
return temporaryPrefix + name
}

func StripTemporaryPrefix(name string) string {
return strings.TrimPrefix(name, temporaryPrefix)
}

func ReadMigration(r io.Reader) (*Migration, error) {
Expand Down
11 changes: 5 additions & 6 deletions pkg/migrations/op_set_notnull.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ func (o *OpSetNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema stri
column := table.GetColumn(o.Column)

// Create a copy of the column on the underlying table.
if err := duplicateColumn(ctx, conn, table, *column); err != nil {
d := NewColumnDuplicator(conn, table, column)
if err := d.Duplicate(ctx); err != nil {
return fmt.Errorf("failed to duplicate column: %w", err)
}

Expand Down Expand Up @@ -130,11 +131,9 @@ func (o *OpSetNotNull) Complete(ctx context.Context, conn *sql.DB, s *schema.Sch
}

// Rename the new column to the old column name
_, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(TemporaryName(o.Column)),
pq.QuoteIdentifier(o.Column)))
if err != nil {
table := s.GetTable(o.Table)
column := table.GetColumn(o.Column)
if err := RenameDuplicatedColumn(ctx, conn, table, column); err != nil {
return err
}

Expand Down
77 changes: 77 additions & 0 deletions pkg/migrations/op_set_notnull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,83 @@ func TestSetNotNull(t *testing.T) {
afterComplete: func(t *testing.T, db *sql.DB) {
},
},
{
name: "setting a foreign key column to not null retains the foreign key constraint",
migrations: []migrations.Migration{
{
Name: "01_add_departments_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "departments",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: true,
},
{
Name: "name",
Type: "text",
Nullable: false,
},
},
},
},
},
{
Name: "02_add_employees_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "employees",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: true,
},
{
Name: "name",
Type: "text",
Nullable: false,
},
{
Name: "department_id",
Type: "integer",
Nullable: true,
References: &migrations.ForeignKeyReference{
Name: "fk_employee_department",
Table: "departments",
Column: "id",
},
},
},
},
},
},
{
Name: "03_set_not_null",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "employees",
Column: "department_id",
Nullable: ptr(false),
Up: "(SELECT CASE WHEN department_id IS NULL THEN 1 ELSE department_id END)",
Down: "department_id",
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// A temporary FK constraint has been created on the temporary column
ConstraintMustExist(t, db, "public", "employees", migrations.TemporaryName("fk_employee_department"))
},
afterRollback: func(t *testing.T, db *sql.DB) {
},
afterComplete: func(t *testing.T, db *sql.DB) {
// The foreign key constraint still exists on the column
ConstraintMustExist(t, db, "public", "employees", "fk_employee_department")
},
},
})
}

Expand Down
53 changes: 53 additions & 0 deletions pkg/migrations/rename.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// SPDX-License-Identifier: Apache-2.0

package migrations

import (
"context"
"database/sql"
"fmt"
"slices"

"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/schema"
)

// RenameDuplicatedColumn renames a duplicated column to its original name and renames any foreign keys
// on the duplicated column to their original name.
func RenameDuplicatedColumn(ctx context.Context, conn *sql.DB, table *schema.Table, column *schema.Column) error {
const (
cRenameColumnSQL = `ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s`
cRenameConstraintSQL = `ALTER TABLE IF EXISTS %s RENAME CONSTRAINT %s TO %s`
)

// Rename the old column to the new column name
renameColumnSQL := fmt.Sprintf(cRenameColumnSQL,
pq.QuoteIdentifier(table.Name),
pq.QuoteIdentifier(TemporaryName(column.Name)),
pq.QuoteIdentifier(column.Name))

_, err := conn.ExecContext(ctx, renameColumnSQL)
if err != nil {
return fmt.Errorf("failed to rename duplicated column %q: %w", column.Name, err)
}

// Rename any foreign keys on the duplicated column from their temporary name
// to their original name
var renameConstraintSQL string
for _, fk := range table.ForeignKeys {
if slices.Contains(fk.Columns, TemporaryName(column.Name)) {
renameConstraintSQL = fmt.Sprintf(cRenameConstraintSQL,
pq.QuoteIdentifier(table.Name),
pq.QuoteIdentifier(fk.Name),
pq.QuoteIdentifier(StripTemporaryPrefix(fk.Name)),
)

_, err = conn.ExecContext(ctx, renameConstraintSQL)
if err != nil {
return fmt.Errorf("failed to rename column constraint %q: %w", fk.Name, err)
}
}
}

return nil
}
Loading