diff --git a/src/EntityFramework5.Npgsql/EntityFramework5.Npgsql.csproj b/src/EntityFramework5.Npgsql/EntityFramework5.Npgsql.csproj index df94040..e351766 100644 --- a/src/EntityFramework5.Npgsql/EntityFramework5.Npgsql.csproj +++ b/src/EntityFramework5.Npgsql/EntityFramework5.Npgsql.csproj @@ -15,7 +15,6 @@ true v4.5 - 5 true @@ -80,4 +79,4 @@ - \ No newline at end of file + diff --git a/src/EntityFramework6.Npgsql/NpgsqlMigrationSqlGenerator.cs b/src/EntityFramework6.Npgsql/NpgsqlMigrationSqlGenerator.cs index e00aa01..5cb03d4 100644 --- a/src/EntityFramework6.Npgsql/NpgsqlMigrationSqlGenerator.cs +++ b/src/EntityFramework6.Npgsql/NpgsqlMigrationSqlGenerator.cs @@ -30,6 +30,7 @@ using System.Data.Entity.Core.Metadata.Edm; using System.Data.Entity.Spatial; using System.Linq; +using JetBrains.Annotations; namespace Npgsql { @@ -38,10 +39,10 @@ namespace Npgsql /// public class NpgsqlMigrationSqlGenerator : MigrationSqlGenerator { - List migrationStatments; - private List addedSchemas; - private List addedExtensions; - private Version serverVersion; + List _migrationStatments; + List _addedSchemas; + List _addedExtensions; + Version _serverVersion; /// /// Generates the migration sql. @@ -49,117 +50,76 @@ public class NpgsqlMigrationSqlGenerator : MigrationSqlGenerator /// The operations in the migration /// The provider manifest token used for server versioning. public override IEnumerable Generate( - IEnumerable migrationOperations, string providerManifestToken) + [NotNull] IEnumerable migrationOperations, + [NotNull] string providerManifestToken) { - migrationStatments = new List(); - addedSchemas = new List(); - addedExtensions = new List(); - serverVersion = new Version(providerManifestToken); + _migrationStatments = new List(); + _addedSchemas = new List(); + _addedExtensions = new List(); + _serverVersion = new Version(providerManifestToken); Convert(migrationOperations); - return migrationStatments; + return _migrationStatments; } #region MigrationOperation to MigrationStatement converters #region General - protected virtual void Convert(IEnumerable operations) + protected virtual void Convert([NotNull] IEnumerable operations) { foreach (var migrationOperation in operations) { if (migrationOperation is AddColumnOperation) - { Convert(migrationOperation as AddColumnOperation); - } else if (migrationOperation is AlterColumnOperation) - { Convert(migrationOperation as AlterColumnOperation); - } else if (migrationOperation is CreateTableOperation) - { Convert(migrationOperation as CreateTableOperation); - } else if (migrationOperation is DropForeignKeyOperation) - { Convert(migrationOperation as DropForeignKeyOperation); - } else if (migrationOperation is DropTableOperation) - { Convert(migrationOperation as DropTableOperation); - } else if (migrationOperation is MoveTableOperation) - { Convert(migrationOperation as MoveTableOperation); - } else if (migrationOperation is RenameTableOperation) - { Convert(migrationOperation as RenameTableOperation); - } else if (migrationOperation is AddForeignKeyOperation) - { Convert(migrationOperation as AddForeignKeyOperation); - } else if (migrationOperation is DropIndexOperation) - { Convert(migrationOperation as DropIndexOperation); - } else if (migrationOperation is SqlOperation) - { AddStatment((migrationOperation as SqlOperation).Sql, (migrationOperation as SqlOperation).SuppressTransaction); - } else if (migrationOperation is AddPrimaryKeyOperation) - { Convert(migrationOperation as AddPrimaryKeyOperation); - } else if (migrationOperation is CreateIndexOperation) - { Convert(migrationOperation as CreateIndexOperation); - } else if (migrationOperation is RenameIndexOperation) - { Convert(migrationOperation as RenameIndexOperation); - } else if (migrationOperation is DropColumnOperation) - { Convert(migrationOperation as DropColumnOperation); - } else if (migrationOperation is DropPrimaryKeyOperation) - { Convert(migrationOperation as DropPrimaryKeyOperation); - } else if (migrationOperation is HistoryOperation) - { Convert(migrationOperation as HistoryOperation); - } else if (migrationOperation is RenameColumnOperation) - { Convert(migrationOperation as RenameColumnOperation); - } else if (migrationOperation is UpdateDatabaseOperation) - { Convert((migrationOperation as UpdateDatabaseOperation).Migrations as IEnumerable); - } else - { throw new NotImplementedException("Unhandled MigrationOperation " + migrationOperation.GetType().Name + " in " + GetType().Name); - } } } - private void AddStatment(string sql, bool suppressTransacion = false) - { - migrationStatments.Add(new MigrationStatement - { - Sql = sql, - SuppressTransaction = suppressTransacion, - BatchTerminator = ";" - }); - } + void AddStatment(string sql, bool suppressTransacion = false) + => _migrationStatments.Add(new MigrationStatement + { + Sql = sql, + SuppressTransaction = suppressTransacion, + BatchTerminator = ";" + }); - private void AddStatment(StringBuilder sql, bool suppressTransacion = false) - { - AddStatment(sql.ToString(), suppressTransacion); - } + void AddStatment(StringBuilder sql, bool suppressTransacion = false) + => AddStatment(sql.ToString(), suppressTransacion); #endregion @@ -170,7 +130,7 @@ protected virtual void Convert(HistoryOperation historyOperation) foreach (var command in historyOperation.CommandTrees) { var npgsqlCommand = new NpgsqlCommand(); - NpgsqlServices.Instance.TranslateCommandTree(serverVersion, command, npgsqlCommand, false); + NpgsqlServices.Instance.TranslateCommandTree(_serverVersion, command, npgsqlCommand, false); AddStatment(npgsqlCommand.CommandText); } } @@ -181,12 +141,10 @@ protected virtual void Convert(HistoryOperation historyOperation) protected virtual void Convert(CreateTableOperation createTableOperation) { - StringBuilder sql = new StringBuilder(); - int dotIndex = createTableOperation.Name.IndexOf('.'); + var sql = new StringBuilder(); + var dotIndex = createTableOperation.Name.IndexOf('.'); if (dotIndex != -1) - { CreateSchema(createTableOperation.Name.Remove(dotIndex)); - } sql.Append("CREATE TABLE "); AppendTableName(createTableOperation.Name, sql); @@ -222,7 +180,7 @@ protected virtual void Convert(CreateTableOperation createTableOperation) protected virtual void Convert(DropTableOperation dropTableOperation) { - StringBuilder sql = new StringBuilder(); + var sql = new StringBuilder(); sql.Append("DROP TABLE "); AppendTableName(dropTableOperation.Name, sql); AddStatment(sql); @@ -230,7 +188,7 @@ protected virtual void Convert(DropTableOperation dropTableOperation) protected virtual void Convert(RenameTableOperation renameTableOperation) { - StringBuilder sql = new StringBuilder(); + var sql = new StringBuilder(); sql.Append("ALTER TABLE "); AppendTableName(renameTableOperation.Name, sql); sql.Append(" RENAME TO "); @@ -238,15 +196,13 @@ protected virtual void Convert(RenameTableOperation renameTableOperation) AddStatment(sql); } - private void CreateSchema(string schemaName) + void CreateSchema(string schemaName) { - if (schemaName == "public" || addedSchemas.Contains(schemaName)) + if (schemaName == "public" || _addedSchemas.Contains(schemaName)) return; - addedSchemas.Add(schemaName); - if (serverVersion.Major > 9 || (serverVersion.Major == 9 && serverVersion.Minor >= 3)) - { + _addedSchemas.Add(schemaName); + if (_serverVersion.Major > 9 || (_serverVersion.Major == 9 && _serverVersion.Minor >= 3)) AddStatment("CREATE SCHEMA IF NOT EXISTS " + schemaName); - } else { //TODO: CREATE PROCEDURE that checks if schema already exists on servers < 9.3 @@ -254,7 +210,7 @@ private void CreateSchema(string schemaName) } } - //private void CreateExtension(string exensionName) + //void CreateExtension(string exensionName) //{ // //This is compatible only with server 9.1+ // if (serverVersion.Major > 9 || (serverVersion.Major == 9 && serverVersion.Minor >= 1)) @@ -268,7 +224,7 @@ private void CreateSchema(string schemaName) protected virtual void Convert(MoveTableOperation moveTableOperation) { - StringBuilder sql = new StringBuilder(); + var sql = new StringBuilder(); var newSchema = moveTableOperation.NewSchema ?? "dbo"; CreateSchema(newSchema); sql.Append("ALTER TABLE "); @@ -283,7 +239,7 @@ protected virtual void Convert(MoveTableOperation moveTableOperation) #region Columns protected virtual void Convert(AddColumnOperation addColumnOperation) { - StringBuilder sql = new StringBuilder(); + var sql = new StringBuilder(); sql.Append("ALTER TABLE "); AppendTableName(addColumnOperation.Table, sql); sql.Append(" ADD "); @@ -293,7 +249,7 @@ protected virtual void Convert(AddColumnOperation addColumnOperation) protected virtual void Convert(DropColumnOperation dropColumnOperation) { - StringBuilder sql = new StringBuilder(); + var sql = new StringBuilder(); sql.Append("ALTER TABLE "); AppendTableName(dropColumnOperation.Table, sql); sql.Append(" DROP COLUMN \""); @@ -304,7 +260,7 @@ protected virtual void Convert(DropColumnOperation dropColumnOperation) protected virtual void Convert(AlterColumnOperation alterColumnOperation) { - StringBuilder sql = new StringBuilder(); + var sql = new StringBuilder(); //TYPE AppendAlterColumn(alterColumnOperation, sql); @@ -359,13 +315,11 @@ protected virtual void Convert(AlterColumnOperation alterColumnOperation) } } else - { sql.Append(" DROP DEFAULT"); - } AddStatment(sql); } - private void AppendAlterColumn(AlterColumnOperation alterColumnOperation, StringBuilder sql) + void AppendAlterColumn(AlterColumnOperation alterColumnOperation, StringBuilder sql) { sql.Append("ALTER TABLE "); AppendTableName(alterColumnOperation.Table, sql); @@ -376,7 +330,7 @@ private void AppendAlterColumn(AlterColumnOperation alterColumnOperation, String protected virtual void Convert(RenameColumnOperation renameColumnOperation) { - StringBuilder sql = new StringBuilder(); + var sql = new StringBuilder(); sql.Append("ALTER TABLE "); AppendTableName(renameColumnOperation.Table, sql); sql.Append(" RENAME COLUMN \""); @@ -393,7 +347,7 @@ protected virtual void Convert(RenameColumnOperation renameColumnOperation) protected virtual void Convert(AddForeignKeyOperation addForeignKeyOperation) { - StringBuilder sql = new StringBuilder(); + var sql = new StringBuilder(); sql.Append("ALTER TABLE "); AppendTableName(addForeignKeyOperation.DependentTable, sql); sql.Append(" ADD CONSTRAINT \""); @@ -419,21 +373,19 @@ protected virtual void Convert(AddForeignKeyOperation addForeignKeyOperation) sql.Append(")"); if (addForeignKeyOperation.CascadeDelete) - { sql.Append(" ON DELETE CASCADE"); - } AddStatment(sql); } protected virtual void Convert(DropForeignKeyOperation dropForeignKeyOperation) { - StringBuilder sql = new StringBuilder(); + var sql = new StringBuilder(); sql.Append("ALTER TABLE "); AppendTableName(dropForeignKeyOperation.DependentTable, sql); - if (serverVersion.Major < 9) - sql.Append(" DROP CONSTRAINT \"");//TODO: http://piecesformthewhole.blogspot.com/2011/04/dropping-foreign-key-if-it-exists-in.html ? - else - sql.Append(" DROP CONSTRAINT IF EXISTS \""); + sql.Append(_serverVersion.Major < 9 + ? " DROP CONSTRAINT \"" //TODO: http://piecesformthewhole.blogspot.com/2011/04/dropping-foreign-key-if-it-exists-in.html ? + : " DROP CONSTRAINT IF EXISTS \"" + ); sql.Append(dropForeignKeyOperation.Name); sql.Append('"'); AddStatment(sql); @@ -441,7 +393,7 @@ protected virtual void Convert(DropForeignKeyOperation dropForeignKeyOperation) protected virtual void Convert(CreateIndexOperation createIndexOperation) { - StringBuilder sql = new StringBuilder(); + var sql = new StringBuilder(); sql.Append("CREATE "); if (createIndexOperation.IsUnique) @@ -465,16 +417,11 @@ protected virtual void Convert(CreateIndexOperation createIndexOperation) protected virtual void Convert(RenameIndexOperation renameIndexOperation) { - StringBuilder sql = new StringBuilder(); + var sql = new StringBuilder(); - if (serverVersion.Major > 9 || (serverVersion.Major == 9 && serverVersion.Minor >= 2)) - { - sql.Append("ALTER INDEX IF EXISTS "); - } - else - { - sql.Append("ALTER INDEX "); - } + sql.Append(_serverVersion.Major > 9 || (_serverVersion.Major == 9 && _serverVersion.Minor >= 2) + ? "ALTER INDEX IF EXISTS " + : "ALTER INDEX "); sql.Append(GetSchemaNameFromFullTableName(renameIndexOperation.Table)); sql.Append(".\""); @@ -485,13 +432,11 @@ protected virtual void Convert(RenameIndexOperation renameIndexOperation) AddStatment(sql); } - private string GetSchemaNameFromFullTableName(string tableFullName) + string GetSchemaNameFromFullTableName(string tableFullName) { - int dotIndex = tableFullName.IndexOf('.'); - if (dotIndex != -1) - return tableFullName.Remove(dotIndex); - else - return "dto";//TODO: Check always setting dto schema if no schema in table name is not bug + var dotIndex = tableFullName.IndexOf('.'); + return dotIndex != -1 ? tableFullName.Remove(dotIndex) : "dto"; + //TODO: Check always setting dto schema if no schema in table name is not bug } /// @@ -499,18 +444,15 @@ private string GetSchemaNameFromFullTableName(string tableFullName) /// /// /// - private string GetTableNameFromFullTableName(string tableName) + string GetTableNameFromFullTableName(string tableName) { - int dotIndex = tableName.IndexOf('.'); - if (dotIndex != -1) - return tableName.Substring(dotIndex + 1); - else - return tableName; + var dotIndex = tableName.IndexOf('.'); + return dotIndex != -1 ? tableName.Substring(dotIndex + 1) : tableName; } protected virtual void Convert(DropIndexOperation dropIndexOperation) { - StringBuilder sql = new StringBuilder(); + var sql = new StringBuilder(); sql.Append("DROP INDEX IF EXISTS "); sql.Append(GetSchemaNameFromFullTableName(dropIndexOperation.Table)); sql.Append(".\""); @@ -521,7 +463,7 @@ protected virtual void Convert(DropIndexOperation dropIndexOperation) protected virtual void Convert(AddPrimaryKeyOperation addPrimaryKeyOperation) { - StringBuilder sql = new StringBuilder(); + var sql = new StringBuilder(); sql.Append("ALTER TABLE "); AppendTableName(addPrimaryKeyOperation.Table, sql); sql.Append(" ADD CONSTRAINT \""); @@ -542,7 +484,7 @@ protected virtual void Convert(AddPrimaryKeyOperation addPrimaryKeyOperation) protected virtual void Convert(DropPrimaryKeyOperation dropPrimaryKeyOperation) { - StringBuilder sql = new StringBuilder(); + var sql = new StringBuilder(); sql.Append("ALTER TABLE "); AppendTableName(dropPrimaryKeyOperation.Table, sql); sql.Append(" DROP CONSTRAINT \""); @@ -557,7 +499,7 @@ protected virtual void Convert(DropPrimaryKeyOperation dropPrimaryKeyOperation) #region Misc functions - private void AppendColumn(ColumnModel column, StringBuilder sql) + void AppendColumn(ColumnModel column, StringBuilder sql) { sql.Append('"'); sql.Append(column.Name); @@ -581,19 +523,19 @@ private void AppendColumn(ColumnModel column, StringBuilder sql) { switch (column.Type) { - case PrimitiveTypeKind.Guid: - //CreateExtension("uuid-ossp"); - //If uuid-ossp is not enabled migrations throw exception - AddStatment("select * from uuid_generate_v4()"); - sql.Append(" DEFAULT uuid_generate_v4()"); - break; - case PrimitiveTypeKind.Byte: - case PrimitiveTypeKind.SByte: - case PrimitiveTypeKind.Int16: - case PrimitiveTypeKind.Int32: - case PrimitiveTypeKind.Int64: - //TODO: Add support for setting "SERIAL" - break; + case PrimitiveTypeKind.Guid: + //CreateExtension("uuid-ossp"); + //If uuid-ossp is not enabled migrations throw exception + AddStatment("select * from uuid_generate_v4()"); + sql.Append(" DEFAULT uuid_generate_v4()"); + break; + case PrimitiveTypeKind.Byte: + case PrimitiveTypeKind.SByte: + case PrimitiveTypeKind.Int16: + case PrimitiveTypeKind.Int32: + case PrimitiveTypeKind.Int64: + //TODO: Add support for setting "SERIAL" + break; } } else if (column.IsNullable != null @@ -606,152 +548,145 @@ private void AppendColumn(ColumnModel column, StringBuilder sql) } } - private void AppendColumnType(ColumnModel column, StringBuilder sql, bool setSerial) + void AppendColumnType(ColumnModel column, StringBuilder sql, bool setSerial) { if (column.StoreType != null) { sql.Append(column.StoreType); return; } + switch (column.Type) { - case PrimitiveTypeKind.Binary: - sql.Append("bytea"); - break; - case PrimitiveTypeKind.Boolean: - sql.Append("boolean"); - break; - case PrimitiveTypeKind.DateTime: - if (column.Precision != null) - sql.Append("timestamp(" + column.Precision + ")"); - else - sql.Append("timestamp"); - break; - case PrimitiveTypeKind.Decimal: - //TODO: Check if inside min/max - if (column.Precision == null && column.Scale == null) - { - sql.Append("numeric"); - } - else - { - sql.Append("numeric("); - sql.Append(column.Precision ?? 19); - sql.Append(','); - sql.Append(column.Scale ?? 4); - sql.Append(')'); - } - break; - case PrimitiveTypeKind.Double: - sql.Append("float8"); - break; - case PrimitiveTypeKind.Guid: - sql.Append("uuid"); - break; - case PrimitiveTypeKind.Single: - sql.Append("float4"); - break; - case PrimitiveTypeKind.Byte://postgres doesn't support sbyte :( - case PrimitiveTypeKind.SByte://postgres doesn't support sbyte :( - case PrimitiveTypeKind.Int16: - if (setSerial) - sql.Append(column.IsIdentity ? "serial2" : "int2"); - else - sql.Append("int2"); - break; - case PrimitiveTypeKind.Int32: - if (setSerial) - sql.Append(column.IsIdentity ? "serial4" : "int4"); - else - sql.Append("int4"); - break; - case PrimitiveTypeKind.Int64: - if (setSerial) - sql.Append(column.IsIdentity ? "serial8" : "int8"); - else - sql.Append("int8"); - break; - case PrimitiveTypeKind.String: - if (column.IsFixedLength.HasValue && - column.IsFixedLength.Value && - column.MaxLength.HasValue) - { - sql.AppendFormat("char({0})",column.MaxLength.Value); - } - else if (column.MaxLength.HasValue) - { - sql.AppendFormat("varchar({0})", column.MaxLength); - } - else - { - sql.Append("text"); - } - break; - case PrimitiveTypeKind.Time: - if (column.Precision != null) - { - sql.Append("interval("); - sql.Append(column.Precision); - sql.Append(')'); - } - else - { - sql.Append("interval"); - } - break; - case PrimitiveTypeKind.DateTimeOffset: - if (column.Precision != null) - { - sql.Append("timestamptz("); - sql.Append(column.Precision); - sql.Append(')'); - } - else - { - sql.Append("timestamptz"); - } - break; - case PrimitiveTypeKind.Geometry: - sql.Append("point"); - break; - //case PrimitiveTypeKind.Geography: - // break; - //case PrimitiveTypeKind.GeometryPoint: - // break; - //case PrimitiveTypeKind.GeometryLineString: - // break; - //case PrimitiveTypeKind.GeometryPolygon: - // break; - //case PrimitiveTypeKind.GeometryMultiPoint: - // break; - //case PrimitiveTypeKind.GeometryMultiLineString: - // break; - //case PrimitiveTypeKind.GeometryMultiPolygon: - // break; - //case PrimitiveTypeKind.GeometryCollection: - // break; - //case PrimitiveTypeKind.GeographyPoint: - // break; - //case PrimitiveTypeKind.GeographyLineString: - // break; - //case PrimitiveTypeKind.GeographyPolygon: - // break; - //case PrimitiveTypeKind.GeographyMultiPoint: - // break; - //case PrimitiveTypeKind.GeographyMultiLineString: - // break; - //case PrimitiveTypeKind.GeographyMultiPolygon: - // break; - //case PrimitiveTypeKind.GeographyCollection: - // break; - default: - throw new ArgumentException("Unhandled column type:" + column.Type); - } - } - - private void AppendTableName(string tableName, StringBuilder sql) - { - int dotIndex = tableName.IndexOf('.'); + case PrimitiveTypeKind.Binary: + sql.Append("bytea"); + break; + case PrimitiveTypeKind.Boolean: + sql.Append("boolean"); + break; + case PrimitiveTypeKind.DateTime: + sql.Append(column.Precision != null + ? $"timestamp({column.Precision})" + : "timestamp" + ); + break; + case PrimitiveTypeKind.Decimal: + //TODO: Check if inside min/max + if (column.Precision == null && column.Scale == null) + sql.Append("numeric"); + else + { + sql.Append("numeric("); + sql.Append(column.Precision ?? 19); + sql.Append(','); + sql.Append(column.Scale ?? 4); + sql.Append(')'); + } + break; + case PrimitiveTypeKind.Double: + sql.Append("float8"); + break; + case PrimitiveTypeKind.Guid: + sql.Append("uuid"); + break; + case PrimitiveTypeKind.Single: + sql.Append("float4"); + break; + case PrimitiveTypeKind.Byte://postgres doesn't support sbyte :( + case PrimitiveTypeKind.SByte://postgres doesn't support sbyte :( + case PrimitiveTypeKind.Int16: + sql.Append(setSerial + ? column.IsIdentity ? "serial2" : "int2" + : "int2" + ); + break; + case PrimitiveTypeKind.Int32: + sql.Append(setSerial + ? column.IsIdentity ? "serial4" : "int4" + : "int4" + ); + break; + case PrimitiveTypeKind.Int64: + sql.Append(setSerial + ? column.IsIdentity ? "serial8" : "int8" + : "int8" + ); + break; + case PrimitiveTypeKind.String: + if (column.IsFixedLength.HasValue && + column.IsFixedLength.Value && + column.MaxLength.HasValue) + { + sql.Append($"char({column.MaxLength.Value})"); + } + else if (column.MaxLength.HasValue) + sql.Append($"varchar({column.MaxLength})"); + else + sql.Append("text"); + break; + case PrimitiveTypeKind.Time: + if (column.Precision != null) + { + sql.Append("interval("); + sql.Append(column.Precision); + sql.Append(')'); + } + else + sql.Append("interval"); + break; + case PrimitiveTypeKind.DateTimeOffset: + if (column.Precision != null) + { + sql.Append("timestamptz("); + sql.Append(column.Precision); + sql.Append(')'); + } + else + { + sql.Append("timestamptz"); + } + break; + case PrimitiveTypeKind.Geometry: + sql.Append("point"); + break; + //case PrimitiveTypeKind.Geography: + // break; + //case PrimitiveTypeKind.GeometryPoint: + // break; + //case PrimitiveTypeKind.GeometryLineString: + // break; + //case PrimitiveTypeKind.GeometryPolygon: + // break; + //case PrimitiveTypeKind.GeometryMultiPoint: + // break; + //case PrimitiveTypeKind.GeometryMultiLineString: + // break; + //case PrimitiveTypeKind.GeometryMultiPolygon: + // break; + //case PrimitiveTypeKind.GeometryCollection: + // break; + //case PrimitiveTypeKind.GeographyPoint: + // break; + //case PrimitiveTypeKind.GeographyLineString: + // break; + //case PrimitiveTypeKind.GeographyPolygon: + // break; + //case PrimitiveTypeKind.GeographyMultiPoint: + // break; + //case PrimitiveTypeKind.GeographyMultiLineString: + // break; + //case PrimitiveTypeKind.GeographyMultiPolygon: + // break; + //case PrimitiveTypeKind.GeographyCollection: + // break; + default: + throw new ArgumentException("Unhandled column type:" + column.Type); + } + } + + void AppendTableName(string tableName, StringBuilder sql) + { + var dotIndex = tableName.IndexOf('.'); if (dotIndex == -1) { sql.Append('"'); @@ -772,12 +707,10 @@ private void AppendTableName(string tableName, StringBuilder sql) #region Value appenders - private void AppendValue(byte[] values, StringBuilder sql) + void AppendValue(byte[] values, StringBuilder sql) { if (values.Length == 0) - { sql.Append("''"); - } else { sql.Append("E'\\\\"); @@ -787,91 +720,73 @@ private void AppendValue(byte[] values, StringBuilder sql) } } - private void AppendValue(bool value, StringBuilder sql) + void AppendValue(bool value, StringBuilder sql) { sql.Append(value ? "TRUE" : "FALSE"); } - private void AppendValue(DateTime value, StringBuilder sql) + void AppendValue(DateTime value, StringBuilder sql) { sql.Append("'"); sql.Append(new NpgsqlTypes.NpgsqlDateTime(value)); sql.Append("'"); } - private void AppendValue(DateTimeOffset value, StringBuilder sql) + void AppendValue(DateTimeOffset value, StringBuilder sql) { sql.Append("'"); sql.Append(new NpgsqlTypes.NpgsqlDateTime(value.UtcDateTime)); sql.Append("'"); } - private void AppendValue(Guid value, StringBuilder sql) + void AppendValue(Guid value, StringBuilder sql) { sql.Append("'"); sql.Append(value); sql.Append("'"); } - private void AppendValue(string value, StringBuilder sql) + void AppendValue(string value, StringBuilder sql) { sql.Append("'"); sql.Append(value); sql.Append("'"); } - private void AppendValue(TimeSpan value, StringBuilder sql) + void AppendValue(TimeSpan value, StringBuilder sql) { sql.Append("'"); - sql.Append(new NpgsqlTypes.NpgsqlTimeSpan(value).ToString()); + sql.Append(new NpgsqlTypes.NpgsqlTimeSpan(value)); sql.Append("'"); } - private void AppendValue(DbGeometry value, StringBuilder sql) + void AppendValue(DbGeometry value, StringBuilder sql) { sql.Append("'"); sql.Append(value); sql.Append("'"); } - private void AppendValue(object value, StringBuilder sql) + void AppendValue(object value, StringBuilder sql) { if (value is byte[]) - { AppendValue((byte[])value, sql); - } else if (value is bool) - { AppendValue((bool)value, sql); - } else if (value is DateTime) - { AppendValue((DateTime)value, sql); - } else if (value is DateTimeOffset) - { AppendValue((DateTimeOffset)value, sql); - } else if (value is Guid) - { AppendValue((Guid)value, sql); - } else if (value is string) - { AppendValue((string)value, sql); - } else if (value is TimeSpan) - { AppendValue((TimeSpan)value, sql); - } else if (value is DbGeometry) - { AppendValue((DbGeometry)value, sql); - } else - { sql.Append(string.Format(CultureInfo.InvariantCulture, "{0}", value)); - } } #endregion diff --git a/src/EntityFramework6.Npgsql/NpgsqlProviderManifest.cs b/src/EntityFramework6.Npgsql/NpgsqlProviderManifest.cs index b990f72..3dfa842 100644 --- a/src/EntityFramework6.Npgsql/NpgsqlProviderManifest.cs +++ b/src/EntityFramework6.Npgsql/NpgsqlProviderManifest.cs @@ -37,321 +37,306 @@ #endif using System.Xml; using System.Data; +using JetBrains.Annotations; using NpgsqlTypes; namespace Npgsql { internal class NpgsqlProviderManifest : DbXmlEnabledProviderManifest { - private Version _version; - - public Version Version { get { return _version; } } + public Version Version { get; } public NpgsqlProviderManifest(string serverVersion) : base(CreateXmlReaderForResource("Npgsql.NpgsqlProviderManifest.Manifest.xml")) { - if (!Version.TryParse(serverVersion, out _version)) - { - _version = new Version(9, 5); - } + Version version; + Version = Version.TryParse(serverVersion, out version) + ? version + : new Version(9, 5); } - protected override XmlReader GetDbInformation(string informationType) + protected override XmlReader GetDbInformation([NotNull] string informationType) { - XmlReader xmlReader = null; - if (informationType == StoreSchemaDefinition) - { - xmlReader = CreateXmlReaderForResource("Npgsql.NpgsqlSchema.ssdl"); - } -#if !NET40 - else if (informationType == StoreSchemaDefinitionVersion3) - { - xmlReader = CreateXmlReaderForResource("Npgsql.NpgsqlSchemaV3.ssdl"); - } -#endif - else if (informationType == StoreSchemaMapping) - { - xmlReader = CreateXmlReaderForResource("Npgsql.NpgsqlSchema.msl"); - } - - if (xmlReader == null) - throw new ArgumentOutOfRangeException("informationType"); + return CreateXmlReaderForResource("Npgsql.NpgsqlSchema.ssdl"); + if (informationType == StoreSchemaDefinitionVersion3) + return CreateXmlReaderForResource("Npgsql.NpgsqlSchemaV3.ssdl"); + if (informationType == StoreSchemaMapping) + return CreateXmlReaderForResource("Npgsql.NpgsqlSchema.msl"); - return xmlReader; + throw new ArgumentOutOfRangeException(nameof(informationType)); } - private const string MaxLengthFacet = "MaxLength"; - private const string ScaleFacet = "Scale"; - private const string PrecisionFacet = "Precision"; - private const string FixedLengthFacet = "FixedLength"; + const string MaxLengthFacet = "MaxLength"; + const string ScaleFacet = "Scale"; + const string PrecisionFacet = "Precision"; + const string FixedLengthFacet = "FixedLength"; - internal static NpgsqlDbType GetNpgsqlDbType(PrimitiveTypeKind _primitiveType) + internal static NpgsqlDbType GetNpgsqlDbType(PrimitiveTypeKind primitiveType) { - switch (_primitiveType) + switch (primitiveType) { - case PrimitiveTypeKind.Binary: - return NpgsqlDbType.Bytea; - case PrimitiveTypeKind.Boolean: - return NpgsqlDbType.Boolean; - case PrimitiveTypeKind.Byte: - case PrimitiveTypeKind.SByte: - case PrimitiveTypeKind.Int16: - return NpgsqlDbType.Smallint; - case PrimitiveTypeKind.DateTime: - return NpgsqlDbType.Timestamp; - case PrimitiveTypeKind.DateTimeOffset: - return NpgsqlDbType.TimestampTZ; - case PrimitiveTypeKind.Decimal: - return NpgsqlDbType.Numeric; - case PrimitiveTypeKind.Double: - return NpgsqlDbType.Double; - case PrimitiveTypeKind.Int32: - return NpgsqlDbType.Integer; - case PrimitiveTypeKind.Int64: - return NpgsqlDbType.Bigint; - case PrimitiveTypeKind.Single: - return NpgsqlDbType.Real; - case PrimitiveTypeKind.Time: - return NpgsqlDbType.Interval; - case PrimitiveTypeKind.Guid: - return NpgsqlDbType.Uuid; - case PrimitiveTypeKind.String: - // Send strings as unknowns to be compatible with other datatypes than text - return NpgsqlDbType.Unknown; - default: - return NpgsqlDbType.Unknown; + case PrimitiveTypeKind.Binary: + return NpgsqlDbType.Bytea; + case PrimitiveTypeKind.Boolean: + return NpgsqlDbType.Boolean; + case PrimitiveTypeKind.Byte: + case PrimitiveTypeKind.SByte: + case PrimitiveTypeKind.Int16: + return NpgsqlDbType.Smallint; + case PrimitiveTypeKind.DateTime: + return NpgsqlDbType.Timestamp; + case PrimitiveTypeKind.DateTimeOffset: + return NpgsqlDbType.TimestampTZ; + case PrimitiveTypeKind.Decimal: + return NpgsqlDbType.Numeric; + case PrimitiveTypeKind.Double: + return NpgsqlDbType.Double; + case PrimitiveTypeKind.Int32: + return NpgsqlDbType.Integer; + case PrimitiveTypeKind.Int64: + return NpgsqlDbType.Bigint; + case PrimitiveTypeKind.Single: + return NpgsqlDbType.Real; + case PrimitiveTypeKind.Time: + return NpgsqlDbType.Interval; + case PrimitiveTypeKind.Guid: + return NpgsqlDbType.Uuid; + case PrimitiveTypeKind.String: + // Send strings as unknowns to be compatible with other datatypes than text + return NpgsqlDbType.Unknown; + default: + return NpgsqlDbType.Unknown; } } - public override TypeUsage GetEdmType(TypeUsage storeType) + public override TypeUsage GetEdmType([NotNull] TypeUsage storeType) { if (storeType == null) - throw new ArgumentNullException("storeType"); + throw new ArgumentNullException(nameof(storeType)); - string storeTypeName = storeType.EdmType.Name; - PrimitiveType primitiveType = StoreTypeNameToEdmPrimitiveType[storeTypeName]; + var storeTypeName = storeType.EdmType.Name; + var primitiveType = StoreTypeNameToEdmPrimitiveType[storeTypeName]; // TODO: come up with way to determin if unicode is used - bool isUnicode = true; + var isUnicode = true; Facet facet; switch (storeTypeName) { - case "bool": - case "int2": - case "int4": - case "int8": - case "float4": - case "float8": - case "uuid": - return TypeUsage.CreateDefaultTypeUsage(primitiveType); - case "numeric": + case "bool": + case "int2": + case "int4": + case "int8": + case "float4": + case "float8": + case "uuid": + return TypeUsage.CreateDefaultTypeUsage(primitiveType); + case "numeric": + { + byte scale; + byte precision; + if (storeType.Facets.TryGetValue(ScaleFacet, false, out facet) && + !facet.IsUnbounded && facet.Value != null) { - byte scale; - byte precision; - if (storeType.Facets.TryGetValue(ScaleFacet, false, out facet) && + scale = (byte)facet.Value; + if (storeType.Facets.TryGetValue(PrecisionFacet, false, out facet) && !facet.IsUnbounded && facet.Value != null) { - scale = (byte)facet.Value; - if (storeType.Facets.TryGetValue(PrecisionFacet, false, out facet) && - !facet.IsUnbounded && facet.Value != null) - { - precision = (byte)facet.Value; - return TypeUsage.CreateDecimalTypeUsage(primitiveType, precision, scale); - } + precision = (byte)facet.Value; + return TypeUsage.CreateDecimalTypeUsage(primitiveType, precision, scale); } - return TypeUsage.CreateDecimalTypeUsage(primitiveType); } - case "bpchar": - if (storeType.Facets.TryGetValue(MaxLengthFacet, false, out facet) && - !facet.IsUnbounded && facet.Value != null) - return TypeUsage.CreateStringTypeUsage(primitiveType, isUnicode, true, (int)facet.Value); - else - return TypeUsage.CreateStringTypeUsage(primitiveType, isUnicode, true); - case "varchar": - if (storeType.Facets.TryGetValue(MaxLengthFacet, false, out facet) && - !facet.IsUnbounded && facet.Value != null) - return TypeUsage.CreateStringTypeUsage(primitiveType, isUnicode, false, (int)facet.Value); - else - return TypeUsage.CreateStringTypeUsage(primitiveType, isUnicode, false); - case "text": - case "xml": + return TypeUsage.CreateDecimalTypeUsage(primitiveType); + } + case "bpchar": + if (storeType.Facets.TryGetValue(MaxLengthFacet, false, out facet) && + !facet.IsUnbounded && facet.Value != null) + return TypeUsage.CreateStringTypeUsage(primitiveType, isUnicode, true, (int)facet.Value); + else + return TypeUsage.CreateStringTypeUsage(primitiveType, isUnicode, true); + case "varchar": + if (storeType.Facets.TryGetValue(MaxLengthFacet, false, out facet) && + !facet.IsUnbounded && facet.Value != null) + return TypeUsage.CreateStringTypeUsage(primitiveType, isUnicode, false, (int)facet.Value); + else return TypeUsage.CreateStringTypeUsage(primitiveType, isUnicode, false); - case "timestamp": - // TODO: make sure the arguments are correct here - if (storeType.Facets.TryGetValue(PrecisionFacet, false, out facet) && - !facet.IsUnbounded && facet.Value != null) - { - return TypeUsage.CreateDateTimeTypeUsage(primitiveType, (byte)facet.Value); - } - else - { - return TypeUsage.CreateDateTimeTypeUsage(primitiveType, null); - } - case "date": - return TypeUsage.CreateDateTimeTypeUsage(primitiveType, 0); - case "timestamptz": - if (storeType.Facets.TryGetValue(PrecisionFacet, false, out facet) && - !facet.IsUnbounded && facet.Value != null) - { - return TypeUsage.CreateDateTimeOffsetTypeUsage(primitiveType, (byte)facet.Value); - } - else - { - return TypeUsage.CreateDateTimeOffsetTypeUsage(primitiveType, null); - } - case "time": - case "interval": - if (storeType.Facets.TryGetValue(PrecisionFacet, false, out facet) && + case "text": + case "xml": + return TypeUsage.CreateStringTypeUsage(primitiveType, isUnicode, false); + case "timestamp": + // TODO: make sure the arguments are correct here + if (storeType.Facets.TryGetValue(PrecisionFacet, false, out facet) && + !facet.IsUnbounded && facet.Value != null) + { + return TypeUsage.CreateDateTimeTypeUsage(primitiveType, (byte)facet.Value); + } + else + { + return TypeUsage.CreateDateTimeTypeUsage(primitiveType, null); + } + case "date": + return TypeUsage.CreateDateTimeTypeUsage(primitiveType, 0); + case "timestamptz": + if (storeType.Facets.TryGetValue(PrecisionFacet, false, out facet) && + !facet.IsUnbounded && facet.Value != null) + { + return TypeUsage.CreateDateTimeOffsetTypeUsage(primitiveType, (byte)facet.Value); + } + else + { + return TypeUsage.CreateDateTimeOffsetTypeUsage(primitiveType, null); + } + case "time": + case "interval": + if (storeType.Facets.TryGetValue(PrecisionFacet, false, out facet) && + !facet.IsUnbounded && facet.Value != null) + { + return TypeUsage.CreateTimeTypeUsage(primitiveType, (byte)facet.Value); + } + else + { + return TypeUsage.CreateTimeTypeUsage(primitiveType, null); + } + case "bytea": + { + if (storeType.Facets.TryGetValue(MaxLengthFacet, false, out facet) && !facet.IsUnbounded && facet.Value != null) { - return TypeUsage.CreateTimeTypeUsage(primitiveType, (byte)facet.Value); - } - else - { - return TypeUsage.CreateTimeTypeUsage(primitiveType, null); - } - case "bytea": - { - if (storeType.Facets.TryGetValue(MaxLengthFacet, false, out facet) && - !facet.IsUnbounded && facet.Value != null) - { - return TypeUsage.CreateBinaryTypeUsage(primitiveType, false, (int)facet.Value); - } - return TypeUsage.CreateBinaryTypeUsage(primitiveType, false); - } - case "rowversion": - { - return TypeUsage.CreateBinaryTypeUsage(primitiveType, true, 8); + return TypeUsage.CreateBinaryTypeUsage(primitiveType, false, (int)facet.Value); } - //TypeUsage.CreateBinaryTypeUsage - //TypeUsage.CreateDateTimeTypeUsage - //TypeUsage.CreateDecimalTypeUsage - //TypeUsage.CreateStringTypeUsage + return TypeUsage.CreateBinaryTypeUsage(primitiveType, false); + } + case "rowversion": + { + return TypeUsage.CreateBinaryTypeUsage(primitiveType, true, 8); + } + //TypeUsage.CreateBinaryTypeUsage + //TypeUsage.CreateDateTimeTypeUsage + //TypeUsage.CreateDecimalTypeUsage + //TypeUsage.CreateStringTypeUsage } + throw new NotSupportedException("Not supported store type: " + storeTypeName); } - public override TypeUsage GetStoreType(TypeUsage edmType) + public override TypeUsage GetStoreType([NotNull] TypeUsage edmType) { if (edmType == null) - throw new ArgumentNullException("edmType"); + throw new ArgumentNullException(nameof(edmType)); - PrimitiveType primitiveType = edmType.EdmType as PrimitiveType; + var primitiveType = edmType.EdmType as PrimitiveType; if (primitiveType == null) throw new ArgumentException("Store does not support specified edm type"); // TODO: come up with way to determin if unicode is used - bool isUnicode = true; + var isUnicode = true; Facet facet; switch (primitiveType.PrimitiveTypeKind) { - case PrimitiveTypeKind.Boolean: - return TypeUsage.CreateDefaultTypeUsage(StoreTypeNameToStorePrimitiveType["bool"]); - case PrimitiveTypeKind.Int16: - return TypeUsage.CreateDefaultTypeUsage(StoreTypeNameToStorePrimitiveType["int2"]); - case PrimitiveTypeKind.Int32: - return TypeUsage.CreateDefaultTypeUsage(StoreTypeNameToStorePrimitiveType["int4"]); - case PrimitiveTypeKind.Int64: - return TypeUsage.CreateDefaultTypeUsage(StoreTypeNameToStorePrimitiveType["int8"]); - case PrimitiveTypeKind.Single: - return TypeUsage.CreateDefaultTypeUsage(StoreTypeNameToStorePrimitiveType["float4"]); - case PrimitiveTypeKind.Double: - return TypeUsage.CreateDefaultTypeUsage(StoreTypeNameToStorePrimitiveType["float8"]); - case PrimitiveTypeKind.Decimal: + case PrimitiveTypeKind.Boolean: + return TypeUsage.CreateDefaultTypeUsage(StoreTypeNameToStorePrimitiveType["bool"]); + case PrimitiveTypeKind.Int16: + return TypeUsage.CreateDefaultTypeUsage(StoreTypeNameToStorePrimitiveType["int2"]); + case PrimitiveTypeKind.Int32: + return TypeUsage.CreateDefaultTypeUsage(StoreTypeNameToStorePrimitiveType["int4"]); + case PrimitiveTypeKind.Int64: + return TypeUsage.CreateDefaultTypeUsage(StoreTypeNameToStorePrimitiveType["int8"]); + case PrimitiveTypeKind.Single: + return TypeUsage.CreateDefaultTypeUsage(StoreTypeNameToStorePrimitiveType["float4"]); + case PrimitiveTypeKind.Double: + return TypeUsage.CreateDefaultTypeUsage(StoreTypeNameToStorePrimitiveType["float8"]); + case PrimitiveTypeKind.Decimal: + { + byte scale; + byte precision; + if (edmType.Facets.TryGetValue(ScaleFacet, false, out facet) && + !facet.IsUnbounded && facet.Value != null) { - byte scale; - byte precision; - if (edmType.Facets.TryGetValue(ScaleFacet, false, out facet) && + scale = (byte)facet.Value; + if (edmType.Facets.TryGetValue(PrecisionFacet, false, out facet) && !facet.IsUnbounded && facet.Value != null) { - scale = (byte)facet.Value; - if (edmType.Facets.TryGetValue(PrecisionFacet, false, out facet) && - !facet.IsUnbounded && facet.Value != null) - { - precision = (byte)facet.Value; - return TypeUsage.CreateDecimalTypeUsage(StoreTypeNameToStorePrimitiveType["numeric"], precision, scale); - } + precision = (byte)facet.Value; + return TypeUsage.CreateDecimalTypeUsage(StoreTypeNameToStorePrimitiveType["numeric"], precision, scale); } - return TypeUsage.CreateDecimalTypeUsage(StoreTypeNameToStorePrimitiveType["numeric"]); } - case PrimitiveTypeKind.String: + return TypeUsage.CreateDecimalTypeUsage(StoreTypeNameToStorePrimitiveType["numeric"]); + } + case PrimitiveTypeKind.String: + { + // TODO: could get character, character varying, text + if (edmType.Facets.TryGetValue(FixedLengthFacet, false, out facet) && + !facet.IsUnbounded && facet.Value != null && (bool)facet.Value) { - // TODO: could get character, character varying, text - if (edmType.Facets.TryGetValue(FixedLengthFacet, false, out facet) && - !facet.IsUnbounded && facet.Value != null && (bool)facet.Value) - { - PrimitiveType characterPrimitive = StoreTypeNameToStorePrimitiveType["bpchar"]; - if (edmType.Facets.TryGetValue(MaxLengthFacet, false, out facet) && - !facet.IsUnbounded && facet.Value != null) - { - return TypeUsage.CreateStringTypeUsage(characterPrimitive, isUnicode, true, (int)facet.Value); - } - // this may not work well - return TypeUsage.CreateStringTypeUsage(characterPrimitive, isUnicode, true); - } + PrimitiveType characterPrimitive = StoreTypeNameToStorePrimitiveType["bpchar"]; if (edmType.Facets.TryGetValue(MaxLengthFacet, false, out facet) && !facet.IsUnbounded && facet.Value != null) { - return TypeUsage.CreateStringTypeUsage(StoreTypeNameToStorePrimitiveType["varchar"], isUnicode, false, (int)facet.Value); + return TypeUsage.CreateStringTypeUsage(characterPrimitive, isUnicode, true, (int)facet.Value); } - // assume text since it is not fixed length and has no max length - return TypeUsage.CreateStringTypeUsage(StoreTypeNameToStorePrimitiveType["text"], isUnicode, false); - } - case PrimitiveTypeKind.DateTime: - if (edmType.Facets.TryGetValue(PrecisionFacet, false, out facet) && - !facet.IsUnbounded && facet.Value != null) - { - return TypeUsage.CreateDateTimeTypeUsage(StoreTypeNameToStorePrimitiveType["timestamp"], (byte)facet.Value); - } - else - { - return TypeUsage.CreateDateTimeTypeUsage(StoreTypeNameToStorePrimitiveType["timestamp"], null); + // this may not work well + return TypeUsage.CreateStringTypeUsage(characterPrimitive, isUnicode, true); } - case PrimitiveTypeKind.DateTimeOffset: - if (edmType.Facets.TryGetValue(PrecisionFacet, false, out facet) && + if (edmType.Facets.TryGetValue(MaxLengthFacet, false, out facet) && !facet.IsUnbounded && facet.Value != null) { - return TypeUsage.CreateDateTimeOffsetTypeUsage(StoreTypeNameToStorePrimitiveType["timestamptz"], (byte)facet.Value); - } - else - { - return TypeUsage.CreateDateTimeOffsetTypeUsage(StoreTypeNameToStorePrimitiveType["timestamptz"], null); + return TypeUsage.CreateStringTypeUsage(StoreTypeNameToStorePrimitiveType["varchar"], isUnicode, false, (int)facet.Value); } - case PrimitiveTypeKind.Time: - if (edmType.Facets.TryGetValue(PrecisionFacet, false, out facet) && + // assume text since it is not fixed length and has no max length + return TypeUsage.CreateStringTypeUsage(StoreTypeNameToStorePrimitiveType["text"], isUnicode, false); + } + case PrimitiveTypeKind.DateTime: + if (edmType.Facets.TryGetValue(PrecisionFacet, false, out facet) && + !facet.IsUnbounded && facet.Value != null) + { + return TypeUsage.CreateDateTimeTypeUsage(StoreTypeNameToStorePrimitiveType["timestamp"], (byte)facet.Value); + } + else + { + return TypeUsage.CreateDateTimeTypeUsage(StoreTypeNameToStorePrimitiveType["timestamp"], null); + } + case PrimitiveTypeKind.DateTimeOffset: + if (edmType.Facets.TryGetValue(PrecisionFacet, false, out facet) && + !facet.IsUnbounded && facet.Value != null) + { + return TypeUsage.CreateDateTimeOffsetTypeUsage(StoreTypeNameToStorePrimitiveType["timestamptz"], (byte)facet.Value); + } + else + { + return TypeUsage.CreateDateTimeOffsetTypeUsage(StoreTypeNameToStorePrimitiveType["timestamptz"], null); + } + case PrimitiveTypeKind.Time: + if (edmType.Facets.TryGetValue(PrecisionFacet, false, out facet) && + !facet.IsUnbounded && facet.Value != null) + { + return TypeUsage.CreateTimeTypeUsage(StoreTypeNameToStorePrimitiveType["interval"], (byte)facet.Value); + } + else + { + return TypeUsage.CreateTimeTypeUsage(StoreTypeNameToStorePrimitiveType["interval"], null); + } + case PrimitiveTypeKind.Binary: + { + if (edmType.Facets.TryGetValue(MaxLengthFacet, false, out facet) && !facet.IsUnbounded && facet.Value != null) { - return TypeUsage.CreateTimeTypeUsage(StoreTypeNameToStorePrimitiveType["interval"], (byte)facet.Value); - } - else - { - return TypeUsage.CreateTimeTypeUsage(StoreTypeNameToStorePrimitiveType["interval"], null); + return TypeUsage.CreateBinaryTypeUsage(StoreTypeNameToStorePrimitiveType["bytea"], false, (int)facet.Value); } - case PrimitiveTypeKind.Binary: - { - if (edmType.Facets.TryGetValue(MaxLengthFacet, false, out facet) && - !facet.IsUnbounded && facet.Value != null) - { - return TypeUsage.CreateBinaryTypeUsage(StoreTypeNameToStorePrimitiveType["bytea"], false, (int)facet.Value); - } - return TypeUsage.CreateBinaryTypeUsage(StoreTypeNameToStorePrimitiveType["bytea"], false); - } - case PrimitiveTypeKind.Guid: - return TypeUsage.CreateDefaultTypeUsage(StoreTypeNameToStorePrimitiveType["uuid"]); - case PrimitiveTypeKind.Byte: - case PrimitiveTypeKind.SByte: - return TypeUsage.CreateDefaultTypeUsage(StoreTypeNameToStorePrimitiveType["int2"]); + return TypeUsage.CreateBinaryTypeUsage(StoreTypeNameToStorePrimitiveType["bytea"], false); + } + case PrimitiveTypeKind.Guid: + return TypeUsage.CreateDefaultTypeUsage(StoreTypeNameToStorePrimitiveType["uuid"]); + case PrimitiveTypeKind.Byte: + case PrimitiveTypeKind.SByte: + return TypeUsage.CreateDefaultTypeUsage(StoreTypeNameToStorePrimitiveType["int2"]); } throw new NotSupportedException("Not supported edm type: " + edmType); } - private static XmlReader CreateXmlReaderForResource(string resourceName) - { - return XmlReader.Create(System.Reflection.Assembly.GetAssembly(typeof(NpgsqlProviderManifest)).GetManifestResourceStream(resourceName)); - } + static XmlReader CreateXmlReaderForResource(string resourceName) + => XmlReader.Create(System.Reflection.Assembly.GetAssembly(typeof(NpgsqlProviderManifest)).GetManifestResourceStream(resourceName)); public override bool SupportsEscapingLikeArgument(out char escapeCharacter) { @@ -359,35 +344,27 @@ public override bool SupportsEscapingLikeArgument(out char escapeCharacter) return true; } - public override string EscapeLikeArgument(string argument) - { - return argument.Replace("\\","\\\\").Replace("%", "\\%").Replace("_", "\\_"); - } + public override string EscapeLikeArgument([NotNull] string argument) + => argument.Replace("\\","\\\\").Replace("%", "\\%").Replace("_", "\\_"); #if ENTITIES6 - public override bool SupportsInExpression() - { - return true; - } + public override bool SupportsInExpression() => true; public override ReadOnlyCollection GetStoreFunctions() - { - var functions = new List(); - - functions.AddRange( - typeof(NpgsqlTextFunctions).GetTypeInfo() - .GetMethods(BindingFlags.Public | BindingFlags.Static) - .Select(x => new { Method = x, DbFunction = x.GetCustomAttribute() }) - .Where(x => x.DbFunction != null) - .Select(x => CreateFullTextEdmFunction(x.Method, x.DbFunction))); + => typeof(NpgsqlTextFunctions).GetTypeInfo() + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .Select(x => new { Method = x, DbFunction = x.GetCustomAttribute() }) + .Where(x => x.DbFunction != null) + .Select(x => CreateFullTextEdmFunction(x.Method, x.DbFunction)) + .ToList() + .AsReadOnly(); - return functions.AsReadOnly(); - } - - private static EdmFunction CreateFullTextEdmFunction(MethodInfo method, DbFunctionAttribute dbFunctionInfo) + static EdmFunction CreateFullTextEdmFunction([NotNull] MethodInfo method, [NotNull] DbFunctionAttribute dbFunctionInfo) { - if (method == null) throw new ArgumentNullException("method"); - if (dbFunctionInfo == null) throw new ArgumentNullException("dbFunctionInfo"); + if (method == null) + throw new ArgumentNullException(nameof(method)); + if (dbFunctionInfo == null) + throw new ArgumentNullException(nameof(dbFunctionInfo)); return EdmFunction.Create( dbFunctionInfo.FunctionName, @@ -418,24 +395,19 @@ private static EdmFunction CreateFullTextEdmFunction(MethodInfo method, DbFuncti new List()); } - private static EdmType MapTypeToEdmType(Type type) + static EdmType MapTypeToEdmType(Type type) { var fromClrType = PrimitiveType .GetEdmPrimitiveTypes() .FirstOrDefault(t => t.ClrEquivalentType == type); if (fromClrType != null) - { return fromClrType; - } if (type.IsEnum) - { return MapTypeToEdmType(Enum.GetUnderlyingType(type)); - } - throw new NotSupportedException( - string.Format("Unsupported type for mapping to EdmType: {0}", type.FullName)); + throw new NotSupportedException($"Unsupported type for mapping to EdmType: {type.FullName}"); } #endif } diff --git a/src/EntityFramework6.Npgsql/NpgsqlServices.cs b/src/EntityFramework6.Npgsql/NpgsqlServices.cs index 9fe793d..cebda62 100644 --- a/src/EntityFramework6.Npgsql/NpgsqlServices.cs +++ b/src/EntityFramework6.Npgsql/NpgsqlServices.cs @@ -22,8 +22,8 @@ #endregion using System; -using System.Collections.Generic; using System.Text; +using JetBrains.Annotations; #if ENTITIES6 using System.Data.Entity.Core.Common; using System.Data.Entity.Core.Common.CommandTrees; @@ -39,6 +39,8 @@ using DbConnection = System.Data.Common.DbConnection; using DbCommand = System.Data.Common.DbCommand; +#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member + namespace Npgsql { #if ENTITIES6 @@ -47,38 +49,33 @@ public class NpgsqlServices : DbProviderServices internal class NpgsqlServices : DbProviderServices #endif { - private static readonly NpgsqlServices _instance = new NpgsqlServices(); + public static NpgsqlServices Instance { get; } = new NpgsqlServices(); #if ENTITIES6 public NpgsqlServices() { AddDependencyResolver(new SingletonDependencyResolver>( - () => new NpgsqlMigrationSqlGenerator(), "Npgsql")); + () => new NpgsqlMigrationSqlGenerator(), nameof(Npgsql))); } #endif - public static NpgsqlServices Instance - { - get { return _instance; } - } - - protected override DbCommandDefinition CreateDbCommandDefinition(DbProviderManifest providerManifest, DbCommandTree commandTree) - { - return CreateCommandDefinition(CreateDbCommand(((NpgsqlProviderManifest)providerManifest).Version, commandTree)); - } + protected override DbCommandDefinition CreateDbCommandDefinition([NotNull] DbProviderManifest providerManifest, [NotNull] DbCommandTree commandTree) + => CreateCommandDefinition(CreateDbCommand(((NpgsqlProviderManifest)providerManifest).Version, commandTree)); internal DbCommand CreateDbCommand(Version serverVersion, DbCommandTree commandTree) { if (commandTree == null) - throw new ArgumentNullException("commandTree"); + throw new ArgumentNullException(nameof(commandTree)); - NpgsqlCommand command = new NpgsqlCommand(); + var command = new NpgsqlCommand(); - foreach (KeyValuePair parameter in commandTree.Parameters) + foreach (var parameter in commandTree.Parameters) { - NpgsqlParameter dbParameter = new NpgsqlParameter(); - dbParameter.ParameterName = parameter.Key; - dbParameter.NpgsqlDbType = NpgsqlProviderManifest.GetNpgsqlDbType(((PrimitiveType)parameter.Value.EdmType).PrimitiveTypeKind); + var dbParameter = new NpgsqlParameter + { + ParameterName = parameter.Key, + NpgsqlDbType = NpgsqlProviderManifest.GetNpgsqlDbType(((PrimitiveType)parameter.Value.EdmType).PrimitiveTypeKind) + }; command.Parameters.Add(dbParameter); } @@ -89,76 +86,66 @@ internal DbCommand CreateDbCommand(Version serverVersion, DbCommandTree commandT internal void TranslateCommandTree(Version serverVersion, DbCommandTree commandTree, DbCommand command, bool createParametersForNonSelect = true) { - SqlBaseGenerator sqlGenerator = null; + SqlBaseGenerator sqlGenerator; DbQueryCommandTree select; DbInsertCommandTree insert; DbUpdateCommandTree update; DbDeleteCommandTree delete; if ((select = commandTree as DbQueryCommandTree) != null) - { sqlGenerator = new SqlSelectGenerator(select); - } else if ((insert = commandTree as DbInsertCommandTree) != null) - { sqlGenerator = new SqlInsertGenerator(insert); - } else if ((update = commandTree as DbUpdateCommandTree) != null) - { sqlGenerator = new SqlUpdateGenerator(update); - } else if ((delete = commandTree as DbDeleteCommandTree) != null) - { sqlGenerator = new SqlDeleteGenerator(delete); - } else { // TODO: get a message (unsupported DbCommandTree type) throw new ArgumentException(); } - sqlGenerator._createParametersForConstants = select != null ? false : createParametersForNonSelect; - sqlGenerator._command = (NpgsqlCommand)command; + sqlGenerator.CreateParametersForConstants = select == null && createParametersForNonSelect; + sqlGenerator.Command = (NpgsqlCommand)command; sqlGenerator.Version = serverVersion; sqlGenerator.BuildCommand(command); } - protected override string GetDbProviderManifestToken(DbConnection connection) + protected override string GetDbProviderManifestToken([NotNull] DbConnection connection) { if (connection == null) - throw new ArgumentNullException("connection"); - string serverVersion = ""; - UsingPostgresDBConnection((NpgsqlConnection)connection, conn => - { + throw new ArgumentNullException(nameof(connection)); + + var serverVersion = ""; + UsingPostgresDbConnection((NpgsqlConnection)connection, conn => { serverVersion = conn.ServerVersion; }); return serverVersion; } - protected override DbProviderManifest GetDbProviderManifest(string versionHint) + protected override DbProviderManifest GetDbProviderManifest([NotNull] string versionHint) { if (versionHint == null) - throw new ArgumentNullException("versionHint"); + throw new ArgumentNullException(nameof(versionHint)); return new NpgsqlProviderManifest(versionHint); } #if ENTITIES6 - protected override bool DbDatabaseExists(DbConnection connection, int? commandTimeout, StoreItemCollection storeItemCollection) + protected override bool DbDatabaseExists([NotNull] DbConnection connection, int? commandTimeout, [NotNull] StoreItemCollection storeItemCollection) { - bool exists = false; - UsingPostgresDBConnection((NpgsqlConnection)connection, conn => + var exists = false; + UsingPostgresDbConnection((NpgsqlConnection)connection, conn => { - using (NpgsqlCommand command = new NpgsqlCommand("select count(*) from pg_catalog.pg_database where datname = '" + connection.Database + "';", conn)) - { + using (var command = new NpgsqlCommand("select count(*) from pg_catalog.pg_database where datname = '" + connection.Database + "';", conn)) exists = Convert.ToInt32(command.ExecuteScalar()) > 0; - } }); return exists; } - protected override void DbCreateDatabase(DbConnection connection, int? commandTimeout, StoreItemCollection storeItemCollection) + protected override void DbCreateDatabase([NotNull] DbConnection connection, int? commandTimeout, [NotNull] StoreItemCollection storeItemCollection) { - UsingPostgresDBConnection((NpgsqlConnection)connection, conn => + UsingPostgresDbConnection((NpgsqlConnection)connection, conn => { var sb = new StringBuilder(); sb.Append("CREATE DATABASE \""); @@ -171,28 +158,24 @@ protected override void DbCreateDatabase(DbConnection connection, int? commandTi sb.Append("\""); } - using (NpgsqlCommand command = new NpgsqlCommand(sb.ToString(), conn)) - { + using (var command = new NpgsqlCommand(sb.ToString(), conn)) command.ExecuteNonQuery(); - } }); } - protected override void DbDeleteDatabase(DbConnection connection, int? commandTimeout, StoreItemCollection storeItemCollection) + protected override void DbDeleteDatabase([NotNull] DbConnection connection, int? commandTimeout, [NotNull] StoreItemCollection storeItemCollection) { - UsingPostgresDBConnection((NpgsqlConnection)connection, conn => + UsingPostgresDbConnection((NpgsqlConnection)connection, conn => { //Close all connections in pool or exception "database used by another user appears" NpgsqlConnection.ClearAllPools(); - using (NpgsqlCommand command = new NpgsqlCommand("DROP DATABASE \"" + connection.Database + "\";", conn)) - { + using (var command = new NpgsqlCommand("DROP DATABASE \"" + connection.Database + "\";", conn)) command.ExecuteNonQuery(); - } }); } #endif - private static void UsingPostgresDBConnection(NpgsqlConnection connection, Action action) + static void UsingPostgresDbConnection(NpgsqlConnection connection, Action action) { var connectionBuilder = new NpgsqlConnectionStringBuilder(connection.ConnectionString) { diff --git a/src/EntityFramework6.Npgsql/SqlGenerators/PendingProjectsNode.cs b/src/EntityFramework6.Npgsql/SqlGenerators/PendingProjectsNode.cs index 9f5bbc4..646fba6 100644 --- a/src/EntityFramework6.Npgsql/SqlGenerators/PendingProjectsNode.cs +++ b/src/EntityFramework6.Npgsql/SqlGenerators/PendingProjectsNode.cs @@ -21,10 +21,7 @@ // TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. #endregion -using System; using System.Collections.Generic; -using System.Linq; -using System.Text; namespace Npgsql.SqlGenerators { @@ -51,23 +48,18 @@ internal class PendingProjectsNode { public readonly List Selects = new List(); public PendingProjectsNode JoinParent { get; set; } - public string TopName - { - get - { - return Selects[0].AsName; - } - } + public string TopName => Selects[0].AsName; public PendingProjectsNode(string asName, InputExpression exp) { Selects.Add(new NameAndInputExpression(asName, exp)); } + public void Add(string asName, InputExpression exp) { Selects.Add(new NameAndInputExpression(asName, exp)); } - public NameAndInputExpression Last { get { return Selects[Selects.Count - 1]; } } + public NameAndInputExpression Last => Selects[Selects.Count - 1]; } } diff --git a/src/EntityFramework6.Npgsql/SqlGenerators/SqlBaseGenerator.cs b/src/EntityFramework6.Npgsql/SqlGenerators/SqlBaseGenerator.cs index dd6435e..0dbccdd 100644 --- a/src/EntityFramework6.Npgsql/SqlGenerators/SqlBaseGenerator.cs +++ b/src/EntityFramework6.Npgsql/SqlGenerators/SqlBaseGenerator.cs @@ -24,6 +24,7 @@ using System; using System.Collections.Generic; using System.Data.Common; +using System.Diagnostics; #if ENTITIES6 using System.Globalization; using System.Data.Entity.Core.Common.CommandTrees; @@ -33,25 +34,33 @@ using System.Data.Metadata.Edm; #endif using System.Linq; -using Npgsql; -using NpgsqlTypes; +using JetBrains.Annotations; namespace Npgsql.SqlGenerators { internal abstract class SqlBaseGenerator : DbExpressionVisitor { - internal NpgsqlCommand _command; - internal bool _createParametersForConstants; - private Version _version; - internal Version Version { get { return _version; } set { _version = value; _useNewPrecedences = value >= new Version(9, 5); } } - private bool _useNewPrecedences; - - protected Dictionary _refToNode = new Dictionary(); - protected HashSet _currentExpressions = new HashSet(); - protected uint _aliasCounter = 0; - protected uint _parameterCount = 0; - - private static Dictionary AggregateFunctionNames = new Dictionary() + internal NpgsqlCommand Command; + internal bool CreateParametersForConstants; + bool _useNewPrecedences; + + protected Dictionary RefToNode = new Dictionary(); + protected HashSet CurrentExpressions = new HashSet(); + protected uint AliasCounter; + protected uint ParameterCount; + + internal Version Version + { + get { return _version; } + set + { + _version = value; + _useNewPrecedences = value >= new Version(9, 5); + } + } + Version _version; + + static readonly Dictionary AggregateFunctionNames = new Dictionary() { {"Avg","avg"}, {"Count","count"}, @@ -66,7 +75,7 @@ internal abstract class SqlBaseGenerator : DbExpressionVisitor BinaryOperatorFunctionNames = new Dictionary() + static readonly Dictionary BinaryOperatorFunctionNames = new Dictionary() { {"@@",Operator.QueryMatch}, {"operator_tsquery_and",Operator.QueryAnd}, @@ -76,363 +85,335 @@ internal abstract class SqlBaseGenerator : DbExpressionVisitor CurrentExpressions.Add(n.Last.Exp); + void LeaveExpression(PendingProjectsNode n) => CurrentExpressions.Remove(n.Last.Exp); - protected string NextAlias() - { - return "Alias" + _aliasCounter++; - } + protected string NextAlias() => "Alias" + AliasCounter++; - private bool IsCompatible(InputExpression child, DbExpressionKind parentKind) + bool IsCompatible(InputExpression child, DbExpressionKind parentKind) { switch (parentKind) { - case DbExpressionKind.Filter: - return - child.Projection == null && - child.GroupBy == null && - child.Skip == null && - child.Limit == null; - case DbExpressionKind.GroupBy: - return - child.Projection == null && - child.GroupBy == null && - child.Distinct == false && - child.OrderBy == null && - child.Skip == null && - child.Limit == null; - case DbExpressionKind.Distinct: - return - child.OrderBy == null && - child.Skip == null && - child.Limit == null; - case DbExpressionKind.Sort: - return - child.Projection == null && - child.GroupBy == null && - child.Skip == null && - child.Limit == null; - case DbExpressionKind.Skip: - return - child.Projection == null && - child.Skip == null && - child.Limit == null; - case DbExpressionKind.Project: - return - child.Projection == null && - child.Distinct == false; - // Limit and NewInstance are always true - default: - throw new ArgumentException("Unexpected parent expression kind"); + case DbExpressionKind.Filter: + return + child.Projection == null && + child.GroupBy == null && + child.Skip == null && + child.Limit == null; + case DbExpressionKind.GroupBy: + return + child.Projection == null && + child.GroupBy == null && + child.Distinct == false && + child.OrderBy == null && + child.Skip == null && + child.Limit == null; + case DbExpressionKind.Distinct: + return + child.OrderBy == null && + child.Skip == null && + child.Limit == null; + case DbExpressionKind.Sort: + return + child.Projection == null && + child.GroupBy == null && + child.Skip == null && + child.Limit == null; + case DbExpressionKind.Skip: + return + child.Projection == null && + child.Skip == null && + child.Limit == null; + case DbExpressionKind.Project: + return + child.Projection == null && + child.Distinct == false; + // Limit and NewInstance are always true + default: + throw new ArgumentException("Unexpected parent expression kind"); } } - private PendingProjectsNode GetInput(DbExpression expression, string childBindingName, string parentBindingName, DbExpressionKind parentKind) + PendingProjectsNode GetInput(DbExpression expression, string childBindingName, string parentBindingName, DbExpressionKind parentKind) { - PendingProjectsNode n = VisitInputWithBinding(expression, childBindingName); + var n = VisitInputWithBinding(expression, childBindingName); if (!IsCompatible(n.Last.Exp, parentKind)) - { n.Selects.Add(new NameAndInputExpression(parentBindingName, new InputExpression(n.Last.Exp, n.Last.AsName))); - } return n; } - private PendingProjectsNode VisitInputWithBinding(DbExpression expression, string bindingName) + PendingProjectsNode VisitInputWithBinding(DbExpression expression, string bindingName) { PendingProjectsNode n; switch (expression.ExpressionKind) { - case DbExpressionKind.Scan: - { - ScanExpression scan = (ScanExpression)expression.Accept(this); - InputExpression input = new InputExpression(scan, bindingName); - n = new PendingProjectsNode(bindingName, input); + case DbExpressionKind.Scan: + { + var scan = (ScanExpression)expression.Accept(this); + var input = new InputExpression(scan, bindingName); + n = new PendingProjectsNode(bindingName, input); - break; - } - case DbExpressionKind.Filter: - { - DbFilterExpression exp = (DbFilterExpression)expression; - n = GetInput(exp.Input.Expression, exp.Input.VariableName, bindingName, expression.ExpressionKind); - EnterExpression(n); - VisitedExpression pred = exp.Predicate.Accept(this); - if (n.Last.Exp.Where == null) - n.Last.Exp.Where = new WhereExpression(pred); - else - n.Last.Exp.Where.And(pred); - LeaveExpression(n); - - break; - } - case DbExpressionKind.Sort: - { - DbSortExpression exp = (DbSortExpression)expression; - n = GetInput(exp.Input.Expression, exp.Input.VariableName, bindingName, expression.ExpressionKind); - EnterExpression(n); - n.Last.Exp.OrderBy = new OrderByExpression(); - foreach (var order in exp.SortOrder) - { - n.Last.Exp.OrderBy.AppendSort(order.Expression.Accept(this), order.Ascending); - } - LeaveExpression(n); - - break; - } - case DbExpressionKind.Skip: - { - DbSkipExpression exp = (DbSkipExpression)expression; - n = GetInput(exp.Input.Expression, exp.Input.VariableName, bindingName, expression.ExpressionKind); - EnterExpression(n); - n.Last.Exp.OrderBy = new OrderByExpression(); - foreach (var order in exp.SortOrder) - { - n.Last.Exp.OrderBy.AppendSort(order.Expression.Accept(this), order.Ascending); - } - n.Last.Exp.Skip = new SkipExpression(exp.Count.Accept(this)); - LeaveExpression(n); - break; - } - case DbExpressionKind.Distinct: - { - DbDistinctExpression exp = (DbDistinctExpression)expression; - string childBindingName = NextAlias(); - - n = VisitInputWithBinding(exp.Argument, childBindingName); - if (!IsCompatible(n.Last.Exp, expression.ExpressionKind)) - { - InputExpression prev = n.Last.Exp; - string prevName = n.Last.AsName; - InputExpression input = new InputExpression(prev, prevName); - n.Selects.Add(new NameAndInputExpression(bindingName, input)); - - // We need to copy all the projected columns so the DISTINCT keyword will work on the correct columns - // A parent project expression is never compatible with this new expression, - // so these are the columns that finally will be projected, as wanted - foreach (ColumnExpression col in prev.Projection.Arguments) - { - input.ColumnsToProject.Add(new StringPair(prevName, col.Name), col.Name); - input.ProjectNewNames.Add(col.Name); - } - } - n.Last.Exp.Distinct = true; - break; - } - case DbExpressionKind.Limit: - { - DbLimitExpression exp = (DbLimitExpression)expression; - n = VisitInputWithBinding(exp.Argument, NextAlias()); - if (n.Last.Exp.Limit != null) - { - FunctionExpression least = new FunctionExpression("LEAST"); - least.AddArgument(n.Last.Exp.Limit.Arg); - least.AddArgument(exp.Limit.Accept(this)); - n.Last.Exp.Limit.Arg = least; - } - else - { - n.Last.Exp.Limit = new LimitExpression(exp.Limit.Accept(this)); - } - break; - } - case DbExpressionKind.NewInstance: + break; + } + case DbExpressionKind.Filter: + { + var exp = (DbFilterExpression)expression; + n = GetInput(exp.Input.Expression, exp.Input.VariableName, bindingName, expression.ExpressionKind); + EnterExpression(n); + var pred = exp.Predicate.Accept(this); + if (n.Last.Exp.Where == null) + n.Last.Exp.Where = new WhereExpression(pred); + else + n.Last.Exp.Where.And(pred); + LeaveExpression(n); + + break; + } + case DbExpressionKind.Sort: + { + var exp = (DbSortExpression)expression; + n = GetInput(exp.Input.Expression, exp.Input.VariableName, bindingName, expression.ExpressionKind); + EnterExpression(n); + n.Last.Exp.OrderBy = new OrderByExpression(); + foreach (var order in exp.SortOrder) + n.Last.Exp.OrderBy.AppendSort(order.Expression.Accept(this), order.Ascending); + LeaveExpression(n); + + break; + } + case DbExpressionKind.Skip: + { + var exp = (DbSkipExpression)expression; + n = GetInput(exp.Input.Expression, exp.Input.VariableName, bindingName, expression.ExpressionKind); + EnterExpression(n); + n.Last.Exp.OrderBy = new OrderByExpression(); + foreach (var order in exp.SortOrder) + n.Last.Exp.OrderBy.AppendSort(order.Expression.Accept(this), order.Ascending); + n.Last.Exp.Skip = new SkipExpression(exp.Count.Accept(this)); + LeaveExpression(n); + break; + } + case DbExpressionKind.Distinct: + { + var exp = (DbDistinctExpression)expression; + var childBindingName = NextAlias(); + + n = VisitInputWithBinding(exp.Argument, childBindingName); + if (!IsCompatible(n.Last.Exp, expression.ExpressionKind)) + { + var prev = n.Last.Exp; + var prevName = n.Last.AsName; + var input = new InputExpression(prev, prevName); + n.Selects.Add(new NameAndInputExpression(bindingName, input)); + + // We need to copy all the projected columns so the DISTINCT keyword will work on the correct columns + // A parent project expression is never compatible with this new expression, + // so these are the columns that finally will be projected, as wanted + foreach (ColumnExpression col in prev.Projection.Arguments) { - DbNewInstanceExpression exp = (DbNewInstanceExpression)expression; - if (exp.Arguments.Count == 1 && exp.Arguments[0].ExpressionKind == DbExpressionKind.Element) - { - n = VisitInputWithBinding(((DbElementExpression)exp.Arguments[0]).Argument, NextAlias()); - if (n.Last.Exp.Limit != null) - { - FunctionExpression least = new FunctionExpression("LEAST"); - least.AddArgument(n.Last.Exp.Limit.Arg); - least.AddArgument(new LiteralExpression("1")); - n.Last.Exp.Limit.Arg = least; - } - else - { - n.Last.Exp.Limit = new LimitExpression(new LiteralExpression("1")); - } - } - else if (exp.Arguments.Count >= 1) - { - LiteralExpression result = new LiteralExpression("("); - for (int i = 0; i < exp.Arguments.Count; ++i) - { - DbExpression arg = exp.Arguments[i]; - var visitedColumn = arg.Accept(this); - if (!(visitedColumn is ColumnExpression)) - visitedColumn = new ColumnExpression(visitedColumn, "C", arg.ResultType); - - result.Append(i == 0 ? "SELECT " : " UNION ALL SELECT "); - result.Append(visitedColumn); - } - result.Append(")"); - n = new PendingProjectsNode(bindingName, new InputExpression(result, bindingName)); - } - else - { - TypeUsage type = ((CollectionType)exp.ResultType.EdmType).TypeUsage; - LiteralExpression result = new LiteralExpression("(SELECT "); - result.Append(new CastExpression(new LiteralExpression("NULL"), GetDbType(type.EdmType))); - result.Append(" LIMIT 0)"); - n = new PendingProjectsNode(bindingName, new InputExpression(result, bindingName)); - } - break; + input.ColumnsToProject.Add(new StringPair(prevName, col.Name), col.Name); + input.ProjectNewNames.Add(col.Name); } - case DbExpressionKind.UnionAll: - case DbExpressionKind.Intersect: - case DbExpressionKind.Except: + } + n.Last.Exp.Distinct = true; + break; + } + case DbExpressionKind.Limit: + { + var exp = (DbLimitExpression)expression; + n = VisitInputWithBinding(exp.Argument, NextAlias()); + if (n.Last.Exp.Limit != null) + { + var least = new FunctionExpression("LEAST"); + least.AddArgument(n.Last.Exp.Limit.Arg); + least.AddArgument(exp.Limit.Accept(this)); + n.Last.Exp.Limit.Arg = least; + } + else + n.Last.Exp.Limit = new LimitExpression(exp.Limit.Accept(this)); + break; + } + case DbExpressionKind.NewInstance: + { + var exp = (DbNewInstanceExpression)expression; + if (exp.Arguments.Count == 1 && exp.Arguments[0].ExpressionKind == DbExpressionKind.Element) + { + n = VisitInputWithBinding(((DbElementExpression)exp.Arguments[0]).Argument, NextAlias()); + if (n.Last.Exp.Limit != null) { - DbBinaryExpression exp = (DbBinaryExpression)expression; - DbExpressionKind expKind = exp.ExpressionKind; - List list = new List(); - Action func = null; - func = e => - { - if (e.ExpressionKind == expKind && e.ExpressionKind != DbExpressionKind.Except) - { - DbBinaryExpression binaryExp = (DbBinaryExpression)e; - func(binaryExp.Left); - func(binaryExp.Right); - } - else - { - list.Add(VisitInputWithBinding(e, bindingName + "_" + list.Count).Last.Exp); - } - }; - func(exp.Left); - func(exp.Right); - InputExpression input = new InputExpression(new CombinedProjectionExpression(expression.ExpressionKind, list), bindingName); - n = new PendingProjectsNode(bindingName, input); - break; + var least = new FunctionExpression("LEAST"); + least.AddArgument(n.Last.Exp.Limit.Arg); + least.AddArgument(new LiteralExpression("1")); + n.Last.Exp.Limit.Arg = least; } - case DbExpressionKind.Project: + else + n.Last.Exp.Limit = new LimitExpression(new LiteralExpression("1")); + } + else if (exp.Arguments.Count >= 1) + { + var result = new LiteralExpression("("); + for (var i = 0; i < exp.Arguments.Count; ++i) { - DbProjectExpression exp = (DbProjectExpression)expression; - PendingProjectsNode child = VisitInputWithBinding(exp.Input.Expression, exp.Input.VariableName); - InputExpression input = child.Last.Exp; - bool enterScope = false; - if (!IsCompatible(input, expression.ExpressionKind)) - { - input = new InputExpression(input, child.Last.AsName); - } - else enterScope = true; - - if (enterScope) EnterExpression(child); - - input.Projection = new CommaSeparatedExpression(); - - DbNewInstanceExpression projection = (DbNewInstanceExpression)exp.Projection; - RowType rowType = projection.ResultType.EdmType as RowType; - for (int i = 0; i < rowType.Properties.Count && i < projection.Arguments.Count; ++i) - { - EdmProperty prop = rowType.Properties[i]; - input.Projection.Arguments.Add(new ColumnExpression(projection.Arguments[i].Accept(this), prop.Name, prop.TypeUsage)); - } - - if (enterScope) LeaveExpression(child); - - n = new PendingProjectsNode(bindingName, input); - break; + var arg = exp.Arguments[i]; + var visitedColumn = arg.Accept(this); + if (!(visitedColumn is ColumnExpression)) + visitedColumn = new ColumnExpression(visitedColumn, "C", arg.ResultType); + + result.Append(i == 0 ? "SELECT " : " UNION ALL SELECT "); + result.Append(visitedColumn); } - case DbExpressionKind.GroupBy: + result.Append(")"); + n = new PendingProjectsNode(bindingName, new InputExpression(result, bindingName)); + } + else + { + var type = ((CollectionType)exp.ResultType.EdmType).TypeUsage; + var result = new LiteralExpression("(SELECT "); + result.Append(new CastExpression(new LiteralExpression("NULL"), GetDbType(type.EdmType))); + result.Append(" LIMIT 0)"); + n = new PendingProjectsNode(bindingName, new InputExpression(result, bindingName)); + } + break; + } + case DbExpressionKind.UnionAll: + case DbExpressionKind.Intersect: + case DbExpressionKind.Except: + { + var exp = (DbBinaryExpression)expression; + var expKind = exp.ExpressionKind; + var list = new List(); + Action func = null; + func = e => + { + if (e.ExpressionKind == expKind && e.ExpressionKind != DbExpressionKind.Except) { - DbGroupByExpression exp = (DbGroupByExpression)expression; - PendingProjectsNode child = VisitInputWithBinding(exp.Input.Expression, exp.Input.VariableName); - - // I don't know why the input for GroupBy in EF have two names - _refToNode[exp.Input.GroupVariableName] = child; - - InputExpression input = child.Last.Exp; - bool enterScope = false; - if (!IsCompatible(input, expression.ExpressionKind)) - { - input = new InputExpression(input, child.Last.AsName); - } - else enterScope = true; - - if (enterScope) EnterExpression(child); - - input.Projection = new CommaSeparatedExpression(); - - input.GroupBy = new GroupByExpression(); - RowType rowType = ((CollectionType)(exp.ResultType.EdmType)).TypeUsage.EdmType as RowType; - int columnIndex = 0; - foreach (var key in exp.Keys) - { - VisitedExpression keyColumnExpression = key.Accept(this); - var prop = rowType.Properties[columnIndex]; - input.Projection.Arguments.Add(new ColumnExpression(keyColumnExpression, prop.Name, prop.TypeUsage)); - // have no idea why EF is generating a group by with a constant expression, - // but postgresql doesn't need it. - if (!(key is DbConstantExpression)) - { - input.GroupBy.AppendGroupingKey(keyColumnExpression); - } - ++columnIndex; - } - foreach (var ag in exp.Aggregates) - { - DbFunctionAggregate function = (DbFunctionAggregate)ag; - VisitedExpression functionExpression = VisitFunction(function); - var prop = rowType.Properties[columnIndex]; - input.Projection.Arguments.Add(new ColumnExpression(functionExpression, prop.Name, prop.TypeUsage)); - ++columnIndex; - } - - if (enterScope) LeaveExpression(child); - - n = new PendingProjectsNode(bindingName, input); - break; + var binaryExp = (DbBinaryExpression)e; + func(binaryExp.Left); + func(binaryExp.Right); } - case DbExpressionKind.CrossJoin: - case DbExpressionKind.FullOuterJoin: - case DbExpressionKind.InnerJoin: - case DbExpressionKind.LeftOuterJoin: - case DbExpressionKind.CrossApply: - case DbExpressionKind.OuterApply: - { - InputExpression input = new InputExpression(); - n = new PendingProjectsNode(bindingName, input); + else + list.Add(VisitInputWithBinding(e, bindingName + "_" + list.Count).Last.Exp); + }; + func(exp.Left); + func(exp.Right); + var input = new InputExpression(new CombinedProjectionExpression(expression.ExpressionKind, list), bindingName); + n = new PendingProjectsNode(bindingName, input); + break; + } + case DbExpressionKind.Project: + { + var exp = (DbProjectExpression)expression; + var child = VisitInputWithBinding(exp.Input.Expression, exp.Input.VariableName); + var input = child.Last.Exp; + var enterScope = false; + if (!IsCompatible(input, expression.ExpressionKind)) + input = new InputExpression(input, child.Last.AsName); + else + enterScope = true; - JoinExpression from = VisitJoinChildren(expression, input, n); + if (enterScope) EnterExpression(child); - input.From = from; + input.Projection = new CommaSeparatedExpression(); - break; - } - default: throw new NotImplementedException(); + var projection = (DbNewInstanceExpression)exp.Projection; + var rowType = (RowType)projection.ResultType.EdmType; + for (var i = 0; i < rowType.Properties.Count && i < projection.Arguments.Count; ++i) + { + var prop = rowType.Properties[i]; + input.Projection.Arguments.Add(new ColumnExpression(projection.Arguments[i].Accept(this), prop.Name, prop.TypeUsage)); + } + + if (enterScope) LeaveExpression(child); + + n = new PendingProjectsNode(bindingName, input); + break; + } + case DbExpressionKind.GroupBy: + { + var exp = (DbGroupByExpression)expression; + var child = VisitInputWithBinding(exp.Input.Expression, exp.Input.VariableName); + + // I don't know why the input for GroupBy in EF have two names + RefToNode[exp.Input.GroupVariableName] = child; + + var input = child.Last.Exp; + var enterScope = false; + if (!IsCompatible(input, expression.ExpressionKind)) + input = new InputExpression(input, child.Last.AsName); + else enterScope = true; + + if (enterScope) EnterExpression(child); + + input.Projection = new CommaSeparatedExpression(); + + input.GroupBy = new GroupByExpression(); + var rowType = (RowType)((CollectionType)exp.ResultType.EdmType).TypeUsage.EdmType; + var columnIndex = 0; + foreach (var key in exp.Keys) + { + var keyColumnExpression = key.Accept(this); + var prop = rowType.Properties[columnIndex]; + input.Projection.Arguments.Add(new ColumnExpression(keyColumnExpression, prop.Name, prop.TypeUsage)); + // have no idea why EF is generating a group by with a constant expression, + // but postgresql doesn't need it. + if (!(key is DbConstantExpression)) + input.GroupBy.AppendGroupingKey(keyColumnExpression); + ++columnIndex; + } + foreach (var ag in exp.Aggregates) + { + var function = (DbFunctionAggregate)ag; + var functionExpression = VisitFunction(function); + var prop = rowType.Properties[columnIndex]; + input.Projection.Arguments.Add(new ColumnExpression(functionExpression, prop.Name, prop.TypeUsage)); + ++columnIndex; + } + + if (enterScope) LeaveExpression(child); + + n = new PendingProjectsNode(bindingName, input); + break; + } + case DbExpressionKind.CrossJoin: + case DbExpressionKind.FullOuterJoin: + case DbExpressionKind.InnerJoin: + case DbExpressionKind.LeftOuterJoin: + case DbExpressionKind.CrossApply: + case DbExpressionKind.OuterApply: + { + var input = new InputExpression(); + n = new PendingProjectsNode(bindingName, input); + + var from = VisitJoinChildren(expression, input, n); + + input.From = from; + + break; + } + default: + throw new NotImplementedException(); } - _refToNode[bindingName] = n; + + RefToNode[bindingName] = n; return n; } - private bool IsJoin(DbExpressionKind kind) + bool IsJoin(DbExpressionKind kind) { switch (kind) { - case DbExpressionKind.CrossJoin: - case DbExpressionKind.FullOuterJoin: - case DbExpressionKind.InnerJoin: - case DbExpressionKind.LeftOuterJoin: - case DbExpressionKind.CrossApply: - case DbExpressionKind.OuterApply: - return true; + case DbExpressionKind.CrossJoin: + case DbExpressionKind.FullOuterJoin: + case DbExpressionKind.InnerJoin: + case DbExpressionKind.LeftOuterJoin: + case DbExpressionKind.CrossApply: + case DbExpressionKind.OuterApply: + return true; } return false; } - private JoinExpression VisitJoinChildren(DbExpression expression, InputExpression input, PendingProjectsNode n) + JoinExpression VisitJoinChildren(DbExpression expression, InputExpression input, PendingProjectsNode n) { DbExpressionBinding left, right; DbExpression condition = null; @@ -460,18 +441,16 @@ private JoinExpression VisitJoinChildren(DbExpression expression, InputExpressio return VisitJoinChildren(left.Expression, left.VariableName, right.Expression, right.VariableName, expression.ExpressionKind, condition, input, n); } - private JoinExpression VisitJoinChildren(DbExpression left, string leftName, DbExpression right, string rightName, DbExpressionKind joinType, DbExpression condition, InputExpression input, PendingProjectsNode n) + + JoinExpression VisitJoinChildren(DbExpression left, string leftName, DbExpression right, string rightName, DbExpressionKind joinType, [CanBeNull] DbExpression condition, InputExpression input, PendingProjectsNode n) { - JoinExpression join = new JoinExpression(); - join.JoinType = joinType; + var join = new JoinExpression { JoinType = joinType }; if (IsJoin(left.ExpressionKind)) - { join.Left = VisitJoinChildren(left, input, n); - } else { - PendingProjectsNode l = VisitInputWithBinding(left, leftName); + var l = VisitInputWithBinding(left, leftName); l.JoinParent = n; join.Left = new FromExpression(l.Last.Exp, l.Last.AsName); } @@ -479,7 +458,7 @@ private JoinExpression VisitJoinChildren(DbExpression left, string leftName, DbE if (joinType == DbExpressionKind.OuterApply || joinType == DbExpressionKind.CrossApply) { EnterExpression(n); - PendingProjectsNode r = VisitInputWithBinding(right, rightName); + var r = VisitInputWithBinding(right, rightName); LeaveExpression(n); r.JoinParent = n; join.Right = new FromExpression(r.Last.Exp, r.Last.AsName) { ForceSubquery = true }; @@ -487,12 +466,10 @@ private JoinExpression VisitJoinChildren(DbExpression left, string leftName, DbE else { if (IsJoin(right.ExpressionKind)) - { join.Right = VisitJoinChildren(right, input, n); - } else { - PendingProjectsNode r = VisitInputWithBinding(right, rightName); + var r = VisitInputWithBinding(right, rightName); r.JoinParent = n; join.Right = new FromExpression(r.Last.Exp, r.Last.AsName); } @@ -507,91 +484,77 @@ private JoinExpression VisitJoinChildren(DbExpression left, string leftName, DbE return join; } - public override VisitedExpression Visit(DbVariableReferenceExpression expression) + public override VisitedExpression Visit([NotNull] DbVariableReferenceExpression expression) { //return new VariableReferenceExpression(expression.VariableName, _variableSubstitution); throw new NotImplementedException(); } - public override VisitedExpression Visit(DbUnionAllExpression expression) + public override VisitedExpression Visit([NotNull] DbUnionAllExpression expression) { // Handled by VisitInputWithBinding throw new NotImplementedException(); } - public override VisitedExpression Visit(DbTreatExpression expression) + public override VisitedExpression Visit([NotNull] DbTreatExpression expression) { throw new NotImplementedException(); } - public override VisitedExpression Visit(DbSkipExpression expression) + public override VisitedExpression Visit([NotNull] DbSkipExpression expression) { // Handled by VisitInputWithBinding throw new NotImplementedException(); } - public override VisitedExpression Visit(DbSortExpression expression) + public override VisitedExpression Visit([NotNull] DbSortExpression expression) { // Handled by VisitInputWithBinding throw new NotImplementedException(); } - public override VisitedExpression Visit(DbScanExpression expression) + public override VisitedExpression Visit([NotNull] DbScanExpression expression) { MetadataProperty metadata; string tableName; - string overrideTable = "http://schemas.microsoft.com/ado/2007/12/edm/EntityStoreSchemaGenerator:Name"; + var overrideTable = "http://schemas.microsoft.com/ado/2007/12/edm/EntityStoreSchemaGenerator:Name"; if (expression.Target.MetadataProperties.TryGetValue(overrideTable, false, out metadata) && metadata.Value != null) - { tableName = metadata.Value.ToString(); - } else if (expression.Target.MetadataProperties.TryGetValue("Table", false, out metadata) && metadata.Value != null) - { tableName = metadata.Value.ToString(); - } else - { tableName = expression.Target.Name; - } if (expression.Target.MetadataProperties.Contains("DefiningQuery")) { - MetadataProperty definingQuery = expression.Target.MetadataProperties.GetValue("DefiningQuery", false); + var definingQuery = expression.Target.MetadataProperties.GetValue("DefiningQuery", false); if (definingQuery.Value != null) - { return new ScanExpression("(" + definingQuery.Value + ")", expression.Target); - } } ScanExpression scan; - string overrideSchema = "http://schemas.microsoft.com/ado/2007/12/edm/EntityStoreSchemaGenerator:Schema"; + var overrideSchema = "http://schemas.microsoft.com/ado/2007/12/edm/EntityStoreSchemaGenerator:Schema"; if (expression.Target.MetadataProperties.TryGetValue(overrideSchema, false, out metadata) && metadata.Value != null) - { scan = new ScanExpression(QuoteIdentifier(metadata.Value.ToString()) + "." + QuoteIdentifier(tableName), expression.Target); - } else if (expression.Target.MetadataProperties.TryGetValue("Schema", false, out metadata) && metadata.Value != null) - { scan = new ScanExpression(QuoteIdentifier(metadata.Value.ToString()) + "." + QuoteIdentifier(tableName), expression.Target); - } else - { scan = new ScanExpression(QuoteIdentifier(expression.Target.EntityContainer.Name) + "." + QuoteIdentifier(tableName), expression.Target); - } return scan; } - public override VisitedExpression Visit(DbRelationshipNavigationExpression expression) + public override VisitedExpression Visit([NotNull] DbRelationshipNavigationExpression expression) { throw new NotImplementedException(); } - public override VisitedExpression Visit(DbRefExpression expression) + public override VisitedExpression Visit([NotNull] DbRefExpression expression) { throw new NotImplementedException(); } - public override VisitedExpression Visit(DbQuantifierExpression expression) + public override VisitedExpression Visit([NotNull] DbQuantifierExpression expression) { // TODO: EXISTS or NOT EXISTS depending on expression.ExpressionKind // comes with it's built in test (subselect for EXISTS) @@ -599,49 +562,38 @@ public override VisitedExpression Visit(DbQuantifierExpression expression) throw new NotImplementedException(); } - public override VisitedExpression Visit(DbProjectExpression expression) - { - return VisitInputWithBinding(expression, NextAlias()).Last.Exp; - } + public override VisitedExpression Visit([NotNull] DbProjectExpression expression) + => VisitInputWithBinding(expression, NextAlias()).Last.Exp; - public override VisitedExpression Visit(DbParameterReferenceExpression expression) - { - // use parameter in sql - return new LiteralExpression("@" + expression.ParameterName); - } + // use parameter in sql + public override VisitedExpression Visit([NotNull] DbParameterReferenceExpression expression) + => new LiteralExpression("@" + expression.ParameterName); - public override VisitedExpression Visit(DbOrExpression expression) - { - return OperatorExpression.Build(Operator.Or, _useNewPrecedences, expression.Left.Accept(this), expression.Right.Accept(this)); - } + public override VisitedExpression Visit([NotNull] DbOrExpression expression) + => OperatorExpression.Build(Operator.Or, _useNewPrecedences, expression.Left.Accept(this), expression.Right.Accept(this)); - public override VisitedExpression Visit(DbOfTypeExpression expression) + public override VisitedExpression Visit([NotNull] DbOfTypeExpression expression) { throw new NotImplementedException(); } - public override VisitedExpression Visit(DbNullExpression expression) - { - // select does something different here. But insert, update, delete, and functions can just use - // a NULL literal. - return new LiteralExpression("NULL"); - } + // select does something different here. But insert, update, delete, and functions can just use + // a NULL literal. + public override VisitedExpression Visit([NotNull] DbNullExpression expression) + => new LiteralExpression("NULL"); - public override VisitedExpression Visit(DbNotExpression expression) - { - // argument can be a "NOT EXISTS" or similar operator that can be negated. - // Convert the not if that's the case - VisitedExpression argument = expression.Argument.Accept(this); - return OperatorExpression.Negate(argument, _useNewPrecedences); - } + // argument can be a "NOT EXISTS" or similar operator that can be negated. + // Convert the not if that's the case + public override VisitedExpression Visit([NotNull] DbNotExpression expression) + => OperatorExpression.Negate(expression.Argument.Accept(this), _useNewPrecedences); - public override VisitedExpression Visit(DbNewInstanceExpression expression) + // Handled by VisitInputWithBinding + public override VisitedExpression Visit([NotNull] DbNewInstanceExpression expression) { - // Handled by VisitInputWithBinding throw new NotImplementedException(); } - public override VisitedExpression Visit(DbLimitExpression expression) + public override VisitedExpression Visit([NotNull] DbLimitExpression expression) { // Normally handled by VisitInputWithBinding @@ -649,7 +601,7 @@ public override VisitedExpression Visit(DbLimitExpression expression) // in which case the child of this expression might be a DbProjectExpression, // then the correct columns will be projected since Limit is compatible with the result of a DbProjectExpression, // which will result in having a Projection on the node after visiting it. - PendingProjectsNode node = VisitInputWithBinding(expression, NextAlias()); + var node = VisitInputWithBinding(expression, NextAlias()); if (node.Last.Exp.Projection == null) { // This DbLimitExpression is (probably) a child of DbElementExpression @@ -661,204 +613,191 @@ public override VisitedExpression Visit(DbLimitExpression expression) // Since this is (probably) a child of DbElementExpression, we want the first column, // so make sure it is propagated from the nearest explicit projection. - CommaSeparatedExpression projection = node.Selects[0].Exp.Projection; - for (int i = 1; i < node.Selects.Count; i++) + var projection = node.Selects[0].Exp.Projection; + for (var i = 1; i < node.Selects.Count; i++) { - ColumnExpression column = (ColumnExpression)projection.Arguments[0]; - + var column = (ColumnExpression)projection.Arguments[0]; node.Selects[i].Exp.ColumnsToProject[new StringPair(node.Selects[i - 1].AsName, column.Name)] = column.Name; } } return node.Last.Exp; } - public override VisitedExpression Visit(DbLikeExpression expression) - { - // LIKE keyword - return OperatorExpression.Build(Operator.Like, _useNewPrecedences, expression.Argument.Accept(this), expression.Pattern.Accept(this)); - } + // LIKE keyword + public override VisitedExpression Visit([NotNull] DbLikeExpression expression) + => OperatorExpression.Build(Operator.Like, _useNewPrecedences, expression.Argument.Accept(this), expression.Pattern.Accept(this)); - public override VisitedExpression Visit(DbJoinExpression expression) + public override VisitedExpression Visit([NotNull] DbJoinExpression expression) { // Handled by VisitInputWithBinding throw new NotImplementedException(); } - public override VisitedExpression Visit(DbIsOfExpression expression) + public override VisitedExpression Visit([NotNull] DbIsOfExpression expression) { throw new NotImplementedException(); } - public override VisitedExpression Visit(DbIsNullExpression expression) - { - return OperatorExpression.Build(Operator.IsNull, _useNewPrecedences, expression.Argument.Accept(this)); - } + public override VisitedExpression Visit([NotNull] DbIsNullExpression expression) + => OperatorExpression.Build(Operator.IsNull, _useNewPrecedences, expression.Argument.Accept(this)); - public override VisitedExpression Visit(DbIsEmptyExpression expression) - { - // NOT EXISTS - return OperatorExpression.Negate(new ExistsExpression(expression.Argument.Accept(this)), _useNewPrecedences); - } + // NOT EXISTS + public override VisitedExpression Visit([NotNull] DbIsEmptyExpression expression) + => OperatorExpression.Negate(new ExistsExpression(expression.Argument.Accept(this)), _useNewPrecedences); - public override VisitedExpression Visit(DbIntersectExpression expression) + public override VisitedExpression Visit([NotNull] DbIntersectExpression expression) { // INTERSECT keyword // Handled by VisitInputWithBinding throw new NotImplementedException(); } - public override VisitedExpression Visit(DbGroupByExpression expression) - { - // Normally handled by VisitInputWithBinding - - // Otherwise, it is (probably) a child of a DbElementExpression. - // Group by always projects the correct columns. - return VisitInputWithBinding(expression, NextAlias()).Last.Exp; - } + // Normally handled by VisitInputWithBinding + // Otherwise, it is (probably) a child of a DbElementExpression. + // Group by always projects the correct columns. + public override VisitedExpression Visit([NotNull] DbGroupByExpression expression) + => VisitInputWithBinding(expression, NextAlias()).Last.Exp; - public override VisitedExpression Visit(DbRefKeyExpression expression) + public override VisitedExpression Visit([NotNull] DbRefKeyExpression expression) { throw new NotImplementedException(); } - public override VisitedExpression Visit(DbEntityRefExpression expression) + public override VisitedExpression Visit([NotNull] DbEntityRefExpression expression) { throw new NotImplementedException(); } - public override VisitedExpression Visit(DbFunctionExpression expression) - { - // a function call - // may be built in, canonical, or user defined - return VisitFunction(expression.Function, expression.Arguments, expression.ResultType); - } + // a function call + // may be built in, canonical, or user defined + public override VisitedExpression Visit([NotNull] DbFunctionExpression expression) + => VisitFunction(expression.Function, expression.Arguments, expression.ResultType); - public override VisitedExpression Visit(DbFilterExpression expression) + public override VisitedExpression Visit([NotNull] DbFilterExpression expression) { // Handled by VisitInputWithBinding throw new NotImplementedException(); } - public override VisitedExpression Visit(DbExceptExpression expression) + public override VisitedExpression Visit([NotNull] DbExceptExpression expression) { // EXCEPT keyword // Handled by VisitInputWithBinding throw new NotImplementedException(); } - public override VisitedExpression Visit(DbElementExpression expression) + public override VisitedExpression Visit([NotNull] DbElementExpression expression) { // If child of DbNewInstanceExpression, this is handled in VisitInputWithBinding // a scalar expression (ie ExecuteScalar) // so it will likely be translated into a select //throw new NotImplementedException(); - LiteralExpression scalar = new LiteralExpression("("); + var scalar = new LiteralExpression("("); scalar.Append(expression.Argument.Accept(this)); scalar.Append(")"); return scalar; } - public override VisitedExpression Visit(DbDistinctExpression expression) + public override VisitedExpression Visit([NotNull] DbDistinctExpression expression) { // Handled by VisitInputWithBinding throw new NotImplementedException(); } - public override VisitedExpression Visit(DbDerefExpression expression) + public override VisitedExpression Visit([NotNull] DbDerefExpression expression) { throw new NotImplementedException(); } - public override VisitedExpression Visit(DbCrossJoinExpression expression) + public override VisitedExpression Visit([NotNull] DbCrossJoinExpression expression) { // join without ON // Handled by VisitInputWithBinding throw new NotImplementedException(); } - public override VisitedExpression Visit(DbConstantExpression expression) + public override VisitedExpression Visit([NotNull] DbConstantExpression expression) { - if (_createParametersForConstants) + if (CreateParametersForConstants) { - NpgsqlParameter parameter = new NpgsqlParameter(); - parameter.ParameterName = "p_" + _parameterCount++; - parameter.NpgsqlDbType = NpgsqlProviderManifest.GetNpgsqlDbType(((PrimitiveType)expression.ResultType.EdmType).PrimitiveTypeKind); - parameter.Value = expression.Value; - _command.Parameters.Add(parameter); + var parameter = new NpgsqlParameter + { + ParameterName = "p_" + ParameterCount++, + NpgsqlDbType = NpgsqlProviderManifest.GetNpgsqlDbType(((PrimitiveType)expression.ResultType.EdmType).PrimitiveTypeKind), + Value = expression.Value + }; + Command.Parameters.Add(parameter); return new LiteralExpression("@" + parameter.ParameterName); } - else - { - return new ConstantExpression(expression.Value, expression.ResultType); - } + + return new ConstantExpression(expression.Value, expression.ResultType); } - public override VisitedExpression Visit(DbComparisonExpression expression) + public override VisitedExpression Visit([NotNull] DbComparisonExpression expression) { Operator comparisonOperator; switch (expression.ExpressionKind) { - case DbExpressionKind.Equals: comparisonOperator = Operator.Equals; break; - case DbExpressionKind.GreaterThan: comparisonOperator = Operator.GreaterThan; break; - case DbExpressionKind.GreaterThanOrEquals: comparisonOperator = Operator.GreaterThanOrEquals; break; - case DbExpressionKind.LessThan: comparisonOperator = Operator.LessThan; break; - case DbExpressionKind.LessThanOrEquals: comparisonOperator = Operator.LessThanOrEquals; break; - case DbExpressionKind.Like: comparisonOperator = Operator.Like; break; - case DbExpressionKind.NotEquals: comparisonOperator = Operator.NotEquals; break; - default: throw new NotSupportedException(); + case DbExpressionKind.Equals: comparisonOperator = Operator.Equals; break; + case DbExpressionKind.GreaterThan: comparisonOperator = Operator.GreaterThan; break; + case DbExpressionKind.GreaterThanOrEquals: comparisonOperator = Operator.GreaterThanOrEquals; break; + case DbExpressionKind.LessThan: comparisonOperator = Operator.LessThan; break; + case DbExpressionKind.LessThanOrEquals: comparisonOperator = Operator.LessThanOrEquals; break; + case DbExpressionKind.Like: comparisonOperator = Operator.Like; break; + case DbExpressionKind.NotEquals: comparisonOperator = Operator.NotEquals; break; + default: throw new NotSupportedException(); } return OperatorExpression.Build(comparisonOperator, _useNewPrecedences, expression.Left.Accept(this), expression.Right.Accept(this)); } - public override VisitedExpression Visit(DbCastExpression expression) - { - return new CastExpression(expression.Argument.Accept(this), GetDbType(expression.ResultType.EdmType)); - } + public override VisitedExpression Visit([NotNull] DbCastExpression expression) + => new CastExpression(expression.Argument.Accept(this), GetDbType(expression.ResultType.EdmType)); protected string GetDbType(EdmType edmType) { - PrimitiveType primitiveType = edmType as PrimitiveType; + var primitiveType = edmType as PrimitiveType; if (primitiveType == null) throw new NotSupportedException(); + switch (primitiveType.PrimitiveTypeKind) { - case PrimitiveTypeKind.Boolean: - return "bool"; - case PrimitiveTypeKind.SByte: - case PrimitiveTypeKind.Byte: - case PrimitiveTypeKind.Int16: - return "int2"; - case PrimitiveTypeKind.Int32: - return "int4"; - case PrimitiveTypeKind.Int64: - return "int8"; - case PrimitiveTypeKind.String: - return "text"; - case PrimitiveTypeKind.Decimal: - return "numeric"; - case PrimitiveTypeKind.Single: - return "float4"; - case PrimitiveTypeKind.Double: - return "float8"; - case PrimitiveTypeKind.DateTime: - return "timestamp"; - case PrimitiveTypeKind.DateTimeOffset: - return "timestamptz"; - case PrimitiveTypeKind.Time: - return "interval"; - case PrimitiveTypeKind.Binary: - return "bytea"; - case PrimitiveTypeKind.Guid: - return "uuid"; + case PrimitiveTypeKind.Boolean: + return "bool"; + case PrimitiveTypeKind.SByte: + case PrimitiveTypeKind.Byte: + case PrimitiveTypeKind.Int16: + return "int2"; + case PrimitiveTypeKind.Int32: + return "int4"; + case PrimitiveTypeKind.Int64: + return "int8"; + case PrimitiveTypeKind.String: + return "text"; + case PrimitiveTypeKind.Decimal: + return "numeric"; + case PrimitiveTypeKind.Single: + return "float4"; + case PrimitiveTypeKind.Double: + return "float8"; + case PrimitiveTypeKind.DateTime: + return "timestamp"; + case PrimitiveTypeKind.DateTimeOffset: + return "timestamptz"; + case PrimitiveTypeKind.Time: + return "interval"; + case PrimitiveTypeKind.Binary: + return "bytea"; + case PrimitiveTypeKind.Guid: + return "uuid"; } throw new NotSupportedException(); } - public override VisitedExpression Visit(DbCaseExpression expression) + public override VisitedExpression Visit([NotNull] DbCaseExpression expression) { - LiteralExpression caseExpression = new LiteralExpression(" CASE "); - for (int i = 0; i < expression.When.Count && i < expression.Then.Count; ++i) + var caseExpression = new LiteralExpression(" CASE "); + for (var i = 0; i < expression.When.Count && i < expression.Then.Count; ++i) { caseExpression.Append(" WHEN ("); caseExpression.Append(expression.When[i].Accept(this)); @@ -866,60 +805,58 @@ public override VisitedExpression Visit(DbCaseExpression expression) caseExpression.Append(expression.Then[i].Accept(this)); caseExpression.Append(")"); } + if (expression.Else is DbNullExpression) - { caseExpression.Append(" END "); - } else { caseExpression.Append(" ELSE ("); caseExpression.Append(expression.Else.Accept(this)); caseExpression.Append(") END "); } + return caseExpression; } - public override VisitedExpression Visit(DbArithmeticExpression expression) + public override VisitedExpression Visit([NotNull] DbArithmeticExpression expression) { Operator arithmeticOperator; switch (expression.ExpressionKind) { - case DbExpressionKind.Divide: - arithmeticOperator = Operator.Div; - break; - case DbExpressionKind.Minus: - arithmeticOperator = Operator.Sub; - break; - case DbExpressionKind.Modulo: - arithmeticOperator = Operator.Mod; - break; - case DbExpressionKind.Multiply: - arithmeticOperator = Operator.Mul; - break; - case DbExpressionKind.Plus: - arithmeticOperator = Operator.Add; - break; - case DbExpressionKind.UnaryMinus: - arithmeticOperator = Operator.UnaryMinus; - break; - default: - throw new NotSupportedException(); + case DbExpressionKind.Divide: + arithmeticOperator = Operator.Div; + break; + case DbExpressionKind.Minus: + arithmeticOperator = Operator.Sub; + break; + case DbExpressionKind.Modulo: + arithmeticOperator = Operator.Mod; + break; + case DbExpressionKind.Multiply: + arithmeticOperator = Operator.Mul; + break; + case DbExpressionKind.Plus: + arithmeticOperator = Operator.Add; + break; + case DbExpressionKind.UnaryMinus: + arithmeticOperator = Operator.UnaryMinus; + break; + default: + throw new NotSupportedException(); } if (expression.ExpressionKind == DbExpressionKind.UnaryMinus) { - System.Diagnostics.Debug.Assert(expression.Arguments.Count == 1); + Debug.Assert(expression.Arguments.Count == 1); return OperatorExpression.Build(arithmeticOperator, _useNewPrecedences, expression.Arguments[0].Accept(this)); } - else - { - System.Diagnostics.Debug.Assert(expression.Arguments.Count == 2); - return OperatorExpression.Build(arithmeticOperator, _useNewPrecedences, expression.Arguments[0].Accept(this), expression.Arguments[1].Accept(this)); - } + + Debug.Assert(expression.Arguments.Count == 2); + return OperatorExpression.Build(arithmeticOperator, _useNewPrecedences, expression.Arguments[0].Accept(this), expression.Arguments[1].Accept(this)); } - public override VisitedExpression Visit(DbApplyExpression expression) + public override VisitedExpression Visit([NotNull] DbApplyExpression expression) { // like a join, but used when the right hand side (the Apply part) is a function. // it lets you return the results of a function call given values from the @@ -930,12 +867,10 @@ public override VisitedExpression Visit(DbApplyExpression expression) throw new NotImplementedException(); } - public override VisitedExpression Visit(DbAndExpression expression) - { - return OperatorExpression.Build(Operator.And, _useNewPrecedences, expression.Left.Accept(this), expression.Right.Accept(this)); - } + public override VisitedExpression Visit([NotNull] DbAndExpression expression) + => OperatorExpression.Build(Operator.And, _useNewPrecedences, expression.Left.Accept(this), expression.Right.Accept(this)); - public override VisitedExpression Visit(DbExpression expression) + public override VisitedExpression Visit([NotNull] DbExpression expression) { // only concrete types visited throw new NotSupportedException(); @@ -944,11 +879,9 @@ public override VisitedExpression Visit(DbExpression expression) public abstract void BuildCommand(DbCommand command); internal static string QuoteIdentifier(string identifier) - { - return "\"" + identifier.Replace("\"", "\"\"") + "\""; - } + => "\"" + identifier.Replace("\"", "\"\"") + "\""; - private VisitedExpression VisitFunction(DbFunctionAggregate functionAggregate) + VisitedExpression VisitFunction(DbFunctionAggregate functionAggregate) { if (functionAggregate.Function.NamespaceName == "Edm") { @@ -956,11 +889,12 @@ private VisitedExpression VisitFunction(DbFunctionAggregate functionAggregate) try { aggregate = new FunctionExpression(AggregateFunctionNames[functionAggregate.Function.Name]); - } catch (KeyNotFoundException) + } + catch (KeyNotFoundException) { throw new NotSupportedException(); } - System.Diagnostics.Debug.Assert(functionAggregate.Arguments.Count == 1); + Debug.Assert(functionAggregate.Arguments.Count == 1); VisitedExpression aggregateArg; if (functionAggregate.Distinct) { @@ -977,159 +911,159 @@ private VisitedExpression VisitFunction(DbFunctionAggregate functionAggregate) throw new NotSupportedException(); } - private VisitedExpression VisitFunction(EdmFunction function, IList args, TypeUsage resultType) + VisitedExpression VisitFunction(EdmFunction function, IList args, TypeUsage resultType) { if (function.NamespaceName == "Edm") { VisitedExpression arg; switch (function.Name) { - // string functions - case "Concat": - System.Diagnostics.Debug.Assert(args.Count == 2); - return OperatorExpression.Build(Operator.Concat, _useNewPrecedences, args[0].Accept(this), args[1].Accept(this)); - case "Contains": - System.Diagnostics.Debug.Assert(args.Count == 2); - FunctionExpression contains = new FunctionExpression("position"); - arg = args[1].Accept(this); - arg.Append(" in "); - arg.Append(args[0].Accept(this)); - contains.AddArgument(arg); - // if position returns zero, then contains is false - return OperatorExpression.Build(Operator.GreaterThan, _useNewPrecedences, contains, new LiteralExpression("0")); - // case "EndsWith": - depends on a reverse function to be able to implement with parameterized queries - case "IndexOf": - System.Diagnostics.Debug.Assert(args.Count == 2); - FunctionExpression indexOf = new FunctionExpression("position"); - arg = args[0].Accept(this); - arg.Append(" in "); - arg.Append(args[1].Accept(this)); - indexOf.AddArgument(arg); - return indexOf; - case "Left": - System.Diagnostics.Debug.Assert(args.Count == 2); - return Substring(args[0].Accept(this), new LiteralExpression(" 1 "), args[1].Accept(this)); - case "Length": - FunctionExpression length = new FunctionExpression("char_length"); - System.Diagnostics.Debug.Assert(args.Count == 1); - length.AddArgument(args[0].Accept(this)); - return new CastExpression(length, GetDbType(resultType.EdmType)); - case "LTrim": - return StringModifier("ltrim", args); - case "Replace": - FunctionExpression replace = new FunctionExpression("replace"); - System.Diagnostics.Debug.Assert(args.Count == 3); - replace.AddArgument(args[0].Accept(this)); - replace.AddArgument(args[1].Accept(this)); - replace.AddArgument(args[2].Accept(this)); - return replace; - // case "Reverse": - case "Right": - System.Diagnostics.Debug.Assert(args.Count == 2); - { - var arg0 = args[0].Accept(this); - var arg1 = args[1].Accept(this); - var start = new FunctionExpression("char_length"); - start.AddArgument(arg0); - // add one before subtracting count since strings are 1 based in postgresql - return Substring(arg0, OperatorExpression.Build(Operator.Sub, _useNewPrecedences, OperatorExpression.Build(Operator.Add, _useNewPrecedences, start, new LiteralExpression("1")), arg1)); - } - case "RTrim": - return StringModifier("rtrim", args); - case "Substring": - System.Diagnostics.Debug.Assert(args.Count == 3); - return Substring(args[0].Accept(this), args[1].Accept(this), args[2].Accept(this)); - case "StartsWith": - System.Diagnostics.Debug.Assert(args.Count == 2); - FunctionExpression startsWith = new FunctionExpression("position"); - arg = args[1].Accept(this); - arg.Append(" in "); - arg.Append(args[0].Accept(this)); - startsWith.AddArgument(arg); - return OperatorExpression.Build(Operator.Equals, _useNewPrecedences, startsWith, new LiteralExpression("1")); - case "ToLower": - return StringModifier("lower", args); - case "ToUpper": - return StringModifier("upper", args); - case "Trim": - return StringModifier("btrim", args); - - // date functions - // date functions - case "AddDays": - case "AddHours": - case "AddMicroseconds": - case "AddMilliseconds": - case "AddMinutes": - case "AddMonths": - case "AddNanoseconds": - case "AddSeconds": - case "AddYears": - return DateAdd(function.Name, args); - case "DiffDays": - case "DiffHours": - case "DiffMicroseconds": - case "DiffMilliseconds": - case "DiffMinutes": - case "DiffMonths": - case "DiffNanoseconds": - case "DiffSeconds": - case "DiffYears": - System.Diagnostics.Debug.Assert(args.Count == 2); - return DateDiff(function.Name, args[0].Accept(this), args[1].Accept(this)); - case "Day": - case "Hour": - case "Minute": - case "Month": - case "Second": - case "Year": - return DatePart(function.Name, args); - case "Millisecond": - return DatePart("milliseconds", args); - case "GetTotalOffsetMinutes": - VisitedExpression timezone = DatePart("timezone", args); - return OperatorExpression.Build(Operator.Div, _useNewPrecedences, timezone, new LiteralExpression("60")); - case "CurrentDateTime": - return new LiteralExpression("LOCALTIMESTAMP"); - case "CurrentUtcDateTime": - LiteralExpression utcNow = new LiteralExpression("CURRENT_TIMESTAMP"); - utcNow.Append(" AT TIME ZONE 'UTC'"); - return utcNow; - case "CurrentDateTimeOffset": - // TODO: this doesn't work yet because the reader - // doesn't return DateTimeOffset. - return new LiteralExpression("CURRENT_TIMESTAMP"); - - // bitwise operators - case "BitwiseAnd": - return BitwiseOperator(args, Operator.BitwiseAnd); - case "BitwiseOr": - return BitwiseOperator(args, Operator.BitwiseOr); - case "BitwiseXor": - return BitwiseOperator(args, Operator.BitwiseXor); - case "BitwiseNot": - System.Diagnostics.Debug.Assert(args.Count == 1); - return OperatorExpression.Build(Operator.BitwiseNot, _useNewPrecedences, args[0].Accept(this)); - - // math operators - case "Abs": - case "Ceiling": - case "Floor": - return UnaryMath(function.Name, args); - case "Round": - return (args.Count == 1) ? UnaryMath(function.Name, args) : BinaryMath(function.Name, args); - case "Power": - return BinaryMath(function.Name, args); - case "Truncate": - return BinaryMath("trunc", args); - - case "NewGuid": - return new FunctionExpression("uuid_generate_v4"); - case "TruncateTime": - return new TruncateTimeExpression("day", args[0].Accept(this)); - - default: - throw new NotSupportedException("NotSupported " + function.Name); + // string functions + case "Concat": + Debug.Assert(args.Count == 2); + return OperatorExpression.Build(Operator.Concat, _useNewPrecedences, args[0].Accept(this), args[1].Accept(this)); + case "Contains": + Debug.Assert(args.Count == 2); + var contains = new FunctionExpression("position"); + arg = args[1].Accept(this); + arg.Append(" in "); + arg.Append(args[0].Accept(this)); + contains.AddArgument(arg); + // if position returns zero, then contains is false + return OperatorExpression.Build(Operator.GreaterThan, _useNewPrecedences, contains, new LiteralExpression("0")); + // case "EndsWith": - depends on a reverse function to be able to implement with parameterized queries + case "IndexOf": + Debug.Assert(args.Count == 2); + var indexOf = new FunctionExpression("position"); + arg = args[0].Accept(this); + arg.Append(" in "); + arg.Append(args[1].Accept(this)); + indexOf.AddArgument(arg); + return indexOf; + case "Left": + Debug.Assert(args.Count == 2); + return Substring(args[0].Accept(this), new LiteralExpression(" 1 "), args[1].Accept(this)); + case "Length": + var length = new FunctionExpression("char_length"); + Debug.Assert(args.Count == 1); + length.AddArgument(args[0].Accept(this)); + return new CastExpression(length, GetDbType(resultType.EdmType)); + case "LTrim": + return StringModifier("ltrim", args); + case "Replace": + var replace = new FunctionExpression("replace"); + Debug.Assert(args.Count == 3); + replace.AddArgument(args[0].Accept(this)); + replace.AddArgument(args[1].Accept(this)); + replace.AddArgument(args[2].Accept(this)); + return replace; + // case "Reverse": + case "Right": + Debug.Assert(args.Count == 2); + { + var arg0 = args[0].Accept(this); + var arg1 = args[1].Accept(this); + var start = new FunctionExpression("char_length"); + start.AddArgument(arg0); + // add one before subtracting count since strings are 1 based in postgresql + return Substring(arg0, OperatorExpression.Build(Operator.Sub, _useNewPrecedences, OperatorExpression.Build(Operator.Add, _useNewPrecedences, start, new LiteralExpression("1")), arg1)); + } + case "RTrim": + return StringModifier("rtrim", args); + case "Substring": + Debug.Assert(args.Count == 3); + return Substring(args[0].Accept(this), args[1].Accept(this), args[2].Accept(this)); + case "StartsWith": + Debug.Assert(args.Count == 2); + var startsWith = new FunctionExpression("position"); + arg = args[1].Accept(this); + arg.Append(" in "); + arg.Append(args[0].Accept(this)); + startsWith.AddArgument(arg); + return OperatorExpression.Build(Operator.Equals, _useNewPrecedences, startsWith, new LiteralExpression("1")); + case "ToLower": + return StringModifier("lower", args); + case "ToUpper": + return StringModifier("upper", args); + case "Trim": + return StringModifier("btrim", args); + + // date functions + // date functions + case "AddDays": + case "AddHours": + case "AddMicroseconds": + case "AddMilliseconds": + case "AddMinutes": + case "AddMonths": + case "AddNanoseconds": + case "AddSeconds": + case "AddYears": + return DateAdd(function.Name, args); + case "DiffDays": + case "DiffHours": + case "DiffMicroseconds": + case "DiffMilliseconds": + case "DiffMinutes": + case "DiffMonths": + case "DiffNanoseconds": + case "DiffSeconds": + case "DiffYears": + Debug.Assert(args.Count == 2); + return DateDiff(function.Name, args[0].Accept(this), args[1].Accept(this)); + case "Day": + case "Hour": + case "Minute": + case "Month": + case "Second": + case "Year": + return DatePart(function.Name, args); + case "Millisecond": + return DatePart("milliseconds", args); + case "GetTotalOffsetMinutes": + var timezone = DatePart("timezone", args); + return OperatorExpression.Build(Operator.Div, _useNewPrecedences, timezone, new LiteralExpression("60")); + case "CurrentDateTime": + return new LiteralExpression("LOCALTIMESTAMP"); + case "CurrentUtcDateTime": + var utcNow = new LiteralExpression("CURRENT_TIMESTAMP"); + utcNow.Append(" AT TIME ZONE 'UTC'"); + return utcNow; + case "CurrentDateTimeOffset": + // TODO: this doesn't work yet because the reader + // doesn't return DateTimeOffset. + return new LiteralExpression("CURRENT_TIMESTAMP"); + + // bitwise operators + case "BitwiseAnd": + return BitwiseOperator(args, Operator.BitwiseAnd); + case "BitwiseOr": + return BitwiseOperator(args, Operator.BitwiseOr); + case "BitwiseXor": + return BitwiseOperator(args, Operator.BitwiseXor); + case "BitwiseNot": + Debug.Assert(args.Count == 1); + return OperatorExpression.Build(Operator.BitwiseNot, _useNewPrecedences, args[0].Accept(this)); + + // math operators + case "Abs": + case "Ceiling": + case "Floor": + return UnaryMath(function.Name, args); + case "Round": + return args.Count == 1 ? UnaryMath(function.Name, args) : BinaryMath(function.Name, args); + case "Power": + return BinaryMath(function.Name, args); + case "Truncate": + return BinaryMath("trunc", args); + + case "NewGuid": + return new FunctionExpression("uuid_generate_v4"); + case "TruncateTime": + return new TruncateTimeExpression("day", args[0].Accept(this)); + + default: + throw new NotSupportedException("NotSupported " + function.Name); } } @@ -1226,59 +1160,58 @@ private VisitedExpression VisitFunction(EdmFunction function, IList args) + VisitedExpression UnaryMath(string funcName, IList args) { - FunctionExpression mathFunction = new FunctionExpression(funcName); - System.Diagnostics.Debug.Assert(args.Count == 1); + var mathFunction = new FunctionExpression(funcName); + Debug.Assert(args.Count == 1); mathFunction.AddArgument(args[0].Accept(this)); return mathFunction; } - private VisitedExpression BinaryMath(string funcName, IList args) + VisitedExpression BinaryMath(string funcName, IList args) { - FunctionExpression mathFunction = new FunctionExpression(funcName); - System.Diagnostics.Debug.Assert(args.Count == 2); + var mathFunction = new FunctionExpression(funcName); + Debug.Assert(args.Count == 2); mathFunction.AddArgument(args[0].Accept(this)); mathFunction.AddArgument(args[1].Accept(this)); return mathFunction; } - private VisitedExpression StringModifier(string modifier, IList args) + VisitedExpression StringModifier(string modifier, IList args) { - FunctionExpression modifierFunction = new FunctionExpression(modifier); - System.Diagnostics.Debug.Assert(args.Count == 1); + var modifierFunction = new FunctionExpression(modifier); + Debug.Assert(args.Count == 1); modifierFunction.AddArgument(args[0].Accept(this)); return modifierFunction; } - private VisitedExpression DatePart(string partName, IList args) + VisitedExpression DatePart(string partName, IList args) { - - FunctionExpression extract_date = new FunctionExpression("cast(extract"); - System.Diagnostics.Debug.Assert(args.Count == 1); + var extractDate = new FunctionExpression("cast(extract"); + Debug.Assert(args.Count == 1); VisitedExpression arg = new LiteralExpression(partName + " FROM "); arg.Append(args[0].Accept(this)); - extract_date.AddArgument(arg); + extractDate.AddArgument(arg); // need to convert to Int32 to match cononical function - extract_date.Append(" as int4)"); - return extract_date; + extractDate.Append(" as int4)"); + return extractDate; } /// @@ -1292,10 +1225,10 @@ private VisitedExpression DatePart(string partName, IList args) /// /// /// - private VisitedExpression DateAdd(string functionName, IList args) + VisitedExpression DateAdd(string functionName, IList args) { - bool nano = false; - string part = functionName.Substring(3); + var nano = false; + var part = functionName.Substring(3); if (part == "Nanoseconds") { @@ -1303,125 +1236,124 @@ private VisitedExpression DateAdd(string functionName, IList args) part = "Microseconds"; } - System.Diagnostics.Debug.Assert(args.Count == 2); - VisitedExpression time = args[0].Accept(this); - VisitedExpression mulLeft = args[1].Accept(this); + Debug.Assert(args.Count == 2); + var time = args[0].Accept(this); + var mulLeft = args[1].Accept(this); if (nano) mulLeft = OperatorExpression.Build(Operator.Div, _useNewPrecedences, mulLeft, new LiteralExpression("1000")); - LiteralExpression mulRight = new LiteralExpression(String.Format("INTERVAL '1 {0}'", part)); + var mulRight = new LiteralExpression($"INTERVAL '1 {part}'"); return OperatorExpression.Build(Operator.Add, _useNewPrecedences, time, OperatorExpression.Build(Operator.Mul, _useNewPrecedences, mulLeft, mulRight)); } - private VisitedExpression DateDiff(string functionName, VisitedExpression start, VisitedExpression end) + VisitedExpression DateDiff(string functionName, VisitedExpression start, VisitedExpression end) { switch (functionName) { - case "DiffDays": - start = new FunctionExpression("date_trunc").AddArgument("'day'").AddArgument(start); - end = new FunctionExpression("date_trunc").AddArgument("'day'").AddArgument(end); - return new FunctionExpression("date_part").AddArgument("'day'").AddArgument( - OperatorExpression.Build(Operator.Sub, _useNewPrecedences, end, start) - ).Append("::int4"); - case "DiffHours": - { - start = new FunctionExpression("date_trunc").AddArgument("'hour'").AddArgument(start); - end = new FunctionExpression("date_trunc").AddArgument("'hour'").AddArgument(end); - LiteralExpression epoch = new LiteralExpression("epoch from "); - OperatorExpression diff = OperatorExpression.Build(Operator.Sub, _useNewPrecedences, end, start); - epoch.Append(diff); - return OperatorExpression.Build(Operator.Div, _useNewPrecedences, new FunctionExpression("extract").AddArgument(epoch).Append("::int4"), new LiteralExpression("3600")); - } - case "DiffMicroseconds": - { - start = new FunctionExpression("date_trunc").AddArgument("'microseconds'").AddArgument(start); - end = new FunctionExpression("date_trunc").AddArgument("'microseconds'").AddArgument(end); - LiteralExpression epoch = new LiteralExpression("epoch from "); - OperatorExpression diff = OperatorExpression.Build(Operator.Sub, _useNewPrecedences, end, start); - epoch.Append(diff); - return new CastExpression(OperatorExpression.Build(Operator.Mul, _useNewPrecedences, new FunctionExpression("extract").AddArgument(epoch), new LiteralExpression("1000000")), "int4"); - } - case "DiffMilliseconds": - { - start = new FunctionExpression("date_trunc").AddArgument("'milliseconds'").AddArgument(start); - end = new FunctionExpression("date_trunc").AddArgument("'milliseconds'").AddArgument(end); - LiteralExpression epoch = new LiteralExpression("epoch from "); - OperatorExpression diff = OperatorExpression.Build(Operator.Sub, _useNewPrecedences, end, start); - epoch.Append(diff); - return new CastExpression(OperatorExpression.Build(Operator.Mul, _useNewPrecedences, new FunctionExpression("extract").AddArgument(epoch), new LiteralExpression("1000")), "int4"); - } - case "DiffMinutes": - { - start = new FunctionExpression("date_trunc").AddArgument("'minute'").AddArgument(start); - end = new FunctionExpression("date_trunc").AddArgument("'minute'").AddArgument(end); - LiteralExpression epoch = new LiteralExpression("epoch from "); - OperatorExpression diff = OperatorExpression.Build(Operator.Sub, _useNewPrecedences, end, start); - epoch.Append(diff); - return OperatorExpression.Build(Operator.Div, _useNewPrecedences, new FunctionExpression("extract").AddArgument(epoch).Append("::int4"), new LiteralExpression("60")); - } - case "DiffMonths": - { - start = new FunctionExpression("date_trunc").AddArgument("'month'").AddArgument(start); - end = new FunctionExpression("date_trunc").AddArgument("'month'").AddArgument(end); - VisitedExpression age = new FunctionExpression("age").AddArgument(end).AddArgument(start); - - // A month is 30 days and a year is 365.25 days after conversion from interval to seconds. - // After rounding and casting, the result will contain the correct number of months as an int4. - FunctionExpression seconds = new FunctionExpression("extract").AddArgument(new LiteralExpression("epoch from ").Append(age)); - VisitedExpression months = OperatorExpression.Build(Operator.Div, _useNewPrecedences, seconds, new LiteralExpression("2629800.0")); - return new FunctionExpression("round").AddArgument(months).Append("::int4"); - } - case "DiffNanoseconds": - { - // PostgreSQL only supports microseconds precision, so the value will be a multiple of 1000 - // This date_trunc will make sure start and end are of type timestamp, e.g. if the arguments is of type date - start = new FunctionExpression("date_trunc").AddArgument("'microseconds'").AddArgument(start); - end = new FunctionExpression("date_trunc").AddArgument("'microseconds'").AddArgument(end); - LiteralExpression epoch = new LiteralExpression("epoch from "); - OperatorExpression diff = OperatorExpression.Build(Operator.Sub, _useNewPrecedences, end, start); - epoch.Append(diff); - return new CastExpression(OperatorExpression.Build(Operator.Mul, _useNewPrecedences, new FunctionExpression("extract").AddArgument(epoch), new LiteralExpression("1000000000")), "int4"); - } - case "DiffSeconds": - { - start = new FunctionExpression("date_trunc").AddArgument("'second'").AddArgument(start); - end = new FunctionExpression("date_trunc").AddArgument("'second'").AddArgument(end); - LiteralExpression epoch = new LiteralExpression("epoch from "); - OperatorExpression diff = OperatorExpression.Build(Operator.Sub, _useNewPrecedences, end, start); - epoch.Append(diff); - return new FunctionExpression("extract").AddArgument(epoch).Append("::int4"); - } - case "DiffYears": - { - start = new FunctionExpression("date_trunc").AddArgument("'year'").AddArgument(start); - end = new FunctionExpression("date_trunc").AddArgument("'year'").AddArgument(end); - VisitedExpression age = new FunctionExpression("age").AddArgument(end).AddArgument(start); - return new FunctionExpression("date_part").AddArgument("'year'").AddArgument(age).Append("::int4"); - } - default: throw new NotSupportedException("Internal error: unknown function name " + functionName); + case "DiffDays": + start = new FunctionExpression("date_trunc").AddArgument("'day'").AddArgument(start); + end = new FunctionExpression("date_trunc").AddArgument("'day'").AddArgument(end); + return new FunctionExpression("date_part").AddArgument("'day'").AddArgument( + OperatorExpression.Build(Operator.Sub, _useNewPrecedences, end, start) + ).Append("::int4"); + case "DiffHours": + { + start = new FunctionExpression("date_trunc").AddArgument("'hour'").AddArgument(start); + end = new FunctionExpression("date_trunc").AddArgument("'hour'").AddArgument(end); + var epoch = new LiteralExpression("epoch from "); + var diff = OperatorExpression.Build(Operator.Sub, _useNewPrecedences, end, start); + epoch.Append(diff); + return OperatorExpression.Build(Operator.Div, _useNewPrecedences, new FunctionExpression("extract").AddArgument(epoch).Append("::int4"), new LiteralExpression("3600")); + } + case "DiffMicroseconds": + { + start = new FunctionExpression("date_trunc").AddArgument("'microseconds'").AddArgument(start); + end = new FunctionExpression("date_trunc").AddArgument("'microseconds'").AddArgument(end); + var epoch = new LiteralExpression("epoch from "); + var diff = OperatorExpression.Build(Operator.Sub, _useNewPrecedences, end, start); + epoch.Append(diff); + return new CastExpression(OperatorExpression.Build(Operator.Mul, _useNewPrecedences, new FunctionExpression("extract").AddArgument(epoch), new LiteralExpression("1000000")), "int4"); + } + case "DiffMilliseconds": + { + start = new FunctionExpression("date_trunc").AddArgument("'milliseconds'").AddArgument(start); + end = new FunctionExpression("date_trunc").AddArgument("'milliseconds'").AddArgument(end); + var epoch = new LiteralExpression("epoch from "); + var diff = OperatorExpression.Build(Operator.Sub, _useNewPrecedences, end, start); + epoch.Append(diff); + return new CastExpression(OperatorExpression.Build(Operator.Mul, _useNewPrecedences, new FunctionExpression("extract").AddArgument(epoch), new LiteralExpression("1000")), "int4"); + } + case "DiffMinutes": + { + start = new FunctionExpression("date_trunc").AddArgument("'minute'").AddArgument(start); + end = new FunctionExpression("date_trunc").AddArgument("'minute'").AddArgument(end); + var epoch = new LiteralExpression("epoch from "); + var diff = OperatorExpression.Build(Operator.Sub, _useNewPrecedences, end, start); + epoch.Append(diff); + return OperatorExpression.Build(Operator.Div, _useNewPrecedences, new FunctionExpression("extract").AddArgument(epoch).Append("::int4"), new LiteralExpression("60")); + } + case "DiffMonths": + { + start = new FunctionExpression("date_trunc").AddArgument("'month'").AddArgument(start); + end = new FunctionExpression("date_trunc").AddArgument("'month'").AddArgument(end); + var age = new FunctionExpression("age").AddArgument(end).AddArgument(start); + + // A month is 30 days and a year is 365.25 days after conversion from interval to seconds. + // After rounding and casting, the result will contain the correct number of months as an int4. + var seconds = new FunctionExpression("extract").AddArgument(new LiteralExpression("epoch from ").Append(age)); + var months = OperatorExpression.Build(Operator.Div, _useNewPrecedences, seconds, new LiteralExpression("2629800.0")); + return new FunctionExpression("round").AddArgument(months).Append("::int4"); + } + case "DiffNanoseconds": + { + // PostgreSQL only supports microseconds precision, so the value will be a multiple of 1000 + // This date_trunc will make sure start and end are of type timestamp, e.g. if the arguments is of type date + start = new FunctionExpression("date_trunc").AddArgument("'microseconds'").AddArgument(start); + end = new FunctionExpression("date_trunc").AddArgument("'microseconds'").AddArgument(end); + var epoch = new LiteralExpression("epoch from "); + var diff = OperatorExpression.Build(Operator.Sub, _useNewPrecedences, end, start); + epoch.Append(diff); + return new CastExpression(OperatorExpression.Build(Operator.Mul, _useNewPrecedences, new FunctionExpression("extract").AddArgument(epoch), new LiteralExpression("1000000000")), "int4"); + } + case "DiffSeconds": + { + start = new FunctionExpression("date_trunc").AddArgument("'second'").AddArgument(start); + end = new FunctionExpression("date_trunc").AddArgument("'second'").AddArgument(end); + var epoch = new LiteralExpression("epoch from "); + var diff = OperatorExpression.Build(Operator.Sub, _useNewPrecedences, end, start); + epoch.Append(diff); + return new FunctionExpression("extract").AddArgument(epoch).Append("::int4"); + } + case "DiffYears": + { + start = new FunctionExpression("date_trunc").AddArgument("'year'").AddArgument(start); + end = new FunctionExpression("date_trunc").AddArgument("'year'").AddArgument(end); + var age = new FunctionExpression("age").AddArgument(end).AddArgument(start); + return new FunctionExpression("date_part").AddArgument("'year'").AddArgument(age).Append("::int4"); + } + default: + throw new NotSupportedException("Internal error: unknown function name " + functionName); } } - private VisitedExpression BitwiseOperator(IList args, Operator oper) + VisitedExpression BitwiseOperator(IList args, Operator oper) { - System.Diagnostics.Debug.Assert(args.Count == 2); + Debug.Assert(args.Count == 2); return OperatorExpression.Build(oper, _useNewPrecedences, args[0].Accept(this), args[1].Accept(this)); } #if ENTITIES6 - public override VisitedExpression Visit(DbInExpression expression) + public override VisitedExpression Visit([NotNull] DbInExpression expression) { - VisitedExpression item = expression.Item.Accept(this); + var item = expression.Item.Accept(this); - ConstantExpression[] elements = new ConstantExpression[expression.List.Count]; - for (int i = 0; i < expression.List.Count; i++) - { + var elements = new ConstantExpression[expression.List.Count]; + for (var i = 0; i < expression.List.Count; i++) elements[i] = (ConstantExpression)expression.List[i].Accept(this); - } return OperatorExpression.Build(Operator.In, _useNewPrecedences, item, new ConstantListExpression(elements)); } - public override VisitedExpression Visit(DbPropertyExpression expression) + public override VisitedExpression Visit([NotNull] DbPropertyExpression expression) { // This is overridden in the other visitors throw new NotImplementedException("New in Entity Framework 6"); diff --git a/src/EntityFramework6.Npgsql/SqlGenerators/SqlDeleteGenerator.cs b/src/EntityFramework6.Npgsql/SqlGenerators/SqlDeleteGenerator.cs index a0cc522..a1ae978 100644 --- a/src/EntityFramework6.Npgsql/SqlGenerators/SqlDeleteGenerator.cs +++ b/src/EntityFramework6.Npgsql/SqlGenerators/SqlDeleteGenerator.cs @@ -22,7 +22,6 @@ #endregion using System; -using System.Collections.Generic; using System.Data.Common; #if ENTITIES6 using System.Data.Entity.Core.Common.CommandTrees; @@ -34,8 +33,8 @@ namespace Npgsql.SqlGenerators { internal class SqlDeleteGenerator : SqlBaseGenerator { - private DbDeleteCommandTree _commandTree; - private string _tableName; + readonly DbDeleteCommandTree _commandTree; + string _tableName; public SqlDeleteGenerator(DbDeleteCommandTree commandTree) { @@ -44,7 +43,7 @@ public SqlDeleteGenerator(DbDeleteCommandTree commandTree) public override VisitedExpression Visit(DbPropertyExpression expression) { - DbVariableReferenceExpression variable = expression.Instance as DbVariableReferenceExpression; + var variable = expression.Instance as DbVariableReferenceExpression; if (variable == null || variable.VariableName != _tableName) throw new NotSupportedException(); return new PropertyExpression(expression.Property); @@ -53,13 +52,11 @@ public override VisitedExpression Visit(DbPropertyExpression expression) public override void BuildCommand(DbCommand command) { // TODO: handle _commandTree.Returning and _commandTree.Parameters - DeleteExpression delete = new DeleteExpression(); + var delete = new DeleteExpression(); _tableName = _commandTree.Target.VariableName; delete.AppendFrom(_commandTree.Target.Expression.Accept(this)); if (_commandTree.Predicate != null) - { delete.AppendWhere(_commandTree.Predicate.Accept(this)); - } _tableName = null; command.CommandText = delete.ToString(); } diff --git a/src/EntityFramework6.Npgsql/SqlGenerators/SqlInsertGenerator.cs b/src/EntityFramework6.Npgsql/SqlGenerators/SqlInsertGenerator.cs index b018bca..80da4b8 100644 --- a/src/EntityFramework6.Npgsql/SqlGenerators/SqlInsertGenerator.cs +++ b/src/EntityFramework6.Npgsql/SqlGenerators/SqlInsertGenerator.cs @@ -34,8 +34,8 @@ namespace Npgsql.SqlGenerators { internal class SqlInsertGenerator : SqlBaseGenerator { - private DbInsertCommandTree _commandTree; - private string _tableName; + readonly DbInsertCommandTree _commandTree; + string _tableName; public SqlInsertGenerator(DbInsertCommandTree commandTree) { @@ -44,7 +44,7 @@ public SqlInsertGenerator(DbInsertCommandTree commandTree) public override VisitedExpression Visit(DbPropertyExpression expression) { - DbVariableReferenceExpression variable = expression.Instance as DbVariableReferenceExpression; + var variable = expression.Instance as DbVariableReferenceExpression; if (variable == null || variable.VariableName != _tableName) throw new NotSupportedException(); return new PropertyExpression(expression.Property); @@ -53,11 +53,11 @@ public override VisitedExpression Visit(DbPropertyExpression expression) public override void BuildCommand(DbCommand command) { // TODO: handle_commandTree.Parameters - InsertExpression insert = new InsertExpression(); + var insert = new InsertExpression(); _tableName = _commandTree.Target.VariableName; insert.AppendTarget(_commandTree.Target.Expression.Accept(this)); - List columns = new List(); - List values = new List(); + var columns = new List(); + var values = new List(); foreach (DbSetClause clause in _commandTree.SetClauses) { columns.Add(clause.Property.Accept(this)); @@ -66,9 +66,7 @@ public override void BuildCommand(DbCommand command) insert.AppendColumns(columns); insert.AppendValues(values); if (_commandTree.Returning != null) - { - insert.AppendReturning(_commandTree.Returning as DbNewInstanceExpression); - } + insert.AppendReturning((DbNewInstanceExpression)_commandTree.Returning); _tableName = null; command.CommandText = insert.ToString(); } diff --git a/src/EntityFramework6.Npgsql/SqlGenerators/SqlSelectGenerator.cs b/src/EntityFramework6.Npgsql/SqlGenerators/SqlSelectGenerator.cs index ef4e8ba..ffecf5b 100644 --- a/src/EntityFramework6.Npgsql/SqlGenerators/SqlSelectGenerator.cs +++ b/src/EntityFramework6.Npgsql/SqlGenerators/SqlSelectGenerator.cs @@ -23,8 +23,8 @@ using System; using System.Linq; -using System.Collections.Generic; using System.Data.Common; +using System.Diagnostics; #if ENTITIES6 using System.Data.Entity.Core.Common.CommandTrees; using System.Data.Entity.Core.Metadata.Edm; @@ -37,7 +37,7 @@ namespace Npgsql.SqlGenerators { internal class SqlSelectGenerator : SqlBaseGenerator { - private DbQueryCommandTree _commandTree; + readonly DbQueryCommandTree _commandTree; public SqlSelectGenerator(DbQueryCommandTree commandTree) { @@ -74,18 +74,18 @@ public override VisitedExpression Visit(DbPropertyExpression expression) * The new name is then propagated down to the root. */ - string name = expression.Property.Name; - string from = (expression.Instance.ExpressionKind == DbExpressionKind.Property) + var name = expression.Property.Name; + var from = expression.Instance.ExpressionKind == DbExpressionKind.Property ? ((DbPropertyExpression)expression.Instance).Property.Name : ((DbVariableReferenceExpression)expression.Instance).VariableName; - PendingProjectsNode node = _refToNode[from]; + var node = RefToNode[from]; from = node.TopName; while (node != null) { foreach (var item in node.Selects) { - if (_currentExpressions.Contains(item.Exp)) + if (CurrentExpressions.Contains(item.Exp)) continue; var use = new StringPair(from, name); @@ -99,9 +99,7 @@ public override VisitedExpression Visit(DbPropertyExpression expression) item.Exp.ProjectNewNames.Add(name); } else - { name = item.Exp.ColumnsToProject[use]; - } from = item.AsName; } node = node.JoinParent; @@ -109,20 +107,18 @@ public override VisitedExpression Visit(DbPropertyExpression expression) return new ColumnReferenceExpression { Variable = from, Name = name }; } + // must provide a NULL of the correct type + // this is necessary for certain types of union queries. public override VisitedExpression Visit(DbNullExpression expression) - { - // must provide a NULL of the correct type - // this is necessary for certain types of union queries. - return new CastExpression(new LiteralExpression("NULL"), GetDbType(expression.ResultType.EdmType)); - } + => new CastExpression(new LiteralExpression("NULL"), GetDbType(expression.ResultType.EdmType)); public override void BuildCommand(DbCommand command) { - System.Diagnostics.Debug.Assert(command is NpgsqlCommand); - System.Diagnostics.Debug.Assert(_commandTree.Query is DbProjectExpression); - VisitedExpression ve = _commandTree.Query.Accept(this); - System.Diagnostics.Debug.Assert(ve is InputExpression); - InputExpression pe = (InputExpression)ve; + Debug.Assert(command is NpgsqlCommand); + Debug.Assert(_commandTree.Query is DbProjectExpression); + var ve = _commandTree.Query.Accept(this); + Debug.Assert(ve is InputExpression); + var pe = (InputExpression)ve; command.CommandText = pe.ToString(); // We retrieve all strings as unknowns in text format in the case the data types aren't really texts @@ -134,17 +130,18 @@ public override void BuildCommand(DbCommand command) return kind == PrimitiveTypeKind.SByte || kind == PrimitiveTypeKind.DateTimeOffset; })) { - ((NpgsqlCommand)command).ObjectResultTypes = pe.Projection.Arguments.Select(a => { + ((NpgsqlCommand)command).ObjectResultTypes = pe.Projection.Arguments.Select(a => + { var kind = ((PrimitiveType)((ColumnExpression)a).ColumnType.EdmType).PrimitiveTypeKind; - if (kind == PrimitiveTypeKind.SByte) + switch (kind) { + case PrimitiveTypeKind.SByte: return typeof(sbyte); - } - else if (kind == PrimitiveTypeKind.DateTimeOffset) - { + case PrimitiveTypeKind.DateTimeOffset: return typeof(DateTimeOffset); + default: + return null; } - return null; }).ToArray(); } } diff --git a/src/EntityFramework6.Npgsql/SqlGenerators/SqlUpdateGenerator.cs b/src/EntityFramework6.Npgsql/SqlGenerators/SqlUpdateGenerator.cs index 9912e4e..c109282 100644 --- a/src/EntityFramework6.Npgsql/SqlGenerators/SqlUpdateGenerator.cs +++ b/src/EntityFramework6.Npgsql/SqlGenerators/SqlUpdateGenerator.cs @@ -22,7 +22,6 @@ #endregion using System; -using System.Collections.Generic; using System.Data.Common; #if ENTITIES6 using System.Data.Entity.Core.Common.CommandTrees; @@ -34,8 +33,8 @@ namespace Npgsql.SqlGenerators { class SqlUpdateGenerator : SqlBaseGenerator { - private DbUpdateCommandTree _commandTree; - private string _tableName; + readonly DbUpdateCommandTree _commandTree; + string _tableName; public SqlUpdateGenerator(DbUpdateCommandTree commandTree) { @@ -44,7 +43,7 @@ public SqlUpdateGenerator(DbUpdateCommandTree commandTree) public override VisitedExpression Visit(DbPropertyExpression expression) { - DbVariableReferenceExpression variable = expression.Instance as DbVariableReferenceExpression; + var variable = expression.Instance as DbVariableReferenceExpression; if (variable == null || variable.VariableName != _tableName) throw new NotSupportedException(); return new PropertyExpression(expression.Property); @@ -53,21 +52,15 @@ public override VisitedExpression Visit(DbPropertyExpression expression) public override void BuildCommand(DbCommand command) { // TODO: handle _commandTree.Parameters - UpdateExpression update = new UpdateExpression(); + var update = new UpdateExpression(); _tableName = _commandTree.Target.VariableName; update.AppendTarget(_commandTree.Target.Expression.Accept(this)); foreach (DbSetClause clause in _commandTree.SetClauses) - { update.AppendSet(clause.Property.Accept(this), clause.Value.Accept(this)); - } if (_commandTree.Predicate != null) - { update.AppendWhere(_commandTree.Predicate.Accept(this)); - } if (_commandTree.Returning != null) - { update.AppendReturning((DbNewInstanceExpression)_commandTree.Returning); - } _tableName = null; command.CommandText = update.ToString(); } diff --git a/src/EntityFramework6.Npgsql/SqlGenerators/StringPair.cs b/src/EntityFramework6.Npgsql/SqlGenerators/StringPair.cs index 7cd6779..c605574 100644 --- a/src/EntityFramework6.Npgsql/SqlGenerators/StringPair.cs +++ b/src/EntityFramework6.Npgsql/SqlGenerators/StringPair.cs @@ -25,6 +25,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using JetBrains.Annotations; namespace Npgsql.SqlGenerators { @@ -33,33 +34,25 @@ namespace Npgsql.SqlGenerators /// internal class StringPair { - private string _item1; - private string _item2; - - public string Item1 { get { return _item1; } } - public string Item2 { get { return _item2; } } + public string Item1 { get; } + public string Item2 { get; } public StringPair(string s1, string s2) { - _item1 = s1; - _item2 = s2; + Item1 = s1; + Item2 = s2; } - public override bool Equals(object obj) + public override bool Equals([CanBeNull] object obj) { - if (obj == null) - return false; - - StringPair o = obj as StringPair; + var o = obj as StringPair; if (o == null) return false; - return _item1 == o._item1 && _item2 == o._item2; + return Item1 == o.Item1 && Item2 == o.Item2; } public override int GetHashCode() - { - return (_item1 + "." + _item2).GetHashCode(); - } + => (Item1 + "." + Item2).GetHashCode(); } } diff --git a/src/EntityFramework6.Npgsql/SqlGenerators/VisitedExpression.cs b/src/EntityFramework6.Npgsql/SqlGenerators/VisitedExpression.cs index 0f468d0..f8963cb 100644 --- a/src/EntityFramework6.Npgsql/SqlGenerators/VisitedExpression.cs +++ b/src/EntityFramework6.Npgsql/SqlGenerators/VisitedExpression.cs @@ -33,8 +33,8 @@ using System.Data.Common.CommandTrees; #endif using NpgsqlTypes; -using System.Data; using System.Globalization; +using JetBrains.Annotations; namespace Npgsql.SqlGenerators { @@ -59,25 +59,23 @@ public VisitedExpression Append(string literal) public override string ToString() { - StringBuilder sqlText = new StringBuilder(); + var sqlText = new StringBuilder(); WriteSql(sqlText); return sqlText.ToString(); } - protected List ExpressionList { get; private set; } + protected List ExpressionList { get; } internal virtual void WriteSql(StringBuilder sqlText) { - foreach (VisitedExpression expression in ExpressionList) - { + foreach (var expression in ExpressionList) expression.WriteSql(sqlText); - } } } internal class LiteralExpression : VisitedExpression { - private string _literal; + readonly string _literal; public LiteralExpression(string literal) { @@ -108,7 +106,7 @@ internal class CommaSeparatedExpression : VisitedExpression internal override void WriteSql(StringBuilder sqlText) { - for (int i = 0; i < Arguments.Count; ++i) + for (var i = 0; i < Arguments.Count; ++i) { if (i != 0) sqlText.Append(", "); @@ -120,15 +118,15 @@ internal override void WriteSql(StringBuilder sqlText) internal class ConstantExpression : VisitedExpression { - private PrimitiveTypeKind _primitiveType; - private object _value; + readonly PrimitiveTypeKind _primitiveType; + readonly object _value; public ConstantExpression(object value, TypeUsage edmType) { if (edmType == null) - throw new ArgumentNullException("edmType"); + throw new ArgumentNullException(nameof(edmType)); if (edmType.EdmType == null || edmType.EdmType.BuiltInTypeKind != BuiltInTypeKind.PrimitiveType) - throw new ArgumentException("Require primitive EdmType", "edmType"); + throw new ArgumentException("Require primitive EdmType", nameof(edmType)); _primitiveType = ((PrimitiveType)edmType.EdmType).PrimitiveTypeKind; _value = value; } @@ -136,120 +134,88 @@ public ConstantExpression(object value, TypeUsage edmType) internal override void WriteSql(StringBuilder sqlText) { var ni = CultureInfo.InvariantCulture.NumberFormat; - object value = _value; + var value = _value; switch (_primitiveType) { - case PrimitiveTypeKind.Binary: - { - sqlText.AppendFormat("decode('{0}', 'base64')", Convert.ToBase64String((byte[])_value)); - } - break; - case PrimitiveTypeKind.DateTime: - sqlText.AppendFormat(ni, "TIMESTAMP '{0:o}'", _value); - break; - case PrimitiveTypeKind.DateTimeOffset: - sqlText.AppendFormat(ni, "TIMESTAMP WITH TIME ZONE '{0:o}'", _value); - break; - case PrimitiveTypeKind.Decimal: - if ((decimal)_value < 0) - { - sqlText.AppendFormat(ni, "({0})::numeric", _value); - } - else - { - sqlText.AppendFormat(ni, "{0}::numeric", _value); - } - break; - case PrimitiveTypeKind.Double: - if (double.IsNaN((double)_value)) - { - sqlText.AppendFormat("'NaN'::float8"); - } - else if (double.IsPositiveInfinity((double)_value)) - { - sqlText.AppendFormat("'Infinity'::float8"); - } - else if (double.IsNegativeInfinity((double)_value)) - { - sqlText.AppendFormat("'-Infinity'::float8"); - } - else if ((double)_value < 0) - { - sqlText.AppendFormat(ni, "({0:r})::float8", _value); - } - else - { - sqlText.AppendFormat(ni, "{0:r}::float8", _value); - } - break; - // PostgreSQL has no support for bytes. int2 is used instead in Npgsql. - case PrimitiveTypeKind.Byte: - value = (short)(byte)_value; - goto case PrimitiveTypeKind.Int16; - case PrimitiveTypeKind.SByte: - value = (short)(sbyte)_value; - goto case PrimitiveTypeKind.Int16; - case PrimitiveTypeKind.Int16: - if ((short)value < 0) - { - sqlText.AppendFormat(ni, "({0})::int2", _value); - } - else - { - sqlText.AppendFormat(ni, "{0}::int2", _value); - } - break; - case PrimitiveTypeKind.Int32: - sqlText.AppendFormat(ni, "{0}", _value); - break; - case PrimitiveTypeKind.Int64: - if ((long)_value < 0) - { - sqlText.AppendFormat(ni, "({0})::int8", _value); - } - else - { - sqlText.AppendFormat(ni, "{0}::int8", _value); - } - break; - case PrimitiveTypeKind.Single: - if (float.IsNaN((float)_value)) - { - sqlText.AppendFormat("'NaN'::float4"); - } - else if (float.IsPositiveInfinity((float)_value)) - { - sqlText.AppendFormat("'Infinity'::float4"); - } - else if (float.IsNegativeInfinity((float)_value)) - { - sqlText.AppendFormat("'-Infinity'::float4"); - } - else if ((float)_value < 0) - { - sqlText.AppendFormat(ni, "({0:r})::float4", _value); - } - else - { - sqlText.AppendFormat(ni, "{0:r}::float4", _value); - } - break; - case PrimitiveTypeKind.Boolean: - sqlText.Append(((bool)_value) ? "TRUE" : "FALSE"); - break; - case PrimitiveTypeKind.Guid: - sqlText.Append('\'').Append((Guid)_value).Append('\''); - sqlText.Append("::uuid"); - break; - case PrimitiveTypeKind.String: - sqlText.Append("E'").Append(((string)_value).Replace(@"\", @"\\").Replace("'", @"\'")).Append("'"); - break; - case PrimitiveTypeKind.Time: - sqlText.AppendFormat(ni, "INTERVAL '{0}'", (NpgsqlTimeSpan)(TimeSpan)_value); - break; - default: - // TODO: must support more constant value types. - throw new NotSupportedException(string.Format("NotSupported: {0} {1}", _primitiveType, _value)); + case PrimitiveTypeKind.Binary: + { + sqlText.Append($"decode('{Convert.ToBase64String((byte[])_value)}', 'base64')"); + } + break; + case PrimitiveTypeKind.DateTime: + sqlText.AppendFormat(ni, "TIMESTAMP '{0:o}'", _value); + break; + case PrimitiveTypeKind.DateTimeOffset: + sqlText.AppendFormat(ni, "TIMESTAMP WITH TIME ZONE '{0:o}'", _value); + break; + case PrimitiveTypeKind.Decimal: + sqlText.AppendFormat(ni, (decimal)_value < 0 + ? "({0})::numeric" + : "{0}::numeric", _value + ); + break; + case PrimitiveTypeKind.Double: + if (double.IsNaN((double)_value)) + sqlText.Append("'NaN'::float8"); + else if (double.IsPositiveInfinity((double)_value)) + sqlText.Append("'Infinity'::float8"); + else if (double.IsNegativeInfinity((double)_value)) + sqlText.Append("'-Infinity'::float8"); + else if ((double)_value < 0) + sqlText.AppendFormat(ni, "({0:r})::float8", _value); + else + sqlText.AppendFormat(ni, "{0:r}::float8", _value); + break; + // PostgreSQL has no support for bytes. int2 is used instead in Npgsql. + case PrimitiveTypeKind.Byte: + value = (short)(byte)_value; + goto case PrimitiveTypeKind.Int16; + case PrimitiveTypeKind.SByte: + value = (short)(sbyte)_value; + goto case PrimitiveTypeKind.Int16; + case PrimitiveTypeKind.Int16: + sqlText.AppendFormat(ni, (short)value < 0 + ? "({0})::int2" + : "{0}::int2", _value + ); + break; + case PrimitiveTypeKind.Int32: + sqlText.AppendFormat(ni, "{0}", _value); + break; + case PrimitiveTypeKind.Int64: + sqlText.AppendFormat(ni, (long)_value < 0 + ? "({0})::int8" + : "{0}::int8", _value + ); + break; + case PrimitiveTypeKind.Single: + if (float.IsNaN((float)_value)) + sqlText.Append("'NaN'::float4"); + else if (float.IsPositiveInfinity((float)_value)) + sqlText.Append("'Infinity'::float4"); + else if (float.IsNegativeInfinity((float)_value)) + sqlText.Append("'-Infinity'::float4"); + else if ((float)_value < 0) + sqlText.AppendFormat(ni, "({0:r})::float4", _value); + else + sqlText.AppendFormat(ni, "{0:r}::float4", _value); + break; + case PrimitiveTypeKind.Boolean: + sqlText.Append((bool)_value ? "TRUE" : "FALSE"); + break; + case PrimitiveTypeKind.Guid: + sqlText.Append('\'').Append((Guid)_value).Append('\''); + sqlText.Append("::uuid"); + break; + case PrimitiveTypeKind.String: + sqlText.Append("E'").Append(((string)_value).Replace(@"\", @"\\").Replace("'", @"\'")).Append("'"); + break; + case PrimitiveTypeKind.Time: + sqlText.AppendFormat(ni, "INTERVAL '{0}'", (NpgsqlTimeSpan)(TimeSpan)_value); + break; + default: + // TODO: must support more constant value types. + throw new NotSupportedException($"NotSupported: {_primitiveType} {_value}"); } base.WriteSql(sqlText); } @@ -268,8 +234,8 @@ public void AppendColumns(IEnumerable columns) return; Append("("); - bool first = true; - foreach (VisitedExpression expression in columns) + var first = true; + foreach (var expression in columns) { if (!first) Append(","); @@ -285,7 +251,7 @@ public void AppendValues(IEnumerable columns) { Append(" VALUES ("); bool first = true; - foreach (VisitedExpression expression in columns) + foreach (var expression in columns) { if (!first) Append(","); @@ -295,20 +261,18 @@ public void AppendValues(IEnumerable columns) Append(")"); } else - { Append(" DEFAULT VALUES"); - } } internal void AppendReturning(DbNewInstanceExpression expression) { Append(" RETURNING ");//Don't put () around columns it will probably have unwanted effect - bool first = true; + var first = true; foreach (var returingProperty in expression.Arguments) { if (!first) Append(","); - Append(SqlBaseGenerator.QuoteIdentifier((returingProperty as DbPropertyExpression).Property.Name)); + Append(SqlBaseGenerator.QuoteIdentifier(((DbPropertyExpression)returingProperty).Property.Name)); first = false; } } @@ -322,7 +286,7 @@ internal override void WriteSql(StringBuilder sqlText) internal class UpdateExpression : VisitedExpression { - private bool _setSeperatorRequired; + bool _setSeperatorRequired; public void AppendTarget(VisitedExpression target) { @@ -331,10 +295,7 @@ public void AppendTarget(VisitedExpression target) public void AppendSet(VisitedExpression property, VisitedExpression value) { - if (_setSeperatorRequired) - Append(","); - else - Append(" SET "); + Append(_setSeperatorRequired ? "," : " SET "); Append(property); Append("="); Append(value); @@ -356,12 +317,12 @@ internal override void WriteSql(StringBuilder sqlText) internal void AppendReturning(DbNewInstanceExpression expression) { Append(" RETURNING ");//Don't put () around columns it will probably have unwanted effect - bool first = true; + var first = true; foreach (var returingProperty in expression.Arguments) { if (!first) Append(","); - Append(SqlBaseGenerator.QuoteIdentifier((returingProperty as DbPropertyExpression).Property.Name)); + Append(SqlBaseGenerator.QuoteIdentifier(((DbPropertyExpression)returingProperty).Property.Name)); first = false; } } @@ -389,31 +350,23 @@ internal override void WriteSql(StringBuilder sqlText) internal class ColumnExpression : VisitedExpression { - private VisitedExpression _column; - private string _columnName; - private TypeUsage _columnType; + internal string Name { get; } + internal TypeUsage ColumnType { get; } + readonly VisitedExpression _column; public ColumnExpression(VisitedExpression column, string columnName, TypeUsage columnType) { _column = column; - _columnName = columnName; - _columnType = columnType; + Name = columnName; + ColumnType = columnType; } - public string Name { get { return _columnName; } } - internal TypeUsage ColumnType { get { return _columnType; ;} } - - public Type CLRType + public Type ClrType { get { - if (_columnType == null) - return null; - PrimitiveType pt = _columnType.EdmType as PrimitiveType; - if (pt != null) - return pt.ClrEquivalentType; - else - return null; + var pt = ColumnType?.EdmType as PrimitiveType; + return pt?.ClrEquivalentType; } } @@ -421,11 +374,11 @@ internal override void WriteSql(StringBuilder sqlText) { _column.WriteSql(sqlText); - ColumnReferenceExpression column = _column as ColumnReferenceExpression; - if (column == null || column.Name != _columnName) + var column = _column as ColumnReferenceExpression; + if (column == null || column.Name != Name) { sqlText.Append(" AS "); - sqlText.Append(SqlBaseGenerator.QuoteIdentifier(_columnName)); + sqlText.Append(SqlBaseGenerator.QuoteIdentifier(Name)); } base.WriteSql(sqlText); @@ -451,17 +404,15 @@ internal override void WriteSql(StringBuilder sqlText) internal class ScanExpression : VisitedExpression { - private string _scanString; - private EntitySetBase _target; + readonly string _scanString; + internal EntitySetBase Target { get; } public ScanExpression(string scanString, EntitySetBase target) { _scanString = scanString; - _target = target; + Target = target; } - internal EntitySetBase Target { get { return _target; } } - internal override void WriteSql(StringBuilder sqlText) { sqlText.Append(_scanString); @@ -480,55 +431,11 @@ internal class InputExpression : VisitedExpression // Either FromExpression or JoinExpression public VisitedExpression From { get; set; } - - private WhereExpression _where; - - public WhereExpression Where - { - get { return _where; } - set - { - _where = value; - } - } - - private GroupByExpression _groupBy; - - public GroupByExpression GroupBy - { - get { return _groupBy; } - set - { - _groupBy = value; - } - } - - private OrderByExpression _orderBy; - - public OrderByExpression OrderBy - { - get { return _orderBy; } - set { _orderBy = value; } - } - - private SkipExpression _skip; - - public SkipExpression Skip - { - get { return _skip; } - set { _skip = value; } - } - - private LimitExpression _limit; - - public LimitExpression Limit - { - get { return _limit; } - set - { - _limit = value; - } - } + public WhereExpression Where { get; set; } + public GroupByExpression GroupBy { get; set; } + public OrderByExpression OrderBy { get; set; } + public SkipExpression Skip { get; set; } + public LimitExpression Limit { get; set; } public InputExpression() { } @@ -540,21 +447,22 @@ public InputExpression(VisitedExpression from, string asName) internal override void WriteSql(StringBuilder sqlText) { sqlText.Append("SELECT "); - if (Distinct) sqlText.Append("DISTINCT "); - if (Projection != null) Projection.WriteSql(sqlText); + if (Distinct) + sqlText.Append("DISTINCT "); + if (Projection != null) + Projection.WriteSql(sqlText); else { if (ColumnsToProject.Count == 0) sqlText.Append("1"); // Could be arbitrary, let's pick 1 else { - bool first = true; + var first = true; foreach (var column in ColumnsToProject) { if (!first) - { sqlText.Append(", "); - } - else first = false; + else + first = false; sqlText.Append(SqlBaseGenerator.QuoteIdentifier(column.Key.Item1)); sqlText.Append("."); sqlText.Append(SqlBaseGenerator.QuoteIdentifier(column.Key.Item2)); @@ -568,38 +476,34 @@ internal override void WriteSql(StringBuilder sqlText) } sqlText.Append(" FROM "); From.WriteSql(sqlText); - if (Where != null) Where.WriteSql(sqlText); - if (GroupBy != null) GroupBy.WriteSql(sqlText); - if (OrderBy != null) OrderBy.WriteSql(sqlText); - if (Skip != null) Skip.WriteSql(sqlText); - if (Limit != null) Limit.WriteSql(sqlText); + Where?.WriteSql(sqlText); + GroupBy?.WriteSql(sqlText); + OrderBy?.WriteSql(sqlText); + Skip?.WriteSql(sqlText); + Limit?.WriteSql(sqlText); base.WriteSql(sqlText); } } internal class FromExpression : VisitedExpression { - private VisitedExpression _from; - private string _name; + readonly VisitedExpression _from; + internal string Name { get; } public FromExpression(VisitedExpression from, string name) { _from = from; - _name = name; - } - - public string Name - { - get { return _name; } + Name = name; } public bool ForceSubquery { get; set; } internal override void WriteSql(StringBuilder sqlText) { - if (_from is InputExpression) + var from = _from as InputExpression; + if (from != null) { - InputExpression input = (InputExpression)_from; + var input = from; if (!ForceSubquery && input.Projection == null && input.Where == null && input.Distinct == false && input.OrderBy == null && input.Skip == null && input.Limit == null) { @@ -617,19 +521,19 @@ internal override void WriteSql(StringBuilder sqlText) sqlText.Append("("); input.WriteSql(sqlText); sqlText.Append(") AS "); - sqlText.Append(SqlBaseGenerator.QuoteIdentifier(_name)); + sqlText.Append(SqlBaseGenerator.QuoteIdentifier(Name)); } } else { - bool wrap = !(_from is LiteralExpression || _from is ScanExpression); + var wrap = !(_from is LiteralExpression || _from is ScanExpression); if (wrap) sqlText.Append("("); _from.WriteSql(sqlText); if (wrap) sqlText.Append(")"); sqlText.Append(" AS "); - sqlText.Append(SqlBaseGenerator.QuoteIdentifier(_name)); + sqlText.Append(SqlBaseGenerator.QuoteIdentifier(Name)); } base.WriteSql(sqlText); } @@ -637,64 +541,54 @@ internal override void WriteSql(StringBuilder sqlText) internal class JoinExpression : VisitedExpression { - private VisitedExpression _left; - private DbExpressionKind _joinType; - private VisitedExpression _right; - private VisitedExpression _condition; + internal VisitedExpression Left { get; set; } + internal DbExpressionKind JoinType { get; set; } + internal VisitedExpression Right { get; set; } + internal VisitedExpression Condition { get; set; } public JoinExpression() { } public JoinExpression(InputExpression left, DbExpressionKind joinType, InputExpression right, VisitedExpression condition) { - _left = left; - _joinType = joinType; - _right = right; - _condition = condition; - } - - public VisitedExpression Left { get { return _left; } set { _left = value; } } - public DbExpressionKind JoinType { get { return _joinType; } set { _joinType = value; } } - public VisitedExpression Right { get { return _right; } set { _right = value; } } - - public VisitedExpression Condition - { - get { return _condition; } - set { _condition = value; } + Left = left; + JoinType = joinType; + Right = right; + Condition = condition; } internal override void WriteSql(StringBuilder sqlText) { - _left.WriteSql(sqlText); - switch (_joinType) + Left.WriteSql(sqlText); + switch (JoinType) { - case DbExpressionKind.InnerJoin: - sqlText.Append(" INNER JOIN "); - break; - case DbExpressionKind.LeftOuterJoin: - sqlText.Append(" LEFT OUTER JOIN "); - break; - case DbExpressionKind.FullOuterJoin: - sqlText.Append(" FULL OUTER JOIN "); - break; - case DbExpressionKind.CrossJoin: - sqlText.Append(" CROSS JOIN "); - break; - case DbExpressionKind.CrossApply: - sqlText.Append(" CROSS JOIN LATERAL "); - break; - case DbExpressionKind.OuterApply: - sqlText.Append(" LEFT OUTER JOIN LATERAL "); - break; - default: - throw new NotSupportedException(); + case DbExpressionKind.InnerJoin: + sqlText.Append(" INNER JOIN "); + break; + case DbExpressionKind.LeftOuterJoin: + sqlText.Append(" LEFT OUTER JOIN "); + break; + case DbExpressionKind.FullOuterJoin: + sqlText.Append(" FULL OUTER JOIN "); + break; + case DbExpressionKind.CrossJoin: + sqlText.Append(" CROSS JOIN "); + break; + case DbExpressionKind.CrossApply: + sqlText.Append(" CROSS JOIN LATERAL "); + break; + case DbExpressionKind.OuterApply: + sqlText.Append(" LEFT OUTER JOIN LATERAL "); + break; + default: + throw new NotSupportedException(); } - _right.WriteSql(sqlText); - if (_joinType == DbExpressionKind.OuterApply) + Right.WriteSql(sqlText); + if (JoinType == DbExpressionKind.OuterApply) sqlText.Append(" ON TRUE"); - else if (_joinType != DbExpressionKind.CrossJoin && _joinType != DbExpressionKind.CrossApply) + else if (JoinType != DbExpressionKind.CrossJoin && JoinType != DbExpressionKind.CrossApply) { sqlText.Append(" ON "); - _condition.WriteSql(sqlText); + Condition.WriteSql(sqlText); } base.WriteSql(sqlText); } @@ -702,7 +596,7 @@ internal override void WriteSql(StringBuilder sqlText) internal class WhereExpression : VisitedExpression { - private VisitedExpression _where; + VisitedExpression _where; public WhereExpression(VisitedExpression where) { @@ -725,7 +619,9 @@ internal void And(VisitedExpression andAlso) internal class PropertyExpression : VisitedExpression { - private EdmMember _property; + readonly EdmMember _property; + public string Name => _property.Name; + public TypeUsage PropertyType => _property.TypeUsage; // used for inserts or updates where the column is not qualified public PropertyExpression(EdmMember property) @@ -733,10 +629,6 @@ public PropertyExpression(EdmMember property) _property = property; } - public string Name { get { return _property.Name; } } - - public TypeUsage PropertyType { get { return _property.TypeUsage; } } - internal override void WriteSql(StringBuilder sqlText) { sqlText.Append(SqlBaseGenerator.QuoteIdentifier(_property.Name)); @@ -745,16 +637,13 @@ internal override void WriteSql(StringBuilder sqlText) // override ToString since we don't want variable substitution or identifier quoting // until writing out the SQL. - public override string ToString() - { - return _property.Name; - } + public override string ToString() => Name; } internal class FunctionExpression : VisitedExpression { - private string _name; - private List _args = new List(); + readonly string _name; + readonly List _args = new List(); public FunctionExpression(string name) { @@ -792,8 +681,8 @@ internal override void WriteSql(StringBuilder sqlText) internal class CastExpression : VisitedExpression { - private VisitedExpression _value; - private string _type; + readonly VisitedExpression _value; + readonly string _type; public CastExpression(VisitedExpression value, string type) { @@ -812,7 +701,7 @@ internal override void WriteSql(StringBuilder sqlText) internal class GroupByExpression : VisitedExpression { - private bool _requiresGroupSeperator; + bool _requiresGroupSeperator; public void AppendGroupingKey(VisitedExpression key) { @@ -832,26 +721,24 @@ internal override void WriteSql(StringBuilder sqlText) internal class LimitExpression : VisitedExpression { - private VisitedExpression _arg; - - public VisitedExpression Arg { get { return _arg; } set { _arg = value; } } + internal VisitedExpression Arg { get; set; } public LimitExpression(VisitedExpression arg) { - _arg = arg; + Arg = arg; } internal override void WriteSql(StringBuilder sqlText) { sqlText.Append(" LIMIT "); - _arg.WriteSql(sqlText); + Arg.WriteSql(sqlText); base.WriteSql(sqlText); } } internal class SkipExpression : VisitedExpression { - private VisitedExpression _arg; + readonly VisitedExpression _arg; public SkipExpression(VisitedExpression arg) { @@ -868,19 +755,13 @@ internal override void WriteSql(StringBuilder sqlText) internal class Operator { - private string op; - private int leftPrecedence; - private int rightPrecedence; - private int newPrecedence; // Since PostgreSQL 9.5, the operator precedence was changed - private UnaryTypes unaryType; - private bool rightAssoc; - - public string Op { get { return op; } } - public int LeftPrecedence { get { return leftPrecedence; } } - public int RightPrecedence { get { return rightPrecedence; } } - public int NewPrecedence { get { return newPrecedence; } } - public UnaryTypes UnaryType { get { return unaryType; } } - public bool RightAssoc { get { return rightAssoc; } } + internal string Op { get; } + internal int LeftPrecedence { get; } + internal int RightPrecedence { get; } + // Since PostgreSQL 9.5, the operator precedence was changed + internal int NewPrecedence { get; } + internal UnaryTypes UnaryType { get; } + internal bool RightAssoc { get; } internal enum UnaryTypes { Binary, @@ -888,32 +769,32 @@ internal enum UnaryTypes { Postfix } - private Operator(string op, int precedence, int newPrecedence) + Operator(string op, int precedence, int newPrecedence) { - this.op = ' ' + op + ' '; - this.leftPrecedence = precedence; - this.rightPrecedence = precedence; - this.newPrecedence = newPrecedence; - this.unaryType = UnaryTypes.Binary; + Op = ' ' + op + ' '; + LeftPrecedence = precedence; + RightPrecedence = precedence; + NewPrecedence = newPrecedence; + UnaryType = UnaryTypes.Binary; } - private Operator(string op, int leftPrecedence, int rightPrecedence, int newPrecedence) + Operator(string op, int leftPrecedence, int rightPrecedence, int newPrecedence) { - this.op = ' ' + op + ' '; - this.leftPrecedence = leftPrecedence; - this.rightPrecedence = rightPrecedence; - this.newPrecedence = newPrecedence; - this.unaryType = UnaryTypes.Binary; + Op = ' ' + op + ' '; + LeftPrecedence = leftPrecedence; + RightPrecedence = rightPrecedence; + NewPrecedence = newPrecedence; + UnaryType = UnaryTypes.Binary; } - private Operator(string op, int precedence, int newPrecedence, UnaryTypes unaryType, bool rightAssoc) + Operator(string op, int precedence, int newPrecedence, UnaryTypes unaryType, bool rightAssoc) { - this.op = unaryType == UnaryTypes.Binary ? ' ' + op + ' ' : unaryType == UnaryTypes.Prefix ? op + ' ' : ' ' + op; - this.leftPrecedence = precedence; - this.rightPrecedence = precedence; - this.newPrecedence = newPrecedence; - this.unaryType = unaryType; - this.rightAssoc = rightAssoc; + Op = unaryType == UnaryTypes.Binary ? ' ' + op + ' ' : unaryType == UnaryTypes.Prefix ? op + ' ' : ' ' + op; + LeftPrecedence = precedence; + RightPrecedence = precedence; + NewPrecedence = newPrecedence; + UnaryType = unaryType; + RightAssoc = rightAssoc; } /* @@ -952,7 +833,7 @@ private Operator(string op, int precedence, int newPrecedence, UnaryTypes unaryT public static readonly Operator NotLike = new Operator("NOT LIKE", 3, 6, 6); public static readonly Operator LessThan = new Operator("<", 5, 5); public static readonly Operator GreaterThan = new Operator(">", 5, 5); - public static readonly new Operator Equals = new Operator("=", 4, 5, UnaryTypes.Binary, true); + public new static readonly Operator Equals = new Operator("=", 4, 5, UnaryTypes.Binary, true); public static readonly Operator Not = new Operator("NOT", 3, 3, UnaryTypes.Prefix, true); public static readonly Operator And = new Operator("AND", 2, 2); public static readonly Operator Or = new Operator("OR", 1, 1); @@ -970,61 +851,53 @@ static Operator() { NegateDict = new Dictionary() { - {IsNull, IsNotNull}, - {IsNotNull, IsNull}, - {LessThanOrEquals, GreaterThan}, - {GreaterThanOrEquals, LessThan}, - {NotEquals, Equals}, - {In, NotIn}, - {NotIn, In}, - {Like, NotLike}, - {NotLike, Like}, - {LessThan, GreaterThanOrEquals}, - {GreaterThan, LessThanOrEquals}, - {Equals, NotEquals} + { IsNull, IsNotNull }, + { IsNotNull, IsNull }, + { LessThanOrEquals, GreaterThan }, + { GreaterThanOrEquals, LessThan }, + { NotEquals, Equals }, + { In, NotIn }, + { NotIn, In }, + { Like, NotLike }, + { NotLike, Like }, + { LessThan, GreaterThanOrEquals }, + { GreaterThan, LessThanOrEquals }, + { Equals, NotEquals } }; } } internal class OperatorExpression : VisitedExpression { - private Operator op; - private bool useNewPrecedences; - private VisitedExpression left; - private VisitedExpression right; + Operator _op; + readonly bool _useNewPrecedences; + readonly VisitedExpression _left; + readonly VisitedExpression _right; - private OperatorExpression(Operator op, bool useNewPrecedences, VisitedExpression left, VisitedExpression right) + OperatorExpression(Operator op, bool useNewPrecedences, [CanBeNull] VisitedExpression left, [CanBeNull] VisitedExpression right) { - this.op = op; - this.useNewPrecedences = useNewPrecedences; - this.left = left; - this.right = right; + _op = op; + _useNewPrecedences = useNewPrecedences; + _left = left; + _right = right; } public static OperatorExpression Build(Operator op, bool useNewPrecedences, VisitedExpression left, VisitedExpression right) { if (op.UnaryType == Operator.UnaryTypes.Binary) - { return new OperatorExpression(op, useNewPrecedences, left, right); - } - else - { - throw new InvalidOperationException("Unary operator with two operands"); - } + throw new InvalidOperationException("Unary operator with two operands"); } public static OperatorExpression Build(Operator op, bool useNewPrecedences, VisitedExpression exp) { - if (op.UnaryType == Operator.UnaryTypes.Prefix) + switch (op.UnaryType) { + case Operator.UnaryTypes.Prefix: return new OperatorExpression(op, useNewPrecedences, null, exp); - } - else if (op.UnaryType == Operator.UnaryTypes.Postfix) - { + case Operator.UnaryTypes.Postfix: return new OperatorExpression(op, useNewPrecedences, exp, null); - } - else - { + default: throw new InvalidOperationException("Binary operator with one operand"); } } @@ -1036,23 +909,21 @@ public static OperatorExpression Build(Operator op, bool useNewPrecedences, Visi /// public static VisitedExpression Negate(VisitedExpression exp, bool useNewPrecedences) { - OperatorExpression expOp = exp as OperatorExpression; + var expOp = exp as OperatorExpression; if (expOp != null) { - Operator op = expOp.op; - Operator newOp = null; + var op = expOp._op; + Operator newOp; if (Operator.NegateDict.TryGetValue(op, out newOp)) { - expOp.op = newOp; + expOp._op = newOp; return expOp; } - if (expOp.op == Operator.Not) - { - return expOp.right; - } + if (expOp._op == Operator.Not) + return expOp._right; } - return OperatorExpression.Build(Operator.Not, useNewPrecedences, exp); + return Build(Operator.Not, useNewPrecedences, exp); } internal override void WriteSql(StringBuilder sqlText) @@ -1060,60 +931,60 @@ internal override void WriteSql(StringBuilder sqlText) WriteSql(sqlText, null); } - private void WriteSql(StringBuilder sqlText, OperatorExpression rightParent) + void WriteSql(StringBuilder sqlText, [CanBeNull] OperatorExpression rightParent) { - OperatorExpression leftOp = left as OperatorExpression; - OperatorExpression rightOp = right as OperatorExpression; + var leftOp = _left as OperatorExpression; + var rightOp = _right as OperatorExpression; bool wrapLeft, wrapRight; - if (!useNewPrecedences) + if (!_useNewPrecedences) { - wrapLeft = leftOp != null && (op.RightAssoc ? leftOp.op.RightPrecedence <= op.LeftPrecedence : leftOp.op.RightPrecedence < op.LeftPrecedence); - wrapRight = rightOp != null && (!op.RightAssoc ? rightOp.op.LeftPrecedence <= op.RightPrecedence : rightOp.op.LeftPrecedence < op.RightPrecedence); + wrapLeft = leftOp != null && (_op.RightAssoc ? leftOp._op.RightPrecedence <= _op.LeftPrecedence : leftOp._op.RightPrecedence < _op.LeftPrecedence); + wrapRight = rightOp != null && (!_op.RightAssoc ? rightOp._op.LeftPrecedence <= _op.RightPrecedence : rightOp._op.LeftPrecedence < _op.RightPrecedence); } else { - wrapLeft = leftOp != null && (op.RightAssoc ? leftOp.op.NewPrecedence <= op.NewPrecedence : leftOp.op.NewPrecedence < op.NewPrecedence); - wrapRight = rightOp != null && (!op.RightAssoc ? rightOp.op.NewPrecedence <= op.NewPrecedence : rightOp.op.NewPrecedence < op.NewPrecedence); + wrapLeft = leftOp != null && (_op.RightAssoc ? leftOp._op.NewPrecedence <= _op.NewPrecedence : leftOp._op.NewPrecedence < _op.NewPrecedence); + wrapRight = rightOp != null && (!_op.RightAssoc ? rightOp._op.NewPrecedence <= _op.NewPrecedence : rightOp._op.NewPrecedence < _op.NewPrecedence); } // Avoid parentheses for prefix operators if possible, // e.g. BitwiseNot: (a & (~ b)) & c is written as a & ~ b & c // but (a + (~ b)) + c must be written as a + (~ b) + c - if (!useNewPrecedences) + if (!_useNewPrecedences) { - if (wrapRight && rightOp.left == null && (rightParent == null || (!rightParent.op.RightAssoc ? rightOp.op.RightPrecedence >= rightParent.op.LeftPrecedence : rightOp.op.RightPrecedence > rightParent.op.LeftPrecedence))) + if (wrapRight && rightOp._left == null && (rightParent == null || (!rightParent._op.RightAssoc ? rightOp._op.RightPrecedence >= rightParent._op.LeftPrecedence : rightOp._op.RightPrecedence > rightParent._op.LeftPrecedence))) wrapRight = false; } else { - if (wrapRight && rightOp.left == null && (rightParent == null || (!rightParent.op.RightAssoc ? rightOp.op.NewPrecedence >= rightParent.op.NewPrecedence : rightOp.op.NewPrecedence > rightParent.op.NewPrecedence))) + if (wrapRight && rightOp._left == null && (rightParent == null || (!rightParent._op.RightAssoc ? rightOp._op.NewPrecedence >= rightParent._op.NewPrecedence : rightOp._op.NewPrecedence > rightParent._op.NewPrecedence))) wrapRight = false; } - if (left != null) + if (_left != null) { if (wrapLeft) sqlText.Append("("); if (leftOp != null && !wrapLeft) leftOp.WriteSql(sqlText, this); else - left.WriteSql(sqlText); + _left.WriteSql(sqlText); if (wrapLeft) sqlText.Append(")"); } - sqlText.Append(op.Op); + sqlText.Append(_op.Op); - if (right != null) + if (_right != null) { if (wrapRight) sqlText.Append("("); if (rightOp != null && !wrapRight) rightOp.WriteSql(sqlText, rightParent); else - right.WriteSql(sqlText); + _right.WriteSql(sqlText); if (wrapRight) sqlText.Append(")"); } @@ -1124,7 +995,7 @@ private void WriteSql(StringBuilder sqlText, OperatorExpression rightParent) internal class ConstantListExpression : VisitedExpression { - private IEnumerable _list; + readonly IEnumerable _list; public ConstantListExpression(IEnumerable list) { @@ -1134,7 +1005,7 @@ public ConstantListExpression(IEnumerable list) internal override void WriteSql(StringBuilder sqlText) { sqlText.Append("("); - bool first = true; + var first = true; foreach (var constant in _list) { if (!first) @@ -1149,12 +1020,16 @@ internal override void WriteSql(StringBuilder sqlText) internal class CombinedProjectionExpression : VisitedExpression { - private List _list; - private string _setOperator; + readonly List _list; + readonly string _setOperator; public CombinedProjectionExpression(DbExpressionKind setOperator, List list) { - _setOperator = setOperator == DbExpressionKind.UnionAll ? "UNION ALL" : setOperator == DbExpressionKind.Except ? "EXCEPT" : "INTERSECT"; + _setOperator = setOperator == DbExpressionKind.UnionAll + ? "UNION ALL" + : setOperator == DbExpressionKind.Except + ? "EXCEPT" + : "INTERSECT"; _list = list; } @@ -1163,9 +1038,7 @@ internal override void WriteSql(StringBuilder sqlText) for (var i = 0; i < _list.Count; i++) { if (i != 0) - { sqlText.Append(' ').Append(_setOperator).Append(' '); - } sqlText.Append('('); _list[i].WriteSql(sqlText); sqlText.Append(')'); @@ -1177,7 +1050,7 @@ internal override void WriteSql(StringBuilder sqlText) internal class ExistsExpression : VisitedExpression { - private VisitedExpression _argument; + readonly VisitedExpression _argument; public ExistsExpression(VisitedExpression argument) { @@ -1195,17 +1068,14 @@ internal override void WriteSql(StringBuilder sqlText) class OrderByExpression : VisitedExpression { - private bool _requiresOrderSeperator; + bool _requiresOrderSeperator; public void AppendSort(VisitedExpression sort, bool ascending) { if (_requiresOrderSeperator) Append(","); Append(sort); - if (ascending) - Append(" ASC "); - else - Append(" DESC "); + Append(ascending ? " ASC " : " DESC "); _requiresOrderSeperator = true; }