File: GenerateEqualsAndGetHashCodeFromMembers\AbstractGenerateEqualsAndGetHashCodeService.cs
Web Access
Project: ..\..\..\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.Formatting.Rules;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
 
namespace Microsoft.CodeAnalysis.GenerateEqualsAndGetHashCodeFromMembers
{
    internal abstract partial class AbstractGenerateEqualsAndGetHashCodeService : IGenerateEqualsAndGetHashCodeService
    {
        private const string GetHashCodeName = nameof(object.GetHashCode);
        private static readonly SyntaxAnnotation s_specializedFormattingAnnotation = new();
 
        protected abstract bool TryWrapWithUnchecked(
            ImmutableArray<SyntaxNode> statements, out ImmutableArray<SyntaxNode> wrappedStatements);
 
        public async Task<Document> FormatDocumentAsync(Document document, SyntaxFormattingOptions options, CancellationToken cancellationToken)
        {
            var rules = new List<AbstractFormattingRule> { new FormatLargeBinaryExpressionRule(document.GetRequiredLanguageService<ISyntaxFactsService>()) };
            rules.AddRange(Formatter.GetDefaultFormattingRules(document));
 
            var formattedDocument = await Formatter.FormatAsync(
                document, s_specializedFormattingAnnotation,
                options, rules, cancellationToken).ConfigureAwait(false);
            return formattedDocument;
        }
 
        public async Task<IMethodSymbol> GenerateEqualsMethodAsync(
            Document document, INamedTypeSymbol namedType, ImmutableArray<ISymbol> members,
            string? localNameOpt, CancellationToken cancellationToken)
        {
            var compilation = await document.Project.GetCompilationAsync(cancellationToken).ConfigureAwait(false);
            var tree = await document.GetRequiredSyntaxTreeAsync(cancellationToken).ConfigureAwait(false);
            var generator = document.GetLanguageService<SyntaxGenerator>();
            var generatorInternal = document.GetLanguageService<SyntaxGeneratorInternal>();
            return generator.CreateEqualsMethod(
                generatorInternal, compilation, tree.Options, namedType, members, localNameOpt, s_specializedFormattingAnnotation);
        }
 
        public async Task<IMethodSymbol> GenerateIEquatableEqualsMethodAsync(
            Document document, INamedTypeSymbol namedType,
            ImmutableArray<ISymbol> members, INamedTypeSymbol constructedEquatableType, CancellationToken cancellationToken)
        {
            var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
            var generator = document.GetLanguageService<SyntaxGenerator>();
            var generatorInternal = document.GetLanguageService<SyntaxGeneratorInternal>();
            return generator.CreateIEquatableEqualsMethod(
                generatorInternal, semanticModel, namedType, members, constructedEquatableType, s_specializedFormattingAnnotation);
        }
 
        public async Task<IMethodSymbol> GenerateEqualsMethodThroughIEquatableEqualsAsync(
            Document document, INamedTypeSymbol containingType, CancellationToken cancellationToken)
        {
            var compilation = await document.Project.GetCompilationAsync(cancellationToken).ConfigureAwait(false);
            var tree = await document.GetRequiredSyntaxTreeAsync(cancellationToken).ConfigureAwait(false);
            var generator = document.GetRequiredLanguageService<SyntaxGenerator>();
 
            using var _ = ArrayBuilder<SyntaxNode>.GetInstance(out var expressions);
            var objName = generator.IdentifierName("obj");
            if (containingType.IsValueType)
            {
                if (generator.SyntaxGeneratorInternal.SupportsPatterns(tree.Options))
                {
                    // return obj is T t && this.Equals(t);
                    var localName = containingType.GetLocalName();
 
                    expressions.Add(
                        generator.SyntaxGeneratorInternal.IsPatternExpression(objName,
                            generator.SyntaxGeneratorInternal.DeclarationPattern(containingType, localName)));
                    expressions.Add(
                        generator.InvocationExpression(
                            generator.MemberAccessExpression(
                                generator.ThisExpression(),
                                generator.IdentifierName(nameof(Equals))),
                            generator.IdentifierName(localName)));
                }
                else
                {
                    // return obj is T && this.Equals((T)obj);
                    expressions.Add(generator.IsTypeExpression(objName, containingType));
                    expressions.Add(
                        generator.InvocationExpression(
                            generator.MemberAccessExpression(
                                generator.ThisExpression(),
                                generator.IdentifierName(nameof(Equals))),
                            generator.CastExpression(containingType, objName)));
                }
            }
            else
            {
                // return this.Equals(obj as T);
                expressions.Add(
                    generator.InvocationExpression(
                        generator.MemberAccessExpression(
                            generator.ThisExpression(),
                            generator.IdentifierName(nameof(Equals))),
                        generator.TryCastExpression(objName, containingType)));
            }
 
            var statement = generator.ReturnStatement(
                expressions.Aggregate(generator.LogicalAndExpression));
 
            return compilation.CreateEqualsMethod(
                ImmutableArray.Create(statement));
        }
 
        public async Task<IMethodSymbol> GenerateGetHashCodeMethodAsync(
            Document document, INamedTypeSymbol namedType,
            ImmutableArray<ISymbol> members, CancellationToken cancellationToken)
        {
            var compilation = await document.Project.GetRequiredCompilationAsync(cancellationToken).ConfigureAwait(false);
            var factory = document.GetRequiredLanguageService<SyntaxGenerator>();
            var generatorInternal = document.GetRequiredLanguageService<SyntaxGeneratorInternal>();
            return CreateGetHashCodeMethod(factory, generatorInternal, compilation, namedType, members);
        }
 
        private IMethodSymbol CreateGetHashCodeMethod(
            SyntaxGenerator factory, SyntaxGeneratorInternal generatorInternal, Compilation compilation,
            INamedTypeSymbol namedType, ImmutableArray<ISymbol> members)
        {
            var statements = CreateGetHashCodeStatements(
                factory, generatorInternal, compilation, namedType, members);
 
            return CodeGenerationSymbolFactory.CreateMethodSymbol(
                attributes: default,
                accessibility: Accessibility.Public,
                modifiers: new DeclarationModifiers(isOverride: true),
                returnType: compilation.GetSpecialType(SpecialType.System_Int32),
                refKind: RefKind.None,
                explicitInterfaceImplementations: default,
                name: GetHashCodeName,
                typeParameters: default,
                parameters: default,
                statements: statements);
        }
 
        private ImmutableArray<SyntaxNode> CreateGetHashCodeStatements(
            SyntaxGenerator factory, SyntaxGeneratorInternal generatorInternal, Compilation compilation,
            INamedTypeSymbol namedType, ImmutableArray<ISymbol> members)
        {
            // See if there's an accessible System.HashCode we can call into to do all the work.
            var hashCodeType = compilation.GetTypeByMetadataName("System.HashCode");
            if (hashCodeType != null && !hashCodeType.IsAccessibleWithin(namedType))
                hashCodeType = null;
 
            var components = factory.GetGetHashCodeComponents(
                generatorInternal, compilation, namedType, members, justMemberReference: true);
 
            if (components.Length > 0 && hashCodeType != null)
            {
                return factory.CreateGetHashCodeStatementsUsingSystemHashCode(
                    factory.SyntaxGeneratorInternal, hashCodeType, components);
            }
 
            // Otherwise, try to just spit out a reasonable hash code for these members.
            var statements = factory.CreateGetHashCodeMethodStatements(
                factory.SyntaxGeneratorInternal, compilation, namedType, members, useInt64: false);
 
            // Unfortunately, our 'reasonable' hash code may overflow in checked contexts.
            // C# can handle this by adding 'checked{}' around the code, VB has to jump
            // through more hoops.
            if (!compilation.Options.CheckOverflow)
            {
                return statements;
            }
 
            if (TryWrapWithUnchecked(statements, out var wrappedStatements))
            {
                return wrappedStatements;
            }
 
            // If tuples are available, use (a, b, c).GetHashCode to simply generate the tuple.
            var valueTupleType = compilation.GetTypeByMetadataName(typeof(ValueTuple).FullName!);
            if (components.Length >= 2 && valueTupleType != null)
            {
                return ImmutableArray.Create(factory.ReturnStatement(
                    factory.InvocationExpression(
                        factory.MemberAccessExpression(
                            factory.TupleExpression(components),
                            GetHashCodeName))));
            }
 
            // Otherwise, use 64bit math to compute the hash.  Importantly, if we always clamp
            // the hash to be 32 bits or less, then the following cannot ever overflow in 64
            // bits: hashCode = hashCode * -1521134295 + j.GetHashCode()
            //
            // So we'll generate lines like: hashCode = (hashCode * -1521134295 + j.GetHashCode()) & 0x7FFFFFFF
            //
            // This does mean all hashcodes will be positive.  But it will avoid the overflow problem.
            return factory.CreateGetHashCodeMethodStatements(
                factory.SyntaxGeneratorInternal, compilation, namedType, members, useInt64: true);
        }
    }
}