@@ -450,7 +450,7 @@ void GatherCollections(SemanticModel model, IEnumerable<ITypeSymbol> types, Hash
450
450
451
451
void GatherCollection ( SemanticModel model , ITypeSymbol type , HashSet < CollectionInfo > collections )
452
452
{
453
- if ( type . TypeKind != TypeKind . Class ) return ;
453
+ if ( type . TypeKind != TypeKind . Class && type . TypeKind != TypeKind . Interface ) return ;
454
454
if ( type is INamedTypeSymbol namedTypeSymbol )
455
455
{
456
456
if ( namedTypeSymbol . IsGenericType )
@@ -462,24 +462,63 @@ void GatherCollection(SemanticModel model, ITypeSymbol type, HashSet<CollectionI
462
462
}
463
463
464
464
var fullName = namedTypeSymbol . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) ;
465
- if ( fullName . StartsWith ( "global::System.Collections.Generic.HashSet<" )
466
- || fullName . StartsWith ( "global::System.Collections.Generic.List<" ) )
465
+ ReadOnlySpan < string > collectionToInstance =
466
+ [
467
+ "global::System.Collections.Generic.List<" , "global::System.Collections.Generic.List<" ,
468
+ "global::System.Collections.Generic.IList<" , "global::System.Collections.Generic.List<" ,
469
+ "global::System.Collections.Generic.IReadOnlyList<" , "global::System.Collections.Generic.List<" ,
470
+ "global::System.Collections.Generic.HashSet<" , "global::System.Collections.Generic.HashSet<" ,
471
+ "global::System.Collections.Generic.ISet<" , "global::System.Collections.Generic.HashSet<" ,
472
+ "global::System.Collections.Generic.IReadOnlySet<" , "global::System.Collections.Generic.HashSet<" ,
473
+ "global::System.Collections.Generic.IEnumerable<" , "global::System.Collections.Generic.List<" ,
474
+ ] ;
475
+ string ? instance = null ;
476
+ for ( var i = 0 ; i < collectionToInstance . Length ; i += 2 )
477
+ {
478
+ if ( fullName . StartsWith ( collectionToInstance [ i ] , StringComparison . Ordinal ) )
479
+ {
480
+ instance = collectionToInstance [ i + 1 ] + fullName . Substring ( collectionToInstance [ i ] . Length ) ;
481
+ break ;
482
+ }
483
+ }
484
+
485
+ if ( instance != null )
467
486
{
468
487
var elementType = namedTypeSymbol . TypeArguments [ 0 ] ;
469
- var collectionInfo = new CollectionInfo ( fullName ,
488
+ var collectionInfo = new CollectionInfo ( fullName , instance ,
470
489
elementType . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) ,
471
490
elementType . IsReferenceType , null , null ) ;
472
491
collections . Add ( collectionInfo ) ;
473
492
}
474
- else if ( fullName . StartsWith ( "global::System.Collections.Generic.Dictionary<" ) )
493
+ else
475
494
{
476
- var keyType = namedTypeSymbol . TypeArguments [ 0 ] ;
477
- var valueType = namedTypeSymbol . TypeArguments [ 1 ] ;
478
- var collectionInfo = new CollectionInfo ( fullName ,
479
- keyType . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) ,
480
- keyType . IsReferenceType , valueType . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) ,
481
- valueType . IsReferenceType ) ;
482
- collections . Add ( collectionInfo ) ;
495
+ ReadOnlySpan < string > collectionToInstance2 =
496
+ [
497
+ "global::System.Collections.Generic.Dictionary<" , "global::System.Collections.Generic.Dictionary<" ,
498
+ "global::System.Collections.Generic.IDictionary<" , "global::System.Collections.Generic.Dictionary<" ,
499
+ "global::System.Collections.Generic.IReadOnlyDictionary<" ,
500
+ "global::System.Collections.Generic.Dictionary<" ,
501
+ ] ;
502
+ instance = null ;
503
+ for ( var i = 0 ; i < collectionToInstance2 . Length ; i += 2 )
504
+ {
505
+ if ( fullName . StartsWith ( collectionToInstance2 [ i ] , StringComparison . Ordinal ) )
506
+ {
507
+ instance = collectionToInstance2 [ i + 1 ] + fullName . Substring ( collectionToInstance2 [ i ] . Length ) ;
508
+ break ;
509
+ }
510
+ }
511
+
512
+ if ( instance != null )
513
+ {
514
+ var keyType = namedTypeSymbol . TypeArguments [ 0 ] ;
515
+ var valueType = namedTypeSymbol . TypeArguments [ 1 ] ;
516
+ var collectionInfo = new CollectionInfo ( fullName , instance ,
517
+ keyType . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) ,
518
+ keyType . IsReferenceType , valueType . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) ,
519
+ valueType . IsReferenceType ) ;
520
+ collections . Add ( collectionInfo ) ;
521
+ }
483
522
}
484
523
}
485
524
}
@@ -667,12 +706,12 @@ internal static void Register4BTDB()
667
706
668
707
static object Create{{ idx }} (uint capacity)
669
708
{
670
- return new {{ collection . FullName }} ((int)capacity);
709
+ return new {{ collection . InstantiableFullName }} ((int)capacity);
671
710
}
672
711
673
712
static void Add{{ idx }} (object c, ref byte key, ref byte value)
674
713
{
675
- Unsafe.As<{{ collection . FullName }} >(c).Add(Unsafe.As<byte, {{ collection . KeyType }} >(ref key), Unsafe.As<byte, {{ collection . ValueType }} >(ref value));
714
+ Unsafe.As<{{ collection . InstantiableFullName }} >(c).Add(Unsafe.As<byte, {{ collection . KeyType }} >(ref key), Unsafe.As<byte, {{ collection . ValueType }} >(ref value));
676
715
}
677
716
678
717
""" ) ;
@@ -692,12 +731,12 @@ internal static void Register4BTDB()
692
731
693
732
static object Create{{ idx }} (uint capacity)
694
733
{
695
- return new {{ collection . FullName }} ((int)capacity);
734
+ return new {{ collection . InstantiableFullName }} ((int)capacity);
696
735
}
697
736
698
737
static void Add{{ idx }} (object c, ref byte value)
699
738
{
700
- Unsafe.As<{{ collection . FullName }} >(c).Add(Unsafe.As<byte, {{ collection . KeyType }} >(ref value));
739
+ Unsafe.As<{{ collection . InstantiableFullName }} >(c).Add(Unsafe.As<byte, {{ collection . KeyType }} >(ref value));
701
740
}
702
741
703
742
""" ) ;
@@ -1192,7 +1231,13 @@ record GenerationInfo(
1192
1231
Location ? Location
1193
1232
) ;
1194
1233
1195
- record CollectionInfo ( string FullName , string KeyType , bool KeyIsReference , string ? ValueType , bool ? ValueIsReference ) ;
1234
+ record CollectionInfo (
1235
+ string FullName ,
1236
+ string InstantiableFullName ,
1237
+ string KeyType ,
1238
+ bool KeyIsReference ,
1239
+ string ? ValueType ,
1240
+ bool ? ValueIsReference ) ;
1196
1241
1197
1242
record ParameterInfo ( string Name , string Type , bool IsReference , bool Optional , string ? DefaultValue ) ;
1198
1243
0 commit comments