diff --git a/pkg/sql2pgroll/alter_table.go b/pkg/sql2pgroll/alter_table.go new file mode 100644 index 00000000..3351cd5d --- /dev/null +++ b/pkg/sql2pgroll/alter_table.go @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 + +package sql2pgroll + +import ( + pgq "github.com/pganalyze/pg_query_go/v6" + "github.com/xataio/pgroll/pkg/migrations" +) + +const PlaceHolderSQL = "TODO: Implement SQL data migration" + +// convertAlterTableStmt converts an ALTER TABLE statement to pgroll operations. +func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, error) { + if stmt.Objtype != pgq.ObjectType_OBJECT_TABLE { + return nil, nil + } + + var ops migrations.Operations + for _, cmd := range stmt.Cmds { + alterTableCmd := cmd.GetAlterTableCmd() + if alterTableCmd == nil { + continue + } + + //nolint:gocritic + switch alterTableCmd.Subtype { + case pgq.AlterTableType_AT_SetNotNull: + ops = append(ops, convertAlterTableSetNotNull(stmt, alterTableCmd)) + } + } + + return ops, nil +} + +func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) migrations.Operation { + return &migrations.OpAlterColumn{ + Table: stmt.GetRelation().GetRelname(), + Column: cmd.GetName(), + Nullable: ptr(false), + Up: PlaceHolderSQL, + Down: PlaceHolderSQL, + } +} + +func ptr[T any](x T) *T { + return &x +} diff --git a/pkg/sql2pgroll/alter_table_test.go b/pkg/sql2pgroll/alter_table_test.go new file mode 100644 index 00000000..b074d9f5 --- /dev/null +++ b/pkg/sql2pgroll/alter_table_test.go @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: Apache-2.0 + +package sql2pgroll_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/xataio/pgroll/pkg/migrations" + "github.com/xataio/pgroll/pkg/sql2pgroll" + "github.com/xataio/pgroll/pkg/sql2pgroll/expect" +) + +func TestConvertAlterTableStatements(t *testing.T) { + t.Parallel() + + tests := []struct { + sql string + expectedOp migrations.Operation + }{ + { + sql: "ALTER TABLE foo ALTER COLUMN a SET NOT NULL", + expectedOp: expect.AlterTableOp1, + }, + } + + for _, tc := range tests { + t.Run(tc.sql, func(t *testing.T) { + ops, err := sql2pgroll.Convert(tc.sql) + require.NoError(t, err) + + require.Len(t, ops, 1) + + alterColumnOps, ok := ops[0].(*migrations.OpAlterColumn) + require.True(t, ok) + + assert.Equal(t, tc.expectedOp, alterColumnOps) + }) + } +} diff --git a/pkg/sql2pgroll/convert.go b/pkg/sql2pgroll/convert.go new file mode 100644 index 00000000..372f25cf --- /dev/null +++ b/pkg/sql2pgroll/convert.go @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 + +package sql2pgroll + +import ( + "fmt" + + pgq "github.com/pganalyze/pg_query_go/v6" + "github.com/xataio/pgroll/pkg/migrations" +) + +var ErrStatementCount = fmt.Errorf("expected exactly one statement") + +// Convert converts a SQL statement to a slice of pgroll operations. +func Convert(sql string) (migrations.Operations, error) { + ops, err := convert(sql) + if err != nil { + return nil, err + } + + if ops == nil { + return makeRawSQLOperation(sql), nil + } + + return ops, nil +} + +func convert(sql string) (migrations.Operations, error) { + tree, err := pgq.Parse(sql) + if err != nil { + return nil, fmt.Errorf("parse error: %w", err) + } + + stmts := tree.GetStmts() + if len(stmts) != 1 { + return nil, fmt.Errorf("%w: got %d statements", ErrStatementCount, len(stmts)) + } + node := stmts[0].GetStmt().GetNode() + + switch node := (node).(type) { + case *pgq.Node_CreateStmt: + return convertCreateStmt(node.CreateStmt) + case *pgq.Node_AlterTableStmt: + return convertAlterTableStmt(node.AlterTableStmt) + default: + return makeRawSQLOperation(sql), nil + } +} + +func makeRawSQLOperation(sql string) migrations.Operations { + return migrations.Operations{ + &migrations.OpRawSQL{Up: sql}, + } +} diff --git a/pkg/sql2pgroll/create_table.go b/pkg/sql2pgroll/create_table.go new file mode 100644 index 00000000..63604895 --- /dev/null +++ b/pkg/sql2pgroll/create_table.go @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: Apache-2.0 + +package sql2pgroll + +import ( + "fmt" + "strings" + + pgq "github.com/pganalyze/pg_query_go/v6" + "github.com/xataio/pgroll/pkg/migrations" +) + +// convertCreateStmt converts a CREATE TABLE statement to a pgroll operation. +func convertCreateStmt(stmt *pgq.CreateStmt) ([]migrations.Operation, error) { + columns := make([]migrations.Column, 0, len(stmt.TableElts)) + for _, elt := range stmt.TableElts { + columns = append(columns, convertColumnDef(elt.GetColumnDef())) + } + + return migrations.Operations{ + &migrations.OpCreateTable{ + Name: stmt.Relation.Relname, + Columns: columns, + }, + }, nil +} + +func convertColumnDef(col *pgq.ColumnDef) migrations.Column { + ignoredTypeParts := map[string]bool{ + "pg_catalog": true, + } + + // Build the type name, including any schema qualifiers + typeParts := make([]string, 0, len(col.GetTypeName().Names)) + for _, node := range col.GetTypeName().Names { + typePart := node.GetString_().GetSval() + if _, ok := ignoredTypeParts[typePart]; ok { + continue + } + typeParts = append(typeParts, typePart) + } + + // Build the type modifiers, such as precision and scale for numeric types + var typeMods []string + for _, node := range col.GetTypeName().Typmods { + if x, ok := node.GetAConst().Val.(*pgq.A_Const_Ival); ok { + typeMods = append(typeMods, fmt.Sprintf("%d", x.Ival.GetIval())) + } + } + var typeModifier string + if len(typeMods) > 0 { + typeModifier = fmt.Sprintf("(%s)", strings.Join(typeMods, ",")) + } + + // Build the array bounds for array types + var arrayBounds string + for _, node := range col.GetTypeName().ArrayBounds { + bound := node.GetInteger().GetIval() + if bound == -1 { + arrayBounds = "[]" + } else { + arrayBounds = fmt.Sprintf("%s[%d]", arrayBounds, bound) + } + } + + // Determine column nullability, uniqueness, and primary key status + var notNull, unique, pk bool + var defaultValue *string + for _, constraint := range col.Constraints { + if constraint.GetConstraint().GetContype() == pgq.ConstrType_CONSTR_NOTNULL { + notNull = true + } + if constraint.GetConstraint().GetContype() == pgq.ConstrType_CONSTR_UNIQUE { + unique = true + } + if constraint.GetConstraint().GetContype() == pgq.ConstrType_CONSTR_PRIMARY { + pk = true + notNull = true + } + } + + return migrations.Column{ + Name: col.Colname, + Type: strings.Join(typeParts, ".") + typeModifier + arrayBounds, + Nullable: !notNull, + Unique: unique, + Default: defaultValue, + Pk: pk, + } +} diff --git a/pkg/sql2pgroll/create_table_test.go b/pkg/sql2pgroll/create_table_test.go new file mode 100644 index 00000000..5636f489 --- /dev/null +++ b/pkg/sql2pgroll/create_table_test.go @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: Apache-2.0 + +package sql2pgroll_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/xataio/pgroll/pkg/migrations" + "github.com/xataio/pgroll/pkg/sql2pgroll" + "github.com/xataio/pgroll/pkg/sql2pgroll/expect" +) + +func TestConvertCreateTableStatements(t *testing.T) { + t.Parallel() + + tests := []struct { + sql string + expectedOp migrations.Operation + }{ + { + sql: "CREATE TABLE foo(a int)", + expectedOp: expect.CreateTableOp1, + }, + { + sql: "CREATE TABLE foo(a int NOT NULL)", + expectedOp: expect.CreateTableOp2, + }, + { + sql: "CREATE TABLE foo(a varchar(255))", + expectedOp: expect.CreateTableOp3, + }, + { + sql: "CREATE TABLE foo(a numeric(10, 2))", + expectedOp: expect.CreateTableOp4, + }, + { + sql: "CREATE TABLE foo(a int UNIQUE)", + expectedOp: expect.CreateTableOp5, + }, + { + sql: "CREATE TABLE foo(a int PRIMARY KEY)", + expectedOp: expect.CreateTableOp6, + }, + { + sql: "CREATE TABLE foo(a text[])", + expectedOp: expect.CreateTableOp7, + }, + { + sql: "CREATE TABLE foo(a text[5])", + expectedOp: expect.CreateTableOp8, + }, + { + sql: "CREATE TABLE foo(a text[5][3])", + expectedOp: expect.CreateTableOp9, + }, + } + + for _, tc := range tests { + t.Run(tc.sql, func(t *testing.T) { + ops, err := sql2pgroll.Convert(tc.sql) + require.NoError(t, err) + + require.Len(t, ops, 1) + + createTableOp, ok := ops[0].(*migrations.OpCreateTable) + require.True(t, ok) + + assert.Equal(t, tc.expectedOp, createTableOp) + }) + } +} diff --git a/pkg/sql2pgroll/expect/alter_table.go b/pkg/sql2pgroll/expect/alter_table.go new file mode 100644 index 00000000..f1da4a85 --- /dev/null +++ b/pkg/sql2pgroll/expect/alter_table.go @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 + +package expect + +import ( + "github.com/xataio/pgroll/pkg/migrations" + "github.com/xataio/pgroll/pkg/sql2pgroll" +) + +var AlterTableOp1 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "a", + Nullable: ptr(false), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + +func ptr[T any](v T) *T { + return &v +} diff --git a/pkg/sql2pgroll/expect/create_table.go b/pkg/sql2pgroll/expect/create_table.go new file mode 100644 index 00000000..d8ded16c --- /dev/null +++ b/pkg/sql2pgroll/expect/create_table.go @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: Apache-2.0 + +package expect + +import "github.com/xataio/pgroll/pkg/migrations" + +var CreateTableOp1 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "int4", + Nullable: true, + }, + }, +} + +var CreateTableOp2 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "int4", + }, + }, +} + +var CreateTableOp3 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "varchar(255)", + Nullable: true, + }, + }, +} + +var CreateTableOp4 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "numeric(10,2)", + Nullable: true, + }, + }, +} + +var CreateTableOp5 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "int4", + Nullable: true, + Unique: true, + }, + }, +} + +var CreateTableOp6 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "int4", + Pk: true, + }, + }, +} + +var CreateTableOp7 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "text[]", + Nullable: true, + }, + }, +} + +var CreateTableOp8 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "text[5]", + Nullable: true, + }, + }, +} + +var CreateTableOp9 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "text[5][3]", + Nullable: true, + }, + }, +}