Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQLite: implement MAX/MIN/ORDER BY for decimal #35606

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -133,24 +133,28 @@ protected override QueryableMethodTranslatingExpressionVisitor CreateSubqueryVis
LambdaExpression keySelector,
bool ascending)
{
var translation = base.TranslateOrderBy(source, keySelector, ascending);
var translation = TranslateLambdaExpression(source, keySelector);
if (translation == null)
{
return null;
}

var orderingExpression = ((SelectExpression)translation.QueryExpression).Orderings.Last();
var orderingExpressionType = GetProviderType(orderingExpression.Expression);
var orderingExpressionType = GetProviderType(translation);
if (orderingExpressionType == typeof(DateTimeOffset)
|| orderingExpressionType == typeof(decimal)
|| orderingExpressionType == typeof(TimeSpan)
|| orderingExpressionType == typeof(ulong))
{
throw new NotSupportedException(
SqliteStrings.OrderByNotSupported(orderingExpressionType.ShortDisplayName()));
}
else if (orderingExpressionType == typeof(decimal))
{
translation = new CollateExpression(translation, "EF_DECIMAL");
}

return translation;
((SelectExpression)source.QueryExpression).ApplyOrdering(new OrderingExpression(translation, ascending));

return source;
}

/// <summary>
Expand All @@ -164,24 +168,28 @@ protected override QueryableMethodTranslatingExpressionVisitor CreateSubqueryVis
LambdaExpression keySelector,
bool ascending)
{
var translation = base.TranslateThenBy(source, keySelector, ascending);
var translation = TranslateLambdaExpression(source, keySelector);
if (translation == null)
{
return null;
}

var orderingExpression = ((SelectExpression)translation.QueryExpression).Orderings.Last();
var orderingExpressionType = GetProviderType(orderingExpression.Expression);
var orderingExpressionType = GetProviderType(translation);
if (orderingExpressionType == typeof(DateTimeOffset)
|| orderingExpressionType == typeof(decimal)
|| orderingExpressionType == typeof(TimeSpan)
|| orderingExpressionType == typeof(ulong))
{
throw new NotSupportedException(
SqliteStrings.OrderByNotSupported(orderingExpressionType.ShortDisplayName()));
}
else if (orderingExpressionType == typeof(decimal))
{
translation = new CollateExpression(translation, "EF_DECIMAL");
}

((SelectExpression)source.QueryExpression).AppendOrdering(new OrderingExpression(translation, ascending));

return translation;
return source;
}

/// <summary>
Expand Down Expand Up @@ -467,9 +475,9 @@ protected override ShapedQueryExpression TransformJsonQueryToTable(JsonQueryExpr
Tables:
[
TableValuedFunctionExpression
{
Name: "json_each", Schema: null, IsBuiltIn: true, Arguments: [var jsonArrayColumn]
} jsonEachExpression
{
Name: "json_each", Schema: null, IsBuiltIn: true, Arguments: [var jsonArrayColumn]
} jsonEachExpression
],
Predicate: null,
GroupBy: [],
Expand Down Expand Up @@ -529,16 +537,16 @@ protected override ShapedQueryExpression TransformJsonQueryToTable(JsonQueryExpr
protected override bool IsNaturallyOrdered(SelectExpression selectExpression)
{
return selectExpression is
{
Tables: [var mainTable, ..],
Orderings:
{
Tables: [var mainTable, ..],
Orderings:
[
{
Expression: ColumnExpression { Name: JsonEachKeyColumnName } orderingColumn,
IsAscending: true
}
{
Expression: ColumnExpression { Name: JsonEachKeyColumnName } orderingColumn,
IsAscending: true
}
]
}
}
&& orderingColumn.TableAlias == mainTable.Alias
&& IsJsonEachKeyColumn(selectExpression, orderingColumn);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,23 @@ public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpress
&& source.Selector is SqlExpression maxSqlExpression:
var maxArgumentType = GetProviderType(maxSqlExpression);
if (maxArgumentType == typeof(DateTimeOffset)
|| maxArgumentType == typeof(decimal)
|| maxArgumentType == typeof(TimeSpan)
|| maxArgumentType == typeof(ulong))
{
throw new NotSupportedException(
SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Max), maxArgumentType.ShortDisplayName()));
}
else if (maxArgumentType == typeof(decimal))
{
maxSqlExpression = CombineTerms(source, maxSqlExpression);
return _sqlExpressionFactory.Function(
"ef_max",
[maxSqlExpression],
nullable: true,
argumentsPropagateNullability: [false],
maxSqlExpression.Type,
maxSqlExpression.TypeMapping);
}

break;

Expand All @@ -86,13 +96,23 @@ public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpress
&& source.Selector is SqlExpression minSqlExpression:
var minArgumentType = GetProviderType(minSqlExpression);
if (minArgumentType == typeof(DateTimeOffset)
|| minArgumentType == typeof(decimal)
|| minArgumentType == typeof(TimeSpan)
|| minArgumentType == typeof(ulong))
{
throw new NotSupportedException(
SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Min), minArgumentType.ShortDisplayName()));
}
else if (minArgumentType == typeof(decimal))
{
minSqlExpression = CombineTerms(source, minSqlExpression);
return _sqlExpressionFactory.Function(
"ef_min",
[minSqlExpression],
nullable: true,
argumentsPropagateNullability: [false],
minSqlExpression.Type,
minSqlExpression.TypeMapping);
}

break;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,22 @@ private void InitializeDbConnection(DbConnection connection)
: acc.sum / acc.count,
isDeterministic: true);

sqliteConnection.CreateAggregate(
"ef_max",
seed: null,
(decimal? max, decimal? value) => max is null
? value
: value is null ? max : decimal.Max(max.Value, value.Value),
isDeterministic: true);

sqliteConnection.CreateAggregate(
"ef_min",
seed: null,
(decimal? min, decimal? value) => min is null
? value
: value is null ? min : decimal.Min(min.Value, value.Value),
isDeterministic: true);

sqliteConnection.CreateAggregate(
"ef_sum",
seed: null,
Expand All @@ -163,6 +179,10 @@ private void InitializeDbConnection(DbConnection connection)
? value
: sum.Value + value.Value,
isDeterministic: true);

sqliteConnection.CreateCollation(
"EF_DECIMAL",
(x, y) => decimal.Compare(decimal.Parse(x), decimal.Parse(y)));
}
else
{
Expand Down
135 changes: 115 additions & 20 deletions test/EFCore.Sqlite.FunctionalTests/BuiltInDataTypesSqliteTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -862,10 +862,7 @@ public virtual void Cant_query_Min_of_converted_types()
.Where(e => e.PartitionId == 200)
.GroupBy(_ => true);

Assert.Equal(
SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Min), typeof(decimal).ShortDisplayName()),
Assert.Throws<NotSupportedException>(
() => query.Select(g => g.Min(e => e.TestNullableDecimal)).ToList()).Message);
Assert.Equal(2.000000000000001m, query.Select(g => g.Min(e => e.TestNullableDecimal)).Single());

Assert.Equal(
SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Min), typeof(DateTimeOffset).ShortDisplayName()),
Expand Down Expand Up @@ -915,10 +912,7 @@ public virtual void Cant_query_Max_of_converted_types()
.Where(e => e.PartitionId == 201)
.GroupBy(_ => true);

Assert.Equal(
SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Max), typeof(decimal).ShortDisplayName()),
Assert.Throws<NotSupportedException>(
() => query.Select(g => g.Max(e => e.TestNullableDecimal)).ToList()).Message);
Assert.Equal(10.000000000000001m, query.Select(g => g.Max(e => e.TestNullableDecimal)).Single());

Assert.Equal(
SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Max), typeof(DateTimeOffset).ShortDisplayName()),
Expand Down Expand Up @@ -1406,12 +1400,6 @@ public virtual void Cant_query_OrderBy_of_converted_types()
.Where(e => e.PartitionId == 205);

var ex = Assert.Throws<NotSupportedException>(
() => query
.OrderBy(e => e.TestNullableDecimal)
.First());
Assert.Equal(SqliteStrings.OrderByNotSupported("decimal"), ex.Message);

ex = Assert.Throws<NotSupportedException>(
() => query
.OrderBy(e => e.TestNullableDateTimeOffset)
.First());
Expand Down Expand Up @@ -1463,12 +1451,6 @@ public virtual void Cant_query_ThenBy_of_converted_types()
.OrderBy(e => e.PartitionId);

var ex = Assert.Throws<NotSupportedException>(
() => query
.ThenBy(e => e.TestNullableDecimal)
.First());
Assert.Equal(SqliteStrings.OrderByNotSupported("decimal"), ex.Message);

ex = Assert.Throws<NotSupportedException>(
() => query
.ThenBy(e => e.TestNullableDateTimeOffset)
.First());
Expand All @@ -1487,6 +1469,119 @@ public virtual void Cant_query_ThenBy_of_converted_types()
Assert.Equal(SqliteStrings.OrderByNotSupported("ulong"), ex.Message);
}


[ConditionalFact]
public virtual void Can_query_OrderBy_of_converted_types()
{
using var context = CreateContext();
var min = new BuiltInNullableDataTypes
{
Id = 221,
PartitionId = 207,
TestNullableDecimal = 2.000000000000001m,
TestNullableDateTimeOffset = new DateTimeOffset(2018, 1, 1, 12, 0, 0, TimeSpan.Zero),
TestNullableTimeSpan = TimeSpan.FromDays(2),
TestNullableUnsignedInt64 = 0
};
context.Add(min);

var max = new BuiltInNullableDataTypes
{
Id = 222,
PartitionId = 207,
TestNullableDecimal = 10.000000000000001m,
TestNullableDateTimeOffset = new DateTimeOffset(2018, 1, 1, 11, 0, 0, TimeSpan.FromHours(-2)),
TestNullableTimeSpan = TimeSpan.FromDays(10),
TestNullableUnsignedInt64 = long.MaxValue + 1ul
};
context.Add(max);

context.SaveChanges();

Fixture.TestSqlLoggerFactory.Clear();

var query = context.Set<BuiltInNullableDataTypes>()
.Where(e => e.PartitionId == 207);

var results = query
.OrderBy(e => e.TestNullableDecimal)
.Select(e => e.Id)
.First();

AssertSql(
"""
SELECT "b"."Id"
FROM "BuiltInNullableDataTypes" AS "b"
WHERE "b"."PartitionId" = 207
ORDER BY "b"."TestNullableDecimal" COLLATE EF_DECIMAL
LIMIT 1
""");

var expectedResults = query.AsEnumerable()
.OrderBy(e => e.TestNullableDecimal)
.Select(e => e.Id)
.First();

Assert.Equal(expectedResults, results);
}

[ConditionalFact]
public virtual void Can_query_ThenBy_of_converted_types()
{
using var context = CreateContext();
var min = new BuiltInNullableDataTypes
{
Id = 223,
PartitionId = 208,
TestNullableDecimal = 2.000000000000001m,
TestNullableDateTimeOffset = new DateTimeOffset(2018, 1, 1, 12, 0, 0, TimeSpan.Zero),
TestNullableTimeSpan = TimeSpan.FromDays(2),
TestNullableUnsignedInt64 = 0
};
context.Add(min);

var max = new BuiltInNullableDataTypes
{
Id = 224,
PartitionId = 208,
TestNullableDecimal = 10.000000000000001m,
TestNullableDateTimeOffset = new DateTimeOffset(2018, 1, 1, 11, 0, 0, TimeSpan.FromHours(-2)),
TestNullableTimeSpan = TimeSpan.FromDays(10),
TestNullableUnsignedInt64 = long.MaxValue + 1ul
};
context.Add(max);

context.SaveChanges();

Fixture.TestSqlLoggerFactory.Clear();

var query = context.Set<BuiltInNullableDataTypes>()
.Where(e => e.PartitionId == 208);

var results = query
.OrderBy(e => e.PartitionId)
.ThenBy(e => e.TestNullableDecimal)
.Select(e => e.Id)
.First();

AssertSql(
"""
SELECT "b"."Id"
FROM "BuiltInNullableDataTypes" AS "b"
WHERE "b"."PartitionId" = 208
ORDER BY "b"."PartitionId", "b"."TestNullableDecimal" COLLATE EF_DECIMAL
LIMIT 1
""");

var expectedResults = query.AsEnumerable()
.OrderBy(e => e.PartitionId)
.ThenBy(e => e.TestNullableDecimal)
.Select(e => e.Id)
.First();

Assert.Equal(expectedResults, results);
}

[ConditionalFact]
public virtual void Can_query_using_char_ToLower()
{
Expand Down
28 changes: 20 additions & 8 deletions test/EFCore.Sqlite.FunctionalTests/Query/Ef6GroupBySqliteTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,28 @@ GROUP BY "p"."Category"
}

public override async Task Max_Grouped_from_LINQ_101(bool async)
=> Assert.Equal(
SqliteStrings.AggregateOperationNotSupported("Max", "decimal"),
(await Assert.ThrowsAsync<NotSupportedException>(
() => base.Max_Grouped_from_LINQ_101(async))).Message);
{
await base.Max_Grouped_from_LINQ_101(async);

AssertSql(
"""
SELECT "p"."Category", ef_max("p"."UnitPrice") AS "MostExpensivePrice"
FROM "ProductForLinq" AS "p"
GROUP BY "p"."Category"
""");
}

public override async Task Min_Grouped_from_LINQ_101(bool async)
=> Assert.Equal(
SqliteStrings.AggregateOperationNotSupported("Min", "decimal"),
(await Assert.ThrowsAsync<NotSupportedException>(
() => base.Min_Grouped_from_LINQ_101(async))).Message);
{
await base.Min_Grouped_from_LINQ_101(async);

AssertSql(
"""
SELECT "p"."Category", ef_min("p"."UnitPrice") AS "CheapestPrice"
FROM "ProductForLinq" AS "p"
GROUP BY "p"."Category"
""");
}

public override async Task Whats_new_2021_sample_3(bool async)
=> await base.Whats_new_2021_sample_3(async);
Expand Down
Loading