File: ConvertIfToSwitch\AbstractConvertIfToSwitchCodeRefactoringProvider.Analyzer.cs
Web Access
Project: ..\..\..\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Immutable;
using System.Diagnostics;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.PooledObjects;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.ConvertIfToSwitch
{
    using static BinaryOperatorKind;
 
    internal abstract partial class AbstractConvertIfToSwitchCodeRefactoringProvider<
        TIfStatementSyntax, TExpressionSyntax, TIsExpressionSyntax, TPatternSyntax>
    {
        // Match the following pattern which can be safely converted to switch statement
        //
        //    <if-statement-sequence>
        //        : if (<section-expr>) { <unreachable-end-point> }, <if-statement-sequence>
        //        : if (<section-expr>) { <unreachable-end-point> }, ( return | throw )
        //        | <if-statement>
        //
        //    <if-statement>
        //        : if (<section-expr>) { _ } else <if-statement>
        //        | if (<section-expr>) { _ } else { <if-statement-sequence> }
        //        | if (<section-expr>) { _ } else { _ }
        //        | if (<section-expr>) { _ }
        //
        //    <section-expr>
        //        : <section-expr> || <pattern-expr>
        //        | <pattern-expr>
        //
        //    <pattern-expr>
        //        : <pattern-expr> && <expr>                         // C#
        //        | <expr0> is <pattern>                             // C#
        //        | <expr0> is <type>                                // C#
        //        | <expr0> == <const-expr>                          // C#, VB
        //        | <expr0> <comparison-op> <const>                  // C#, VB
        //        | ( <expr0> >= <const> | <const> <= <expr0> )
        //           && ( <expr0> <= <const> | <const> >= <expr0> )  // C#, VB
        //        | ( <expr0> <= <const> | <const> >= <expr0> )
        //           && ( <expr0> >= <const> | <const> <= <expr0> )  // C#, VB
        //
        internal abstract class Analyzer
        {
            public abstract bool CanConvert(IConditionalOperation operation);
            public abstract bool HasUnreachableEndPoint(IOperation operation);
            public abstract bool CanImplicitlyConvert(SemanticModel semanticModel, SyntaxNode syntax, ITypeSymbol targetType);
 
            /// <summary>
            /// Holds the expression determined to be used as the target expression of the switch
            /// </summary>
            /// <remarks>
            /// Note that this is initially unset until we find a non-constant expression.
            /// </remarks>
            private SyntaxNode _switchTargetExpression = null!;
            /// <summary>
            /// Holds the type of the <see cref="_switchTargetExpression"/>
            /// </summary>
            private ITypeSymbol? _switchTargetType = null!;
            private readonly ISyntaxFacts _syntaxFacts;
 
            protected Analyzer(ISyntaxFacts syntaxFacts, Feature features)
            {
                _syntaxFacts = syntaxFacts;
                Features = features;
            }
 
            public Feature Features { get; }
 
            public bool Supports(Feature feature)
                => (Features & feature) != 0;
 
            public (ImmutableArray<AnalyzedSwitchSection>, SyntaxNode TargetExpression) AnalyzeIfStatementSequence(ReadOnlySpan<IOperation> operations)
            {
                using var _ = ArrayBuilder<AnalyzedSwitchSection>.GetInstance(out var sections);
                if (!ParseIfStatementSequence(operations, sections, out var defaultBodyOpt))
                {
                    return default;
                }
 
                if (defaultBodyOpt is object)
                {
                    sections.Add(new AnalyzedSwitchSection(labels: default, defaultBodyOpt, defaultBodyOpt.Syntax));
                }
 
                RoslynDebug.Assert(_switchTargetExpression is object);
                return (sections.ToImmutable(), _switchTargetExpression);
            }
 
            // Tree to parse:
            //
            //    <if-statement-sequence>
            //        : if (<section-expr>) { <unreachable-end-point> }, <if-statement-sequence>
            //        : if (<section-expr>) { <unreachable-end-point> }, ( return | throw )
            //        | <if-statement>
            //
            private bool ParseIfStatementSequence(ReadOnlySpan<IOperation> operations, ArrayBuilder<AnalyzedSwitchSection> sections, out IOperation? defaultBodyOpt)
            {
                var current = 0;
                while (current < operations.Length &&
                    operations[current] is IConditionalOperation { WhenFalse: null } op &&
                    HasUnreachableEndPoint(op.WhenTrue) &&
                    ParseIfStatement(op, sections, out _))
                {
                    current++;
                }
 
                defaultBodyOpt = null;
                if (current == 0)
                {
                    // didn't consume a sequence of if-statements with unreachable ends.  Check for the last case.
                    return operations.Length > 0 && ParseIfStatement(operations[0], sections, out defaultBodyOpt);
                }
                else
                {
                    if (current < operations.Length)
                    {
                        // consumed a sequence of if-statements with unreachable-ends.  If we end with a normal
                        // if-statement, we're done.  Otherwise, we end with whatever last return/throw we see.
                        var nextStatement = operations[current];
                        if (!ParseIfStatement(nextStatement, sections, out defaultBodyOpt) &&
                            nextStatement is IReturnOperation { ReturnedValue: not null } or IThrowOperation { Exception: not null })
                        {
                            defaultBodyOpt = nextStatement;
                        }
                    }
 
                    return true;
                }
            }
 
            // Tree to parse:
            //
            //    <if-statement>
            //        : if (<section-expr>) { _ } else <if-statement>
            //        | if (<section-expr>) { _ } else { <if-statement-sequence> }
            //        | if (<section-expr>) { _ } else { _ }
            //        | if (<section-expr>) { _ }
            //
            private bool ParseIfStatement(IOperation operation, ArrayBuilder<AnalyzedSwitchSection> sections, out IOperation? defaultBodyOpt)
            {
                switch (operation)
                {
                    case IConditionalOperation op when CanConvert(op):
                        var section = ParseSwitchSection(op);
                        if (section is null)
                        {
                            break;
                        }
 
                        sections.Add(section);
 
                        if (op.WhenFalse is null)
                        {
                            defaultBodyOpt = null;
                        }
                        else if (!ParseIfStatementOrBlock(op.WhenFalse, sections, out defaultBodyOpt))
                        {
                            defaultBodyOpt = op.WhenFalse;
                        }
 
                        return true;
                }
 
                defaultBodyOpt = null;
                return false;
            }
 
            private bool ParseIfStatementOrBlock(IOperation op, ArrayBuilder<AnalyzedSwitchSection> sections, out IOperation? defaultBodyOpt)
            {
                return op is IBlockOperation block
                    ? ParseIfStatementSequence(block.Operations.AsSpan(), sections, out defaultBodyOpt)
                    : ParseIfStatement(op, sections, out defaultBodyOpt);
            }
 
            private AnalyzedSwitchSection? ParseSwitchSection(IConditionalOperation operation)
            {
                using var _ = ArrayBuilder<AnalyzedSwitchLabel>.GetInstance(out var labels);
                if (!ParseSwitchLabels(operation.Condition, labels))
                {
                    return null;
                }
 
                return new AnalyzedSwitchSection(labels.ToImmutable(), operation.WhenTrue, operation.Syntax);
            }
 
            // Tree to parse:
            //
            //    <section-expr>
            //        : <section-expr> || <pattern-expr>
            //        | <pattern-expr>
            //
            private bool ParseSwitchLabels(IOperation operation, ArrayBuilder<AnalyzedSwitchLabel> labels)
            {
                if (operation is IBinaryOperation { OperatorKind: ConditionalOr } op)
                {
                    if (!ParseSwitchLabels(op.LeftOperand, labels))
                    {
                        return false;
                    }
 
                    operation = op.RightOperand;
                }
 
                var label = ParseSwitchLabel(operation);
                if (label is null)
                {
                    return false;
                }
 
                labels.Add(label);
                return true;
            }
 
            private AnalyzedSwitchLabel? ParseSwitchLabel(IOperation operation)
            {
                using var _ = ArrayBuilder<TExpressionSyntax>.GetInstance(out var guards);
                var pattern = ParsePattern(operation, guards);
                if (pattern is null)
                    return null;
 
                return new AnalyzedSwitchLabel(pattern, guards.ToImmutable());
            }
 
            private enum ConstantResult
            {
                /// <summary>
                /// None of operands were constant.
                /// </summary>
                None,
                /// <summary>
                /// Signifies that the left operand is the constant.
                /// </summary>
                Left,
                /// <summary>
                /// Signifies that the right operand is the constant.
                /// </summary>
                Right,
            }
 
            private ConstantResult DetermineConstant(IBinaryOperation op)
            {
                return (op.LeftOperand, op.RightOperand) switch
                {
                    var (e, v) when IsConstant(v) && CheckTargetExpression(e) && CheckConstantType(v) => ConstantResult.Right,
                    var (v, e) when IsConstant(v) && CheckTargetExpression(e) && CheckConstantType(v) => ConstantResult.Left,
                    _ => ConstantResult.None,
                };
            }
 
            // Tree to parse:
            //
            //    <pattern-expr>
            //        : <pattern-expr> && <expr>                         // C#
            //        | <expr0> is <pattern>                             // C#
            //        | <expr0> is <type>                                // C#
            //        | <expr0> == <const-expr>                          // C#, VB
            //        | <expr0> <comparison-op> <const>                  //     VB
            //        | ( <expr0> >= <const> | <const> <= <expr0> )
            //           && ( <expr0> <= <const> | <const> >= <expr0> )  //     VB
            //        | ( <expr0> <= <const> | <const> >= <expr0> )
            //           && ( <expr0> >= <const> | <const> <= <expr0> )  //     VB
            //
            private AnalyzedPattern? ParsePattern(IOperation operation, ArrayBuilder<TExpressionSyntax> guards)
            {
                switch (operation)
                {
                    case IBinaryOperation { OperatorKind: ConditionalAnd } op
                        when Supports(Feature.RangePattern) && GetRangeBounds(op) is (TExpressionSyntax lower, TExpressionSyntax higher):
                        return new AnalyzedPattern.Range(lower, higher);
 
                    case IBinaryOperation { OperatorKind: BinaryOperatorKind.Equals } op:
                        return DetermineConstant(op) switch
                        {
                            ConstantResult.Left when op.LeftOperand.Syntax is TExpressionSyntax left
                                => new AnalyzedPattern.Constant(left),
                            ConstantResult.Right when op.RightOperand.Syntax is TExpressionSyntax right
                                => new AnalyzedPattern.Constant(right),
                            _ => null
                        };
 
                    case IBinaryOperation { OperatorKind: NotEquals } op
                        when Supports(Feature.InequalityPattern):
                        return ParseRelationalPattern(op);
 
                    case IBinaryOperation op
                        when Supports(Feature.RelationalPattern) && IsRelationalOperator(op.OperatorKind):
                        return ParseRelationalPattern(op);
 
                    // Check this below the cases that produce Relational/Ranges.  We would prefer to use those if
                    // available before utilizing a CaseGuard.
                    case IBinaryOperation { OperatorKind: ConditionalAnd } op
                        when Supports(Feature.AndPattern | Feature.CaseGuard):
                        {
                            var leftPattern = ParsePattern(op.LeftOperand, guards);
                            if (leftPattern == null)
                                return null;
 
                            if (Supports(Feature.AndPattern))
                            {
                                var guardCount = guards.Count;
                                var rightPattern = ParsePattern(op.RightOperand, guards);
                                if (rightPattern != null)
                                    return new AnalyzedPattern.And(leftPattern, rightPattern);
 
                                // Making a pattern out of the RHS didn't work.  Reset the guards back to where we started.
                                guards.Count = guardCount;
                            }
 
                            if (Supports(Feature.CaseGuard) && op.RightOperand.Syntax is TExpressionSyntax node)
                            {
                                guards.Add(node);
                                return leftPattern;
                            }
 
                            return null;
                        }
 
                    case IIsTypeOperation op
                        when Supports(Feature.IsTypePattern) && CheckTargetExpression(op.ValueOperand) && op.Syntax is TIsExpressionSyntax node:
                        return new AnalyzedPattern.Type(node);
 
                    case IIsPatternOperation op
                        when Supports(Feature.SourcePattern) && CheckTargetExpression(op.Value) && op.Pattern.Syntax is TPatternSyntax pattern:
                        return new AnalyzedPattern.Source(pattern);
 
                    case IParenthesizedOperation op:
                        return ParsePattern(op.Operand, guards);
                }
 
                return null;
            }
 
            private AnalyzedPattern? ParseRelationalPattern(IBinaryOperation op)
            {
                return DetermineConstant(op) switch
                {
                    ConstantResult.Left when op.LeftOperand.Syntax is TExpressionSyntax left
                        => new AnalyzedPattern.Relational(Flip(op.OperatorKind), left),
                    ConstantResult.Right when op.RightOperand.Syntax is TExpressionSyntax right
                        => new AnalyzedPattern.Relational(op.OperatorKind, right),
                    _ => null
                };
            }
 
            private enum BoundKind
            {
                /// <summary>
                /// Not a range bound.
                /// </summary>
                None,
                /// <summary>
                /// Signifies that the lower-bound of a range pattern
                /// </summary>
                Lower,
                /// <summary>
                /// Signifies that the higher-bound of a range pattern
                /// </summary>
                Higher,
            }
 
            private (SyntaxNode Lower, SyntaxNode Higher) GetRangeBounds(IBinaryOperation op)
            {
                if (op is not
                    { LeftOperand: IBinaryOperation left, RightOperand: IBinaryOperation right })
                {
                    return default;
                }
 
                return (GetRangeBound(left), GetRangeBound(right)) switch
                {
                    ({ Kind: BoundKind.Lower } low, { Kind: BoundKind.Higher } high)
                        when CheckTargetExpression(low.Expression, high.Expression) => (low.Value.Syntax, high.Value.Syntax),
                    ({ Kind: BoundKind.Higher } high, { Kind: BoundKind.Lower } low)
                        when CheckTargetExpression(low.Expression, high.Expression) => (low.Value.Syntax, high.Value.Syntax),
                    _ => default
                };
 
                bool CheckTargetExpression(IOperation left, IOperation right)
                    => _syntaxFacts.AreEquivalent(left.Syntax, right.Syntax) && this.CheckTargetExpression(left);
            }
 
            private static (BoundKind Kind, IOperation Expression, IOperation Value) GetRangeBound(IBinaryOperation op)
            {
                return op.OperatorKind switch
                {
                    // 5 <= i
                    LessThanOrEqual when IsConstant(op.LeftOperand) => (BoundKind.Lower, op.RightOperand, op.LeftOperand),
                    // i <= 5
                    LessThanOrEqual when IsConstant(op.RightOperand) => (BoundKind.Higher, op.LeftOperand, op.RightOperand),
                    // 5 >= i
                    GreaterThanOrEqual when IsConstant(op.LeftOperand) => (BoundKind.Higher, op.RightOperand, op.LeftOperand),
                    // i >= 5
                    GreaterThanOrEqual when IsConstant(op.RightOperand) => (BoundKind.Lower, op.LeftOperand, op.RightOperand),
                    _ => default
                };
            }
 
            /// <summary>
            /// Changes the direction the operator is pointing
            /// </summary>
            private static BinaryOperatorKind Flip(BinaryOperatorKind operatorKind)
            {
                return operatorKind switch
                {
                    LessThan => GreaterThan,
                    LessThanOrEqual => GreaterThanOrEqual,
                    GreaterThanOrEqual => LessThanOrEqual,
                    GreaterThan => LessThan,
                    NotEquals => NotEquals,
                    var v => throw ExceptionUtilities.UnexpectedValue(v)
                };
            }
 
            private static bool IsRelationalOperator(BinaryOperatorKind operatorKind)
            {
                switch (operatorKind)
                {
                    case LessThan:
                    case LessThanOrEqual:
                    case GreaterThanOrEqual:
                    case GreaterThan:
                        return true;
                    default:
                        return false;
                }
            }
 
            private static bool IsConstant(IOperation operation)
            {
                // Constants do not propagate to conversions
                return operation is IConversionOperation op
                    ? IsConstant(op.Operand)
                    : operation.ConstantValue.HasValue;
            }
 
            private bool CheckTargetExpression(IOperation operation)
            {
                operation = operation.WalkDownConversion();
 
                var expression = operation.Syntax;
                // If we have not figured the switch expression yet,
                // we will assume that the first expression is the one.
                if (_switchTargetExpression is null)
                {
                    RoslynDebug.Assert(_switchTargetType is null);
 
                    _switchTargetExpression = expression;
                    _switchTargetType = operation.Type;
                    return true;
                }
 
                return _syntaxFacts.AreEquivalent(expression, _switchTargetExpression);
            }
 
            private bool CheckConstantType(IOperation operation)
            {
                RoslynDebug.AssertNotNull(operation.SemanticModel);
                RoslynDebug.AssertNotNull(_switchTargetType);
 
                return CanImplicitlyConvert(operation.SemanticModel, operation.Syntax, _switchTargetType);
            }
        }
 
        [Flags]
        internal enum Feature
        {
            None = 0,
            // VB/C# 9.0 features
            RelationalPattern = 1,
            // VB features
            InequalityPattern = 1 << 1,
            RangePattern = 1 << 2,
            // C# 7.0 features
            SourcePattern = 1 << 3,
            IsTypePattern = 1 << 4,
            CaseGuard = 1 << 5,
            // C# 8.0 features
            SwitchExpression = 1 << 6,
            // C# 9.0 features
            OrPattern = 1 << 7,
            AndPattern = 1 << 8,
            TypePattern = 1 << 9,
        }
    }
}