File: AbstractUseNullPropagationDiagnosticAnalyzer_IfStatement.cs
Web Access
Project: ..\..\..\src\CodeStyle\Core\Analyzers\Microsoft.CodeAnalysis.CodeStyle.csproj (Microsoft.CodeAnalysis.CodeStyle)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Shared.Extensions;
 
namespace Microsoft.CodeAnalysis.UseNullPropagation;
 
internal abstract partial class AbstractUseNullPropagationDiagnosticAnalyzer<
    TSyntaxKind,
    TExpressionSyntax,
    TStatementSyntax,
    TConditionalExpressionSyntax,
    TBinaryExpressionSyntax,
    TInvocationExpressionSyntax,
    TConditionalAccessExpressionSyntax,
    TElementAccessExpressionSyntax,
    TMemberAccessExpressionSyntax,
    TIfStatementSyntax,
    TExpressionStatementSyntax>
{
    protected abstract bool TryGetPartsOfIfStatement(
        TIfStatementSyntax ifStatement, [NotNullWhen(true)] out TExpressionSyntax? condition, [NotNullWhen(true)] out TStatementSyntax? trueStatement);
 
    private void AnalyzeIfStatement(
        SyntaxNodeAnalysisContext context,
        IMethodSymbol? referenceEqualsMethod)
    {
        var cancellationToken = context.CancellationToken;
        var option = context.GetAnalyzerOptions().PreferNullPropagation;
        if (!option.Value)
            return;
 
        var syntaxFacts = GetSyntaxFacts();
        var ifStatement = (TIfStatementSyntax)context.Node;
 
        // The true-statement if the if-statement has to be a statement of the form `<expr1>.Name(...)`;
        if (!TryGetPartsOfIfStatement(ifStatement, out var condition, out var trueStatement))
            return;
 
        if (trueStatement is not TExpressionStatementSyntax expressionStatement)
            return;
 
        // Now see if the `if (<condition>)` looks like an appropriate null check.
        if (!TryAnalyzeCondition(context, syntaxFacts, referenceEqualsMethod, condition, out var conditionPartToCheck, out var isEquals))
            return;
 
        // Ok, we have `if (<expr2> == null)` or `if (<expr2> != null)` (or some similar form of that.  `conditionPartToCheck` will be `<expr2>` here.
        // We only support `if (<expr2> != null)`.  Fail out if we have the alternate form.
        if (isEquals)
            return;
 
        var semanticModel = context.SemanticModel;
        var whenPartMatch = GetWhenPartMatch(
            syntaxFacts, semanticModel, conditionPartToCheck,
            (TExpressionSyntax)syntaxFacts.GetExpressionOfExpressionStatement(expressionStatement),
            cancellationToken);
        if (whenPartMatch == null)
            return;
 
        // If we have:
        //
        // D D { get; }
        // 
        // public void Test()
        // {
        //     if (D != null)
        //     {
        //         D.Method(D);
        //     }
        // }
        //
        // Then `D.Method` is actually an access of a static member, and cannot be converted to `D?.Method`.
        if (whenPartMatch.Parent is TMemberAccessExpressionSyntax memberAccess)
        {
            var memberSymbol = semanticModel.GetSymbolInfo(memberAccess, cancellationToken).GetAnySymbol();
            if (memberSymbol?.IsStatic is true)
                return;
        }
 
        // can't use ?. on a pointer
        var whenPartType = semanticModel.GetTypeInfo(whenPartMatch, cancellationToken).Type;
        if (whenPartType is IPointerTypeSymbol or IFunctionPointerTypeSymbol)
            return;
 
        var whenPartIsNullable = semanticModel.GetTypeInfo(whenPartMatch).Type?.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T;
        var properties = whenPartIsNullable
            ? s_whenPartIsNullableProperties
            : ImmutableDictionary<string, string?>.Empty;
 
        context.ReportDiagnostic(DiagnosticHelper.Create(
            Descriptor,
            ifStatement.GetFirstToken().GetLocation(),
            option.Notification.Severity,
            ImmutableArray.Create(
                ifStatement.GetLocation(),
                trueStatement.GetLocation(),
                whenPartMatch.GetLocation()),
            properties));
    }
}