diff --git a/pkg/migrations/dbactions.go b/pkg/migrations/dbactions.go index c1afd75f4..fa4a2dc5b 100644 --- a/pkg/migrations/dbactions.go +++ b/pkg/migrations/dbactions.go @@ -12,6 +12,7 @@ import ( "github.com/lib/pq" "github.com/xataio/pgroll/pkg/db" + "github.com/xataio/pgroll/pkg/schema" ) // DBAction is an interface for common database actions @@ -941,3 +942,54 @@ func (a *setReplicaIdentityAction) Execute(ctx context.Context) error { identitySQL)) return err } + +type alterReferencesAction struct { + conn db.DB + id string + referencedBy map[string][]*schema.ReferencedBy + table string + column string +} + +func NewAlterReferencesAction(conn db.DB, referencedBy map[string][]*schema.ReferencedBy, table, column string) *alterReferencesAction { + return &alterReferencesAction{ + conn: conn, + id: fmt.Sprintf("alter_references_%s_%s", table, column), + table: table, + column: column, + referencedBy: referencedBy, + } +} + +func (a *alterReferencesAction) ID() string { + return a.id +} + +func (a *alterReferencesAction) Execute(ctx context.Context) error { + for table, constraints := range a.referencedBy { + for _, constraint := range constraints { + // Drop the existing constraint + _, err := a.conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", + pq.QuoteIdentifier(table), + pq.QuoteIdentifier(constraint.Name), + )) + if err != nil { + return fmt.Errorf("dropping constraint %s on %s: %w", constraint.Name, table, err) + } + + // Recreate the constraint with the table and new column + newDef := strings.ReplaceAll(constraint.Definition, a.column, pq.QuoteIdentifier(TemporaryName(a.column))) + newDef = strings.ReplaceAll(newDef, a.table, pq.QuoteIdentifier(a.table)) + _, err = a.conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s %s", + pq.QuoteIdentifier(table), + pq.QuoteIdentifier(constraint.Name), + newDef, + )) + if err != nil { + return fmt.Errorf("altering references for %s.%s: %w", a.table, a.column, err) + } + + } + } + return nil +} diff --git a/pkg/migrations/op_alter_column.go b/pkg/migrations/op_alter_column.go index ac7f9fdd1..d862c98df 100644 --- a/pkg/migrations/op_alter_column.go +++ b/pkg/migrations/op_alter_column.go @@ -127,6 +127,7 @@ func (o *OpAlterColumn) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAc // Rename the new column to the old column name return append(dbActions, []DBAction{ NewAlterSequenceOwnerAction(conn, table.Name, column.Name, TemporaryName(column.Name)), + NewAlterReferencesAction(conn, table.ReferencedBy, table.Name, o.Column), NewDropColumnAction(conn, table.Name, o.Column), NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, o.Column), diff --git a/pkg/migrations/op_alter_column_test.go b/pkg/migrations/op_alter_column_test.go index 4654cf202..a36c23175 100644 --- a/pkg/migrations/op_alter_column_test.go +++ b/pkg/migrations/op_alter_column_test.go @@ -778,6 +778,138 @@ func TestAlterPrimaryKeyColumns(t *testing.T) { }) }, }, + { + name: "alter a single primary key column when the column is used in a foreign key constraint", + migrations: []migrations.Migration{ + { + Name: "01_create_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "events", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "name", + Type: "varchar(255)", + Nullable: true, + }, + }, + }, + &migrations.OpCreateTable{ + Name: "people", + Columns: []migrations.Column{ + { + Name: "id", + Type: "int", + Pk: true, + }, + { + Name: "name", + Type: "varchar(255)", + Nullable: true, + }, + { + Name: "manages", + Type: "serial", + Nullable: false, + References: &migrations.ForeignKeyReference{ + Table: "events", + Column: "id", + Name: "person_manages_event_fk", + }, + }, + }, + }, + }, + }, + { + Name: "02_alter_column", + Operations: migrations.Operations{ + &migrations.OpAlterColumn{ + Table: "events", + Column: "id", + Type: ptr("bigint"), + Up: "CAST(id AS bigint)", + Down: "SELECT CASE WHEN id < 2147483647 THEN CAST(id AS int) ELSE 0 END", + }, + &migrations.OpAlterColumn{ + Table: "people", + Column: "manages", + Type: ptr("bigint"), + Up: "CAST(manages AS bigint)", + Down: "SELECT CASE WHEN manages < 2147483647 THEN CAST(manages AS int) ELSE 0 END", + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB, schema string) { + PrimaryKeyConstraintMustExist(t, db, schema, "people", "people_pkey") + ValidatedForeignKeyMustExist(t, db, schema, "people", "person_manages_event_fk") + + bigint := "31474836471" // A value larger than int can hold + + // Inserting a row with integer id into the old schema should succeed + MustInsert(t, db, schema, "01_create_table", "events", map[string]string{ + "id": "1", + "name": "pgroll v1 release party", + }) + MustInsert(t, db, schema, "01_create_table", "people", map[string]string{ + "id": "1", + "name": "alice", + "manages": "1", + }) + // Inserting a row with integer bigint id into the old schema should fail + MustNotInsert(t, db, schema, "01_create_table", "events", map[string]string{ + "id": bigint, + "name": "pgroll v2 release party", + }, testutils.NumericValueOutOfRangeErrorCode) + + // Inserting a row with bigint id into the new schema should succeed + MustInsert(t, db, schema, "02_alter_column", "events", map[string]string{ + "id": bigint, + "name": "pgroll v2 release party", + }) + + // Inserting a row with a bigint value into the new column should succeed + MustInsert(t, db, schema, "02_alter_column", "people", map[string]string{ + "id": "2", + "name": "bob", + "manages": bigint, + }) + // Inserting a row into the `people` table with a `manages` field that + // violates the FK constraint fails + MustNotInsert(t, db, schema, "02_alter_column", "people", map[string]string{ + "id": "10", + "name": "alice", + "manages": "2", + }, testutils.FKViolationErrorCode) + }, + afterRollback: func(t *testing.T, db *sql.DB, schema string) { + PrimaryKeyConstraintMustExist(t, db, schema, "people", "people_pkey") + ValidatedForeignKeyMustExist(t, db, schema, "people", "person_manages_event_fk") + }, + afterComplete: func(t *testing.T, db *sql.DB, schema string) { + PrimaryKeyConstraintMustExist(t, db, schema, "people", "people_pkey") + ValidatedForeignKeyMustExist(t, db, schema, "people", "person_manages_event_fk") + + // Inserting a row with integer bigint into the new schema should succeed + MustInsert(t, db, schema, "02_alter_column", "events", map[string]string{ + "id": "31474836472", + "name": "pgroll v3 release party", + }) + // Inserting a row into the `people` table with a `manages` field that + // violates the FK constraint fails + MustNotInsert(t, db, schema, "02_alter_column", "people", map[string]string{ + "id": "3", + "name": "carol", + "manages": "2", + }, testutils.FKViolationErrorCode) + }, + }, }) } diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index ce72ec443..618fc6369 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -62,10 +62,25 @@ type Table struct { // ExcludeConstraints is a map of all exclude constraints defined on the table ExcludeConstraints map[string]*ExcludeConstraint `json:"excludeConstraints"` + // ReferencedBy is a map of table names that reference this table by foreign key + // The key is the name of the referencing table, and the value is a slice of foreign key references + ReferencedBy map[string][]*ReferencedBy `json:"referencedBy"` + // Whether or not the table has been deleted in the virtual schema Deleted bool `json:"-"` } +type ReferencedBy struct { + // Name is the name of the foreign key constraint + Name string `json:"name"` + + // Table is the name of the table that is referenced + Table string `json:"table"` + + // Definition is the definition of the foreign key constraint + Definition string `json:"definition"` +} + // Column represents a column in a table type Column struct { // Name is the actual name in postgres diff --git a/pkg/state/init.sql b/pkg/state/init.sql index e03fa4cdc..028109418 100644 --- a/pkg/state/init.sql +++ b/pkg/state/init.sql @@ -465,7 +465,20 @@ BEGIN AND fk_constraint.contype = 'f' GROUP BY fk_constraint.conrelid, fk_constraint.conname, fk_constraint.confrelid, fk_cl.relname, fk_constraint.confkey, fk_constraint.confmatchtype, fk_constraint.confdeltype, fk_constraint.confupdtype) AS fk_info INNER JOIN pg_attribute ref_attr ON ref_attr.attrelid = fk_info.confrelid AND ref_attr.attnum = ANY (fk_info.confkey) -- join the columns of the referenced table - GROUP BY fk_info.conname, fk_info.conrelid, fk_info.columns, fk_info.confrelid, fk_info.confmatchtype, fk_info.confdeltype, fk_info.confupdtype, fk_info.relname) AS fk_details)))), '{}'::json) + GROUP BY fk_info.conname, fk_info.conrelid, fk_info.columns, fk_info.confrelid, fk_info.confmatchtype, fk_info.confdeltype, fk_info.confupdtype, fk_info.relname) AS fk_details), 'referencedBy', ( + SELECT + json_object_agg(ref_table, ref_constraints) + FROM ( + SELECT + ref_cl.relname AS ref_table, + json_agg(json_build_object('name', ref_constraint.conname, 'table', ref_cl.relname, 'definition', pg_get_constraintdef(ref_constraint.oid))) AS ref_constraints + FROM pg_constraint AS ref_constraint + INNER JOIN pg_class ref_cl ON ref_constraint.conrelid = ref_cl.oid + WHERE + ref_constraint.confrelid = t.oid + AND ref_constraint.contype = 'f' + GROUP BY ref_cl.relname + ) AS ref_fk)))), '{}'::json) FROM pg_class AS t INNER JOIN pg_namespace AS ns ON t.relnamespace = ns.oid LEFT JOIN pg_description AS descr ON t.oid = descr.objoid