diff --git a/pkg/migrations/op_common_test.go b/pkg/migrations/op_common_test.go index feb193210..c9ca2665d 100644 --- a/pkg/migrations/op_common_test.go +++ b/pkg/migrations/op_common_test.go @@ -492,12 +492,22 @@ func MustInsert(t *testing.T, db *sql.DB, schema, version, table string, record } } -func MustNotInsert(t *testing.T, db *sql.DB, schema, version, table string, record map[string]string) { +func MustNotInsert(t *testing.T, db *sql.DB, schema, version, table string, record map[string]string, errorCode string) { t.Helper() - if err := insert(t, db, schema, version, table, record); err == nil { + err := insert(t, db, schema, version, table, record) + if err == nil { t.Fatal("Expected INSERT to fail") } + + var pqErr *pq.Error + if ok := errors.As(err, &pqErr); ok { + if pqErr.Code.Name() != errorCode { + t.Fatalf("Expected INSERT to fail with %q, got %q", errorCode, pqErr.Code.Name()) + } + } else { + t.Fatalf("INSERT failed with unknown error: %v", err) + } } func insert(t *testing.T, db *sql.DB, schema, version, table string, record map[string]string) error { diff --git a/pkg/testutils/error_codes.go b/pkg/testutils/error_codes.go new file mode 100644 index 000000000..a30ea70fe --- /dev/null +++ b/pkg/testutils/error_codes.go @@ -0,0 +1,8 @@ +package testutils + +const ( + CheckViolationErrorCode string = "check_violation" + FKViolationErrorCode string = "foreign_key_violation" + NotNullViolationErrorCode string = "not_null_violation" + UniqueViolationErrorCode string = "unique_violation" +)