Optimize the Regex source generator's handling of `Compilation` objects. (#65431)
authorTheodore Tsirpanis <teo@tsirpanis.gr>
Wed, 23 Feb 2022 13:54:47 +0000 (15:54 +0200)
committerGitHub <noreply@github.com>
Wed, 23 Feb 2022 13:54:47 +0000 (08:54 -0500)
* Optimize the Regex source generator's handling of Compilation objects.

* Use the common downlevel IsExternalInit file in the regex source generator.

And remove a now-unused file.

* Address PR feedback; revert to the old way of matching symbols.

* Fix an outdated comment.

src/libraries/System.Text.RegularExpressions/gen/RegexGenerator.Emitter.cs
src/libraries/System.Text.RegularExpressions/gen/RegexGenerator.Parser.cs
src/libraries/System.Text.RegularExpressions/gen/RegexGenerator.cs
src/libraries/System.Text.RegularExpressions/gen/Stubs.cs
src/libraries/System.Text.RegularExpressions/gen/System.Text.RegularExpressions.Generator.csproj

index d0a0684..3340a14 100644 (file)
@@ -37,7 +37,7 @@ namespace System.Text.RegularExpressions.Generator
         };
 
         /// <summary>Generates the code for one regular expression class.</summary>
-        private static (string, ImmutableArray<Diagnostic>) EmitRegexType(RegexType regexClass, Compilation compilation)
+        private static (string, ImmutableArray<Diagnostic>) EmitRegexType(RegexType regexClass, bool allowUnsafe)
         {
             var sb = new StringBuilder(1024);
             var writer = new IndentedTextWriter(new StringWriter(sb));
@@ -78,7 +78,7 @@ namespace System.Text.RegularExpressions.Generator
             generatedName += ComputeStringHash(generatedName).ToString("X");
 
             // Generate the regex type
-            ImmutableArray<Diagnostic> diagnostics = EmitRegexMethod(writer, regexClass.Method, generatedName, compilation);
+            ImmutableArray<Diagnostic> diagnostics = EmitRegexMethod(writer, regexClass.Method, generatedName, allowUnsafe);
 
             while (writer.Indent != 0)
             {
@@ -145,7 +145,7 @@ namespace System.Text.RegularExpressions.Generator
         }
 
         /// <summary>Generates the code for a regular expression method.</summary>
-        private static ImmutableArray<Diagnostic> EmitRegexMethod(IndentedTextWriter writer, RegexMethod rm, string id, Compilation compilation)
+        private static ImmutableArray<Diagnostic> EmitRegexMethod(IndentedTextWriter writer, RegexMethod rm, string id, bool allowUnsafe)
         {
             string patternExpression = Literal(rm.Pattern);
             string optionsExpression = Literal(rm.Options);
@@ -170,8 +170,6 @@ namespace System.Text.RegularExpressions.Generator
                 return ImmutableArray.Create(Diagnostic.Create(DiagnosticDescriptors.LimitedSourceGeneration, rm.MethodSyntax.GetLocation()));
             }
 
-            bool allowUnsafe = compilation.Options is CSharpCompilationOptions { AllowUnsafe: true };
-
             writer.WriteLine($"new {id}();");
             writer.WriteLine();
             writer.WriteLine($"    private {id}()");
index a7ba253..5053b53 100644 (file)
@@ -21,34 +21,42 @@ namespace System.Text.RegularExpressions.Generator
         private const string RegexName = "System.Text.RegularExpressions.Regex";
         private const string RegexGeneratorAttributeName = "System.Text.RegularExpressions.RegexGeneratorAttribute";
 
-        private static bool IsSyntaxTargetForGeneration(SyntaxNode node) =>
+        private static bool IsSyntaxTargetForGeneration(SyntaxNode node, CancellationToken cancellationToken) =>
             // We don't have a semantic model here, so the best we can do is say whether there are any attributes.
             node is MethodDeclarationSyntax { AttributeLists: { Count: > 0 } };
 
-        private static MethodDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
+        private static bool IsSemanticTargetForGeneration(SemanticModel semanticModel, MethodDeclarationSyntax methodDeclarationSyntax, CancellationToken cancellationToken)
         {
-            var methodDeclarationSyntax = (MethodDeclarationSyntax)context.Node;
-
             foreach (AttributeListSyntax attributeListSyntax in methodDeclarationSyntax.AttributeLists)
             {
                 foreach (AttributeSyntax attributeSyntax in attributeListSyntax.Attributes)
                 {
-                    if (context.SemanticModel.GetSymbolInfo(attributeSyntax).Symbol is IMethodSymbol attributeSymbol &&
+                    if (semanticModel.GetSymbolInfo(attributeSyntax, cancellationToken).Symbol is IMethodSymbol attributeSymbol &&
                         attributeSymbol.ContainingType.ToDisplayString() == RegexGeneratorAttributeName)
                     {
-                        return methodDeclarationSyntax;
+                        return true;
                     }
                 }
             }
 
-            return null;
+            return false;
         }
 
         // Returns null if nothing to do, Diagnostic if there's an error to report, or RegexType if the type was analyzed successfully.
-        private static object? GetRegexTypeToEmit(Compilation compilation, MethodDeclarationSyntax methodSyntax, CancellationToken cancellationToken)
+        private static object? GetSemanticTargetForGeneration(GeneratorSyntaxContext context, CancellationToken cancellationToken)
         {
+            var methodSyntax = (MethodDeclarationSyntax)context.Node;
+            SemanticModel sm = context.SemanticModel;
+
+            if (!IsSemanticTargetForGeneration(sm, methodSyntax, cancellationToken))
+            {
+                return null;
+            }
+
+            Compilation compilation = sm.Compilation;
             INamedTypeSymbol? regexSymbol = compilation.GetBestTypeByMetadataName(RegexName);
             INamedTypeSymbol? regexGeneratorAttributeSymbol = compilation.GetBestTypeByMetadataName(RegexGeneratorAttributeName);
+
             if (regexSymbol is null || regexGeneratorAttributeSymbol is null)
             {
                 // Required types aren't available
@@ -61,8 +69,6 @@ namespace System.Text.RegularExpressions.Generator
                 return null;
             }
 
-            SemanticModel sm = compilation.GetSemanticModel(methodSyntax.SyntaxTree);
-
             IMethodSymbol? regexMethodSymbol = sm.GetDeclaredSymbol(methodSyntax, cancellationToken) as IMethodSymbol;
             if (regexMethodSymbol is null)
             {
index 558c613..b840aff 100644 (file)
@@ -25,31 +25,31 @@ namespace System.Text.RegularExpressions.Generator
     {
         public void Initialize(IncrementalGeneratorInitializationContext context)
         {
+            // To avoid invalidating the generator's output when anything from the compilation
+            // changes, we will extract from it the only thing we care about: whether unsafe
+            // code is allowed.
+            IncrementalValueProvider<bool> allowUnsafeProvider =
+                context.CompilationProvider
+                .Select((x, _) => x.Options is CSharpCompilationOptions { AllowUnsafe: true });
+
             // Contains one entry per regex method, either the generated code for that regex method,
             // a diagnostic to fail with, or null if no action should be taken for that regex.
             IncrementalValueProvider<ImmutableArray<object?>> codeOrDiagnostics =
                 context.SyntaxProvider
 
-                // Find all MethodDeclarationSyntax nodes attributed with RegexGenerator
-                .CreateSyntaxProvider(static (s, _) => IsSyntaxTargetForGeneration(s), static (ctx, _) => GetSemanticTargetForGeneration(ctx))
+                // Find all MethodDeclarationSyntax nodes attributed with RegexGenerator and gather the required information
+                .CreateSyntaxProvider(IsSyntaxTargetForGeneration, GetSemanticTargetForGeneration)
                 .Where(static m => m is not null)
 
-                // Pair each with the compilation
-                .Combine(context.CompilationProvider)
-
-                // Use a custom comparer that ignores the compilation. We want to avoid regenerating for regex methods
-                // that haven't been changed, but any change to a regex method will change the Compilation, so we ignore
-                // the Compilation for purposes of caching.
-                .WithComparer(new LambdaComparer<(MethodDeclarationSyntax?, Compilation)>(
-                    static (left, right) => EqualityComparer<MethodDeclarationSyntax>.Default.Equals(left.Item1, right.Item1),
-                    static o => o.Item1?.GetHashCode() ?? 0))
+                // Pair each with whether unsafe code is allowed
+                .Combine(allowUnsafeProvider)
 
-                // Get the resulting code string or error Diagnostic for each MethodDeclarationSyntax/Compilation pair
-                .Select((state, cancellationToken) =>
+                // Get the resulting code string or error Diagnostic for
+                // each MethodDeclarationSyntax/allow-unsafe-blocks pair
+                .Select((state, _) =>
                 {
-                    Debug.Assert(state.Item1 is not null);
-                    object? result = GetRegexTypeToEmit(state.Item2, state.Item1, cancellationToken);
-                    return result is RegexType regexType ? EmitRegexType(regexType, state.Item2) : result;
+                    Debug.Assert(state.Left is not null);
+                    return state.Left is RegexType regexType ? EmitRegexType(regexType, state.Right) : state.Left;
                 })
                 .Collect();
 
@@ -83,21 +83,5 @@ namespace System.Text.RegularExpressions.Generator
                 context.AddSource("RegexGenerator.g.cs", string.Join(Environment.NewLine, code));
             });
         }
-
-        private sealed class LambdaComparer<T> : IEqualityComparer<T>
-        {
-            private readonly Func<T?, T?, bool> _equal;
-            private readonly Func<T?, int> _getHashCode;
-
-            public LambdaComparer(Func<T?, T?, bool> equal, Func<T?, int> getHashCode)
-            {
-                _equal = equal;
-                _getHashCode = getHashCode;
-            }
-
-            public bool Equals(T? x, T? y) => _equal(x, y);
-
-            public int GetHashCode(T obj) => _getHashCode(obj);
-        }
     }
 }
index e0d6dfb..9b99a91 100644 (file)
@@ -84,8 +84,3 @@ namespace System.Text.RegularExpressions
         public const int WholeString = -4;
     }
 }
-
-namespace System.Runtime.CompilerServices
-{
-    internal static class IsExternalInit { }
-}
index 51c49cc..1178e5e 100644 (file)
@@ -23,6 +23,7 @@
   <ItemGroup>
     <!-- Common generator support -->
     <Compile Include="$(CommonPath)Roslyn\GetBestTypeByMetadataName.cs" Link="Common\Roslyn\GetBestTypeByMetadataName.cs" />
+    <Compile Include="$(CommonPath)System\Runtime\CompilerServices\IsExternalInit.cs" Link="Common\System\Runtime\CompilerServices\IsExternalInit.cs" />
 
     <!-- Code included from System.Text.RegularExpressions -->
     <Compile Include="$(CommonPath)System\HexConverter.cs" Link="Production\HexConverter.cs" />