diff --git a/core/casts/collection.go b/core/casts/collection.go new file mode 100644 index 0000000000..177cf343cf --- /dev/null +++ b/core/casts/collection.go @@ -0,0 +1,643 @@ +// Copyright 2026 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casts + +import ( + "context" + "fmt" + "io" + "strings" + + "github.com/cockroachdb/errors" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/store/hash" + "github.com/dolthub/dolt/go/store/prolly" + "github.com/dolthub/dolt/go/store/prolly/tree" + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/doltgresql/core/id" + "github.com/dolthub/doltgresql/core/rootobject/objinterface" + pgtypes "github.com/dolthub/doltgresql/server/types" +) + +// Collection contains a collection of casts. +type Collection struct { + mapHash hash.Hash // This is cached so that we don't have to calculate the hash every time + underlyingMap prolly.AddressMap + ns tree.NodeStore +} + +// CastType is the type of the cast, indicating which contexts it may be called in. +type CastType uint8 + +const ( + CastType_Explicit CastType = 0 + CastType_Assignment CastType = 1 + CastType_Implicit CastType = 2 +) + +// builtInCasts contains all casts that are built into the database by default. +var builtInCasts = map[id.Cast]Cast{} + +// Cast represents a cast between two types. +type Cast struct { + ID id.Cast + CastType CastType + Function id.Function + BuiltIn pgtypes.TypeCastFunction + UseInOut bool +} + +var _ objinterface.Collection = (*Collection)(nil) +var _ objinterface.RootObject = Cast{} + +// NewCollection returns a new Collection. +func NewCollection(ctx context.Context, underlyingMap prolly.AddressMap, ns tree.NodeStore) (*Collection, error) { + collection := &Collection{ + mapHash: underlyingMap.HashOf(), + underlyingMap: underlyingMap, + ns: ns, + } + return collection, nil +} + +// GetExplicitCast returns the explicit type cast function that will cast the source type to the target type. Returns +// a Cast with an invalid ID if such a cast is not valid. +func (pgc *Collection) GetExplicitCast(ctx context.Context, sourceType *pgtypes.DoltgresType, targetType *pgtypes.DoltgresType) (Cast, error) { + castID := id.NewCast(sourceType.ID, targetType.ID) + c, err := pgc.getCast(ctx, castID, sourceType, targetType, CastType_Explicit) + if err != nil { + return Cast{}, err + } + if c.ID.IsValid() { + return c, nil + } + // We check for the identity and sizing casts after checking the maps, as the identity may be overridden by a user. + if cast := pgc.getSizingOrIdentityCast(sourceType, targetType, CastType_Explicit); cast.ID.IsValid() { + return cast, nil + } + // We then check for a record to composite cast + if recordCast := pgc.getRecordCast(sourceType, targetType, CastType_Explicit); recordCast.ID.IsValid() { + return recordCast, nil + } + // All types have a built-in explicit cast from string types: https://www.postgresql.org/docs/15/sql-createcast.html + if sourceType.TypCategory == pgtypes.TypeCategory_StringTypes { + return Cast{ + ID: castID, + CastType: CastType_Explicit, + Function: id.NullFunction, + UseInOut: true, + }, nil + } else if targetType.TypCategory == pgtypes.TypeCategory_StringTypes { + // All types have a built-in assignment cast to string types, which we can reference in an explicit cast + return Cast{ + ID: castID, + CastType: CastType_Explicit, + Function: id.NullFunction, + UseInOut: true, + }, nil + } + // It is always valid to convert from the `unknown` type + if sourceType.ID == pgtypes.Unknown.ID { + return Cast{ + ID: castID, + CastType: CastType_Explicit, + Function: id.NullFunction, + UseInOut: true, + }, nil + } + return Cast{}, nil +} + +// GetAssignmentCast returns the assignment type cast function that will cast the source type to the target type. +// Returns a Cast with an invalid ID if such a cast is not valid. +func (pgc *Collection) GetAssignmentCast(ctx context.Context, sourceType *pgtypes.DoltgresType, targetType *pgtypes.DoltgresType) (Cast, error) { + castID := id.NewCast(sourceType.ID, targetType.ID) + c, err := pgc.getCast(ctx, castID, sourceType, targetType, CastType_Assignment) + if err != nil { + return Cast{}, err + } + if c.ID.IsValid() { + if c.CastType == CastType_Explicit { + return Cast{}, nil + } + return c, nil + } + // We check for the identity and sizing casts after checking the maps, as the identity may be overridden by a user. + if cast := pgc.getSizingOrIdentityCast(sourceType, targetType, CastType_Assignment); cast.ID.IsValid() { + return cast, nil + } + // We then check for a record to composite cast + if recordCast := pgc.getRecordCast(sourceType, targetType, CastType_Assignment); recordCast.ID.IsValid() { + return recordCast, nil + } + // All types have a built-in assignment cast to string types: https://www.postgresql.org/docs/15/sql-createcast.html + if targetType.TypCategory == pgtypes.TypeCategory_StringTypes { + return Cast{ + ID: castID, + CastType: CastType_Assignment, + Function: id.NullFunction, + UseInOut: true, + }, nil + } + // It is always valid to convert from the `unknown` type + if sourceType.ID == pgtypes.Unknown.ID { + return Cast{ + ID: castID, + CastType: CastType_Assignment, + Function: id.NullFunction, + UseInOut: true, + }, nil + } + return Cast{}, nil +} + +// GetImplicitCast returns the implicit type cast function that will cast the source type to the target type. Returns a +// Cast with an invalid ID if such a cast is not valid. +func (pgc *Collection) GetImplicitCast(ctx context.Context, sourceType *pgtypes.DoltgresType, targetType *pgtypes.DoltgresType) (Cast, error) { + castID := id.NewCast(sourceType.ID, targetType.ID) + c, err := pgc.getCast(ctx, castID, sourceType, targetType, CastType_Implicit) + if err != nil { + return Cast{}, err + } + if c.ID.IsValid() { + if c.CastType == CastType_Implicit { + return c, nil + } + return Cast{}, nil + } + // We check for the identity and sizing casts after checking the maps, as the identity may be overridden by a user. + if cast := pgc.getSizingOrIdentityCast(sourceType, targetType, CastType_Implicit); cast.ID.IsValid() { + return cast, nil + } + // We then check for a record to composite cast + if recordCast := pgc.getRecordCast(sourceType, targetType, CastType_Implicit); recordCast.ID.IsValid() { + return recordCast, nil + } + // It is always valid to convert from the `unknown` type + if sourceType.ID == pgtypes.Unknown.ID { + return Cast{ + ID: castID, + CastType: CastType_Implicit, + Function: id.NullFunction, + UseInOut: true, + }, nil + } + return Cast{}, nil +} + +// getCast is used by each individual Get function to handle the actual fetching of the cast. +func (pgc *Collection) getCast(ctx context.Context, castID id.Cast, sourceType *pgtypes.DoltgresType, targetType *pgtypes.DoltgresType, castType CastType) (Cast, error) { + if c, ok := builtInCasts[castID]; ok { + return c, nil + } + h, err := pgc.underlyingMap.Get(ctx, string(castID)) + if err != nil { + return Cast{}, err + } + if h.IsEmpty() { + // If there isn't a direct mapping, then we need to check if the types are array variants. + // As long as the base types are convertable, the array variants are also convertable. + if sourceType != nil && targetType != nil && sourceType.IsArrayType() && targetType.IsArrayType() { + fromBaseType := sourceType.ArrayBaseType() + toBaseType := targetType.ArrayBaseType() + var baseCast Cast + switch castType { + case CastType_Explicit: + baseCast, err = pgc.GetExplicitCast(ctx, fromBaseType, toBaseType) + if err != nil { + return Cast{}, err + } + case CastType_Assignment: + baseCast, err = pgc.GetAssignmentCast(ctx, fromBaseType, toBaseType) + if err != nil { + return Cast{}, err + } + case CastType_Implicit: + baseCast, err = pgc.GetImplicitCast(ctx, fromBaseType, toBaseType) + if err != nil { + return Cast{}, err + } + } + if baseCast.ID.IsValid() { + // We use a closure that can unwrap the slice, since conversion functions expect a singular non-nil value + evalFunc := func(ctx *sql.Context, vals any, sourceType *pgtypes.DoltgresType, targetType *pgtypes.DoltgresType) (any, error) { + var err error + oldVals := vals.([]any) + newVals := make([]any, len(oldVals)) + for i, oldVal := range oldVals { + if oldVal == nil { + continue + } + // Some errors are optional depending on the context, so we'll still process all values even + // after an error is received. + var nErr error + sourceBaseType := sourceType.ArrayBaseType() + targetBaseType := targetType.ArrayBaseType() + newVals[i], nErr = baseCast.Eval(ctx, oldVal, sourceBaseType, targetBaseType) + if nErr != nil && err == nil { + err = nErr + } + } + return newVals, err + } + return Cast{ + ID: castID, + CastType: castType, + Function: id.NullFunction, + BuiltIn: evalFunc, + UseInOut: false, + }, nil + } + } + return Cast{}, nil + } + data, err := pgc.ns.ReadBytes(ctx, h) + if err != nil { + return Cast{}, err + } + return DeserializeCast(ctx, data) +} + +// getSizingOrIdentityCast returns an identity cast if the two types are exactly the same, and a sizing cast if they +// only differ in their atttypmod values. Returns a Cast with an invalid ID if no cast is matched. This mirrors the +// behavior as described in: +// https://www.postgresql.org/docs/15/typeconv-query.html +func (pgc *Collection) getSizingOrIdentityCast(sourceType *pgtypes.DoltgresType, targetType *pgtypes.DoltgresType, castType CastType) Cast { + // If we receive different types, then we can return immediately + if sourceType.ID != targetType.ID { + return Cast{} + } + // If we have different atttypmod values, then we need to do a sizing cast only if one exists + if sourceType.GetAttTypMod() != targetType.GetAttTypMod() { + // TODO: We don't have any sizing cast functions implemented, so for now we'll approximate using output to input. + // We can use the query below to find all implemented sizing cast functions. It's also detailed in the link above. + // Lastly, not all sizing functions accept a boolean, but for those that do, we need to see whether true is + // used for explicit casts, or whether true is used for implicit casts. + // SELECT + // format_type(c.castsource, NULL) AS source, + // format_type(c.casttarget, NULL) AS target, + // p.oid::regprocedure AS func + // FROM pg_cast c JOIN pg_proc p ON p.oid = c.castfunc WHERE c.castsource = c.casttarget ORDER BY 1,2; + return Cast{ + ID: id.NewCast(sourceType.ID, targetType.ID), + CastType: castType, + Function: id.NullFunction, + UseInOut: true, + } + } + // If there is no sizing cast, then we simply use the identity cast + return Cast{ + ID: id.NewCast(sourceType.ID, targetType.ID), + CastType: castType, + Function: id.NullFunction, + UseInOut: false, + } +} + +// getRecordCast handles casting from a record type to a composite type (if applicable). Returns a Cast with an invalid +// ID if not applicable. +func (pgc *Collection) getRecordCast(sourceType *pgtypes.DoltgresType, targetType *pgtypes.DoltgresType, castType CastType) Cast { + // TODO: does casting to a record type always work for any composite type? + // https://www.postgresql.org/docs/15/sql-expressions.html#SQL-SYNTAX-ROW-CONSTRUCTORS seems to suggest so + // Also not sure if we should use the passthrough, or if we always default to implicit, assignment, or explicit + if sourceType.IsRecordType() && targetType.IsCompositeType() { + // When casting to a composite type, then we must match the arity and have valid casts for every position. + if targetType.IsRecordType() { + return Cast{ + ID: id.NewCast(sourceType.ID, targetType.ID), + CastType: castType, + Function: id.NullFunction, + UseInOut: false, + } + } else { + evalFunc := func(ctx *sql.Context, val any, sourceType *pgtypes.DoltgresType, targetType *pgtypes.DoltgresType) (any, error) { + vals, ok := val.([]pgtypes.RecordValue) + if !ok { + return nil, errors.New("casting input error from record type") + } + if len(targetType.CompositeAttrs) != len(vals) { + // TODO: these should go in DETAIL depending on the size + // Input has too few columns. + // Input has too many columns. + return nil, errors.Errorf("cannot cast type %s to %s", sourceType.Name(), targetType.Name()) + } + typeCollection, err := pgtypes.GetTypesCollectionFromContext(ctx) + if err != nil { + return nil, err + } + outputVals := make([]pgtypes.RecordValue, len(vals)) + for i := range vals { + valType, ok := vals[i].Type.(*pgtypes.DoltgresType) + if !ok { + return nil, errors.New("cannot cast record containing GMS type") + } + outputType, err := typeCollection.GetType(ctx, targetType.CompositeAttrs[i].TypeID) + if err != nil { + return nil, err + } + outputVals[i].Type = outputType + if vals[i].Value != nil { + var positionCast Cast + switch castType { + case CastType_Explicit: + positionCast, err = pgc.GetExplicitCast(ctx, valType, outputType) + if err != nil { + return nil, err + } + case CastType_Assignment: + positionCast, err = pgc.GetAssignmentCast(ctx, valType, outputType) + if err != nil { + return nil, err + } + case CastType_Implicit: + positionCast, err = pgc.GetImplicitCast(ctx, valType, outputType) + if err != nil { + return nil, err + } + } + if !positionCast.ID.IsValid() { + // TODO: this should be the DETAIL, with the actual error being "cannot cast type to " + return nil, errors.Errorf("Cannot cast type %s to %s in column %d", valType.Name(), outputType.Name(), i+1) + } + outputVals[i].Value, err = positionCast.Eval(ctx, vals[i].Value, valType, outputType) + if err != nil { + return nil, err + } + } + } + return outputVals, nil + } + return Cast{ + ID: id.NewCast(sourceType.ID, targetType.ID), + CastType: castType, + Function: id.NullFunction, + BuiltIn: evalFunc, + UseInOut: false, + } + } + } + return Cast{} +} + +// HasCast returns whether the given cast exists. +func (pgc *Collection) HasCast(ctx context.Context, castID id.Cast) bool { + if _, ok := builtInCasts[castID]; ok { + return true + } + ok, err := pgc.underlyingMap.Has(ctx, string(castID)) + if err == nil && ok { + return true + } + return false +} + +// AddCast adds a new cast. +func (pgc *Collection) AddCast(ctx context.Context, cast Cast) error { + // First we'll check to see if it exists + if pgc.HasCast(ctx, cast.ID) { + return errors.Errorf(`cast from type %s to type %s already exists`, + cast.ID.SourceType().TypeName(), cast.ID.TargetType().TypeName()) + } + if cast.BuiltIn != nil { + return errors.Errorf(`cannot create a built-in cast from type %s to type %s`, + cast.ID.SourceType().TypeName(), cast.ID.TargetType().TypeName()) + } + + // Now we'll add the cast to our map + data, err := cast.Serialize(ctx) + if err != nil { + return err + } + h, err := pgc.ns.WriteBytes(ctx, data) + if err != nil { + return err + } + mapEditor := pgc.underlyingMap.Editor() + if err = mapEditor.Add(ctx, string(cast.ID), h); err != nil { + return err + } + newMap, err := mapEditor.Flush(ctx) + if err != nil { + return err + } + pgc.underlyingMap = newMap + pgc.mapHash = pgc.underlyingMap.HashOf() + return nil +} + +// DropCast drops an existing cast. +func (pgc *Collection) DropCast(ctx context.Context, castIDs ...id.Cast) error { + if len(castIDs) == 0 { + return nil + } + // Check that each name exists before performing any deletions + for _, castID := range castIDs { + if _, ok := builtInCasts[castID]; !ok { + return errors.Errorf(`cannot delete built-in cast from type %s to type %s`, + castID.SourceType().TypeName(), castID.TargetType().TypeName()) + } + if ok, err := pgc.underlyingMap.Has(ctx, string(castID)); err != nil { + return err + } else if !ok { + return errors.Errorf(`cast from type %s to type %s does not exist`, + castID.SourceType().TypeName(), castID.TargetType().TypeName()) + } + } + + // Now we'll remove the casts from the map + mapEditor := pgc.underlyingMap.Editor() + for _, castID := range castIDs { + err := mapEditor.Delete(ctx, string(castID)) + if err != nil { + return err + } + } + newMap, err := mapEditor.Flush(ctx) + if err != nil { + return err + } + pgc.underlyingMap = newMap + pgc.mapHash = pgc.underlyingMap.HashOf() + return nil +} + +// resolveName returns the fully resolved name of the given cast. Returns an error if the name is ambiguous. +func (pgc *Collection) resolveName(ctx context.Context, schemaName string, formattedName string) (id.Cast, error) { + if len(formattedName) == 0 { + return id.NullCast, nil + } + + // Check for an exact match + fullID := pgc.tableNameToID(schemaName, formattedName) + if pgc.HasCast(ctx, fullID) { + return fullID, nil + } + + // Otherwise we'll iterate over all the names + var resolvedID id.Cast + err := pgc.IterateCasts(ctx, func(c Cast) (stop bool, err error) { + if !strings.EqualFold(string(c.ID), string(fullID)) { + return false, nil + } + // The above matches, so this counts as a match + if resolvedID.IsValid() { + castTableName := CastIDToTableName(c.ID) + resolvedTableName := CastIDToTableName(resolvedID) + return true, fmt.Errorf("`%s` is ambiguous, matches `%s` and `%s`", + formattedName, castTableName.String(), resolvedTableName.String()) + } + resolvedID = c.ID + return false, nil + }) + return resolvedID, err +} + +// IterateCasts iterates over all casts in the collection. +func (pgc *Collection) IterateCasts(ctx context.Context, callback func(c Cast) (stop bool, err error)) error { + for _, cast := range builtInCasts { + stop, err := callback(cast) + if err != nil { + return err + } else if stop { + return nil + } + } + return pgc.underlyingMap.IterAll(ctx, func(_ string, v hash.Hash) error { + data, err := pgc.ns.ReadBytes(ctx, v) + if err != nil { + return err + } + c, err := DeserializeCast(ctx, data) + if err != nil { + return err + } + stop, err := callback(c) + if err != nil { + return err + } else if stop { + return io.EOF + } else { + return nil + } + }) +} + +// Clone returns a new *Collection with the same contents as the original. +func (pgc *Collection) Clone(ctx context.Context) *Collection { + return &Collection{ + mapHash: pgc.mapHash, + underlyingMap: pgc.underlyingMap, + ns: pgc.ns, + } +} + +// Map writes any cached sequences to the underlying map, and then returns the underlying map. +func (pgc *Collection) Map(ctx context.Context) (prolly.AddressMap, error) { + return pgc.underlyingMap, nil +} + +// tableNameToID returns the ID that was encoded via the Name() call, as the returned TableName contains additional +// information (which this is able to process). +func (pgc *Collection) tableNameToID(schemaName string, formattedName string) id.Cast { + sections := strings.Split(strings.TrimSuffix(strings.TrimPrefix(formattedName, "("), ")"), ")|(") + if len(sections) != 4 { + return id.NullCast + } + return id.NewCast(id.NewType(sections[0], sections[1]), id.NewType(sections[2], sections[3])) +} + +// GetID implements the interface objinterface.RootObject. +func (cast Cast) GetID() id.Id { + return cast.ID.AsId() +} + +// DiffersFrom returns true when the hash that is associated with the underlying map for this collection is different +// from the hash in the given root. +func (pgc *Collection) DiffersFrom(ctx context.Context, root objinterface.RootValue) bool { + hashOnGivenRoot, err := pgc.LoadCollectionHash(ctx, root) + if err != nil { + return true + } + if pgc.mapHash.Equal(hashOnGivenRoot) { + return false + } + // An empty map should match an uninitialized collection on the root + count, err := pgc.underlyingMap.Count() + if err == nil && count == 0 && hashOnGivenRoot.IsEmpty() { + return false + } + return true +} + +// GetRootObjectID implements the interface objinterface.RootObject. +func (cast Cast) GetRootObjectID() objinterface.RootObjectID { + return objinterface.RootObjectID_Casts +} + +// HashOf implements the interface objinterface.RootObject. +func (cast Cast) HashOf(ctx context.Context) (hash.Hash, error) { + data, err := cast.Serialize(ctx) + if err != nil { + return hash.Hash{}, err + } + return hash.Of(data), nil +} + +// Name implements the interface rootobject.RootObject. +func (cast Cast) Name() doltdb.TableName { + return CastIDToTableName(cast.ID) +} + +// Eval evaluates the cast against the given value. +func (cast Cast) Eval(ctx *sql.Context, val any, sourceType *pgtypes.DoltgresType, targetType *pgtypes.DoltgresType) (any, error) { + if cast.UseInOut { + if val == nil { + return nil, nil + } + output, err := sourceType.IoOutput(ctx, val) + if err != nil { + return nil, err + } + return targetType.IoInput(ctx, output) + } + if cast.BuiltIn != nil { + return cast.BuiltIn(ctx, val, sourceType, targetType) + } + if cast.Function != id.NullFunction { + // TODO: get the function collection and call the pointed-to function (argument count determines parameters) + return nil, errors.Errorf(`cannot cast from type %s to type %s as CREATE CAST is not yet implemented`, + cast.ID.SourceType().TypeName(), cast.ID.TargetType().TypeName()) + } + // In this case, the values are binary-coercible, but we still check as we may deviate from Postgres for some reason + if _, _, err := targetType.Convert(ctx, val); err != nil { + return nil, errors.Errorf(`cast from type %s to type %s is mislabeled as binary-coercible`, + cast.ID.SourceType().TypeName(), cast.ID.TargetType().TypeName()) + } + return val, nil +} + +// CastIDToTableName returns the ID in a format that's better for user consumption. +func CastIDToTableName(castID id.Cast) doltdb.TableName { + name := fmt.Sprintf(`(%s)|(%s)|(%s)|(%s)`, + castID.SourceType().SchemaName(), + castID.SourceType().TypeName(), + castID.TargetType().SchemaName(), + castID.TargetType().TypeName()) + return doltdb.TableName{ + Name: name, + Schema: "", + } +} diff --git a/core/casts/collection_funcs.go b/core/casts/collection_funcs.go new file mode 100644 index 0000000000..8f836a6c1d --- /dev/null +++ b/core/casts/collection_funcs.go @@ -0,0 +1,133 @@ +// Copyright 2026 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casts + +import ( + "context" + + "github.com/cockroachdb/errors" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/libraries/doltcore/merge" + "github.com/dolthub/dolt/go/store/hash" + "github.com/dolthub/dolt/go/store/prolly" + "github.com/dolthub/dolt/go/store/prolly/tree" + + "github.com/dolthub/doltgresql/core/id" + "github.com/dolthub/doltgresql/core/rootobject/objinterface" + "github.com/dolthub/doltgresql/flatbuffers/gen/serial" +) + +// storage is used to read from and write to the root. +var storage = objinterface.RootObjectSerializer{ + Bytes: (*serial.RootValue).CastsBytes, + RootValueAdd: serial.RootValueAddCasts, +} + +// HandleMerge implements the interface objinterface.Collection. +func (*Collection) HandleMerge(ctx context.Context, mro merge.MergeRootObject) (doltdb.RootObject, *merge.MergeStats, error) { + ourCast := mro.OurRootObj.(Cast) + theirCast := mro.TheirRootObj.(Cast) + // Ensure that they have the same identifier + if ourCast.ID != theirCast.ID { + return nil, nil, errors.Newf("attempted to merge different casts: `%s` and `%s`", + ourCast.Name().String(), theirCast.Name().String()) + } + ourHash, err := ourCast.HashOf(ctx) + if err != nil { + return nil, nil, err + } + theirHash, err := theirCast.HashOf(ctx) + if err != nil { + return nil, nil, err + } + if ourHash.Equal(theirHash) { + return mro.OurRootObj, &merge.MergeStats{ + Operation: merge.TableUnmodified, + Adds: 0, + Deletes: 0, + Modifications: 0, + DataConflicts: 0, + SchemaConflicts: 0, + ConstraintViolations: 0, + }, nil + } + // TODO: figure out a decent merge strategy + return nil, nil, errors.Errorf("unable to merge `%s`", theirCast.Name().String()) +} + +// LoadCollection implements the interface objinterface.Collection. +func (*Collection) LoadCollection(ctx context.Context, root objinterface.RootValue) (objinterface.Collection, error) { + return LoadCasts(ctx, root) +} + +// LoadCollectionHash implements the interface objinterface.Collection. +func (*Collection) LoadCollectionHash(ctx context.Context, root objinterface.RootValue) (hash.Hash, error) { + m, ok, err := storage.GetProllyMap(ctx, root) + if err != nil || !ok { + return hash.Hash{}, err + } + return m.HashOf(), nil +} + +// LoadCasts loads the casts collection from the given root. +func LoadCasts(ctx context.Context, root objinterface.RootValue) (*Collection, error) { + m, ok, err := storage.GetProllyMap(ctx, root) + if err != nil { + return nil, err + } + if !ok { + m, err = prolly.NewEmptyAddressMap(root.NodeStore()) + if err != nil { + return nil, err + } + } + return NewCollection(ctx, m, root.NodeStore()) +} + +// ResolveNameFromObjects implements the interface objinterface.Collection. +func (*Collection) ResolveNameFromObjects(ctx context.Context, name doltdb.TableName, rootObjects []objinterface.RootObject) (doltdb.TableName, id.Id, error) { + // There are root objects to search through, so we'll create a temporary store + ns := tree.NewTestNodeStore() + addressMap, err := prolly.NewEmptyAddressMap(ns) + if err != nil { + return doltdb.TableName{}, id.Null, err + } + tempCollection, err := NewCollection(ctx, addressMap, ns) + if err != nil { + return doltdb.TableName{}, id.Null, err + } + for _, rootObject := range rootObjects { + if c, ok := rootObject.(Cast); ok { + if err = tempCollection.AddCast(ctx, c); err != nil { + return doltdb.TableName{}, id.Null, err + } + } + } + return tempCollection.ResolveName(ctx, name) +} + +// Serializer implements the interface objinterface.Collection. +func (*Collection) Serializer() objinterface.RootObjectSerializer { + return storage +} + +// UpdateRoot implements the interface objinterface.Collection. +func (pgc *Collection) UpdateRoot(ctx context.Context, root objinterface.RootValue) (objinterface.RootValue, error) { + m, err := pgc.Map(ctx) + if err != nil { + return nil, err + } + return storage.WriteProllyMap(ctx, root, m) +} diff --git a/core/casts/init.go b/core/casts/init.go new file mode 100644 index 0000000000..41d25fce35 --- /dev/null +++ b/core/casts/init.go @@ -0,0 +1,22 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casts + +import "github.com/dolthub/doltgresql/core/id" + +// Init initializes this package. +func Init() map[id.Cast]Cast { + return builtInCasts +} diff --git a/core/casts/root_object.go b/core/casts/root_object.go new file mode 100644 index 0000000000..572f557923 --- /dev/null +++ b/core/casts/root_object.go @@ -0,0 +1,149 @@ +// Copyright 2026 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casts + +import ( + "context" + "io" + + "github.com/cockroachdb/errors" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/store/hash" + + "github.com/dolthub/doltgresql/core/id" + "github.com/dolthub/doltgresql/core/rootobject/objinterface" + pgtypes "github.com/dolthub/doltgresql/server/types" +) + +// DeserializeRootObject implements the interface objinterface.Collection. +func (pgc *Collection) DeserializeRootObject(ctx context.Context, data []byte) (objinterface.RootObject, error) { + return DeserializeCast(ctx, data) +} + +// DiffRootObjects implements the interface objinterface.Collection. +func (pgc *Collection) DiffRootObjects(ctx context.Context, fromHash string, ours objinterface.RootObject, theirs objinterface.RootObject, ancestor objinterface.RootObject) ([]objinterface.RootObjectDiff, objinterface.RootObject, error) { + return nil, nil, errors.New("cast conflict detection has not yet been implemented") +} + +// DropRootObject implements the interface objinterface.Collection. +func (pgc *Collection) DropRootObject(ctx context.Context, identifier id.Id) error { + if identifier.Section() != id.Section_Cast { + return errors.Errorf(`cast %s does not exist`, identifier.String()) + } + return pgc.DropCast(ctx, id.Cast(identifier)) +} + +// GetFieldType implements the interface objinterface.Collection. +func (pgc *Collection) GetFieldType(ctx context.Context, fieldName string) *pgtypes.DoltgresType { + return nil +} + +// GetID implements the interface objinterface.Collection. +func (pgc *Collection) GetID() objinterface.RootObjectID { + return objinterface.RootObjectID_Casts +} + +// GetRootObject implements the interface objinterface.Collection. +func (pgc *Collection) GetRootObject(ctx context.Context, identifier id.Id) (objinterface.RootObject, bool, error) { + if identifier.Section() != id.Section_Cast { + return nil, false, nil + } + c, err := pgc.getCast(ctx, id.Cast(identifier), nil, nil, CastType_Explicit) + return c, err == nil && c.ID.IsValid(), err +} + +// HasRootObject implements the interface objinterface.Collection. +func (pgc *Collection) HasRootObject(ctx context.Context, identifier id.Id) (bool, error) { + if identifier.Section() != id.Section_Cast { + return false, nil + } + return pgc.HasCast(ctx, id.Cast(identifier)), nil +} + +// IDToTableName implements the interface objinterface.Collection. +func (pgc *Collection) IDToTableName(identifier id.Id) doltdb.TableName { + if identifier.Section() != id.Section_Cast { + return doltdb.TableName{} + } + return CastIDToTableName(id.Cast(identifier)) +} + +// IterAll implements the interface objinterface.Collection. +func (pgc *Collection) IterAll(ctx context.Context, callback func(rootObj objinterface.RootObject) (stop bool, err error)) error { + return pgc.IterateCasts(ctx, func(c Cast) (stop bool, err error) { + return callback(c) + }) +} + +// IterIDs implements the interface objinterface.Collection. +func (pgc *Collection) IterIDs(ctx context.Context, callback func(identifier id.Id) (stop bool, err error)) error { + err := pgc.underlyingMap.IterAll(ctx, func(k string, _ hash.Hash) error { + stop, err := callback(id.Id(k)) + if err != nil { + return err + } else if stop { + return io.EOF + } else { + return nil + } + }) + return err +} + +// PutRootObject implements the interface objinterface.Collection. +func (pgc *Collection) PutRootObject(ctx context.Context, rootObj objinterface.RootObject) error { + c, ok := rootObj.(Cast) + if !ok { + return errors.Newf("invalid cast root object: %T", rootObj) + } + return pgc.AddCast(ctx, c) +} + +// RenameRootObject implements the interface objinterface.Collection. +func (pgc *Collection) RenameRootObject(ctx context.Context, oldName id.Id, newName id.Id) error { + if !oldName.IsValid() || !newName.IsValid() || oldName.Section() != newName.Section() || oldName.Section() != id.Section_Cast { + return errors.New("cannot rename cast due to invalid id") + } + oldCastName := id.Cast(oldName) + newCastName := id.Cast(newName) + c, err := pgc.getCast(ctx, oldCastName, nil, nil, CastType_Explicit) + if err != nil { + return err + } + if err = pgc.DropCast(ctx, newCastName); err != nil { + return err + } + c.ID = newCastName + return pgc.AddCast(ctx, c) +} + +// ResolveName implements the interface objinterface.Collection. +func (pgc *Collection) ResolveName(ctx context.Context, name doltdb.TableName) (doltdb.TableName, id.Id, error) { + rawID, err := pgc.resolveName(ctx, name.Schema, name.Name) + if err != nil || !rawID.IsValid() { + return doltdb.TableName{}, id.Null, err + } + return CastIDToTableName(rawID), rawID.AsId(), nil +} + +// TableNameToID implements the interface objinterface.Collection. +func (pgc *Collection) TableNameToID(name doltdb.TableName) id.Id { + return pgc.tableNameToID(name.Schema, name.Name).AsId() +} + +// UpdateField implements the interface objinterface.Collection. +func (pgc *Collection) UpdateField(ctx context.Context, rootObject objinterface.RootObject, fieldName string, newValue any) (objinterface.RootObject, error) { + return nil, errors.New("updating through the conflicts table for this object type is not yet supported") +} diff --git a/core/casts/serialization.go b/core/casts/serialization.go new file mode 100644 index 0000000000..3dd4d2a2b5 --- /dev/null +++ b/core/casts/serialization.go @@ -0,0 +1,67 @@ +// Copyright 2026 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casts + +import ( + "context" + + "github.com/cockroachdb/errors" + + "github.com/dolthub/doltgresql/core/id" + "github.com/dolthub/doltgresql/utils" +) + +// Serialize returns the Cast as a byte slice. If the Cast is invalid, then this returns a nil slice. +func (cast Cast) Serialize(ctx context.Context) ([]byte, error) { + if !cast.ID.IsValid() { + return nil, nil + } + + // Initialize the writer and version + writer := utils.NewWriter(256) + writer.VariableUint(0) // Version + // Write the cast data + writer.Id(cast.ID.AsId()) + writer.Uint8(uint8(cast.CastType)) + writer.Id(cast.Function.AsId()) + writer.Bool(cast.UseInOut) + // Returns the data + return writer.Data(), nil +} + +// DeserializeCast returns the Cast that was serialized in the byte slice. Returns an empty Cast (invalid ID) if data is +// nil or empty. +func DeserializeCast(ctx context.Context, data []byte) (Cast, error) { + if len(data) == 0 { + return Cast{}, nil + } + reader := utils.NewReader(data) + version := reader.VariableUint() + if version != 0 { + return Cast{}, errors.Errorf("version %d of casts is not supported, please upgrade the server", version) + } + + // Read from the reader + t := Cast{} + t.ID = id.Cast(reader.Id()) + t.CastType = CastType(reader.Uint8()) + t.Function = id.Function(reader.Id()) + t.UseInOut = reader.Bool() + if !reader.IsEmpty() { + return Cast{}, errors.Errorf("extra data found while deserializing a cast") + } + // Return the deserialized object + return t, nil +} diff --git a/core/context.go b/core/context.go index bd2b6f52bc..60365c722a 100644 --- a/core/context.go +++ b/core/context.go @@ -22,8 +22,11 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve" + "github.com/dolthub/dolt/go/store/prolly" + "github.com/dolthub/dolt/go/store/prolly/tree" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" "github.com/dolthub/doltgresql/core/extensions" "github.com/dolthub/doltgresql/core/functions" "github.com/dolthub/doltgresql/core/procedures" @@ -44,6 +47,7 @@ type contextValues struct { procs *procedures.Collection trigs *triggers.Collection exts *extensions.Collection + casts *casts.Collection pgCatalogCache any } @@ -394,6 +398,35 @@ func GetTypesCollectionFromContext(ctx *sql.Context) (*typecollection.TypeCollec return cv.types, nil } +// GetCastsCollectionFromContext returns the given casts collection from the context. +// Will always return a collection if no error is returned. +func GetCastsCollectionFromContext(ctx *sql.Context) (*casts.Collection, error) { + // TODO: remove this nil check once contexts have been threaded everywhere + if ctx == nil { + ns := tree.NewTestNodeStore() + addressMap, err := prolly.NewEmptyAddressMap(ns) + if err != nil { + return nil, err + } + return casts.NewCollection(ctx, addressMap, ns) + } + cv, err := getContextValues(ctx) + if err != nil { + return nil, err + } + if cv.casts == nil { + _, root, err := GetRootFromContext(ctx) + if err != nil { + return nil, err + } + cv.casts, err = casts.LoadCasts(ctx, root) + if err != nil { + return nil, err + } + } + return cv.casts, nil +} + // CloseContextRootFinalizer finalizes any changes persisted within the context by writing them to the working root. // This should ONLY be called by the ContextRootFinalizer node. func CloseContextRootFinalizer(ctx *sql.Context) error { @@ -478,6 +511,8 @@ func updateSessionRootForDatabase(ctx *sql.Context, db string, cv *contextValues cv.types = nil } + // TODO: need to be able to persist cv.casts without an empty collection updating the root (no value != empty value) + // Setting the session working root doesn't do a check to see if anything actually changed or not before marking that // branch state dirty, and dolt only allows a single dirty working set per commit. So it's important here to only // update the session root if something actually changed for that db. @@ -551,6 +586,8 @@ func (cv *contextValues) clear(objID objinterface.RootObjectID) { // We don't cache these case objinterface.RootObjectID_Procedures: cv.procs = nil + case objinterface.RootObjectID_Casts: + cv.casts = nil default: panic("unhandled context clear object ID") } diff --git a/core/id/id.go b/core/id/id.go index e933feee7b..c9ab4fe7c5 100644 --- a/core/id/id.go +++ b/core/id/id.go @@ -49,6 +49,8 @@ const ( Null Id = "" // NullAccessMethod is an empty, invalid ID. This is exactly equivalent to Null. NullAccessMethod AccessMethod = "" + // NullCast is an empty, invalid ID. This is exactly equivalent to Null. + NullCast Cast = "" // NullCheck is an empty, invalid ID. This is exactly equivalent to Null. NullCheck Check = "" // NullCollation is an empty, invalid ID. This is exactly equivalent to Null. diff --git a/core/id/id_wrappers.go b/core/id/id_wrappers.go index 09245b4202..8c71494497 100644 --- a/core/id/id_wrappers.go +++ b/core/id/id_wrappers.go @@ -19,6 +19,9 @@ import "strconv" // AccessMethod is an Id wrapper for access methods. This wrapper must not be returned to the client. type AccessMethod Id +// Cast is an Id wrapper for casts. This wrapper must not be returned to the client. +type Cast Id + // Check is an Id wrapper for checks. This wrapper must not be returned to the client. type Check Id @@ -78,6 +81,14 @@ func NewAccessMethod(methodName string) AccessMethod { return AccessMethod(NewId(Section_AccessMethod, methodName)) } +// NewCast returns a new Cast. This wrapper must not be returned to the client. +func NewCast(sourceType Type, targetType Type) Cast { + if len(sourceType) == 0 && len(targetType) == 0 { + return NullCast + } + return Cast(NewId(Section_Cast, string(sourceType), string(targetType))) +} + // NewCheck returns a new Check. This wrapper must not be returned to the client. func NewCheck(schemaName string, tableName string, checkName string) Check { if len(schemaName) == 0 && len(tableName) == 0 && len(checkName) == 0 { @@ -228,6 +239,16 @@ func (id AccessMethod) MethodName() string { return Id(id).Segment(0) } +// SourceType returns the source type. +func (id Cast) SourceType() Type { + return Type(Id(id).Segment(0)) +} + +// TargetType returns the target type. +func (id Cast) TargetType() Type { + return Type(Id(id).Segment(1)) +} + // CheckName returns the check's name. func (id Check) CheckName() string { return Id(id).Segment(2) @@ -437,6 +458,9 @@ func (id View) ViewName() string { // IsValid returns whether the ID is valid. func (id AccessMethod) IsValid() bool { return Id(id).IsValid() } +// IsValid returns whether the ID is valid. +func (id Cast) IsValid() bool { return Id(id).IsValid() } + // IsValid returns whether the ID is valid. func (id Check) IsValid() bool { return Id(id).IsValid() } @@ -491,6 +515,9 @@ func (id View) IsValid() bool { return Id(id).IsValid() } // AsId returns the unwrapped ID. func (id AccessMethod) AsId() Id { return Id(id) } +// AsId returns the unwrapped ID. +func (id Cast) AsId() Id { return Id(id) } + // AsId returns the unwrapped ID. func (id Check) AsId() Id { return Id(id) } diff --git a/core/init.go b/core/init.go index a69ecb3e0c..0ade461d03 100644 --- a/core/init.go +++ b/core/init.go @@ -40,4 +40,11 @@ func Init() { pgtypes.GetTypesCollectionFromContext = func(ctx *sql.Context) (pgtypes.TypeCollection, error) { return GetTypesCollectionFromContext(ctx) } + pgtypes.GetAssignmentCast = func(ctx *sql.Context, sourceType *pgtypes.DoltgresType, targetType *pgtypes.DoltgresType) (pgtypes.Cast, error) { + castsColl, err := GetCastsCollectionFromContext(ctx) + if err != nil { + return nil, err + } + return castsColl.GetAssignmentCast(ctx, sourceType, targetType) + } } diff --git a/core/rootobject/collection.go b/core/rootobject/collection.go index 2bb1208b17..2feca3b84f 100644 --- a/core/rootobject/collection.go +++ b/core/rootobject/collection.go @@ -23,6 +23,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/merge" + "github.com/dolthub/doltgresql/core/casts" "github.com/dolthub/doltgresql/core/conflicts" "github.com/dolthub/doltgresql/core/extensions" "github.com/dolthub/doltgresql/core/functions" @@ -38,7 +39,7 @@ import ( var ( // globalCollections maps each ID to the collection. globalCollections = []objinterface.Collection{ - nil, + nil, // Corresponds to RootObjectID_None &sequences.Collection{}, &typecollection.TypeCollection{}, &functions.Collection{}, @@ -46,6 +47,7 @@ var ( &extensions.Collection{}, &conflicts.Collection{}, &procedures.Collection{}, + &casts.Collection{}, } ) diff --git a/core/rootobject/objinterface/interfaces.go b/core/rootobject/objinterface/interfaces.go index 6a9af84ba9..e3f0c310f5 100644 --- a/core/rootobject/objinterface/interfaces.go +++ b/core/rootobject/objinterface/interfaces.go @@ -38,6 +38,7 @@ const ( RootObjectID_Extensions RootObjectID_Conflicts RootObjectID_Procedures + RootObjectID_Casts ) const ( diff --git a/flatbuffers/gen/serial/rootvalue.go b/flatbuffers/gen/serial/rootvalue.go index fbdb77205a..b38ac31b30 100644 --- a/flatbuffers/gen/serial/rootvalue.go +++ b/flatbuffers/gen/serial/rootvalue.go @@ -26,7 +26,11 @@ type RootValue struct { func InitRootValueRoot(o *RootValue, buf []byte, offset flatbuffers.UOffsetT) error { n := flatbuffers.GetUOffsetT(buf[offset:]) - return o.Init(buf, n+offset) + o.Init(buf, n+offset) + if RootValueNumFields < o.Table().NumFields() { + return flatbuffers.ErrTableHasUnknownFields + } + return nil } func TryGetRootAsRootValue(buf []byte, offset flatbuffers.UOffsetT) (*RootValue, error) { @@ -34,18 +38,26 @@ func TryGetRootAsRootValue(buf []byte, offset flatbuffers.UOffsetT) (*RootValue, return x, InitRootValueRoot(x, buf, offset) } +func GetRootAsRootValue(buf []byte, offset flatbuffers.UOffsetT) *RootValue { + x := &RootValue{} + InitRootValueRoot(x, buf, offset) + return x +} + func TryGetSizePrefixedRootAsRootValue(buf []byte, offset flatbuffers.UOffsetT) (*RootValue, error) { x := &RootValue{} return x, InitRootValueRoot(x, buf, offset+flatbuffers.SizeUint32) } -func (rcv *RootValue) Init(buf []byte, i flatbuffers.UOffsetT) error { +func GetSizePrefixedRootAsRootValue(buf []byte, offset flatbuffers.UOffsetT) *RootValue { + x := &RootValue{} + InitRootValueRoot(x, buf, offset+flatbuffers.SizeUint32) + return x +} + +func (rcv *RootValue) Init(buf []byte, i flatbuffers.UOffsetT) { rcv._tab.Bytes = buf rcv._tab.Pos = i - if RootValueNumFields < rcv.Table().NumFields() { - return flatbuffers.ErrTableHasUnknownFields - } - return nil } func (rcv *RootValue) Table() flatbuffers.Table { @@ -144,6 +156,18 @@ func (rcv *RootValue) MutateCollation(n Collation) bool { return rcv._tab.MutateUint16Slot(10, uint16(n)) } +func (rcv *RootValue) Schemas(obj *DatabaseSchema, j int) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + x := rcv._tab.Vector(o) + x += flatbuffers.UOffsetT(j) * 4 + x = rcv._tab.Indirect(x) + obj.Init(rcv._tab.Bytes, x) + return true + } + return false +} + func (rcv *RootValue) TrySchemas(obj *DatabaseSchema, j int) (bool, error) { o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) if o != 0 { @@ -405,7 +429,41 @@ func (rcv *RootValue) MutateProcedures(j int, n byte) bool { return false } -const RootValueNumFields = 12 +func (rcv *RootValue) Casts(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(28)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *RootValue) CastsLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(28)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *RootValue) CastsBytes() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(28)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *RootValue) MutateCasts(j int, n byte) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(28)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.MutateByte(a+flatbuffers.UOffsetT(j*1), n) + } + return false +} + +const RootValueNumFields = 13 func RootValueStart(builder *flatbuffers.Builder) { builder.StartObject(RootValueNumFields) @@ -476,6 +534,12 @@ func RootValueAddProcedures(builder *flatbuffers.Builder, procedures flatbuffers func RootValueStartProceduresVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { return builder.StartVector(1, numElems, 1) } +func RootValueAddCasts(builder *flatbuffers.Builder, casts flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(12, flatbuffers.UOffsetT(casts), 0) +} +func RootValueStartCastsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} func RootValueEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { return builder.EndObject() } @@ -486,7 +550,11 @@ type DatabaseSchema struct { func InitDatabaseSchemaRoot(o *DatabaseSchema, buf []byte, offset flatbuffers.UOffsetT) error { n := flatbuffers.GetUOffsetT(buf[offset:]) - return o.Init(buf, n+offset) + o.Init(buf, n+offset) + if DatabaseSchemaNumFields < o.Table().NumFields() { + return flatbuffers.ErrTableHasUnknownFields + } + return nil } func TryGetRootAsDatabaseSchema(buf []byte, offset flatbuffers.UOffsetT) (*DatabaseSchema, error) { @@ -494,18 +562,26 @@ func TryGetRootAsDatabaseSchema(buf []byte, offset flatbuffers.UOffsetT) (*Datab return x, InitDatabaseSchemaRoot(x, buf, offset) } +func GetRootAsDatabaseSchema(buf []byte, offset flatbuffers.UOffsetT) *DatabaseSchema { + x := &DatabaseSchema{} + InitDatabaseSchemaRoot(x, buf, offset) + return x +} + func TryGetSizePrefixedRootAsDatabaseSchema(buf []byte, offset flatbuffers.UOffsetT) (*DatabaseSchema, error) { x := &DatabaseSchema{} return x, InitDatabaseSchemaRoot(x, buf, offset+flatbuffers.SizeUint32) } -func (rcv *DatabaseSchema) Init(buf []byte, i flatbuffers.UOffsetT) error { +func GetSizePrefixedRootAsDatabaseSchema(buf []byte, offset flatbuffers.UOffsetT) *DatabaseSchema { + x := &DatabaseSchema{} + InitDatabaseSchemaRoot(x, buf, offset+flatbuffers.SizeUint32) + return x +} + +func (rcv *DatabaseSchema) Init(buf []byte, i flatbuffers.UOffsetT) { rcv._tab.Bytes = buf rcv._tab.Pos = i - if DatabaseSchemaNumFields < rcv.Table().NumFields() { - return flatbuffers.ErrTableHasUnknownFields - } - return nil } func (rcv *DatabaseSchema) Table() flatbuffers.Table { diff --git a/flatbuffers/serial/rootvalue.fbs b/flatbuffers/serial/rootvalue.fbs index f8c4230e16..0393afff55 100644 --- a/flatbuffers/serial/rootvalue.fbs +++ b/flatbuffers/serial/rootvalue.fbs @@ -41,6 +41,8 @@ table RootValue { conflicts:[ubyte]; // Serialized AddressMap. procedures:[ubyte]; // Serialized AddressMap. + + casts:[ubyte]; // Serialized AddressMap. } table DatabaseSchema { diff --git a/server/analyzer/foreign_key.go b/server/analyzer/foreign_key.go index 0020e6a78b..6e55139c89 100755 --- a/server/analyzer/foreign_key.go +++ b/server/analyzer/foreign_key.go @@ -21,16 +21,26 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/doltgresql/core" + "github.com/dolthub/doltgresql/core/casts" "github.com/dolthub/doltgresql/server/functions/framework" "github.com/dolthub/doltgresql/server/types" ) // validateForeignKeyDefinition validates that the given foreign key definition is valid for creation func validateForeignKeyDefinition(ctx *sql.Context, fkDef sql.ForeignKeyConstraint, cols map[string]*sql.Column, parentCols map[string]*sql.Column) error { + var castsColl *casts.Collection + if len(fkDef.Columns) > 0 { + var err error + castsColl, err = core.GetCastsCollectionFromContext(ctx) + if err != nil { + return err + } + } for i := range fkDef.Columns { col := cols[strings.ToLower(fkDef.Columns[i])] parentCol := parentCols[strings.ToLower(fkDef.ParentColumns[i])] - if !foreignKeyComparableTypes(col.Type, parentCol.Type) { + if !foreignKeyComparableTypes(ctx, castsColl, col.Type, parentCol.Type) { return errors.Errorf("Key columns %q and %q are of incompatible types: %s and %s", col.Name, parentCol.Name, col.Type.String(), parentCol.Type.String()) } } @@ -39,7 +49,7 @@ func validateForeignKeyDefinition(ctx *sql.Context, fkDef sql.ForeignKeyConstrai // foreignKeyComparableTypes returns whether the two given types are able to be used as parent/child columns in a // foreign key. -func foreignKeyComparableTypes(from sql.Type, to sql.Type) bool { +func foreignKeyComparableTypes(ctx *sql.Context, castColl *casts.Collection, from sql.Type, to sql.Type) bool { dtFrom, ok := from.(*types.DoltgresType) if !ok { return false // should never be possible @@ -67,8 +77,8 @@ func foreignKeyComparableTypes(from sql.Type, to sql.Type) bool { // Additionally, we need to be able to convert freely between the two types in both directions, since we do this // during the process of enforcing the constraints - forwardConversion := types.GetAssignmentCast(dtFrom, dtTo) - reverseConversion := types.GetAssignmentCast(dtTo, dtFrom) + forwardConversion, fErr := castColl.GetAssignmentCast(ctx, dtFrom, dtTo) + reverseConversion, rErr := castColl.GetAssignmentCast(ctx, dtTo, dtFrom) - return forwardConversion != nil && reverseConversion != nil + return fErr == nil && rErr == nil && forwardConversion.ID.IsValid() && reverseConversion.ID.IsValid() } diff --git a/server/analyzer/optimize_functions.go b/server/analyzer/optimize_functions.go index 42e4e88cc5..5991ba7eec 100644 --- a/server/analyzer/optimize_functions.go +++ b/server/analyzer/optimize_functions.go @@ -64,7 +64,7 @@ func OptimizeFunctions(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, sc } // fill in default exprs if applicable - if err := compiledFunction.ResolveDefaultValues(func(defExpr string) (sql.Expression, error) { + if err := compiledFunction.ResolveDefaultValues(ctx, func(defExpr string) (sql.Expression, error) { return getDefaultExpr(ctx, a.Catalog, defExpr) }); err != nil { return nil, transform.SameTree, err @@ -105,7 +105,7 @@ func OptimizeFunctions(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, sc } // fill in default exprs if applicablea - if err = compiledFunction.ResolveDefaultValues(func(defExpr string) (sql.Expression, error) { + if err = compiledFunction.ResolveDefaultValues(ctx, func(defExpr string) (sql.Expression, error) { return getDefaultExpr(ctx, a.Catalog, defExpr) }); err != nil { return nil, transform.SameTree, err diff --git a/server/analyzer/resolve_routine_defaults.go b/server/analyzer/resolve_routine_defaults.go index 86551b8fd1..5a867fbf85 100644 --- a/server/analyzer/resolve_routine_defaults.go +++ b/server/analyzer/resolve_routine_defaults.go @@ -117,7 +117,7 @@ func ResolveProcedureDefaults(ctx *sql.Context, a *analyzer.Analyzer, node sql.N } compiledFunction := framework.NewCompiledFunction(n.ProcedureName, n.Exprs, overloadTree, false) // fill in default exprs if applicable - if err := compiledFunction.ResolveDefaultValues(func(defExpr string) (sql.Expression, error) { + if err := compiledFunction.ResolveDefaultValues(ctx, func(defExpr string) (sql.Expression, error) { return getDefaultExpr(ctx, a.Catalog, defExpr) }); err != nil { return nil, transform.SameTree, err diff --git a/server/analyzer/resolve_values_types.go b/server/analyzer/resolve_values_types.go index 029c3dbab9..debb73258b 100644 --- a/server/analyzer/resolve_values_types.go +++ b/server/analyzer/resolve_values_types.go @@ -38,7 +38,7 @@ func ResolveValuesTypes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, s // We record which VDTs changed so we can fix up GetField types afterward. transformedVDTs := make(map[sql.TableId]sql.Schema) node, same, err := transform.NodeWithOpaque(node, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { - newNode, same, err := transformValuesNode(n) + newNode, same, err := transformValuesNode(ctx, n) if err != nil { return nil, same, err } @@ -155,7 +155,7 @@ func ResolveValuesTypes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, s } // transformValuesNode transforms a plan.Values or plan.ValueDerivedTable node to use common types -func transformValuesNode(n sql.Node) (sql.Node, transform.TreeIdentity, error) { +func transformValuesNode(ctx *sql.Context, n sql.Node) (sql.Node, transform.TreeIdentity, error) { var values *plan.Values var expressionerNode sql.Expressioner switch v := n.(type) { @@ -202,7 +202,7 @@ func transformValuesNode(n sql.Node) (sql.Node, transform.TreeIdentity, error) { // Find common type for each column var newTuples [][]sql.Expression for colIdx := 0; colIdx < numCols; colIdx++ { - commonType, requiresCasts, err := framework.FindCommonType(columnTypes[colIdx]) + commonType, requiresCasts, err := framework.FindCommonType(ctx, columnTypes[colIdx]) if err != nil { return nil, transform.NewTree, err } diff --git a/server/cast/bit.go b/server/cast/bit.go index 512e898571..9971546c4d 100644 --- a/server/cast/bit.go +++ b/server/cast/bit.go @@ -18,24 +18,27 @@ import ( "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/sem/tree" - "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initBit handles all casts that are built-in. This comprises only the "From" types. -func initBit() { - bitExplicit() - bitImplicit() +// initBit handles all casts that are built-in. This comprises only the source types. +func initBit(builtInCasts map[id.Cast]casts.Cast) { + bitExplicit(builtInCasts) + bitImplicit(builtInCasts) } -// bitExplicit registers all explicit casts. This comprises only the "From" types. -func bitExplicit() { - framework.MustAddExplicitTypeCast(framework.TypeCast{ - FromType: pgtypes.Bit, - ToType: pgtypes.Int32, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { +// bitExplicit registers all explicit casts. This comprises only the source types. +func bitExplicit(builtInCasts map[id.Cast]casts.Cast) { + bitToInt32 := id.NewCast(pgtypes.Bit.ID, pgtypes.Int32.ID) + builtInCasts[bitToInt32] = casts.Cast{ + ID: bitToInt32, + CastType: casts.CastType_Explicit, + Function: id.NullFunction, + BuiltIn: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { array, err := tree.ParseDBitArray(val.(string)) if err != nil { return nil, err @@ -45,11 +48,14 @@ func bitExplicit() { } return int32(array.AsInt64(32)), nil }, - }) - framework.MustAddExplicitTypeCast(framework.TypeCast{ - FromType: pgtypes.Bit, - ToType: pgtypes.Int64, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + UseInOut: false, + } + bitToInt64 := id.NewCast(pgtypes.Bit.ID, pgtypes.Int64.ID) + builtInCasts[bitToInt64] = casts.Cast{ + ID: bitToInt64, + CastType: casts.CastType_Explicit, + Function: id.NullFunction, + BuiltIn: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { array, err := tree.ParseDBitArray(val.(string)) if err != nil { return nil, err @@ -59,15 +65,16 @@ func bitExplicit() { } return array.AsInt64(64), nil }, - }) + UseInOut: false, + } } -// bitImplicit registers all implicit casts. This comprises only the "From" types. -func bitImplicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// bitImplicit registers all implicit casts. This comprises only the source types. +func bitImplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Bit, ToType: pgtypes.Bit, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { input := val.(string) array, err := tree.ParseDBitArray(input) if err != nil { @@ -80,10 +87,10 @@ func bitImplicit() { return tree.AsStringWithFlags(array, tree.FmtPgwireText), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Bit, ToType: pgtypes.VarBit, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { input := val.(string) array, err := tree.ParseDBitArray(input) if err != nil { diff --git a/server/cast/bool.go b/server/cast/bool.go index 9281750efd..90231ea191 100644 --- a/server/cast/bool.go +++ b/server/cast/bool.go @@ -17,22 +17,25 @@ package cast import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" + "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initBool handles all casts that are built-in. This comprises only the "From" types. -func initBool() { - boolExplicit() - boolAssignment() +// initBool handles all casts that are built-in. This comprises only the source types. +func initBool(builtInCasts map[id.Cast]casts.Cast) { + boolExplicit(builtInCasts) + boolAssignment(builtInCasts) } -// boolExplicit registers all explicit casts. This comprises only the "From" types. -func boolExplicit() { - framework.MustAddExplicitTypeCast(framework.TypeCast{ +// boolExplicit registers all explicit casts. This comprises only the source types. +func boolExplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddExplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Bool, ToType: pgtypes.Int32, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { if val.(bool) { return int32(1), nil } else { @@ -42,12 +45,12 @@ func boolExplicit() { }) } -// boolAssignment registers all assignment casts. This comprises only the "From" types. -func boolAssignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// boolAssignment registers all assignment casts. This comprises only the source types. +func boolAssignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Bool, ToType: pgtypes.BpChar, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { str := "false" if val.(bool) { str = "true" @@ -55,10 +58,10 @@ func boolAssignment() { return handleStringCast(str, targetType) }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Bool, ToType: pgtypes.Name, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { str := "f" if val.(bool) { str = "t" @@ -66,10 +69,10 @@ func boolAssignment() { return handleStringCast(str, targetType) }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Bool, ToType: pgtypes.Text, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { if val.(bool) { return "true", nil } else { @@ -77,10 +80,10 @@ func boolAssignment() { } }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Bool, ToType: pgtypes.VarChar, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { str := "false" if val.(bool) { str = "true" diff --git a/server/cast/char.go b/server/cast/char.go index efb5e81257..9e885638f4 100644 --- a/server/cast/char.go +++ b/server/cast/char.go @@ -21,34 +21,36 @@ import ( "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initChar handles all casts that are built-in. This comprises only the "From" types. -func initChar() { - charAssignment() - charExplicit() - charImplicit() +// initChar handles all casts that are built-in. This comprises only the source types. +func initChar(builtInCasts map[id.Cast]casts.Cast) { + charAssignment(builtInCasts) + charExplicit(builtInCasts) + charImplicit(builtInCasts) } -// charAssignment registers all assignment casts. This comprises only the "From" types. -func charAssignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// charAssignment registers all assignment casts. This comprises only the source types. +func charAssignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.BpChar, ToType: pgtypes.InternalChar, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return targetType.IoInput(ctx, val.(string)) }, }) } -// charExplicit registers all explicit casts. This comprises only the "From" types. -func charExplicit() { - framework.MustAddExplicitTypeCast(framework.TypeCast{ +// charExplicit registers all explicit casts. This comprises only the source types. +func charExplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddExplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.BpChar, ToType: pgtypes.Int32, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { out, err := strconv.ParseInt(strings.TrimSpace(val.(string)), 10, 32) if err != nil { return nil, errors.Errorf("invalid input syntax for type %s: %q", targetType.String(), val.(string)) @@ -61,33 +63,33 @@ func charExplicit() { }) } -// charImplicit registers all implicit casts. This comprises only the "From" types. -func charImplicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// charImplicit registers all implicit casts. This comprises only the source types. +func charImplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.BpChar, ToType: pgtypes.BpChar, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return targetType.IoInput(ctx, val.(string)) }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.BpChar, ToType: pgtypes.Name, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return handleStringCast(val.(string), targetType) }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.BpChar, ToType: pgtypes.Text, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return val, nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.BpChar, ToType: pgtypes.VarChar, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return handleStringCast(val.(string), targetType) }, }) diff --git a/server/cast/date.go b/server/cast/date.go index 4c7d85fc31..6a8c2613ad 100644 --- a/server/cast/date.go +++ b/server/cast/date.go @@ -19,28 +19,30 @@ import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initDate handles all casts that are built-in. This comprises only the "From" types. -func initDate() { - dateImplicit() +// initDate handles all casts that are built-in. This comprises only the source types. +func initDate(builtInCasts map[id.Cast]casts.Cast) { + dateImplicit(builtInCasts) } -// dateImplicit registers all implicit casts. This comprises only the "From" types. -func dateImplicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// dateImplicit registers all implicit casts. This comprises only the source types. +func dateImplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Date, ToType: pgtypes.Timestamp, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return val.(time.Time), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Date, ToType: pgtypes.TimestampTZ, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return val.(time.Time), nil }, }) diff --git a/server/cast/float32.go b/server/cast/float32.go index d86bc627cc..36031914e0 100644 --- a/server/cast/float32.go +++ b/server/cast/float32.go @@ -21,22 +21,24 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/shopspring/decimal" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initFloat32 handles all casts that are built-in. This comprises only the "From" types. -func initFloat32() { - float32Assignment() - float32Implicit() +// initFloat32 handles all casts that are built-in. This comprises only the source types. +func initFloat32(builtInCasts map[id.Cast]casts.Cast) { + float32Assignment(builtInCasts) + float32Implicit(builtInCasts) } -// float32Assignment registers all assignment casts. This comprises only the "From" types. -func float32Assignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// float32Assignment registers all assignment casts. This comprises only the source types. +func float32Assignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Float32, ToType: pgtypes.Int16, - Function: func(ctx *sql.Context, valInterface any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, valInterface any, _, targetType *pgtypes.DoltgresType) (any, error) { val := float32(math.RoundToEven(float64(valInterface.(float32)))) if val > 32767 || val < -32768 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "smallint out of range") @@ -44,10 +46,10 @@ func float32Assignment() { return int16(val), nil }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Float32, ToType: pgtypes.Int32, - Function: func(ctx *sql.Context, valInterface any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, valInterface any, _, targetType *pgtypes.DoltgresType) (any, error) { val := float32(math.RoundToEven(float64(valInterface.(float32)))) if val > 2147483647 || val < -2147483648 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "integer out of range") @@ -55,10 +57,10 @@ func float32Assignment() { return int32(val), nil }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Float32, ToType: pgtypes.Int64, - Function: func(ctx *sql.Context, valInterface any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, valInterface any, _, targetType *pgtypes.DoltgresType) (any, error) { val := float32(math.RoundToEven(float64(valInterface.(float32)))) if val > 9223372036854775807 || val < -9223372036854775808 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "bigint out of range") @@ -66,21 +68,21 @@ func float32Assignment() { return int64(val), nil }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Float32, ToType: pgtypes.Numeric, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return pgtypes.GetNumericValueWithTypmod(decimal.NewFromFloat(float64(val.(float32))), targetType.GetAttTypMod()) }, }) } -// float32Implicit registers all implicit casts. This comprises only the "From" types. -func float32Implicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// float32Implicit registers all implicit casts. This comprises only the source types. +func float32Implicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Float32, ToType: pgtypes.Float64, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return float64(val.(float32)), nil }, }) diff --git a/server/cast/float64.go b/server/cast/float64.go index aedae498bd..c62699e83a 100644 --- a/server/cast/float64.go +++ b/server/cast/float64.go @@ -21,28 +21,30 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/shopspring/decimal" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initFloat64 handles all casts that are built-in. This comprises only the "From" types. -func initFloat64() { - float64Assignment() +// initFloat64 handles all casts that are built-in. This comprises only the source types. +func initFloat64(builtInCasts map[id.Cast]casts.Cast) { + float64Assignment(builtInCasts) } -// float64Assignment registers all assignment casts. This comprises only the "From" types. -func float64Assignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// float64Assignment registers all assignment casts. This comprises only the source types. +func float64Assignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Float64, ToType: pgtypes.Float32, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return float32(val.(float64)), nil }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Float64, ToType: pgtypes.Int16, - Function: func(ctx *sql.Context, valInterface any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, valInterface any, _, targetType *pgtypes.DoltgresType) (any, error) { val := math.RoundToEven(valInterface.(float64)) if val > 32767 || val < -32768 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "smallint out of range") @@ -50,10 +52,10 @@ func float64Assignment() { return int16(val), nil }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Float64, ToType: pgtypes.Int32, - Function: func(ctx *sql.Context, valInterface any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, valInterface any, _, targetType *pgtypes.DoltgresType) (any, error) { val := math.RoundToEven(valInterface.(float64)) if val > 2147483647 || val < -2147483648 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "integer out of range") @@ -61,10 +63,10 @@ func float64Assignment() { return int32(val), nil }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Float64, ToType: pgtypes.Int64, - Function: func(ctx *sql.Context, valInterface any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, valInterface any, _, targetType *pgtypes.DoltgresType) (any, error) { val := math.RoundToEven(valInterface.(float64)) if val > 9223372036854775807 || val < -9223372036854775808 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "bigint out of range") @@ -72,10 +74,10 @@ func float64Assignment() { return int64(val), nil }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Float64, ToType: pgtypes.Numeric, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return pgtypes.GetNumericValueWithTypmod(decimal.NewFromFloat(val.(float64)), targetType.GetAttTypMod()) }, }) diff --git a/server/cast/init.go b/server/cast/init.go index 47f5b16efa..632a4f9f00 100644 --- a/server/cast/init.go +++ b/server/cast/init.go @@ -15,42 +15,36 @@ package cast import ( - "github.com/dolthub/doltgresql/server/functions/framework" - "github.com/dolthub/doltgresql/server/types" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" ) // Init initializes all casts in this package. -func Init() { - initBit() - initBool() - initChar() - initDate() - initFloat32() - initFloat64() - initInt16() - initInt32() - initInt64() - initInternalChar() - initInterval() - initJson() - initJsonB() - initName() - initNumeric() - initOid() - initRegclass() - initRegproc() - initRegtype() - initText() - initTime() - initTimestamp() - initTimestampTZ() - initTimeTZ() - initVarBit() - initVarChar() - - // This is a hack to get around import cycles. The types package needs these references for type conversions in - // some contexts - types.GetImplicitCast = framework.GetImplicitCast - types.GetAssignmentCast = framework.GetAssignmentCast - types.GetExplicitCast = framework.GetExplicitCast +func Init(builtInCasts map[id.Cast]casts.Cast) { + initBit(builtInCasts) + initBool(builtInCasts) + initChar(builtInCasts) + initDate(builtInCasts) + initFloat32(builtInCasts) + initFloat64(builtInCasts) + initInt16(builtInCasts) + initInt32(builtInCasts) + initInt64(builtInCasts) + initInternalChar(builtInCasts) + initInterval(builtInCasts) + initJson(builtInCasts) + initJsonB(builtInCasts) + initName(builtInCasts) + initNumeric(builtInCasts) + initOid(builtInCasts) + initRegclass(builtInCasts) + initRegproc(builtInCasts) + initRegtype(builtInCasts) + initText(builtInCasts) + initTime(builtInCasts) + initTimestamp(builtInCasts) + initTimestampTZ(builtInCasts) + initTimeTZ(builtInCasts) + initVarBit(builtInCasts) + initVarChar(builtInCasts) } diff --git a/server/cast/int16.go b/server/cast/int16.go index 5ac3fffa2d..56cf283bce 100644 --- a/server/cast/int16.go +++ b/server/cast/int16.go @@ -18,87 +18,89 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/shopspring/decimal" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initInt16 handles all casts that are built-in. This comprises only the "From" types. -func initInt16() { - int16Implicit() +// initInt16 handles all casts that are built-in. This comprises only the source types. +func initInt16(builtInCasts map[id.Cast]casts.Cast) { + int16Implicit(builtInCasts) } -// int16Implicit registers all implicit casts. This comprises only the "From" types. -func int16Implicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// int16Implicit registers all implicit casts. This comprises only the source types. +func int16Implicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int16, ToType: pgtypes.Float32, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return float32(val.(int16)), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int16, ToType: pgtypes.Float64, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return float64(val.(int16)), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int16, ToType: pgtypes.Int32, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return int32(val.(int16)), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int16, ToType: pgtypes.Int64, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return int64(val.(int16)), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int16, ToType: pgtypes.Numeric, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return decimal.NewFromInt(int64(val.(int16))), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int16, ToType: pgtypes.Oid, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { if internalID := id.Cache().ToInternal(uint32(val.(int16))); internalID.IsValid() { return internalID, nil } return id.NewOID(uint32(val.(int16))).AsId(), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int16, ToType: pgtypes.Regclass, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { if internalID := id.Cache().ToInternal(uint32(val.(int16))); internalID.IsValid() { return internalID, nil } return id.NewOID(uint32(val.(int16))).AsId(), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int16, ToType: pgtypes.Regproc, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { if internalID := id.Cache().ToInternal(uint32(val.(int16))); internalID.IsValid() { return internalID, nil } return id.NewOID(uint32(val.(int16))).AsId(), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int16, ToType: pgtypes.Regtype, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { if internalID := id.Cache().ToInternal(uint32(val.(int16))); internalID.IsValid() { return internalID, nil } diff --git a/server/cast/int32.go b/server/cast/int32.go index cbf46fb5c0..64c4b1e970 100644 --- a/server/cast/int32.go +++ b/server/cast/int32.go @@ -16,39 +16,39 @@ package cast import ( "github.com/cockroachdb/errors" - "github.com/dolthub/go-mysql-server/sql" "github.com/shopspring/decimal" + "github.com/dolthub/doltgresql/core/casts" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initInt32 handles all casts that are built-in. This comprises only the "From" types. -func initInt32() { - int32Explicit() - int32Assignment() - int32Implicit() +// initInt32 handles all casts that are built-in. This comprises only the source types. +func initInt32(builtInCasts map[id.Cast]casts.Cast) { + int32Explicit(builtInCasts) + int32Assignment(builtInCasts) + int32Implicit(builtInCasts) } -// int32Explicit registers all explicit casts. This comprises only the "From" types. -func int32Explicit() { - framework.MustAddExplicitTypeCast(framework.TypeCast{ +// int32Explicit registers all explicit casts. This comprises only the source types. +func int32Explicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddExplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int32, ToType: pgtypes.Bool, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return val.(int32) != 0, nil }, }) } -// int32Assignment registers all assignment casts. This comprises only the "From" types. -func int32Assignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// int32Assignment registers all assignment casts. This comprises only the source types. +func int32Assignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int32, ToType: pgtypes.Int16, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { if val.(int32) > 32767 || val.(int32) < -32768 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "smallint out of range") } @@ -57,70 +57,70 @@ func int32Assignment() { }) } -// int32Implicit registers all implicit casts. This comprises only the "From" types. -func int32Implicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// int32Implicit registers all implicit casts. This comprises only the source types. +func int32Implicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int32, ToType: pgtypes.Float32, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return float32(val.(int32)), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int32, ToType: pgtypes.Float64, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return float64(val.(int32)), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int32, ToType: pgtypes.Int64, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return int64(val.(int32)), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int32, ToType: pgtypes.Numeric, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return decimal.NewFromInt(int64(val.(int32))), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int32, ToType: pgtypes.Oid, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { if internalID := id.Cache().ToInternal(uint32(val.(int32))); internalID.IsValid() { return internalID, nil } return id.NewOID(uint32(val.(int32))).AsId(), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int32, ToType: pgtypes.Regclass, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { if internalID := id.Cache().ToInternal(uint32(val.(int32))); internalID.IsValid() { return internalID, nil } return id.NewOID(uint32(val.(int32))).AsId(), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int32, ToType: pgtypes.Regproc, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { if internalID := id.Cache().ToInternal(uint32(val.(int32))); internalID.IsValid() { return internalID, nil } return id.NewOID(uint32(val.(int32))).AsId(), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int32, ToType: pgtypes.Regtype, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { if internalID := id.Cache().ToInternal(uint32(val.(int32))); internalID.IsValid() { return internalID, nil } diff --git a/server/cast/int64.go b/server/cast/int64.go index bbfdbf5ca9..5d9dee1c60 100644 --- a/server/cast/int64.go +++ b/server/cast/int64.go @@ -16,37 +16,37 @@ package cast import ( "github.com/cockroachdb/errors" - "github.com/dolthub/go-mysql-server/sql" "github.com/shopspring/decimal" + "github.com/dolthub/doltgresql/core/casts" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initInt64 handles all casts that are built-in. This comprises only the "From" types. -func initInt64() { - int64Assignment() - int64Implicit() +// initInt64 handles all casts that are built-in. This comprises only the source types. +func initInt64(builtInCasts map[id.Cast]casts.Cast) { + int64Assignment(builtInCasts) + int64Implicit(builtInCasts) } -// int64Assignment registers all assignment casts. This comprises only the "From" types. -func int64Assignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// int64Assignment registers all assignment casts. This comprises only the source types. +func int64Assignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int64, ToType: pgtypes.Int16, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { if val.(int64) > 32767 || val.(int64) < -32768 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "smallint out of range") } return int16(val.(int64)), nil }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int64, ToType: pgtypes.Int32, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { if val.(int64) > 2147483647 || val.(int64) < -2147483648 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "integer out of range") } @@ -55,33 +55,33 @@ func int64Assignment() { }) } -// int64Implicit registers all implicit casts. This comprises only the "From" types. -func int64Implicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// int64Implicit registers all implicit casts. This comprises only the source types. +func int64Implicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int64, ToType: pgtypes.Float32, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return float32(val.(int64)), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int64, ToType: pgtypes.Float64, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return float64(val.(int64)), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int64, ToType: pgtypes.Numeric, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return decimal.NewFromInt(val.(int64)), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int64, ToType: pgtypes.Oid, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { if val.(int64) > pgtypes.MaxUint32 || val.(int64) < 0 { return nil, errOutOfRange.New(targetType.String()) } @@ -91,10 +91,10 @@ func int64Implicit() { return id.NewOID(uint32(val.(int64))).AsId(), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int64, ToType: pgtypes.Regclass, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { if val.(int64) > pgtypes.MaxUint32 || val.(int64) < 0 { return nil, errOutOfRange.New(targetType.String()) } @@ -104,10 +104,10 @@ func int64Implicit() { return id.NewOID(uint32(val.(int64))).AsId(), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int64, ToType: pgtypes.Regproc, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { if val.(int64) > pgtypes.MaxUint32 || val.(int64) < 0 { return nil, errOutOfRange.New(targetType.String()) } @@ -117,10 +117,10 @@ func int64Implicit() { return id.NewOID(uint32(val.(int64))).AsId(), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int64, ToType: pgtypes.Regtype, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { if val.(int64) > pgtypes.MaxUint32 || val.(int64) < 0 { return nil, errOutOfRange.New(targetType.String()) } diff --git a/server/cast/internal_char.go b/server/cast/internal_char.go index 30d9ceb064..88617b5196 100644 --- a/server/cast/internal_char.go +++ b/server/cast/internal_char.go @@ -20,41 +20,43 @@ import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initInternalChar handles all casts that are built-in. This comprises only the "From" types. -func initInternalChar() { - internalCharAssignment() - internalCharExplicit() - internalCharImplicit() +// initInternalChar handles all casts that are built-in. This comprises only the source types. +func initInternalChar(builtInCasts map[id.Cast]casts.Cast) { + internalCharAssignment(builtInCasts) + internalCharExplicit(builtInCasts) + internalCharImplicit(builtInCasts) } -// internalCharAssignment registers all assignment casts. This comprises only the "From" types. -func internalCharAssignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// internalCharAssignment registers all assignment casts. This comprises only the source types. +func internalCharAssignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.InternalChar, ToType: pgtypes.BpChar, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return targetType.IoInput(ctx, val.(string)) }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.InternalChar, ToType: pgtypes.VarChar, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return handleStringCast(val.(string), targetType) }, }) } -// internalCharExplicit registers all explicit casts. This comprises only the "From" types. -func internalCharExplicit() { - framework.MustAddExplicitTypeCast(framework.TypeCast{ +// internalCharExplicit registers all explicit casts. This comprises only the source types. +func internalCharExplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddExplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.InternalChar, ToType: pgtypes.Int32, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { s := val.(string) if len(s) == 0 { return int32(0), nil @@ -71,12 +73,12 @@ func internalCharExplicit() { }) } -// internalCharImplicit registers all implicit casts. This comprises only the "From" types. -func internalCharImplicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// internalCharImplicit registers all implicit casts. This comprises only the source types. +func internalCharImplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.InternalChar, ToType: pgtypes.Text, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return val, nil }, }) diff --git a/server/cast/interval.go b/server/cast/interval.go index ef0d8e1c9c..4ffc57859b 100644 --- a/server/cast/interval.go +++ b/server/cast/interval.go @@ -17,26 +17,27 @@ package cast import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" + "github.com/dolthub/doltgresql/postgres/parser/duration" "github.com/dolthub/doltgresql/postgres/parser/timeofday" "github.com/dolthub/doltgresql/server/functions" - - "github.com/dolthub/doltgresql/postgres/parser/duration" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initInterval handles all casts that are built-in. This comprises only the "From" types. -func initInterval() { - intervalAssignment() - intervalImplicit() +// initInterval handles all casts that are built-in. This comprises only the source types. +func initInterval(builtInCasts map[id.Cast]casts.Cast) { + intervalAssignment(builtInCasts) + intervalImplicit(builtInCasts) } -// intervalAssignment registers all assignment casts. This comprises only the "From" types. -func intervalAssignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// intervalAssignment registers all assignment casts. This comprises only the source types. +func intervalAssignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Interval, ToType: pgtypes.Time, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { dur := val.(duration.Duration) // the month and day of the duration are excluded return timeofday.FromInt(dur.Nanos() / functions.NanosPerMicro), nil @@ -44,12 +45,12 @@ func intervalAssignment() { }) } -// intervalImplicit registers all implicit casts. This comprises only the "From" types. -func intervalImplicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// intervalImplicit registers all implicit casts. This comprises only the source types. +func intervalImplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Interval, ToType: pgtypes.Interval, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return val.(duration.Duration), nil }, }) diff --git a/server/cast/json.go b/server/cast/json.go index 78afbf028b..58955bcca5 100644 --- a/server/cast/json.go +++ b/server/cast/json.go @@ -17,21 +17,24 @@ package cast import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" + "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initJson handles all casts that are built-in. This comprises only the "From" types. -func initJson() { - jsonAssignment() +// initJson handles all casts that are built-in. This comprises only the source types. +func initJson(builtInCasts map[id.Cast]casts.Cast) { + jsonAssignment(builtInCasts) } -// jsonAssignment registers all assignment casts. This comprises only the "From" types. -func jsonAssignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// jsonAssignment registers all assignment casts. This comprises only the source types. +func jsonAssignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Json, ToType: pgtypes.JsonB, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return targetType.IoInput(ctx, val.(string)) }, }) diff --git a/server/cast/jsonb.go b/server/cast/jsonb.go index db236c68f7..f01ef7a7a2 100644 --- a/server/cast/jsonb.go +++ b/server/cast/jsonb.go @@ -16,26 +16,27 @@ package cast import ( "github.com/cockroachdb/errors" - "github.com/dolthub/go-mysql-server/sql" "github.com/shopspring/decimal" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initJsonB handles all casts that are built-in. This comprises only the "From" types. -func initJsonB() { - jsonbExplicit() - jsonbAssignment() +// initJsonB handles all casts that are built-in. This comprises only the source types. +func initJsonB(builtInCasts map[id.Cast]casts.Cast) { + jsonbExplicit(builtInCasts) + jsonbAssignment(builtInCasts) } -// jsonbExplicit registers all explicit casts. This comprises only the "From" types. -func jsonbExplicit() { - framework.MustAddExplicitTypeCast(framework.TypeCast{ +// jsonbExplicit registers all explicit casts. This comprises only the source types. +func jsonbExplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddExplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.JsonB, ToType: pgtypes.Bool, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { switch value := val.(pgtypes.JsonDocument).Value.(type) { case pgtypes.JsonValueObject: return nil, errors.Errorf("cannot cast jsonb object to type %s", targetType.String()) @@ -54,10 +55,10 @@ func jsonbExplicit() { } }, }) - framework.MustAddExplicitTypeCast(framework.TypeCast{ + framework.MustAddExplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.JsonB, ToType: pgtypes.Float32, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { switch value := val.(pgtypes.JsonDocument).Value.(type) { case pgtypes.JsonValueObject: return nil, errors.Errorf("cannot cast jsonb object to type %s", targetType.String()) @@ -77,10 +78,10 @@ func jsonbExplicit() { } }, }) - framework.MustAddExplicitTypeCast(framework.TypeCast{ + framework.MustAddExplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.JsonB, ToType: pgtypes.Float64, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { switch value := val.(pgtypes.JsonDocument).Value.(type) { case pgtypes.JsonValueObject: return nil, errors.Errorf("cannot cast jsonb object to type %s", targetType.String()) @@ -100,10 +101,10 @@ func jsonbExplicit() { } }, }) - framework.MustAddExplicitTypeCast(framework.TypeCast{ + framework.MustAddExplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.JsonB, ToType: pgtypes.Int16, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { switch value := val.(pgtypes.JsonDocument).Value.(type) { case pgtypes.JsonValueObject: return nil, errors.Errorf("cannot cast jsonb object to type %s", targetType.String()) @@ -126,10 +127,10 @@ func jsonbExplicit() { } }, }) - framework.MustAddExplicitTypeCast(framework.TypeCast{ + framework.MustAddExplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.JsonB, ToType: pgtypes.Int32, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { switch value := val.(pgtypes.JsonDocument).Value.(type) { case pgtypes.JsonValueObject: return nil, errors.Errorf("cannot cast jsonb object to type %s", targetType.String()) @@ -152,10 +153,10 @@ func jsonbExplicit() { } }, }) - framework.MustAddExplicitTypeCast(framework.TypeCast{ + framework.MustAddExplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.JsonB, ToType: pgtypes.Int64, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { switch value := val.(pgtypes.JsonDocument).Value.(type) { case pgtypes.JsonValueObject: return nil, errors.Errorf("cannot cast jsonb object to type %s", targetType.String()) @@ -178,10 +179,10 @@ func jsonbExplicit() { } }, }) - framework.MustAddExplicitTypeCast(framework.TypeCast{ + framework.MustAddExplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.JsonB, ToType: pgtypes.Numeric, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { switch value := val.(pgtypes.JsonDocument).Value.(type) { case pgtypes.JsonValueObject: return nil, errors.Errorf("cannot cast jsonb object to type %s", targetType.String()) @@ -202,12 +203,12 @@ func jsonbExplicit() { }) } -// jsonbAssignment registers all assignment casts. This comprises only the "From" types. -func jsonbAssignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// jsonbAssignment registers all assignment casts. This comprises only the source types. +func jsonbAssignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.JsonB, ToType: pgtypes.Json, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return pgtypes.JsonB.IoOutput(ctx, val) }, }) diff --git a/server/cast/name.go b/server/cast/name.go index 05747f1407..298b787392 100644 --- a/server/cast/name.go +++ b/server/cast/name.go @@ -17,40 +17,43 @@ package cast import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" + "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initName handles all casts that are built-in. This comprises only the "From" types. -func initName() { - nameAssignment() - nameImplicit() +// initName handles all casts that are built-in. This comprises only the source types. +func initName(builtInCasts map[id.Cast]casts.Cast) { + nameAssignment(builtInCasts) + nameImplicit(builtInCasts) } -// nameAssignment registers all assignment casts. This comprises only the "From" types. -func nameAssignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// nameAssignment registers all assignment casts. This comprises only the source types. +func nameAssignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Name, ToType: pgtypes.BpChar, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return handleStringCast(val.(string), targetType) }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Name, ToType: pgtypes.VarChar, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return handleStringCast(val.(string), targetType) }, }) } -// nameImplicit registers all implicit casts. This comprises only the "From" types. -func nameImplicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// nameImplicit registers all implicit casts. This comprises only the source types. +func nameImplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Name, ToType: pgtypes.Text, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return val, nil }, }) diff --git a/server/cast/numeric.go b/server/cast/numeric.go index 8eab4ef909..ead85ecf00 100644 --- a/server/cast/numeric.go +++ b/server/cast/numeric.go @@ -16,26 +16,27 @@ package cast import ( "github.com/cockroachdb/errors" - "github.com/dolthub/go-mysql-server/sql" "github.com/shopspring/decimal" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initNumeric handles all casts that are built-in. This comprises only the "From" types. -func initNumeric() { - numericAssignment() - numericImplicit() +// initNumeric handles all casts that are built-in. This comprises only the source types. +func initNumeric(builtInCasts map[id.Cast]casts.Cast) { + numericAssignment(builtInCasts) + numericImplicit(builtInCasts) } -// numericAssignment registers all assignment casts. This comprises only the "From" types. -func numericAssignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// numericAssignment registers all assignment casts. This comprises only the source types. +func numericAssignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Numeric, ToType: pgtypes.Int16, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { d := val.(decimal.Decimal) if d.LessThan(pgtypes.NumericValueMinInt16) || d.GreaterThan(pgtypes.NumericValueMaxInt16) { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "smallint out of range") @@ -43,10 +44,10 @@ func numericAssignment() { return int16(d.Round(0).IntPart()), nil }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Numeric, ToType: pgtypes.Int32, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { d := val.(decimal.Decimal) if d.LessThan(pgtypes.NumericValueMinInt32) || d.GreaterThan(pgtypes.NumericValueMaxInt32) { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "integer out of range") @@ -54,10 +55,10 @@ func numericAssignment() { return int32(d.Round(0).IntPart()), nil }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Numeric, ToType: pgtypes.Int64, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { d := val.(decimal.Decimal) if d.LessThan(pgtypes.NumericValueMinInt64) || d.GreaterThan(pgtypes.NumericValueMaxInt64) { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "bigint out of range") @@ -67,28 +68,28 @@ func numericAssignment() { }) } -// numericImplicit registers all implicit casts. This comprises only the "From" types. -func numericImplicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// numericImplicit registers all implicit casts. This comprises only the source types. +func numericImplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Numeric, ToType: pgtypes.Float32, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { f, _ := val.(decimal.Decimal).Float64() return float32(f), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Numeric, ToType: pgtypes.Float64, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { f, _ := val.(decimal.Decimal).Float64() return f, nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Numeric, ToType: pgtypes.Numeric, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return pgtypes.GetNumericValueWithTypmod(val.(decimal.Decimal), targetType.GetAttTypMod()) }, }) diff --git a/server/cast/oid.go b/server/cast/oid.go index 2aac193643..5b08b7cda2 100644 --- a/server/cast/oid.go +++ b/server/cast/oid.go @@ -17,56 +17,56 @@ package cast import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" "github.com/dolthub/doltgresql/core/id" - "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initOid handles all casts that are built-in. This comprises only the "From" types. -func initOid() { - oidAssignment() - oidImplicit() +// initOid handles all casts that are built-in. This comprises only the source types. +func initOid(builtInCasts map[id.Cast]casts.Cast) { + oidAssignment(builtInCasts) + oidImplicit(builtInCasts) } -// oidAssignment registers all assignment casts. This comprises only the "From" types. -func oidAssignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// oidAssignment registers all assignment casts. This comprises only the source types. +func oidAssignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Oid, ToType: pgtypes.Int32, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return int32(id.Cache().ToOID(val.(id.Id))), nil }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Oid, ToType: pgtypes.Int64, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return int64(id.Cache().ToOID(val.(id.Id))), nil }, }) } -// oidImplicit registers all implicit casts. This comprises only the "From" types. -func oidImplicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// oidImplicit registers all implicit casts. This comprises only the source types. +func oidImplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Oid, ToType: pgtypes.Regclass, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return val, nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Oid, ToType: pgtypes.Regproc, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return val, nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Oid, ToType: pgtypes.Regtype, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return val, nil }, }) diff --git a/server/cast/regclass.go b/server/cast/regclass.go index dd7f7e1eac..e5e656f872 100644 --- a/server/cast/regclass.go +++ b/server/cast/regclass.go @@ -17,42 +17,42 @@ package cast import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" "github.com/dolthub/doltgresql/core/id" - "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initRegclass handles all casts that are built-in. This comprises only the "From" types. -func initRegclass() { - regclassAssignment() - regclassImplicit() +// initRegclass handles all casts that are built-in. This comprises only the source types. +func initRegclass(builtInCasts map[id.Cast]casts.Cast) { + regclassAssignment(builtInCasts) + regclassImplicit(builtInCasts) } -// regclassAssignment registers all assignment casts. This comprises only the "From" types. -func regclassAssignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// regclassAssignment registers all assignment casts. This comprises only the source types. +func regclassAssignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Regclass, ToType: pgtypes.Int32, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return int32(id.Cache().ToOID(val.(id.Id))), nil }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Regclass, ToType: pgtypes.Int64, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return int64(id.Cache().ToOID(val.(id.Id))), nil }, }) } -// regclassImplicit registers all implicit casts. This comprises only the "From" types. -func regclassImplicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// regclassImplicit registers all implicit casts. This comprises only the source types. +func regclassImplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Regclass, ToType: pgtypes.Oid, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return val, nil }, }) diff --git a/server/cast/regproc.go b/server/cast/regproc.go index d9dc9b83d0..a2e4b93237 100644 --- a/server/cast/regproc.go +++ b/server/cast/regproc.go @@ -17,42 +17,42 @@ package cast import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" "github.com/dolthub/doltgresql/core/id" - "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initRegproc handles all casts that are built-in. This comprises only the "From" types. -func initRegproc() { - regprocAssignment() - regprocImplicit() +// initRegproc handles all casts that are built-in. This comprises only the source types. +func initRegproc(builtInCasts map[id.Cast]casts.Cast) { + regprocAssignment(builtInCasts) + regprocImplicit(builtInCasts) } -// regprocAssignment registers all assignment casts. This comprises only the "From" types. -func regprocAssignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// regprocAssignment registers all assignment casts. This comprises only the source types. +func regprocAssignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Regproc, ToType: pgtypes.Int32, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return int32(id.Cache().ToOID(val.(id.Id))), nil }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Regproc, ToType: pgtypes.Int64, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return int64(id.Cache().ToOID(val.(id.Id))), nil }, }) } -// regprocImplicit registers all implicit casts. This comprises only the "From" types. -func regprocImplicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// regprocImplicit registers all implicit casts. This comprises only the source types. +func regprocImplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Regproc, ToType: pgtypes.Oid, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return val, nil }, }) diff --git a/server/cast/regtype.go b/server/cast/regtype.go index 60cc36cce7..1f66de9a89 100644 --- a/server/cast/regtype.go +++ b/server/cast/regtype.go @@ -17,42 +17,42 @@ package cast import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" "github.com/dolthub/doltgresql/core/id" - "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initRegtype handles all casts that are built-in. This comprises only the "From" types. -func initRegtype() { - regtypeAssignment() - regtypeImplicit() +// initRegtype handles all casts that are built-in. This comprises only the source types. +func initRegtype(builtInCasts map[id.Cast]casts.Cast) { + regtypeAssignment(builtInCasts) + regtypeImplicit(builtInCasts) } -// regtypeAssignment registers all assignment casts. This comprises only the "From" types. -func regtypeAssignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// regtypeAssignment registers all assignment casts. This comprises only the source types. +func regtypeAssignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Regtype, ToType: pgtypes.Int32, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return int32(id.Cache().ToOID(val.(id.Id))), nil }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Regtype, ToType: pgtypes.Int64, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return int64(id.Cache().ToOID(val.(id.Id))), nil }, }) } -// regtypeImplicit registers all implicit casts. This comprises only the "From" types. -func regtypeImplicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// regtypeImplicit registers all implicit casts. This comprises only the source types. +func regtypeImplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Regtype, ToType: pgtypes.Oid, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return val, nil }, }) diff --git a/server/cast/text.go b/server/cast/text.go index 2ff5ff3620..ee09a7a694 100644 --- a/server/cast/text.go +++ b/server/cast/text.go @@ -17,61 +17,57 @@ package cast import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" + "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initText handles all casts that are built-in. This comprises only the "From" types. -func initText() { - textAssignment() - textImplicit() +// initText handles all casts that are built-in. This comprises only the source types. +func initText(builtInCasts map[id.Cast]casts.Cast) { + textAssignment(builtInCasts) + textImplicit(builtInCasts) } -// textAssignment registers all assignment casts. This comprises only the "From" types. -func textAssignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ - FromType: pgtypes.Text, - ToType: pgtypes.BpChar, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - return handleStringCast(val.(string), targetType) - }, - }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// textAssignment registers all assignment casts. This comprises only the source types. +func textAssignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Text, ToType: pgtypes.InternalChar, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return handleStringCast(val.(string), targetType) }, }) } -// textImplicit registers all implicit casts. This comprises only the "From" types. -func textImplicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// textImplicit registers all implicit casts. This comprises only the source types. +func textImplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Text, ToType: pgtypes.BpChar, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return handleStringCast(val.(string), targetType) }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Text, ToType: pgtypes.Name, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return handleStringCast(val.(string), targetType) }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Text, ToType: pgtypes.Regclass, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return targetType.IoInput(ctx, val.(string)) }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Text, ToType: pgtypes.VarChar, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return handleStringCast(val.(string), targetType) }, }) diff --git a/server/cast/time.go b/server/cast/time.go index 45129a7516..19bb2593de 100644 --- a/server/cast/time.go +++ b/server/cast/time.go @@ -17,24 +17,25 @@ package cast import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/timeofday" - "github.com/dolthub/doltgresql/server/functions" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initTime handles all casts that are built-in. This comprises only the "From" types. -func initTime() { - timeImplicit() +// initTime handles all casts that are built-in. This comprises only the source types. +func initTime(builtInCasts map[id.Cast]casts.Cast) { + timeImplicit(builtInCasts) } -// timeImplicit registers all implicit casts. This comprises only the "From" types. -func timeImplicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// timeImplicit registers all implicit casts. This comprises only the source types. +func timeImplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Time, ToType: pgtypes.Interval, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { t := val.(timeofday.TimeOfDay) dur := functions.GetIntervalDurationFromTimeComponents(0, 0, 0, int64(t.Hour()), int64(t.Minute()), int64(t.Second()), int64(t.Microsecond())*1000) return dur, nil diff --git a/server/cast/timestamp.go b/server/cast/timestamp.go index ca86d62855..85640118a5 100644 --- a/server/cast/timestamp.go +++ b/server/cast/timestamp.go @@ -19,24 +19,26 @@ import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/pgdate" "github.com/dolthub/doltgresql/postgres/parser/timeofday" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initTimestamp handles all casts that are built-in. This comprises only the "From" types. -func initTimestamp() { - timestampAssignment() - timestampImplicit() +// initTimestamp handles all casts that are built-in. This comprises only the source types. +func initTimestamp(builtInCasts map[id.Cast]casts.Cast) { + timestampAssignment(builtInCasts) + timestampImplicit(builtInCasts) } -// timestampAssignment registers all assignment casts. This comprises only the "From" types. -func timestampAssignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// timestampAssignment registers all assignment casts. This comprises only the source types. +func timestampAssignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Timestamp, ToType: pgtypes.Date, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { d, err := pgdate.MakeDateFromTime(val.(time.Time)) if err != nil { return nil, err @@ -44,28 +46,28 @@ func timestampAssignment() { return d.ToTime() }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Timestamp, ToType: pgtypes.Time, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return timeofday.FromTime(val.(time.Time)), nil }, }) } -// timestampImplicit registers all implicit casts. This comprises only the "From" types. -func timestampImplicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// timestampImplicit registers all implicit casts. This comprises only the source types. +func timestampImplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Timestamp, ToType: pgtypes.Timestamp, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return val.(time.Time), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Timestamp, ToType: pgtypes.TimestampTZ, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { // TODO: check return val.(time.Time), nil }, diff --git a/server/cast/timestamptz.go b/server/cast/timestamptz.go index 78283d6bda..b953da65da 100644 --- a/server/cast/timestamptz.go +++ b/server/cast/timestamptz.go @@ -19,6 +19,8 @@ import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/pgdate" "github.com/dolthub/doltgresql/postgres/parser/timeofday" "github.com/dolthub/doltgresql/postgres/parser/timetz" @@ -26,18 +28,18 @@ import ( pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initTimestampTZ handles all casts that are built-in. This comprises only the "From" types. -func initTimestampTZ() { - timestampTZAssignment() - timestampTZImplicit() +// initTimestampTZ handles all casts that are built-in. This comprises only the source types. +func initTimestampTZ(builtInCasts map[id.Cast]casts.Cast) { + timestampTZAssignment(builtInCasts) + timestampTZImplicit(builtInCasts) } -// timestampTZAssignment registers all assignment casts. This comprises only the "From" types. -func timestampTZAssignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// timestampTZAssignment registers all assignment casts. This comprises only the source types. +func timestampTZAssignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.TimestampTZ, ToType: pgtypes.Date, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { d, err := pgdate.MakeDateFromTime(val.(time.Time)) if err != nil { return nil, err @@ -45,36 +47,36 @@ func timestampTZAssignment() { return d.ToTime() }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.TimestampTZ, ToType: pgtypes.Time, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return timeofday.FromTime(val.(time.Time)), nil }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.TimestampTZ, ToType: pgtypes.Timestamp, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { // TODO: check return val.(time.Time), nil }, }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.TimestampTZ, ToType: pgtypes.TimeTZ, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return timetz.MakeTimeTZFromTime(val.(time.Time)), nil }, }) } -// timestampTZImplicit registers all implicit casts. This comprises only the "From" types. -func timestampTZImplicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// timestampTZImplicit registers all implicit casts. This comprises only the source types. +func timestampTZImplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.TimestampTZ, ToType: pgtypes.TimestampTZ, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return val.(time.Time), nil }, }) diff --git a/server/cast/timetz.go b/server/cast/timetz.go index f109c5a420..8a20a39519 100644 --- a/server/cast/timetz.go +++ b/server/cast/timetz.go @@ -17,35 +17,36 @@ package cast import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/timetz" - "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initTimeTZ handles all casts that are built-in. This comprises only the "From" types. -func initTimeTZ() { - timeTZAssignment() - timeTZImplicit() +// initTimeTZ handles all casts that are built-in. This comprises only the source types. +func initTimeTZ(builtInCasts map[id.Cast]casts.Cast) { + timeTZAssignment(builtInCasts) + timeTZImplicit(builtInCasts) } -// timeTZAssignment registers all assignment casts. This comprises only the "From" types. -func timeTZAssignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// timeTZAssignment registers all assignment casts. This comprises only the source types. +func timeTZAssignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.TimeTZ, ToType: pgtypes.Time, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return val.(timetz.TimeTZ).TimeOfDay, nil }, }) } -// timeTZImplicit registers all implicit casts. This comprises only the "From" types. -func timeTZImplicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// timeTZImplicit registers all implicit casts. This comprises only the source types. +func timeTZImplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.TimeTZ, ToType: pgtypes.TimeTZ, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return val.(timetz.TimeTZ), nil }, }) diff --git a/server/cast/varbit.go b/server/cast/varbit.go index 6cd7de67d1..51fcdedd0b 100644 --- a/server/cast/varbit.go +++ b/server/cast/varbit.go @@ -17,23 +17,24 @@ package cast import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/sem/tree" - "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initVarBit handles all casts that are built-in. This comprises only the "From" types. -func initVarBit() { - varBitImplicit() +// initVarBit handles all casts that are built-in. This comprises only the source types. +func initVarBit(builtInCasts map[id.Cast]casts.Cast) { + varBitImplicit(builtInCasts) } -// varBitImplicit registers all implicit casts. This comprises only the "From" types. -func varBitImplicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// varBitImplicit registers all implicit casts. This comprises only the source types. +func varBitImplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.VarBit, ToType: pgtypes.Bit, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { input := val.(string) array, err := tree.ParseDBitArray(input) if err != nil { @@ -46,10 +47,10 @@ func varBitImplicit() { return tree.AsStringWithFlags(array, tree.FmtPgwireText), nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.VarBit, ToType: pgtypes.VarBit, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { input := val.(string) array, err := tree.ParseDBitArray(input) if err != nil { diff --git a/server/cast/varchar.go b/server/cast/varchar.go index 552f44dff8..a97ed63b2b 100644 --- a/server/cast/varchar.go +++ b/server/cast/varchar.go @@ -17,61 +17,57 @@ package cast import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/casts" + "github.com/dolthub/doltgresql/core/id" + "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// initVarChar handles all casts that are built-in. This comprises only the "From" types. -func initVarChar() { - varcharAssignment() - varcharImplicit() +// initVarChar handles all casts that are built-in. This comprises only the source types. +func initVarChar(builtInCasts map[id.Cast]casts.Cast) { + varcharAssignment(builtInCasts) + varcharImplicit(builtInCasts) } -// varcharAssignment registers all assignment casts. This comprises only the "From" types. -func varcharAssignment() { - framework.MustAddAssignmentTypeCast(framework.TypeCast{ - FromType: pgtypes.VarChar, - ToType: pgtypes.BpChar, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - return handleStringCast(val.(string), targetType) - }, - }) - framework.MustAddAssignmentTypeCast(framework.TypeCast{ +// varcharAssignment registers all assignment casts. This comprises only the source types. +func varcharAssignment(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddAssignmentTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.VarChar, ToType: pgtypes.InternalChar, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return handleStringCast(val.(string), targetType) }, }) } -// varcharImplicit registers all implicit casts. This comprises only the "From" types. -func varcharImplicit() { - framework.MustAddImplicitTypeCast(framework.TypeCast{ +// varcharImplicit registers all implicit casts. This comprises only the source types. +func varcharImplicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.VarChar, ToType: pgtypes.BpChar, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return handleStringCast(val.(string), targetType) }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.VarChar, ToType: pgtypes.Name, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return handleStringCast(val.(string), targetType) }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.VarChar, ToType: pgtypes.Text, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return val, nil }, }) - framework.MustAddImplicitTypeCast(framework.TypeCast{ + framework.MustAddImplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.VarChar, ToType: pgtypes.VarChar, - Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { return handleStringCast(val.(string), targetType) }, }) diff --git a/server/connection_data.go b/server/connection_data.go index f064290613..a718faf9da 100644 --- a/server/connection_data.go +++ b/server/connection_data.go @@ -229,7 +229,8 @@ func checkCompatibleTypes(existingOid, newOid uint32, newName string) error { var err error existing := pgtypes.GetTypeByID(id.Type(id.Cache().ToInternal(existingOid))) newType := pgtypes.GetTypeByID(id.Type(id.Cache().ToInternal(newOid))) - if _, _, err = framework.FindCommonType([]*pgtypes.DoltgresType{existing, newType}); err != nil { + // TODO: sql.Context needs to be threaded everywhere + if _, _, err = framework.FindCommonType(nil, []*pgtypes.DoltgresType{existing, newType}); err != nil { err = errors.Errorf("parameter %s is used for incompatible types: %s and %s", newName, existing.String(), newType.String()) } return err diff --git a/server/expression/array.go b/server/expression/array.go index 0a91fe70bc..0620977bc3 100644 --- a/server/expression/array.go +++ b/server/expression/array.go @@ -21,6 +21,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" vitess "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/dolthub/doltgresql/core" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -63,6 +64,10 @@ func (array *Array) Children() []sql.Expression { func (array *Array) Eval(ctx *sql.Context, row sql.Row) (any, error) { resultTyp := array.coercedType.ArrayBaseType() values := make([]any, len(array.children)) + castsColl, err := core.GetCastsCollectionFromContext(ctx) + if err != nil { + return nil, err + } for i, expr := range array.children { val, err := expr.Eval(ctx, row) if err != nil { @@ -80,12 +85,15 @@ func (array *Array) Eval(ctx *sql.Context, row sql.Row) (any, error) { } // We always cast the element, as there may be parameter restrictions in place - castFunc := framework.GetImplicitCast(doltgresType, resultTyp) - if castFunc == nil { + cast, err := castsColl.GetImplicitCast(ctx, doltgresType, resultTyp) + if err != nil { + return nil, err + } + if !cast.ID.IsValid() { return nil, errors.Errorf("cannot find cast function from %s to %s", doltgresType.String(), resultTyp.String()) } - values[i], err = castFunc(ctx, val, resultTyp) + values[i], err = cast.Eval(ctx, val, doltgresType, resultTyp) if err != nil { return nil, err } @@ -171,7 +179,8 @@ func (array *Array) getTargetType(children ...sql.Expression) (*pgtypes.Doltgres childrenTypes = append(childrenTypes, childType) } } - targetType, _, err := framework.FindCommonType(childrenTypes) + // TODO: sql.Context needs to be threaded everywhere + targetType, _, err := framework.FindCommonType(nil, childrenTypes) if err != nil { return nil, errors.Errorf("ARRAY %s", err.Error()) } diff --git a/server/expression/assignment_cast.go b/server/expression/assignment_cast.go index e897192f43..e1dce6f07a 100644 --- a/server/expression/assignment_cast.go +++ b/server/expression/assignment_cast.go @@ -16,30 +16,29 @@ package expression import ( "github.com/cockroachdb/errors" - "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/doltgresql/server/functions/framework" + "github.com/dolthub/doltgresql/core" pgtypes "github.com/dolthub/doltgresql/server/types" ) // AssignmentCast handles assignment casts. type AssignmentCast struct { - expr sql.Expression - fromType *pgtypes.DoltgresType - toType *pgtypes.DoltgresType + expr sql.Expression + sourceType *pgtypes.DoltgresType + targetType *pgtypes.DoltgresType } var _ sql.Expression = (*AssignmentCast)(nil) // NewAssignmentCast returns a new *AssignmentCast expression. -func NewAssignmentCast(expr sql.Expression, fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType) *AssignmentCast { - toType = checkForDomainType(toType) - fromType = checkForDomainType(fromType) +func NewAssignmentCast(expr sql.Expression, sourceType *pgtypes.DoltgresType, targetType *pgtypes.DoltgresType) *AssignmentCast { + targetType = checkForDomainType(targetType) + sourceType = checkForDomainType(sourceType) return &AssignmentCast{ - expr: expr, - fromType: fromType, - toType: toType, + expr: expr, + sourceType: sourceType, + targetType: targetType, } } @@ -54,12 +53,19 @@ func (ac *AssignmentCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { if err != nil || val == nil { return val, err } - castFunc := framework.GetAssignmentCast(ac.fromType, ac.toType) - if castFunc == nil { + castsColl, err := core.GetCastsCollectionFromContext(ctx) + if err != nil { + return nil, err + } + cast, err := castsColl.GetAssignmentCast(ctx, ac.sourceType, ac.targetType) + if err != nil { + return nil, err + } + if !cast.ID.IsValid() { return nil, errors.Errorf("ASSIGNMENT_CAST: target is of type %s but expression is of type %s: %s", - ac.toType.String(), ac.fromType.String(), ac.expr.String()) + ac.targetType.String(), ac.sourceType.String(), ac.expr.String()) } - return castFunc(ctx, val, ac.toType) + return cast.Eval(ctx, val, ac.sourceType, ac.targetType) } // IsNullable implements the sql.Expression interface. @@ -79,7 +85,7 @@ func (ac *AssignmentCast) String() string { // Type implements the sql.Expression interface. func (ac *AssignmentCast) Type() sql.Type { - return ac.toType + return ac.targetType } // WithChildren implements the sql.Expression interface. @@ -87,9 +93,11 @@ func (ac *AssignmentCast) WithChildren(children ...sql.Expression) (sql.Expressi if len(children) != 1 { return nil, sql.ErrInvalidChildrenNumber.New(ac, len(children), 1) } - return NewAssignmentCast(children[0], ac.fromType, ac.toType), nil + return NewAssignmentCast(children[0], ac.sourceType, ac.targetType), nil } +// checkForDomainType returns the underlying type if the given type is a domain type. Casting always applies to the base +// type. func checkForDomainType(t *pgtypes.DoltgresType) *pgtypes.DoltgresType { if t.TypType == pgtypes.TypeType_Domain { t = t.DomainUnderlyingBaseType() diff --git a/server/expression/explicit_cast.go b/server/expression/explicit_cast.go index a3017766ad..c6173d5cc5 100644 --- a/server/expression/explicit_cast.go +++ b/server/expression/explicit_cast.go @@ -23,7 +23,7 @@ import ( "github.com/dolthub/go-mysql-server/sql/expression" vitess "github.com/dolthub/vitess/go/vt/sqlparser" - "github.com/dolthub/doltgresql/server/functions/framework" + "github.com/dolthub/doltgresql/core" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -79,7 +79,7 @@ func (c *ExplicitCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { if err != nil { return nil, err } - fromType, ok := c.sqlChild.Type().(*pgtypes.DoltgresType) + sourceType, ok := c.sqlChild.Type().(*pgtypes.DoltgresType) if !ok { // We'll leverage GMSCast to handle the conversion from a GMS type to a Doltgres type. // Rather than re-evaluating the expression, we put the result in a literal. @@ -88,7 +88,7 @@ func (c *ExplicitCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { if err != nil { return nil, err } - fromType = gmsCast.DoltgresType() + sourceType = gmsCast.DoltgresType() } if val == nil { if c.castToType.TypType == pgtypes.TypeType_Domain && !c.domainNullable { @@ -98,14 +98,21 @@ func (c *ExplicitCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { } baseCastToType := checkForDomainType(c.castToType) - castFunction := framework.GetExplicitCast(fromType, baseCastToType) - if castFunction == nil { + castsColl, err := core.GetCastsCollectionFromContext(ctx) + if err != nil { + return nil, err + } + cast, err := castsColl.GetExplicitCast(ctx, sourceType, baseCastToType) + if err != nil { + return nil, err + } + if !cast.ID.IsValid() { return nil, errors.Errorf( "EXPLICIT CAST: cast from `%s` to `%s` does not exist: %s", - fromType.String(), c.castToType.String(), c.sqlChild.String(), + sourceType.String(), c.castToType.String(), c.sqlChild.String(), ) } - castResult, err := castFunction(ctx, val, c.castToType) + castResult, err := cast.Eval(ctx, val, sourceType, c.castToType) if err != nil { // For string types and string array types, we intentionally ignore the error as using a length-restricted cast // is a way to intentionally truncate the data. All string types will always return the truncated result, even diff --git a/server/expression/implicit_cast.go b/server/expression/implicit_cast.go index fe2474a9fc..1e386c0cc4 100644 --- a/server/expression/implicit_cast.go +++ b/server/expression/implicit_cast.go @@ -16,18 +16,17 @@ package expression import ( "github.com/cockroachdb/errors" - "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/doltgresql/server/functions/framework" + "github.com/dolthub/doltgresql/core" pgtypes "github.com/dolthub/doltgresql/server/types" ) // ImplicitCast handles implicit casts. type ImplicitCast struct { - expr sql.Expression - fromType *pgtypes.DoltgresType - toType *pgtypes.DoltgresType + expr sql.Expression + sourceType *pgtypes.DoltgresType + targetType *pgtypes.DoltgresType } var _ sql.Expression = (*ImplicitCast)(nil) @@ -37,9 +36,9 @@ func NewImplicitCast(expr sql.Expression, fromType *pgtypes.DoltgresType, toType toType = checkForDomainType(toType) fromType = checkForDomainType(fromType) return &ImplicitCast{ - expr: expr, - fromType: fromType, - toType: toType, + expr: expr, + sourceType: fromType, + targetType: toType, } } @@ -54,11 +53,18 @@ func (ic *ImplicitCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { if err != nil || val == nil { return val, err } - castFunc := framework.GetImplicitCast(ic.fromType, ic.toType) - if castFunc == nil { - return nil, errors.Errorf("target is of type %s but expression is of type %s", ic.toType.String(), ic.fromType.String()) + castsColl, err := core.GetCastsCollectionFromContext(ctx) + if err != nil { + return nil, err + } + cast, err := castsColl.GetImplicitCast(ctx, ic.sourceType, ic.targetType) + if err != nil { + return nil, err + } + if !cast.ID.IsValid() { + return nil, errors.Errorf("target is of type %s but expression is of type %s", ic.targetType.String(), ic.sourceType.String()) } - return castFunc(ctx, val, ic.toType) + return cast.Eval(ctx, val, ic.sourceType, ic.targetType) } // IsNullable implements the sql.Expression interface. @@ -78,7 +84,7 @@ func (ic *ImplicitCast) String() string { // Type implements the sql.Expression interface. func (ic *ImplicitCast) Type() sql.Type { - return ic.toType + return ic.targetType } // WithChildren implements the sql.Expression interface. @@ -86,5 +92,5 @@ func (ic *ImplicitCast) WithChildren(children ...sql.Expression) (sql.Expression if len(children) != 1 { return nil, sql.ErrInvalidChildrenNumber.New(ic, len(children), 1) } - return NewImplicitCast(children[0], ic.fromType, ic.toType), nil + return NewImplicitCast(children[0], ic.sourceType, ic.targetType), nil } diff --git a/server/functions/dolt_procedures.go b/server/functions/dolt_procedures.go index c5a01df15e..14ee8908c6 100644 --- a/server/functions/dolt_procedures.go +++ b/server/functions/dolt_procedures.go @@ -27,6 +27,7 @@ import ( "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/types" + "github.com/dolthub/doltgresql/core" "github.com/dolthub/doltgresql/server/auth" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -172,17 +173,23 @@ func drainRowIter(ctx *sql.Context, rowIter sql.RowIter) (any, error) { } else if err != nil { return nil, err } + castsColl, err := core.GetCastsCollectionFromContext(ctx) + if err != nil { + return nil, err + } // The conversion to []text needs []any, not sql.Row rowSlice := make([]any, len(row)) for i := range row { - fromType, err := typeForElement(row[i]) + sourceType, err := typeForElement(row[i]) if err != nil { return nil, err } - - castFn := framework.GetExplicitCast(fromType, pgtypes.Text) - textVal, err := castFn(ctx, row[i], pgtypes.Text) + cast, err := castsColl.GetExplicitCast(ctx, sourceType, pgtypes.Text) + if err != nil { + return nil, err + } + textVal, err := cast.Eval(ctx, row[i], sourceType, pgtypes.Text) if err != nil { return nil, err } diff --git a/server/functions/framework/cast.go b/server/functions/framework/cast.go index 96b2d237b4..60492bf744 100644 --- a/server/functions/framework/cast.go +++ b/server/functions/framework/cast.go @@ -15,21 +15,13 @@ package framework import ( - "sync" - - "github.com/cockroachdb/errors" - "github.com/dolthub/go-mysql-server/sql" - - "github.com/dolthub/doltgresql/core" + "github.com/dolthub/doltgresql/core/casts" "github.com/dolthub/doltgresql/core/id" pgtypes "github.com/dolthub/doltgresql/server/types" ) -// TODO: Right now, all casts are global. We should decide how to handle this in the presence of branches, sessions, etc. - -// getCastFunction is used to recursively call the cast function for when the inner logic sees that it has two array -// types. This sidesteps providing -type getCastFunction func(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType) pgtypes.TypeCastFunction +// TODO: no need to use these functions, should instead add everything directly to built-in. +// For now, this just makes the transition easier since it's less to rewrite // TypeCast is used to cast from one type to another. type TypeCast struct { @@ -38,373 +30,47 @@ type TypeCast struct { Function pgtypes.TypeCastFunction } -// explicitTypeCastMutex is used to lock the explicit type cast map and array when writing. -var explicitTypeCastMutex = &sync.RWMutex{} - -// explicitTypeCastsMap is a map that maps: from -> to -> function. -var explicitTypeCastsMap = map[id.Type]map[id.Type]pgtypes.TypeCastFunction{} - -// explicitTypeCastsArray is a slice that holds all registered explicit casts from the given type. -var explicitTypeCastsArray = map[id.Type][]*pgtypes.DoltgresType{} - -// assignmentTypeCastMutex is used to lock the assignment type cast map and array when writing. -var assignmentTypeCastMutex = &sync.RWMutex{} - -// assignmentTypeCastsMap is a map that maps: from -> to -> function. -var assignmentTypeCastsMap = map[id.Type]map[id.Type]pgtypes.TypeCastFunction{} - -// assignmentTypeCastsArray is a slice that holds all registered assignment casts from the given type. -var assignmentTypeCastsArray = map[id.Type][]*pgtypes.DoltgresType{} - -// implicitTypeCastMutex is used to lock the implicit type cast map and array when writing. -var implicitTypeCastMutex = &sync.RWMutex{} - -// implicitTypeCastsMap is a map that maps: from -> to -> function. -var implicitTypeCastsMap = map[id.Type]map[id.Type]pgtypes.TypeCastFunction{} - -// implicitTypeCastsArray is a slice that holds all registered implicit casts from the given type. -var implicitTypeCastsArray = map[id.Type][]*pgtypes.DoltgresType{} - -// AddExplicitTypeCast registers the given explicit type cast. -func AddExplicitTypeCast(cast TypeCast) error { - return addTypeCast(explicitTypeCastMutex, explicitTypeCastsMap, explicitTypeCastsArray, cast) -} - -// AddAssignmentTypeCast registers the given assignment type cast. -func AddAssignmentTypeCast(cast TypeCast) error { - return addTypeCast(assignmentTypeCastMutex, assignmentTypeCastsMap, assignmentTypeCastsArray, cast) -} - -// AddImplicitTypeCast registers the given implicit type cast. -func AddImplicitTypeCast(cast TypeCast) error { - return addTypeCast(implicitTypeCastMutex, implicitTypeCastsMap, implicitTypeCastsArray, cast) -} - // MustAddExplicitTypeCast registers the given explicit type cast. Panics if an error occurs. -func MustAddExplicitTypeCast(cast TypeCast) { - if err := AddExplicitTypeCast(cast); err != nil { - panic(err) +func MustAddExplicitTypeCast(builtInCasts map[id.Cast]casts.Cast, cast TypeCast) { + castID := id.NewCast(cast.FromType.ID, cast.ToType.ID) + if _, ok := builtInCasts[castID]; ok { + panic("duplicate built-in cast") } -} - -// MustAddAssignmentTypeCast registers the given assignment type cast. Panics if an error occurs. -func MustAddAssignmentTypeCast(cast TypeCast) { - if err := AddAssignmentTypeCast(cast); err != nil { - panic(err) + builtInCasts[castID] = casts.Cast{ + ID: castID, + CastType: casts.CastType_Explicit, + Function: id.NullFunction, + BuiltIn: cast.Function, + UseInOut: false, } } -// MustAddImplicitTypeCast registers the given implicit type cast. Panics if an error occurs. -func MustAddImplicitTypeCast(cast TypeCast) { - if err := AddImplicitTypeCast(cast); err != nil { - panic(err) - } -} - -// GetPotentialExplicitCasts returns all registered explicit type casts from the given type. -func GetPotentialExplicitCasts(fromType id.Type) []*pgtypes.DoltgresType { - return getPotentialCasts(explicitTypeCastMutex, explicitTypeCastsArray, fromType) -} - -// GetPotentialAssignmentCasts returns all registered assignment and implicit type casts from the given type. -func GetPotentialAssignmentCasts(fromType id.Type) []*pgtypes.DoltgresType { - assignment := getPotentialCasts(assignmentTypeCastMutex, assignmentTypeCastsArray, fromType) - implicit := GetPotentialImplicitCasts(fromType) - both := make([]*pgtypes.DoltgresType, len(assignment)+len(implicit)) - copy(both, assignment) - copy(both[len(assignment):], implicit) - return both -} - -// GetPotentialImplicitCasts returns all registered implicit type casts from the given type. -func GetPotentialImplicitCasts(fromType id.Type) []*pgtypes.DoltgresType { - return getPotentialCasts(implicitTypeCastMutex, implicitTypeCastsArray, fromType) -} - -// GetExplicitCast returns the explicit type cast function that will cast the "from" type to the "to" type. Returns nil -// if such a cast is not valid. -func GetExplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType) pgtypes.TypeCastFunction { - if tcf := getCast(explicitTypeCastMutex, explicitTypeCastsMap, fromType, toType, GetExplicitCast); tcf != nil { - return tcf - } else if tcf = getCast(assignmentTypeCastMutex, assignmentTypeCastsMap, fromType, toType, GetExplicitCast); tcf != nil { - return tcf - } else if tcf = getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromType, toType, GetExplicitCast); tcf != nil { - return tcf - } - // We check for the identity and sizing casts after checking the maps, as the identity may be overridden by a user. - if cast := getSizingOrIdentityCast(fromType, toType, true); cast != nil { - return cast - } - if recordCast := getRecordCast(fromType, toType, GetExplicitCast); recordCast != nil { - return recordCast - } - // All types have a built-in explicit cast from string types: https://www.postgresql.org/docs/15/sql-createcast.html - if fromType.TypCategory == pgtypes.TypeCategory_StringTypes { - return func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - if val == nil { - return nil, nil - } - str, err := fromType.IoOutput(ctx, val) - if err != nil { - return nil, err - } - return targetType.IoInput(ctx, str) - } - } else if toType.TypCategory == pgtypes.TypeCategory_StringTypes { - // All types have a built-in assignment cast to string types, which we can reference in an explicit cast - return func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - if val == nil { - return nil, nil - } - str, err := fromType.IoOutput(ctx, val) - if err != nil { - return nil, err - } - return targetType.IoInput(ctx, str) - } - } - // It is always valid to convert from the `unknown` type - if fromType.ID == pgtypes.Unknown.ID { - return UnknownLiteralCast - } - return nil -} - -// GetAssignmentCast returns the assignment type cast function that will cast the "from" type to the "to" type. Returns -// nil if such a cast is not valid. -func GetAssignmentCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType) pgtypes.TypeCastFunction { - if tcf := getCast(assignmentTypeCastMutex, assignmentTypeCastsMap, fromType, toType, GetAssignmentCast); tcf != nil { - return tcf - } else if tcf = getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromType, toType, GetAssignmentCast); tcf != nil { - return tcf - } - // We check for the identity and sizing casts after checking the maps, as the identity may be overridden by a user. - if cast := getSizingOrIdentityCast(fromType, toType, false); cast != nil { - return cast - } - // We then check for a record to composite cast - if recordCast := getRecordCast(fromType, toType, GetAssignmentCast); recordCast != nil { - return recordCast - } - // All types have a built-in assignment cast to string types: https://www.postgresql.org/docs/15/sql-createcast.html - if toType.TypCategory == pgtypes.TypeCategory_StringTypes { - return func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - if val == nil { - return nil, nil - } - str, err := fromType.IoOutput(ctx, val) - if err != nil { - return nil, err - } - return targetType.IoInput(ctx, str) - } - } - // It is always valid to convert from the `unknown` type - if fromType.ID == pgtypes.Unknown.ID { - return UnknownLiteralCast - } - return nil -} - -// GetImplicitCast returns the implicit type cast function that will cast the "from" type to the "to" type. Returns nil -// if such a cast is not valid. -func GetImplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType) pgtypes.TypeCastFunction { - if tcf := getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromType, toType, GetImplicitCast); tcf != nil { - return tcf - } - // We check for the identity and sizing casts after checking the maps, as the identity may be overridden by a user. - if cast := getSizingOrIdentityCast(fromType, toType, false); cast != nil { - return cast - } - // We then check for a record to composite cast - if recordCast := getRecordCast(fromType, toType, GetImplicitCast); recordCast != nil { - return recordCast - } - // It is always valid to convert from the `unknown` type - if fromType.ID == pgtypes.Unknown.ID { - return UnknownLiteralCast - } - return nil -} - -// addTypeCast registers the given type cast. -func addTypeCast(mutex *sync.RWMutex, - castMap map[id.Type]map[id.Type]pgtypes.TypeCastFunction, - castArray map[id.Type][]*pgtypes.DoltgresType, cast TypeCast) error { - mutex.Lock() - defer mutex.Unlock() - - toMap, ok := castMap[cast.FromType.ID] - if !ok { - toMap = map[id.Type]pgtypes.TypeCastFunction{} - castMap[cast.FromType.ID] = toMap - castArray[cast.FromType.ID] = nil - } - if _, ok := toMap[cast.ToType.ID]; ok { - // TODO: return the actual Postgres error - return errors.Errorf("cast from `%s` to `%s` already exists", cast.FromType.String(), cast.ToType.String()) - } - toMap[cast.ToType.ID] = cast.Function - castArray[cast.FromType.ID] = append(castArray[cast.FromType.ID], cast.ToType) - return nil -} - -// getPotentialCasts returns all registered type casts from the given type. -func getPotentialCasts(mutex *sync.RWMutex, castArray map[id.Type][]*pgtypes.DoltgresType, fromType id.Type) []*pgtypes.DoltgresType { - mutex.RLock() - defer mutex.RUnlock() - - return castArray[fromType] -} - -// getCast returns the type cast function that will cast the "from" type to the "to" type. Returns nil if such a cast is -// not valid. -func getCast(mutex *sync.RWMutex, - castMap map[id.Type]map[id.Type]pgtypes.TypeCastFunction, - fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType, outerFunc getCastFunction) pgtypes.TypeCastFunction { - mutex.RLock() - defer mutex.RUnlock() - - if toMap, ok := castMap[fromType.ID]; ok { - if f, ok := toMap[toType.ID]; ok { - return f - } - } - // If there isn't a direct mapping, then we need to check if the types are array variants. - // As long as the base types are convertable, the array variants are also convertable. - if fromType.IsArrayType() && toType.IsArrayType() { - fromBaseType := fromType.ArrayBaseType() - toBaseType := toType.ArrayBaseType() - if baseCast := outerFunc(fromBaseType, toBaseType); baseCast != nil { - // We use a closure that can unwrap the slice, since conversion functions expect a singular non-nil value - return func(ctx *sql.Context, vals any, targetType *pgtypes.DoltgresType) (any, error) { - var err error - oldVals := vals.([]any) - newVals := make([]any, len(oldVals)) - for i, oldVal := range oldVals { - if oldVal == nil { - continue - } - // Some errors are optional depending on the context, so we'll still process all values even - // after an error is received. - var nErr error - targetBaseType := targetType.ArrayBaseType() - newVals[i], nErr = baseCast(ctx, oldVal, targetBaseType) - if nErr != nil && err == nil { - err = nErr - } - } - return newVals, err - } - } - - } - return nil -} - -// getSizingOrIdentityCast returns an identity cast if the two types are exactly the same, and a sizing cast if they -// only differ in their atttypmod values. Returns nil if no functions are matched. This mirrors the behavior as described in: -// https://www.postgresql.org/docs/15/typeconv-query.html -func getSizingOrIdentityCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType, isExplicitCast bool) pgtypes.TypeCastFunction { - // If we receive different types, then we can return immediately - if fromType.ID != toType.ID { - return nil - } - // If we have different atttypmod values, then we need to do a sizing cast only if one exists - if fromType.GetAttTypMod() != toType.GetAttTypMod() { - // TODO: We don't have any sizing cast functions implemented, so for now we'll approximate using output to input. - // We can use the query below to find all implemented sizing cast functions. It's also detailed in the link above. - // Lastly, not all sizing functions accept a boolean, but for those that do, we need to see whether true is - // used for explicit casts, or whether true is used for implicit casts. - // SELECT - // format_type(c.castsource, NULL) AS source, - // format_type(c.casttarget, NULL) AS target, - // p.oid::regprocedure AS func - // FROM pg_cast c JOIN pg_proc p ON p.oid = c.castfunc WHERE c.castsource = c.casttarget ORDER BY 1,2; - return func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - if val == nil { - return nil, nil - } - str, err := fromType.IoOutput(ctx, val) - if err != nil { - return nil, err - } - return targetType.IoInput(ctx, str) - } +// MustAddAssignmentTypeCast registers the given assignment type cast. Panics if an error occurs. +func MustAddAssignmentTypeCast(builtInCasts map[id.Cast]casts.Cast, cast TypeCast) { + castID := id.NewCast(cast.FromType.ID, cast.ToType.ID) + if _, ok := builtInCasts[castID]; ok { + panic("duplicate built-in cast") } - // If there is no sizing cast, then we simply use the identity cast - return IdentityCast -} - -// getRecordCast handles casting from a record type to a composite type (if applicable). Returns nil if not applicable. -func getRecordCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType, passthrough func(*pgtypes.DoltgresType, *pgtypes.DoltgresType) pgtypes.TypeCastFunction) pgtypes.TypeCastFunction { - // TODO: does casting to a record type always work for any composite type? - // https://www.postgresql.org/docs/15/sql-expressions.html#SQL-SYNTAX-ROW-CONSTRUCTORS seems to suggest so - // Also not sure if we should use the passthrough, or if we always default to implicit, assignment, or explicit - if fromType.IsRecordType() && toType.IsCompositeType() { - // When casting to a composite type, then we must match the arity and have valid casts for every position. - if toType.IsRecordType() { - return IdentityCast - } else { - return func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - vals, ok := val.([]pgtypes.RecordValue) - if !ok { - return nil, errors.New("casting input error from record type") - } - if len(targetType.CompositeAttrs) != len(vals) { - // TODO: these should go in DETAIL depending on the size - // Input has too few columns. - // Input has too many columns. - return nil, errors.Newf("cannot cast type %s to %s", fromType.Name(), targetType.Name()) - } - typeCollection, err := core.GetTypesCollectionFromContext(ctx) - if err != nil { - return nil, err - } - outputVals := make([]pgtypes.RecordValue, len(vals)) - for i := range vals { - valType, ok := vals[i].Type.(*pgtypes.DoltgresType) - if !ok { - return nil, errors.New("cannot cast record containing GMS type") - } - outputType, err := typeCollection.GetType(ctx, targetType.CompositeAttrs[i].TypeID) - if err != nil { - return nil, err - } - outputVals[i].Type = outputType - if vals[i].Value != nil { - positionCast := passthrough(valType, outputType) - if positionCast == nil { - // TODO: this should be the DETAIL, with the actual error being "cannot cast type to " - return nil, errors.Newf("Cannot cast type %s to %s in column %d", valType.Name(), outputType.Name(), i+1) - } - outputVals[i].Value, err = positionCast(ctx, vals[i].Value, outputType) - if err != nil { - return nil, err - } - } - } - return outputVals, nil - } - } + builtInCasts[castID] = casts.Cast{ + ID: castID, + CastType: casts.CastType_Assignment, + Function: id.NullFunction, + BuiltIn: cast.Function, + UseInOut: false, } - return nil } -// IdentityCast returns the input value. -func IdentityCast(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - return val, nil -} - -// UnknownLiteralCast is used when casting from an unknown literal to any type, as unknown literals are treated special in -// some contexts. -func UnknownLiteralCast(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - if val == nil { - return nil, nil - } - str, err := pgtypes.Unknown.IoOutput(ctx, val) - if err != nil { - return nil, err +// MustAddImplicitTypeCast registers the given implicit type cast. Panics if an error occurs. +func MustAddImplicitTypeCast(builtInCasts map[id.Cast]casts.Cast, cast TypeCast) { + castID := id.NewCast(cast.FromType.ID, cast.ToType.ID) + if _, ok := builtInCasts[castID]; ok { + panic("duplicate built-in cast") + } + builtInCasts[castID] = casts.Cast{ + ID: castID, + CastType: casts.CastType_Implicit, + Function: id.NullFunction, + BuiltIn: cast.Function, + UseInOut: false, } - return targetType.IoInput(ctx, str) } diff --git a/server/functions/framework/common_type.go b/server/functions/framework/common_type.go index fc69c9a693..7dba44cfdf 100644 --- a/server/functions/framework/common_type.go +++ b/server/functions/framework/common_type.go @@ -16,6 +16,9 @@ package framework import ( "github.com/cockroachdb/errors" + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/doltgresql/core" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -23,7 +26,7 @@ import ( // FindCommonType returns the common type that given types can convert to. Returns false if no implicit casts are needed // to resolve the given types as the returned common type. // https://www.postgresql.org/docs/15/typeconv-union-case.html -func FindCommonType(types []*pgtypes.DoltgresType) (_ *pgtypes.DoltgresType, requiresCasts bool, err error) { +func FindCommonType(ctx *sql.Context, types []*pgtypes.DoltgresType) (_ *pgtypes.DoltgresType, requiresCasts bool, err error) { candidateType := pgtypes.Unknown differentTypes := false for _, typ := range types { @@ -53,14 +56,24 @@ func FindCommonType(types []*pgtypes.DoltgresType) (_ *pgtypes.DoltgresType, req return nil, false, errors.Errorf("types %s and %s cannot be matched", candidateType.String(), typ.String()) } } + castsColl, err := core.GetCastsCollectionFromContext(ctx) + if err != nil { + return nil, false, err + } // Attempt to find the most general type (or the preferred type in the type category) for _, typ := range types { if typ.ID == pgtypes.Unknown.ID || typ.ID == candidateType.ID { continue - } else if GetImplicitCast(typ, candidateType) != nil { + } else if cast, err := castsColl.GetImplicitCast(ctx, typ, candidateType); err != nil || cast.ID.IsValid() { + if err != nil { + return nil, false, err + } // typ can convert to the candidate type, so the candidate type is at least as general continue - } else if GetImplicitCast(candidateType, typ) != nil { + } else if cast, err = castsColl.GetImplicitCast(ctx, candidateType, typ); err != nil || cast.ID.IsValid() { + if err != nil { + return nil, false, err + } // the candidate type can convert to typ, but not vice versa, so typ is likely more general candidateType = typ if candidateType.IsPreferred { @@ -73,7 +86,9 @@ func FindCommonType(types []*pgtypes.DoltgresType) (_ *pgtypes.DoltgresType, req for _, typ := range types { if typ.ID == pgtypes.Unknown.ID || typ.ID == candidateType.ID { continue - } else if GetImplicitCast(typ, candidateType) == nil { + } else if cast, err := castsColl.GetImplicitCast(ctx, typ, candidateType); err != nil { + return nil, false, err + } else if !cast.ID.IsValid() { return nil, false, errors.Errorf("cannot find implicit cast function from %s to %s", candidateType.String(), typ.String()) } } diff --git a/server/functions/framework/compiled_aggregate_function.go b/server/functions/framework/compiled_aggregate_function.go index 99f0fe7781..baee990779 100644 --- a/server/functions/framework/compiled_aggregate_function.go +++ b/server/functions/framework/compiled_aggregate_function.go @@ -48,7 +48,8 @@ func NewCompiledAggregateFunction(name string, args []sql.Expression, functions // newCompiledAggregateFunctionInternal is called internally, which skips steps that may have already been processed. func newCompiledAggregateFunctionInternal(name string, args []sql.Expression, overloads *Overloads, fnOverloads []Overload, newBuffer NewBufferFn) *CompiledAggregateFunction { - cf := newCompiledFunctionInternal(name, args, overloads, fnOverloads, false, nil) + // TODO: sql.Context should be threaded everywhere + cf := newCompiledFunctionInternal(nil, name, args, overloads, fnOverloads, false, nil) c := &CompiledAggregateFunction{ CompiledFunction: cf, newBuffer: newBuffer, diff --git a/server/functions/framework/compiled_function.go b/server/functions/framework/compiled_function.go index 52e6c7099b..d15f18251d 100644 --- a/server/functions/framework/compiled_function.go +++ b/server/functions/framework/compiled_function.go @@ -24,8 +24,11 @@ import ( "github.com/dolthub/go-mysql-server/sql/procedures" "gopkg.in/src-d/go-errors.v1" + "github.com/dolthub/doltgresql/core" + "github.com/dolthub/doltgresql/core/casts" "github.com/dolthub/doltgresql/core/extensions" "github.com/dolthub/doltgresql/core/extensions/pg_extension" + "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/server/plpgsql" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -61,11 +64,13 @@ var _ sql.RowIterExpression = (*CompiledFunction)(nil) // NewCompiledFunction returns a newly compiled function. func NewCompiledFunction(name string, args []sql.Expression, functions *Overloads, isOperator bool) *CompiledFunction { - return newCompiledFunctionInternal(name, args, functions, functions.overloadsForParams(len(args)), isOperator, nil) + // TODO: sql.Context needs to be threaded everywhere + return newCompiledFunctionInternal(nil, name, args, functions, functions.overloadsForParams(len(args)), isOperator, nil) } // newCompiledFunctionInternal is called internally, which skips steps that may have already been processed. func newCompiledFunctionInternal( + ctx *sql.Context, name string, args []sql.Expression, overloads *Overloads, @@ -89,7 +94,7 @@ func newCompiledFunctionInternal( return c } // Next we'll resolve the overload based on the parameters given. - overload, err := c.resolve(overloads, fnOverloads, originalTypes) + overload, err := c.resolve(ctx, overloads, fnOverloads, originalTypes) if err != nil { c.stashedErr = err return c @@ -304,8 +309,8 @@ func (c *CompiledFunction) Eval(ctx *sql.Context, row sql.Row) (interface{}, err targetType = targetParamTypes[i] } - if c.overload.casts[i] != nil { - args[i], err = c.overload.casts[i](ctx, arg, targetType) + if c.overload.casts[i].ID.IsValid() { + args[i], err = c.overload.casts[i].Eval(ctx, arg, exprTypes[i], targetType) if err != nil { return nil, err } @@ -412,7 +417,8 @@ func (c *CompiledFunction) WithChildren(children ...sql.Expression) (sql.Express } // We have to re-resolve here, since the change in children may require it (e.g. we have more type info than we did) - return newCompiledFunctionInternal(c.Name, children, c.overloads, c.fnOverloads, c.IsOperator, c.runner), nil + // TODO: sql.Context needs to be threaded everywhere + return newCompiledFunctionInternal(nil, c.Name, children, c.overloads, c.fnOverloads, c.IsOperator, c.runner), nil } // SetStatementRunner implements the interface analyzer.Interpreter. @@ -464,7 +470,7 @@ func (c *CompiledFunction) GetQuickFunction() QuickFunction { // resolve returns an overloadMatch that either matches the given parameters exactly, or is a viable match after casting. // Returns an invalid overloadMatch if a viable match is not found. -func (c *CompiledFunction) resolve(overloads *Overloads, fnOverloads []Overload, argTypes []*pgtypes.DoltgresType) (overloadMatch, error) { +func (c *CompiledFunction) resolve(ctx *sql.Context, overloads *Overloads, fnOverloads []Overload, argTypes []*pgtypes.DoltgresType) (overloadMatch, error) { // First check for an exact match exactMatch, found := overloads.ExactMatchForTypes(argTypes...) if found { @@ -480,15 +486,15 @@ func (c *CompiledFunction) resolve(overloads *Overloads, fnOverloads []Overload, // There are no exact matches, so now we'll look through all overloads to determine the best match. This is // much more work, but there's a performance penalty for runtime overload resolution in Postgres as well. if c.IsOperator { - return c.resolveOperator(argTypes, overloads, fnOverloads) + return c.resolveOperator(ctx, argTypes, overloads, fnOverloads) } else { - return c.resolveFunction(argTypes, fnOverloads) + return c.resolveFunction(ctx, argTypes, fnOverloads) } } // resolveOperator resolves an operator according to the rules defined by Postgres. // https://www.postgresql.org/docs/15/typeconv-oper.html -func (c *CompiledFunction) resolveOperator(argTypes []*pgtypes.DoltgresType, overloads *Overloads, fnOverloads []Overload) (overloadMatch, error) { +func (c *CompiledFunction) resolveOperator(ctx *sql.Context, argTypes []*pgtypes.DoltgresType, overloads *Overloads, fnOverloads []Overload) (overloadMatch, error) { // Binary operators treat unknown literals as the other type, so we'll account for that here to see if we can find // an "exact" match. if len(argTypes) == 2 { @@ -496,12 +502,18 @@ func (c *CompiledFunction) resolveOperator(argTypes []*pgtypes.DoltgresType, ove rightUnknownType := argTypes[1].ID == pgtypes.Unknown.ID if (leftUnknownType && !rightUnknownType) || (!leftUnknownType && rightUnknownType) { var typ *pgtypes.DoltgresType - casts := []pgtypes.TypeCastFunction{IdentityCast, IdentityCast} + identity := casts.Cast{ + ID: id.NewCast(argTypes[0].ID, argTypes[1].ID), + CastType: casts.CastType_Explicit, + Function: id.NullFunction, + UseInOut: false, + } + opCasts := []casts.Cast{identity, identity} if leftUnknownType { - casts[0] = UnknownLiteralCast + opCasts[0].UseInOut = true typ = argTypes[1] } else { - casts[1] = UnknownLiteralCast + opCasts[1].UseInOut = true typ = argTypes[0] } if exactMatch, ok := overloads.ExactMatchForTypes(typ, typ); ok { @@ -512,20 +524,23 @@ func (c *CompiledFunction) resolveOperator(argTypes []*pgtypes.DoltgresType, ove argTypes: []*pgtypes.DoltgresType{typ, typ}, variadic: -1, }, - casts: casts, + casts: opCasts, }, nil } } } // From this point, the steps appear to be the same for functions and operators - return c.resolveFunction(argTypes, fnOverloads) + return c.resolveFunction(ctx, argTypes, fnOverloads) } // resolveFunction resolves a function according to the rules defined by Postgres. // https://www.postgresql.org/docs/15/typeconv-func.html -func (c *CompiledFunction) resolveFunction(argTypes []*pgtypes.DoltgresType, overloads []Overload) (overloadMatch, error) { +func (c *CompiledFunction) resolveFunction(ctx *sql.Context, argTypes []*pgtypes.DoltgresType, overloads []Overload) (overloadMatch, error) { // First we'll discard all overloads that do not have implicitly-convertible param types - compatibleOverloads := c.typeCompatibleOverloads(overloads, argTypes) + compatibleOverloads, err := c.typeCompatibleOverloads(ctx, overloads, argTypes) + if err != nil { + return overloadMatch{}, err + } // No compatible overloads available, return early if len(compatibleOverloads) == 0 { @@ -573,22 +588,36 @@ func (c *CompiledFunction) resolveFunction(argTypes []*pgtypes.DoltgresType, ove // typeCompatibleOverloads returns all overloads that have a matching number of params whose types can be // implicitly converted to the ones provided. This is the set of all possible overloads that could be used with the // param types provided. -func (c *CompiledFunction) typeCompatibleOverloads(fnOverloads []Overload, argTypes []*pgtypes.DoltgresType) []overloadMatch { +func (c *CompiledFunction) typeCompatibleOverloads(ctx *sql.Context, fnOverloads []Overload, argTypes []*pgtypes.DoltgresType) ([]overloadMatch, error) { + castsColl, err := core.GetCastsCollectionFromContext(ctx) + if err != nil { + return nil, err + } var compatible []overloadMatch for _, overload := range fnOverloads { isConvertible := true - overloadCasts := make([]pgtypes.TypeCastFunction, len(argTypes)) + overloadCasts := make([]casts.Cast, len(argTypes)) // Polymorphic parameters must be gathered so that we can later verify that they all have matching base types var polymorphicParameters []*pgtypes.DoltgresType var polymorphicTargets []*pgtypes.DoltgresType for i := range argTypes { paramType := overload.argTypes[i] if paramType.IsValidForPolymorphicType(argTypes[i]) { - overloadCasts[i] = IdentityCast + overloadCasts[i] = casts.Cast{ + ID: id.NewCast(argTypes[i].ID, paramType.ID), + CastType: casts.CastType_Explicit, + Function: id.NullFunction, + UseInOut: false, + } polymorphicParameters = append(polymorphicParameters, paramType) polymorphicTargets = append(polymorphicTargets, argTypes[i]) } else { - if overloadCasts[i] = GetImplicitCast(argTypes[i], paramType); overloadCasts[i] == nil { + var err error + overloadCasts[i], err = castsColl.GetImplicitCast(ctx, argTypes[i], paramType) + if err != nil { + return nil, err + } + if !overloadCasts[i].ID.IsValid() { isConvertible = false break } @@ -599,7 +628,7 @@ func (c *CompiledFunction) typeCompatibleOverloads(fnOverloads []Overload, argTy compatible = append(compatible, overloadMatch{params: overload, casts: overloadCasts}) } } - return compatible + return compatible, nil } // closestTypeMatches returns the set of overload candidates that have the most exact type matches for the arg types @@ -825,7 +854,7 @@ func getTypeIfRowType(isSRF bool, t *pgtypes.DoltgresType) *pgtypes.DoltgresType // ResolveDefaultValues adds missing arguments if there is any using the default value set on the parameter. // It checks if it's a valid SQL function that has fewer arguments than defined parameters. -func (c *CompiledFunction) ResolveDefaultValues(getDefExpr func(defExpr string) (sql.Expression, error)) error { +func (c *CompiledFunction) ResolveDefaultValues(ctx *sql.Context, getDefExpr func(defExpr string) (sql.Expression, error)) error { if !c.overload.Valid() { return nil } @@ -835,6 +864,10 @@ func (c *CompiledFunction) ResolveDefaultValues(getDefExpr func(defExpr string) } if len(c.Arguments) < len(sqlFunc.ParameterTypes) { + castsColl, err := core.GetCastsCollectionFromContext(ctx) + if err != nil { + return err + } for i, param := range sqlFunc.ParameterTypes { if i < len(c.Arguments) { if exprTypeId := c.Arguments[i].Type().(*pgtypes.DoltgresType).ID; exprTypeId != pgtypes.Unknown.ID && param.ID != exprTypeId { @@ -848,7 +881,11 @@ func (c *CompiledFunction) ResolveDefaultValues(getDefExpr func(defExpr string) return err } c.Arguments = append(c.Arguments, cdv) - c.overload.casts = append(c.overload.casts, GetImplicitCast(cdv.Type().(*pgtypes.DoltgresType), sqlFunc.ParameterTypes[i])) + implicitCast, err := castsColl.GetImplicitCast(ctx, cdv.Type().(*pgtypes.DoltgresType), sqlFunc.ParameterTypes[i]) + if err != nil { + return err + } + c.overload.casts = append(c.overload.casts, implicitCast) } } } diff --git a/server/functions/framework/intermediate_function.go b/server/functions/framework/intermediate_function.go index 7b0740992f..62cb09a906 100644 --- a/server/functions/framework/intermediate_function.go +++ b/server/functions/framework/intermediate_function.go @@ -31,5 +31,6 @@ func (f IntermediateFunction) Compile(name string, parameters ...sql.Expression) if f.Functions == nil { return nil } - return newCompiledFunctionInternal(name, parameters, f.Functions, f.AllOverloads, f.IsOperator, nil) + // TODO: sql.Context needs to be threaded through everything + return newCompiledFunctionInternal(nil, name, parameters, f.Functions, f.AllOverloads, f.IsOperator, nil) } diff --git a/server/functions/framework/interpreted_function.go b/server/functions/framework/interpreted_function.go index bf1e6c8f3d..b451aa6c83 100644 --- a/server/functions/framework/interpreted_function.go +++ b/server/functions/framework/interpreted_function.go @@ -23,6 +23,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/lib/pq" + "github.com/dolthub/doltgresql/core" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/server/plpgsql" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -141,39 +142,38 @@ func (iFunc InterpretedFunction) QuerySingleReturn(ctx *sql.Context, stack plpgs if rows[0][0] == nil { return nil, nil } - fromType, ok := sch[0].Type.(*pgtypes.DoltgresType) + sourceType, ok := sch[0].Type.(*pgtypes.DoltgresType) if !ok { // TODO: We ensure we have a DoltgresType, but we should also convert the value to // ensure it's in the correct form for the DoltgresType. This logic lives in // pgexpressions.GMSCast, but need to be extracted to avoid a dependency cycle // so it can be used here and from server.plpgsql. - fromType, err = pgtypes.FromGmsTypeToDoltgresType(sch[0].Type) + sourceType, err = pgtypes.FromGmsTypeToDoltgresType(sch[0].Type) if err != nil { return nil, err } } - castFunc := GetAssignmentCast(fromType, targetType) - if castFunc == nil { + castsColl, err := core.GetCastsCollectionFromContext(ctx) + if err != nil { + return nil, err + } + cast, err := castsColl.GetAssignmentCast(ctx, sourceType, targetType) + if err != nil { + return nil, err + } + if !cast.ID.IsValid() { // TODO: We're using assignment casting, but for some reason we have to use I/O casting here, which is incorrect? // We need to dig into this and figure out exactly what's happening, as this is "wrong" according to what // I understand. This lines up more with explicit casting, but it's supposed to be assignment. // Maybe there are specific rules for pgsql? - if fromType.TypCategory == pgtypes.TypeCategory_StringTypes { - castFunc = func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - if val == nil { - return nil, nil - } - str, err := fromType.IoOutput(ctx, val) - if err != nil { - return nil, err - } - return targetType.IoInput(ctx, str) - } + if sourceType.TypCategory == pgtypes.TypeCategory_StringTypes { + cast.ID = id.NewCast(sourceType.ID, targetType.ID) + cast.UseInOut = true } else { return nil, errors.New("no valid cast for return value") } } - return castFunc(subCtx, rows[0][0], targetType) + return cast.Eval(subCtx, rows[0][0], sourceType, targetType) }) } diff --git a/server/functions/framework/overloads.go b/server/functions/framework/overloads.go index 29cd2f4841..e6ae3144d9 100644 --- a/server/functions/framework/overloads.go +++ b/server/functions/framework/overloads.go @@ -19,6 +19,7 @@ import ( "github.com/cockroachdb/errors" + "github.com/dolthub/doltgresql/core/casts" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -198,7 +199,7 @@ func (o *Overload) coalesceVariadicValues(returnValues []any) []any { // as the type cast functions required to convert every argument to its appropriate parameter type type overloadMatch struct { params Overload - casts []pgtypes.TypeCastFunction + casts []casts.Cast } // Valid returns whether this overload is valid (has a callable function) diff --git a/server/initialization/initialization.go b/server/initialization/initialization.go index f15eb5becb..3a35594dbe 100644 --- a/server/initialization/initialization.go +++ b/server/initialization/initialization.go @@ -21,6 +21,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/servercfg" "github.com/dolthub/doltgresql/core" + "github.com/dolthub/doltgresql/core/casts" "github.com/dolthub/doltgresql/core/rootobject" "github.com/dolthub/doltgresql/server/analyzer" "github.com/dolthub/doltgresql/server/auth" @@ -53,7 +54,8 @@ func Initialize(dEnv *env.DoltEnv, cfg *doltgresservercfg.DoltgresConfig) { unary.Init() functions.Init() aggregate.Init() - cast.Init() + builtInCasts := casts.Init() + cast.Init(builtInCasts) framework.Initialize() servercfg.DefaultUnixSocketFilePath = cfgdetails.DefaultPostgresUnixSocketFilePath tables.Init() diff --git a/server/types/type.go b/server/types/type.go index 173bfd03d1..33100f3592 100644 --- a/server/types/type.go +++ b/server/types/type.go @@ -469,17 +469,9 @@ func (t *DoltgresType) Convert(ctx context.Context, v interface{}) (interface{}, return nil, sql.InRange, ErrUnhandledType.New(t.String(), v) } -// GetImplicitCast is a reference to the implicit cast logic in the functions/framework package, which we can't use -// here due to import cycles -var GetImplicitCast func(fromType *DoltgresType, toType *DoltgresType) TypeCastFunction - -// GetAssignmentCast is a reference to the assignment cast logic in the functions/framework package, which we can't use -// here due to import cycles -var GetAssignmentCast func(fromType *DoltgresType, toType *DoltgresType) TypeCastFunction - -// GetExplicitCast is a reference to the explicit cast logic in the functions/framework package, which we can't use -// here due to import cycles -var GetExplicitCast func(fromType *DoltgresType, toType *DoltgresType) TypeCastFunction +// GetAssignmentCast is a reference to the assignment cast logic in the core package, which we can't use here due to +// import cycles +var GetAssignmentCast func(ctx *sql.Context, fromType *DoltgresType, toType *DoltgresType) (Cast, error) // ConvertToType implements the types.ExtendedType interface. func (t *DoltgresType) ConvertToType(ctx *sql.Context, typ sql.ExtendedType, val any) (any, sql.ConvertInRange, error) { @@ -488,8 +480,11 @@ func (t *DoltgresType) ConvertToType(ctx *sql.Context, typ sql.ExtendedType, val return nil, sql.InRange, errors.Errorf("expected DoltgresType, got %T", typ) } - castFn := GetAssignmentCast(dt, t) - if castFn == nil { + cast, err := GetAssignmentCast(ctx, dt, t) + if err != nil { + return nil, sql.InRange, err + } + if cast == nil { // In the case that we have an unknown type string literal, we attempt to parse it with the target type's // input function // TODO: this is probably not the best place to perform this conversion, it would probably be better as an @@ -511,7 +506,7 @@ func (t *DoltgresType) ConvertToType(ctx *sql.Context, typ sql.ExtendedType, val return nil, sql.InRange, errors.Errorf("no assignment cast from %s to %s", dt.Name(), t.Name()) } - castResult, err := castFn(ctx, val, t) + castResult, err := cast.Eval(ctx, val, dt, t) if err != nil && errors.Is(err, ErrCastOutOfRange) { // TODO: this could be either an overflow or an underflow, we should distinguish return castResult, sql.Overflow, nil @@ -1148,4 +1143,4 @@ func (t *DoltgresType) ConvertSerialized(ctx context.Context, other val.TupleTyp // TypeCastFunction is a function that takes a value of a particular kind of type, and returns it as another kind of type. // The targetType given should match the "To" type used to obtain the cast. -type TypeCastFunction func(ctx *sql.Context, val any, targetType *DoltgresType) (any, error) +type TypeCastFunction func(ctx *sql.Context, val any, sourceType *DoltgresType, targetType *DoltgresType) (any, error) diff --git a/server/types/utils.go b/server/types/utils.go index ace04be8ea..1c3077dae6 100644 --- a/server/types/utils.go +++ b/server/types/utils.go @@ -62,9 +62,24 @@ type TypeCollection interface { GetType(context.Context, id.Type) (*DoltgresType, error) } +// Cast is an interface from the core package, redeclared here to get around import cycles. +type Cast interface { + Eval(ctx *sql.Context, val any, sourceType *DoltgresType, targetType *DoltgresType) (any, error) +} + +// CastsCollection is an interface from the core package, redeclared here to get around import cycles. +type CastsCollection interface { + GetExplicitCast(ctx context.Context, sourceType *DoltgresType, targetType *DoltgresType) (Cast, error) + GetAssignmentCast(ctx context.Context, sourceType *DoltgresType, targetType *DoltgresType) (Cast, error) + GetImplicitCast(ctx context.Context, sourceType *DoltgresType, targetType *DoltgresType) (Cast, error) +} + // GetTypesCollectionFromContext is a function from the core package, redeclared here to get around import cycles. var GetTypesCollectionFromContext func(*sql.Context) (TypeCollection, error) +// GetCastsCollectionFromContext is a function from the core package, redeclared here to get around import cycles. +var GetCastsCollectionFromContext func(*sql.Context) (CastsCollection, error) + // FromGmsType returns a DoltgresType that is most similar to the given GMS type. // It returns UNKNOWN type for GMS types that are not handled. func FromGmsType(typ sql.Type) *DoltgresType {