66// a generalized "nullable" option here to allow us to do that.
77#nullable disable
88
9+ using System . ComponentModel ;
910using System . Diagnostics . CodeAnalysis ;
1011using System . Linq . Expressions ;
12+ using System . Reflection ;
1113
1214namespace CommunityToolkit . Datasync . Client . Query . Linq ;
1315
@@ -17,6 +19,30 @@ namespace CommunityToolkit.Datasync.Client.Query.Linq;
1719/// </summary>
1820internal static class ExpressionExtensions
1921{
22+ private static readonly MethodInfo Contains ;
23+ private static readonly MethodInfo SequenceEqual ;
24+
25+ static ExpressionExtensions ( )
26+ {
27+ Dictionary < string , List < MethodInfo > > queryableMethodGroups = typeof ( Enumerable )
28+ . GetMethods ( BindingFlags . Public | BindingFlags . Static | BindingFlags . DeclaredOnly )
29+ . GroupBy ( mi => mi . Name )
30+ . ToDictionary ( e => e . Key , l => l . ToList ( ) ) ;
31+
32+ MethodInfo GetMethod ( string name , int genericParameterCount , Func < Type [ ] , Type [ ] > parameterGenerator )
33+ => queryableMethodGroups [ name ] . Single ( mi => ( ( genericParameterCount == 0 && ! mi . IsGenericMethod )
34+ || ( mi . IsGenericMethod && mi . GetGenericArguments ( ) . Length == genericParameterCount ) )
35+ && mi . GetParameters ( ) . Select ( e => e . ParameterType ) . SequenceEqual (
36+ parameterGenerator ( mi . IsGenericMethod ? mi . GetGenericArguments ( ) : [ ] ) ) ) ;
37+
38+ Contains = GetMethod (
39+ nameof ( Enumerable . Contains ) , 1 ,
40+ types => [ typeof ( IEnumerable < > ) . MakeGenericType ( types [ 0 ] ) , types [ 0 ] ] ) ;
41+ SequenceEqual = GetMethod (
42+ nameof ( Enumerable . SequenceEqual ) , 1 ,
43+ types => [ typeof ( IEnumerable < > ) . MakeGenericType ( types [ 0 ] ) , typeof ( IEnumerable < > ) . MakeGenericType ( types [ 0 ] ) ] ) ;
44+ }
45+
2046 /// <summary>
2147 /// Walk the expression and compute all the subtrees that are not dependent on any
2248 /// of the expressions parameters.
@@ -127,6 +153,7 @@ internal static bool IsValidLambdaExpression(this MethodCallExpression expressio
127153 /// <returns>The partially evaluated expression</returns>
128154 internal static Expression PartiallyEvaluate ( this Expression expression )
129155 {
156+ expression = expression . RemoveSpanImplicitCast ( ) ;
130157 List < Expression > subtrees = expression . FindIndependentSubtrees ( ) ;
131158 return VisitorHelper . VisitAll ( expression , ( Expression expr , Func < Expression , Expression > recurse ) =>
132159 {
@@ -143,6 +170,63 @@ internal static Expression PartiallyEvaluate(this Expression expression)
143170 } ) ;
144171 }
145172
173+ internal static Expression RemoveSpanImplicitCast ( this Expression expression )
174+ {
175+ return VisitorHelper . VisitAll ( expression , ( Expression expr , Func < Expression , Expression > recurse ) =>
176+ {
177+ if ( expr is MethodCallExpression methodCall )
178+ {
179+ MethodInfo method = methodCall . Method ;
180+
181+ if ( method . DeclaringType == typeof ( MemoryExtensions ) )
182+ {
183+ switch ( method . Name )
184+ {
185+ case nameof ( MemoryExtensions . Contains )
186+ when methodCall . Arguments is [ Expression arg0 , Expression arg1 ] && TryUnwrapSpanImplicitCast ( arg0 , out Expression unwrappedArg0 ) :
187+ {
188+ Expression unwrappedExpr = Expression . Call (
189+ Contains . MakeGenericMethod ( methodCall . Method . GetGenericArguments ( ) [ 0 ] ) ,
190+ unwrappedArg0 , arg1 ) ;
191+ return recurse ( unwrappedExpr ) ;
192+ }
193+
194+ case nameof ( MemoryExtensions . SequenceEqual )
195+ when methodCall . Arguments is [ Expression arg0 , Expression arg1 ]
196+ && TryUnwrapSpanImplicitCast ( arg0 , out Expression unwrappedArg0 )
197+ && TryUnwrapSpanImplicitCast ( arg1 , out Expression unwrappedArg1 ) :
198+ {
199+ Expression unwrappedExpr = Expression . Call (
200+ SequenceEqual . MakeGenericMethod ( methodCall . Method . GetGenericArguments ( ) [ 0 ] ) ,
201+ unwrappedArg0 , unwrappedArg1 ) ;
202+ return recurse ( unwrappedExpr ) ;
203+ }
204+ }
205+
206+ static bool TryUnwrapSpanImplicitCast ( Expression expression , out Expression result )
207+ {
208+ if ( expression is MethodCallExpression
209+ {
210+ Method : { Name : "op_Implicit" , DeclaringType : { IsGenericType : true } implicitCastDeclaringType } ,
211+ Arguments : [ Expression unwrapped ]
212+ }
213+ && implicitCastDeclaringType . GetGenericTypeDefinition ( ) is Type genericTypeDefinition
214+ && ( genericTypeDefinition == typeof ( Span < > ) || genericTypeDefinition == typeof ( ReadOnlySpan < > ) ) )
215+ {
216+ result = unwrapped ;
217+ return true ;
218+ }
219+
220+ result = null ;
221+ return false ;
222+ }
223+ }
224+ }
225+
226+ return recurse ( expr ) ;
227+ } ) ;
228+ }
229+
146230 /// <summary>
147231 /// Remove the quote from quoted expressions.
148232 /// </summary>
0 commit comments