File: AbstractUseNullPropagationDiagnosticAnalyzer.cs
Web Access
Project: ..\..\..\src\CodeStyle\Core\Analyzers\Microsoft.CodeAnalysis.CodeStyle.csproj (Microsoft.CodeAnalysis.CodeStyle)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using Microsoft.CodeAnalysis.CodeStyle;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Shared.Extensions;
 
namespace Microsoft.CodeAnalysis.UseNullPropagation
{
    internal static class UseNullPropagationConstants
    {
        public const string WhenPartIsNullable = nameof(WhenPartIsNullable);
    }
 
    /// <summary>
    /// Looks for code snippets similar to <c>x == null ? null : x.Y()</c> and converts it to <c>x?.Y()</c>.  This form is also supported:
    /// <code>
    /// if (x != null)
    ///     x.Y();
    /// </code>
    /// </summary>
    internal abstract partial class AbstractUseNullPropagationDiagnosticAnalyzer<
        TSyntaxKind,
        TExpressionSyntax,
        TStatementSyntax,
        TConditionalExpressionSyntax,
        TBinaryExpressionSyntax,
        TInvocationExpressionSyntax,
        TConditionalAccessExpressionSyntax,
        TElementAccessExpressionSyntax,
        TMemberAccessExpressionSyntax,
        TIfStatementSyntax,
        TExpressionStatementSyntax> : AbstractBuiltInCodeStyleDiagnosticAnalyzer
        where TSyntaxKind : struct
        where TExpressionSyntax : SyntaxNode
        where TStatementSyntax : SyntaxNode
        where TConditionalExpressionSyntax : TExpressionSyntax
        where TBinaryExpressionSyntax : TExpressionSyntax
        where TInvocationExpressionSyntax : TExpressionSyntax
        where TConditionalAccessExpressionSyntax : TExpressionSyntax
        where TElementAccessExpressionSyntax : TExpressionSyntax
        where TMemberAccessExpressionSyntax : TExpressionSyntax
        where TIfStatementSyntax : TStatementSyntax
        where TExpressionStatementSyntax : TStatementSyntax
    {
        private static readonly ImmutableDictionary<string, string?> s_whenPartIsNullableProperties =
            ImmutableDictionary<string, string?>.Empty.Add(UseNullPropagationConstants.WhenPartIsNullable, "");
 
        protected AbstractUseNullPropagationDiagnosticAnalyzer()
            : base(IDEDiagnosticIds.UseNullPropagationDiagnosticId,
                   EnforceOnBuildValues.UseNullPropagation,
                   CodeStyleOptions2.PreferNullPropagation,
                   new LocalizableResourceString(nameof(AnalyzersResources.Use_null_propagation), AnalyzersResources.ResourceManager, typeof(AnalyzersResources)),
                   new LocalizableResourceString(nameof(AnalyzersResources.Null_check_can_be_simplified), AnalyzersResources.ResourceManager, typeof(AnalyzersResources)))
        {
        }
 
        public override DiagnosticAnalyzerCategory GetAnalyzerCategory()
            => DiagnosticAnalyzerCategory.SemanticSpanAnalysis;
 
        protected abstract bool ShouldAnalyze(Compilation compilation);
 
        protected abstract TSyntaxKind IfStatementSyntaxKind { get; }
        protected abstract ISyntaxFacts GetSyntaxFacts();
        protected abstract bool IsInExpressionTree(SemanticModel semanticModel, SyntaxNode node, INamedTypeSymbol? expressionTypeOpt, CancellationToken cancellationToken);
 
        protected abstract bool TryAnalyzePatternCondition(
            ISyntaxFacts syntaxFacts, TExpressionSyntax conditionNode,
            [NotNullWhen(true)] out TExpressionSyntax? conditionPartToCheck, out bool isEquals);
 
        protected override void InitializeWorker(AnalysisContext context)
        {
            context.RegisterCompilationStartAction(context =>
            {
                if (!ShouldAnalyze(context.Compilation))
                    return;
 
                var expressionType = context.Compilation.ExpressionOfTType();
 
                var objectType = context.Compilation.GetSpecialType(SpecialType.System_Object);
                var referenceEqualsMethod = objectType?.GetMembers(nameof(ReferenceEquals))
                                                          .OfType<IMethodSymbol>()
                                                          .FirstOrDefault(m => m.DeclaredAccessibility == Accessibility.Public &&
                                                                               m.Parameters.Length == 2);
 
                var syntaxKinds = GetSyntaxFacts().SyntaxKinds;
                context.RegisterSyntaxNodeAction(
                    context => AnalyzeTernaryConditionalExpression(context, expressionType, referenceEqualsMethod),
                    syntaxKinds.Convert<TSyntaxKind>(syntaxKinds.TernaryConditionalExpression));
                context.RegisterSyntaxNodeAction(
                    context => AnalyzeIfStatement(context, referenceEqualsMethod),
                    IfStatementSyntaxKind);
            });
        }
 
        private void AnalyzeTernaryConditionalExpression(
            SyntaxNodeAnalysisContext context,
            INamedTypeSymbol? expressionType,
            IMethodSymbol? referenceEqualsMethod)
        {
            var cancellationToken = context.CancellationToken;
            var conditionalExpression = (TConditionalExpressionSyntax)context.Node;
 
            var option = context.GetAnalyzerOptions().PreferNullPropagation;
            if (!option.Value)
                return;
 
            var syntaxFacts = GetSyntaxFacts();
            syntaxFacts.GetPartsOfConditionalExpression(
                conditionalExpression, out var condition, out var whenTrue, out var whenFalse);
 
            var conditionNode = (TExpressionSyntax)condition;
 
            var whenTrueNode = (TExpressionSyntax)syntaxFacts.WalkDownParentheses(whenTrue);
            var whenFalseNode = (TExpressionSyntax)syntaxFacts.WalkDownParentheses(whenFalse);
 
            if (!TryAnalyzeCondition(
                    context, syntaxFacts, referenceEqualsMethod, conditionNode,
                    out var conditionPartToCheck, out var isEquals))
            {
                return;
            }
 
            // Needs to be of the form:
            //      x == null ? null : ...    or
            //      x != null ? ...  : null;
            if (isEquals && !syntaxFacts.IsNullLiteralExpression(whenTrueNode))
                return;
 
            if (!isEquals && !syntaxFacts.IsNullLiteralExpression(whenFalseNode))
                return;
 
            var whenPartToCheck = isEquals ? whenFalseNode : whenTrueNode;
 
            var semanticModel = context.SemanticModel;
            var whenPartMatch = GetWhenPartMatch(syntaxFacts, semanticModel, conditionPartToCheck, whenPartToCheck, cancellationToken);
            if (whenPartMatch == null)
                return;
 
            // can't use ?. on a pointer
            var whenPartType = semanticModel.GetTypeInfo(whenPartMatch, cancellationToken).Type;
            if (whenPartType is IPointerTypeSymbol)
                return;
 
            var type = semanticModel.GetTypeInfo(conditionalExpression, cancellationToken).Type;
            if (type?.IsValueType == true)
            {
                if (type is not INamedTypeSymbol namedType || namedType.ConstructedFrom.SpecialType != SpecialType.System_Nullable_T)
                {
                    // User has something like:  If(str is nothing, nothing, str.Length)
                    // In this case, converting to str?.Length changes the type of this from
                    // int to int?
                    return;
                }
                // But for a nullable type, such as  If(c is nothing, nothing, c.nullable)
                // converting to c?.nullable doesn't affect the type
            }
 
            if (syntaxFacts.IsSimpleMemberAccessExpression(whenPartToCheck))
            {
                // `x == null ? x : x.M` cannot be converted to `x?.M` when M is a method symbol.
                syntaxFacts.GetPartsOfMemberAccessExpression(whenPartToCheck, out _, out var name);
                if (semanticModel.GetSymbolInfo(name, cancellationToken).GetAnySymbol() is IMethodSymbol)
                    return;
            }
 
            // ?. is not available in expression-trees.  Disallow the fix in that case.
            if (IsInExpressionTree(semanticModel, conditionNode, expressionType, cancellationToken))
                return;
 
            var locations = ImmutableArray.Create(
                conditionalExpression.GetLocation(),
                conditionPartToCheck.GetLocation(),
                whenPartToCheck.GetLocation());
 
            var whenPartIsNullable = whenPartType?.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T;
            var properties = whenPartIsNullable
                ? s_whenPartIsNullableProperties
                : ImmutableDictionary<string, string?>.Empty;
 
            context.ReportDiagnostic(DiagnosticHelper.Create(
                Descriptor,
                conditionalExpression.GetLocation(),
                option.Notification.Severity,
                locations,
                properties));
        }
 
        private bool TryAnalyzeCondition(
            SyntaxNodeAnalysisContext context,
            ISyntaxFacts syntaxFacts,
            IMethodSymbol? referenceEqualsMethod,
            TExpressionSyntax condition,
            [NotNullWhen(true)] out TExpressionSyntax? conditionPartToCheck,
            out bool isEquals)
        {
            condition = (TExpressionSyntax)syntaxFacts.WalkDownParentheses(condition);
            var conditionIsNegated = false;
            if (syntaxFacts.IsLogicalNotExpression(condition))
            {
                conditionIsNegated = true;
                condition = (TExpressionSyntax)syntaxFacts.WalkDownParentheses(
                    syntaxFacts.GetOperandOfPrefixUnaryExpression(condition));
            }
 
            var result = condition switch
            {
                TBinaryExpressionSyntax binaryExpression => TryAnalyzeBinaryExpressionCondition(
                        syntaxFacts, binaryExpression, out conditionPartToCheck, out isEquals),
 
                TInvocationExpressionSyntax invocation => TryAnalyzeInvocationCondition(
                        context, syntaxFacts, referenceEqualsMethod, invocation,
                        out conditionPartToCheck, out isEquals),
 
                _ => TryAnalyzePatternCondition(syntaxFacts, condition, out conditionPartToCheck, out isEquals),
            };
 
            if (conditionIsNegated)
                isEquals = !isEquals;
 
            return result;
        }
 
        private static bool TryAnalyzeBinaryExpressionCondition(
            ISyntaxFacts syntaxFacts, TBinaryExpressionSyntax condition,
            [NotNullWhen(true)] out TExpressionSyntax? conditionPartToCheck, out bool isEquals)
        {
            var syntaxKinds = syntaxFacts.SyntaxKinds;
            isEquals = syntaxKinds.ReferenceEqualsExpression == condition.RawKind;
            var isNotEquals = syntaxKinds.ReferenceNotEqualsExpression == condition.RawKind;
            if (!isEquals && !isNotEquals)
            {
                conditionPartToCheck = null;
                return false;
            }
            else
            {
                syntaxFacts.GetPartsOfBinaryExpression(condition, out var conditionLeft, out var conditionRight);
                conditionPartToCheck = GetConditionPartToCheck(syntaxFacts, (TExpressionSyntax)conditionLeft, (TExpressionSyntax)conditionRight);
                return conditionPartToCheck != null;
            }
        }
 
        private static bool TryAnalyzeInvocationCondition(
            SyntaxNodeAnalysisContext context,
            ISyntaxFacts syntaxFacts,
            IMethodSymbol? referenceEqualsMethod,
            TInvocationExpressionSyntax invocation,
            [NotNullWhen(true)] out TExpressionSyntax? conditionPartToCheck,
            out bool isEquals)
        {
            conditionPartToCheck = null;
            isEquals = true;
 
            if (referenceEqualsMethod == null)
                return false;
 
            var expression = syntaxFacts.GetExpressionOfInvocationExpression(invocation);
            var nameNode = syntaxFacts.IsIdentifierName(expression)
                ? expression
                : syntaxFacts.IsSimpleMemberAccessExpression(expression)
                    ? syntaxFacts.GetNameOfMemberAccessExpression(expression)
                    : null;
 
            if (!syntaxFacts.IsIdentifierName(nameNode))
            {
                return false;
            }
 
            syntaxFacts.GetNameAndArityOfSimpleName(nameNode, out var name, out _);
            if (!syntaxFacts.StringComparer.Equals(name, nameof(ReferenceEquals)))
            {
                return false;
            }
 
            var arguments = syntaxFacts.GetArgumentsOfInvocationExpression(invocation);
            if (arguments.Count != 2)
            {
                return false;
            }
 
            var conditionLeft = (TExpressionSyntax)syntaxFacts.GetExpressionOfArgument(arguments[0]);
            var conditionRight = (TExpressionSyntax)syntaxFacts.GetExpressionOfArgument(arguments[1]);
            if (conditionLeft == null || conditionRight == null)
            {
                return false;
            }
 
            conditionPartToCheck = GetConditionPartToCheck(syntaxFacts, conditionLeft, conditionRight);
            if (conditionPartToCheck == null)
            {
                return false;
            }
 
            var semanticModel = context.SemanticModel;
            var cancellationToken = context.CancellationToken;
            var symbol = semanticModel.GetSymbolInfo(invocation, cancellationToken).Symbol;
            return referenceEqualsMethod.Equals(symbol);
        }
 
        private static TExpressionSyntax? GetConditionPartToCheck(
            ISyntaxFacts syntaxFacts, TExpressionSyntax conditionLeft, TExpressionSyntax conditionRight)
        {
            var conditionLeftIsNull = syntaxFacts.IsNullLiteralExpression(conditionLeft);
            var conditionRightIsNull = syntaxFacts.IsNullLiteralExpression(conditionRight);
 
            if (conditionRightIsNull && conditionLeftIsNull)
            {
                // null == null    nothing to do here.
                return null;
            }
 
            if (!conditionRightIsNull && !conditionLeftIsNull)
            {
                return null;
            }
 
            return conditionRightIsNull ? conditionLeft : conditionRight;
        }
 
        internal static TExpressionSyntax? GetWhenPartMatch(
            ISyntaxFacts syntaxFacts,
            SemanticModel semanticModel,
            TExpressionSyntax expressionToMatch,
            TExpressionSyntax whenPart,
            CancellationToken cancellationToken)
        {
            expressionToMatch = RemoveObjectCastIfAny(syntaxFacts, semanticModel, expressionToMatch, cancellationToken);
            var current = whenPart;
            while (true)
            {
                var unwrapped = Unwrap(syntaxFacts, current);
                if (unwrapped == null)
                    return null;
 
                if (syntaxFacts.IsSimpleMemberAccessExpression(current) || current is TElementAccessExpressionSyntax)
                {
                    if (syntaxFacts.AreEquivalent(unwrapped, expressionToMatch))
                        return unwrapped;
                }
 
                current = unwrapped;
            }
        }
 
        private static TExpressionSyntax RemoveObjectCastIfAny(
            ISyntaxFacts syntaxFacts, SemanticModel semanticModel, TExpressionSyntax node, CancellationToken cancellationToken)
        {
            if (syntaxFacts.IsCastExpression(node))
            {
                syntaxFacts.GetPartsOfCastExpression(node, out var type, out var expression);
                var typeSymbol = semanticModel.GetTypeInfo(type, cancellationToken).Type;
 
                if (typeSymbol?.SpecialType == SpecialType.System_Object)
                    return (TExpressionSyntax)expression;
            }
 
            return node;
        }
 
        private static TExpressionSyntax? Unwrap(ISyntaxFacts syntaxFacts, TExpressionSyntax node)
        {
            node = (TExpressionSyntax)syntaxFacts.WalkDownParentheses(node);
 
            if (node is TInvocationExpressionSyntax invocation)
                return (TExpressionSyntax)syntaxFacts.GetExpressionOfInvocationExpression(invocation);
 
            if (syntaxFacts.IsSimpleMemberAccessExpression(node))
                return (TExpressionSyntax?)syntaxFacts.GetExpressionOfMemberAccessExpression(node);
 
            if (node is TConditionalAccessExpressionSyntax conditionalAccess)
                return (TExpressionSyntax)syntaxFacts.GetExpressionOfConditionalAccessExpression(conditionalAccess);
 
            if (node is TElementAccessExpressionSyntax elementAccess)
                return (TExpressionSyntax?)syntaxFacts.GetExpressionOfElementAccessExpression(elementAccess);
 
            return null;
        }
    }
}