File: Shared\Utilities\ExtractTypeHelpers.cs
Web Access
Project: ..\..\..\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#nullable disable
 
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeCleanup;
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.GenerateType;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Simplification;
using Microsoft.CodeAnalysis.Text;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.Shared.Utilities
{
    internal static class ExtractTypeHelpers
    {
        public static async Task<(Document containingDocument, SyntaxAnnotation typeAnnotation)> AddTypeToExistingFileAsync(Document document, INamedTypeSymbol newType, AnnotatedSymbolMapping symbolMapping, CodeGenerationOptionsProvider fallbackOptions, CancellationToken cancellationToken)
        {
            var originalRoot = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
            var typeDeclaration = originalRoot.GetAnnotatedNodes(symbolMapping.TypeNodeAnnotation).Single();
            var editor = new SyntaxEditor(originalRoot, symbolMapping.AnnotatedSolution.Services);
 
            var context = new CodeGenerationContext(generateMethodBodies: true);
            var info = await document.GetCodeGenerationInfoAsync(context, fallbackOptions, cancellationToken).ConfigureAwait(false);
 
            var newTypeNode = info.Service.CreateNamedTypeDeclaration(newType, CodeGenerationDestination.Unspecified, info, cancellationToken)
                .WithAdditionalAnnotations(SimplificationHelpers.SimplifyModuleNameAnnotation);
 
            var typeAnnotation = new SyntaxAnnotation();
            newTypeNode = newTypeNode.WithAdditionalAnnotations(typeAnnotation);
 
            editor.InsertBefore(typeDeclaration, newTypeNode);
 
            var newDocument = document.WithSyntaxRoot(editor.GetChangedRoot());
            return (newDocument, typeAnnotation);
        }
 
        public static async Task<(Document containingDocument, SyntaxAnnotation typeAnnotation)> AddTypeToNewFileAsync(
            Solution solution,
            string containingNamespaceDisplay,
            string fileName,
            ProjectId projectId,
            IEnumerable<string> folders,
            INamedTypeSymbol newSymbol,
            Document hintDocument,
            CleanCodeGenerationOptionsProvider fallbackOptions,
            CancellationToken cancellationToken)
        {
            var newDocumentId = DocumentId.CreateNewId(projectId, debugName: fileName);
            var newDocumentPath = PathUtilities.CombinePaths(PathUtilities.GetDirectoryName(hintDocument.FilePath), fileName);
 
            var solutionWithInterfaceDocument = solution.AddDocument(newDocumentId, fileName, text: "", folders: folders, filePath: newDocumentPath);
            var newDocument = solutionWithInterfaceDocument.GetRequiredDocument(newDocumentId);
            var newSemanticModel = await newDocument.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
 
            var context = new CodeGenerationContext(
                contextLocation: newSemanticModel.SyntaxTree.GetLocation(new TextSpan()),
                generateMethodBodies: true);
 
            // need to remove the root namespace from the containing namespace display because it is implied
            // For C# this does nothing as there is no root namespace (root namespace is empty string)
            var generateTypeService = newDocument.GetRequiredLanguageService<IGenerateTypeService>();
            var rootNamespace = generateTypeService.GetRootNamespace(newDocument.Project.CompilationOptions);
            var index = rootNamespace.IsEmpty() ? -1 : containingNamespaceDisplay.IndexOf(rootNamespace);
            // if we did find the root namespace as the first element, then we remove it
            // this may leave us with an extra "." character at the start, but when we split it shouldn't matter
            var namespaceWithoutRoot = index == 0
                ? containingNamespaceDisplay.Remove(index, rootNamespace.Length)
                : containingNamespaceDisplay;
 
            var namespaceParts = namespaceWithoutRoot.Split('.').Where(s => !string.IsNullOrEmpty(s));
            var newTypeDocument = await CodeGenerator.AddNamespaceOrTypeDeclarationAsync(
                new CodeGenerationSolutionContext(
                    newDocument.Project.Solution,
                    context,
                    fallbackOptions),
                newSemanticModel.GetEnclosingNamespace(0, cancellationToken),
                newSymbol.GenerateRootNamespaceOrType(namespaceParts.ToArray()),
                cancellationToken).ConfigureAwait(false);
 
            var newCleanupOptions = await newTypeDocument.GetCodeCleanupOptionsAsync(fallbackOptions, cancellationToken).ConfigureAwait(false);
 
            var formattingService = newTypeDocument.GetLanguageService<INewDocumentFormattingService>();
            if (formattingService is not null)
            {
                newTypeDocument = await formattingService.FormatNewDocumentAsync(newTypeDocument, hintDocument, newCleanupOptions, cancellationToken).ConfigureAwait(false);
            }
 
            var syntaxRoot = await newTypeDocument.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
 
            var typeAnnotation = new SyntaxAnnotation();
            var syntaxFacts = newTypeDocument.GetRequiredLanguageService<ISyntaxFactsService>();
 
            var declarationNode = syntaxRoot.DescendantNodes().First(syntaxFacts.IsTypeDeclaration);
            var annotatedRoot = syntaxRoot.ReplaceNode(declarationNode, declarationNode.WithAdditionalAnnotations(typeAnnotation));
 
            newTypeDocument = newTypeDocument.WithSyntaxRoot(annotatedRoot);
 
            var simplified = await Simplifier.ReduceAsync(newTypeDocument, newCleanupOptions.SimplifierOptions, cancellationToken).ConfigureAwait(false);
            var formattedDocument = await Formatter.FormatAsync(simplified, newCleanupOptions.FormattingOptions, cancellationToken).ConfigureAwait(false);
 
            return (formattedDocument, typeAnnotation);
        }
 
        public static string GetTypeParameterSuffix(Document document, SyntaxFormattingOptions formattingOptions, INamedTypeSymbol type, IEnumerable<ISymbol> extractableMembers, CancellationToken cancellationToken)
        {
            var typeParameters = GetRequiredTypeParametersForMembers(type, extractableMembers);
 
            if (type.TypeParameters.Length == 0)
            {
                return string.Empty;
            }
 
            var typeParameterNames = typeParameters.SelectAsArray(p => p.Name);
            var syntaxGenerator = SyntaxGenerator.GetGenerator(document);
 
            return Formatter.Format(syntaxGenerator.SyntaxGeneratorInternal.TypeParameterList(typeParameterNames), document.Project.Solution.Services, formattingOptions, cancellationToken).ToString();
        }
 
        public static ImmutableArray<ITypeParameterSymbol> GetRequiredTypeParametersForMembers(INamedTypeSymbol type, IEnumerable<ISymbol> includedMembers)
        {
            var potentialTypeParameters = GetPotentialTypeParameters(type);
 
            var directlyReferencedTypeParameters = GetDirectlyReferencedTypeParameters(potentialTypeParameters, includedMembers);
 
            // The directly referenced TypeParameters may have constraints that reference other 
            // type parameters.
 
            var allReferencedTypeParameters = new HashSet<ITypeParameterSymbol>(directlyReferencedTypeParameters);
            var unanalyzedTypeParameters = new Queue<ITypeParameterSymbol>(directlyReferencedTypeParameters);
 
            while (!unanalyzedTypeParameters.IsEmpty())
            {
                var typeParameter = unanalyzedTypeParameters.Dequeue();
 
                foreach (var constraint in typeParameter.ConstraintTypes)
                {
                    foreach (var originalTypeParameter in potentialTypeParameters)
                    {
                        if (!allReferencedTypeParameters.Contains(originalTypeParameter) &&
                            DoesTypeReferenceTypeParameter(constraint, originalTypeParameter, new HashSet<ITypeSymbol>()))
                        {
                            allReferencedTypeParameters.Add(originalTypeParameter);
                            unanalyzedTypeParameters.Enqueue(originalTypeParameter);
                        }
                    }
                }
            }
 
            return potentialTypeParameters.WhereAsArray(allReferencedTypeParameters.Contains);
        }
 
        private static ImmutableArray<ITypeParameterSymbol> GetPotentialTypeParameters(INamedTypeSymbol type)
        {
            using var _ = ArrayBuilder<ITypeParameterSymbol>.GetInstance(out var typeParameters);
 
            var typesToVisit = new Stack<INamedTypeSymbol>();
 
            var currentType = type;
            while (currentType != null)
            {
                typesToVisit.Push(currentType);
                currentType = currentType.ContainingType;
            }
 
            while (typesToVisit.Any())
            {
                typeParameters.AddRange(typesToVisit.Pop().TypeParameters);
            }
 
            return typeParameters.ToImmutable();
        }
 
        private static ImmutableArray<ITypeParameterSymbol> GetDirectlyReferencedTypeParameters(IEnumerable<ITypeParameterSymbol> potentialTypeParameters, IEnumerable<ISymbol> includedMembers)
        {
            using var _ = ArrayBuilder<ITypeParameterSymbol>.GetInstance(out var directlyReferencedTypeParameters);
            foreach (var typeParameter in potentialTypeParameters)
            {
                if (includedMembers.Any(m => DoesMemberReferenceTypeParameter(m, typeParameter, new HashSet<ITypeSymbol>())))
                {
                    directlyReferencedTypeParameters.Add(typeParameter);
                }
            }
 
            return directlyReferencedTypeParameters.ToImmutable();
        }
 
        private static bool DoesMemberReferenceTypeParameter(ISymbol member, ITypeParameterSymbol typeParameter, HashSet<ITypeSymbol> checkedTypes)
        {
            switch (member.Kind)
            {
                case SymbolKind.Event:
                    var @event = member as IEventSymbol;
                    return DoesTypeReferenceTypeParameter(@event.Type, typeParameter, checkedTypes);
                case SymbolKind.Method:
                    var method = member as IMethodSymbol;
                    return method.Parameters.Any(static (t, arg) => DoesTypeReferenceTypeParameter(t.Type, arg.typeParameter, arg.checkedTypes), (typeParameter, checkedTypes)) ||
                        method.TypeParameters.Any(static (t, arg) => t.ConstraintTypes.Any(static (c, arg) => DoesTypeReferenceTypeParameter(c, arg.typeParameter, arg.checkedTypes), (arg.typeParameter, arg.checkedTypes)), (typeParameter, checkedTypes)) ||
                        DoesTypeReferenceTypeParameter(method.ReturnType, typeParameter, checkedTypes);
                case SymbolKind.Property:
                    var property = member as IPropertySymbol;
                    return property.Parameters.Any(static (t, arg) => DoesTypeReferenceTypeParameter(t.Type, arg.typeParameter, arg.checkedTypes), (typeParameter, checkedTypes)) ||
                        DoesTypeReferenceTypeParameter(property.Type, typeParameter, checkedTypes);
                case SymbolKind.Field:
                    var field = member as IFieldSymbol;
                    return DoesTypeReferenceTypeParameter(field.Type, typeParameter, checkedTypes);
                default:
                    Debug.Assert(false, string.Format(FeaturesResources.Unexpected_interface_member_kind_colon_0, member.Kind.ToString()));
                    return false;
            }
        }
 
        private static bool DoesTypeReferenceTypeParameter(ITypeSymbol type, ITypeParameterSymbol typeParameter, HashSet<ITypeSymbol> checkedTypes)
        {
            if (!checkedTypes.Add(type))
            {
                return false;
            }
 
            // We want to ignore nullability when comparing as T and T? both are references to the type parameter
            if (type.Equals(typeParameter, SymbolEqualityComparer.Default) ||
                type.GetTypeArguments().Any(static (t, arg) => DoesTypeReferenceTypeParameter(t, arg.typeParameter, arg.checkedTypes), (typeParameter, checkedTypes)))
            {
                return true;
            }
 
            if (type.ContainingType != null &&
                type.Kind != SymbolKind.TypeParameter &&
                DoesTypeReferenceTypeParameter(type.ContainingType, typeParameter, checkedTypes))
            {
                return true;
            }
 
            return false;
        }
    }
}