File: INamedTypeSymbolExtensions.cs
Web Access
Project: ..\..\..\src\CodeStyle\Core\Analyzers\Microsoft.CodeAnalysis.CodeStyle.csproj (Microsoft.CodeAnalysis.CodeStyle)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.PooledObjects;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.Shared.Extensions
{
    internal static partial class INamedTypeSymbolExtensions
    {
        public static IEnumerable<INamedTypeSymbol> GetBaseTypesAndThis(this INamedTypeSymbol? namedType)
        {
            var current = namedType;
            while (current != null)
            {
                yield return current;
                current = current.BaseType;
            }
        }
 
        public static ImmutableArray<ITypeParameterSymbol> GetAllTypeParameters(this INamedTypeSymbol? symbol)
        {
            var stack = GetContainmentStack(symbol);
            return stack.SelectMany(n => n.TypeParameters).ToImmutableArray();
        }
 
        public static IEnumerable<ITypeSymbol> GetAllTypeArguments(this INamedTypeSymbol? symbol)
        {
            var stack = GetContainmentStack(symbol);
            return stack.SelectMany(n => n.TypeArguments);
        }
 
        private static Stack<INamedTypeSymbol> GetContainmentStack(INamedTypeSymbol? symbol)
        {
            var stack = new Stack<INamedTypeSymbol>();
            for (var current = symbol; current != null; current = current.ContainingType)
            {
                stack.Push(current);
            }
 
            return stack;
        }
 
        public static bool IsContainedWithin([NotNullWhen(returnValue: true)] this INamedTypeSymbol? symbol, INamedTypeSymbol outer)
        {
            // TODO(cyrusn): Should we be using OriginalSymbol here?
            for (var current = symbol; current != null; current = current.ContainingType)
            {
                if (current.Equals(outer))
                {
                    return true;
                }
            }
 
            return false;
        }
 
        public static ISymbol? FindImplementationForAbstractMember(this INamedTypeSymbol? type, ISymbol symbol)
        {
            if (symbol.IsAbstract)
            {
                return type.GetBaseTypesAndThis().SelectMany(t => t.GetMembers(symbol.Name))
                                                 .FirstOrDefault(s => symbol.Equals(s.GetOverriddenMember()));
            }
 
            return null;
        }
 
        private static bool ImplementationExists(INamedTypeSymbol classOrStructType, ISymbol member)
            => classOrStructType.FindImplementationForInterfaceMember(member) != null;
 
        private static bool IsImplemented(
            this INamedTypeSymbol classOrStructType,
            ISymbol member,
            Func<INamedTypeSymbol, ISymbol, bool> isValidImplementation,
            CancellationToken cancellationToken)
        {
            if (member.ContainingType.TypeKind == TypeKind.Interface)
            {
                if (member.Kind == SymbolKind.Property)
                {
                    return IsInterfacePropertyImplemented(classOrStructType, (IPropertySymbol)member);
                }
                else
                {
                    return isValidImplementation(classOrStructType, member);
                }
            }
 
            if (member.IsAbstract)
            {
                if (member.Kind == SymbolKind.Property)
                {
                    return IsAbstractPropertyImplemented(classOrStructType, (IPropertySymbol)member);
                }
                else
                {
                    return classOrStructType.FindImplementationForAbstractMember(member) != null;
                }
            }
 
            return true;
        }
 
        private static bool IsInterfacePropertyImplemented(INamedTypeSymbol classOrStructType, IPropertySymbol propertySymbol)
        {
            // A property is only fully implemented if both it's setter and getter is implemented.
 
            return IsAccessorImplemented(propertySymbol.GetMethod, classOrStructType) && IsAccessorImplemented(propertySymbol.SetMethod, classOrStructType);
 
            // local functions
 
            static bool IsAccessorImplemented(IMethodSymbol? accessor, INamedTypeSymbol classOrStructType)
            {
                return accessor == null || !IsImplementable(accessor) || classOrStructType.FindImplementationForInterfaceMember(accessor) != null;
            }
        }
 
        private static bool IsAbstractPropertyImplemented(INamedTypeSymbol classOrStructType, IPropertySymbol propertySymbol)
        {
            // A property is only fully implemented if both it's setter and getter is implemented.
            if (propertySymbol.GetMethod != null)
            {
                if (classOrStructType.FindImplementationForAbstractMember(propertySymbol.GetMethod) == null)
                {
                    return false;
                }
            }
 
            if (propertySymbol.SetMethod != null)
            {
                if (classOrStructType.FindImplementationForAbstractMember(propertySymbol.SetMethod) == null)
                {
                    return false;
                }
            }
 
            return true;
        }
 
        private static bool IsExplicitlyImplemented(
            this INamedTypeSymbol classOrStructType,
            ISymbol member,
            Func<INamedTypeSymbol, ISymbol, bool> isValid,
            CancellationToken cancellationToken)
        {
            var implementation = classOrStructType.FindImplementationForInterfaceMember(member);
 
            if (implementation?.ContainingType.TypeKind == TypeKind.Interface)
            {
                // Treat all implementations in interfaces as explicit, even the original declaration with implementation.
                // There are no implicit interface implementations in derived interfaces and it feels reasonable to treat
                // original declaration with implementation as an explicit implementation as well, the implementation is
                // explicitly provided after all. All implementations in interfaces will be treated uniformly.
                return true;
            }
 
            return implementation switch
            {
                IEventSymbol @event => @event.ExplicitInterfaceImplementations.Length > 0,
                IMethodSymbol method => method.ExplicitInterfaceImplementations.Length > 0,
                IPropertySymbol property => property.ExplicitInterfaceImplementations.Length > 0,
                _ => false,
            };
        }
 
        public static ImmutableArray<(INamedTypeSymbol type, ImmutableArray<ISymbol> members)> GetAllUnimplementedMembers(
            this INamedTypeSymbol classOrStructType,
            IEnumerable<INamedTypeSymbol> interfaces,
            bool includeMembersRequiringExplicitImplementation,
            CancellationToken cancellationToken)
        {
            Func<INamedTypeSymbol, ISymbol, ImmutableArray<ISymbol>> GetMembers;
            if (includeMembersRequiringExplicitImplementation)
            {
                GetMembers = GetExplicitlyImplementableMembers;
            }
            else
            {
                GetMembers = GetImplicitlyImplementableMembers;
            }
 
            return classOrStructType.GetAllUnimplementedMembers(
                interfaces,
                IsImplemented,
                ImplementationExists,
                GetMembers,
                allowReimplementation: false,
                cancellationToken: cancellationToken);
 
            // local functions
 
            static ImmutableArray<ISymbol> GetImplicitlyImplementableMembers(INamedTypeSymbol type, ISymbol within)
            {
                if (type.TypeKind == TypeKind.Interface)
                {
                    return type.GetMembers().WhereAsArray(m => m.DeclaredAccessibility == Accessibility.Public &&
                                                               m.Kind != SymbolKind.NamedType && IsImplementable(m) &&
                                                               !IsPropertyWithNonPublicImplementableAccessor(m) &&
                                                               IsImplicitlyImplementable(m, within));
                }
 
                return type.GetMembers();
            }
 
            static bool IsPropertyWithNonPublicImplementableAccessor(ISymbol member)
            {
                if (member.Kind != SymbolKind.Property)
                {
                    return false;
                }
 
                var property = (IPropertySymbol)member;
 
                return IsNonPublicImplementableAccessor(property.GetMethod) || IsNonPublicImplementableAccessor(property.SetMethod);
            }
 
            static bool IsNonPublicImplementableAccessor(IMethodSymbol? accessor)
            {
                return accessor != null && IsImplementable(accessor) && accessor.DeclaredAccessibility != Accessibility.Public;
            }
 
            static bool IsImplicitlyImplementable(ISymbol member, ISymbol within)
            {
                if (member is IMethodSymbol { IsStatic: true, IsAbstract: true, MethodKind: MethodKind.UserDefinedOperator } method)
                {
                    // For example, the following is not implementable implicitly.
                    // interface I { static abstract int operator -(I x); }
                    // But the following is implementable:
                    // interface I<T> where T : I<T> { static abstract int operator -(T x); }
 
                    // See https://github.com/dotnet/csharplang/blob/main/spec/classes.md#unary-operators.
                    return method.Parameters.Any(static (p, within) => p.Type.Equals(within, SymbolEqualityComparer.Default), within);
                }
 
                return true;
            }
        }
 
        private static bool IsImplementable(ISymbol m)
            => m.IsVirtual || m.IsAbstract;
 
        public static ImmutableArray<(INamedTypeSymbol type, ImmutableArray<ISymbol> members)> GetAllUnimplementedMembersInThis(
            this INamedTypeSymbol classOrStructType,
            IEnumerable<INamedTypeSymbol> interfacesOrAbstractClasses,
            CancellationToken cancellationToken)
        {
            return classOrStructType.GetAllUnimplementedMembers(
                interfacesOrAbstractClasses,
                IsImplemented,
                (t, m) =>
                {
                    var implementation = classOrStructType.FindImplementationForInterfaceMember(m);
                    return implementation != null && Equals(implementation.ContainingType, classOrStructType);
                },
                GetMembers,
                allowReimplementation: true,
                cancellationToken: cancellationToken);
        }
 
        public static ImmutableArray<(INamedTypeSymbol type, ImmutableArray<ISymbol> members)> GetAllUnimplementedMembersInThis(
            this INamedTypeSymbol classOrStructType,
            IEnumerable<INamedTypeSymbol> interfacesOrAbstractClasses,
            Func<INamedTypeSymbol, ISymbol, ImmutableArray<ISymbol>> interfaceMemberGetter,
            CancellationToken cancellationToken)
        {
            return classOrStructType.GetAllUnimplementedMembers(
                interfacesOrAbstractClasses,
                IsImplemented,
                (t, m) =>
                {
                    var implementation = classOrStructType.FindImplementationForInterfaceMember(m);
                    return implementation != null && Equals(implementation.ContainingType, classOrStructType);
                },
                interfaceMemberGetter,
                allowReimplementation: true,
                cancellationToken: cancellationToken);
        }
 
        public static ImmutableArray<(INamedTypeSymbol type, ImmutableArray<ISymbol> members)> GetAllUnimplementedExplicitMembers(
            this INamedTypeSymbol classOrStructType,
            IEnumerable<INamedTypeSymbol> interfaces,
            CancellationToken cancellationToken)
        {
            return classOrStructType.GetAllUnimplementedMembers(
                interfaces,
                IsExplicitlyImplemented,
                ImplementationExists,
                GetExplicitlyImplementableMembers,
                allowReimplementation: false,
                cancellationToken: cancellationToken);
        }
 
        private static ImmutableArray<ISymbol> GetExplicitlyImplementableMembers(INamedTypeSymbol type, ISymbol within)
        {
            if (type.TypeKind == TypeKind.Interface)
            {
                return type.GetMembers().WhereAsArray(m => m.Kind != SymbolKind.NamedType &&
                                                           IsImplementable(m) && m.IsAccessibleWithin(within) &&
                                                           !IsPropertyWithInaccessibleImplementableAccessor(m, within));
            }
 
            return type.GetMembers();
        }
 
        private static bool IsPropertyWithInaccessibleImplementableAccessor(ISymbol member, ISymbol within)
        {
            if (member.Kind != SymbolKind.Property)
            {
                return false;
            }
 
            var property = (IPropertySymbol)member;
 
            return IsInaccessibleImplementableAccessor(property.GetMethod, within) || IsInaccessibleImplementableAccessor(property.SetMethod, within);
        }
 
        private static bool IsInaccessibleImplementableAccessor(IMethodSymbol? accessor, ISymbol within)
            => accessor != null && IsImplementable(accessor) && !accessor.IsAccessibleWithin(within);
 
        private static ImmutableArray<(INamedTypeSymbol type, ImmutableArray<ISymbol> members)> GetAllUnimplementedMembers(
            this INamedTypeSymbol classOrStructType,
            IEnumerable<INamedTypeSymbol> interfacesOrAbstractClasses,
            Func<INamedTypeSymbol, ISymbol, Func<INamedTypeSymbol, ISymbol, bool>, CancellationToken, bool> isImplemented,
            Func<INamedTypeSymbol, ISymbol, bool> isValidImplementation,
            Func<INamedTypeSymbol, ISymbol, ImmutableArray<ISymbol>> interfaceMemberGetter,
            bool allowReimplementation,
            CancellationToken cancellationToken)
        {
            Contract.ThrowIfNull(classOrStructType);
            Contract.ThrowIfNull(interfacesOrAbstractClasses);
            Contract.ThrowIfNull(isImplemented);
 
            if (classOrStructType.TypeKind is not TypeKind.Class and not TypeKind.Struct)
            {
                return ImmutableArray<(INamedTypeSymbol type, ImmutableArray<ISymbol> members)>.Empty;
            }
 
            if (!interfacesOrAbstractClasses.Any())
            {
                return ImmutableArray<(INamedTypeSymbol type, ImmutableArray<ISymbol> members)>.Empty;
            }
 
            if (!interfacesOrAbstractClasses.All(i => i.TypeKind == TypeKind.Interface) &&
                !interfacesOrAbstractClasses.All(i => i.IsAbstractClass()))
            {
                return ImmutableArray<(INamedTypeSymbol type, ImmutableArray<ISymbol> members)>.Empty;
            }
 
            var typesToImplement = GetTypesToImplement(classOrStructType, interfacesOrAbstractClasses, allowReimplementation, cancellationToken);
            return typesToImplement.SelectAsArray(s => (s, members: GetUnimplementedMembers(classOrStructType, s, isImplemented, isValidImplementation, interfaceMemberGetter, cancellationToken)))
                                   .WhereAsArray(t => t.members.Length > 0);
        }
 
        private static ImmutableArray<INamedTypeSymbol> GetTypesToImplement(
            INamedTypeSymbol classOrStructType,
            IEnumerable<INamedTypeSymbol> interfacesOrAbstractClasses,
            bool allowReimplementation,
            CancellationToken cancellationToken)
        {
            return interfacesOrAbstractClasses.First().TypeKind == TypeKind.Interface
                ? GetInterfacesToImplement(classOrStructType, interfacesOrAbstractClasses, allowReimplementation, cancellationToken)
                : GetAbstractClassesToImplement(interfacesOrAbstractClasses);
        }
 
        private static ImmutableArray<INamedTypeSymbol> GetAbstractClassesToImplement(
            IEnumerable<INamedTypeSymbol> abstractClasses)
        {
            return abstractClasses.SelectMany(a => a.GetBaseTypesAndThis())
                                  .Where(t => t.IsAbstractClass())
                                  .ToImmutableArray();
        }
 
        private static ImmutableArray<INamedTypeSymbol> GetInterfacesToImplement(
            INamedTypeSymbol classOrStructType,
            IEnumerable<INamedTypeSymbol> interfaces,
            bool allowReimplementation,
            CancellationToken cancellationToken)
        {
            // We need to not only implement the specified interface, but also everything it
            // inherits from.
            cancellationToken.ThrowIfCancellationRequested();
            var interfacesToImplement = new List<INamedTypeSymbol>(
                interfaces.SelectMany(i => i.GetAllInterfacesIncludingThis()).Distinct());
 
            // However, there's no need to re-implement any interfaces that our base types already
            // implement.  By definition they must contain all the necessary methods.
            var baseType = classOrStructType.BaseType;
            var alreadyImplementedInterfaces = baseType == null || allowReimplementation
                ? SpecializedCollections.EmptyEnumerable<INamedTypeSymbol>()
                : baseType.AllInterfaces;
 
            cancellationToken.ThrowIfCancellationRequested();
            interfacesToImplement.RemoveRange(alreadyImplementedInterfaces);
            return interfacesToImplement.ToImmutableArray();
        }
 
        private static ImmutableArray<ISymbol> GetUnimplementedMembers(
            this INamedTypeSymbol classOrStructType,
            INamedTypeSymbol interfaceType,
            Func<INamedTypeSymbol, ISymbol, Func<INamedTypeSymbol, ISymbol, bool>, CancellationToken, bool> isImplemented,
            Func<INamedTypeSymbol, ISymbol, bool> isValidImplementation,
            Func<INamedTypeSymbol, ISymbol, ImmutableArray<ISymbol>> interfaceMemberGetter,
            CancellationToken cancellationToken)
        {
            var q = from m in interfaceMemberGetter(interfaceType, classOrStructType)
                    where m.Kind != SymbolKind.NamedType
                    where m.Kind != SymbolKind.Method || ((IMethodSymbol)m).MethodKind is MethodKind.Ordinary or MethodKind.UserDefinedOperator or MethodKind.Conversion
                    where m.Kind != SymbolKind.Property || ((IPropertySymbol)m).IsIndexer || ((IPropertySymbol)m).CanBeReferencedByName
                    where m.Kind != SymbolKind.Event || ((IEventSymbol)m).CanBeReferencedByName
                    where !isImplemented(classOrStructType, m, isValidImplementation, cancellationToken)
                    select m;
 
            return q.ToImmutableArray();
        }
 
        public static IEnumerable<ISymbol> GetAttributeNamedParameters(
            this INamedTypeSymbol attributeSymbol,
            Compilation compilation,
            ISymbol within)
        {
            using var _ = PooledHashSet<string>.GetInstance(out var seenNames);
 
            var systemAttributeType = compilation.AttributeType();
 
            foreach (var type in attributeSymbol.GetBaseTypesAndThis())
            {
                if (type.Equals(systemAttributeType))
                {
                    break;
                }
 
                foreach (var member in type.GetMembers())
                {
                    var namedParameter = IsAttributeNamedParameter(member, within ?? compilation.Assembly);
                    if (namedParameter != null && seenNames.Add(namedParameter.Name))
                    {
                        yield return namedParameter;
                    }
                }
            }
        }
 
        private static ISymbol? IsAttributeNamedParameter(
            ISymbol symbol,
            ISymbol within)
        {
            if (!symbol.CanBeReferencedByName ||
                !symbol.IsAccessibleWithin(within))
            {
                return null;
            }
 
            switch (symbol.Kind)
            {
                case SymbolKind.Field:
                    var fieldSymbol = (IFieldSymbol)symbol;
                    if (!fieldSymbol.IsConst &&
                        !fieldSymbol.IsReadOnly &&
                        !fieldSymbol.IsStatic)
                    {
                        return fieldSymbol;
                    }
 
                    break;
 
                case SymbolKind.Property:
                    var propertySymbol = (IPropertySymbol)symbol;
                    if (!propertySymbol.IsReadOnly &&
                        !propertySymbol.IsWriteOnly &&
                        !propertySymbol.IsStatic &&
                        propertySymbol.GetMethod != null &&
                        propertySymbol.SetMethod != null &&
                        propertySymbol.GetMethod.IsAccessibleWithin(within) &&
                        propertySymbol.SetMethod.IsAccessibleWithin(within))
                    {
                        return propertySymbol;
                    }
 
                    break;
            }
 
            return null;
        }
 
        private static ImmutableArray<ISymbol> GetMembers(INamedTypeSymbol type, ISymbol within)
            => type.GetMembers();
 
        /// <summary>
        /// Gets the set of members in the inheritance chain of <paramref name="containingType"/> that
        /// are overridable.  The members will be returned in furthest-base type to closest-base
        /// type order.  i.e. the overridable members of <see cref="System.Object"/> will be at the start
        /// of the list, and the members of the direct parent type of <paramref name="containingType"/> 
        /// will be at the end of the list.
        /// 
        /// If a member has already been overridden (in <paramref name="containingType"/> or any base type) 
        /// it will not be included in the list.
        /// </summary>
        public static ImmutableArray<ISymbol> GetOverridableMembers(
            this INamedTypeSymbol containingType, CancellationToken cancellationToken)
        {
            // Keep track of the symbols we've seen and what order we saw them in.  The 
            // order allows us to produce the symbols in the end from the furthest base-type
            // to the closest base-type
            var result = new Dictionary<ISymbol, int>();
            var index = 0;
 
            if (containingType != null &&
                !containingType.IsScriptClass &&
                !containingType.IsImplicitClass &&
                !containingType.IsStatic)
            {
                if (containingType.TypeKind is TypeKind.Class or TypeKind.Struct)
                {
                    var baseTypes = containingType.GetBaseTypes().Reverse();
                    foreach (var type in baseTypes)
                    {
                        cancellationToken.ThrowIfCancellationRequested();
 
                        // Prefer overrides in derived classes
                        RemoveOverriddenMembers(result, type, cancellationToken);
 
                        // Retain overridable methods
                        AddOverridableMembers(result, containingType, type, ref index, cancellationToken);
                    }
 
                    // Don't suggest already overridden members
                    RemoveOverriddenMembers(result, containingType, cancellationToken);
                }
            }
 
            return result.Keys.OrderBy(s => result[s]).ToImmutableArray();
        }
 
        private static void AddOverridableMembers(
            Dictionary<ISymbol, int> result, INamedTypeSymbol containingType,
            INamedTypeSymbol type, ref int index, CancellationToken cancellationToken)
        {
            foreach (var member in type.GetMembers())
            {
                cancellationToken.ThrowIfCancellationRequested();
 
                if (IsOverridable(member, containingType))
                {
                    result[member] = index++;
                }
            }
        }
 
        private static bool IsOverridable(ISymbol member, INamedTypeSymbol containingType)
        {
            if (!member.IsAbstract && !member.IsVirtual && !member.IsOverride)
                return false;
 
            if (member.IsSealed)
                return false;
 
            if (!member.IsAccessibleWithin(containingType))
                return false;
 
            return member switch
            {
                IEventSymbol => true,
                IMethodSymbol { MethodKind: MethodKind.Ordinary, CanBeReferencedByName: true } => true,
                IPropertySymbol { IsWithEvents: false } => true,
                _ => false,
            };
        }
 
        private static void RemoveOverriddenMembers(
            Dictionary<ISymbol, int> result, INamedTypeSymbol containingType, CancellationToken cancellationToken)
        {
            foreach (var member in containingType.GetMembers())
            {
                cancellationToken.ThrowIfCancellationRequested();
 
                // An implicitly declared override is still something the user can provide their own explicit override
                // for.  This is true for all implicit overrides *except* for the one for `bool object.Equals(object)`.
                // This override is not one the user is allowed to provide their own override for as it must have a very
                // particular implementation to ensure proper record equality semantics.
                if (!member.IsImplicitlyDeclared || IsEqualsObjectOverride(member))
                {
                    var overriddenMember = member.GetOverriddenMember();
                    if (overriddenMember != null)
                        result.Remove(overriddenMember);
                }
            }
        }
 
        private static bool IsEqualsObjectOverride(ISymbol? member)
        {
            if (member == null)
                return false;
 
            if (IsEqualsObject(member))
                return true;
 
            return IsEqualsObjectOverride(member.GetOverriddenMember());
        }
 
        private static bool IsEqualsObject(ISymbol member)
        {
            return member is IMethodSymbol
            {
                Name: nameof(Equals),
                IsStatic: false,
                ContainingType.SpecialType: SpecialType.System_Object,
                Parameters.Length: 1,
            };
        }
 
        public static INamedTypeSymbol TryConstruct(this INamedTypeSymbol type, ITypeSymbol[] typeArguments)
            => typeArguments.Length > 0 ? type.Construct(typeArguments) : type;
    }
}