Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unique & FK constraints info to the schema #218

Merged
merged 2 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/xataio/pgroll
go 1.21

require (
github.com/google/go-cmp v0.6.0
github.com/lib/pq v1.10.9
github.com/pterm/pterm v0.12.69
github.com/spf13/cobra v1.7.0
Expand All @@ -11,6 +12,7 @@ require (
github.com/testcontainers/testcontainers-go v0.23.0
github.com/testcontainers/testcontainers-go/modules/postgres v0.23.0
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1
gotest.tools/v3 v3.5.0
)

require (
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
Expand Down
17 changes: 17 additions & 0 deletions pkg/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ type Table struct {

// The columns that make up the primary key
PrimaryKey []string `json:"primaryKey"`

ForeignKeys map[string]ForeignKey `json:"foreignKeys"`
}

type Column struct {
Expand All @@ -57,6 +59,7 @@ type Column struct {

Default *string `json:"default"`
Nullable bool `json:"nullable"`
Unique bool `json:"unique"`

// Optional comment for the column
Comment string `json:"comment"`
Expand All @@ -67,6 +70,20 @@ type Index struct {
Name string `json:"name"`
}

type ForeignKey struct {
// Name is the name of the foreign key in postgres
Name string `json:"name"`

// The columns that the foreign key is defined on
Columns []string `json:"columns"`

// The table that the foreign key references
ReferencedTable string `json:"referencedTable"`

// The columns in the referenced table that the foreign key references
ReferencedColumns []string `json:"referencedColumns"`
}

func (s *Schema) GetTable(name string) *Table {
if s.Tables == nil {
return nil
Expand Down
38 changes: 37 additions & 1 deletion pkg/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,21 @@ BEGIN
)
ELSE format_type(attr.atttypid, attr.atttypmod)
END AS type,
descr.description AS comment
descr.description AS comment,
(EXISTS (
SELECT 1
FROM pg_constraint
WHERE conrelid = attr.attrelid
AND conkey::int[] @> ARRAY[attr.attnum::int]
AND contype = 'u'
) OR EXISTS (
SELECT 1
FROM pg_index
JOIN pg_class ON pg_class.oid = pg_index.indexrelid
WHERE indrelid = attr.attrelid
AND indisunique
AND pg_index.indkey::int[] @> ARRAY[attr.attnum::int]
)) AS unique
FROM
pg_attribute AS attr
INNER JOIN pg_type AS tp ON attr.atttypid = tp.oid
Expand Down Expand Up @@ -158,6 +172,28 @@ BEGIN
))
FROM pg_index pi
WHERE pi.indrelid = t.oid::regclass
),
'foreignKeys', (
SELECT json_object_agg(fk_details.conname, json_build_object(
'name', fk_details.conname,
'columns', fk_details.columns,
'referencedTable', fk_details.referencedTable,
'referencedColumns', fk_details.referencedColumns
))
FROM (
SELECT
fk_constraint.conname,
array_agg(fk_attr.attname ORDER BY fk_constraint.conkey::int[]) AS columns,
fk_cl.relname AS referencedTable,
array_agg(ref_attr.attname ORDER BY fk_constraint.confkey::int[]) AS referencedColumns
FROM pg_constraint AS fk_constraint
INNER JOIN pg_class fk_cl ON fk_constraint.confrelid = fk_cl.oid
INNER JOIN pg_attribute fk_attr ON fk_attr.attrelid = fk_constraint.conrelid AND fk_attr.attnum = ANY(fk_constraint.conkey)
INNER JOIN pg_attribute ref_attr ON ref_attr.attrelid = fk_constraint.confrelid AND ref_attr.attnum = ANY(fk_constraint.confkey)
WHERE fk_constraint.conrelid = t.oid
AND fk_constraint.contype = 'f'
GROUP BY fk_constraint.conname, fk_cl.relname
) AS fk_details
)
)) FROM pg_class AS t
INNER JOIN pg_namespace AS ns ON t.relnamespace = ns.oid
Expand Down
131 changes: 131 additions & 0 deletions pkg/state/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/assert"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
"github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/schema"
"github.com/xataio/pgroll/pkg/state"
)

Expand Down Expand Up @@ -57,6 +60,134 @@ func TestSchemaOptionIsRespected(t *testing.T) {
})
}

func TestReadSchema(t *testing.T) {
t.Parallel()

witStateAndConnectionToContainer(t, func(state *state.State, db *sql.DB) {
ctx := context.Background()

tests := []struct {
name string
createStmt string
wantSchema *schema.Schema
}{
{
name: "one table",
createStmt: "CREATE TABLE public.table1 (id int)",
wantSchema: &schema.Schema{
Name: "public",
Tables: map[string]schema.Table{
"table1": {
Name: "table1",
Columns: map[string]schema.Column{
"id": {
Name: "id",
Type: "integer",
Nullable: true,
},
},
},
},
},
},
{
name: "unique, not null",
createStmt: "CREATE TABLE public.table1 (id int NOT NULL, CONSTRAINT id_unique UNIQUE(id))",
wantSchema: &schema.Schema{
Name: "public",
Tables: map[string]schema.Table{
"table1": {
Name: "table1",
Columns: map[string]schema.Column{
"id": {
Name: "id",
Type: "integer",
Nullable: false,
Unique: true,
},
},
Indexes: map[string]schema.Index{
"id_unique": {
Name: "id_unique",
},
},
},
},
},
},
{
name: "foreign key",
createStmt: "CREATE TABLE public.table1 (id int PRIMARY KEY); CREATE TABLE public.table2 (fk int NOT NULL, CONSTRAINT fk_fkey FOREIGN KEY (fk) REFERENCES public.table1 (id))",
wantSchema: &schema.Schema{
Name: "public",
Tables: map[string]schema.Table{
"table1": {
Name: "table1",
Columns: map[string]schema.Column{
"id": {
Name: "id",
Type: "integer",
Nullable: false,
Unique: true,
},
},
PrimaryKey: []string{"id"},
Indexes: map[string]schema.Index{
"table1_pkey": {
Name: "table1_pkey",
},
},
},
"table2": {
Name: "table2",
Columns: map[string]schema.Column{
"fk": {
Name: "fk",
Type: "integer",
Nullable: false,
},
},
ForeignKeys: map[string]schema.ForeignKey{
"fk_fkey": {
Name: "fk_fkey",
Columns: []string{"fk"},
ReferencedTable: "table1",
ReferencedColumns: []string{"id"},
},
},
},
},
},
},
}

// init the state
if err := state.Init(ctx); err != nil {
t.Fatal(err)
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if _, err := db.ExecContext(ctx, "DROP SCHEMA public CASCADE; CREATE SCHEMA public"); err != nil {
t.Fatal(err)
}

if _, err := db.ExecContext(ctx, tt.createStmt); err != nil {
t.Fatal(err)
}

gotSchema, err := state.ReadSchema(ctx, "public")
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(tt.wantSchema, gotSchema, cmpopts.IgnoreFields(schema.Table{}, "OID")); diff != "" {
t.Errorf("expected schema mismatch (-want +got):\n%s", diff)
}
})
}
})
}

func witStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.DB)) {
t.Helper()
ctx := context.Background()
Expand Down
Loading