diff --git a/src/Analyzers/SetExplicitMockBehaviorAnalyzer.cs b/src/Analyzers/SetExplicitMockBehaviorAnalyzer.cs index 5670732..3ba6a61 100644 --- a/src/Analyzers/SetExplicitMockBehaviorAnalyzer.cs +++ b/src/Analyzers/SetExplicitMockBehaviorAnalyzer.cs @@ -1,4 +1,5 @@ -using Microsoft.CodeAnalysis.Operations; +using System.Diagnostics.CodeAnalysis; +using Microsoft.CodeAnalysis.Operations; namespace Moq.Analyzers; @@ -43,94 +44,114 @@ private static void RegisterCompilationStartAction(CompilationStartAnalysisConte } // Look for the MockBehavior type and provide it to Analyze to avoid looking it up multiple times. - INamedTypeSymbol? mockBehaviorSymbol = knownSymbols.MockBehavior; - if (mockBehaviorSymbol is null) + if (knownSymbols.MockBehavior is null) { return; } - // Look for the Mock.Of() method and provide it to Analyze to avoid looking it up multiple times. - ImmutableArray ofMethods = knownSymbols.MockOf; + context.RegisterOperationAction(context => AnalyzeObjectCreation(context, knownSymbols), OperationKind.ObjectCreation); - ImmutableArray mockTypes = - new INamedTypeSymbol?[] { knownSymbols.Mock1, knownSymbols.MockRepository } - .WhereNotNull() - .ToImmutableArray(); - - context.RegisterOperationAction( - context => AnalyzeNewObject(context, mockTypes, mockBehaviorSymbol), - OperationKind.ObjectCreation); - - if (!ofMethods.IsEmpty) - { - context.RegisterOperationAction( - context => AnalyzeInvocation(context, ofMethods, mockBehaviorSymbol), - OperationKind.Invocation); - } + context.RegisterOperationAction(context => AnalyzeInvocation(context, knownSymbols), OperationKind.Invocation); } - private static void AnalyzeNewObject(OperationAnalysisContext context, ImmutableArray mockTypes, INamedTypeSymbol mockBehaviorSymbol) + private static void AnalyzeObjectCreation(OperationAnalysisContext context, MoqKnownSymbols knownSymbols) { - if (context.Operation is not IObjectCreationOperation creationOperation) + if (context.Operation is not IObjectCreationOperation creation) { return; } - if (creationOperation.Type is not INamedTypeSymbol namedType) + if (creation.Type is null || + creation.Constructor is null || + !(creation.Type.IsInstanceOf(knownSymbols.Mock1) || creation.Type.IsInstanceOf(knownSymbols.MockRepository))) { + // We could expand this check to include any method that accepts a MockBehavior parameter. + // Leaving it narrowly scoped for now to avoid false positives and potential performance problems. return; } - if (!namedType.IsInstanceOf(mockTypes)) + AnalyzeCore(context, creation.Constructor, creation.Arguments, knownSymbols); + } + + private static void AnalyzeInvocation(OperationAnalysisContext context, MoqKnownSymbols knownSymbols) + { + if (context.Operation is not IInvocationOperation invocation) { return; } - foreach (IArgumentOperation argument in creationOperation.Arguments) + if (!invocation.TargetMethod.IsInstanceOf(knownSymbols.MockOf, out IMethodSymbol? match)) { - if (argument.Value is IFieldReferenceOperation fieldReferenceOperation) - { - ISymbol field = fieldReferenceOperation.Member; - if (field.ContainingType.IsInstanceOf(mockBehaviorSymbol) && IsExplicitBehavior(field.Name)) - { - return; - } - } + // We could expand this check to include any method that accepts a MockBehavior parameter. + // Leaving it narrowly scoped for now to avoid false positives and potential performance problems. + return; } - context.ReportDiagnostic(creationOperation.CreateDiagnostic(Rule)); + AnalyzeCore(context, match, invocation.Arguments, knownSymbols); } - private static void AnalyzeInvocation(OperationAnalysisContext context, ImmutableArray wellKnownOfMethods, INamedTypeSymbol mockBehaviorSymbol) + [SuppressMessage("Design", "MA0051:Method is too long", Justification = "Should be fixed. Ignoring for now to avoid additional churn as part of larger refactor.")] + private static void AnalyzeCore(OperationAnalysisContext context, IMethodSymbol target, ImmutableArray arguments, MoqKnownSymbols knownSymbols) { - if (context.Operation is not IInvocationOperation invocationOperation) + // Check if the target method has a parameter of type MockBehavior + IParameterSymbol? mockParameter = target.Parameters.DefaultIfNotSingle(parameter => parameter.Type.IsInstanceOf(knownSymbols.MockBehavior)); + + // If the target method doesn't have a MockBehavior parameter, check if there's an overload that does + if (mockParameter is null && target.TryGetOverloadWithParameterOfType(knownSymbols.MockBehavior!, out IMethodSymbol? methodMatch, out _, cancellationToken: context.CancellationToken)) { + if (!methodMatch.TryGetParameterOfType(knownSymbols.MockBehavior!, out IParameterSymbol? parameterMatch, cancellationToken: context.CancellationToken)) + { + return; + } + + ImmutableDictionary properties = new DiagnosticEditProperties + { + TypeOfEdit = DiagnosticEditProperties.EditType.Insert, + EditPosition = parameterMatch.Ordinal, + }.ToImmutableDictionary(); + + // Using a method that doesn't accept a MockBehavior parameter, however there's an overload that does + context.ReportDiagnostic(context.Operation.CreateDiagnostic(Rule, properties)); return; } - IMethodSymbol targetMethod = invocationOperation.TargetMethod; - if (!targetMethod.IsInstanceOf(wellKnownOfMethods)) + IArgumentOperation? mockArgument = arguments.DefaultIfNotSingle(argument => argument.Parameter.IsInstanceOf(mockParameter)); + + // Is the behavior set via a default value? + if (mockArgument?.ArgumentKind == ArgumentKind.DefaultValue && mockArgument.Value.WalkDownConversion().ConstantValue.Value == knownSymbols.MockBehaviorDefault?.ConstantValue) { - return; + if (!target.TryGetParameterOfType(knownSymbols.MockBehavior!, out IParameterSymbol? parameterMatch, cancellationToken: context.CancellationToken)) + { + return; + } + + ImmutableDictionary properties = new DiagnosticEditProperties + { + TypeOfEdit = DiagnosticEditProperties.EditType.Insert, + EditPosition = parameterMatch.Ordinal, + }.ToImmutableDictionary(); + + context.ReportDiagnostic(context.Operation.CreateDiagnostic(Rule, properties)); } - foreach (IArgumentOperation argument in invocationOperation.Arguments) + // NOTE: This logic can't handle indirection (e.g. var x = MockBehavior.Default; new Mock(x);). We can't use the constant value either, + // as Loose and Default share the same enum value: `1`. Being more accurate I believe requires data flow analysis. + // + // The operation specifies a MockBehavior; is it MockBehavior.Default? + if (mockArgument?.DescendantsAndSelf().OfType().Any(argument => argument.Member.IsInstanceOf(knownSymbols.MockBehaviorDefault)) == true) { - if (argument.Value is IFieldReferenceOperation fieldReferenceOperation) + if (!target.TryGetParameterOfType(knownSymbols.MockBehavior!, out IParameterSymbol? parameterMatch, cancellationToken: context.CancellationToken)) { - ISymbol field = fieldReferenceOperation.Member; - if (field.ContainingType.IsInstanceOf(mockBehaviorSymbol) && IsExplicitBehavior(field.Name)) - { - return; - } + return; } - } - context.ReportDiagnostic(invocationOperation.CreateDiagnostic(Rule)); - } + ImmutableDictionary properties = new DiagnosticEditProperties + { + TypeOfEdit = DiagnosticEditProperties.EditType.Replace, + EditPosition = parameterMatch.Ordinal, + }.ToImmutableDictionary(); - private static bool IsExplicitBehavior(string symbolName) - { - return string.Equals(symbolName, "Loose", StringComparison.Ordinal) || string.Equals(symbolName, "Strict", StringComparison.Ordinal); + context.ReportDiagnostic(context.Operation.CreateDiagnostic(Rule, properties)); + } } } diff --git a/src/Analyzers/SquiggleCop.Baseline.yaml b/src/Analyzers/SquiggleCop.Baseline.yaml index 2af6ab6..ef9bfdd 100644 --- a/src/Analyzers/SquiggleCop.Baseline.yaml +++ b/src/Analyzers/SquiggleCop.Baseline.yaml @@ -359,7 +359,7 @@ - {Id: HAA0603, Title: Delegate allocation from a method group, Category: Performance, DefaultSeverity: Warning, IsEnabledByDefault: true, EffectiveSeverities: [Error], IsEverSuppressed: false} - {Id: HAA0604, Title: Delegate allocation from a method group, Category: Performance, DefaultSeverity: Note, IsEnabledByDefault: true, EffectiveSeverities: [Note], IsEverSuppressed: false} - {Id: IDE0004, Title: Remove Unnecessary Cast, Category: Style, DefaultSeverity: Note, IsEnabledByDefault: true, EffectiveSeverities: [Note], IsEverSuppressed: true} -- {Id: IDE0005, Title: Using directive is unnecessary., Category: Style, DefaultSeverity: Note, IsEnabledByDefault: true, EffectiveSeverities: [Note], IsEverSuppressed: false} +- {Id: IDE0005, Title: Using directive is unnecessary., Category: Style, DefaultSeverity: Note, IsEnabledByDefault: true, EffectiveSeverities: [Note], IsEverSuppressed: true} - {Id: IDE0005_gen, Title: Using directive is unnecessary., Category: Style, DefaultSeverity: Note, IsEnabledByDefault: true, EffectiveSeverities: [Note], IsEverSuppressed: true} - {Id: IDE0007, Title: Use implicit type, Category: Style, DefaultSeverity: Note, IsEnabledByDefault: true, EffectiveSeverities: [Note], IsEverSuppressed: false} - {Id: IDE0008, Title: Use explicit type, Category: Style, DefaultSeverity: Note, IsEnabledByDefault: true, EffectiveSeverities: [Note], IsEverSuppressed: false} diff --git a/src/CodeFixes/CallbackSignatureShouldMatchMockedMethodCodeFix.cs b/src/CodeFixes/CallbackSignatureShouldMatchMockedMethodFixer.cs similarity index 96% rename from src/CodeFixes/CallbackSignatureShouldMatchMockedMethodCodeFix.cs rename to src/CodeFixes/CallbackSignatureShouldMatchMockedMethodFixer.cs index 3fe5007..0476b19 100644 --- a/src/CodeFixes/CallbackSignatureShouldMatchMockedMethodCodeFix.cs +++ b/src/CodeFixes/CallbackSignatureShouldMatchMockedMethodFixer.cs @@ -9,9 +9,9 @@ namespace Moq.CodeFixes; /// /// Fixes for CallbackSignatureShouldMatchMockedMethodAnalyzer (Moq1100). /// -[ExportCodeFixProvider(LanguageNames.CSharp, Name = nameof(CallbackSignatureShouldMatchMockedMethodCodeFix))] +[ExportCodeFixProvider(LanguageNames.CSharp, Name = nameof(CallbackSignatureShouldMatchMockedMethodFixer))] [Shared] -public class CallbackSignatureShouldMatchMockedMethodCodeFix : CodeFixProvider +public class CallbackSignatureShouldMatchMockedMethodFixer : CodeFixProvider { /// public sealed override ImmutableArray FixableDiagnosticIds => ImmutableArray.Create(DiagnosticIds.BadCallbackParameters); diff --git a/src/CodeFixes/CodeFixContextExtensions.cs b/src/CodeFixes/CodeFixContextExtensions.cs new file mode 100644 index 0000000..9616c4a --- /dev/null +++ b/src/CodeFixes/CodeFixContextExtensions.cs @@ -0,0 +1,14 @@ +using System.Diagnostics.CodeAnalysis; +using Microsoft.CodeAnalysis.CodeFixes; + +namespace Moq.CodeFixes; + +internal static class CodeFixContextExtensions +{ + public static bool TryGetEditProperties(this CodeFixContext context, [NotNullWhen(true)] out DiagnosticEditProperties? editProperties) + { + ImmutableDictionary properties = context.Diagnostics[0].Properties; + + return DiagnosticEditProperties.TryGetFromImmutableDictionary(properties, out editProperties); + } +} diff --git a/src/CodeFixes/SetExplicitMockBehaviorFixer.cs b/src/CodeFixes/SetExplicitMockBehaviorFixer.cs new file mode 100644 index 0000000..a6aa162 --- /dev/null +++ b/src/CodeFixes/SetExplicitMockBehaviorFixer.cs @@ -0,0 +1,107 @@ +using System.Composition; +using Microsoft.CodeAnalysis.CodeActions; +using Microsoft.CodeAnalysis.CodeFixes; +using Microsoft.CodeAnalysis.Editing; +using Microsoft.CodeAnalysis.Simplification; + +namespace Moq.CodeFixes; + +/// +/// Fixes for . +/// +[ExportCodeFixProvider(LanguageNames.CSharp, Name = nameof(SetExplicitMockBehaviorFixer))] +[Shared] +public class SetExplicitMockBehaviorFixer : CodeFixProvider +{ + private enum BehaviorType + { + Loose, + Strict, + } + + /// + public override ImmutableArray FixableDiagnosticIds => ImmutableArray.Create(DiagnosticIds.SetExplicitMockBehavior); + + /// + public override FixAllProvider? GetFixAllProvider() => WellKnownFixAllProviders.BatchFixer; + + /// + public override async Task RegisterCodeFixesAsync(CodeFixContext context) + { + SyntaxNode? root = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false); + SyntaxNode? nodeToFix = root?.FindNode(context.Span, getInnermostNodeForTie: true); + + if (!context.TryGetEditProperties(out DiagnosticEditProperties? editProperties)) + { + return; + } + + if (nodeToFix is null) + { + return; + } + + context.RegisterCodeFix(new SetExplicitMockBehaviorCodeAction("Set MockBehavior (Loose)", context.Document, nodeToFix, BehaviorType.Loose, editProperties.TypeOfEdit, editProperties.EditPosition), context.Diagnostics); + context.RegisterCodeFix(new SetExplicitMockBehaviorCodeAction("Set MockBehavior (Strict)", context.Document, nodeToFix, BehaviorType.Strict, editProperties.TypeOfEdit, editProperties.EditPosition), context.Diagnostics); + } + + private sealed class SetExplicitMockBehaviorCodeAction : CodeAction + { + private readonly Document _document; + private readonly SyntaxNode _nodeToFix; + private readonly BehaviorType _behaviorType; + private readonly DiagnosticEditProperties.EditType _editType; + private readonly int _position; + + public SetExplicitMockBehaviorCodeAction(string title, Document document, SyntaxNode nodeToFix, BehaviorType behaviorType, DiagnosticEditProperties.EditType editType, int position) + { + Title = title; + _document = document; + _nodeToFix = nodeToFix; + _behaviorType = behaviorType; + _editType = editType; + _position = position; + } + + public override string Title { get; } + + public override string? EquivalenceKey => Title; + + protected override async Task GetChangedDocumentAsync(CancellationToken cancellationToken) + { + DocumentEditor editor = await DocumentEditor.CreateAsync(_document, cancellationToken).ConfigureAwait(false); + SemanticModel? model = await _document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false); + IOperation? operation = model?.GetOperation(_nodeToFix, cancellationToken); + + MoqKnownSymbols knownSymbols = new(editor.SemanticModel.Compilation); + + if (knownSymbols.MockBehavior is null + || knownSymbols.MockBehaviorDefault is null + || knownSymbols.MockBehaviorLoose is null + || knownSymbols.MockBehaviorStrict is null + || operation is null) + { + return _document; + } + + SyntaxNode behavior = _behaviorType switch + { + BehaviorType.Loose => editor.Generator.MemberAccessExpression(knownSymbols.MockBehaviorLoose), + BehaviorType.Strict => editor.Generator.MemberAccessExpression(knownSymbols.MockBehaviorStrict), + _ => throw new InvalidOperationException(), + }; + + SyntaxNode argument = editor.Generator.Argument(behavior); + + SyntaxNode newNode = _editType switch + { + DiagnosticEditProperties.EditType.Insert => editor.Generator.InsertArguments(operation, _position, argument), + DiagnosticEditProperties.EditType.Replace => editor.Generator.ReplaceArgument(operation, _position, argument), + _ => throw new InvalidOperationException(), + }; + + editor.ReplaceNode(_nodeToFix, newNode.WithAdditionalAnnotations(Simplifier.Annotation)); + return editor.GetChangedDocument(); + } + } +} diff --git a/src/CodeFixes/SyntaxGeneratorExtensions.cs b/src/CodeFixes/SyntaxGeneratorExtensions.cs new file mode 100644 index 0000000..cbb2d6a --- /dev/null +++ b/src/CodeFixes/SyntaxGeneratorExtensions.cs @@ -0,0 +1,71 @@ +using Microsoft.CodeAnalysis.Editing; + +namespace Moq.CodeFixes; + +internal static class SyntaxGeneratorExtensions +{ + public static SyntaxNode MemberAccessExpression(this SyntaxGenerator generator, IFieldSymbol fieldSymbol) + { + return generator.MemberAccessExpression(generator.TypeExpression(fieldSymbol.Type), generator.IdentifierName(fieldSymbol.Name)); + } + + public static SyntaxNode InsertArguments(this SyntaxGenerator generator, IOperation operation, int index, params SyntaxNode[] items) + { + // Ideally we could modify argument lists only using the IOperation APIs, but I haven't figured out a way to do that yet. + return generator.InsertArguments(operation.Syntax, index, items); + } + + public static SyntaxNode InsertArguments(this SyntaxGenerator generator, SyntaxNode syntax, int index, params SyntaxNode[] items) + { + if (Array.Exists(items, item => item is not ArgumentSyntax)) + { + throw new ArgumentException("Must all be of type ArgumentSyntax", nameof(items)); + } + + if (syntax is InvocationExpressionSyntax invocation) + { + SeparatedSyntaxList arguments = invocation.ArgumentList.Arguments; + arguments = arguments.InsertRange(index, items.OfType()); + return invocation.WithArgumentList(SyntaxFactory.ArgumentList(arguments)); + } + + if (syntax is ObjectCreationExpressionSyntax creation) + { + SeparatedSyntaxList arguments = creation.ArgumentList?.Arguments ?? []; + arguments = arguments.InsertRange(index, items.OfType()); + return creation.WithArgumentList(SyntaxFactory.ArgumentList(arguments)); + } + + throw new ArgumentException($"Must be of type {nameof(InvocationExpressionSyntax)} or {nameof(ObjectCreationExpressionSyntax)} but is of type {syntax.GetType().Name}", nameof(syntax)); + } + + public static SyntaxNode ReplaceArgument(this SyntaxGenerator generator, IOperation operation, int index, SyntaxNode item) + { + // Ideally we could modify argument lists only using the IOperation APIs, but I haven't figured out a way to do that yet. + return generator.ReplaceArgument(operation.Syntax, index, item); + } + + public static SyntaxNode ReplaceArgument(this SyntaxGenerator generator, SyntaxNode syntax, int index, SyntaxNode item) + { + if (item is not ArgumentSyntax argument) + { + throw new ArgumentException("Must be of type ArgumentSyntax", nameof(item)); + } + + if (syntax is InvocationExpressionSyntax invocation) + { + SeparatedSyntaxList arguments = invocation.ArgumentList.Arguments; + arguments = arguments.RemoveAt(index).Insert(index, argument); + return invocation.WithArgumentList(SyntaxFactory.ArgumentList(arguments)); + } + + if (syntax is ObjectCreationExpressionSyntax creation) + { + SeparatedSyntaxList arguments = creation.ArgumentList?.Arguments ?? []; + arguments = arguments.RemoveAt(index).Insert(index, argument); + return creation.WithArgumentList(SyntaxFactory.ArgumentList(arguments)); + } + + throw new ArgumentException($"Must be of type {nameof(InvocationExpressionSyntax)} or {nameof(ObjectCreationExpressionSyntax)} but is of type {syntax.GetType().Name}", nameof(syntax)); + } +} diff --git a/src/Common/DiagnosticEditProperties.cs b/src/Common/DiagnosticEditProperties.cs new file mode 100644 index 0000000..b485c54 --- /dev/null +++ b/src/Common/DiagnosticEditProperties.cs @@ -0,0 +1,87 @@ +using System.Diagnostics.CodeAnalysis; +using System.Globalization; + +namespace Moq.Analyzers.Common; + +internal record class DiagnosticEditProperties +{ + internal static readonly string EditTypeKey = nameof(EditTypeKey); + internal static readonly string EditPositionKey = nameof(EditPositionKey); + + /// + /// The type of edit for the code fix to perform. + /// + internal enum EditType + { + /// + /// Insert a new parameter, moving the existing parameters to position N+1. + /// + Insert, + + /// + /// Replace the parameter without changing the overall number of parameters. + /// + Replace, + } + + /// + /// Gets the type of edit operation to perform. + /// + public EditType TypeOfEdit { get; init; } + + /// + /// Gets the zero-based position where the edit should be applied. + /// + public int EditPosition { get; init; } + + /// + /// Returns the current object as an . + /// + /// The current objbect as an immutable dictionary. + public ImmutableDictionary ToImmutableDictionary() + { + return new Dictionary(StringComparer.Ordinal) + { + { EditTypeKey, TypeOfEdit.ToString() }, + { EditPositionKey, EditPosition.ToString(CultureInfo.InvariantCulture) }, + }.ToImmutableDictionary(); + } + + /// + /// Tries to convert an immuatble dictionary to a . + /// + /// The dictionary to try to convert. + /// The output edit properties if parsing suceeded, otherwise null. + /// true if parsing succeeded; false otherwise. + public static bool TryGetFromImmutableDictionary(ImmutableDictionary dictionary, [NotNullWhen(true)] out DiagnosticEditProperties? editProperties) + { + editProperties = null; + if (!dictionary.TryGetValue(EditTypeKey, out string? editTypeString)) + { + return false; + } + + if (!dictionary.TryGetValue(EditPositionKey, out string? editPositionString)) + { + return false; + } + + if (!Enum.TryParse(editTypeString, out EditType editType)) + { + return false; + } + + if (!int.TryParse(editPositionString, NumberStyles.Integer, CultureInfo.InvariantCulture, out int editPosition)) + { + return false; + } + + editProperties = new DiagnosticEditProperties + { + TypeOfEdit = editType, + EditPosition = editPosition, + }; + + return true; + } +} diff --git a/src/Common/EnumerableExtensions.cs b/src/Common/EnumerableExtensions.cs index d02d035..d523b48 100644 --- a/src/Common/EnumerableExtensions.cs +++ b/src/Common/EnumerableExtensions.cs @@ -1,11 +1,28 @@ -namespace Moq.Analyzers.Common; +using System.Diagnostics.CodeAnalysis; + +namespace Moq.Analyzers.Common; internal static class EnumerableExtensions { /// public static TSource? DefaultIfNotSingle(this IEnumerable source) { - return source.DefaultIfNotSingle(_ => true); + return source.DefaultIfNotSingle(static _ => true); + } + + /// + public static TSource? DefaultIfNotSingle(this ImmutableArray source) + { + return source.DefaultIfNotSingle(static _ => true); + } + + /// + /// The collection to enumerate. + /// A function to test each element for a condition. + [SuppressMessage("Performance", "ECS0900:Minimize boxing and unboxing", Justification = "Should revisit. Suppressing for now to unblock refactor.")] + public static TSource? DefaultIfNotSingle(this ImmutableArray source, Func predicate) + { + return source.AsEnumerable().DefaultIfNotSingle(predicate); } /// diff --git a/src/Common/IMethodSymbolExtensions.cs b/src/Common/IMethodSymbolExtensions.cs new file mode 100644 index 0000000..1a6a974 --- /dev/null +++ b/src/Common/IMethodSymbolExtensions.cs @@ -0,0 +1,99 @@ +using System.Diagnostics.CodeAnalysis; + +namespace Moq.Analyzers.Common; + +internal static class IMethodSymbolExtensions +{ + /// + /// Get all overloads of a given . + /// + /// The method to inspect for overloads. + /// + /// The to use for the comparison. Defaults to . + /// + /// + /// A collection of representing the overloads of the given method. + /// + public static IEnumerable Overloads(this IMethodSymbol? method, SymbolEqualityComparer? comparer = null) + { + comparer ??= SymbolEqualityComparer.Default; + + IEnumerable? methods = method?.ContainingType?.GetMembers(method.Name).OfType(); + + if (methods is not null) + { + foreach (IMethodSymbol member in methods) + { + if (!comparer.Equals(member, method)) + { + yield return member; + } + } + } + } + + /// + /// Check if, given a set of overloads, any overload has a parameter of the given type. + /// + /// The method to inspect for overloads. + /// The set of candidate methods to check. + /// The type to check for in the parameters. + /// The matching method overload. if no matches. + /// The matching parameter. if no matches. + /// The to use for equality. + /// A to use to cancel long running operations. + /// if a method in has a parameter of type . Otherwise . + public static bool TryGetOverloadWithParameterOfType(this IMethodSymbol method, IEnumerable overloads, INamedTypeSymbol type, [NotNullWhen(true)] out IMethodSymbol? methodMatch, [NotNullWhen(true)] out IParameterSymbol? parameterMatch, SymbolEqualityComparer? comparer = null, CancellationToken cancellationToken = default) + { + comparer ??= SymbolEqualityComparer.Default; + + foreach (IMethodSymbol overload in overloads) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (comparer.Equals(method, overload)) + { + continue; + } + + foreach (IParameterSymbol parameter in overload.Parameters) + { + if (comparer.Equals(parameter.Type, type)) + { + methodMatch = overload; + parameterMatch = parameter; + return true; + } + } + } + + methodMatch = null; + parameterMatch = null; + return false; + } + + /// + public static bool TryGetOverloadWithParameterOfType(this IMethodSymbol method, INamedTypeSymbol type, [NotNullWhen(true)] out IMethodSymbol? methodMatch, [NotNullWhen(true)] out IParameterSymbol? parameterMatch, SymbolEqualityComparer? comparer = null, CancellationToken cancellationToken = default) + { + return method.TryGetOverloadWithParameterOfType(method.Overloads(), type, out methodMatch, out parameterMatch, comparer, cancellationToken); + } + + public static bool TryGetParameterOfType(this IMethodSymbol method, INamedTypeSymbol type, [NotNullWhen(true)] out IParameterSymbol? match, SymbolEqualityComparer? comparer = null, CancellationToken cancellationToken = default) + { + comparer ??= SymbolEqualityComparer.Default; + + foreach (IParameterSymbol parameter in method.Parameters) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (comparer.Equals(parameter.Type, type)) + { + match = parameter; + return true; + } + } + + match = null; + return false; + } +} diff --git a/src/Common/IOperationExtensions.cs b/src/Common/IOperationExtensions.cs new file mode 100644 index 0000000..17d0250 --- /dev/null +++ b/src/Common/IOperationExtensions.cs @@ -0,0 +1,21 @@ +using Microsoft.CodeAnalysis.Operations; + +namespace Moq.Analyzers.Common; + +internal static class IOperationExtensions +{ + /// + /// Walks down consecutive conversion operations until an operand is reached that isn't a conversion operation. + /// + /// The starting operation. + /// The inner non conversion operation or the starting operation if it wasn't a conversion operation. + public static IOperation WalkDownConversion(this IOperation operation) + { + while (operation is IConversionOperation conversionOperation) + { + operation = conversionOperation.Operand; + } + + return operation; + } +} diff --git a/src/Common/ISymbolExtensions.cs b/src/Common/ISymbolExtensions.cs index 6cc992e..f446b27 100644 --- a/src/Common/ISymbolExtensions.cs +++ b/src/Common/ISymbolExtensions.cs @@ -1,4 +1,5 @@ -using System.Runtime.CompilerServices; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; namespace Moq.Analyzers.Common; @@ -26,19 +27,25 @@ public static bool IsInstanceOf(this ISymbol? symbol, TSymbol? other, S { symbolEqualityComparer ??= SymbolEqualityComparer.Default; - if (symbol is IMethodSymbol methodSymbol) + if (symbol is IMethodSymbol method) { - return symbolEqualityComparer.Equals(methodSymbol.OriginalDefinition, other); + return symbolEqualityComparer.Equals(method.OriginalDefinition, other); } - if (symbol is INamedTypeSymbol namedTypeSymbol) + if (symbol is IParameterSymbol parameter && other is IParameterSymbol otherParameter) { - if (namedTypeSymbol.IsGenericType) + return parameter.Ordinal == otherParameter.Ordinal + && symbolEqualityComparer.Equals(parameter.OriginalDefinition, otherParameter.OriginalDefinition); + } + + if (symbol is INamedTypeSymbol namedType) + { + if (namedType.IsGenericType) { - namedTypeSymbol = namedTypeSymbol.ConstructedFrom; + namedType = namedType.ConstructedFrom; } - return symbolEqualityComparer.Equals(namedTypeSymbol, other); + return symbolEqualityComparer.Equals(namedType, other); } return symbolEqualityComparer.Equals(symbol, other); @@ -49,13 +56,38 @@ public static bool IsInstanceOf(this ISymbol? symbol, TSymbol? other, S /// /// The symbols to compare to. Returns if matches any of others. /// + /// + /// The matching symbol if is an instance of any of . otherwise. + /// /// The to use for equality. - public static bool IsInstanceOf(this ISymbol symbol, ImmutableArray others, SymbolEqualityComparer? symbolEqualityComparer = null) + public static bool IsInstanceOf(this ISymbol symbol, ImmutableArray others, [NotNullWhen(true)] out TSymbol? matchingSymbol, SymbolEqualityComparer? symbolEqualityComparer = null) where TSymbol : class, ISymbol { symbolEqualityComparer ??= SymbolEqualityComparer.Default; - return others.Any(other => symbol.IsInstanceOf(other, symbolEqualityComparer)); + foreach (TSymbol other in others) + { + if (symbol.IsInstanceOf(other, symbolEqualityComparer)) + { + matchingSymbol = other; + return true; + } + } + + matchingSymbol = null; + return false; + } + + /// + /// The symbol to compare. + /// + /// The symbols to compare to. Returns if matches any of others. + /// + /// The to use for equality. + public static bool IsInstanceOf(this ISymbol symbol, ImmutableArray others, SymbolEqualityComparer? symbolEqualityComparer = null) + where TSymbol : class, ISymbol + { + return symbol.IsInstanceOf(others, out _, symbolEqualityComparer); } public static bool IsConstructor(this ISymbol symbol) diff --git a/src/Common/InvocationExpressionSyntaxExtensions.cs b/src/Common/InvocationExpressionSyntaxExtensions.cs index 87e432e..05239c6 100644 --- a/src/Common/InvocationExpressionSyntaxExtensions.cs +++ b/src/Common/InvocationExpressionSyntaxExtensions.cs @@ -1,4 +1,6 @@ -namespace Moq.Analyzers.Common; +using System.Diagnostics.CodeAnalysis; + +namespace Moq.Analyzers.Common; /// /// Extension methods for s. diff --git a/tests/Moq.Analyzers.Test/CallbackSignatureShouldMatchMockedMethodCodeFixTests.cs b/tests/Moq.Analyzers.Test/CallbackSignatureShouldMatchMockedMethodCodeFixTests.cs index f0f7b87..5356752 100644 --- a/tests/Moq.Analyzers.Test/CallbackSignatureShouldMatchMockedMethodCodeFixTests.cs +++ b/tests/Moq.Analyzers.Test/CallbackSignatureShouldMatchMockedMethodCodeFixTests.cs @@ -1,4 +1,4 @@ -using Verifier = Moq.Analyzers.Test.Helpers.CodeFixVerifier; +using Verifier = Moq.Analyzers.Test.Helpers.CodeFixVerifier; namespace Moq.Analyzers.Test; diff --git a/tests/Moq.Analyzers.Test/SetExplicitMockBehaviorAnalyzerTests.cs b/tests/Moq.Analyzers.Test/SetExplicitMockBehaviorAnalyzerTests.cs deleted file mode 100644 index 0af47f6..0000000 --- a/tests/Moq.Analyzers.Test/SetExplicitMockBehaviorAnalyzerTests.cs +++ /dev/null @@ -1,58 +0,0 @@ -using Verifier = Moq.Analyzers.Test.Helpers.AnalyzerVerifier; - -namespace Moq.Analyzers.Test; - -public class SetExplicitMockBehaviorAnalyzerTests -{ - public static IEnumerable TestData() - { - IEnumerable mockConstructors = new object[][] - { - ["""{|Moq1400:new Mock()|};"""], - ["""{|Moq1400:new Mock(MockBehavior.Default)|};"""], - ["""new Mock(MockBehavior.Loose);"""], - ["""new Mock(MockBehavior.Strict);"""], - }.WithNamespaces().WithMoqReferenceAssemblyGroups(); - - IEnumerable fluentBuilders = new object[][] - { - ["""{|Moq1400:Mock.Of()|};"""], - ["""{|Moq1400:Mock.Of(MockBehavior.Default)|};"""], - ["""Mock.Of(MockBehavior.Loose);"""], - ["""Mock.Of(MockBehavior.Strict);"""], - }.WithNamespaces().WithNewMoqReferenceAssemblyGroups(); - - IEnumerable mockRepositories = new object[][] - { - ["""{|Moq1400:new MockRepository(MockBehavior.Default)|};"""], - ["""new MockRepository(MockBehavior.Loose);"""], - ["""new MockRepository(MockBehavior.Strict);"""], - }.WithNamespaces().WithNewMoqReferenceAssemblyGroups(); - - return mockConstructors.Union(fluentBuilders).Union(mockRepositories); - } - - [Theory] - [MemberData(nameof(TestData))] - public async Task ShouldAnalyzeMocksWithoutExplictMockBehavior(string referenceAssemblyGroup, string @namespace, string mock) - { - await Verifier.VerifyAnalyzerAsync( - $$""" - {{@namespace}} - - public interface ISample - { - int Calculate(int a, int b); - } - - internal class UnitTest - { - private void Test() - { - {{mock}} - } - } - """, - referenceAssemblyGroup); - } -} diff --git a/tests/Moq.Analyzers.Test/SetExplicitMockBehaviorCodeFixTests.cs b/tests/Moq.Analyzers.Test/SetExplicitMockBehaviorCodeFixTests.cs new file mode 100644 index 0000000..a412524 --- /dev/null +++ b/tests/Moq.Analyzers.Test/SetExplicitMockBehaviorCodeFixTests.cs @@ -0,0 +1,136 @@ +using Verifier = Moq.Analyzers.Test.Helpers.CodeFixVerifier; + +namespace Moq.Analyzers.Test; + +public class SetExplicitMockBehaviorCodeFixTests +{ + private readonly ITestOutputHelper _output; + + public SetExplicitMockBehaviorCodeFixTests(ITestOutputHelper output) + { + _output = output; + } + + public static IEnumerable TestData() + { + IEnumerable mockConstructors = new object[][] + { + [ + """{|Moq1400:new Mock()|};""", + """new Mock(MockBehavior.Loose);""", + ], + [ + """{|Moq1400:new Mock(MockBehavior.Default)|};""", + """new Mock(MockBehavior.Loose);""", + ], + [ + """new Mock(MockBehavior.Loose);""", + """new Mock(MockBehavior.Loose);""", + ], + [ + """new Mock(MockBehavior.Strict);""", + """new Mock(MockBehavior.Strict);""", + ], + }.WithNamespaces().WithMoqReferenceAssemblyGroups(); + + IEnumerable mockConstructorsWithExpressions = new object[][] + { + [ + """{|Moq1400:new Mock(() => new Calculator())|};""", + """new Mock(() => new Calculator(), MockBehavior.Loose);""", + ], + [ + """{|Moq1400:new Mock(() => new Calculator(), MockBehavior.Default)|};""", + """new Mock(() => new Calculator(), MockBehavior.Loose);""", + ], + [ + """new Mock(() => new Calculator(), MockBehavior.Loose);""", + """new Mock(() => new Calculator(), MockBehavior.Loose);""", + ], + [ + """new Mock(() => new Calculator(), MockBehavior.Strict);""", + """new Mock(() => new Calculator(), MockBehavior.Strict);""", + ], + }.WithNamespaces().WithNewMoqReferenceAssemblyGroups(); + + IEnumerable fluentBuilders = new object[][] + { + [ + """{|Moq1400:Mock.Of()|};""", + """Mock.Of(MockBehavior.Loose);""", + ], + [ + """{|Moq1400:Mock.Of(MockBehavior.Default)|};""", + """Mock.Of(MockBehavior.Loose);""", + ], + [ + """Mock.Of(MockBehavior.Loose);""", + """Mock.Of(MockBehavior.Loose);""", + ], + [ + """Mock.Of(MockBehavior.Strict);""", + """Mock.Of(MockBehavior.Strict);""", + ], + }.WithNamespaces().WithNewMoqReferenceAssemblyGroups(); + + IEnumerable mockRepositories = new object[][] + { + [ + """{|Moq1400:new MockRepository(MockBehavior.Default)|};""", + """new MockRepository(MockBehavior.Loose);""", + ], + [ + """new MockRepository(MockBehavior.Loose);""", + """new MockRepository(MockBehavior.Loose);""", + ], + [ + """new MockRepository(MockBehavior.Strict);""", + """new MockRepository(MockBehavior.Strict);""", + ], + }.WithNamespaces().WithNewMoqReferenceAssemblyGroups(); + + return mockConstructors.Union(mockConstructorsWithExpressions).Union(fluentBuilders).Union(mockRepositories); + } + + [Theory] + [MemberData(nameof(TestData))] + public async Task ShouldAnalyzeMocksWithoutExplicitMockBehavior(string referenceAssemblyGroup, string @namespace, string original, string quickFix) + { + static string Template(string ns, string mock) => + $$""" + {{ns}} + + public interface ISample + { + int Calculate(int a, int b); + } + + public class Calculator + { + public int Calculate(int a, int b) + { + return a + b; + } + } + + internal class UnitTest + { + private void Test() + { + {{mock}} + } + } + """; + + string o = Template(@namespace, original); + string f = Template(@namespace, quickFix); + + _output.WriteLine("Original:"); + _output.WriteLine(o); + _output.WriteLine(string.Empty); + _output.WriteLine("Fixed:"); + _output.WriteLine(f); + + await Verifier.VerifyCodeFixAsync(o, f, referenceAssemblyGroup); + } +}