diff --git a/src/EntityFramework6.Npgsql/EntityFramework6.Npgsql.csproj b/src/EntityFramework6.Npgsql/EntityFramework6.Npgsql.csproj index 1d7fb7c..c5b375b 100644 --- a/src/EntityFramework6.Npgsql/EntityFramework6.Npgsql.csproj +++ b/src/EntityFramework6.Npgsql/EntityFramework6.Npgsql.csproj @@ -69,6 +69,7 @@ + diff --git a/src/EntityFramework6.Npgsql/NpgsqlProviderManifest.cs b/src/EntityFramework6.Npgsql/NpgsqlProviderManifest.cs index 3dfa842..f7d4d21 100644 --- a/src/EntityFramework6.Npgsql/NpgsqlProviderManifest.cs +++ b/src/EntityFramework6.Npgsql/NpgsqlProviderManifest.cs @@ -351,15 +351,15 @@ public override string EscapeLikeArgument([NotNull] string argument) public override bool SupportsInExpression() => true; public override ReadOnlyCollection GetStoreFunctions() - => typeof(NpgsqlTextFunctions).GetTypeInfo() - .GetMethods(BindingFlags.Public | BindingFlags.Static) + => new[] { typeof(NpgsqlTextFunctions).GetTypeInfo(), typeof(NpgsqlTypeFunctions) } + .SelectMany(x => x.GetMethods(BindingFlags.Public | BindingFlags.Static)) .Select(x => new { Method = x, DbFunction = x.GetCustomAttribute() }) .Where(x => x.DbFunction != null) - .Select(x => CreateFullTextEdmFunction(x.Method, x.DbFunction)) + .Select(x => CreateComposableEdmFunction(x.Method, x.DbFunction)) .ToList() .AsReadOnly(); - static EdmFunction CreateFullTextEdmFunction([NotNull] MethodInfo method, [NotNull] DbFunctionAttribute dbFunctionInfo) + static EdmFunction CreateComposableEdmFunction([NotNull] MethodInfo method, [NotNull] DbFunctionAttribute dbFunctionInfo) { if (method == null) throw new ArgumentNullException(nameof(method)); diff --git a/src/EntityFramework6.Npgsql/NpgsqlTypeFunctions.cs b/src/EntityFramework6.Npgsql/NpgsqlTypeFunctions.cs new file mode 100644 index 0000000..b567668 --- /dev/null +++ b/src/EntityFramework6.Npgsql/NpgsqlTypeFunctions.cs @@ -0,0 +1,20 @@ +using System; +using System.Data.Entity; + +namespace Npgsql +{ + /// + /// Use this class in LINQ queries to emit type manipulation SQL fragments. + /// + public static class NpgsqlTypeFunctions + { + /// + /// Emits an explicit cast for unknown types sent as strings to their correct postgresql type. + /// + [DbFunction("Npgsql", "cast")] + public static string Cast(string unknownTypeValue, string postgresTypeName) + { + throw new NotSupportedException(); + } + } +} \ No newline at end of file diff --git a/src/EntityFramework6.Npgsql/SqlGenerators/SqlBaseGenerator.cs b/src/EntityFramework6.Npgsql/SqlGenerators/SqlBaseGenerator.cs index dc3555c..0cf543c 100644 --- a/src/EntityFramework6.Npgsql/SqlGenerators/SqlBaseGenerator.cs +++ b/src/EntityFramework6.Npgsql/SqlGenerators/SqlBaseGenerator.cs @@ -322,7 +322,14 @@ PendingProjectsNode VisitInputWithBinding(DbExpression expression, string bindin for (var i = 0; i < rowType.Properties.Count && i < projection.Arguments.Count; ++i) { var prop = rowType.Properties[i]; - input.Projection.Arguments.Add(new ColumnExpression(projection.Arguments[i].Accept(this), prop.Name, prop.TypeUsage)); + var argument = projection.Arguments[i].Accept(this); + var constantArgument = projection.Arguments[i] as DbConstantExpression; + if (constantArgument != null && constantArgument.Value is string) + { + argument = new CastExpression(argument, "varchar"); + } + + input.Projection.Arguments.Add(new ColumnExpression(argument, prop.Name, prop.TypeUsage)); } if (enterScope) LeaveExpression(child); @@ -1150,6 +1157,17 @@ VisitedExpression VisitFunction(EdmFunction function, IList args, { return VisitMatchRegex(function, args, resultType); } + else if (functionName == "cast") + { + if (args.Count != 2) + throw new ArgumentException("Invalid number of arguments. Expected 2.", "args"); + + var typeNameExpression = args[1] as DbConstantExpression; + if (typeNameExpression == null) + throw new NotSupportedException("cast type name argument must be a constant expression."); + + return new CastExpression(args[0].Accept(this), typeNameExpression.Value.ToString()); + } } var customFuncCall = new FunctionExpression( diff --git a/test/EntityFramework6.Npgsql.Tests/EntityFrameworkBasicTests.cs b/test/EntityFramework6.Npgsql.Tests/EntityFrameworkBasicTests.cs index e1e30fc..2ff87fb 100644 --- a/test/EntityFramework6.Npgsql.Tests/EntityFrameworkBasicTests.cs +++ b/test/EntityFramework6.Npgsql.Tests/EntityFrameworkBasicTests.cs @@ -650,5 +650,74 @@ public void TestScalarValuedStoredFunctions_with_null_StoreFunctionName() Assert.That(echo, Is.EqualTo(1337)); } } + + [Test] + public void TestCastFunction() + { + using (var context = new BloggingContext(ConnectionString)) + { + context.Database.Log = Console.Out.WriteLine; + + var varbitVal = "10011"; + + var blog = new Blog + { + Name = "_", + Posts = new List + { + new Post + { + Content = "Some post content", + Rating = 1, + Title = "Some post Title", + VarbitColumn = varbitVal + } + } + }; + context.Blogs.Add(blog); + context.SaveChanges(); + + Assert.IsTrue( + context.Posts.Select( + p => NpgsqlTypeFunctions.Cast(p.VarbitColumn, "varbit") == varbitVal).First()); + + Assert.IsTrue( + context.Posts.Select( + p => NpgsqlTypeFunctions.Cast(p.VarbitColumn, "varbit") == "10011").First()); + } + } + + [Test] + public void Test_issue_27_select_ef_generated_literals_from_inner_select() + { + using (var context = new BloggingContext(ConnectionString)) + { + context.Database.Log = Console.Out.WriteLine; + + var blog = new Blog { Name = "Hello" }; + context.Users.Add(new Administrator { Blogs = new List { blog } }); + context.Users.Add(new Editor()); + context.SaveChanges(); + + var administrator = context.Users + .Where(x => x is Administrator) // Removing this changes the query to using a UNION which doesn't fail. + .Select( + x => new + { + // causes entity framework to emit a literal discriminator + Computed = x is Administrator + ? "I administrate" + : x is Editor + ? "I edit" + : "Unknown", + // causes an inner select to be emitted thus showing the issue + HasBlog = x.Blogs.Any() + }) + .First(); + + Assert.That(administrator.Computed, Is.EqualTo("I administrate")); + Assert.That(administrator.HasBlog, Is.True); + } + } } } diff --git a/test/EntityFramework6.Npgsql.Tests/Support/EntityFrameworkTestBase.cs b/test/EntityFramework6.Npgsql.Tests/Support/EntityFrameworkTestBase.cs index 366a02f..1fcc3ad 100644 --- a/test/EntityFramework6.Npgsql.Tests/Support/EntityFrameworkTestBase.cs +++ b/test/EntityFramework6.Npgsql.Tests/Support/EntityFrameworkTestBase.cs @@ -104,6 +104,20 @@ public class NoColumnsEntity public int Id { get; set; } } + [Table("Users")] + public abstract class User + { + public int Id { get; set; } + + public IList Blogs { get; set; } + } + + [Table("Editors")] + public class Editor : User { } + + [Table("Administrators")] + public class Administrator : User { } + public class BloggingContext : DbContext { public BloggingContext(string connection) @@ -114,6 +128,9 @@ public BloggingContext(string connection) public DbSet Blogs { get; set; } public DbSet Posts { get; set; } public DbSet NoColumnsEntities { get; set; } + public DbSet Users { get; set; } + public DbSet Editors { get; set; } + public DbSet Administrators { get; set; } [DbFunction("BloggingContext", "ClrStoredAddFunction")] public static int StoredAddFunction(int val1, int val2) @@ -135,6 +152,9 @@ private static DbCompiledModel CreateModel(NpgsqlConnection connection) dbModelBuilder.Entity(); dbModelBuilder.Entity(); dbModelBuilder.Entity(); + dbModelBuilder.Entity(); + dbModelBuilder.Entity(); + dbModelBuilder.Entity(); // Import function var dbModel = dbModelBuilder.Build(connection);