diff --git a/pkg/migrations/op_set_check_test.go b/pkg/migrations/op_set_check_test.go index ecf959eec..fed7c063b 100644 --- a/pkg/migrations/op_set_check_test.go +++ b/pkg/migrations/op_set_check_test.go @@ -359,6 +359,60 @@ func TestSetCheckConstraint(t *testing.T) { }, testutils.CheckViolationErrorCode) }, }, + { + name: "not null is preserved when adding a check constraint", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "posts", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "title", + Type: "text", + Nullable: false, + }, + }, + }, + }, + }, + { + Name: "02_add_check_constraint", + Operations: migrations.Operations{ + &migrations.OpAlterColumn{ + Table: "posts", + Column: "title", + Check: &migrations.CheckConstraint{ + Name: "check_title_length", + Constraint: "length(title) > 3", + }, + Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)", + Down: "title", + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + // Inserting a row that violates the NOT NULL constraint on `title` fails. + MustNotInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{ + "id": "1", + }, testutils.NotNullViolationErrorCode) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // Inserting a row that violates the NOT NULL constraint on `title` fails. + MustNotInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{ + "id": "1", + }, testutils.NotNullViolationErrorCode) + }, + }, }) }