From 62a62cf953dd43400e0121a77c3a5c996149a67c Mon Sep 17 00:00:00 2001 From: JKamsker Date: Thu, 10 Oct 2024 02:37:59 +0200 Subject: [PATCH] Fixes #60: String array + Comparer --- .../Issue-60-StringEquality-Enumerables.cs | 48 ++++++++++++++ .../UnitTest1.cs | 10 +-- Generator.Equals.Runtime/Attributes.cs | 18 ++--- .../Classes/StringArrayEquality.cs | 49 ++++++++++++++ .../Classes/UnorderedEquality.Sample.cs | 1 + .../Classes/UnorderedEquality.cs | 16 ++--- Generator.Equals/EqualityGeneratorBase.cs | 40 ++++++++++-- .../{ => Extensions}/SymbolHelpers.cs | 12 ++++ .../Models/EqualityMemberModelTransformer.cs | 65 +++++++++++-------- 9 files changed, 206 insertions(+), 53 deletions(-) create mode 100644 Generator.Equals.DynamicGenerationTests/Issues/Issue-60-StringEquality-Enumerables.cs create mode 100644 Generator.Equals.Tests/Classes/StringArrayEquality.cs rename Generator.Equals/{ => Extensions}/SymbolHelpers.cs (91%) diff --git a/Generator.Equals.DynamicGenerationTests/Issues/Issue-60-StringEquality-Enumerables.cs b/Generator.Equals.DynamicGenerationTests/Issues/Issue-60-StringEquality-Enumerables.cs new file mode 100644 index 0000000..df7a341 --- /dev/null +++ b/Generator.Equals.DynamicGenerationTests/Issues/Issue-60-StringEquality-Enumerables.cs @@ -0,0 +1,48 @@ +using Microsoft.CodeAnalysis.CSharp; +using SourceGeneratorTestHelpers; + +namespace Generator.Equals.DynamicGenerationTests.Issues; + +public class Issue_60_StringEquality_Enumerables +{ + [Fact] + public void Test3_Struct_UnorderedEquality() + { + // StringComparer.OrdinalIgnoreCase; + + var input = SourceText.CSharp( + """ + using System; + using System.Collections.Generic; + using Generator.Equals; + + [Equatable] + public partial class Resource + { + [UnorderedEquality] + [StringEqualityAttribute(StringComparison.OrdinalIgnoreCase)] + public string[] Tags { get; set; } = Array.Empty(); + + } + """ + ); + + var result = IncrementalGenerator.Run + ( + input, + new CSharpParseOptions(), + UnitTest1.References + ); + + var gensource = result.Results + .SelectMany(x => x.GeneratedSources) + .Select(x => x.SourceText) + .ToList() + ; + + Assert.NotNull(gensource); + + Assert.Contains("new global::Generator.Equals.UnorderedEqualityComparer(StringComparer.OrdinalIgnoreCase)", + gensource.FirstOrDefault()?.ToString()); + } +} \ No newline at end of file diff --git a/Generator.Equals.DynamicGenerationTests/UnitTest1.cs b/Generator.Equals.DynamicGenerationTests/UnitTest1.cs index 9d7ada3..ff07bf1 100644 --- a/Generator.Equals.DynamicGenerationTests/UnitTest1.cs +++ b/Generator.Equals.DynamicGenerationTests/UnitTest1.cs @@ -9,7 +9,7 @@ namespace Generator.Equals.DynamicGenerationTests; public class UnitTest1 { - public static readonly List References2 = + public static readonly List References = AppDomain.CurrentDomain.GetAssemblies() .Where(_ => !_.IsDynamic && !string.IsNullOrWhiteSpace(_.Location)) .Select(_ => MetadataReference.CreateFromFile(_.Location)) @@ -45,7 +45,7 @@ public partial record MyRecord( new CSharpParseOptions() { }, - References2 + References ); var gensource = result.Results @@ -114,7 +114,7 @@ class LengthEqualityComparer : IEqualityComparer new CSharpParseOptions() { }, - References2 + References ); var gensource = result.Results @@ -153,7 +153,7 @@ public partial struct Sample new CSharpParseOptions() { }, - References2 + References ); var gensource = result.Results @@ -202,7 +202,7 @@ public Sample(string name, int age, bool flag) new CSharpParseOptions() { }, - References2 + References ); var gensource = result.Results diff --git a/Generator.Equals.Runtime/Attributes.cs b/Generator.Equals.Runtime/Attributes.cs index e41b04f..02c9203 100644 --- a/Generator.Equals.Runtime/Attributes.cs +++ b/Generator.Equals.Runtime/Attributes.cs @@ -5,7 +5,7 @@ namespace Generator.Equals { [GeneratedCode("Generator.Equals", "1.0.0.0")] - [Conditional("GENERATOR_EQUALS")] + //[Conditional("GENERATOR_EQUALS")] [AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct)] public class EquatableAttribute : Attribute { @@ -21,49 +21,49 @@ public class EquatableAttribute : Attribute } [GeneratedCode("Generator.Equals", "1.0.0.0")] - [Conditional("GENERATOR_EQUALS")] + //[Conditional("GENERATOR_EQUALS")] [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field)] public class DefaultEqualityAttribute : Attribute { } [GeneratedCode("Generator.Equals", "1.0.0.0")] - [Conditional("GENERATOR_EQUALS")] + //[Conditional("GENERATOR_EQUALS")] [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field)] public class OrderedEqualityAttribute : Attribute { } [GeneratedCode("Generator.Equals", "1.0.0.0")] - [Conditional("GENERATOR_EQUALS")] + //[Conditional("GENERATOR_EQUALS")] [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field)] public class IgnoreEqualityAttribute : Attribute { } [GeneratedCode("Generator.Equals", "1.0.0.0")] - [Conditional("GENERATOR_EQUALS")] + //[Conditional("GENERATOR_EQUALS")] [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field)] public class UnorderedEqualityAttribute : Attribute { } [GeneratedCode("Generator.Equals", "1.0.0.0")] - [Conditional("GENERATOR_EQUALS")] + //[Conditional("GENERATOR_EQUALS")] [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field)] public class ReferenceEqualityAttribute : Attribute { } [GeneratedCode("Generator.Equals", "1.0.0.0")] - [Conditional("GENERATOR_EQUALS")] + //[Conditional("GENERATOR_EQUALS")] [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field)] public class SetEqualityAttribute : Attribute { } [GeneratedCode("Generator.Equals", "1.0.0.0")] - [Conditional("GENERATOR_EQUALS")] + //[Conditional("GENERATOR_EQUALS")] [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field)] public class StringEqualityAttribute : Attribute { @@ -76,7 +76,7 @@ public StringEqualityAttribute(StringComparison comparisonType) } [GeneratedCode("Generator.Equals", "1.0.0.0")] - [Conditional("GENERATOR_EQUALS")] + //[Conditional("GENERATOR_EQUALS")] [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field)] public class CustomEqualityAttribute : Attribute { diff --git a/Generator.Equals.Tests/Classes/StringArrayEquality.cs b/Generator.Equals.Tests/Classes/StringArrayEquality.cs new file mode 100644 index 0000000..7de0090 --- /dev/null +++ b/Generator.Equals.Tests/Classes/StringArrayEquality.cs @@ -0,0 +1,49 @@ +using System; + +namespace Generator.Equals.Tests.Classes; + +public partial class StringArrayEquality +{ + [Equatable] + public partial class Sample + { + [UnorderedEquality, StringEquality(StringComparison.OrdinalIgnoreCase)] + public string[] Tags { get; set; } + } +} + +public partial class StringArrayEquality +{ + public class EqualsTests : EqualityTestCase + { + public override object Factory1() + { + return new Sample + { + Tags = new[] { "a", "b", "c" } + }; + } + + public override bool EqualsOperator(object value1, object value2) => (Sample)value1 == (Sample)value2; + public override bool NotEqualsOperator(object value1, object value2) => (Sample)value1 != (Sample)value2; + } + + // Order doesnt matter + public class OrderDoesntMatterEqualsTest : EqualityTestCase + { + public override bool Expected => true; + + public override object Factory1() => new Sample + { + Tags = new[] { "a", "b", "c" } + }; + + public override object Factory2() => new Sample + { + Tags = new[] { "c", "b", "a" } + }; + + public override bool EqualsOperator(object value1, object value2) => (Sample)value1 == (Sample)value2; + public override bool NotEqualsOperator(object value1, object value2) => (Sample)value1 != (Sample)value2; + } +} \ No newline at end of file diff --git a/Generator.Equals.Tests/Classes/UnorderedEquality.Sample.cs b/Generator.Equals.Tests/Classes/UnorderedEquality.Sample.cs index 9a5aa26..49ce179 100644 --- a/Generator.Equals.Tests/Classes/UnorderedEquality.Sample.cs +++ b/Generator.Equals.Tests/Classes/UnorderedEquality.Sample.cs @@ -1,3 +1,4 @@ +using System; using System.Collections.Generic; namespace Generator.Equals.Tests.Classes diff --git a/Generator.Equals.Tests/Classes/UnorderedEquality.cs b/Generator.Equals.Tests/Classes/UnorderedEquality.cs index 07eb641..8e68dcf 100644 --- a/Generator.Equals.Tests/Classes/UnorderedEquality.cs +++ b/Generator.Equals.Tests/Classes/UnorderedEquality.cs @@ -18,30 +18,30 @@ public override object Factory1() Properties = Enumerable .Range(1, 1000) .OrderBy(_ => randomSort.NextDouble()) - .ToList() + .ToList(), }; } - public override bool EqualsOperator(object value1, object value2) => (Sample) value1 == (Sample) value2; - public override bool NotEqualsOperator(object value1, object value2) => (Sample) value1 != (Sample) value2; + public override bool EqualsOperator(object value1, object value2) => (Sample)value1 == (Sample)value2; + public override bool NotEqualsOperator(object value1, object value2) => (Sample)value1 != (Sample)value2; } - + public class NotEqualsTest : EqualityTestCase { public override bool Expected => false; public override object Factory1() => new Sample { - Properties = Enumerable.Range(1, 1000).ToList() + Properties = Enumerable.Range(1, 1000).ToList(), }; public override object Factory2() => new Sample { - Properties = Enumerable.Range(1, 1001).ToList() + Properties = Enumerable.Range(1, 1001).ToList(), }; - public override bool EqualsOperator(object value1, object value2) => (Sample) value1 == (Sample) value2; - public override bool NotEqualsOperator(object value1, object value2) => (Sample) value1 != (Sample) value2; + public override bool EqualsOperator(object value1, object value2) => (Sample)value1 == (Sample)value2; + public override bool NotEqualsOperator(object value1, object value2) => (Sample)value1 != (Sample)value2; } } } \ No newline at end of file diff --git a/Generator.Equals/EqualityGeneratorBase.cs b/Generator.Equals/EqualityGeneratorBase.cs index 3b60403..3793563 100644 --- a/Generator.Equals/EqualityGeneratorBase.cs +++ b/Generator.Equals/EqualityGeneratorBase.cs @@ -50,16 +50,32 @@ private static void BuildEquality(EqualityMemberModel memberModel, IndentedTextW case EqualityType.IgnoreEquality: break; - case EqualityType.UnorderedEquality when !memberModel.IsDictionary: + case EqualityType.UnorderedEquality + when memberModel is { IsDictionary: false, StringComparer: not null and not "" }: + + writer.WriteLine( + $"&& new global::Generator.Equals.UnorderedEqualityComparer<{memberModel.TypeName}>(global::System.StringComparer.{memberModel.StringComparer}).Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); + break; + + case EqualityType.UnorderedEquality + when memberModel is { IsDictionary: false, StringComparer: null }: writer.WriteLine( $"&& global::Generator.Equals.UnorderedEqualityComparer<{memberModel.TypeName}>.Default.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); break; - case EqualityType.UnorderedEquality when memberModel.IsDictionary: + case EqualityType.UnorderedEquality + when memberModel is { IsDictionary: true, StringComparer: null }: writer.WriteLine( $"&& global::Generator.Equals.DictionaryEqualityComparer<{memberModel.TypeName}>.Default.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); break; + case EqualityType.OrderedEquality + when memberModel is { StringComparer: not null and not "" }: + + writer.WriteLine( + $"&& new global::Generator.Equals.OrderedEqualityComparer<{memberModel.TypeName}>(global::System.StringComparer.{memberModel.StringComparer}).Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); + break; + case EqualityType.OrderedEquality: writer.WriteLine( $"&& global::Generator.Equals.OrderedEqualityComparer<{memberModel.TypeName}>.Default.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); @@ -137,13 +153,29 @@ private static void BuildHashCode(EqualityMemberModel memberModel, IndentedTextW case EqualityType.IgnoreEquality: break; - case EqualityType.UnorderedEquality when memberModel.IsDictionary: + case EqualityType.UnorderedEquality + when memberModel is { StringComparer: null, IsDictionary: true }: + BuildHashCodeAdd($"global::Generator.Equals.DictionaryEqualityComparer<{memberModel.TypeName}>.Default"); break; - case EqualityType.UnorderedEquality when !memberModel.IsDictionary: + case EqualityType.UnorderedEquality + when memberModel is { StringComparer: null, IsDictionary: false}: + BuildHashCodeAdd($"global::Generator.Equals.UnorderedEqualityComparer<{memberModel.TypeName}>.Default"); break; + + case EqualityType.UnorderedEquality + when memberModel is { StringComparer: not null and not "", IsDictionary: false }: + + BuildHashCodeAdd($"new global::Generator.Equals.UnorderedEqualityComparer<{memberModel.TypeName}>(global::System.StringComparer.{memberModel.StringComparer})"); + break; + + case EqualityType.OrderedEquality + when memberModel is { StringComparer: not null and not "" }: + + BuildHashCodeAdd($"new global::Generator.Equals.OrderedEqualityComparer<{memberModel.TypeName}>(global::System.StringComparer.{memberModel.StringComparer})"); + break; case EqualityType.OrderedEquality: BuildHashCodeAdd($"global::Generator.Equals.OrderedEqualityComparer<{memberModel.TypeName}>.Default"); diff --git a/Generator.Equals/SymbolHelpers.cs b/Generator.Equals/Extensions/SymbolHelpers.cs similarity index 91% rename from Generator.Equals/SymbolHelpers.cs rename to Generator.Equals/Extensions/SymbolHelpers.cs index b9253c9..621b4d3 100644 --- a/Generator.Equals/SymbolHelpers.cs +++ b/Generator.Equals/Extensions/SymbolHelpers.cs @@ -73,6 +73,18 @@ public static bool HasAttribute(this ISymbol symbol, INamedTypeSymbol attribute) ? null : new DictionaryArgumentsResult(res); } + + public static bool IsStringArray(this ITypeSymbol typeSymbol) + { + // Check if the symbol is an array + if (typeSymbol is IArrayTypeSymbol arrayTypeSymbol) + { + // Check if the element type is string + return arrayTypeSymbol.ElementType.SpecialType == SpecialType.System_String; + } + + return false; + } } public record DictionaryArgumentsResult(ImmutableArray? Arguments) : ArgumentsResult(Arguments) diff --git a/Generator.Equals/Models/EqualityMemberModelTransformer.cs b/Generator.Equals/Models/EqualityMemberModelTransformer.cs index 9e80541..4060d9d 100644 --- a/Generator.Equals/Models/EqualityMemberModelTransformer.cs +++ b/Generator.Equals/Models/EqualityMemberModelTransformer.cs @@ -16,17 +16,16 @@ public static ImmutableArray BuildEqualityModels( ) { var isRecord = symbol.IsRecord; - + var members = symbol.GetPropertiesAndFields(); var models = members .Where(member => filter == null || filter(member)) - + // ignore equalitycontract if the type is a record .Where(member => !isRecord || member.Name != "EqualityContract") - .Select(member => member switch { - IPropertySymbol propertySymbol + IPropertySymbol propertySymbol => BuildEqualityModel(propertySymbol, propertySymbol.Type, attributesMetadata, explicitMode), IFieldSymbol fieldSymbol => BuildEqualityModel(fieldSymbol, fieldSymbol.Type, attributesMetadata, explicitMode), _ => throw new NotSupportedException($"Member of type {member.GetType()} not supported") @@ -36,7 +35,7 @@ IPropertySymbol propertySymbol return models; } - + public static EqualityMemberModel BuildEqualityModel( ISymbol memberSymbol, ITypeSymbol typeSymbol, @@ -46,7 +45,7 @@ bool explicitMode { var propertyName = memberSymbol.ToFQF(); var typeName = typeSymbol.ToNullableFQF(); - + // IgnoreEquality if (memberSymbol.HasAttribute(attributesMetadata.IgnoreEquality)) { @@ -56,8 +55,31 @@ bool explicitMode }; } - // Check for different equality attributes and map them to the model - if (memberSymbol.HasAttribute(attributesMetadata.UnorderedEquality)) + + if (memberSymbol.HasAttribute(attributesMetadata.StringEquality)) + { + var attribute = memberSymbol.GetAttribute(attributesMetadata.StringEquality)!; + var stringComparisonValue = Convert.ToInt64(attribute.ConstructorArguments[0].Value); + + if (!attributesMetadata.StringComparisonLookup.TryGetValue(stringComparisonValue, out var enumMemberName)) + { + throw new Exception("Unexpected StringComparison value."); + } + + + // Special case: We do this comparison through either OrderedEquality or UnorderedEquality + if (typeSymbol.IsStringArray() && typeSymbol.GetIEnumerableTypeArguments() is { } args) + { + var equalityType = memberSymbol.HasAttribute(attributesMetadata.UnorderedEquality) + ? EqualityType.UnorderedEquality + : EqualityType.OrderedEquality; + + return new EqualityMemberModel(propertyName, args.Name, equalityType, stringComparer: enumMemberName); + } + + return new EqualityMemberModel(propertyName, typeName, EqualityType.StringEquality, stringComparer: enumMemberName); + } + else if (memberSymbol.HasAttribute(attributesMetadata.UnorderedEquality)) { var args = typeSymbol.GetIDictionaryTypeArguments() ?? typeSymbol.GetIEnumerableTypeArguments()!; @@ -81,18 +103,6 @@ bool explicitMode var types = typeSymbol.GetIEnumerableTypeArguments()!; return new EqualityMemberModel(propertyName, types.Name, EqualityType.SetEquality); } - else if (memberSymbol.HasAttribute(attributesMetadata.StringEquality)) - { - var attribute = memberSymbol.GetAttribute(attributesMetadata.StringEquality)!; - var stringComparisonValue = Convert.ToInt64(attribute.ConstructorArguments[0].Value); - - if (!attributesMetadata.StringComparisonLookup.TryGetValue(stringComparisonValue, out var enumMemberName)) - { - throw new Exception("Unexpected StringComparison value."); - } - - return new EqualityMemberModel(propertyName, typeName, EqualityType.StringEquality, stringComparer: enumMemberName); - } else if (memberSymbol.HasAttribute(attributesMetadata.CustomEquality)) { var attribute = memberSymbol.GetAttribute(attributesMetadata.CustomEquality); @@ -101,19 +111,20 @@ bool explicitMode var comparerTypeName = comparerType.ToFQF(); var comparerMemberName = (string)attribute?.ConstructorArguments[1].Value!; - var hasDefault = comparerType.GetMembers().Any(x => x.Name == comparerMemberName && x.IsStatic) - || comparerType.GetPropertiesAndFields().Any(x => x.Name == comparerMemberName && x.IsStatic); + var hasDefault = comparerType.GetMembers().Any(x => x.Name == comparerMemberName && x.IsStatic) + || comparerType.GetPropertiesAndFields().Any(x => x.Name == comparerMemberName && x.IsStatic); - return new EqualityMemberModel(propertyName, typeName, EqualityType.CustomEquality, comparerTypeName, comparerMemberName) + return new EqualityMemberModel(propertyName, typeName, EqualityType.CustomEquality, comparerTypeName, + comparerMemberName) { ComparerHasStaticInstance = hasDefault }; - } - + var isIgnored = (explicitMode && !memberSymbol.HasAttribute(attributesMetadata.DefaultEquality)); - - return new EqualityMemberModel(propertyName, typeName, isIgnored ? EqualityType.IgnoreEquality : EqualityType.DefaultEquality) + + return new EqualityMemberModel(propertyName, typeName, + isIgnored ? EqualityType.IgnoreEquality : EqualityType.DefaultEquality) { Ignored = isIgnored };