Skip to content

Commit

Permalink
Use null propagation to optimize away IS NOT NULL checks
Browse files Browse the repository at this point in the history
When a CASE expression simply replicates SQL null propagation, simplify it.
  • Loading branch information
ranma42 committed Jun 30, 2024
1 parent cde8072 commit 4e1a381
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,80 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
return elseResult ?? _sqlExpressionFactory.Constant(null, caseExpression.Type, caseExpression.TypeMapping);
}

// optimize expressions such as expr != null ? expr : null
// TODO: optimize expr == null ? null : expr
if (testIsCondition && whenClauses is [var clause] && IsNull(elseResult))
{
HashSet<SqlExpression> nullPropagatedOperands = [];

NullPropagatedOperands(clause.Result, nullPropagatedOperands);
var test = DropNotNullChecks(clause.Test, nullPropagatedOperands);

if (test is null)
{
return clause.Result;
}

whenClauses = [new(test, clause.Result)];
}

return caseExpression.Update(operand, whenClauses, elseResult);

static SqlExpression? DropNotNullChecks(SqlExpression expression, HashSet<SqlExpression> nullPropagatedOperands)
{
if (expression is SqlUnaryExpression { OperatorType: ExpressionType.NotEqual } isNotNull
&& nullPropagatedOperands.Contains(isNotNull.Operand))
{
return null; // true
}
else if (expression is SqlBinaryExpression { OperatorType: ExpressionType.AndAlso } binary)
{
var left = DropNotNullChecks(binary.Left, nullPropagatedOperands);
var right = DropNotNullChecks(binary.Right, nullPropagatedOperands);

return left is null ? right
: right is null ? left
: binary.Update(left, right);
}
else
{
return expression;
}
}

static void NullPropagatedOperands(SqlExpression expression, HashSet<SqlExpression> operands)
{
operands.Add(expression);

if (expression is SqlUnaryExpression unary
&& unary.OperatorType is ExpressionType.Not or ExpressionType.Negate or ExpressionType.Convert)
{
NullPropagatedOperands(unary.Operand, operands);
}
else if (expression is SqlBinaryExpression binary)
{
NullPropagatedOperands(binary.Left, operands);
NullPropagatedOperands(binary.Right, operands);
}
else if (expression is SqlFunctionExpression { IsNullable: true } func)
{
if (func.InstancePropagatesNullability == true)
{
NullPropagatedOperands(func.Instance!, operands);
}

if (!func.IsNiladic)
{
for (var i = 0; i < func.ArgumentsPropagateNullability.Count; i++)
{
if (func.ArgumentsPropagateNullability[i])
{
NullPropagatedOperands(func.Arguments[i], operands);
}
}
}
}
}
}

/// <summary>
Expand Down

0 comments on commit 4e1a381

Please sign in to comment.