Skip to content

Commit

Permalink
Convert ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a) SQL to `pgrol…
Browse files Browse the repository at this point in the history
…l` operation (#507)

Convert SQL DDL of the form:

```sql
"ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)"
```

To the equivalent `pgroll` operation:

```json
[
  {
    "create_constraint": {
      "type": "unique",
      "table": "foo",
      "name": "bar",
      "columns": ["a"],
      "up": {
        "a": "...",
      },
      "down": {
        "a": "..."
      }
    }
  }
]
```

We need to be conservative when converting SQL statements to `pgroll`
operations to ensure that information present in the SQL is not lost
during the conversion.

There are several options possible as part of `ADD CONSTRAINT ...
UNIQUE` statements that aren't currently representable by the
`OpCreateConstraint` operation, for example:

```sql
ALTER TABLE foo ADD CONSTRAINT bar UNIQUE NULLS NOT DISTINCT (a)
ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a) INCLUDE (b)
```

In these cases we must resort to converting to an `OpRawSQL`. 

Tests are added to cover these unrepresentable cases.

Part of #504
  • Loading branch information
andrew-farries authored Dec 4, 2024
1 parent fd94011 commit 06bfebd
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 4 deletions.
100 changes: 100 additions & 0 deletions pkg/sql2pgroll/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,30 @@ func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, err
op, err = convertAlterTableSetNotNull(stmt, alterTableCmd, false)
case pgq.AlterTableType_AT_AlterColumnType:
op, err = convertAlterTableAlterColumnType(stmt, alterTableCmd)
case pgq.AlterTableType_AT_AddConstraint:
op, err = convertAlterTableAddConstraint(stmt, alterTableCmd)
}

if err != nil {
return nil, err
}

if op == nil {
return nil, nil
}

ops = append(ops, op)
}

return ops, nil
}

// convertAlterTableSetNotNull converts SQL statements like:
//
// `ALTER TABLE foo ALTER COLUMN a SET NOT NULL`
// `ALTER TABLE foo ALTER COLUMN a DROP NOT NULL`
//
// to an OpAlterColumn operation.
func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd, notNull bool) (migrations.Operation, error) {
return &migrations.OpAlterColumn{
Table: stmt.GetRelation().GetRelname(),
Expand All @@ -55,6 +67,11 @@ func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCm
}, nil
}

// convertAlterTableAlterColumnType converts a SQL statement like:
//
// `ALTER TABLE foo ALTER COLUMN a SET DATA TYPE text`
//
// to an OpAlterColumn operation.
func convertAlterTableAlterColumnType(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) {
node, ok := cmd.GetDef().Node.(*pgq.Node_ColumnDef)
if !ok {
Expand All @@ -70,6 +87,89 @@ func convertAlterTableAlterColumnType(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTa
}, nil
}

// convertAlterTableAddConstraint converts SQL statements like:
//
// `ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)`
//
// To an OpCreateConstraint operation.
func convertAlterTableAddConstraint(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) {
node, ok := cmd.GetDef().Node.(*pgq.Node_Constraint)
if !ok {
return nil, fmt.Errorf("expected constraint definition, got %T", cmd.GetDef().Node)
}

var op migrations.Operation
var err error
switch node.Constraint.GetContype() {
case pgq.ConstrType_CONSTR_UNIQUE:
op, err = convertAlterTableAddUniqueConstraint(stmt, node.Constraint)
default:
return nil, nil
}

if err != nil {
return nil, err
}

return op, nil
}

// convertAlterTableAddUniqueConstraint converts SQL statements like:
//
// `ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)`
//
// to an OpCreateConstraint operation.
func convertAlterTableAddUniqueConstraint(stmt *pgq.AlterTableStmt, constraint *pgq.Constraint) (migrations.Operation, error) {
if !canConvertUniqueConstraint(constraint) {
return nil, nil
}

// Extract the columns covered by the unique constraint
columns := make([]string, 0, len(constraint.GetKeys()))
for _, keyNode := range constraint.GetKeys() {
key, ok := keyNode.Node.(*pgq.Node_String_)
if !ok {
return nil, fmt.Errorf("expected string key, got %T", keyNode)
}
columns = append(columns, key.String_.GetSval())
}

// Build the up and down SQL placeholders for each column covered by the
// constraint
upDown := make(map[string]string, len(columns))
for _, column := range columns {
upDown[column] = PlaceHolderSQL
}

return &migrations.OpCreateConstraint{
Type: migrations.OpCreateConstraintTypeUnique,
Name: constraint.GetConname(),
Table: stmt.GetRelation().GetRelname(),
Columns: columns,
Down: upDown,
Up: upDown,
}, nil
}

// canConvertUniqueConstraint checks if the unique constraint `constraint` can
// be faithfully converted to an OpCreateConstraint operation without losing
// information.
func canConvertUniqueConstraint(constraint *pgq.Constraint) bool {
if constraint.GetNullsNotDistinct() {
return false
}
if len(constraint.GetIncluding()) > 0 {
return false
}
if len(constraint.GetOptions()) > 0 {
return false
}
if constraint.GetIndexspace() != "" {
return false
}
return true
}

func ptr[T any](x T) *T {
return &x
}
35 changes: 32 additions & 3 deletions pkg/sql2pgroll/alter_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ func TestConvertAlterTableStatements(t *testing.T) {
sql: "ALTER TABLE foo ALTER COLUMN a TYPE text",
expectedOp: expect.AlterTableOp3,
},
{
sql: "ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)",
expectedOp: expect.AlterTableOp4,
},
{
sql: "ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a, b)",
expectedOp: expect.AlterTableOp5,
},
}

for _, tc := range tests {
Expand All @@ -44,10 +52,31 @@ func TestConvertAlterTableStatements(t *testing.T) {

require.Len(t, ops, 1)

alterColumnOps, ok := ops[0].(*migrations.OpAlterColumn)
require.True(t, ok)
assert.Equal(t, tc.expectedOp, ops[0])
})
}
}

func TestUnconvertableAlterTableAddConstraintStatements(t *testing.T) {
t.Parallel()

tests := []string{
// UNIQUE constraints with various options that are not representable by
// `OpCreateConstraint` operations
"ALTER TABLE foo ADD CONSTRAINT bar UNIQUE NULLS NOT DISTINCT (a)",
"ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a) INCLUDE (b)",
"ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a) WITH (fillfactor=70)",
"ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a) USING INDEX TABLESPACE baz",
}

for _, sql := range tests {
t.Run(sql, func(t *testing.T) {
ops, err := sql2pgroll.Convert(sql)
require.NoError(t, err)

require.Len(t, ops, 1)

assert.Equal(t, tc.expectedOp, alterColumnOps)
assert.Equal(t, expect.RawSQLOp(sql), ops[0])
})
}
}
2 changes: 1 addition & 1 deletion pkg/sql2pgroll/create_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
)

// convertCreateStmt converts a CREATE TABLE statement to a pgroll operation.
func convertCreateStmt(stmt *pgq.CreateStmt) ([]migrations.Operation, error) {
func convertCreateStmt(stmt *pgq.CreateStmt) (migrations.Operations, error) {
columns := make([]migrations.Column, 0, len(stmt.TableElts))
for _, elt := range stmt.TableElts {
columns = append(columns, convertColumnDef(elt.GetColumnDef()))
Expand Down
24 changes: 24 additions & 0 deletions pkg/sql2pgroll/expect/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,30 @@ var AlterTableOp3 = &migrations.OpAlterColumn{
Down: sql2pgroll.PlaceHolderSQL,
}

var AlterTableOp4 = &migrations.OpCreateConstraint{
Type: migrations.OpCreateConstraintTypeUnique,
Name: "bar",
Table: "foo",
Columns: []string{"a"},
Down: map[string]string{"a": sql2pgroll.PlaceHolderSQL},
Up: map[string]string{"a": sql2pgroll.PlaceHolderSQL},
}

var AlterTableOp5 = &migrations.OpCreateConstraint{
Type: migrations.OpCreateConstraintTypeUnique,
Name: "bar",
Table: "foo",
Columns: []string{"a", "b"},
Down: map[string]string{
"a": sql2pgroll.PlaceHolderSQL,
"b": sql2pgroll.PlaceHolderSQL,
},
Up: map[string]string{
"a": sql2pgroll.PlaceHolderSQL,
"b": sql2pgroll.PlaceHolderSQL,
},
}

func ptr[T any](v T) *T {
return &v
}
9 changes: 9 additions & 0 deletions pkg/sql2pgroll/expect/raw_sql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// SPDX-License-Identifier: Apache-2.0

package expect

import "github.com/xataio/pgroll/pkg/migrations"

func RawSQLOp(sql string) *migrations.OpRawSQL {
return &migrations.OpRawSQL{Up: sql}
}

0 comments on commit 06bfebd

Please sign in to comment.