|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Immutable;
using System.Composition;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Host.Mef;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Simplification;
namespace Microsoft.CodeAnalysis.CSharp.CodeFixes.FixReturnType
{
/// <summary>
/// Helps fix void-returning methods or local functions to return a correct type.
/// </summary>
[ExportCodeFixProvider(LanguageNames.CSharp, Name = PredefinedCodeFixProviderNames.FixReturnType), Shared]
internal class CSharpFixReturnTypeCodeFixProvider : SyntaxEditorBasedCodeFixProvider
{
// error CS0127: Since 'M()' returns void, a return keyword must not be followed by an object expression
// error CS1997: Since 'M()' is an async method that returns 'Task', a return keyword must not be followed by an object expression
// error CS0201: Only assignment, call, increment, decrement, await, and new object expressions can be used as a statement
public override ImmutableArray<string> FixableDiagnosticIds => ImmutableArray.Create("CS0127", "CS1997", "CS0201");
[ImportingConstructor]
[Obsolete(MefConstruction.ImportingConstructorMessage, error: true)]
public CSharpFixReturnTypeCodeFixProvider()
: base(supportsFixAll: false)
{
}
public override async Task RegisterCodeFixesAsync(CodeFixContext context)
{
var document = context.Document;
var diagnostics = context.Diagnostics;
var cancellationToken = context.CancellationToken;
var analyzedTypes = await TryGetOldAndNewReturnTypeAsync(document, diagnostics, cancellationToken).ConfigureAwait(false);
if (analyzedTypes == default)
return;
if (IsVoid(analyzedTypes.declarationToFix) && IsVoid(analyzedTypes.fixedDeclaration))
{
// Don't offer a code fix if the return type is void and return is followed by a void expression.
// See https://github.com/dotnet/roslyn/issues/47089
return;
}
RegisterCodeFix(context, CSharpCodeFixesResources.Fix_return_type, nameof(CSharpCodeFixesResources.Fix_return_type));
return;
static bool IsVoid(TypeSyntax typeSyntax)
=> typeSyntax is PredefinedTypeSyntax { Keyword.RawKind: (int)SyntaxKind.VoidKeyword };
}
private static async Task<(TypeSyntax declarationToFix, TypeSyntax fixedDeclaration)> TryGetOldAndNewReturnTypeAsync(
Document document, ImmutableArray<Diagnostic> diagnostics, CancellationToken cancellationToken)
{
Debug.Assert(diagnostics.Length == 1);
var location = diagnostics[0].Location;
var node = location.FindNode(getInnermostNodeForTie: true, cancellationToken);
var returnedValue = node is ReturnStatementSyntax returnStatement ? returnStatement.Expression : node;
if (returnedValue is null)
return default;
var (declarationTypeToFix, isAsync) = TryGetDeclarationTypeToFix(node);
if (declarationTypeToFix is null)
return default;
var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var returnedType = semanticModel.GetTypeInfo(returnedValue, cancellationToken).Type;
// Special case when tuple has elements with unknown type, e.g. `(null, default)`
// Need to replace this unknown elements with default `object`s
if (returnedType is null &&
returnedValue is TupleExpressionSyntax tuple)
{
returnedType = InferTupleType(tuple, semanticModel, cancellationToken);
}
returnedType ??= semanticModel.Compilation.ObjectType;
TypeSyntax fixedDeclaration;
if (isAsync)
{
var previousReturnType = semanticModel.GetTypeInfo(declarationTypeToFix, cancellationToken).Type;
if (previousReturnType is null)
return default;
var compilation = semanticModel.Compilation;
INamedTypeSymbol? taskType = null;
// void, Task -> Task<T>
// ValueTask -> ValueTask<T>
// other type -> we cannot infer anything
if (previousReturnType.SpecialType is SpecialType.System_Void ||
Equals(previousReturnType, compilation.TaskType()))
{
taskType = compilation.TaskOfTType();
}
else if (Equals(previousReturnType, compilation.ValueTaskType()))
{
taskType = compilation.ValueTaskOfTType();
}
if (taskType is null)
return default;
fixedDeclaration = taskType.Construct(returnedType).GenerateTypeSyntax(allowVar: false);
}
else
{
fixedDeclaration = returnedType.GenerateTypeSyntax(allowVar: false);
}
fixedDeclaration = fixedDeclaration.WithAdditionalAnnotations(Simplifier.Annotation).WithTriviaFrom(declarationTypeToFix);
return (declarationTypeToFix, fixedDeclaration);
}
protected override async Task FixAllAsync(Document document, ImmutableArray<Diagnostic> diagnostics, SyntaxEditor editor, CodeActionOptionsProvider fallbackOptions, CancellationToken cancellationToken)
{
var (declarationTypeToFix, fixedDeclaration) =
await TryGetOldAndNewReturnTypeAsync(document, diagnostics, cancellationToken).ConfigureAwait(false);
editor.ReplaceNode(declarationTypeToFix, fixedDeclaration);
}
private static (TypeSyntax type, bool isAsync) TryGetDeclarationTypeToFix(SyntaxNode node)
{
return node.GetAncestors().Select(TryGetReturnTypeToFix).FirstOrDefault(p => p.type != null);
static (TypeSyntax type, bool isAsync) TryGetReturnTypeToFix(SyntaxNode containingMember)
{
return containingMember switch
{
// void M() { return 1; }
// async Task M() { return 1; }
MethodDeclarationSyntax method => (method.ReturnType, method.Modifiers.Any(SyntaxKind.AsyncKeyword)),
// void local() { return 1; }
// async Task local() { return 1; }
LocalFunctionStatementSyntax localFunction => (localFunction.ReturnType, localFunction.Modifiers.Any(SyntaxKind.AsyncKeyword)),
_ => default,
};
}
}
private static ITypeSymbol? InferTupleType(TupleExpressionSyntax tuple, SemanticModel semanticModel, CancellationToken cancellationToken)
{
var compilation = semanticModel.Compilation;
var argCount = tuple.Arguments.Count;
var baseTupleType = compilation.ValueTupleType(argCount);
if (baseTupleType is null)
return null;
var inferredTupleTypes = new ITypeSymbol[argCount];
for (var i = 0; i < argCount; i++)
{
var argumentExpression = tuple.Arguments[i].Expression;
var type = semanticModel.GetTypeInfo(argumentExpression, cancellationToken).Type;
// Nested tuple with unknown type, e.g. `(string.Empty, (2, null))`
if (type is null &&
argumentExpression is TupleExpressionSyntax nestedTuple)
{
type = InferTupleType(nestedTuple, semanticModel, cancellationToken);
}
inferredTupleTypes[i] = type is null ? semanticModel.Compilation.ObjectType : type;
}
return baseTupleType.Construct(inferredTupleTypes);
}
}
}
|