[release/6.0-rc1] Migrate LoggerMessageGenerator to IIncrementalGenerator (#58271)
authorgithub-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Sat, 28 Aug 2021 00:30:23 +0000 (17:30 -0700)
committerGitHub <noreply@github.com>
Sat, 28 Aug 2021 00:30:23 +0000 (17:30 -0700)
* Migrate LoggerMessageGenerator to IIncrementalGenerator

This reduces the time spent in the background in VS running the source generator, since we only need to respond to methods that have the LoggerMessageAttribute on them.

Contributes to #56702

* PR feedback

* PR feedback

Co-authored-by: Eric Erhardt <eric.erhardt@microsoft.com>
src/libraries/Common/tests/SourceGenerators/RoslynTestUtils.cs
src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/LoggerMessageGenerator.Parser.cs
src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/LoggerMessageGenerator.cs
src/libraries/Microsoft.Extensions.Logging.Abstractions/tests/Microsoft.Extensions.Logging.Generators.Tests/LoggerMessageGeneratorParserTests.cs

index 4928270..1c57260 100644 (file)
@@ -141,10 +141,9 @@ namespace SourceGenerators.Tests
         /// Runs a Roslyn generator over a set of source files.
         /// </summary>
         public static async Task<(ImmutableArray<Diagnostic>, ImmutableArray<GeneratedSourceResult>)> RunGenerator(
-            ISourceGenerator generator,
+            IIncrementalGenerator generator,
             IEnumerable<Assembly>? references,
             IEnumerable<string> sources,
-            AnalyzerConfigOptionsProvider? optionsProvider = null,
             bool includeBaseReferences = true,
             CancellationToken cancellationToken = default)
         {
@@ -156,7 +155,9 @@ namespace SourceGenerators.Tests
 
             Compilation? comp = await proj!.GetCompilationAsync(CancellationToken.None).ConfigureAwait(false);
 
-            CSharpGeneratorDriver cgd = CSharpGeneratorDriver.Create(new[] { generator }, optionsProvider: optionsProvider);
+            // workaround https://github.com/dotnet/roslyn/pull/55866. We can remove "LangVersion=Preview" when we get a Roslyn build with that change.
+            CSharpParseOptions options = CSharpParseOptions.Default.WithLanguageVersion(LanguageVersion.Preview);
+            CSharpGeneratorDriver cgd = CSharpGeneratorDriver.Create(new[] { generator.AsSourceGenerator() }, parseOptions: options);
             GeneratorDriver gd = cgd.RunGenerators(comp!, cancellationToken);
 
             GeneratorDriverRunResult r = gd.GetRunResult();
index 4a99503..b448bc7 100644 (file)
@@ -17,6 +17,8 @@ namespace Microsoft.Extensions.Logging.Generators
     {
         internal class Parser
         {
+            private const string LoggerMessageAttribute = "Microsoft.Extensions.Logging.LoggerMessageAttribute";
+
             private readonly CancellationToken _cancellationToken;
             private readonly Compilation _compilation;
             private readonly Action<Diagnostic> _reportDiagnostic;
@@ -28,13 +30,41 @@ namespace Microsoft.Extensions.Logging.Generators
                 _reportDiagnostic = reportDiagnostic;
             }
 
+            internal static bool IsSyntaxTargetForGeneration(SyntaxNode node) =>
+                node is MethodDeclarationSyntax m && m.AttributeLists.Count > 0;
+
+            internal static ClassDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
+            {
+                var methodDeclarationSyntax = (MethodDeclarationSyntax)context.Node;
+
+                foreach (AttributeListSyntax attributeListSyntax in methodDeclarationSyntax.AttributeLists)
+                {
+                    foreach (AttributeSyntax attributeSyntax in attributeListSyntax.Attributes)
+                    {
+                        IMethodSymbol attributeSymbol = context.SemanticModel.GetSymbolInfo(attributeSyntax).Symbol as IMethodSymbol;
+                        if (attributeSymbol == null)
+                        {
+                            continue;
+                        }
+
+                        INamedTypeSymbol attributeContainingTypeSymbol = attributeSymbol.ContainingType;
+                        string fullName = attributeContainingTypeSymbol.ToDisplayString();
+
+                        if (fullName == LoggerMessageAttribute)
+                        {
+                            return methodDeclarationSyntax.Parent as ClassDeclarationSyntax;
+                        }
+                    }
+                }
+
+                return null;
+            }
+
             /// <summary>
             /// Gets the set of logging classes containing methods to output.
             /// </summary>
             public IReadOnlyList<LoggerClass> GetLogClasses(IEnumerable<ClassDeclarationSyntax> classes)
             {
-                const string LoggerMessageAttribute = "Microsoft.Extensions.Logging.LoggerMessageAttribute";
-
                 INamedTypeSymbol loggerMessageAttribute = _compilation.GetTypeByMetadataName(LoggerMessageAttribute);
                 if (loggerMessageAttribute == null)
                 {
@@ -442,11 +472,11 @@ namespace Microsoft.Extensions.Logging.Generators
                                             LoggerClass currentLoggerClass = lc;
                                             var parentLoggerClass = (classDec.Parent as TypeDeclarationSyntax);
 
-                                            bool IsAllowedKind(SyntaxKind kind) => 
+                                            bool IsAllowedKind(SyntaxKind kind) =>
                                                 kind == SyntaxKind.ClassDeclaration ||
                                                 kind == SyntaxKind.StructDeclaration ||
                                                 kind == SyntaxKind.RecordDeclaration;
-                                            
+
                                             while (parentLoggerClass != null && IsAllowedKind(parentLoggerClass.Kind()))
                                             {
                                                 currentLoggerClass.ParentClass = new LoggerClass
index 92105d5..7aaca68 100644 (file)
@@ -3,7 +3,10 @@
 
 using System;
 using System.Collections.Generic;
+using System.Collections.Immutable;
 using System.Diagnostics.CodeAnalysis;
+using System.Diagnostics.Tracing;
+using System.Linq;
 using System.Runtime.CompilerServices;
 using System.Text;
 using Microsoft.CodeAnalysis;
@@ -15,50 +18,38 @@ using Microsoft.CodeAnalysis.Text;
 namespace Microsoft.Extensions.Logging.Generators
 {
     [Generator]
-    public partial class LoggerMessageGenerator : ISourceGenerator
+    public partial class LoggerMessageGenerator : IIncrementalGenerator
     {
-        [ExcludeFromCodeCoverage]
-        public void Initialize(GeneratorInitializationContext context)
+        public void Initialize(IncrementalGeneratorInitializationContext context)
         {
-            context.RegisterForSyntaxNotifications(SyntaxReceiver.Create);
+            IncrementalValuesProvider<ClassDeclarationSyntax> classDeclarations = context.SyntaxProvider
+                .CreateSyntaxProvider(static (s, _) => Parser.IsSyntaxTargetForGeneration(s), static (ctx, _) => Parser.GetSemanticTargetForGeneration(ctx))
+                .Where(static m => m is not null);
+
+            IncrementalValueProvider<(Compilation, ImmutableArray<ClassDeclarationSyntax>)> compilationAndClasses =
+                context.CompilationProvider.Combine(classDeclarations.Collect());
+
+            context.RegisterSourceOutput(compilationAndClasses, static (spc, source) => Execute(source.Item1, source.Item2, spc));
         }
 
-        [ExcludeFromCodeCoverage]
-        public void Execute(GeneratorExecutionContext context)
+        private static void Execute(Compilation compilation, ImmutableArray<ClassDeclarationSyntax> classes, SourceProductionContext context)
         {
-            if (context.SyntaxReceiver is not SyntaxReceiver receiver || receiver.ClassDeclarations.Count == 0)
+            if (classes.IsDefaultOrEmpty)
             {
                 // nothing to do yet
                 return;
             }
 
-            var p = new Parser(context.Compilation, context.ReportDiagnostic, context.CancellationToken);
-            IReadOnlyList<LoggerClass> logClasses = p.GetLogClasses(receiver.ClassDeclarations);
+            IEnumerable<ClassDeclarationSyntax> distinctClasses = classes.Distinct();
+
+            var p = new Parser(compilation, context.ReportDiagnostic, context.CancellationToken);
+            IReadOnlyList<LoggerClass> logClasses = p.GetLogClasses(distinctClasses);
             if (logClasses.Count > 0)
             {
                 var e = new Emitter();
                 string result = e.Emit(logClasses, context.CancellationToken);
-    
-                context.AddSource("LoggerMessage.g.cs", SourceText.From(result, Encoding.UTF8));
-            }
-        }
 
-        [ExcludeFromCodeCoverage]
-        private sealed class SyntaxReceiver : ISyntaxReceiver
-        {
-            internal static ISyntaxReceiver Create()
-            {
-                return new SyntaxReceiver();
-            }
-
-            public List<ClassDeclarationSyntax> ClassDeclarations { get; } = new ();
-
-            public void OnVisitSyntaxNode(SyntaxNode syntaxNode)
-            {
-                if (syntaxNode is ClassDeclarationSyntax classSyntax)
-                {
-                    ClassDeclarations.Add(classSyntax);
-                }
+                context.AddSource("LoggerMessage.g.cs", SourceText.From(result, Encoding.UTF8));
             }
         }
     }
index b2fe610..a1ec425 100644 (file)
@@ -380,6 +380,7 @@ namespace Microsoft.Extensions.Logging.Generators.Tests
                     public class Void {}
                     public class String {}
                     public struct DateTime {}
+                    public abstract class Attribute {}
                 }
                 namespace System.Collections
                 {
@@ -392,10 +393,12 @@ namespace Microsoft.Extensions.Logging.Generators.Tests
                 }
                 namespace Microsoft.Extensions.Logging
                 {
-                    public class LoggerMessageAttribute {}
+                    public class LoggerMessageAttribute : System.Attribute {}
                 }
                 partial class C
                 {
+                    [Microsoft.Extensions.Logging.LoggerMessage]
+                    public static partial void Log(ILogger logger);
                 }
             ", false, includeBaseReferences: false, includeLoggingReferences: false);