File: AbstractSimplifyLinqExpressionDiagnosticAnalyzer.cs
Web Access
Project: ..\..\..\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Microsoft.CodeAnalysis.CodeStyle;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.PooledObjects;
 
namespace Microsoft.CodeAnalysis.SimplifyLinqExpression
{
    internal abstract class AbstractSimplifyLinqExpressionDiagnosticAnalyzer<TInvocationExpressionSyntax, TMemberAccessExpressionSyntax> : AbstractBuiltInCodeStyleDiagnosticAnalyzer
        where TInvocationExpressionSyntax : SyntaxNode
        where TMemberAccessExpressionSyntax : SyntaxNode
    {
        private static readonly IImmutableSet<string> s_nonEnumerableReturningLinqMethodNames =
            ImmutableHashSet.Create(
                nameof(Enumerable.First),
                nameof(Enumerable.Last),
                nameof(Enumerable.Single),
                nameof(Enumerable.Any),
                nameof(Enumerable.Count),
                nameof(Enumerable.SingleOrDefault),
                nameof(Enumerable.FirstOrDefault),
                nameof(Enumerable.LastOrDefault));
 
        protected abstract ISyntaxFacts SyntaxFacts { get; }
 
        public AbstractSimplifyLinqExpressionDiagnosticAnalyzer()
            : base(IDEDiagnosticIds.SimplifyLinqExpressionDiagnosticId,
                   EnforceOnBuildValues.SimplifyLinqExpression,
                   option: null,
                   title: new LocalizableResourceString(nameof(AnalyzersResources.Simplify_LINQ_expression), AnalyzersResources.ResourceManager, typeof(AnalyzersResources)))
        {
        }
 
        protected abstract IInvocationOperation? TryGetNextInvocationInChain(IInvocationOperation invocation);
 
        public override DiagnosticAnalyzerCategory GetAnalyzerCategory()
            => DiagnosticAnalyzerCategory.SemanticSpanAnalysis;
 
        protected override void InitializeWorker(AnalysisContext context)
            => context.RegisterCompilationStartAction(OnCompilationStart);
 
        private void OnCompilationStart(CompilationStartAnalysisContext context)
        {
            if (!TryGetEnumerableTypeSymbol(context.Compilation, out var enumerableType))
            {
                return;
            }
 
            if (!TryGetLinqWhereExtensionMethod(enumerableType, out var whereMethodSymbol))
            {
                return;
            }
 
            if (!TryGetLinqMethodsThatDoNotReturnEnumerables(enumerableType, out var linqMethodSymbols))
            {
                return;
            }
 
            context.RegisterOperationAction(
                context => AnalyzeInvocationOperation(context, enumerableType, whereMethodSymbol, linqMethodSymbols),
                OperationKind.Invocation);
 
            return;
 
            static bool TryGetEnumerableTypeSymbol(Compilation compilation, [NotNullWhen(true)] out INamedTypeSymbol? enumerableType)
            {
                enumerableType = compilation.GetTypeByMetadataName(typeof(Enumerable)?.FullName!);
                return enumerableType is not null;
            }
 
            static bool TryGetLinqWhereExtensionMethod(INamedTypeSymbol enumerableType, [NotNullWhen(true)] out IMethodSymbol? whereMethod)
            {
                foreach (var whereMethodSymbol in enumerableType.GetMembers(nameof(Enumerable.Where)).OfType<IMethodSymbol>())
                {
                    var parameters = whereMethodSymbol.Parameters;
 
                    if (parameters is [_, { Type: INamedTypeSymbol { Arity: 2 } }])
                    {
                        // This is the where overload that does not take and index (i.e. Where(source, Func<T, bool>) vs Where(source, Func<T, int, bool>))
                        whereMethod = whereMethodSymbol;
                        return true;
                    }
                }
 
                whereMethod = null;
                return false;
            }
 
            static bool TryGetLinqMethodsThatDoNotReturnEnumerables(INamedTypeSymbol enumerableType, out ImmutableArray<IMethodSymbol> linqMethods)
            {
                using var _ = ArrayBuilder<IMethodSymbol>.GetInstance(out var linqMethodSymbolsBuilder);
                foreach (var method in enumerableType.GetMembers().OfType<IMethodSymbol>())
                {
                    if (s_nonEnumerableReturningLinqMethodNames.Contains(method.Name) &&
                        method.Parameters is { Length: 1 })
                    {
                        linqMethodSymbolsBuilder.AddRange(method);
                    }
                }
 
                linqMethods = linqMethodSymbolsBuilder.ToImmutable();
                return linqMethods.Any();
            }
        }
 
        public void AnalyzeInvocationOperation(OperationAnalysisContext context, INamedTypeSymbol enumerableType, IMethodSymbol whereMethod, ImmutableArray<IMethodSymbol> linqMethods)
        {
            if (context.Operation.Syntax.GetDiagnostics().Any(diagnostic => diagnostic.Severity == DiagnosticSeverity.Error))
            {
                // Do not analyze linq methods that contain diagnostics.
                return;
            }
 
            if (context.Operation is not IInvocationOperation invocation ||
                !IsWhereLinqMethod(invocation))
            {
                // we only care about Where methods on linq expressions
                return;
            }
 
            if (TryGetNextInvocationInChain(invocation) is not IInvocationOperation nextInvocation ||
                !IsInvocationNonEnumerableReturningLinqMethod(nextInvocation))
            {
                // Invocation is not part of a chain of invocations (i.e. Where(x => x is not null).First())
                return;
            }
 
            if (TryGetSymbolOfMemberAccess(invocation) is not INamedTypeSymbol targetTypeSymbol ||
                TryGetMethodName(nextInvocation) is not string name)
            {
                return;
            }
 
            if (!targetTypeSymbol.Equals(enumerableType, SymbolEqualityComparer.Default) &&
                targetTypeSymbol.MemberNames.Contains(name))
            {
                // Do not offer to transpose if there is already a member on the collection named the same as the linq extension method
                // example: list.Where(x => x != null).Count() cannot be changed to list.Count(x => x != null) as List<T> already has a member named Count
                return;
            }
 
            context.ReportDiagnostic(
                DiagnosticHelper.Create(
                    Descriptor,
                    nextInvocation.Syntax.GetLocation(),
                    Descriptor.GetEffectiveSeverity(context.Compilation.Options),
                    additionalLocations: null,
                    properties: null));
 
            return;
 
            bool IsWhereLinqMethod(IInvocationOperation invocation)
                => whereMethod.Equals(invocation.TargetMethod.ReducedFrom ?? invocation.TargetMethod.OriginalDefinition, SymbolEqualityComparer.Default);
 
            bool IsInvocationNonEnumerableReturningLinqMethod(IInvocationOperation invocation)
                => linqMethods.Any(static (m, invocation) => m.Equals(invocation.TargetMethod.ReducedFrom ?? invocation.TargetMethod.OriginalDefinition, SymbolEqualityComparer.Default), invocation);
 
            INamedTypeSymbol? TryGetSymbolOfMemberAccess(IInvocationOperation invocation)
            {
                if (invocation.Syntax is TInvocationExpressionSyntax invocationNode &&
                    SyntaxFacts.GetExpressionOfInvocationExpression(invocationNode) is TMemberAccessExpressionSyntax memberAccess &&
                    SyntaxFacts.GetExpressionOfMemberAccessExpression(memberAccess) is SyntaxNode expression)
                {
                    return invocation.SemanticModel?.GetTypeInfo(expression).Type as INamedTypeSymbol;
                }
 
                return null;
            }
 
            string? TryGetMethodName(IInvocationOperation invocation)
            {
                if (invocation.Syntax is TInvocationExpressionSyntax invocationNode &&
                    SyntaxFacts.GetExpressionOfInvocationExpression(invocationNode) is TMemberAccessExpressionSyntax memberAccess)
                {
                    return SyntaxFacts.GetNameOfMemberAccessExpression(memberAccess).GetText().ToString();
                }
 
                return null;
            }
        }
    }
}