Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions core/sequences/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,20 @@ const (

// Sequence represents a single sequence within the pg_sequence table.
type Sequence struct {
Id id.Sequence
DataTypeID id.Type
Persistence Persistence
Start int64
Current int64
Increment int64
Minimum int64
Maximum int64
Cache int64
Cycle bool
IsAtEnd bool
OwnerTable id.Table
OwnerColumn string
Id id.Sequence
DataTypeID id.Type
Persistence Persistence
Start int64
Current int64
Increment int64
Minimum int64
Maximum int64
Cache int64
Cycle bool
IsAtEnd bool
HasBeenCalled bool
OwnerTable id.Table
OwnerColumn string
}

var _ objinterface.Collection = (*Collection)(nil)
Expand Down Expand Up @@ -297,6 +298,7 @@ func (pgs *Collection) SetVal(ctx context.Context, name id.Sequence, newValue in
}
seq.Current = newValue
seq.IsAtEnd = false
seq.HasBeenCalled = false
if autoAdvance {
_, err := seq.nextValForSequence()
return err
Expand Down Expand Up @@ -450,6 +452,7 @@ func (sequence *Sequence) nextValForSequence() (int64, error) {
}
}
// We'll return the current value, so everything after this sets the value for the next call
sequence.HasBeenCalled = true
valueToReturn := sequence.Current
// Increment the current value
if sequence.Increment > 0 {
Expand Down
3 changes: 3 additions & 0 deletions core/sequences/collection_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ func (*Collection) HandleMerge(ctx context.Context, mro merge.MergeRootObject) (
return ourCurrent
}
})
mergedSeq.HasBeenCalled = merge2.ResolveMergeValues(ourSeq.HasBeenCalled, theirSeq.HasBeenCalled, ancSeq.HasBeenCalled, hasAncestor, func(ourcalled, theirCalled bool) bool {
return ourcalled || theirCalled
})
return &mergedSeq, &merge.MergeStats{
Operation: merge.TableModified,
Adds: 0,
Expand Down
39 changes: 27 additions & 12 deletions core/sequences/root_object.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,19 @@ import (
)

const (
FIELD_NAME_DATA_TYPE = "data_type"
FIELD_NAME_PERSISTENCE = "persistence"
FIELD_NAME_START = "start"
FIELD_NAME_CURRENT = "current"
FIELD_NAME_INCREMENT = "increment"
FIELD_NAME_MINIMUM = "minimum"
FIELD_NAME_MAXIMUM = "maximum"
FIELD_NAME_CACHE = "cache"
FIELD_NAME_CYCLE = "cycle"
FIELD_NAME_IS_AT_END = "is_at_end"
FIELD_NAME_OWNER_TABLE = "owner_table"
FIELD_NAME_OWNER_COLUMN = "owner_column"
FIELD_NAME_DATA_TYPE = "data_type"
FIELD_NAME_PERSISTENCE = "persistence"
FIELD_NAME_START = "start"
FIELD_NAME_CURRENT = "current"
FIELD_NAME_INCREMENT = "increment"
FIELD_NAME_MINIMUM = "minimum"
FIELD_NAME_MAXIMUM = "maximum"
FIELD_NAME_CACHE = "cache"
FIELD_NAME_CYCLE = "cycle"
FIELD_NAME_IS_AT_END = "is_at_end"
FIELD_NAME_HAS_BEEN_CALLED = "has_been_called"
FIELD_NAME_OWNER_TABLE = "owner_table"
FIELD_NAME_OWNER_COLUMN = "owner_column"
)

// DeserializeRootObject implements the interface objinterface.Collection.
Expand Down Expand Up @@ -109,6 +110,18 @@ func (pgs *Collection) DiffRootObjects(ctx context.Context, fromHash string, o o
ours.Current = diff.OurValue.(int64)
}
}
if ours.HasBeenCalled != theirs.HasBeenCalled {
diff := objinterface.RootObjectDiff{
Type: pgtypes.Bool,
FromHash: fromHash,
FieldName: FIELD_NAME_HAS_BEEN_CALLED,
}
if pgmerge.DiffValues(&diff, ours.HasBeenCalled, theirs.HasBeenCalled, ancestor.HasBeenCalled, hasAncestor) {
diffs = append(diffs, diff)
} else {
ours.HasBeenCalled = diff.OurValue.(bool)
}
}
if ours.Increment != theirs.Increment {
diff := objinterface.RootObjectDiff{
Type: pgtypes.Int64,
Expand Down Expand Up @@ -239,6 +252,8 @@ func (pgs *Collection) GetFieldType(ctx context.Context, fieldName string) *pgty
return pgtypes.Bool
case FIELD_NAME_IS_AT_END:
return pgtypes.Bool
case FIELD_NAME_HAS_BEEN_CALLED:
return pgtypes.Bool
case FIELD_NAME_OWNER_TABLE:
return pgtypes.Text
case FIELD_NAME_OWNER_COLUMN:
Expand Down
10 changes: 7 additions & 3 deletions core/sequences/serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (sequence *Sequence) Serialize(ctx context.Context) ([]byte, error) {

// Create the writer
writer := utils.NewWriter(256)
writer.VariableUint(0) // Version
writer.VariableUint(1) // Version
// Write the sequence data
writer.Id(sequence.Id.AsId())
writer.Id(sequence.DataTypeID.AsId())
Expand All @@ -44,6 +44,7 @@ func (sequence *Sequence) Serialize(ctx context.Context) ([]byte, error) {
writer.Int64(sequence.Cache)
writer.Bool(sequence.Cycle)
writer.Bool(sequence.IsAtEnd)
writer.Bool(sequence.HasBeenCalled)
writer.Id(sequence.OwnerTable.AsId())
writer.String(sequence.OwnerColumn)
// Returns the data
Expand All @@ -52,13 +53,13 @@ func (sequence *Sequence) Serialize(ctx context.Context) ([]byte, error) {

// DeserializeSequence returns the Sequence that was serialized in the byte slice. Returns an empty Sequence if data is
// nil or empty.
func DeserializeSequence(ctx context.Context, data []byte) (*Sequence, error) {
func DeserializeSequence(_ context.Context, data []byte) (*Sequence, error) {
if len(data) == 0 {
return nil, nil
}
reader := utils.NewReader(data)
version := reader.VariableUint()
if version != 0 {
if version > 1 {
return nil, errors.Errorf("version %d of sequences is not supported, please upgrade the server", version)
}

Expand All @@ -75,6 +76,9 @@ func DeserializeSequence(ctx context.Context, data []byte) (*Sequence, error) {
sequence.Cache = reader.Int64()
sequence.Cycle = reader.Bool()
sequence.IsAtEnd = reader.Bool()
if version >= 1 {
sequence.HasBeenCalled = reader.Bool()
}
sequence.OwnerTable = id.Table(reader.Id())
sequence.OwnerColumn = reader.String()
if !reader.IsEmpty() {
Expand Down
6 changes: 2 additions & 4 deletions server/tables/pgcatalog/pg_catalog_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (

"github.com/dolthub/doltgresql/core"
"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/core/sequences"
"github.com/dolthub/doltgresql/server/functions"
)

Expand Down Expand Up @@ -60,9 +59,8 @@ type pgCatalogCache struct {
// pg_index / pg_indexes
pgIndexes *pgIndexCache

// pg_sequence
sequences []*sequences.Sequence
sequenceOids []id.Id
// pg_sequence / pg_sequences
sequences []*pgSequence

// pg_attrdef
attrdefCols []functions.ItemColumnDefault
Expand Down
53 changes: 36 additions & 17 deletions server/tables/pgcatalog/pg_sequence.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ func InitPgSequence() {
// PgSequenceHandler is the handler for the pg_sequence table.
type PgSequenceHandler struct{}

// pgSequence represents a row in the pg_sequence table and pg_sequences view
type pgSequence struct {
sequence *sequences.Sequence
schema string
oid id.Id
}

var _ tables.Handler = PgSequenceHandler{}

// Name implements the interface tables.Handler.
Expand All @@ -45,37 +52,50 @@ func (p PgSequenceHandler) Name() string {
}

// RowIter implements the interface tables.Handler.
func (p PgSequenceHandler) RowIter(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) {
func (p PgSequenceHandler) RowIter(ctx *sql.Context, _ sql.Partition) (sql.RowIter, error) {
// Use cached data from this process if it exists
pgCatalogCache, err := getPgCatalogCache(ctx)
if err != nil {
return nil, err
}

if pgCatalogCache.sequences == nil {
var sequences []*sequences.Sequence
var sequenceOids []id.Id
err := functions.IterateCurrentDatabase(ctx, functions.Callbacks{
Sequence: func(ctx *sql.Context, _ functions.ItemSchema, sequence functions.ItemSequence) (cont bool, err error) {
sequences = append(sequences, sequence.Item)
sequenceOids = append(sequenceOids, sequence.OID.AsId())
return true, nil
},
})
err = cachePgSequences(ctx, pgCatalogCache)
if err != nil {
return nil, err
}
pgCatalogCache.sequences = sequences
pgCatalogCache.sequenceOids = sequenceOids
}

return &pgSequenceRowIter{
sequences: pgCatalogCache.sequences,
oids: pgCatalogCache.sequenceOids,
idx: 0,
}, nil
}

func cachePgSequences(ctx *sql.Context, pgCatalogCache *pgCatalogCache) error {
var sequences []*pgSequence

err := functions.IterateCurrentDatabase(ctx, functions.Callbacks{
Sequence: func(ctx *sql.Context, schema functions.ItemSchema, sequence functions.ItemSequence) (cont bool, err error) {
pgSeq := &pgSequence{
sequence: sequence.Item,
schema: schema.Item.SchemaName(),
oid: sequence.OID.AsId(),
}

sequences = append(sequences, pgSeq)
return true, nil
},
})
if err != nil {
return err
}

pgCatalogCache.sequences = sequences

return nil
}

// Schema implements the interface tables.Handler.
func (p PgSequenceHandler) PkSchema() sql.PrimaryKeySchema {
return sql.PrimaryKeySchema{
Expand All @@ -98,8 +118,7 @@ var pgSequenceSchema = sql.Schema{

// pgSequenceRowIter is the sql.RowIter for the pg_sequence table.
type pgSequenceRowIter struct {
sequences []*sequences.Sequence
oids []id.Id
sequences []*pgSequence
idx int
}

Expand All @@ -111,8 +130,8 @@ func (iter *pgSequenceRowIter) Next(ctx *sql.Context) (sql.Row, error) {
return nil, io.EOF
}
iter.idx++
sequence := iter.sequences[iter.idx-1]
oid := iter.oids[iter.idx-1]
sequence := iter.sequences[iter.idx-1].sequence
oid := iter.sequences[iter.idx-1].oid
return sql.Row{
oid, // seqrelid
sequence.DataTypeID.AsId(), // seqtypid
Expand Down
56 changes: 50 additions & 6 deletions server/tables/pgcatalog/pg_sequences.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,23 @@ func (p PgSequencesHandler) Name() string {
}

// RowIter implements the interface tables.Handler.
func (p PgSequencesHandler) RowIter(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) {
// TODO: Implement pg_sequences row iter
return emptyRowIter()
func (p PgSequencesHandler) RowIter(ctx *sql.Context, _ sql.Partition) (sql.RowIter, error) {
pgCatalogCache, err := getPgCatalogCache(ctx)
if err != nil {
return nil, err
}

if pgCatalogCache.sequences == nil {
err = cachePgSequences(ctx, pgCatalogCache)
if err != nil {
return nil, err
}
}

return &pgSequencesRowIter{
sequences: pgCatalogCache.sequences,
idx: 0,
}, nil
}

// Schema implements the interface tables.Handler.
Expand Down Expand Up @@ -72,16 +86,46 @@ var pgSequencesSchema = sql.Schema{

// pgSequencesRowIter is the sql.RowIter for the pg_sequences table.
type pgSequencesRowIter struct {
sequences []*pgSequence
idx int
}

var _ sql.RowIter = (*pgSequencesRowIter)(nil)

// Next implements the interface sql.RowIter.
func (iter *pgSequencesRowIter) Next(ctx *sql.Context) (sql.Row, error) {
return nil, io.EOF
func (iter *pgSequencesRowIter) Next(_ *sql.Context) (sql.Row, error) {
if iter.idx >= len(iter.sequences) {
return nil, io.EOF
}
sequence := iter.sequences[iter.idx].sequence
schemaName := iter.sequences[iter.idx].schema
iter.idx++

var lastValue interface{}
if sequence.HasBeenCalled {
if sequence.IsAtEnd {
lastValue = sequence.Current
} else {
lastValue = sequence.Current - sequence.Increment
}
}

return sql.Row{
schemaName, // schemaname
sequence.Id.SequenceName(), // sequencename
nil, // TODO sequenceowner
sequence.DataTypeID.TypeName(), // data_type
sequence.Start, // start_value
sequence.Minimum, // min_value
sequence.Maximum, // max_value
sequence.Increment, // increment_by
sequence.Cycle, // cycle
sequence.Cache, // cache_size
lastValue, // TODO last_value
}, nil
}

// Close implements the interface sql.RowIter.
func (iter *pgSequencesRowIter) Close(ctx *sql.Context) error {
func (iter *pgSequencesRowIter) Close(_ *sql.Context) error {
return nil
}
Loading
Loading