Skip to content

Commit 34225d9

Browse files
committed
Slightly simplify some logic
1 parent 1599911 commit 34225d9

File tree

3 files changed

+54
-46
lines changed

3 files changed

+54
-46
lines changed

Generators/Internal/MockClassGenerator.cs

+7-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@ internal class MockClassGenerator {
88
SymbolDisplayGlobalNamespaceStyle.OmittedAsContaining
99
);
1010

11-
private readonly MockTargetDiscovery _mockTargetDiscovery = new();
11+
private readonly MockTargetModelFactory _modelFactory;
1212
private readonly MockMemberGenerator _mockMemberGenerator = new();
1313

14+
public MockClassGenerator(MockTargetModelFactory modelFactory) {
15+
_modelFactory = modelFactory;
16+
}
17+
1418
[PerformanceSensitive("")]
1519
public string Generate(MockTarget target) {
1620
var targetTypeNamespace = target.Type.ContainingNamespace.ToDisplayString(TargetTypeNamespaceDisplayFormat);
@@ -45,7 +49,7 @@ public string Generate(MockTarget target) {
4549
.WriteLine(Indents.Type, "public interface ", callsInterfaceName, " {");
4650

4751
#pragma warning disable HAA0401 // Possible allocation of reference type enumerator - TODO
48-
foreach (var member in _mockTargetDiscovery.GetMembersToMock(target)) {
52+
foreach (var member in _modelFactory.GetMockTargetMembers(target)) {
4953
#pragma warning restore HAA0401
5054
mainWriter.WriteLine();
5155
_mockMemberGenerator.WriteMemberMocks(
@@ -99,7 +103,7 @@ private string GenerateTypeParametersAsString(MockTarget target) {
99103
writer.Write("<");
100104
var index = 0;
101105
foreach (var parameter in parameters) {
102-
_mockTargetDiscovery.EnsureNoUnsupportedConstraints(parameter);
106+
_modelFactory.EnsureNoUnsupportedConstraints(parameter);
103107
if (index > 0)
104108
writer.Write(", ");
105109
writer.Write(parameter.Name);

Generators/Internal/MockTargetDiscovery.cs renamed to Generators/Internal/MockTargetModelFactory.cs

+12-7
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,17 @@
66
using SourceMock.Generators.Internal.Models;
77

88
namespace SourceMock.Generators.Internal {
9-
internal class MockTargetDiscovery {
9+
internal class MockTargetModelFactory {
1010
private static readonly SymbolDisplayFormat TargetTypeDisplayFormat = SymbolDisplayFormat.FullyQualifiedFormat
1111
.WithMiscellaneousOptions(SymbolDisplayFormat.FullyQualifiedFormat.MiscellaneousOptions | SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier);
1212

13+
public MockTarget GetMockTarget(INamedTypeSymbol type) {
14+
var fullName = GetFullTypeName(type, NullableAnnotation.None);
15+
return new MockTarget(type, fullName);
16+
}
17+
1318
[PerformanceSensitive("")]
14-
public IEnumerable<MockTargetMember> GetMembersToMock(MockTarget target) {
19+
public IEnumerable<MockTargetMember> GetMockTargetMembers(MockTarget target) {
1520
#pragma warning disable HAA0502 // Explicit allocation -- unavoidable for now, can be pooled later (or removed if we handle them differently)
1621
var seen = new HashSet<string>();
1722
#pragma warning restore HAA0502
@@ -20,7 +25,7 @@ public IEnumerable<MockTargetMember> GetMembersToMock(MockTarget target) {
2025
foreach (var member in target.Type.GetMembers()) {
2126
seen.Add(member.Name);
2227

23-
if (GetTargetMember(member, memberId) is not {} discovered)
28+
if (GetMockTargetMember(member, memberId) is not {} discovered)
2429
continue;
2530

2631
yield return discovered;
@@ -31,7 +36,7 @@ public IEnumerable<MockTargetMember> GetMembersToMock(MockTarget target) {
3136
foreach (var member in @interface.GetMembers()) {
3237
if (!seen.Add(member.Name))
3338
throw Exceptions.NotSupported($"Type member {@interface.Name}.{member.Name} is hidden or overloaded by another type member. This is not yet supported.");
34-
if (GetTargetMember(member, memberId) is not { } discovered)
39+
if (GetMockTargetMember(member, memberId) is not { } discovered)
3540
continue;
3641

3742
yield return discovered;
@@ -41,8 +46,8 @@ public IEnumerable<MockTargetMember> GetMembersToMock(MockTarget target) {
4146
}
4247

4348
[PerformanceSensitive("")]
44-
private MockTargetMember? GetTargetMember(ISymbol member, int uniqueMemberId) => member switch {
45-
IMethodSymbol method => GetTargetMethod(method, uniqueMemberId),
49+
private MockTargetMember? GetMockTargetMember(ISymbol member, int uniqueMemberId) => member switch {
50+
IMethodSymbol method => GetMockTargetMethod(method, uniqueMemberId),
4651

4752
IPropertySymbol property => new(
4853
property, property.Name, property.Type,
@@ -58,7 +63,7 @@ public IEnumerable<MockTargetMember> GetMembersToMock(MockTarget target) {
5863
_ => throw Exceptions.MemberNotSupported(member)
5964
};
6065

61-
private MockTargetMember? GetTargetMethod(IMethodSymbol method, int uniqueMemberId) {
66+
private MockTargetMember? GetMockTargetMethod(IMethodSymbol method, int uniqueMemberId) {
6267
if (method.MethodKind != MethodKind.Ordinary)
6368
return null;
6469

Generators/MockGenerator.cs

+35-36
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ private static class DiagnosticDescriptors {
2727

2828
private readonly GeneratorCache<(IAssemblySymbol assembly, string? excludePattern), ImmutableArray<(string name, SourceText source)>> _mockedAssemblyCache = new("MockedAssemblyCache");
2929
private readonly GeneratorCache<INamedTypeSymbol, (string name, SourceText source)> _mockedTypeCache = new("MockedTypeCache", NamedTypeSymbolCacheKeyEqualityComparer.Default);
30-
private readonly MockClassGenerator _classGenerator = new();
30+
private readonly MockTargetModelFactory _modelFactory = new();
31+
private readonly MockClassGenerator _classGenerator;
3132

3233
public MockGenerator() {
3334
GeneratorLog.Log("MockGenerator constructor");
35+
_classGenerator = new(_modelFactory);
3436
}
3537

3638
public void Initialize(GeneratorInitializationContext context) {
@@ -62,47 +64,55 @@ private void ProcessAssemblyAttribute(AttributeData attribute, in GeneratorExecu
6264
case KnownTypes.GenerateMocksForAssemblyOfAttribute.Name:
6365
if (!KnownTypes.GenerateMocksForAssemblyOfAttribute.NamespaceMatches(attributeClass.ContainingNamespace))
6466
return;
65-
GenerateMocksForAttributeTargetAssembly(attribute, context);
67+
ProcessGenerateMocksForAssemblyAttribute(attribute, context);
6668
break;
6769

6870
case KnownTypes.GenerateMocksForTypesAttribute.Name:
6971
if (!KnownTypes.GenerateMocksForTypesAttribute.NamespaceMatches(attributeClass.ContainingNamespace))
7072
return;
71-
72-
if (attribute.ConstructorArguments.ElementAtOrDefault(0) is not { Kind: TypedConstantKind.Array, Values: var typeofConstants })
73-
return;
74-
75-
foreach (var typeofConstant in typeofConstants) {
76-
if (typeofConstant is not { Value: INamedTypeSymbol type })
77-
continue;
78-
GenerateMockForType(new MockTarget(type, GetFullTypeName(type)), assemblyCacheBuilder: null, context);
79-
}
73+
ProcessGenerateMocksForTypesAttribute(attribute, context);
8074
break;
8175
}
8276
}
8377

8478
[PerformanceSensitive("")]
85-
private void GenerateMocksForAttributeTargetAssembly(AttributeData attribute, in GeneratorExecutionContext context) {
79+
private void ProcessGenerateMocksForAssemblyAttribute(AttributeData attribute, GeneratorExecutionContext context) {
8680
// intermediate code state? just in case
8781
if (attribute.ConstructorArguments.ElementAtOrDefault(0).Value is not INamedTypeSymbol anyTypeInAssembly)
8882
return;
8983

90-
var targetAssembly = anyTypeInAssembly.ContainingAssembly;
9184
string? excludePattern = null;
9285
foreach (var named in attribute.NamedArguments) {
9386
if (named.Key == KnownTypes.GenerateMocksForAssemblyOfAttribute.NamedParameters.ExcludeRegex)
9487
excludePattern = named.Value.Value as string;
9588
}
9689

97-
if (_mockedAssemblyCache.TryGetValue((targetAssembly, excludePattern), out var sources)) {
98-
GeneratorLog.Log("Using cached mocks for assembly " + targetAssembly.Name);
90+
GenerateMocksForAssembly(anyTypeInAssembly.ContainingAssembly, excludePattern, attribute.ApplicationSyntaxReference, context);
91+
}
92+
93+
[PerformanceSensitive("")]
94+
private void ProcessGenerateMocksForTypesAttribute(AttributeData attribute, GeneratorExecutionContext context) {
95+
if (attribute.ConstructorArguments.ElementAtOrDefault(0) is not { Kind: TypedConstantKind.Array, Values: var typeConstants })
96+
return;
97+
98+
foreach (var typeConstant in typeConstants) {
99+
if (typeConstant is not { Value: INamedTypeSymbol type })
100+
continue;
101+
GenerateMockForType(_modelFactory.GetMockTarget(type), assemblyCacheBuilder: null, context);
102+
}
103+
}
104+
105+
[PerformanceSensitive("")]
106+
private void GenerateMocksForAssembly(IAssemblySymbol assembly, string? excludePattern, SyntaxReference? errorSyntaxReference, in GeneratorExecutionContext context) {
107+
if (_mockedAssemblyCache.TryGetValue((assembly, excludePattern), out var sources)) {
108+
GeneratorLog.Log("Using cached mocks for assembly " + assembly.Name);
99109
foreach (var (name, source) in sources) {
100110
context.AddSource(name, source);
101111
}
102112
return;
103113
}
104114

105-
GeneratorLog.Log("Generating mocks for assembly " + targetAssembly.Name);
115+
GeneratorLog.Log("Generating mocks for assembly " + assembly.Name);
106116

107117
Regex? excludeRegex;
108118
try {
@@ -111,20 +121,19 @@ private void GenerateMocksForAttributeTargetAssembly(AttributeData attribute, in
111121
#pragma warning restore HAA0502
112122
}
113123
catch (ArgumentException ex) {
114-
var attributeSyntax = attribute.ApplicationSyntaxReference;
115124
#pragma warning disable HAA0101 // Array allocation for params parameter -- Exceptional case: OK to allocate
116125
context.ReportDiagnostic(Diagnostic.Create(
117126
DiagnosticDescriptors.RegexPatternFailedToParse,
118-
attributeSyntax?.SyntaxTree.GetLocation(attributeSyntax.Span),
127+
errorSyntaxReference?.SyntaxTree.GetLocation(errorSyntaxReference.Span),
119128
excludePattern, ex.Message
120129
));
121130
#pragma warning restore HAA0101 // Array allocation for params parameter
122131
return;
123132
}
124133

125134
var assemblyCacheBuilder = ImmutableArray.CreateBuilder<(string, SourceText)>();
126-
GenerateMocksForNamespace(targetAssembly.GlobalNamespace, excludeRegex, assemblyCacheBuilder, context);
127-
_mockedAssemblyCache.TryAdd((targetAssembly, excludePattern), assemblyCacheBuilder.ToImmutable());
135+
GenerateMocksForNamespace(assembly.GlobalNamespace, excludeRegex, assemblyCacheBuilder, context);
136+
_mockedAssemblyCache.TryAdd((assembly, excludePattern), assemblyCacheBuilder.ToImmutable());
128137
}
129138

130139
[PerformanceSensitive("")]
@@ -139,9 +148,10 @@ in GeneratorExecutionContext context
139148
#pragma warning restore HAA0401
140149
switch (member) {
141150
case INamedTypeSymbol type:
142-
if (!ShouldIncludeInMocksForAssembly(type, excludeRegex, out var fullName, context))
151+
var target = _modelFactory.GetMockTarget(type);
152+
if (!ShouldIncludeInMocksForAssembly(target, excludeRegex, context))
143153
continue;
144-
GenerateMockForType(new MockTarget(type, fullName!), assemblyCacheBuilder, context);
154+
GenerateMockForType(target, assemblyCacheBuilder, context);
145155
break;
146156

147157
case INamespaceSymbol nested:
@@ -152,13 +162,8 @@ in GeneratorExecutionContext context
152162
}
153163

154164
[PerformanceSensitive("")]
155-
private bool ShouldIncludeInMocksForAssembly(
156-
INamedTypeSymbol type,
157-
Regex? excludeRegex,
158-
out string? fullName,
159-
in GeneratorExecutionContext context
160-
) {
161-
fullName = null;
165+
private bool ShouldIncludeInMocksForAssembly(MockTarget target, Regex? excludeRegex, in GeneratorExecutionContext context) {
166+
var type = target.Type;
162167
if (type.TypeKind != TypeKind.Interface)
163168
return false;
164169

@@ -169,18 +174,12 @@ in GeneratorExecutionContext context
169174
return false;
170175
}
171176

172-
fullName = GetFullTypeName(type);
173-
if (excludeRegex != null && excludeRegex.IsMatch(fullName))
177+
if (excludeRegex != null && excludeRegex.IsMatch(target.FullTypeName))
174178
return false;
175179

176180
return true;
177181
}
178182

179-
[PerformanceSensitive("")]
180-
private static string GetFullTypeName(INamedTypeSymbol type) {
181-
return type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
182-
}
183-
184183
[PerformanceSensitive("")]
185184
private void GenerateMockForType(
186185
MockTarget target,

0 commit comments

Comments
 (0)