File: CSharpUseLocalFunctionDiagnosticAnalyzer.cs
Web Access
Project: ..\..\..\src\Features\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.Features.csproj (Microsoft.CodeAnalysis.CSharp.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using Microsoft.CodeAnalysis.CodeStyle;
using Microsoft.CodeAnalysis.CSharp.CodeStyle;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
 
namespace Microsoft.CodeAnalysis.CSharp.UseLocalFunction
{
    /// <summary>
    /// Looks for code of the form:
    /// 
    ///     Func&lt;int, int&gt; fib = n =>
    ///     {
    ///         if (n &lt;= 2)
    ///             return 1
    ///             
    ///         return fib(n - 1) + fib(n - 2);
    ///     }
    ///     
    /// and converts it to:
    /// 
    ///     int fib(int n)
    ///     {
    ///         if (n &lt;= 2)
    ///             return 1
    ///             
    ///         return fib(n - 1) + fib(n - 2);
    ///     }
    /// </summary>
    [DiagnosticAnalyzer(LanguageNames.CSharp)]
    internal class CSharpUseLocalFunctionDiagnosticAnalyzer : AbstractBuiltInCodeStyleDiagnosticAnalyzer
    {
        public CSharpUseLocalFunctionDiagnosticAnalyzer()
            : base(IDEDiagnosticIds.UseLocalFunctionDiagnosticId,
                   EnforceOnBuildValues.UseLocalFunction,
                   CSharpCodeStyleOptions.PreferLocalOverAnonymousFunction,
                   new LocalizableResourceString(
                       nameof(CSharpAnalyzersResources.Use_local_function), CSharpAnalyzersResources.ResourceManager, typeof(CSharpAnalyzersResources)))
        {
        }
 
        protected override void InitializeWorker(AnalysisContext context)
        {
            context.RegisterCompilationStartAction(compilationContext =>
            {
                var compilation = compilationContext.Compilation;
 
                // Local functions are only available in C# 7.0 and above.  Don't offer this refactoring
                // in projects targeting a lesser version.
                if (compilation.LanguageVersion() < LanguageVersion.CSharp7)
                    return;
 
                var expressionType = compilation.GetTypeByMetadataName(typeof(Expression<>).FullName!);
                context.RegisterSyntaxNodeAction(ctx => SyntaxNodeAction(ctx, expressionType),
                    SyntaxKind.SimpleLambdaExpression, SyntaxKind.ParenthesizedLambdaExpression, SyntaxKind.AnonymousMethodExpression);
            });
        }
 
        private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTypeSymbol? expressionType)
        {
            var styleOption = syntaxContext.GetCSharpAnalyzerOptions().PreferLocalOverAnonymousFunction;
            if (!styleOption.Value)
            {
                // Bail immediately if the user has disabled this feature.
                return;
            }
 
            var severity = styleOption.Notification.Severity;
            var anonymousFunction = (AnonymousFunctionExpressionSyntax)syntaxContext.Node;
 
            var semanticModel = syntaxContext.SemanticModel;
            if (!CheckForPattern(anonymousFunction, out var localDeclaration))
            {
                return;
            }
 
            if (localDeclaration.Declaration.Variables.Count != 1)
            {
                return;
            }
 
            if (localDeclaration.Parent is not BlockSyntax block)
            {
                return;
            }
 
            // If there are compiler error on the declaration we can't reliably
            // tell that the refactoring will be accurate, so don't provide any
            // code diagnostics
            if (localDeclaration.GetDiagnostics().Any(d => d.Severity == DiagnosticSeverity.Error))
            {
                return;
            }
 
            var cancellationToken = syntaxContext.CancellationToken;
            var local = semanticModel.GetDeclaredSymbol(localDeclaration.Declaration.Variables[0], cancellationToken);
            if (local == null)
            {
                return;
            }
 
            var delegateType = semanticModel.GetTypeInfo(anonymousFunction, cancellationToken).ConvertedType as INamedTypeSymbol;
            if (!delegateType.IsDelegateType() ||
                delegateType.DelegateInvokeMethod == null ||
                !CanReplaceDelegateWithLocalFunction(delegateType, localDeclaration, semanticModel, cancellationToken))
            {
                return;
            }
 
            if (!CanReplaceAnonymousWithLocalFunction(semanticModel, expressionType, local, block, anonymousFunction, out var referenceLocations, cancellationToken))
                return;
 
            // Looks good!
            var additionalLocations = ImmutableArray.Create(
                localDeclaration.GetLocation(),
                anonymousFunction.GetLocation());
 
            additionalLocations = additionalLocations.AddRange(referenceLocations);
 
            if (severity.WithDefaultSeverity(DiagnosticSeverity.Hidden) < ReportDiagnostic.Hidden)
            {
                // If the diagnostic is not hidden, then just place the user visible part
                // on the local being initialized with the lambda.
                syntaxContext.ReportDiagnostic(DiagnosticHelper.Create(
                    Descriptor,
                    localDeclaration.Declaration.Variables[0].Identifier.GetLocation(),
                    severity,
                    additionalLocations,
                    properties: null));
            }
            else
            {
                // If the diagnostic is hidden, place it on the entire construct.
                syntaxContext.ReportDiagnostic(DiagnosticHelper.Create(
                    Descriptor,
                    localDeclaration.GetLocation(),
                    severity,
                    additionalLocations,
                    properties: null));
 
                var anonymousFunctionStatement = anonymousFunction.GetAncestor<StatementSyntax>();
                if (anonymousFunctionStatement != null && localDeclaration != anonymousFunctionStatement)
                {
                    syntaxContext.ReportDiagnostic(DiagnosticHelper.Create(
                        Descriptor,
                        anonymousFunctionStatement.GetLocation(),
                        severity,
                        additionalLocations,
                        properties: null));
                }
            }
        }
 
        private static bool CheckForPattern(
            AnonymousFunctionExpressionSyntax anonymousFunction,
            [NotNullWhen(true)] out LocalDeclarationStatementSyntax? localDeclaration)
        {
            // Look for:
            //
            // Type t = <anonymous function>
            // var t = (Type)(<anonymous function>)
            //
            // Type t = null;
            // t = <anonymous function>
            return CheckForSimpleLocalDeclarationPattern(anonymousFunction, out localDeclaration) ||
                   CheckForCastedLocalDeclarationPattern(anonymousFunction, out localDeclaration) ||
                   CheckForLocalDeclarationAndAssignment(anonymousFunction, out localDeclaration);
        }
 
        private static bool CheckForSimpleLocalDeclarationPattern(
            AnonymousFunctionExpressionSyntax anonymousFunction,
            [NotNullWhen(true)] out LocalDeclarationStatementSyntax? localDeclaration)
        {
            // Type t = <anonymous function>
            if (anonymousFunction.IsParentKind(SyntaxKind.EqualsValueClause) &&
                anonymousFunction.Parent.IsParentKind(SyntaxKind.VariableDeclarator) &&
                anonymousFunction.Parent.Parent.IsParentKind(SyntaxKind.VariableDeclaration) &&
                anonymousFunction.Parent.Parent.Parent.IsParentKind(SyntaxKind.LocalDeclarationStatement, out localDeclaration))
            {
                if (!localDeclaration.Declaration.Type.IsVar)
                {
                    return true;
                }
            }
 
            localDeclaration = null;
            return false;
        }
 
        private static bool CanReplaceAnonymousWithLocalFunction(
            SemanticModel semanticModel, INamedTypeSymbol? expressionTypeOpt, ISymbol local, BlockSyntax block,
            AnonymousFunctionExpressionSyntax anonymousFunction, out ImmutableArray<Location> referenceLocations, CancellationToken cancellationToken)
        {
            // Check all the references to the anonymous function and disallow the conversion if
            // they're used in certain ways.
            var references = ArrayBuilder<Location>.GetInstance();
            referenceLocations = ImmutableArray<Location>.Empty;
            var anonymousFunctionStart = anonymousFunction.SpanStart;
            foreach (var descendentNode in block.DescendantNodes())
            {
                var descendentStart = descendentNode.Span.Start;
                if (descendentStart <= anonymousFunctionStart)
                {
                    // This node is before the local declaration.  Can ignore it entirely as it could
                    // not be an access to the local.
                    continue;
                }
 
                if (descendentNode is IdentifierNameSyntax identifierName)
                {
                    if (identifierName.Identifier.ValueText == local.Name &&
                        local.Equals(semanticModel.GetSymbolInfo(identifierName, cancellationToken).GetAnySymbol()))
                    {
                        if (identifierName.IsWrittenTo(semanticModel, cancellationToken))
                        {
                            // Can't change this to a local function if it is assigned to.
                            return false;
                        }
 
                        var nodeToCheck = identifierName.WalkUpParentheses();
                        if (nodeToCheck.Parent is BinaryExpressionSyntax)
                        {
                            // Can't change this if they're doing things like delegate addition with
                            // the lambda.
                            return false;
                        }
 
                        if (nodeToCheck.Parent is InvocationExpressionSyntax invocationExpression)
                        {
                            references.Add(invocationExpression.GetLocation());
                        }
                        else if (nodeToCheck.Parent is MemberAccessExpressionSyntax memberAccessExpression)
                        {
                            if (memberAccessExpression.Parent is InvocationExpressionSyntax explicitInvocationExpression &&
                                memberAccessExpression.Name.Identifier.ValueText == WellKnownMemberNames.DelegateInvokeName)
                            {
                                references.Add(explicitInvocationExpression.GetLocation());
                            }
                            else
                            {
                                // They're doing something like "del.ToString()".  Can't do this with a
                                // local function.
                                return false;
                            }
                        }
                        else
                        {
                            references.Add(nodeToCheck.GetLocation());
                        }
 
                        var convertedType = semanticModel.GetTypeInfo(nodeToCheck, cancellationToken).ConvertedType;
                        if (!convertedType.IsDelegateType())
                        {
                            // We can't change this anonymous function into a local function if it is
                            // converted to a non-delegate type (i.e. converted to 'object' or 
                            // 'System.Delegate'). Local functions are not convertible to these types.  
                            // They're only convertible to other delegate types.
                            return false;
                        }
 
                        if (nodeToCheck.IsInExpressionTree(semanticModel, expressionTypeOpt, cancellationToken))
                        {
                            // Can't reference a local function inside an expression tree.
                            return false;
                        }
                    }
                }
            }
 
            referenceLocations = references.ToImmutableAndFree();
            return true;
        }
 
        private static bool CheckForCastedLocalDeclarationPattern(
            AnonymousFunctionExpressionSyntax anonymousFunction,
            [NotNullWhen(true)] out LocalDeclarationStatementSyntax? localDeclaration)
        {
            // var t = (Type)(<anonymous function>)
            var containingStatement = anonymousFunction.GetAncestor<StatementSyntax>();
            if (containingStatement.IsKind(SyntaxKind.LocalDeclarationStatement, out localDeclaration) &&
                localDeclaration.Declaration.Variables.Count == 1)
            {
                var variableDeclarator = localDeclaration.Declaration.Variables[0];
                if (variableDeclarator.Initializer != null)
                {
                    var value = variableDeclarator.Initializer.Value.WalkDownParentheses();
                    if (value is CastExpressionSyntax castExpression)
                    {
                        if (castExpression.Expression.WalkDownParentheses() == anonymousFunction)
                        {
                            return true;
                        }
                    }
                }
            }
 
            localDeclaration = null;
            return false;
        }
 
        private static bool CheckForLocalDeclarationAndAssignment(
            AnonymousFunctionExpressionSyntax anonymousFunction,
            [NotNullWhen(true)] out LocalDeclarationStatementSyntax? localDeclaration)
        {
            // Type t = null;
            // t = <anonymous function>
            if (anonymousFunction?.Parent is AssignmentExpressionSyntax(SyntaxKind.SimpleAssignmentExpression) assignment &&
                assignment.Parent is ExpressionStatementSyntax expressionStatement &&
                expressionStatement.Parent is BlockSyntax block)
            {
                if (assignment.Left.IsKind(SyntaxKind.IdentifierName))
                {
                    var expressionStatementIndex = block.Statements.IndexOf(expressionStatement);
                    if (expressionStatementIndex >= 1)
                    {
                        var previousStatement = block.Statements[expressionStatementIndex - 1];
                        if (previousStatement.IsKind(SyntaxKind.LocalDeclarationStatement, out localDeclaration) &&
                            localDeclaration.Declaration.Variables.Count == 1)
                        {
                            var variableDeclarator = localDeclaration.Declaration.Variables[0];
                            if (variableDeclarator.Initializer == null ||
                                variableDeclarator.Initializer.Value.Kind() is
                                    SyntaxKind.NullLiteralExpression or
                                    SyntaxKind.DefaultLiteralExpression or
                                    SyntaxKind.DefaultExpression)
                            {
                                var identifierName = (IdentifierNameSyntax)assignment.Left;
                                if (variableDeclarator.Identifier.ValueText == identifierName.Identifier.ValueText)
                                {
                                    return true;
                                }
                            }
                        }
                    }
                }
            }
 
            localDeclaration = null;
            return false;
        }
 
        private static bool CanReplaceDelegateWithLocalFunction(
            INamedTypeSymbol delegateType,
            LocalDeclarationStatementSyntax localDeclaration,
            SemanticModel semanticModel,
            CancellationToken cancellationToken)
        {
            var delegateContainingType = delegateType.ContainingType;
            if (delegateContainingType is null || !delegateContainingType.IsGenericType)
            {
                return true;
            }
 
            var delegateTypeParamNames = delegateType.GetAllTypeParameters().Select(p => p.Name).ToImmutableHashSet();
            var localEnclosingSymbol = semanticModel.GetEnclosingSymbol(localDeclaration.SpanStart, cancellationToken);
            while (localEnclosingSymbol != null)
            {
                if (localEnclosingSymbol.Equals(delegateContainingType))
                {
                    return true;
                }
 
                var typeParams = localEnclosingSymbol.GetTypeParameters();
                if (typeParams.Any())
                {
                    if (typeParams.Any(static (p, delegateTypeParamNames) => delegateTypeParamNames.Contains(p.Name), delegateTypeParamNames))
                    {
                        return false;
                    }
                }
 
                localEnclosingSymbol = localEnclosingSymbol.ContainingType;
            }
 
            return true;
        }
 
        public override DiagnosticAnalyzerCategory GetAnalyzerCategory()
            => DiagnosticAnalyzerCategory.SemanticSpanAnalysis;
    }
}