File: SymbolKey\SymbolKeyTestBase.cs
Web Access
Project: ..\..\..\src\EditorFeatures\CSharpTest\Microsoft.CodeAnalysis.CSharp.EditorFeatures.UnitTests.csproj (Microsoft.CodeAnalysis.CSharp.EditorFeatures.UnitTests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#nullable disable
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.CSharp.Test.Utilities;
using Xunit;
 
namespace Microsoft.CodeAnalysis.Editor.CSharp.UnitTests.SymbolId
{
    public abstract class SymbolKeyTestBase : CSharpTestBase
    {
        [Flags]
        internal enum SymbolKeyComparison
        {
            None = 0x0,
            IgnoreCase = 0x1,
            IgnoreAssemblyIds = 0x2
        }
 
        [Flags]
        internal enum SymbolCategory
        {
            All = 0,
            DeclaredNamespace = 2,
            DeclaredType = 4,
            NonTypeMember = 8,
            Parameter = 16,
        }
 
        #region "Verification"
 
        internal static void ResolveAndVerifySymbolList(IEnumerable<ISymbol> newSymbols, IEnumerable<ISymbol> originalSymbols, CSharpCompilation originalComp)
        {
            var newlist = newSymbols.OrderBy(s => s.Name).ToList();
            var origlist = originalSymbols.OrderBy(s => s.Name).ToList();
 
            Assert.Equal(origlist.Count, newlist.Count);
 
            for (var i = 0; i < newlist.Count; i++)
            {
                ResolveAndVerifySymbol(newlist[i], origlist[i], originalComp);
            }
        }
 
        internal static void ResolveAndVerifyTypeSymbol(ExpressionSyntax node, ITypeSymbol sourceSymbol, SemanticModel model, CSharpCompilation sourceComp)
        {
            var typeinfo = model.GetTypeInfo(node);
            ResolveAndVerifySymbol(typeinfo.Type ?? typeinfo.ConvertedType, sourceSymbol, sourceComp);
        }
 
        internal static void ResolveAndVerifySymbol(ExpressionSyntax node, ISymbol sourceSymbol, SemanticModel model, CSharpCompilation sourceComp, SymbolKeyComparison comparison = SymbolKeyComparison.None)
        {
            var syminfo = model.GetSymbolInfo(node);
            ResolveAndVerifySymbol(syminfo.Symbol, sourceSymbol, sourceComp, comparison);
        }
 
        internal static void ResolveAndVerifySymbol(ISymbol symbol1, ISymbol symbol2, Compilation compilation2, SymbolKeyComparison comparison = SymbolKeyComparison.None)
        {
            // same ID
            AssertSymbolKeysEqual(symbol1, symbol2, comparison);
 
            var resolvedSymbol = ResolveSymbol(symbol1, compilation2, comparison);
 
            Assert.NotNull(resolvedSymbol);
 
            // same Symbol
            Assert.Equal(symbol2, resolvedSymbol);
            Assert.Equal(symbol2.GetHashCode(), resolvedSymbol.GetHashCode());
        }
 
        internal static ISymbol ResolveSymbol(ISymbol originalSymbol, Compilation targetCompilation, SymbolKeyComparison comparison)
        {
            var sid = SymbolKey.Create(originalSymbol, CancellationToken.None);
 
            // Verify that serialization works.
            var serialized = sid.ToString();
            var deserialized = new SymbolKey(serialized);
            var comparer = SymbolKey.GetComparer(ignoreCase: false, ignoreAssemblyKeys: false);
            Assert.True(comparer.Equals(sid, deserialized));
 
            var symInfo = sid.Resolve(targetCompilation, (comparison & SymbolKeyComparison.IgnoreAssemblyIds) == SymbolKeyComparison.IgnoreAssemblyIds);
            return symInfo.Symbol;
        }
 
        internal static void AssertSymbolKeysEqual(ISymbol symbol1, ISymbol symbol2, SymbolKeyComparison comparison, bool expectEqual = true)
        {
            var sid1 = SymbolKey.Create(symbol1, CancellationToken.None);
            var sid2 = SymbolKey.Create(symbol2, CancellationToken.None);
 
            // default is Insensitive
            var ignoreCase = (comparison & SymbolKeyComparison.IgnoreCase) == SymbolKeyComparison.IgnoreCase;
 
            // default is NOT ignore
            var ignoreAssemblyIds = (comparison & SymbolKeyComparison.IgnoreAssemblyIds) == SymbolKeyComparison.IgnoreAssemblyIds;
            var message = string.Concat(
                ignoreCase ? "SymbolID IgnoreCase" : "SymbolID",
                ignoreAssemblyIds ? " IgnoreAssemblyIds " : " ",
                "Compare");
 
            var ret = CodeAnalysis.SymbolKey.GetComparer(ignoreCase, ignoreAssemblyIds).Equals(sid2, sid1);
            if (expectEqual)
            {
                Assert.True(ret, message);
            }
            else
            {
                Assert.False(ret, message);
            }
        }
 
        #endregion
 
        #region "Utilities"
 
        internal static List<BlockSyntax> GetBlockSyntaxList(IMethodSymbol symbol)
        {
            var list = new List<BlockSyntax>();
 
            foreach (var node in symbol.DeclaringSyntaxReferences.Select(d => d.GetSyntax()))
            {
                BlockSyntax body = null;
                if (node is BaseMethodDeclarationSyntax baseMethod)
                {
                    body = baseMethod.Body;
                }
                else if (node is AccessorDeclarationSyntax accessor)
                {
                    body = accessor.Body;
                }
 
                if (body != null && body.Statements.Any())
                {
                    list.Add(body);
                }
            }
 
            return list;
        }
 
        internal static IEnumerable<ISymbol> GetSourceSymbols(Microsoft.CodeAnalysis.CSharp.CSharpCompilation compilation, SymbolCategory category)
        {
            // NYI for local symbols
            var list = GetSourceSymbols(compilation, includeLocal: false);
 
            var kinds = new List<SymbolKind>();
            if ((category & SymbolCategory.DeclaredNamespace) != 0)
            {
                kinds.Add(SymbolKind.Namespace);
            }
 
            if ((category & SymbolCategory.DeclaredType) != 0)
            {
                kinds.Add(SymbolKind.NamedType);
                kinds.Add(SymbolKind.TypeParameter);
            }
 
            if ((category & SymbolCategory.NonTypeMember) != 0)
            {
                kinds.Add(SymbolKind.Field);
                kinds.Add(SymbolKind.Event);
                kinds.Add(SymbolKind.Property);
                kinds.Add(SymbolKind.Method);
            }
 
            if ((category & SymbolCategory.Parameter) != 0)
            {
                kinds.Add(SymbolKind.Parameter);
            }
 
            return list.Where(s =>
            {
                if (s.IsImplicitlyDeclared)
                {
                    return false;
                }
 
                foreach (var k in kinds)
                {
                    if (s.Kind == k)
                    {
                        return true;
                    }
                }
 
                return false;
            });
        }
 
        internal static IList<ISymbol> GetSourceSymbols(CSharpCompilation compilation, bool includeLocal)
        {
            var list = new List<ISymbol>();
            var localDumper = includeLocal ? new LocalSymbolDumper(compilation) : null;
            GetSourceMemberSymbols(compilation.SourceModule.GlobalNamespace.GetPublicSymbol(), list, localDumper);
 
            // ??
            // if (includeLocal)
            GetSourceAliasSymbols(compilation, list);
            Compilation c = compilation;
            list.Add(c.Assembly);
            list.AddRange(c.Assembly.Modules);
 
            return list;
        }
 
        #endregion
 
        #region "Private Helpers"
 
        private static void GetSourceMemberSymbols(INamespaceOrTypeSymbol symbol, List<ISymbol> list, LocalSymbolDumper localDumper)
        {
            foreach (var memberSymbol in symbol.GetMembers())
            {
                list.Add(memberSymbol);
 
                switch (memberSymbol.Kind)
                {
                    case SymbolKind.NamedType:
                    case SymbolKind.Namespace:
                        GetSourceMemberSymbols((INamespaceOrTypeSymbol)memberSymbol, list, localDumper);
                        break;
                    case SymbolKind.Method:
                        var method = (IMethodSymbol)memberSymbol;
                        foreach (var parameter in method.Parameters)
                        {
                            list.Add(parameter);
                        }
 
                        localDumper?.GetLocalSymbols(method.GetSymbol(), list);
 
                        break;
                    case SymbolKind.Field:
                        localDumper?.GetLocalSymbols(memberSymbol.GetSymbol<FieldSymbol>(), list);
 
                        break;
                }
            }
        }
 
        private static void GetSourceAliasSymbols(CSharpCompilation comp, List<ISymbol> list)
        {
            foreach (var tree in comp.SyntaxTrees)
            {
                var usingNodes = tree.GetRoot().DescendantNodes().OfType<UsingDirectiveSyntax>();
                var model = comp.GetSemanticModel(tree);
                foreach (var u in usingNodes)
                {
                    if (u.Alias != null)
                    {
                        // var sym = model.GetSymbolInfo(u.Alias.Identifier).Symbol;
                        var sym = model.GetDeclaredSymbol(u);
                        if (sym != null && !list.Contains(sym))
                        {
                            list.Add(sym);
                        }
                    }
                }
            }
        }
 
        #endregion
 
        private class LocalSymbolDumper
        {
            private readonly CSharpCompilation _compilation;
            public LocalSymbolDumper(CSharpCompilation compilation)
                => _compilation = compilation;
 
            public void GetLocalSymbols(FieldSymbol symbol, List<ISymbol> list)
            {
                foreach (var node in symbol.DeclaringSyntaxReferences.Select(d => d.GetSyntax()))
                {
                    if (node is VariableDeclaratorSyntax declarator && declarator.Initializer != null)
                    {
                        var model = _compilation.GetSemanticModel(declarator.SyntaxTree);
 
                        // Expression
                        var df = model.AnalyzeDataFlow(declarator.Initializer.Value);
                        GetLocalAndType(df, list);
 
                        GetAnonymousExprSymbols(declarator.Initializer.Value, model, list);
                    }
                }
            }
 
            public void GetLocalSymbols(MethodSymbol symbol, List<ISymbol> list)
            {
                foreach (var node in symbol.DeclaringSyntaxReferences.Select(d => d.GetSyntax()))
                {
                    BlockSyntax body = null;
                    if (node is BaseMethodDeclarationSyntax baseMethod)
                    {
                        body = baseMethod.Body;
                    }
                    else if (node is AccessorDeclarationSyntax accessor)
                    {
                        body = accessor.Body;
                    }
 
                    var model = _compilation.GetSemanticModel(node.SyntaxTree);
 
                    if (body != null && body.Statements.Any())
                    {
                        var df = model.AnalyzeDataFlow(body);
                        GetLocalAndType(df, list);
 
                        GetAnonymousTypeOrFuncSymbols(body, model, list);
 
                        GetLabelSymbols(body, model, list);
                    }
 
                    // C# specific (this|base access)
                    if (node is ConstructorDeclarationSyntax ctor && ctor.Initializer != null)
                    {
                        foreach (var a in ctor.Initializer.ArgumentList.Arguments)
                        {
                            var df = model.AnalyzeDataFlow(a.Expression);
 
                            // VisitLocals(arg, df);
                            list.AddRange(df.VariablesDeclared);
 
                            GetAnonymousExprSymbols(a.Expression, model, list);
                        }
                    }
                }
            }
 
            private static void GetLocalAndType(DataFlowAnalysis df, List<ISymbol> list)
            {
                foreach (var v in df.VariablesDeclared)
                {
                    list.Add(v);
                    if (v is ILocalSymbol local && (local.Type.Kind == SymbolKind.ArrayType || local.Type.Kind == SymbolKind.PointerType))
                    {
                        list.Add(local.Type);
                    }
                }
            }
 
            private static void GetLabelSymbols(BlockSyntax body, SemanticModel model, List<ISymbol> list)
            {
                var labels = body.DescendantNodes().OfType<LabeledStatementSyntax>();
                foreach (var n in labels)
                {
                    // Label: -> 'Label' is token
                    var sym = model.GetDeclaredSymbol(n);
                    list.Add(sym);
                }
 
                var swlabels = body.DescendantNodes().OfType<SwitchLabelSyntax>();
                foreach (var n in swlabels)
                {
                    // label value has NO symbol, Type is expr's type
                    // e.g. case "A": -> string type
                    // var info1 = model.GetTypeInfo(n.Value);
                    // var info2 = model.GetSymbolInfo(n.Value);
                    var sym = model.GetDeclaredSymbol(n);
                    list.Add(sym);
                }
            }
 
            private static void GetAnonymousTypeOrFuncSymbols(BlockSyntax body, SemanticModel model, List<ISymbol> list)
            {
                IEnumerable<ExpressionSyntax> exprs = body.DescendantNodes().OfType<SimpleLambdaExpressionSyntax>();
                IEnumerable<ExpressionSyntax> tmp = body.DescendantNodes().OfType<ParenthesizedLambdaExpressionSyntax>();
                exprs = exprs.Concat(tmp);
                tmp = body.DescendantNodes().OfType<AnonymousMethodExpressionSyntax>();
                exprs = exprs.Concat(tmp);
 
                tmp = body.DescendantNodes().OfType<AnonymousObjectCreationExpressionSyntax>();
                exprs = exprs.Concat(tmp);
 
                foreach (var expr in exprs)
                {
                    GetAnonymousExprSymbols(expr, model, list);
                }
            }
 
            private static void GetAnonymousExprSymbols(ExpressionSyntax expr, SemanticModel model, List<ISymbol> list)
            {
                var kind = expr.Kind();
                if (kind is not SyntaxKind.AnonymousObjectCreationExpression and
                    not SyntaxKind.AnonymousMethodExpression and
                    not SyntaxKind.ParenthesizedLambdaExpression and
                    not SyntaxKind.SimpleLambdaExpression)
                {
                    return;
                }
 
                var tinfo = model.GetTypeInfo(expr);
                var conv = model.GetConversion(expr);
                if (conv.IsAnonymousFunction)
                {
                    // Lambda has no Type unless in part of case expr (C# specific)
                    // var f = (Func<int>)(() => { return 1; }); Type is delegate
                    // method symbol
                    var sinfo = model.GetSymbolInfo(expr);
                    list.Add(sinfo.Symbol);
                }
                else if (tinfo.Type != null && tinfo.Type.TypeKind != TypeKind.Delegate)
                {
                    // bug#12625
                    // GetSymbolInfo -> .ctor (part of members)
                    list.Add(tinfo.Type); // NamedType with empty name
                    foreach (var m in tinfo.Type.GetMembers())
                    {
                        list.Add(m);
                    }
                }
            }
        }
    }
}