Skip to content

Commit 5114a6d

Browse files
committed
Fix LINQ conversion of .Contains()
1 parent 949b368 commit 5114a6d

File tree

5 files changed

+88
-4
lines changed

5 files changed

+88
-4
lines changed

src/CommunityToolkit.Datasync.Client/Query/Linq/ExpressionExtensions.cs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
// a generalized "nullable" option here to allow us to do that.
77
#nullable disable
88

9+
using System.ComponentModel;
910
using System.Diagnostics.CodeAnalysis;
1011
using System.Linq.Expressions;
12+
using System.Reflection;
1113

1214
namespace CommunityToolkit.Datasync.Client.Query.Linq;
1315

@@ -17,6 +19,30 @@ namespace CommunityToolkit.Datasync.Client.Query.Linq;
1719
/// </summary>
1820
internal static class ExpressionExtensions
1921
{
22+
private static readonly MethodInfo Contains;
23+
private static readonly MethodInfo SequenceEqual;
24+
25+
static ExpressionExtensions()
26+
{
27+
Dictionary<string, List<MethodInfo>> queryableMethodGroups = typeof(Enumerable)
28+
.GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
29+
.GroupBy(mi => mi.Name)
30+
.ToDictionary(e => e.Key, l => l.ToList());
31+
32+
MethodInfo GetMethod(string name, int genericParameterCount, Func<Type[], Type[]> parameterGenerator)
33+
=> queryableMethodGroups[name].Single(mi => ((genericParameterCount == 0 && !mi.IsGenericMethod)
34+
|| (mi.IsGenericMethod && mi.GetGenericArguments().Length == genericParameterCount))
35+
&& mi.GetParameters().Select(e => e.ParameterType).SequenceEqual(
36+
parameterGenerator(mi.IsGenericMethod ? mi.GetGenericArguments() : [])));
37+
38+
Contains = GetMethod(
39+
nameof(Enumerable.Contains), 1,
40+
types => [typeof(IEnumerable<>).MakeGenericType(types[0]), types[0]]);
41+
SequenceEqual = GetMethod(
42+
nameof(Enumerable.SequenceEqual), 1,
43+
types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(IEnumerable<>).MakeGenericType(types[0])]);
44+
}
45+
2046
/// <summary>
2147
/// Walk the expression and compute all the subtrees that are not dependent on any
2248
/// of the expressions parameters.
@@ -127,6 +153,7 @@ internal static bool IsValidLambdaExpression(this MethodCallExpression expressio
127153
/// <returns>The partially evaluated expression</returns>
128154
internal static Expression PartiallyEvaluate(this Expression expression)
129155
{
156+
expression = expression.RemoveSpanImplicitCast();
130157
List<Expression> subtrees = expression.FindIndependentSubtrees();
131158
return VisitorHelper.VisitAll(expression, (Expression expr, Func<Expression, Expression> recurse) =>
132159
{
@@ -143,6 +170,63 @@ internal static Expression PartiallyEvaluate(this Expression expression)
143170
});
144171
}
145172

173+
internal static Expression RemoveSpanImplicitCast(this Expression expression)
174+
{
175+
return VisitorHelper.VisitAll(expression, (Expression expr, Func<Expression, Expression> recurse) =>
176+
{
177+
if (expr is MethodCallExpression methodCall)
178+
{
179+
MethodInfo method = methodCall.Method;
180+
181+
if (method.DeclaringType == typeof(MemoryExtensions))
182+
{
183+
switch (method.Name)
184+
{
185+
case nameof(MemoryExtensions.Contains)
186+
when methodCall.Arguments is [Expression arg0, Expression arg1] && TryUnwrapSpanImplicitCast(arg0, out Expression unwrappedArg0):
187+
{
188+
Expression unwrappedExpr = Expression.Call(
189+
Contains.MakeGenericMethod(methodCall.Method.GetGenericArguments()[0]),
190+
unwrappedArg0, arg1);
191+
return recurse(unwrappedExpr);
192+
}
193+
194+
case nameof(MemoryExtensions.SequenceEqual)
195+
when methodCall.Arguments is [Expression arg0, Expression arg1]
196+
&& TryUnwrapSpanImplicitCast(arg0, out Expression unwrappedArg0)
197+
&& TryUnwrapSpanImplicitCast(arg1, out Expression unwrappedArg1):
198+
{
199+
Expression unwrappedExpr = Expression.Call(
200+
SequenceEqual.MakeGenericMethod(methodCall.Method.GetGenericArguments()[0]),
201+
unwrappedArg0, unwrappedArg1);
202+
return recurse(unwrappedExpr);
203+
}
204+
}
205+
206+
static bool TryUnwrapSpanImplicitCast(Expression expression, out Expression result)
207+
{
208+
if (expression is MethodCallExpression
209+
{
210+
Method: { Name: "op_Implicit", DeclaringType: { IsGenericType: true } implicitCastDeclaringType },
211+
Arguments: [Expression unwrapped]
212+
}
213+
&& implicitCastDeclaringType.GetGenericTypeDefinition() is Type genericTypeDefinition
214+
&& (genericTypeDefinition == typeof(Span<>) || genericTypeDefinition == typeof(ReadOnlySpan<>)))
215+
{
216+
result = unwrapped;
217+
return true;
218+
}
219+
220+
result = null;
221+
return false;
222+
}
223+
}
224+
}
225+
226+
return recurse(expr);
227+
});
228+
}
229+
146230
/// <summary>
147231
/// Remove the quote from quoted expressions.
148232
/// </summary>

tests/CommunityToolkit.Datasync.Client.Test/Query/IDatasyncPullQuery_Tests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1177,7 +1177,7 @@ public void Linq_Where_StartsWith_InvariantIgnoreCase()
11771177
);
11781178
}
11791179

1180-
[Fact(Skip = "OData v8.4 does not allow string.contains")]
1180+
[Fact]
11811181
public void Linq_Where_String_Contains()
11821182
{
11831183
string[] ratings = ["A", "B"];

tests/CommunityToolkit.Datasync.Client.Test/Query/IDatasyncQueryable_Tests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1416,7 +1416,7 @@ public void Linq_Where_StartsWith_InvariantIgnoreCase()
14161416
);
14171417
}
14181418

1419-
[Fact(Skip = "OData v8.4 does not allow string.contains")]
1419+
[Fact]
14201420
public void Linq_Where_String_Contains()
14211421
{
14221422
string[] ratings = ["A", "B"];

tests/CommunityToolkit.Datasync.Client.Test/Service/DatasyncServiceClient_Tests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3547,7 +3547,7 @@ public void Linq_Where_StartsWith_InvariantIgnoreCase()
35473547
);
35483548
}
35493549

3550-
[Fact(Skip = "OData v8.4 does not allow string.contains")]
3550+
[Fact]
35513551
public void Linq_Where_String_Contains()
35523552
{
35533553
string[] ratings = ["A", "B"];

tests/CommunityToolkit.Datasync.Client.Test/Service/Integration_Query_Tests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ await KitchenSinkQueryTest(
540540
// );
541541
//}
542542

543-
[Fact(Skip = "OData v8.4 does not allow string.contains")]
543+
[Fact]
544544
public async Task KitchenSinkQueryTest_020()
545545
{
546546
SeedKitchenSinkWithCountryData();

0 commit comments

Comments
 (0)