diff --git a/src/EFCore.Cosmos/EFCore.Cosmos.csproj b/src/EFCore.Cosmos/EFCore.Cosmos.csproj
index 5a078af20b3..6f02ab3d015 100644
--- a/src/EFCore.Cosmos/EFCore.Cosmos.csproj
+++ b/src/EFCore.Cosmos/EFCore.Cosmos.csproj
@@ -9,6 +9,7 @@
true$(PackageTags);CosmosDb;SQL APItrue
+ $(NoWarn);EF9100$(NoWarn);EF9101$(NoWarn);EF9102$(NoWarn);EF9103
diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosMethodCallTranslatorProvider.cs b/src/EFCore.Cosmos/Query/Internal/CosmosMethodCallTranslatorProvider.cs
index 352e3d443e1..2d33f24304b 100644
--- a/src/EFCore.Cosmos/Query/Internal/CosmosMethodCallTranslatorProvider.cs
+++ b/src/EFCore.Cosmos/Query/Internal/CosmosMethodCallTranslatorProvider.cs
@@ -36,11 +36,11 @@ public CosmosMethodCallTranslatorProvider(
new CosmosRegexTranslator(sqlExpressionFactory),
new CosmosStringMethodTranslator(sqlExpressionFactory),
new CosmosTypeCheckingTranslator(sqlExpressionFactory),
- new CosmosVectorSearchTranslator(sqlExpressionFactory, typeMappingSource)
+ new CosmosVectorSearchTranslator(sqlExpressionFactory, typeMappingSource),
//new LikeTranslator(sqlExpressionFactory),
- //new EnumHasFlagTranslator(sqlExpressionFactory),
+ new CosmosEnumMethodTranslator(sqlExpressionFactory),
//new GetValueOrDefaultTranslator(sqlExpressionFactory),
- //new ComparisonTranslator(sqlExpressionFactory),
+ new CosmosComparisonTranslator(sqlExpressionFactory),
]);
}
diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosQuerySqlGenerator.cs b/src/EFCore.Cosmos/Query/Internal/CosmosQuerySqlGenerator.cs
index 9e9c8151d2f..6bb836d6b44 100644
--- a/src/EFCore.Cosmos/Query/Internal/CosmosQuerySqlGenerator.cs
+++ b/src/EFCore.Cosmos/Query/Internal/CosmosQuerySqlGenerator.cs
@@ -658,6 +658,48 @@ protected override Expression VisitSqlConditional(SqlConditionalExpression sqlCo
return sqlConditionalExpression;
}
+ ///
+ /// Generates SQL for a CASE clause CASE/WHEN construct.
+ ///
+ /// The for which to generate SQL.
+ protected override Expression VisitCase(CaseExpression caseExpression)
+ {
+ //using (_sqlBuilder.Indent())
+ {
+ foreach (var whenClause in caseExpression.WhenClauses)
+ {
+ _sqlBuilder.Append("IIF(");
+
+ if (caseExpression.Operand != null)
+ {
+ Visit(caseExpression.Operand);
+ _sqlBuilder.Append(" = ");
+ }
+
+ Visit(whenClause.Test);
+
+ _sqlBuilder.Append(", ");
+
+ Visit(whenClause.Result);
+
+ _sqlBuilder.Append(", ");
+ }
+
+ if (caseExpression.ElseResult != null)
+ {
+ Visit(caseExpression.ElseResult);
+ }
+ else
+ {
+ _sqlBuilder.Append("null");
+ }
+
+ _sqlBuilder.Append(new string(')', caseExpression.WhenClauses.Count));
+ }
+
+ return caseExpression;
+ }
+
///
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs
index e46d7344f52..d582460ab09 100644
--- a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs
+++ b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs
@@ -5,6 +5,7 @@
using System.Diagnostics.CodeAnalysis;
using Microsoft.EntityFrameworkCore.Cosmos.Internal;
using Microsoft.EntityFrameworkCore.Internal;
+using Microsoft.EntityFrameworkCore.Query;
using static Microsoft.EntityFrameworkCore.Query.QueryHelpers;
namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
@@ -44,7 +45,7 @@ private static readonly MethodInfo StringEqualsWithStringComparisonStatic
private static readonly MethodInfo GetTypeMethodInfo = typeof(object).GetTypeInfo().GetDeclaredMethod(nameof(GetType))!;
private readonly IModel _model = queryCompilationContext.Model;
- private readonly SqlTypeMappingVerifyingExpressionVisitor _sqlVerifyingExpressionVisitor = new();
+ //private readonly SqlTypeMappingVerifyingExpressionVisitor _sqlVerifyingExpressionVisitor = new();
///
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
@@ -101,7 +102,7 @@ protected virtual void AddTranslationErrorDetails(string details)
return null;
}
- _sqlVerifyingExpressionVisitor.Visit(translation);
+ //_sqlVerifyingExpressionVisitor.Visit(translation);
}
return translation;
@@ -345,7 +346,29 @@ protected override Expression VisitExtension(Expression extensionExpression)
return extensionExpression;
case QueryParameterExpression queryParameter:
- return new SqlParameterExpression(queryParameter.Name, queryParameter.Type, null);
+ // If we're precompiling a query, nullability information about reference type parameters has been extracted by the
+ // funcletizer and stored on the query compilation context; use that information when creating the SqlParameterExpression.
+ if (queryParameter.IsNonNullableReferenceType)
+ {
+ /*Check.DebugAssert(
+ _queryCompilationContext.IsPrecompiling,
+ "Parameters can only be known to has non-nullable reference types in query precompilation.");*/
+ return new SqlParameterExpression(
+ invariantName: queryParameter.Name,
+ name: queryParameter.Name,
+ queryParameter.Type,
+ nullable: false,
+ queryParameter.ShouldBeConstantized,
+ typeMapping: null);
+ }
+
+ return new SqlParameterExpression(
+ invariantName: queryParameter.Name,
+ name: queryParameter.Name,
+ queryParameter.Type,
+ queryParameter.Type.IsNullableType(),
+ queryParameter.ShouldBeConstantized,
+ typeMapping: null);
case StructuralTypeShaperExpression shaper:
return new EntityReferenceExpression(shaper);
diff --git a/src/EFCore.Cosmos/Query/Internal/Expressions/CaseExpression.cs b/src/EFCore.Cosmos/Query/Internal/Expressions/CaseExpression.cs
new file mode 100644
index 00000000000..a4016b96a81
--- /dev/null
+++ b/src/EFCore.Cosmos/Query/Internal/Expressions/CaseExpression.cs
@@ -0,0 +1,168 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
+
+///
+///
+/// An expression that represents a CASE statement in a SQL tree.
+///
+///
+/// This type is typically used by database providers (and other extensions). It is generally
+/// not used in application code.
+///
+///
+public class CaseExpression : SqlExpression
+{
+ private readonly List _whenClauses = [];
+
+ ///
+ /// Creates a new instance of the class which represents a simple CASE expression.
+ ///
+ /// An expression to compare with in .
+ /// A list of to compare or evaluate and get result from.
+ /// A value to return if no matches, if any.
+ public CaseExpression(
+ SqlExpression? operand,
+ IReadOnlyList whenClauses,
+ SqlExpression? elseResult = null)
+ : base(whenClauses[0].Result.Type, whenClauses[0].Result.TypeMapping)
+ {
+ Operand = operand;
+ _whenClauses.AddRange(whenClauses);
+ ElseResult = elseResult;
+ }
+
+ ///
+ /// Creates a new instance of the class which represents a searched CASE expression.
+ ///
+ /// A list of to evaluate condition and get result from.
+ /// A value to return if no matches, if any.
+ public CaseExpression(
+ IReadOnlyList whenClauses,
+ SqlExpression? elseResult = null)
+ : this(null, whenClauses, elseResult)
+ {
+ }
+
+ ///
+ /// The value to compare in .
+ ///
+ public virtual SqlExpression? Operand { get; }
+
+ ///
+ /// The list of to match or evaluate condition to get result.
+ ///
+ public virtual IReadOnlyList WhenClauses
+ => _whenClauses;
+
+ ///
+ /// The value to return if none of the matches.
+ ///
+ public virtual SqlExpression? ElseResult { get; }
+
+ ///
+ protected override Expression VisitChildren(ExpressionVisitor visitor)
+ {
+ var operand = (SqlExpression?)visitor.Visit(Operand);
+ var changed = operand != Operand;
+ var whenClauses = new List();
+ foreach (var whenClause in WhenClauses)
+ {
+ var test = (SqlExpression)visitor.Visit(whenClause.Test);
+ var result = (SqlExpression)visitor.Visit(whenClause.Result);
+
+ if (test != whenClause.Test
+ || result != whenClause.Result)
+ {
+ changed = true;
+ whenClauses.Add(new CaseWhenClause(test, result));
+ }
+ else
+ {
+ whenClauses.Add(whenClause);
+ }
+ }
+
+ var elseResult = (SqlExpression?)visitor.Visit(ElseResult);
+ changed |= elseResult != ElseResult;
+
+ return changed
+ ? new CaseExpression(operand, whenClauses, elseResult)
+ : this;
+ }
+
+ ///
+ /// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will
+ /// return this expression.
+ ///
+ /// The property of the result.
+ /// The property of the result.
+ /// The property of the result.
+ /// This expression if no children changed, or an expression with the updated children.
+ public virtual CaseExpression Update(
+ SqlExpression? operand,
+ IReadOnlyList whenClauses,
+ SqlExpression? elseResult)
+ => operand != Operand || !whenClauses.SequenceEqual(WhenClauses) || elseResult != ElseResult
+ ? new CaseExpression(operand, whenClauses, elseResult)
+ : this;
+
+ ///
+ protected override void Print(ExpressionPrinter expressionPrinter)
+ {
+ expressionPrinter.Append("CASE");
+ if (Operand != null)
+ {
+ expressionPrinter.Append(" ");
+ expressionPrinter.Visit(Operand);
+ }
+
+ using (expressionPrinter.Indent())
+ {
+ foreach (var whenClause in WhenClauses)
+ {
+ expressionPrinter.AppendLine().Append("WHEN ");
+ expressionPrinter.Visit(whenClause.Test);
+ expressionPrinter.Append(" THEN ");
+ expressionPrinter.Visit(whenClause.Result);
+ }
+
+ if (ElseResult != null)
+ {
+ expressionPrinter.AppendLine().Append("ELSE ");
+ expressionPrinter.Visit(ElseResult);
+ }
+ }
+
+ expressionPrinter.AppendLine().Append("END");
+ }
+
+ ///
+ public override bool Equals(object? obj)
+ => obj != null
+ && (ReferenceEquals(this, obj)
+ || obj is CaseExpression caseExpression
+ && Equals(caseExpression));
+
+ private bool Equals(CaseExpression caseExpression)
+ => base.Equals(caseExpression)
+ && (Operand?.Equals(caseExpression.Operand) ?? caseExpression.Operand == null)
+ && WhenClauses.SequenceEqual(caseExpression.WhenClauses)
+ && (ElseResult?.Equals(caseExpression.ElseResult) ?? caseExpression.ElseResult == null);
+
+ ///
+ public override int GetHashCode()
+ {
+ var hash = new HashCode();
+ hash.Add(base.GetHashCode());
+ hash.Add(Operand);
+ for (var i = 0; i < WhenClauses.Count; i++)
+ {
+ hash.Add(WhenClauses[i]);
+ }
+
+ hash.Add(ElseResult);
+ return hash.ToHashCode();
+ }
+}
diff --git a/src/EFCore.Cosmos/Query/Internal/Expressions/CaseWhenClause.cs b/src/EFCore.Cosmos/Query/Internal/Expressions/CaseWhenClause.cs
new file mode 100644
index 00000000000..724f39b8162
--- /dev/null
+++ b/src/EFCore.Cosmos/Query/Internal/Expressions/CaseWhenClause.cs
@@ -0,0 +1,52 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
+
+///
+///
+/// An object that represents a WHEN...THEN... construct in a SQL tree.
+///
+///
+/// This type is typically used by database providers (and other extensions). It is generally
+/// not used in application code.
+///
+///
+public class CaseWhenClause
+{
+ ///
+ /// Creates a new instance of the class.
+ ///
+ /// A value to compare with or condition to evaluate.
+ /// A value to return if test succeeds.
+ public CaseWhenClause(SqlExpression test, SqlExpression result)
+ {
+ Test = test;
+ Result = result;
+ }
+
+ ///
+ /// The value to compare with or the condition to evaluate.
+ ///
+ public virtual SqlExpression Test { get; }
+
+ ///
+ /// The value to return if succeeds.
+ ///
+ public virtual SqlExpression Result { get; }
+
+ ///
+ public override bool Equals(object? obj)
+ => obj != null
+ && (ReferenceEquals(this, obj)
+ || obj is CaseWhenClause caseWhenClause
+ && Equals(caseWhenClause));
+
+ private bool Equals(CaseWhenClause caseWhenClause)
+ => Test.Equals(caseWhenClause.Test)
+ && Result.Equals(caseWhenClause.Result);
+
+ ///
+ public override int GetHashCode()
+ => HashCode.Combine(Test, Result);
+}
diff --git a/src/EFCore.Cosmos/Query/Internal/Expressions/SqlParameterExpression.cs b/src/EFCore.Cosmos/Query/Internal/Expressions/SqlParameterExpression.cs
index fe317979f5a..e3318bdf919 100644
--- a/src/EFCore.Cosmos/Query/Internal/Expressions/SqlParameterExpression.cs
+++ b/src/EFCore.Cosmos/Query/Internal/Expressions/SqlParameterExpression.cs
@@ -3,6 +3,8 @@
// ReSharper disable once CheckNamespace
+using System.Data.Common;
+
namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
///
@@ -11,16 +13,66 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
///
-public sealed class SqlParameterExpression(string name, Type type, CoreTypeMapping? typeMapping)
- : SqlExpression(type, typeMapping)
+public sealed class SqlParameterExpression: SqlExpression
{
///
- /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
- /// the same compatibility standards as public APIs. It may be changed or removed without notice in
- /// any release. You should only use it directly in your code with extreme caution and knowing that
- /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ /// Creates a new instance of the class.
+ ///
+ /// The parameter name.
+ /// The of the expression.
+ /// The associated with the expression.
+ public SqlParameterExpression(string name, Type type, CoreTypeMapping? typeMapping)
+ : this(invariantName: name, name: name, type.UnwrapNullableType(), type.IsNullableType(), shouldBeConstantized: false, typeMapping)
+ {
+ }
+
+ ///
+ /// Creates a new instance of the class.
+ ///
+ /// The name of the parameter as it is recorded in .
+ ///
+ /// The name of the parameter as it will be set on and inside the SQL as a placeholder
+ /// (before any additional placeholder character prefixing).
+ ///
+ /// The of the expression.
+ /// Whether this parameter can have null values.
+ /// Whether the user has indicated that this query parameter should be inlined as a constant.
+ /// The associated with the expression.
+ public SqlParameterExpression(
+ string invariantName,
+ string name,
+ Type type,
+ bool nullable,
+ bool shouldBeConstantized,
+ CoreTypeMapping? typeMapping)
+ : base(type.UnwrapNullableType(), typeMapping)
+ {
+ InvariantName = invariantName;
+ Name = name;
+ IsNullable = nullable;
+ ShouldBeConstantized = shouldBeConstantized;
+ }
+
+ ///
+ /// The name of the parameter as it is recorded in .
+ ///
+ public string InvariantName { get; }
+
+ ///
+ /// The name of the parameter as it will be set on and inside the SQL as a placeholder
+ /// (before any additional placeholder character prefixing).
+ ///
+ public string Name { get; }
+
+ ///
+ /// The bool value indicating if this parameter can have null values.
+ ///
+ public bool IsNullable { get; }
+
+ ///
+ /// Whether the user has indicated that this query parameter should be inlined as a constant.
///
- public string Name { get; } = name;
+ public bool ShouldBeConstantized { get; }
///
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
@@ -29,7 +81,7 @@ public sealed class SqlParameterExpression(string name, Type type, CoreTypeMappi
/// doing so can result in application failures when updating to a new Entity Framework Core release.
///
public SqlExpression ApplyTypeMapping(CoreTypeMapping? typeMapping)
- => new SqlParameterExpression(Name, Type, typeMapping ?? TypeMapping);
+ => new SqlParameterExpression(InvariantName, Name, Type, IsNullable, ShouldBeConstantized, typeMapping);
///
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
diff --git a/src/EFCore.Cosmos/Query/Internal/ISqlExpressionFactory.cs b/src/EFCore.Cosmos/Query/Internal/ISqlExpressionFactory.cs
index 12c46524fc0..4cb7034cc26 100644
--- a/src/EFCore.Cosmos/Query/Internal/ISqlExpressionFactory.cs
+++ b/src/EFCore.Cosmos/Query/Internal/ISqlExpressionFactory.cs
@@ -282,4 +282,26 @@ public interface ISqlExpressionFactory
/// doing so can result in application failures when updating to a new Entity Framework Core release.
///
SqlExpression Constant(object? value, Type type, CoreTypeMapping? typeMapping = null);
+
+ ///
+ /// Creates a new which represent a CASE statement in a SQL tree.
+ ///
+ /// An expression to compare with in .
+ /// A list of to compare or evaluate and get result from.
+ /// A value to return if no matches, if any.
+ /// An optional expression that can be re-used if it matches the new expression.
+ /// An expression representing a CASE statement in a SQL tree.
+ SqlExpression Case(
+ SqlExpression? operand,
+ IReadOnlyList whenClauses,
+ SqlExpression? elseResult,
+ SqlExpression? existingExpression = null);
+
+ ///
+ /// Creates a new which represent a CASE statement in a SQL tree.
+ ///
+ /// A list of to evaluate condition and get result from.
+ /// A value to return if no matches, if any.
+ /// An expression representing a CASE statement in a SQL tree.
+ SqlExpression Case(IReadOnlyList whenClauses, SqlExpression? elseResult);
}
diff --git a/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs b/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs
index b4cfb0938a5..762672b86f6 100644
--- a/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs
+++ b/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs
@@ -43,6 +43,7 @@ public class SqlExpressionFactory(ITypeMappingSource typeMappingSource, IModel m
null or { TypeMapping: not null } => sqlExpression,
ScalarSubqueryExpression e => e.ApplyTypeMapping(typeMapping),
+ CaseExpression e => ApplyTypeMappingOnCase(e, typeMapping),
SqlConditionalExpression sqlConditionalExpression => ApplyTypeMappingOnSqlConditional(sqlConditionalExpression, typeMapping),
SqlBinaryExpression sqlBinaryExpression => ApplyTypeMappingOnSqlBinary(sqlBinaryExpression, typeMapping),
SqlUnaryExpression sqlUnaryExpression => ApplyTypeMappingOnSqlUnary(sqlUnaryExpression, typeMapping),
@@ -105,6 +106,24 @@ when sqlUnaryExpression.IsLogicalNot():
return new SqlUnaryExpression(sqlUnaryExpression.OperatorType, operand, resultType, resultTypeMapping);
}
+ private SqlExpression ApplyTypeMappingOnCase(
+ CaseExpression caseExpression,
+ CoreTypeMapping? typeMapping)
+ {
+ var whenClauses = new List();
+ foreach (var caseWhenClause in caseExpression.WhenClauses)
+ {
+ whenClauses.Add(
+ new CaseWhenClause(
+ caseWhenClause.Test,
+ ApplyTypeMapping(caseWhenClause.Result, typeMapping)));
+ }
+
+ var elseResult = ApplyTypeMapping(caseExpression.ElseResult, typeMapping);
+
+ return caseExpression.Update(caseExpression.Operand, whenClauses, elseResult);
+ }
+
private SqlExpression ApplyTypeMappingOnSqlBinary(
SqlBinaryExpression sqlBinaryExpression,
CoreTypeMapping? typeMapping)
@@ -728,4 +747,150 @@ public virtual SqlExpression Constant(object value, CoreTypeMapping? typeMapping
///
public virtual SqlExpression Constant(object? value, Type type, CoreTypeMapping? typeMapping = null)
=> new SqlConstantExpression(value, type, typeMapping);
+
+ ///
+ public virtual SqlExpression Case(
+ SqlExpression? operand,
+ IReadOnlyList whenClauses,
+ SqlExpression? elseResult,
+ SqlExpression? existingExpression = null)
+ {
+ CoreTypeMapping? testTypeMapping;
+ if (operand == null)
+ {
+ testTypeMapping = _boolTypeMapping;
+ }
+ else
+ {
+ testTypeMapping = operand.TypeMapping
+ ?? whenClauses.Select(wc => wc.Test.TypeMapping).FirstOrDefault(t => t != null)
+ // Since we never look at type of Operand/Test after this place,
+ // we need to find actual typeMapping based on non-object type.
+ ?? new[] { operand.Type }.Concat(whenClauses.Select(wc => wc.Test.Type))
+ .Where(t => t != typeof(object)).Select(t => typeMappingSource.FindMapping(t, model))
+ .FirstOrDefault();
+
+ operand = ApplyTypeMapping(operand, testTypeMapping);
+ }
+
+ var resultTypeMapping = elseResult?.TypeMapping
+ ?? whenClauses.Select(wc => wc.Result.TypeMapping).FirstOrDefault(t => t != null);
+
+ elseResult = ApplyTypeMapping(elseResult, resultTypeMapping);
+
+ var typeMappedWhenClauses = new List();
+ foreach (var caseWhenClause in whenClauses)
+ {
+ var test = caseWhenClause.Test;
+
+ if (operand == null && test is CaseExpression { Operand: null, WhenClauses: [var nestedSingleClause] } testExpr)
+ {
+ if (nestedSingleClause.Result is SqlConstantExpression { Value: true }
+ && testExpr.ElseResult is null or SqlConstantExpression { Value: false or null })
+ {
+ // WHEN CASE
+ // WHEN x THEN TRUE
+ // ELSE FALSE/NULL
+ // END THEN y
+ // simplifies to
+ // WHEN x THEN y
+ test = nestedSingleClause.Test;
+ }
+ else if (nestedSingleClause.Result is SqlConstantExpression { Value: false or null }
+ && testExpr.ElseResult is SqlConstantExpression { Value: true })
+ {
+ // same for the negated results
+ test = Not(nestedSingleClause.Test);
+ }
+ }
+
+ typeMappedWhenClauses.Add(
+ new CaseWhenClause(
+ ApplyTypeMapping(test, testTypeMapping),
+ ApplyTypeMapping(caseWhenClause.Result, resultTypeMapping)));
+ }
+
+ if (operand is null && elseResult is CaseExpression { Operand: null } nestedCaseExpression)
+ {
+ typeMappedWhenClauses.AddRange(nestedCaseExpression.WhenClauses);
+ elseResult = nestedCaseExpression.ElseResult;
+ }
+
+ typeMappedWhenClauses = typeMappedWhenClauses
+ .Where(c => !IsSkipped(c))
+ .TakeUpTo(IsMatched)
+ .DistinctBy(c => c.Test)
+ .ToList();
+
+ // CASE
+ // ...
+ // WHEN TRUE THEN a
+ // ELSE b
+ // END
+ // simplifies to
+ // CASE
+ // ...
+ // ELSE a
+ // END
+ if (typeMappedWhenClauses.Count > 0 && IsMatched(typeMappedWhenClauses[^1]))
+ {
+ elseResult = typeMappedWhenClauses[^1].Result;
+ typeMappedWhenClauses.RemoveAt(typeMappedWhenClauses.Count - 1);
+ }
+
+ var nullResult = Constant(null, elseResult?.Type ?? whenClauses[0].Result.Type, resultTypeMapping);
+
+ // if there are no whenClauses left (e.g. their tests evaluated to false):
+ // - if there is Else block, return it
+ // - if there is no Else block, return null
+ if (typeMappedWhenClauses.Count == 0)
+ {
+ return elseResult ?? nullResult;
+ }
+
+ // omit `ELSE NULL` (this makes it easier to compare/reuse expressions)
+ if (elseResult is SqlConstantExpression { Value: null })
+ {
+ elseResult = null;
+ }
+
+ // CASE
+ // ...
+ // WHEN x THEN CASE
+ // WHEN y THEN a
+ // ELSE b
+ // END
+ // ELSE b
+ // END
+ // simplifies to
+ // CASE
+ // ...
+ // WHEN x AND y THEN a
+ // ELSE b
+ // END
+ if (operand == null
+ && typeMappedWhenClauses[^1].Result is CaseExpression { Operand: null, WhenClauses: [var lastClause] } lastCase
+ && Equals(elseResult, lastCase.ElseResult))
+ {
+ typeMappedWhenClauses[^1] = new CaseWhenClause(AndAlso(typeMappedWhenClauses[^1].Test, lastClause.Test), lastClause.Result);
+ elseResult = lastCase.ElseResult;
+ }
+
+ return existingExpression is CaseExpression expr
+ && operand == expr.Operand
+ && typeMappedWhenClauses.SequenceEqual(expr.WhenClauses)
+ && elseResult == expr.ElseResult
+ ? expr
+ : new CaseExpression(operand, typeMappedWhenClauses, elseResult);
+
+ bool IsSkipped(CaseWhenClause clause)
+ => operand is null && clause.Test is SqlConstantExpression { Value: false or null };
+
+ bool IsMatched(CaseWhenClause clause)
+ => operand is null && clause.Test is SqlConstantExpression { Value: true };
+ }
+
+ ///
+ public virtual SqlExpression Case(IReadOnlyList whenClauses, SqlExpression? elseResult)
+ => Case(operand: null, whenClauses, elseResult);
}
diff --git a/src/EFCore.Cosmos/Query/Internal/SqlExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/SqlExpressionVisitor.cs
index 22d2b286bb1..76a276ff2dd 100644
--- a/src/EFCore.Cosmos/Query/Internal/SqlExpressionVisitor.cs
+++ b/src/EFCore.Cosmos/Query/Internal/SqlExpressionVisitor.cs
@@ -22,6 +22,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
{
ShapedQueryExpression shapedQueryExpression
=> shapedQueryExpression.UpdateQueryExpression(Visit(shapedQueryExpression.QueryExpression)),
+ CaseExpression caseExpression => VisitCase(caseExpression),
SelectExpression selectExpression => VisitSelect(selectExpression),
ProjectionExpression projectionExpression => VisitProjection(projectionExpression),
EntityProjectionExpression entityProjectionExpression => VisitEntityProjection(entityProjectionExpression),
@@ -53,6 +54,13 @@ ShapedQueryExpression shapedQueryExpression
_ => base.VisitExtension(extensionExpression)
};
+ ///
+ /// Visits the children of the case expression.
+ ///
+ /// The expression to visit.
+ /// The modified expression, if it or any subexpression was modified; otherwise, returns the original expression.
+ protected abstract Expression VisitCase(CaseExpression caseExpression);
+
///
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
diff --git a/src/EFCore.Cosmos/Query/Internal/Translators/CosmosComparisonTranslator.cs b/src/EFCore.Cosmos/Query/Internal/Translators/CosmosComparisonTranslator.cs
new file mode 100644
index 00000000000..454976935e2
--- /dev/null
+++ b/src/EFCore.Cosmos/Query/Internal/Translators/CosmosComparisonTranslator.cs
@@ -0,0 +1,81 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+
+// ReSharper disable once CheckNamespace
+
+using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal;
+
+namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
+
+///
+/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+/// the same compatibility standards as public APIs. It may be changed or removed without notice in
+/// any release. You should only use it directly in your code with extreme caution and knowing that
+/// doing so can result in application failures when updating to a new Entity Framework Core release.
+///
+public class CosmosComparisonTranslator : IMethodCallTranslator
+{
+ private readonly ISqlExpressionFactory _sqlExpressionFactory;
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public CosmosComparisonTranslator(ISqlExpressionFactory sqlExpressionFactory)
+ => _sqlExpressionFactory = sqlExpressionFactory;
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public virtual SqlExpression? Translate(
+ SqlExpression? instance,
+ MethodInfo method,
+ IReadOnlyList arguments,
+ IDiagnosticsLogger logger)
+ {
+ if (method.ReturnType == typeof(int))
+ {
+ SqlExpression? left = null;
+ SqlExpression? right = null;
+ if (method.Name == nameof(string.Compare)
+ && arguments.Count == 2
+ && arguments[0].Type == arguments[1].Type)
+ {
+ left = arguments[0];
+ right = arguments[1];
+ }
+ else if (method.Name == nameof(string.CompareTo)
+ && arguments.Count == 1
+ && instance != null
+ && instance.Type == arguments[0].Type)
+ {
+ left = instance;
+ right = arguments[0];
+ }
+
+ if (left != null
+ && right != null)
+ {
+ return _sqlExpressionFactory.Case(
+ new[]
+ {
+ new CaseWhenClause(
+ _sqlExpressionFactory.Equal(left, right), _sqlExpressionFactory.Constant(0)),
+ new CaseWhenClause(
+ _sqlExpressionFactory.GreaterThan(left, right), _sqlExpressionFactory.Constant(1)),
+ new CaseWhenClause(
+ _sqlExpressionFactory.LessThan(left, right), _sqlExpressionFactory.Constant(-1))
+ },
+ null);
+ }
+ }
+
+ return null;
+ }
+}
diff --git a/src/EFCore.Cosmos/Query/Internal/Translators/CosmosEnumMethodTranslator.cs b/src/EFCore.Cosmos/Query/Internal/Translators/CosmosEnumMethodTranslator.cs
new file mode 100644
index 00000000000..8ce86f82eba
--- /dev/null
+++ b/src/EFCore.Cosmos/Query/Internal/Translators/CosmosEnumMethodTranslator.cs
@@ -0,0 +1,87 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
+
+// ReSharper disable once CheckNamespace
+namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
+
+///
+/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+/// the same compatibility standards as public APIs. It may be changed or removed without notice in
+/// any release. You should only use it directly in your code with extreme caution and knowing that
+/// doing so can result in application failures when updating to a new Entity Framework Core release.
+///
+public class CosmosEnumMethodTranslator : IMethodCallTranslator
+{
+ private static readonly MethodInfo HasFlagMethodInfo
+ = typeof(Enum).GetRuntimeMethod(nameof(Enum.HasFlag), [typeof(Enum)])!;
+
+ private static readonly MethodInfo ToStringMethodInfo
+ = typeof(object).GetRuntimeMethod(nameof(ToString), [])!;
+
+ private readonly ISqlExpressionFactory _sqlExpressionFactory;
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public CosmosEnumMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)
+ => _sqlExpressionFactory = sqlExpressionFactory;
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public virtual SqlExpression? Translate(
+ SqlExpression? instance,
+ MethodInfo method,
+ IReadOnlyList arguments,
+ IDiagnosticsLogger logger)
+ {
+ if (Equals(method, HasFlagMethodInfo)
+ && instance != null)
+ {
+ var argument = arguments[0];
+ return instance.Type != argument.Type
+ ? null
+ : _sqlExpressionFactory.Equal(_sqlExpressionFactory.And(instance, argument), argument);
+ }
+
+ if (Equals(method, ToStringMethodInfo)
+ && instance is { Type.IsEnum: true, TypeMapping.Converter: ValueConverter converter }
+ && converter.GetType() is { IsGenericType: true } converterType)
+ {
+ switch (converterType)
+ {
+ case not null when converterType.GetGenericTypeDefinition() == typeof(EnumToNumberConverter<,>):
+ var whenClauses = Enum.GetValues(instance.Type)
+ .Cast