File: IntroduceVariable\CSharpIntroduceVariableService_IntroduceLocal.cs
Web Access
Project: ..\..\..\src\Features\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.Features.csproj (Microsoft.CodeAnalysis.CSharp.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#nullable disable
 
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.CSharp.IntroduceVariable
{
    internal partial class CSharpIntroduceVariableService
    {
        protected override async Task<Document> IntroduceLocalAsync(
            SemanticDocument document,
            ExpressionSyntax expression,
            bool allOccurrences,
            bool isConstant,
            CancellationToken cancellationToken)
        {
            var containerToGenerateInto = expression.Ancestors().FirstOrDefault(s =>
                s is BlockSyntax or ArrowExpressionClauseSyntax or LambdaExpressionSyntax);
 
            var newLocalNameToken = GenerateUniqueLocalName(
                document, expression, isConstant, containerToGenerateInto, cancellationToken);
            var newLocalName = SyntaxFactory.IdentifierName(newLocalNameToken);
 
            var modifiers = isConstant
                ? SyntaxFactory.TokenList(SyntaxFactory.Token(SyntaxKind.ConstKeyword))
                : default;
 
            var declarationStatement = SyntaxFactory.LocalDeclarationStatement(
                modifiers,
                SyntaxFactory.VariableDeclaration(
                    GetTypeSyntax(document, expression, cancellationToken),
                    SyntaxFactory.SingletonSeparatedList(SyntaxFactory.VariableDeclarator(
                        newLocalNameToken.WithAdditionalAnnotations(RenameAnnotation.Create()),
                        null,
                        SyntaxFactory.EqualsValueClause(expression.WithoutTrivia())))));
 
            // If we're inserting into a multi-line parent, then add a newline after the local-var
            // we're adding.  That way we don't end up having it and the starting statement be on
            // the same line (which will cause indentation to be computed incorrectly).
            var text = await document.Document.GetTextAsync(cancellationToken).ConfigureAwait(false);
            if (!text.AreOnSameLine(containerToGenerateInto.GetFirstToken(), containerToGenerateInto.GetLastToken()))
            {
                declarationStatement = declarationStatement.WithAppendedTrailingTrivia(SyntaxFactory.ElasticCarriageReturnLineFeed);
            }
 
            switch (containerToGenerateInto)
            {
                case BlockSyntax block:
                    return await IntroduceLocalDeclarationIntoBlockAsync(
                        document, block, expression, newLocalName, declarationStatement, allOccurrences, cancellationToken).ConfigureAwait(false);
 
                case ArrowExpressionClauseSyntax arrowExpression:
                    // this will be null for expression-bodied properties & indexer (not for individual getters & setters, those do have a symbol),
                    // both of which are a shorthand for the getter and always return a value
                    var method = document.SemanticModel.GetDeclaredSymbol(arrowExpression.Parent, cancellationToken) as IMethodSymbol;
                    var createReturnStatement = true;
 
                    if (method is not null)
                        createReturnStatement = !method.ReturnsVoid && !method.IsAsyncReturningVoidTask(document.SemanticModel.Compilation);
 
                    return RewriteExpressionBodiedMemberAndIntroduceLocalDeclaration(
                        document, arrowExpression, expression, newLocalName,
                        declarationStatement, allOccurrences, createReturnStatement, cancellationToken);
 
                case LambdaExpressionSyntax lambda:
                    return IntroduceLocalDeclarationIntoLambda(
                        document, lambda, expression, newLocalName, declarationStatement,
                        allOccurrences, cancellationToken);
            }
 
            throw new InvalidOperationException();
        }
 
        private Document IntroduceLocalDeclarationIntoLambda(
            SemanticDocument document,
            LambdaExpressionSyntax oldLambda,
            ExpressionSyntax expression,
            IdentifierNameSyntax newLocalName,
            LocalDeclarationStatementSyntax declarationStatement,
            bool allOccurrences,
            CancellationToken cancellationToken)
        {
            var oldBody = (ExpressionSyntax)oldLambda.Body;
            var isEntireLambdaBodySelected = oldBody.Equals(expression.WalkUpParentheses());
 
            var rewrittenBody = Rewrite(
                document, expression, newLocalName, document, oldBody, allOccurrences, cancellationToken);
 
            var shouldIncludeReturnStatement = ShouldIncludeReturnStatement(document, oldLambda, cancellationToken);
            var newBody = GetNewBlockBodyForLambda(
                declarationStatement, isEntireLambdaBodySelected, rewrittenBody, shouldIncludeReturnStatement);
 
            // Add an elastic newline so that the formatter will place this new lambda body across multiple lines.
            newBody = newBody.WithOpenBraceToken(newBody.OpenBraceToken.WithAppendedTrailingTrivia(SyntaxFactory.ElasticCarriageReturnLineFeed))
                             .WithAdditionalAnnotations(Formatter.Annotation);
 
            var newLambda = oldLambda.WithBody(newBody);
 
            var newRoot = document.Root.ReplaceNode(oldLambda, newLambda);
            return document.Document.WithSyntaxRoot(newRoot);
        }
 
        private static bool ShouldIncludeReturnStatement(
            SemanticDocument document,
            LambdaExpressionSyntax oldLambda,
            CancellationToken cancellationToken)
        {
            if (document.SemanticModel.GetTypeInfo(oldLambda, cancellationToken).ConvertedType is INamedTypeSymbol delegateType &&
                delegateType.DelegateInvokeMethod != null)
            {
                if (delegateType.DelegateInvokeMethod.ReturnsVoid)
                {
                    return false;
                }
 
                // Async lambdas with a Task or ValueTask return type don't need a return statement.
                // e.g.:
                //     Func<int, Task> f = async x => await M2();
                //
                // After refactoring:
                //     Func<int, Task> f = async x =>
                //     {
                //         Task task = M2();
                //         await task;
                //     };
                var compilation = document.SemanticModel.Compilation;
                var delegateReturnType = delegateType.DelegateInvokeMethod.ReturnType;
                if (oldLambda.AsyncKeyword != default && delegateReturnType != null)
                {
                    if ((compilation.TaskType() != null && delegateReturnType.Equals(compilation.TaskType())) ||
                        (compilation.ValueTaskType() != null && delegateReturnType.Equals(compilation.ValueTaskType())))
                    {
                        return false;
                    }
                }
            }
 
            return true;
        }
 
        private static BlockSyntax GetNewBlockBodyForLambda(
            LocalDeclarationStatementSyntax declarationStatement,
            bool isEntireLambdaBodySelected,
            ExpressionSyntax rewrittenBody,
            bool includeReturnStatement)
        {
            if (includeReturnStatement)
            {
                // Case 1: The lambda has a non-void return type.
                // e.g.:
                //     Func<int, int> f = x => [|x + 1|];
                //
                // After refactoring:
                //     Func<int, int> f = x =>
                //     {
                //         var v = x + 1;
                //         return v;
                //     };
                return SyntaxFactory.Block(declarationStatement, SyntaxFactory.ReturnStatement(rewrittenBody));
            }
 
            // For lambdas with void return types, we don't need to include the rewritten body if the entire lambda body
            // was originally selected for refactoring, as the rewritten body should already be encompassed within the
            // declaration statement.
            if (isEntireLambdaBodySelected)
            {
                // Case 2a: The lambda has a void return type, and the user selects the entire lambda body.
                // e.g.:
                //     Action<int> goo = x => [|x.ToString()|];
                //
                // After refactoring:
                //     Action<int> goo = x =>
                //     {
                //         string v = x.ToString();
                //     };
                return SyntaxFactory.Block(declarationStatement);
            }
 
            // Case 2b: The lambda has a void return type, and the user didn't select the entire lambda body.
            // e.g.:
            //     Task.Run(() => File.Copy("src", [|Path.Combine("dir", "file")|]));
            //
            // After refactoring:
            //     Task.Run(() =>
            //     {
            //         string destFileName = Path.Combine("dir", "file");
            //         File.Copy("src", destFileName);
            //     });
            return SyntaxFactory.Block(
                declarationStatement,
                SyntaxFactory.ExpressionStatement(rewrittenBody, SyntaxFactory.Token(SyntaxKind.SemicolonToken)));
        }
 
        private static TypeSyntax GetTypeSyntax(SemanticDocument document, ExpressionSyntax expression, CancellationToken cancellationToken)
        {
            var typeSymbol = GetTypeSymbol(document, expression, cancellationToken);
            return typeSymbol.GenerateTypeSyntax();
        }
 
        private Document RewriteExpressionBodiedMemberAndIntroduceLocalDeclaration(
            SemanticDocument document,
            ArrowExpressionClauseSyntax arrowExpression,
            ExpressionSyntax expression,
            NameSyntax newLocalName,
            LocalDeclarationStatementSyntax declarationStatement,
            bool allOccurrences,
            bool createReturnStatement,
            CancellationToken cancellationToken)
        {
            var oldBody = arrowExpression;
            var oldParentingNode = oldBody.Parent;
            var leadingTrivia = oldBody.GetLeadingTrivia()
                                       .AddRange(oldBody.ArrowToken.TrailingTrivia);
 
            var newExpression = Rewrite(document, expression, newLocalName, document, oldBody.Expression, allOccurrences, cancellationToken);
 
            var convertedStatement = createReturnStatement
                ? SyntaxFactory.ReturnStatement(newExpression)
                : (StatementSyntax)SyntaxFactory.ExpressionStatement(newExpression);
 
            var newBody = SyntaxFactory.Block(declarationStatement, convertedStatement)
                                       .WithLeadingTrivia(leadingTrivia)
                                       .WithTrailingTrivia(oldBody.GetTrailingTrivia());
 
            // Add an elastic newline so that the formatter will place this new block across multiple lines.
            newBody = newBody.WithOpenBraceToken(newBody.OpenBraceToken.WithAppendedTrailingTrivia(SyntaxFactory.ElasticCarriageReturnLineFeed))
                             .WithAdditionalAnnotations(Formatter.Annotation);
 
            var newRoot = document.Root.ReplaceNode(oldParentingNode, WithBlockBody(oldParentingNode, newBody));
            return document.Document.WithSyntaxRoot(newRoot);
        }
 
        private static SyntaxNode WithBlockBody(SyntaxNode node, BlockSyntax body)
        {
            switch (node)
            {
                case BasePropertyDeclarationSyntax baseProperty:
                    var accessorList = SyntaxFactory.AccessorList(SyntaxFactory.SingletonList(
                        SyntaxFactory.AccessorDeclaration(SyntaxKind.GetAccessorDeclaration, body)));
                    return baseProperty
                        .TryWithExpressionBody(null)
                        .WithAccessorList(accessorList)
                        .TryWithSemicolonToken(SyntaxFactory.Token(SyntaxKind.None))
                        .WithTriviaFrom(baseProperty);
                case AccessorDeclarationSyntax accessor:
                    return accessor
                        .WithExpressionBody(null)
                        .WithBody(body)
                        .WithSemicolonToken(SyntaxFactory.Token(SyntaxKind.None))
                        .WithTriviaFrom(accessor);
                case BaseMethodDeclarationSyntax baseMethod:
                    return baseMethod
                        .WithExpressionBody(null)
                        .WithBody(body)
                        .WithSemicolonToken(SyntaxFactory.Token(SyntaxKind.None))
                        .WithTriviaFrom(baseMethod);
                case LocalFunctionStatementSyntax localFunction:
                    return localFunction
                        .WithExpressionBody(null)
                        .WithBody(body)
                        .WithSemicolonToken(SyntaxFactory.Token(SyntaxKind.None))
                        .WithTriviaFrom(localFunction);
                default:
                    throw ExceptionUtilities.UnexpectedValue(node);
            }
        }
 
        private async Task<Document> IntroduceLocalDeclarationIntoBlockAsync(
            SemanticDocument document,
            BlockSyntax block,
            ExpressionSyntax expression,
            NameSyntax newLocalName,
            LocalDeclarationStatementSyntax declarationStatement,
            bool allOccurrences,
            CancellationToken cancellationToken)
        {
            declarationStatement = declarationStatement.WithAdditionalAnnotations(Formatter.Annotation);
 
            SyntaxNode scope = block;
 
            // If we're within a non-static local function, our scope for the new local declaration is expanded to include the enclosing member.
            var localFunction = block.GetAncestor<LocalFunctionStatementSyntax>();
            if (localFunction != null && !localFunction.Modifiers.Any(modifier => modifier.IsKind(SyntaxKind.StaticKeyword)))
            {
                scope = block.GetAncestor<MemberDeclarationSyntax>();
            }
 
            var matches = FindMatches(document, expression, document, scope, allOccurrences, cancellationToken);
            Debug.Assert(matches.Contains(expression));
 
            (document, matches) = await ComplexifyParentingStatementsAsync(document, matches, cancellationToken).ConfigureAwait(false);
 
            // Our original expression should have been one of the matches, which were tracked as part
            // of complexification, so we can retrieve the latest version of the expression here.
            expression = document.Root.GetCurrentNode(expression);
 
            var root = document.Root;
            ISet<StatementSyntax> allAffectedStatements = new HashSet<StatementSyntax>(matches.SelectMany(expr => GetApplicableStatementAncestors(expr)));
 
            SyntaxNode innermostCommonBlock;
 
            var innermostStatements = new HashSet<StatementSyntax>(matches.Select(expr => GetApplicableStatementAncestors(expr).First()));
            if (innermostStatements.Count == 1)
            {
                // if there was only one match, or all the matches came from the same statement
                var statement = innermostStatements.Single();
 
                // and the statement is an embedded statement without a block, we want to generate one
                // around this statement rather than continue going up to find an actual block
                if (!IsBlockLike(statement.Parent))
                {
                    root = root.TrackNodes(allAffectedStatements.Concat(new SyntaxNode[] { expression, statement }));
                    root = root.ReplaceNode(root.GetCurrentNode(statement),
                        SyntaxFactory.Block(root.GetCurrentNode(statement)).WithAdditionalAnnotations(Formatter.Annotation));
 
                    expression = root.GetCurrentNode(expression);
                    allAffectedStatements = allAffectedStatements.Select(root.GetCurrentNode).ToSet();
 
                    statement = root.GetCurrentNode(statement);
                }
 
                innermostCommonBlock = statement.Parent;
            }
            else
            {
                innermostCommonBlock = innermostStatements.FindInnermostCommonNode(IsBlockLike);
            }
 
            var firstStatementAffectedIndex = GetFirstStatementAffectedIndex(innermostCommonBlock, matches, GetStatements(innermostCommonBlock).IndexOf(allAffectedStatements.Contains));
 
            var newInnerMostBlock = Rewrite(
                document, expression, newLocalName, document, innermostCommonBlock, allOccurrences, cancellationToken);
 
            var statements = InsertWithinTriviaOfNext(GetStatements(newInnerMostBlock), declarationStatement, firstStatementAffectedIndex);
            var finalInnerMostBlock = WithStatements(newInnerMostBlock, statements);
 
            var newRoot = root.ReplaceNode(innermostCommonBlock, finalInnerMostBlock);
            return document.Document.WithSyntaxRoot(newRoot);
        }
 
        private static IEnumerable<StatementSyntax> GetApplicableStatementAncestors(ExpressionSyntax expr)
        {
            foreach (var statement in expr.GetAncestorsOrThis<StatementSyntax>())
            {
                // When determining where to put a local, we don't want to put it between the `else`
                // and `if` of a compound if-statement.
 
                if (statement.Kind() == SyntaxKind.IfStatement &&
                    statement.IsParentKind(SyntaxKind.ElseClause))
                {
                    continue;
                }
 
                yield return statement;
            }
        }
 
        private static int GetFirstStatementAffectedIndex(SyntaxNode innermostCommonBlock, ISet<ExpressionSyntax> matches, int firstStatementAffectedIndex)
        {
            // If a local function is involved, we have to make sure the new declaration is placed:
            //     1. Before all calls to local functions that use the variable.
            //     2. Before the local function(s) themselves.
            //     3. Before all matches, i.e. places in the code where the new declaration will replace existing code.
            // Cases (2) and (3) are already covered by the 'firstStatementAffectedIndex' parameter. Thus, all we have to do is ensure we consider (1) when
            // determining where to place our new declaration.
 
            // Find all the local functions within the scope that will use the new declaration.
            var localFunctions = innermostCommonBlock.DescendantNodes().Where(node => node.IsKind(SyntaxKind.LocalFunctionStatement) && matches.Any(match => match.Span.OverlapsWith(node.Span)));
 
            if (localFunctions.IsEmpty())
            {
                return firstStatementAffectedIndex;
            }
 
            var localFunctionIdentifiers = localFunctions.Select(node => ((LocalFunctionStatementSyntax)node).Identifier.ValueText);
 
            // Find all calls to the applicable local functions within the scope.
            var localFunctionCalls = innermostCommonBlock.DescendantNodes().Where(node => node is InvocationExpressionSyntax invocationExpression &&
                                                                                  invocationExpression.Expression.GetRightmostName() != null &&
                                                                                  !invocationExpression.Expression.IsKind(SyntaxKind.SimpleMemberAccessExpression) &&
                                                                                  localFunctionIdentifiers.Contains(invocationExpression.Expression.GetRightmostName().Identifier.ValueText));
 
            if (localFunctionCalls.IsEmpty())
            {
                return firstStatementAffectedIndex;
            }
 
            // Find which call is the earliest.
            var earliestLocalFunctionCall = localFunctionCalls.Min(node => node.SpanStart);
 
            var statementsInBlock = GetStatements(innermostCommonBlock);
 
            // Check if our earliest call is before all local function declarations and all matches, and if so, place our new declaration there.
            var earliestLocalFunctionCallIndex = statementsInBlock.IndexOf(s => s.Span.Contains(earliestLocalFunctionCall));
            return Math.Min(earliestLocalFunctionCallIndex, firstStatementAffectedIndex);
        }
 
        private static SyntaxList<StatementSyntax> InsertWithinTriviaOfNext(
            SyntaxList<StatementSyntax> oldStatements,
            StatementSyntax newStatement,
            int statementIndex)
        {
            var nextStatement = oldStatements.ElementAtOrDefault(statementIndex);
            if (nextStatement == null)
                return oldStatements.Insert(statementIndex, newStatement);
 
            // Grab all the trivia before the line the next statement is on and move it to the new node.
 
            var nextStatementLeading = nextStatement.GetLeadingTrivia();
            var precedingEndOfLine = nextStatementLeading.LastOrDefault(t => t.Kind() == SyntaxKind.EndOfLineTrivia);
            if (precedingEndOfLine == default)
            {
                return oldStatements.ReplaceRange(
                    nextStatement, new[] { newStatement, nextStatement });
            }
 
            var endOfLineIndex = nextStatementLeading.IndexOf(precedingEndOfLine) + 1;
 
            return oldStatements.ReplaceRange(
                nextStatement, new[]
                {
                    newStatement.WithLeadingTrivia(nextStatementLeading.Take(endOfLineIndex)),
                    nextStatement.WithLeadingTrivia(nextStatementLeading.Skip(endOfLineIndex)),
                });
        }
 
        private static bool IsBlockLike(SyntaxNode node) => node is BlockSyntax or SwitchSectionSyntax;
 
        private static SyntaxList<StatementSyntax> GetStatements(SyntaxNode blockLike)
            => blockLike switch
            {
                BlockSyntax block => block.Statements,
                SwitchSectionSyntax switchSection => switchSection.Statements,
                _ => throw ExceptionUtilities.UnexpectedValue(blockLike),
            };
 
        private static SyntaxNode WithStatements(SyntaxNode blockLike, SyntaxList<StatementSyntax> statements)
            => blockLike switch
            {
                BlockSyntax block => block.WithStatements(statements),
                SwitchSectionSyntax switchSection => switchSection.WithStatements(statements),
                _ => throw ExceptionUtilities.UnexpectedValue(blockLike),
            };
    }
}