File: ConvertForToForEach\AbstractConvertForToForEachCodeRefactoringProvider.cs
Web Access
Project: ..\..\..\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.ConvertForToForEach
{
    internal abstract class AbstractConvertForToForEachCodeRefactoringProvider<
        TStatementSyntax,
        TForStatementSyntax,
        TExpressionSyntax,
        TMemberAccessExpressionSyntax,
        TTypeNode,
        TVariableDeclaratorSyntax> : CodeRefactoringProvider
        where TStatementSyntax : SyntaxNode
        where TForStatementSyntax : TStatementSyntax
        where TExpressionSyntax : SyntaxNode
        where TMemberAccessExpressionSyntax : SyntaxNode
        where TTypeNode : SyntaxNode
        where TVariableDeclaratorSyntax : SyntaxNode
    {
        protected abstract string GetTitle();
 
        protected abstract SyntaxList<TStatementSyntax> GetBodyStatements(TForStatementSyntax forStatement);
        protected abstract bool IsValidVariableDeclarator(TVariableDeclaratorSyntax firstVariable);
 
        protected abstract bool TryGetForStatementComponents(
            TForStatementSyntax forStatement,
            out SyntaxToken iterationVariable,
            [NotNullWhen(true)] out TExpressionSyntax? initializer,
            [NotNullWhen(true)] out TMemberAccessExpressionSyntax? memberAccess,
            out TExpressionSyntax? stepValueExpressionOpt,
            CancellationToken cancellationToken);
 
        protected abstract SyntaxNode ConvertForNode(
            TForStatementSyntax currentFor, TTypeNode? typeNode, SyntaxToken foreachIdentifier,
            TExpressionSyntax collectionExpression, ITypeSymbol iterationVariableType);
 
        public override async Task ComputeRefactoringsAsync(CodeRefactoringContext context)
        {
            var (document, textSpan, cancellationToken) = context;
            var forStatement = await context.TryGetRelevantNodeAsync<TForStatementSyntax>().ConfigureAwait(false);
            if (forStatement == null)
                return;
 
            if (!TryGetForStatementComponents(forStatement,
                    out var iterationVariable, out var initializer, out var memberAccess, out var stepValueExpressionOpt, cancellationToken))
            {
                return;
            }
 
            var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();
            syntaxFacts.GetPartsOfMemberAccessExpression(memberAccess,
                out var collectionExpressionNode, out var memberAccessNameNode);
 
            var collectionExpression = (TExpressionSyntax)collectionExpressionNode;
            syntaxFacts.GetNameAndArityOfSimpleName(memberAccessNameNode, out var memberAccessName, out _);
            if (memberAccessName is not nameof(Array.Length) and not nameof(IList.Count))
                return;
 
            var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
 
            // Make sure it's a single-variable for loop and that we're not a loop where we're
            // referencing some previously declared symbol.  i.e
            // VB allows:
            //
            //      dim i as integer
            //      for i = 0 to ...
            //
            // We can't convert this as it would change important semantics.
            // NOTE: we could potentially update this if we saw that the variable was not used
            // after the for-loop.  But, for now, we'll just be conservative and assume this means
            // the user wanted the 'i' for some other purpose and we should keep things as is.
            if (semanticModel.GetOperation(forStatement, cancellationToken) is not ILoopOperation { Locals.Length: 1 })
                return;
 
            // Make sure we're starting at 0.
            var initializerValue = semanticModel.GetConstantValue(initializer, cancellationToken);
            if (initializerValue is not { HasValue: true, Value: 0 })
                return;
 
            // Make sure we're incrementing by 1.
            if (stepValueExpressionOpt != null)
            {
                var stepValue = semanticModel.GetConstantValue(stepValueExpressionOpt);
                if (stepValue is not { HasValue: true, Value: 1 })
                    return;
            }
 
            var collectionType = semanticModel.GetTypeInfo(collectionExpression, cancellationToken);
            if (collectionType.Type is null or IErrorTypeSymbol)
                return;
 
            var containingType = semanticModel.GetEnclosingNamedType(textSpan.Start, cancellationToken);
            if (containingType == null)
                return;
 
            var ienumerableType = semanticModel.Compilation.GetSpecialType(SpecialType.System_Collections_Generic_IEnumerable_T);
            var ienumeratorType = semanticModel.Compilation.GetSpecialType(SpecialType.System_Collections_Generic_IEnumerator_T);
 
            // make sure the collection can be iterated.
            if (!TryGetIterationElementType(
                    containingType, collectionType.Type,
                    ienumerableType, ienumeratorType,
                    out var iterationType))
            {
                return;
            }
 
            // If the user uses the iteration variable for any other reason, we can't convert this.
            var bodyStatements = GetBodyStatements(forStatement);
            foreach (var statement in bodyStatements)
            {
                if (IterationVariableIsUsedForMoreThanCollectionIndex(statement))
                    return;
            }
 
            // Looks good.  We can convert this.
            var title = GetTitle();
            context.RegisterRefactoring(
                CodeAction.Create(
                    title,
                    cancellationToken => ConvertForToForEachAsync(
                        document, forStatement, iterationVariable, collectionExpression,
                        containingType, collectionType.Type, iterationType, cancellationToken),
                    title),
                forStatement.Span);
 
            return;
 
            // local functions
            bool IterationVariableIsUsedForMoreThanCollectionIndex(SyntaxNode current)
            {
                if (syntaxFacts.IsIdentifierName(current))
                {
                    syntaxFacts.GetNameAndArityOfSimpleName(current, out var name, out _);
                    if (name == iterationVariable.ValueText)
                    {
                        // found a reference.  make sure it's only used inside something like
                        // list[i]
 
                        var argument = current.Parent;
                        if (!syntaxFacts.IsSimpleArgument(argument))
                            return true;
 
                        // we support `list[i]` or `list.ElementAt(i)`
                        var argumentList = argument?.Parent;
                        if (argumentList is null)
                            return true;
 
                        var arguments = syntaxFacts.GetArgumentsOfArgumentList(argumentList);
                        // was used in a multi-dimensional indexing, or multiple argument method call.  Can't convert this.
                        if (arguments.Count != 1)
                            return true;
 
                        if (!IsGoodElementAccessExpression(argumentList) &&
                            !IsGoodInvocationExpression(argumentList))
                        {
                            // used in something other than accessing into a collection.
                            // can't convert this for-loop.
                            return true;
                        }
                    }
 
                    // this usage of the for-variable is fine.
                }
 
                foreach (var child in current.ChildNodesAndTokens())
                {
                    if (child.IsNode)
                    {
                        if (IterationVariableIsUsedForMoreThanCollectionIndex(child.AsNode()!))
                            return true;
                    }
                }
 
                return false;
            }
 
            bool IsGoodElementAccessExpression(SyntaxNode argumentList)
            {
                if (syntaxFacts.IsElementAccessExpression(argumentList.Parent))
                {
                    var expr = syntaxFacts.GetExpressionOfElementAccessExpression(argumentList.Parent);
 
                    // Have to be indexing into the collection.
                    if (syntaxFacts.AreEquivalent(expr, collectionExpression))
                        return true;
                }
 
                return false;
            }
 
            bool IsGoodInvocationExpression(SyntaxNode argumentList)
            {
                if (syntaxFacts.IsInvocationExpression(argumentList.Parent))
                {
                    var invokedExpression = syntaxFacts.GetExpressionOfInvocationExpression(argumentList.Parent);
                    if (syntaxFacts.IsMemberAccessExpression(invokedExpression))
                    {
                        syntaxFacts.GetPartsOfMemberAccessExpression(invokedExpression, out var accessedExpression, out var accessedName);
                        syntaxFacts.GetNameAndArityOfSimpleName(accessedName, out var memberName, out _);
 
                        // Have to be indexing into the collection.
                        if (memberName == nameof(Enumerable.ElementAt) &&
                            syntaxFacts.AreEquivalent(accessedExpression, collectionExpression))
                        {
                            return true;
                        }
                    }
                }
 
                return false;
            }
        }
 
        private static IEnumerable<TSymbol> TryFindMembersInThisOrBaseTypes<TSymbol>(
            INamedTypeSymbol containingType, ITypeSymbol type, string memberName) where TSymbol : class, ISymbol
        {
            var methods = type.GetAccessibleMembersInThisAndBaseTypes<TSymbol>(containingType);
            return methods.Where(m => m.Name == memberName);
        }
 
        private static TSymbol? TryFindMemberInThisOrBaseTypes<TSymbol>(
            INamedTypeSymbol containingType, ITypeSymbol type, string memberName) where TSymbol : class, ISymbol
        {
            return TryFindMembersInThisOrBaseTypes<TSymbol>(containingType, type, memberName).FirstOrDefault();
        }
 
        private static bool TryGetIterationElementType(
            INamedTypeSymbol containingType, ITypeSymbol collectionType,
            INamedTypeSymbol ienumerableType, INamedTypeSymbol ienumeratorType,
            [NotNullWhen(true)] out ITypeSymbol? iterationType)
        {
            if (collectionType is IArrayTypeSymbol arrayType)
            {
                iterationType = arrayType.ElementType;
 
                // We only support single-dimensional array iteration.
                return arrayType.Rank == 1;
            }
 
            // Check in the class/struct hierarchy first.
            var getEnumeratorMethod = TryFindMemberInThisOrBaseTypes<IMethodSymbol>(
                containingType, collectionType, WellKnownMemberNames.GetEnumeratorMethodName);
            if (getEnumeratorMethod != null)
            {
                return TryGetIterationElementTypeFromGetEnumerator(
                    containingType, getEnumeratorMethod, ienumeratorType, out iterationType);
            }
 
            // couldn't find .GetEnumerator on the class/struct.  Check the interface hierarchy.
            var instantiatedIEnumerableType = collectionType.GetAllInterfacesIncludingThis().FirstOrDefault(
                t => Equals(t.OriginalDefinition, ienumerableType));
 
            if (instantiatedIEnumerableType != null)
            {
                iterationType = instantiatedIEnumerableType.TypeArguments[0];
                return true;
            }
 
            iterationType = null;
            return false;
        }
 
        private static bool TryGetIterationElementTypeFromGetEnumerator(
            INamedTypeSymbol containingType, IMethodSymbol getEnumeratorMethod,
            INamedTypeSymbol ienumeratorType, [NotNullWhen(true)] out ITypeSymbol? iterationType)
        {
            var getEnumeratorReturnType = getEnumeratorMethod.ReturnType;
 
            // Check in the class/struct hierarchy first.
            var currentProperty = TryFindMemberInThisOrBaseTypes<IPropertySymbol>(
                containingType, getEnumeratorReturnType, WellKnownMemberNames.CurrentPropertyName);
            if (currentProperty != null)
            {
                iterationType = currentProperty.Type;
                return true;
            }
 
            // couldn't find .Current on the class/struct.  Check the interface hierarchy.
            var instantiatedIEnumeratorType = getEnumeratorReturnType.GetAllInterfacesIncludingThis().FirstOrDefault(
                t => Equals(t.OriginalDefinition, ienumeratorType));
 
            if (instantiatedIEnumeratorType != null)
            {
                iterationType = instantiatedIEnumeratorType.TypeArguments[0];
                return true;
            }
 
            iterationType = null;
            return false;
        }
 
        private async Task<Document> ConvertForToForEachAsync(
            Document document,
            TForStatementSyntax forStatement,
            SyntaxToken iterationVariable,
            TExpressionSyntax collectionExpression,
            INamedTypeSymbol containingType,
            ITypeSymbol collectionType,
            ITypeSymbol iterationType,
            CancellationToken cancellationToken)
        {
            var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();
            var semanticFacts = document.GetRequiredLanguageService<ISemanticFactsService>();
            var generator = SyntaxGenerator.GetGenerator(document);
 
            var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
            var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
            var editor = new SyntaxEditor(root, generator);
 
            // create dummy "list[i]" and "list.ElementAt(i)" expressions.  We'll use this to find all places to replace
            // in the current for statement.
            var indexExpression = generator.ElementAccessExpression(collectionExpression, generator.IdentifierName(iterationVariable));
            var elementAtExpression = generator.InvocationExpression(
                generator.MemberAccessExpression(collectionExpression, generator.IdentifierName(nameof(Enumerable.ElementAt))),
                generator.IdentifierName(iterationVariable));
 
            // See if the first statement in the for loop is of the form:
            //      var x = list[i]   or
            //
            // If so, we'll use those as the iteration variables for the new foreach statement.
            var (typeNode, foreachIdentifier, declarationStatement) = TryDeconstructInitialDeclaration();
 
            if (typeNode == null)
            {
                // user didn't provide an explicit type.  Check if the index-type of the collection
                // is different from than .Current type of the enumerator.  If so, add an explicit
                // type so that the foreach will coerce the types accordingly.
                var indexerType = GetIndexerType(containingType, collectionType, semanticModel.Compilation.GetSpecialType(SpecialType.System_Collections_Generic_IEnumerable_T));
                if (!Equals(indexerType, iterationType))
                {
                    typeNode = (TTypeNode)generator.TypeExpression(
                        indexerType ?? semanticModel.Compilation.GetSpecialType(SpecialType.System_Object));
                }
            }
 
            // If we couldn't find an appropriate existing variable to use as the foreach
            // variable, then generate one automatically.
            if (foreachIdentifier.RawKind == 0)
            {
                foreachIdentifier = semanticFacts.GenerateUniqueName(
                    semanticModel, forStatement, container: null, baseName: "v", usedNames: Enumerable.Empty<string>(), cancellationToken);
                foreachIdentifier = foreachIdentifier.WithAdditionalAnnotations(RenameAnnotation.Create());
            }
 
            // Create the expression we'll use to replace all matches in the for-body.
            var foreachIdentifierReference = foreachIdentifier.WithoutAnnotations(RenameAnnotation.Kind).WithoutTrivia();
 
            // Walk the for statement, replacing any matches we find.
            FindAndReplaceMatches(forStatement);
 
            // Finally, remove the declaration statement if we found one.  Move all its leading
            // trivia to the next statement.
            if (declarationStatement != null)
            {
                editor.RemoveNode(declarationStatement,
                    SyntaxGenerator.DefaultRemoveOptions | SyntaxRemoveOptions.KeepLeadingTrivia);
            }
 
            editor.ReplaceNode(
                forStatement,
                (currentFor, _) => ConvertForNode(
                    (TForStatementSyntax)currentFor, typeNode, foreachIdentifier,
                    collectionExpression, iterationType));
 
            return document.WithSyntaxRoot(editor.GetChangedRoot());
 
            // local functions
            (TTypeNode?, SyntaxToken, TStatementSyntax) TryDeconstructInitialDeclaration()
            {
                var bodyStatements = GetBodyStatements(forStatement);
 
                if (bodyStatements.Count >= 1)
                {
                    var firstStatement = bodyStatements[0];
                    if (syntaxFacts.IsLocalDeclarationStatement(firstStatement))
                    {
                        var variables = syntaxFacts.GetVariablesOfLocalDeclarationStatement(firstStatement);
                        if (variables.Count == 1)
                        {
                            var firstVariable = (TVariableDeclaratorSyntax)variables[0];
                            if (IsValidVariableDeclarator(firstVariable))
                            {
                                var initializer = syntaxFacts.GetInitializerOfVariableDeclarator(firstVariable);
                                if (initializer != null)
                                {
                                    var firstVariableInitializer = syntaxFacts.GetValueOfEqualsValueClause(initializer);
                                    if (syntaxFacts.AreEquivalent(firstVariableInitializer, indexExpression))
                                    {
                                        var type = (TTypeNode?)syntaxFacts.GetTypeOfVariableDeclarator(firstVariable)?.WithoutLeadingTrivia();
                                        var identifier = syntaxFacts.GetIdentifierOfVariableDeclarator(firstVariable);
                                        var statement = firstStatement;
                                        return (type, identifier, statement);
                                    }
                                }
                            }
                        }
                    }
                }
 
                return default;
            }
 
            void FindAndReplaceMatches(SyntaxNode current)
            {
                if (SemanticEquivalence.AreEquivalent(semanticModel, current, collectionExpression))
                {
                    if (syntaxFacts.AreEquivalent(current.Parent, indexExpression))
                    {
                        // Found a match.  replace with iteration variable.
                        var indexMatch = current.GetRequiredParent();
                        Replace(indexMatch);
                    }
                    else if (syntaxFacts.AreEquivalent(current.Parent?.Parent, elementAtExpression))
                    {
                        // Found a match.  replace with iteration variable.
                        var indexMatch = current.GetRequiredParent().GetRequiredParent();
                        Replace(indexMatch);
                    }
                    else
                    {
                        // Collection was used for some other purpose.  If it's passed as an argument
                        // to something, or is written to, or has a method invoked on it, we'll warn
                        // that it's potentially changing and may break if you switch to a foreach loop.
                        var shouldWarn = syntaxFacts.IsArgument(current.Parent);
                        shouldWarn |= semanticFacts.IsWrittenTo(semanticModel, current, cancellationToken);
                        shouldWarn |=
                            syntaxFacts.IsMemberAccessExpression(current.Parent) &&
                            syntaxFacts.IsInvocationExpression(current.Parent.Parent);
 
                        if (shouldWarn)
                        {
                            editor.ReplaceNode(
                                current,
                                (node, _) => node.WithAdditionalAnnotations(
                                    WarningAnnotation.Create(FeaturesResources.Warning_colon_Iteration_variable_crossed_function_boundary)));
                        }
                    }
 
                    return;
                }
 
                foreach (var child in current.ChildNodesAndTokens())
                {
                    if (child.IsNode)
                        FindAndReplaceMatches(child.AsNode()!);
                }
            }
 
            bool CrossesFunctionBoundary(SyntaxNode node)
            {
                var containingFunction = node.AncestorsAndSelf().FirstOrDefault(
                    n => syntaxFacts.IsLocalFunctionStatement(n) || syntaxFacts.IsAnonymousFunctionExpression(n));
 
                if (containingFunction == null)
                    return false;
 
                return containingFunction.AncestorsAndSelf().Contains(forStatement);
            }
 
            void Replace(SyntaxNode indexMatch)
            {
                var replacementToken = foreachIdentifierReference;
 
                if (semanticFacts.IsWrittenTo(semanticModel, indexMatch, cancellationToken))
                {
                    replacementToken = replacementToken.WithAdditionalAnnotations(
                        WarningAnnotation.Create(FeaturesResources.Warning_colon_Collection_was_modified_during_iteration));
                }
 
                if (CrossesFunctionBoundary(indexMatch))
                {
                    replacementToken = replacementToken.WithAdditionalAnnotations(
                        WarningAnnotation.Create(FeaturesResources.Warning_colon_Iteration_variable_crossed_function_boundary));
                }
 
                editor.ReplaceNode(
                    indexMatch,
                    generator.IdentifierName(replacementToken).WithTriviaFrom(indexMatch));
            }
        }
 
        private static ITypeSymbol? GetIndexerType(
            INamedTypeSymbol containingType,
            ITypeSymbol collectionType,
            INamedTypeSymbol ienumerableType)
        {
            if (collectionType is IArrayTypeSymbol arrayType)
                return arrayType.Rank == 1 ? arrayType.ElementType : null;
 
            var indexer = collectionType
                .GetAccessibleMembersInThisAndBaseTypes<IPropertySymbol>(containingType)
                .Where(IsViableIndexer)
                .FirstOrDefault();
 
            if (indexer?.Type != null)
                return indexer.Type;
 
            if (collectionType.IsInterfaceType())
            {
                var interfaces = collectionType.GetAllInterfacesIncludingThis();
                indexer = interfaces.SelectMany(i => i.GetMembers().OfType<IPropertySymbol>().Where(IsViableIndexer)).FirstOrDefault();
 
                if (indexer?.Type != null)
                    return indexer.Type;
            }
 
            foreach (var interfaceType in collectionType.GetAllInterfacesIncludingThis())
            {
                if (Equals(interfaceType.OriginalDefinition, ienumerableType))
                    return interfaceType.TypeArguments[0];
            }
 
            return null;
        }
 
        private static bool IsViableIndexer(IPropertySymbol property)
            => property is { IsIndexer: true, Parameters: [{ Type.SpecialType: SpecialType.System_Int32 }] };
    }
}