File: ConvertLinq\ConvertForEachToLinqQuery\DefaultConverter.cs
Web Access
Project: ..\..\..\src\Features\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.Features.csproj (Microsoft.CodeAnalysis.CSharp.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#nullable disable
 
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using Microsoft.CodeAnalysis.ConvertLinq.ConvertForEachToLinqQuery;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Formatting;
 
namespace Microsoft.CodeAnalysis.CSharp.ConvertLinq.ConvertForEachToLinqQuery
{
    internal sealed class DefaultConverter : AbstractConverter
    {
        private static readonly TypeSyntax VarNameIdentifier = SyntaxFactory.IdentifierName("var");
 
        public DefaultConverter(ForEachInfo<ForEachStatementSyntax, StatementSyntax> forEachInfo)
            : base(forEachInfo)
        {
        }
 
        public override void Convert(SyntaxEditor editor, bool convertToQuery, CancellationToken cancellationToken)
        {
            // Filter out identifiers which are not used in statements.
            var variableNamesReadInside = new HashSet<string>(ForEachInfo.Statements
                .SelectMany(statement => ForEachInfo.SemanticModel.AnalyzeDataFlow(statement).ReadInside).Select(symbol => symbol.Name));
            var identifiersUsedInStatements = ForEachInfo.Identifiers
                .Where(identifier => variableNamesReadInside.Contains(identifier.ValueText));
 
            // If there is a single statement and it is a block, leave it as is.
            // Otherwise, wrap with a block.
            var block = WrapWithBlockIfNecessary(
                ForEachInfo.Statements.SelectAsArray(statement => statement.KeepCommentsAndAddElasticMarkers()));
 
            editor.ReplaceNode(
                ForEachInfo.ForEachStatement,
                CreateDefaultReplacementStatement(identifiersUsedInStatements, block, convertToQuery)
                    .WithAdditionalAnnotations(Formatter.Annotation));
        }
 
        private StatementSyntax CreateDefaultReplacementStatement(
            IEnumerable<SyntaxToken> identifiers,
            BlockSyntax block,
            bool convertToQuery)
        {
            var identifiersCount = identifiers.Count();
            if (identifiersCount == 0)
            {
                // Generate foreach(var _ ... select new {})
                return SyntaxFactory.ForEachStatement(
                    VarNameIdentifier,
                    SyntaxFactory.Identifier("_"),
                    CreateQueryExpressionOrLinqInvocation(
                        SyntaxFactory.AnonymousObjectCreationExpression(),
                        Enumerable.Empty<SyntaxToken>(),
                        Enumerable.Empty<SyntaxToken>(),
                        convertToQuery),
                    block);
            }
            else if (identifiersCount == 1)
            {
                // Generate foreach(var singleIdentifier from ... select singleIdentifier)
                return SyntaxFactory.ForEachStatement(
                    VarNameIdentifier,
                    identifiers.Single(),
                    CreateQueryExpressionOrLinqInvocation(
                        SyntaxFactory.IdentifierName(identifiers.Single()),
                        Enumerable.Empty<SyntaxToken>(),
                        Enumerable.Empty<SyntaxToken>(),
                        convertToQuery),
                    block);
            }
            else
            {
                var tupleForSelectExpression = SyntaxFactory.TupleExpression(
                    SyntaxFactory.SeparatedList(identifiers.Select(
                        identifier => SyntaxFactory.Argument(SyntaxFactory.IdentifierName(identifier)))));
                var declaration = SyntaxFactory.DeclarationExpression(
                    VarNameIdentifier,
                    SyntaxFactory.ParenthesizedVariableDesignation(
                        SyntaxFactory.SeparatedList<VariableDesignationSyntax>(identifiers.Select(
                            identifier => SyntaxFactory.SingleVariableDesignation(identifier)))));
 
                // Generate foreach(var (a,b) ... select (a, b))
                return SyntaxFactory.ForEachVariableStatement(
                    declaration,
                    CreateQueryExpressionOrLinqInvocation(
                        tupleForSelectExpression,
                        Enumerable.Empty<SyntaxToken>(),
                        Enumerable.Empty<SyntaxToken>(),
                        convertToQuery),
                    block);
            }
        }
 
        private static BlockSyntax WrapWithBlockIfNecessary(ImmutableArray<StatementSyntax> statements)
            => statements is [BlockSyntax block] ? block : SyntaxFactory.Block(statements);
    }
}