File: AbstractUserDiagnosticTest.FixAllDiagnosticProvider.cs
Web Access
Project: ..\..\..\src\CodeStyle\Core\Tests\Microsoft.CodeAnalysis.CodeStyle.LegacyTestFramework.UnitTestUtilities.csproj (Microsoft.CodeAnalysis.CodeStyle.LegacyTestFramework.UnitTestUtilities)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#nullable disable
 
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.UnitTests.Diagnostics;
 
namespace Microsoft.CodeAnalysis.Editor.UnitTests.Diagnostics
{
    public abstract partial class AbstractUserDiagnosticTest
    {
        private class FixAllDiagnosticProvider : FixAllContext.DiagnosticProvider
        {
            private readonly TestDiagnosticAnalyzerDriver _testDriver;
            private readonly ImmutableHashSet<string> _diagnosticIds;
 
            public FixAllDiagnosticProvider(TestDiagnosticAnalyzerDriver testDriver, ImmutableHashSet<string> diagnosticIds)
            {
                _testDriver = testDriver;
                _diagnosticIds = diagnosticIds;
            }
 
            public override async Task<IEnumerable<Diagnostic>> GetDocumentDiagnosticsAsync(Document document, CancellationToken cancellationToken)
            {
                var root = await document.GetSyntaxRootAsync(cancellationToken);
                var diags = await _testDriver.GetDocumentDiagnosticsAsync(document, root.FullSpan);
                diags = diags.Where(diag => _diagnosticIds.Contains(diag.Id));
                return diags;
            }
 
            public override Task<IEnumerable<Diagnostic>> GetAllDiagnosticsAsync(Project project, CancellationToken cancellationToken)
                => GetProjectDiagnosticsAsync(project, true);
 
            public override Task<IEnumerable<Diagnostic>> GetProjectDiagnosticsAsync(Project project, CancellationToken cancellationToken)
                => GetProjectDiagnosticsAsync(project, false);
 
            private async Task<IEnumerable<Diagnostic>> GetProjectDiagnosticsAsync(
                Project project, bool includeAllDocumentDiagnostics)
            {
                var diags = includeAllDocumentDiagnostics
                    ? await _testDriver.GetAllDiagnosticsAsync(project)
                    : await _testDriver.GetProjectDiagnosticsAsync(project);
                diags = diags.Where(diag => _diagnosticIds.Contains(diag.Id));
                return diags;
            }
        }
    }
}