diff --git a/pkg/migrations/op_set_fk_test.go b/pkg/migrations/op_set_fk_test.go index 1d2891989..183c50d9f 100644 --- a/pkg/migrations/op_set_fk_test.go +++ b/pkg/migrations/op_set_fk_test.go @@ -443,6 +443,82 @@ func TestSetForeignKey(t *testing.T) { }, testutils.CheckViolationErrorCode) }, }, + { + name: "not null is preserved when adding a foreign key constraint", + migrations: []migrations.Migration{ + { + Name: "01_add_tables", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "name", + Type: "text", + }, + }, + }, + &migrations.OpCreateTable{ + Name: "posts", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "title", + Type: "text", + Nullable: true, + }, + { + Name: "user_id", + Type: "integer", + Nullable: false, + }, + }, + }, + }, + }, + { + Name: "02_add_fk_constraint", + Operations: migrations.Operations{ + &migrations.OpAlterColumn{ + Table: "posts", + Column: "user_id", + References: &migrations.ForeignKeyReference{ + Name: "fk_users_id", + Table: "users", + Column: "id", + }, + Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)", + Down: "user_id", + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + // Inserting a row that violates the NOT NULL constraint on `user_id` fails. + MustNotInsert(t, db, "public", "02_add_fk_constraint", "posts", map[string]string{ + "id": "1", + "title": "post by alice", + }, testutils.NotNullViolationErrorCode) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // Inserting a row that violates the NOT NULL constraint on `user_id` fails. + MustNotInsert(t, db, "public", "02_add_fk_constraint", "posts", map[string]string{ + "id": "1", + "title": "post by alice", + }, testutils.NotNullViolationErrorCode) + }, + }, }) }