diff --git a/core/sequences/collection.go b/core/sequences/collection.go index c4f2853002..a450d380fc 100644 --- a/core/sequences/collection.go +++ b/core/sequences/collection.go @@ -50,19 +50,20 @@ const ( // Sequence represents a single sequence within the pg_sequence table. type Sequence struct { - Id id.Sequence - DataTypeID id.Type - Persistence Persistence - Start int64 - Current int64 - Increment int64 - Minimum int64 - Maximum int64 - Cache int64 - Cycle bool - IsAtEnd bool - OwnerTable id.Table - OwnerColumn string + Id id.Sequence + DataTypeID id.Type + Persistence Persistence + Start int64 + Current int64 + Increment int64 + Minimum int64 + Maximum int64 + Cache int64 + Cycle bool + IsAtEnd bool + HasBeenCalled bool + OwnerTable id.Table + OwnerColumn string } var _ objinterface.Collection = (*Collection)(nil) @@ -297,6 +298,7 @@ func (pgs *Collection) SetVal(ctx context.Context, name id.Sequence, newValue in } seq.Current = newValue seq.IsAtEnd = false + seq.HasBeenCalled = false if autoAdvance { _, err := seq.nextValForSequence() return err @@ -450,6 +452,7 @@ func (sequence *Sequence) nextValForSequence() (int64, error) { } } // We'll return the current value, so everything after this sets the value for the next call + sequence.HasBeenCalled = true valueToReturn := sequence.Current // Increment the current value if sequence.Increment > 0 { diff --git a/core/sequences/collection_funcs.go b/core/sequences/collection_funcs.go index 49023efbb8..a2cb9bea59 100644 --- a/core/sequences/collection_funcs.go +++ b/core/sequences/collection_funcs.go @@ -102,6 +102,9 @@ func (*Collection) HandleMerge(ctx context.Context, mro merge.MergeRootObject) ( return ourCurrent } }) + mergedSeq.HasBeenCalled = merge2.ResolveMergeValues(ourSeq.HasBeenCalled, theirSeq.HasBeenCalled, ancSeq.HasBeenCalled, hasAncestor, func(ourcalled, theirCalled bool) bool { + return ourcalled || theirCalled + }) return &mergedSeq, &merge.MergeStats{ Operation: merge.TableModified, Adds: 0, diff --git a/core/sequences/root_object.go b/core/sequences/root_object.go index 9ac202103c..35750e0327 100644 --- a/core/sequences/root_object.go +++ b/core/sequences/root_object.go @@ -27,18 +27,19 @@ import ( ) const ( - FIELD_NAME_DATA_TYPE = "data_type" - FIELD_NAME_PERSISTENCE = "persistence" - FIELD_NAME_START = "start" - FIELD_NAME_CURRENT = "current" - FIELD_NAME_INCREMENT = "increment" - FIELD_NAME_MINIMUM = "minimum" - FIELD_NAME_MAXIMUM = "maximum" - FIELD_NAME_CACHE = "cache" - FIELD_NAME_CYCLE = "cycle" - FIELD_NAME_IS_AT_END = "is_at_end" - FIELD_NAME_OWNER_TABLE = "owner_table" - FIELD_NAME_OWNER_COLUMN = "owner_column" + FIELD_NAME_DATA_TYPE = "data_type" + FIELD_NAME_PERSISTENCE = "persistence" + FIELD_NAME_START = "start" + FIELD_NAME_CURRENT = "current" + FIELD_NAME_INCREMENT = "increment" + FIELD_NAME_MINIMUM = "minimum" + FIELD_NAME_MAXIMUM = "maximum" + FIELD_NAME_CACHE = "cache" + FIELD_NAME_CYCLE = "cycle" + FIELD_NAME_IS_AT_END = "is_at_end" + FIELD_NAME_HAS_BEEN_CALLED = "has_been_called" + FIELD_NAME_OWNER_TABLE = "owner_table" + FIELD_NAME_OWNER_COLUMN = "owner_column" ) // DeserializeRootObject implements the interface objinterface.Collection. @@ -109,6 +110,18 @@ func (pgs *Collection) DiffRootObjects(ctx context.Context, fromHash string, o o ours.Current = diff.OurValue.(int64) } } + if ours.HasBeenCalled != theirs.HasBeenCalled { + diff := objinterface.RootObjectDiff{ + Type: pgtypes.Bool, + FromHash: fromHash, + FieldName: FIELD_NAME_HAS_BEEN_CALLED, + } + if pgmerge.DiffValues(&diff, ours.HasBeenCalled, theirs.HasBeenCalled, ancestor.HasBeenCalled, hasAncestor) { + diffs = append(diffs, diff) + } else { + ours.HasBeenCalled = diff.OurValue.(bool) + } + } if ours.Increment != theirs.Increment { diff := objinterface.RootObjectDiff{ Type: pgtypes.Int64, @@ -239,6 +252,8 @@ func (pgs *Collection) GetFieldType(ctx context.Context, fieldName string) *pgty return pgtypes.Bool case FIELD_NAME_IS_AT_END: return pgtypes.Bool + case FIELD_NAME_HAS_BEEN_CALLED: + return pgtypes.Bool case FIELD_NAME_OWNER_TABLE: return pgtypes.Text case FIELD_NAME_OWNER_COLUMN: diff --git a/core/sequences/serialization.go b/core/sequences/serialization.go index 0faea946c5..415888c0fc 100644 --- a/core/sequences/serialization.go +++ b/core/sequences/serialization.go @@ -31,7 +31,7 @@ func (sequence *Sequence) Serialize(ctx context.Context) ([]byte, error) { // Create the writer writer := utils.NewWriter(256) - writer.VariableUint(0) // Version + writer.VariableUint(1) // Version // Write the sequence data writer.Id(sequence.Id.AsId()) writer.Id(sequence.DataTypeID.AsId()) @@ -44,6 +44,7 @@ func (sequence *Sequence) Serialize(ctx context.Context) ([]byte, error) { writer.Int64(sequence.Cache) writer.Bool(sequence.Cycle) writer.Bool(sequence.IsAtEnd) + writer.Bool(sequence.HasBeenCalled) writer.Id(sequence.OwnerTable.AsId()) writer.String(sequence.OwnerColumn) // Returns the data @@ -52,13 +53,13 @@ func (sequence *Sequence) Serialize(ctx context.Context) ([]byte, error) { // DeserializeSequence returns the Sequence that was serialized in the byte slice. Returns an empty Sequence if data is // nil or empty. -func DeserializeSequence(ctx context.Context, data []byte) (*Sequence, error) { +func DeserializeSequence(_ context.Context, data []byte) (*Sequence, error) { if len(data) == 0 { return nil, nil } reader := utils.NewReader(data) version := reader.VariableUint() - if version != 0 { + if version > 1 { return nil, errors.Errorf("version %d of sequences is not supported, please upgrade the server", version) } @@ -75,6 +76,9 @@ func DeserializeSequence(ctx context.Context, data []byte) (*Sequence, error) { sequence.Cache = reader.Int64() sequence.Cycle = reader.Bool() sequence.IsAtEnd = reader.Bool() + if version >= 1 { + sequence.HasBeenCalled = reader.Bool() + } sequence.OwnerTable = id.Table(reader.Id()) sequence.OwnerColumn = reader.String() if !reader.IsEmpty() { diff --git a/server/tables/pgcatalog/pg_catalog_cache.go b/server/tables/pgcatalog/pg_catalog_cache.go index d9a6ce1237..9ec62859fe 100644 --- a/server/tables/pgcatalog/pg_catalog_cache.go +++ b/server/tables/pgcatalog/pg_catalog_cache.go @@ -20,7 +20,6 @@ import ( "github.com/dolthub/doltgresql/core" "github.com/dolthub/doltgresql/core/id" - "github.com/dolthub/doltgresql/core/sequences" "github.com/dolthub/doltgresql/server/functions" ) @@ -60,9 +59,8 @@ type pgCatalogCache struct { // pg_index / pg_indexes pgIndexes *pgIndexCache - // pg_sequence - sequences []*sequences.Sequence - sequenceOids []id.Id + // pg_sequence / pg_sequences + sequences []*pgSequence // pg_attrdef attrdefCols []functions.ItemColumnDefault diff --git a/server/tables/pgcatalog/pg_sequence.go b/server/tables/pgcatalog/pg_sequence.go index 2050808dde..8768c89723 100644 --- a/server/tables/pgcatalog/pg_sequence.go +++ b/server/tables/pgcatalog/pg_sequence.go @@ -37,6 +37,13 @@ func InitPgSequence() { // PgSequenceHandler is the handler for the pg_sequence table. type PgSequenceHandler struct{} +// pgSequence represents a row in the pg_sequence table and pg_sequences view +type pgSequence struct { + sequence *sequences.Sequence + schema string + oid id.Id +} + var _ tables.Handler = PgSequenceHandler{} // Name implements the interface tables.Handler. @@ -45,7 +52,7 @@ func (p PgSequenceHandler) Name() string { } // RowIter implements the interface tables.Handler. -func (p PgSequenceHandler) RowIter(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) { +func (p PgSequenceHandler) RowIter(ctx *sql.Context, _ sql.Partition) (sql.RowIter, error) { // Use cached data from this process if it exists pgCatalogCache, err := getPgCatalogCache(ctx) if err != nil { @@ -53,29 +60,42 @@ func (p PgSequenceHandler) RowIter(ctx *sql.Context, partition sql.Partition) (s } if pgCatalogCache.sequences == nil { - var sequences []*sequences.Sequence - var sequenceOids []id.Id - err := functions.IterateCurrentDatabase(ctx, functions.Callbacks{ - Sequence: func(ctx *sql.Context, _ functions.ItemSchema, sequence functions.ItemSequence) (cont bool, err error) { - sequences = append(sequences, sequence.Item) - sequenceOids = append(sequenceOids, sequence.OID.AsId()) - return true, nil - }, - }) + err = cachePgSequences(ctx, pgCatalogCache) if err != nil { return nil, err } - pgCatalogCache.sequences = sequences - pgCatalogCache.sequenceOids = sequenceOids } return &pgSequenceRowIter{ sequences: pgCatalogCache.sequences, - oids: pgCatalogCache.sequenceOids, idx: 0, }, nil } +func cachePgSequences(ctx *sql.Context, pgCatalogCache *pgCatalogCache) error { + var sequences []*pgSequence + + err := functions.IterateCurrentDatabase(ctx, functions.Callbacks{ + Sequence: func(ctx *sql.Context, schema functions.ItemSchema, sequence functions.ItemSequence) (cont bool, err error) { + pgSeq := &pgSequence{ + sequence: sequence.Item, + schema: schema.Item.SchemaName(), + oid: sequence.OID.AsId(), + } + + sequences = append(sequences, pgSeq) + return true, nil + }, + }) + if err != nil { + return err + } + + pgCatalogCache.sequences = sequences + + return nil +} + // Schema implements the interface tables.Handler. func (p PgSequenceHandler) PkSchema() sql.PrimaryKeySchema { return sql.PrimaryKeySchema{ @@ -98,8 +118,7 @@ var pgSequenceSchema = sql.Schema{ // pgSequenceRowIter is the sql.RowIter for the pg_sequence table. type pgSequenceRowIter struct { - sequences []*sequences.Sequence - oids []id.Id + sequences []*pgSequence idx int } @@ -111,8 +130,8 @@ func (iter *pgSequenceRowIter) Next(ctx *sql.Context) (sql.Row, error) { return nil, io.EOF } iter.idx++ - sequence := iter.sequences[iter.idx-1] - oid := iter.oids[iter.idx-1] + sequence := iter.sequences[iter.idx-1].sequence + oid := iter.sequences[iter.idx-1].oid return sql.Row{ oid, // seqrelid sequence.DataTypeID.AsId(), // seqtypid diff --git a/server/tables/pgcatalog/pg_sequences.go b/server/tables/pgcatalog/pg_sequences.go index cc6a06f10f..12e156357d 100644 --- a/server/tables/pgcatalog/pg_sequences.go +++ b/server/tables/pgcatalog/pg_sequences.go @@ -42,9 +42,23 @@ func (p PgSequencesHandler) Name() string { } // RowIter implements the interface tables.Handler. -func (p PgSequencesHandler) RowIter(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) { - // TODO: Implement pg_sequences row iter - return emptyRowIter() +func (p PgSequencesHandler) RowIter(ctx *sql.Context, _ sql.Partition) (sql.RowIter, error) { + pgCatalogCache, err := getPgCatalogCache(ctx) + if err != nil { + return nil, err + } + + if pgCatalogCache.sequences == nil { + err = cachePgSequences(ctx, pgCatalogCache) + if err != nil { + return nil, err + } + } + + return &pgSequencesRowIter{ + sequences: pgCatalogCache.sequences, + idx: 0, + }, nil } // Schema implements the interface tables.Handler. @@ -72,16 +86,46 @@ var pgSequencesSchema = sql.Schema{ // pgSequencesRowIter is the sql.RowIter for the pg_sequences table. type pgSequencesRowIter struct { + sequences []*pgSequence + idx int } var _ sql.RowIter = (*pgSequencesRowIter)(nil) // Next implements the interface sql.RowIter. -func (iter *pgSequencesRowIter) Next(ctx *sql.Context) (sql.Row, error) { - return nil, io.EOF +func (iter *pgSequencesRowIter) Next(_ *sql.Context) (sql.Row, error) { + if iter.idx >= len(iter.sequences) { + return nil, io.EOF + } + sequence := iter.sequences[iter.idx].sequence + schemaName := iter.sequences[iter.idx].schema + iter.idx++ + + var lastValue interface{} + if sequence.HasBeenCalled { + if sequence.IsAtEnd { + lastValue = sequence.Current + } else { + lastValue = sequence.Current - sequence.Increment + } + } + + return sql.Row{ + schemaName, // schemaname + sequence.Id.SequenceName(), // sequencename + nil, // TODO sequenceowner + sequence.DataTypeID.TypeName(), // data_type + sequence.Start, // start_value + sequence.Minimum, // min_value + sequence.Maximum, // max_value + sequence.Increment, // increment_by + sequence.Cycle, // cycle + sequence.Cache, // cache_size + lastValue, // TODO last_value + }, nil } // Close implements the interface sql.RowIter. -func (iter *pgSequencesRowIter) Close(ctx *sql.Context) error { +func (iter *pgSequencesRowIter) Close(_ *sql.Context) error { return nil } diff --git a/testing/go/pgcatalog_test.go b/testing/go/pgcatalog_test.go index 6cedcd412d..ec74a744b5 100644 --- a/testing/go/pgcatalog_test.go +++ b/testing/go/pgcatalog_test.go @@ -2832,7 +2832,7 @@ func TestPgSeclabels(t *testing.T) { func TestPgSequences(t *testing.T) { RunScripts(t, []ScriptTest{ { - Name: "pg_sequences", + Name: "select from pg_sequences", Assertions: []ScriptTestAssertion{ { Query: `SELECT * FROM "pg_catalog"."pg_sequences";`, @@ -2852,6 +2852,97 @@ func TestPgSequences(t *testing.T) { }, }, }, + { + Name: "default sequence values", + SetUpScript: []string{ + "CREATE SEQUENCE test;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT schemaname, sequencename, sequenceowner, data_type, start_value, min_value, " + + "max_value, increment_by, cycle, cache_size, last_value FROM pg_sequences", + Expected: []sql.Row{ + {"public", "test", nil, "int8", int64(1), int64(1), int64(9223372036854775807), int64(1), "f", int64(1), nil}, + }, + }, + }, + }, + { + Name: "custom sequence values", + SetUpScript: []string{ + "CREATE SEQUENCE test as integer start 10 minvalue 5 maxvalue 11 increment 2 cycle;", + "CREATE SCHEMA test_schema", + "CREATE SEQUENCE test_schema.secondseq", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT * FROM pg_sequences where sequencename = 'test'", + Expected: []sql.Row{ + {"public", "test", nil, "int4", int64(10), int64(5), int64(11), int64(2), "t", int64(1), nil}, + }, + }, + { + Query: "SELECT * FROM pg_sequences where sequencename = 'secondseq'", + Expected: []sql.Row{ + {"test_schema", "secondseq", nil, "int8", int64(1), int64(1), int64(9223372036854775807), int64(1), "f", int64(1), nil}, + }, + }, + }, + }, + { + Name: "multiple sequences", + SetUpScript: []string{ + "CREATE SEQUENCE c", + "CREATE SEQUENCE a", + "CREATE SEQUENCE b", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT schemaname, sequencename from pg_sequences", + Expected: []sql.Row{ + {"public", "a"}, + {"public", "b"}, + {"public", "c"}, + }, + }, + }, + }, + { + Name: "table with serial column", + SetUpScript: []string{ + "CREATE TABLE test (id serial)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT sequencename, data_type, start_value, max_value FROM pg_sequences", + Expected: []sql.Row{ + {"test_id_seq", "int4", int64(1), int64(2147483647)}, + }, + }, + }, + }, + { + Name: "last value set correctly", + SetUpScript: []string{ + "CREATE SEQUENCE seq1", + "CREATE SEQUENCE seq2", + "select nextval('seq2')", + "CREATE SEQUENCE seq3", + "CREATE SEQUENCE seq4", + "select setval('seq3', 2, false), setval('seq4', 2, true)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT sequencename, last_value FROM pg_sequences", + Expected: []sql.Row{ + {"seq1", nil}, + {"seq2", 1}, + {"seq3", nil}, + {"seq4", 2}, + }, + }, + }, + }, }) }