Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add enum and comparison translators for cosmos #35405

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/EFCore.Cosmos/EFCore.Cosmos.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<PackageTags>$(PackageTags);CosmosDb;SQL API</PackageTags>
<ImplicitUsings>true</ImplicitUsings>
<NoWarn>$(NoWarn);EF9100</NoWarn> <!-- Precomiled query is experimental -->
<NoWarn>$(NoWarn);EF9101</NoWarn> <!-- Metrics is experimental -->
<NoWarn>$(NoWarn);EF9102</NoWarn> <!-- Paging is experimental -->
<NoWarn>$(NoWarn);EF9103</NoWarn> <!-- Vector search is experimental -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ public CosmosMethodCallTranslatorProvider(
new CosmosRegexTranslator(sqlExpressionFactory),
new CosmosStringMethodTranslator(sqlExpressionFactory),
new CosmosTypeCheckingTranslator(sqlExpressionFactory),
new CosmosVectorSearchTranslator(sqlExpressionFactory, typeMappingSource)
new CosmosVectorSearchTranslator(sqlExpressionFactory, typeMappingSource),
//new LikeTranslator(sqlExpressionFactory),
//new EnumHasFlagTranslator(sqlExpressionFactory),
new CosmosEnumMethodTranslator(sqlExpressionFactory),
//new GetValueOrDefaultTranslator(sqlExpressionFactory),
//new ComparisonTranslator(sqlExpressionFactory),
new CosmosComparisonTranslator(sqlExpressionFactory),
]);
}

Expand Down
42 changes: 42 additions & 0 deletions src/EFCore.Cosmos/Query/Internal/CosmosQuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,48 @@ protected override Expression VisitSqlConditional(SqlConditionalExpression sqlCo
return sqlConditionalExpression;
}

/// <summary>
/// Generates SQL for a CASE clause CASE/WHEN construct.
/// </summary>
/// <param name="caseExpression">The <see cref="CaseExpression" /> for which to generate SQL.</param>
protected override Expression VisitCase(CaseExpression caseExpression)
{
//using (_sqlBuilder.Indent())
{
foreach (var whenClause in caseExpression.WhenClauses)
{
_sqlBuilder.Append("IIF(");

if (caseExpression.Operand != null)
{
Visit(caseExpression.Operand);
_sqlBuilder.Append(" = ");
}

Visit(whenClause.Test);

_sqlBuilder.Append(", ");

Visit(whenClause.Result);

_sqlBuilder.Append(", ");
}

if (caseExpression.ElseResult != null)
{
Visit(caseExpression.ElseResult);
}
else
{
_sqlBuilder.Append("null");
}

_sqlBuilder.Append(new string(')', caseExpression.WhenClauses.Count));
}

return caseExpression;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Diagnostics.CodeAnalysis;
using Microsoft.EntityFrameworkCore.Cosmos.Internal;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Query;
using static Microsoft.EntityFrameworkCore.Query.QueryHelpers;

namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
Expand Down Expand Up @@ -44,7 +45,7 @@ private static readonly MethodInfo StringEqualsWithStringComparisonStatic
private static readonly MethodInfo GetTypeMethodInfo = typeof(object).GetTypeInfo().GetDeclaredMethod(nameof(GetType))!;

private readonly IModel _model = queryCompilationContext.Model;
private readonly SqlTypeMappingVerifyingExpressionVisitor _sqlVerifyingExpressionVisitor = new();
//private readonly SqlTypeMappingVerifyingExpressionVisitor _sqlVerifyingExpressionVisitor = new();

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down Expand Up @@ -101,7 +102,7 @@ protected virtual void AddTranslationErrorDetails(string details)
return null;
}

_sqlVerifyingExpressionVisitor.Visit(translation);
//_sqlVerifyingExpressionVisitor.Visit(translation);
}

return translation;
Expand Down Expand Up @@ -345,7 +346,29 @@ protected override Expression VisitExtension(Expression extensionExpression)
return extensionExpression;

case QueryParameterExpression queryParameter:
return new SqlParameterExpression(queryParameter.Name, queryParameter.Type, null);
// If we're precompiling a query, nullability information about reference type parameters has been extracted by the
// funcletizer and stored on the query compilation context; use that information when creating the SqlParameterExpression.
if (queryParameter.IsNonNullableReferenceType)
{
/*Check.DebugAssert(
_queryCompilationContext.IsPrecompiling,
"Parameters can only be known to has non-nullable reference types in query precompilation.");*/
return new SqlParameterExpression(
invariantName: queryParameter.Name,
name: queryParameter.Name,
queryParameter.Type,
nullable: false,
queryParameter.ShouldBeConstantized,
typeMapping: null);
}

return new SqlParameterExpression(
invariantName: queryParameter.Name,
name: queryParameter.Name,
queryParameter.Type,
queryParameter.Type.IsNullableType(),
queryParameter.ShouldBeConstantized,
typeMapping: null);

case StructuralTypeShaperExpression shaper:
return new EntityReferenceExpression(shaper);
Expand Down
168 changes: 168 additions & 0 deletions src/EFCore.Cosmos/Query/Internal/Expressions/CaseExpression.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;

/// <summary>
/// <para>
/// An expression that represents a CASE statement in a SQL tree.
/// </para>
/// <para>
/// This type is typically used by database providers (and other extensions). It is generally
/// not used in application code.
/// </para>
/// </summary>
public class CaseExpression : SqlExpression
{
private readonly List<CaseWhenClause> _whenClauses = [];

/// <summary>
/// Creates a new instance of the <see cref="CaseExpression" /> class which represents a simple CASE expression.
/// </summary>
/// <param name="operand">An expression to compare with <see cref="CaseWhenClause.Test" /> in <see cref="WhenClauses" />.</param>
/// <param name="whenClauses">A list of <see cref="CaseWhenClause" /> to compare or evaluate and get result from.</param>
/// <param name="elseResult">A value to return if no <see cref="WhenClauses" /> matches, if any.</param>
public CaseExpression(
SqlExpression? operand,
IReadOnlyList<CaseWhenClause> whenClauses,
SqlExpression? elseResult = null)
: base(whenClauses[0].Result.Type, whenClauses[0].Result.TypeMapping)
{
Operand = operand;
_whenClauses.AddRange(whenClauses);
ElseResult = elseResult;
}

/// <summary>
/// Creates a new instance of the <see cref="CaseExpression" /> class which represents a searched CASE expression.
/// </summary>
/// <param name="whenClauses">A list of <see cref="CaseWhenClause" /> to evaluate condition and get result from.</param>
/// <param name="elseResult">A value to return if no <see cref="WhenClauses" /> matches, if any.</param>
public CaseExpression(
IReadOnlyList<CaseWhenClause> whenClauses,
SqlExpression? elseResult = null)
: this(null, whenClauses, elseResult)
{
}

/// <summary>
/// The value to compare in <see cref="WhenClauses" />.
/// </summary>
public virtual SqlExpression? Operand { get; }

/// <summary>
/// The list of <see cref="CaseWhenClause" /> to match <see cref="Operand" /> or evaluate condition to get result.
/// </summary>
public virtual IReadOnlyList<CaseWhenClause> WhenClauses
=> _whenClauses;

/// <summary>
/// The value to return if none of the <see cref="WhenClauses" /> matches.
/// </summary>
public virtual SqlExpression? ElseResult { get; }

/// <inheritdoc />
protected override Expression VisitChildren(ExpressionVisitor visitor)
{
var operand = (SqlExpression?)visitor.Visit(Operand);
var changed = operand != Operand;
var whenClauses = new List<CaseWhenClause>();
foreach (var whenClause in WhenClauses)
{
var test = (SqlExpression)visitor.Visit(whenClause.Test);
var result = (SqlExpression)visitor.Visit(whenClause.Result);

if (test != whenClause.Test
|| result != whenClause.Result)
{
changed = true;
whenClauses.Add(new CaseWhenClause(test, result));
}
else
{
whenClauses.Add(whenClause);
}
}

var elseResult = (SqlExpression?)visitor.Visit(ElseResult);
changed |= elseResult != ElseResult;

return changed
? new CaseExpression(operand, whenClauses, elseResult)
: this;
}

/// <summary>
/// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will
/// return this expression.
/// </summary>
/// <param name="operand">The <see cref="Operand" /> property of the result.</param>
/// <param name="whenClauses">The <see cref="WhenClauses" /> property of the result.</param>
/// <param name="elseResult">The <see cref="ElseResult" /> property of the result.</param>
/// <returns>This expression if no children changed, or an expression with the updated children.</returns>
public virtual CaseExpression Update(
SqlExpression? operand,
IReadOnlyList<CaseWhenClause> whenClauses,
SqlExpression? elseResult)
=> operand != Operand || !whenClauses.SequenceEqual(WhenClauses) || elseResult != ElseResult
? new CaseExpression(operand, whenClauses, elseResult)
: this;

/// <inheritdoc />
protected override void Print(ExpressionPrinter expressionPrinter)
{
expressionPrinter.Append("CASE");
if (Operand != null)
{
expressionPrinter.Append(" ");
expressionPrinter.Visit(Operand);
}

using (expressionPrinter.Indent())
{
foreach (var whenClause in WhenClauses)
{
expressionPrinter.AppendLine().Append("WHEN ");
expressionPrinter.Visit(whenClause.Test);
expressionPrinter.Append(" THEN ");
expressionPrinter.Visit(whenClause.Result);
}

if (ElseResult != null)
{
expressionPrinter.AppendLine().Append("ELSE ");
expressionPrinter.Visit(ElseResult);
}
}

expressionPrinter.AppendLine().Append("END");
}

/// <inheritdoc />
public override bool Equals(object? obj)
=> obj != null
&& (ReferenceEquals(this, obj)
|| obj is CaseExpression caseExpression
&& Equals(caseExpression));

private bool Equals(CaseExpression caseExpression)
=> base.Equals(caseExpression)
&& (Operand?.Equals(caseExpression.Operand) ?? caseExpression.Operand == null)
&& WhenClauses.SequenceEqual(caseExpression.WhenClauses)
&& (ElseResult?.Equals(caseExpression.ElseResult) ?? caseExpression.ElseResult == null);

/// <inheritdoc />
public override int GetHashCode()
{
var hash = new HashCode();
hash.Add(base.GetHashCode());
hash.Add(Operand);
for (var i = 0; i < WhenClauses.Count; i++)
{
hash.Add(WhenClauses[i]);
}

hash.Add(ElseResult);
return hash.ToHashCode();
}
}
52 changes: 52 additions & 0 deletions src/EFCore.Cosmos/Query/Internal/Expressions/CaseWhenClause.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;

/// <summary>
/// <para>
/// An object that represents a WHEN...THEN... construct in a SQL tree.
/// </para>
/// <para>
/// This type is typically used by database providers (and other extensions). It is generally
/// not used in application code.
/// </para>
/// </summary>
public class CaseWhenClause
{
/// <summary>
/// Creates a new instance of the <see cref="CaseWhenClause" /> class.
/// </summary>
/// <param name="test">A value to compare with <see cref="CaseExpression.Operand" /> or condition to evaluate.</param>
/// <param name="result">A value to return if test succeeds.</param>
public CaseWhenClause(SqlExpression test, SqlExpression result)
{
Test = test;
Result = result;
}

/// <summary>
/// The value to compare with <see cref="CaseExpression.Operand" /> or the condition to evaluate.
/// </summary>
public virtual SqlExpression Test { get; }

/// <summary>
/// The value to return if <see cref="Test" /> succeeds.
/// </summary>
public virtual SqlExpression Result { get; }

/// <inheritdoc />
public override bool Equals(object? obj)
=> obj != null
&& (ReferenceEquals(this, obj)
|| obj is CaseWhenClause caseWhenClause
&& Equals(caseWhenClause));

private bool Equals(CaseWhenClause caseWhenClause)
=> Test.Equals(caseWhenClause.Test)
&& Result.Equals(caseWhenClause.Result);

/// <inheritdoc />
public override int GetHashCode()
=> HashCode.Combine(Test, Result);
}
Loading
Loading