File: InheritanceDistanceComparer.cs
Web Access
Project: ..\..\..\src\CodeStyle\Core\CodeFixes\Microsoft.CodeAnalysis.CodeStyle.Fixes.csproj (Microsoft.CodeAnalysis.CodeStyle.Fixes)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
 
namespace Microsoft.CodeAnalysis.CodeFixes.AddExplicitCast
{
    /// <summary>
    /// The item is the pair of target argument expression and its conversion type
    /// <para/>
    /// Sort pairs using conversion types by inheritance distance from the base type in ascending order,
    /// i.e., less specific type has higher priority because it has less probability to make mistakes
    /// <para/>
    /// For example:
    /// class Base { }
    /// class Derived1 : Base { }
    /// class Derived2 : Derived1 { }
    /// 
    /// void Foo(Derived1 d1) { }
    /// void Foo(Derived2 d2) { }
    /// 
    /// Base b = new Derived1();
    /// Foo([||]b);
    /// 
    /// operations:
    /// 1. Convert type to 'Derived1'
    /// 2. Convert type to 'Derived2'
    /// 
    /// 'Derived1' is less specific than 'Derived2' compared to 'Base'
    /// </summary>
    internal sealed class InheritanceDistanceComparer<TExpressionSyntax>
    : IComparer<(TExpressionSyntax syntax, ITypeSymbol symbol)>
    where TExpressionSyntax : SyntaxNode
    {
        private readonly SemanticModel _semanticModel;
 
        public InheritanceDistanceComparer(SemanticModel semanticModel)
        {
            _semanticModel = semanticModel;
        }
 
        public int Compare((TExpressionSyntax syntax, ITypeSymbol symbol) x,
            (TExpressionSyntax syntax, ITypeSymbol symbol) y)
        {
            // if the argument is different, keep the original order
            if (!x.syntax.Equals(y.syntax))
            {
                return 0;
            }
            else
            {
                var baseType = _semanticModel.GetTypeInfo(x.syntax).Type;
                var xDist = GetInheritanceDistance(baseType, x.symbol);
                var yDist = GetInheritanceDistance(baseType, y.symbol);
                return xDist.CompareTo(yDist);
            }
        }
 
        /// <summary>
        /// Calculate the inheritance distance between baseType and derivedType.
        /// </summary>
        private int GetInheritanceDistanceRecursive(ITypeSymbol baseType, ITypeSymbol? derivedType)
        {
            if (derivedType == null)
                return int.MaxValue;
            if (derivedType.Equals(baseType))
                return 0;
 
            var distance = GetInheritanceDistanceRecursive(baseType, derivedType.BaseType);
 
            if (derivedType.Interfaces.Length != 0)
            {
                foreach (var interfaceType in derivedType.Interfaces)
                {
                    distance = Math.Min(GetInheritanceDistanceRecursive(baseType, interfaceType), distance);
                }
            }
 
            return distance == int.MaxValue ? distance : distance + 1;
        }
 
        /// <summary>
        /// Wrapper function of [GetInheritanceDistance], also consider the class with explicit conversion operator
        /// has the highest priority.
        /// </summary>
        private int GetInheritanceDistance(ITypeSymbol? baseType, ITypeSymbol castType)
        {
            if (baseType is null)
                return 0;
 
            var conversion = _semanticModel.Compilation.ClassifyCommonConversion(baseType, castType);
 
            // If the node has the explicit conversion operator, then it has the shortest distance
            // since explicit conversion operator is defined by users and has the highest priority 
            var distance = conversion.IsUserDefined ? 0 : GetInheritanceDistanceRecursive(baseType, castType);
            return distance;
        }
    }
}