diff --git a/pkg/sql2pgroll/alter_table.go b/pkg/sql2pgroll/alter_table.go index 2d6b3d6b..6f59825b 100644 --- a/pkg/sql2pgroll/alter_table.go +++ b/pkg/sql2pgroll/alter_table.go @@ -98,11 +98,13 @@ func convertAlterTableAlterColumnType(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTa }, nil } -// convertAlterTableAddConstraint converts SQL statements like: +// convertAlterTableAddConstraint converts SQL statements that add UNIQUE or FOREIGN KEY constraints, +// for example: // // `ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)` +// `ALTER TABLE foo ADD CONSTRAINT fk_bar_c FOREIGN KEY (a) REFERENCES bar (c);` // -// To an OpCreateConstraint operation. +// An OpCreateConstraint operation is returned. func convertAlterTableAddConstraint(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) { node, ok := cmd.GetDef().Node.(*pgq.Node_Constraint) if !ok { @@ -114,6 +116,8 @@ func convertAlterTableAddConstraint(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTabl switch node.Constraint.GetContype() { case pgq.ConstrType_CONSTR_UNIQUE: op, err = convertAlterTableAddUniqueConstraint(stmt, node.Constraint) + case pgq.ConstrType_CONSTR_FOREIGN: + op, err = convertAlterTableAddForeignKeyConstraint(stmt, node.Constraint) default: return nil, nil } @@ -162,6 +166,87 @@ func convertAlterTableAddUniqueConstraint(stmt *pgq.AlterTableStmt, constraint * }, nil } +func convertAlterTableAddForeignKeyConstraint(stmt *pgq.AlterTableStmt, constraint *pgq.Constraint) (migrations.Operation, error) { + if !canConvertAlterTableAddForeignKeyConstraint(constraint) { + return nil, nil + } + + columns := make([]string, len(constraint.GetFkAttrs())) + for i := range columns { + columns[i] = constraint.GetFkAttrs()[i].GetString_().GetSval() + } + + foreignColumns := make([]string, len(constraint.GetPkAttrs())) + for i := range columns { + foreignColumns[i] = constraint.GetPkAttrs()[i].GetString_().GetSval() + } + + migs := make(map[string]string) + for _, column := range columns { + migs[column] = PlaceHolderSQL + } + + var onDelete migrations.ForeignKeyReferenceOnDelete + switch constraint.GetFkDelAction() { + case "a": + onDelete = migrations.ForeignKeyReferenceOnDeleteNOACTION + case "c": + onDelete = migrations.ForeignKeyReferenceOnDeleteCASCADE + case "r": + onDelete = migrations.ForeignKeyReferenceOnDeleteRESTRICT + case "d": + onDelete = migrations.ForeignKeyReferenceOnDeleteSETDEFAULT + case "n": + onDelete = migrations.ForeignKeyReferenceOnDeleteSETNULL + default: + return nil, fmt.Errorf("unknown delete action: %q", constraint.GetFkDelAction()) + } + + tableName := stmt.GetRelation().GetRelname() + if stmt.GetRelation().GetSchemaname() != "" { + tableName = stmt.GetRelation().GetSchemaname() + "." + tableName + } + + foreignTable := constraint.GetPktable().GetRelname() + if constraint.GetPktable().GetSchemaname() != "" { + foreignTable = constraint.GetPktable().GetSchemaname() + "." + foreignTable + } + + return &migrations.OpCreateConstraint{ + Columns: columns, + Up: migs, + Down: migs, + Name: constraint.GetConname(), + References: &migrations.OpCreateConstraintReferences{ + Columns: foreignColumns, + OnDelete: onDelete, + Table: foreignTable, + }, + Table: tableName, + Type: migrations.OpCreateConstraintTypeForeignKey, + }, nil +} + +func canConvertAlterTableAddForeignKeyConstraint(constraint *pgq.Constraint) bool { + switch constraint.GetFkUpdAction() { + case "r", "c", "n", "d": + // RESTRICT, CASCADE, SET NULL, SET DEFAULT + return false + case "a": + // NO ACTION, the default + break + } + switch constraint.GetFkMatchtype() { + case "f": + // FULL + return false + case "s": + // SIMPLE, the default + break + } + return true +} + // convertAlterTableSetColumnDefault converts SQL statements like: // // `ALTER TABLE foo COLUMN bar SET DEFAULT 'foo'` diff --git a/pkg/sql2pgroll/alter_table_test.go b/pkg/sql2pgroll/alter_table_test.go index 65c39ba8..f3665aa2 100644 --- a/pkg/sql2pgroll/alter_table_test.go +++ b/pkg/sql2pgroll/alter_table_test.go @@ -80,6 +80,38 @@ func TestConvertAlterTableStatements(t *testing.T) { sql: "ALTER TABLE foo DROP COLUMN bar RESTRICT ", expectedOp: expect.DropColumnOp1, }, + { + sql: "ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d);", + expectedOp: expect.AddForeignKeyOp1WithOnDelete(migrations.ForeignKeyReferenceOnDeleteNOACTION), + }, + { + sql: "ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) ON DELETE NO ACTION;", + expectedOp: expect.AddForeignKeyOp1WithOnDelete(migrations.ForeignKeyReferenceOnDeleteNOACTION), + }, + { + sql: "ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) ON DELETE RESTRICT;", + expectedOp: expect.AddForeignKeyOp1WithOnDelete(migrations.ForeignKeyReferenceOnDeleteRESTRICT), + }, + { + sql: "ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) ON DELETE SET DEFAULT ;", + expectedOp: expect.AddForeignKeyOp1WithOnDelete(migrations.ForeignKeyReferenceOnDeleteSETDEFAULT), + }, + { + sql: "ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) ON DELETE SET NULL;", + expectedOp: expect.AddForeignKeyOp1WithOnDelete(migrations.ForeignKeyReferenceOnDeleteSETNULL), + }, + { + sql: "ALTER TABLE foo ADD CONSTRAINT fk_bar_c FOREIGN KEY (a) REFERENCES bar (c);", + expectedOp: expect.AddForeignKeyOp2, + }, + { + sql: "ALTER TABLE foo ADD CONSTRAINT fk_bar_c FOREIGN KEY (a) REFERENCES bar (c) NOT VALID;", + expectedOp: expect.AddForeignKeyOp2, + }, + { + sql: "ALTER TABLE schema_a.foo ADD CONSTRAINT fk_bar_c FOREIGN KEY (a) REFERENCES schema_a.bar (c);", + expectedOp: expect.AddForeignKeyOp3, + }, } for _, tc := range tests { @@ -116,6 +148,15 @@ func TestUnconvertableAlterTableStatements(t *testing.T) { // Non literal default values "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT now()", + + // Unsupported foreign key statements + "ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) ON UPDATE RESTRICT;", + "ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) ON UPDATE CASCADE;", + "ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) ON UPDATE SET NULL;", + "ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) ON UPDATE SET DEFAULT;", + "ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) MATCH FULL;", + // MATCH PARTIAL is not implemented in the actual parser yet + //"ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) MATCH PARTIAL;", } for _, sql := range tests { diff --git a/pkg/sql2pgroll/expect/add_foreign_key.go b/pkg/sql2pgroll/expect/add_foreign_key.go new file mode 100644 index 00000000..590a20ad --- /dev/null +++ b/pkg/sql2pgroll/expect/add_foreign_key.go @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 + +package expect + +import ( + "github.com/xataio/pgroll/pkg/migrations" + "github.com/xataio/pgroll/pkg/sql2pgroll" +) + +func AddForeignKeyOp1WithOnDelete(onDelete migrations.ForeignKeyReferenceOnDelete) *migrations.OpCreateConstraint { + return &migrations.OpCreateConstraint{ + Columns: []string{"a", "b"}, + Name: "fk_bar_cd", + References: &migrations.OpCreateConstraintReferences{ + Columns: []string{"c", "d"}, + OnDelete: onDelete, + Table: "bar", + }, + Table: "foo", + Type: migrations.OpCreateConstraintTypeForeignKey, + Up: map[string]string{ + "a": sql2pgroll.PlaceHolderSQL, + "b": sql2pgroll.PlaceHolderSQL, + }, + Down: map[string]string{ + "a": sql2pgroll.PlaceHolderSQL, + "b": sql2pgroll.PlaceHolderSQL, + }, + } +} + +var AddForeignKeyOp2 = &migrations.OpCreateConstraint{ + Columns: []string{"a"}, + Name: "fk_bar_c", + References: &migrations.OpCreateConstraintReferences{ + Columns: []string{"c"}, + OnDelete: migrations.ForeignKeyReferenceOnDeleteNOACTION, + Table: "bar", + }, + Table: "foo", + Type: migrations.OpCreateConstraintTypeForeignKey, + Up: map[string]string{ + "a": sql2pgroll.PlaceHolderSQL, + }, + Down: map[string]string{ + "a": sql2pgroll.PlaceHolderSQL, + }, +} + +var AddForeignKeyOp3 = &migrations.OpCreateConstraint{ + Columns: []string{"a"}, + Name: "fk_bar_c", + References: &migrations.OpCreateConstraintReferences{ + Columns: []string{"c"}, + OnDelete: migrations.ForeignKeyReferenceOnDeleteNOACTION, + Table: "schema_a.bar", + }, + Table: "schema_a.foo", + Type: migrations.OpCreateConstraintTypeForeignKey, + Up: map[string]string{ + "a": sql2pgroll.PlaceHolderSQL, + }, + Down: map[string]string{ + "a": sql2pgroll.PlaceHolderSQL, + }, +}