Use StartsWith in regex compiler / source gen for shorter strings (#65222)
authorStephen Toub <stoub@microsoft.com>
Tue, 15 Feb 2022 02:02:47 +0000 (21:02 -0500)
committerGitHub <noreply@github.com>
Tue, 15 Feb 2022 02:02:47 +0000 (21:02 -0500)
The RegexCompiler and source generator currently special-case strings < 64 chars in length and unroll the loop, using a series of ulong and uint comparisons where possible.  While efficient, this makes the generated code harder to read, and the source generator also has endianness issues when the compiled binary is then used on a machine with different endianness.  The JIT is going to start doing such unrolling as part of StartsWith, so we can leave the optimization up to it; it'll be able to do it better, anyway, with its optimization applying to more uses, using vectors where applicable, etc.

src/libraries/System.Text.RegularExpressions/gen/RegexGenerator.Emitter.cs
src/libraries/System.Text.RegularExpressions/src/System/Text/RegularExpressions/RegexCompiler.cs

index 00be97d..39324d2 100644 (file)
@@ -2234,110 +2234,37 @@ namespace System.Text.RegularExpressions.Generator
             {
                 Debug.Assert(node.Kind is RegexNodeKind.Multi, $"Unexpected type: {node.Kind}");
 
-                bool caseInsensitive = IsCaseInsensitive(node);
-
                 string str = node.Str!;
-                Debug.Assert(str.Length != 0);
+                Debug.Assert(str.Length >= 2);
 
-                const int MaxUnrollLength = 64;
-                if (str.Length <= MaxUnrollLength)
+                if (IsCaseInsensitive(node)) // StartsWith(..., XxIgnoreCase) won't necessarily be the same as char-by-char comparison
                 {
-                    // Unroll shorter strings.
-
-                    // For strings more than two characters and when performing case-sensitive searches, we try to do fewer comparisons
-                    // by comparing 2 or 4 characters at a time.  Because we might be compiling on one endianness and running on another,
-                    // both little and big endian values are emitted and which is used is selected at run-time.
-                    ReadOnlySpan<byte> byteStr = MemoryMarshal.AsBytes(str.AsSpan());
-                    bool useMultiCharReads = !caseInsensitive && byteStr.Length >= sizeof(uint);
-                    if (useMultiCharReads)
-                    {
-                        additionalDeclarations.Add("global::System.ReadOnlySpan<byte> byteSpan;");
-                        writer.WriteLine($"byteSpan = global::System.Runtime.InteropServices.MemoryMarshal.AsBytes({sliceSpan});");
-                    }
-
-                    writer.Write("if (");
+                    // This case should be relatively rare.  It will only occur with IgnoreCase and a series of non-ASCII characters.
 
-                    bool emittedFirstCheck = false;
                     if (emitLengthCheck)
                     {
-                        writer.Write($"(uint){sliceSpan}.Length < {sliceStaticPos + str.Length}");
-                        emittedFirstCheck = true;
-                    }
-
-                    void EmitOr()
-                    {
-                        if (emittedFirstCheck)
-                        {
-                            writer.WriteLine(" ||");
-                            writer.Write("    ");
-                        }
-                        emittedFirstCheck = true;
+                        EmitSpanLengthCheck(str.Length);
                     }
 
-                    if (useMultiCharReads)
+                    using (EmitBlock(writer, $"for (int i = 0; i < {Literal(node.Str)}.Length; i++)"))
                     {
-                        while (byteStr.Length >= sizeof(ulong))
+                        string textSpanIndex = sliceStaticPos > 0 ? $"i + {sliceStaticPos}" : "i";
+                        using (EmitBlock(writer, $"if ({ToLower(hasTextInfo, options, $"{sliceSpan}[{textSpanIndex}]")} != {Literal(str)}[i])"))
                         {
-                            EmitOr();
-                            string byteSpan = sliceStaticPos > 0 ? $"byteSpan.Slice({sliceStaticPos * sizeof(char)})" : "byteSpan";
-                            writer.Write($"global::System.Buffers.Binary.BinaryPrimitives.ReadUInt64LittleEndian({byteSpan}) != 0x{BinaryPrimitives.ReadUInt64LittleEndian(byteStr):X}ul");
-                            sliceStaticPos += sizeof(ulong) / sizeof(char);
-                            byteStr = byteStr.Slice(sizeof(ulong));
-                        }
-
-                        while (byteStr.Length >= sizeof(uint))
-                        {
-                            EmitOr();
-                            string byteSpan = sliceStaticPos > 0 ? $"byteSpan.Slice({sliceStaticPos * sizeof(char)})" : "byteSpan";
-                            writer.Write($"global::System.Buffers.Binary.BinaryPrimitives.ReadUInt32LittleEndian({byteSpan}) != 0x{BinaryPrimitives.ReadUInt32LittleEndian(byteStr):X}u");
-                            sliceStaticPos += sizeof(uint) / sizeof(char);
-                            byteStr = byteStr.Slice(sizeof(uint));
+                            writer.WriteLine($"goto {doneLabel};");
                         }
                     }
-
-                    // Emit remaining comparisons character by character.
-                    for (int i = (str.Length * sizeof(char) - byteStr.Length) / sizeof(char); i < str.Length; i++)
-                    {
-                        EmitOr();
-                        writer.Write($"{ToLowerIfNeeded(hasTextInfo, options, $"{sliceSpan}[{sliceStaticPos}]", caseInsensitive)} != {Literal(str[i])}");
-                        sliceStaticPos++;
-                    }
-
-                    writer.WriteLine(")");
-                    using (EmitBlock(writer, null))
-                    {
-                        writer.WriteLine($"goto {doneLabel};");
-                    }
                 }
                 else
                 {
-                    // Longer strings are compared character by character.  If this is a case-sensitive comparison, we can simply
-                    // delegate to StartsWith.  If this is case-insensitive, we open-code the comparison loop, as we need to lowercase
-                    // each character involved, and none of the StringComparison options provide the right semantics of comparing
-                    // character-by-character while respecting the culture.
-                    if (!caseInsensitive)
+                    string sourceSpan = sliceStaticPos > 0 ? $"{sliceSpan}.Slice({sliceStaticPos})" : sliceSpan;
+                    using (EmitBlock(writer, $"if (!global::System.MemoryExtensions.StartsWith({sourceSpan}, {Literal(node.Str)}))"))
                     {
-                        string sourceSpan = sliceStaticPos > 0 ? $"{sliceSpan}.Slice({sliceStaticPos})" : sliceSpan;
-                        using (EmitBlock(writer, $"if (!global::System.MemoryExtensions.StartsWith({sourceSpan}, {Literal(node.Str)}))"))
-                        {
-                            writer.WriteLine($"goto {doneLabel};");
-                        }
-                        sliceStaticPos += node.Str.Length;
-                    }
-                    else
-                    {
-                        EmitSpanLengthCheck(str.Length);
-                        using (EmitBlock(writer, $"for (int i = 0; i < {Literal(node.Str)}.Length; i++)"))
-                        {
-                            string textSpanIndex = sliceStaticPos > 0 ? $"i + {sliceStaticPos}" : "i";
-                            using (EmitBlock(writer, $"if ({ToLower(hasTextInfo, options, $"{sliceSpan}[{textSpanIndex}]")} != {Literal(str)}[i])"))
-                            {
-                                writer.WriteLine($"goto {doneLabel};");
-                            }
-                        }
-                        sliceStaticPos += node.Str.Length;
+                        writer.WriteLine($"goto {doneLabel};");
                     }
                 }
+
+                sliceStaticPos += node.Str.Length;
             }
 
             void EmitSingleCharLoop(RegexNode node, RegexNode? subsequent = null, bool emitLengthChecksIfRequired = true)
index ceb61bd..1938de0 100644 (file)
@@ -2447,95 +2447,40 @@ namespace System.Text.RegularExpressions
             {
                 Debug.Assert(node.Kind is RegexNodeKind.Multi, $"Unexpected type: {node.Kind}");
 
-                bool caseInsensitive = IsCaseInsensitive(node);
-
-                // If the multi string's length exceeds the maximum length we want to unroll, instead generate a call to StartsWith.
-                // Each character that we unroll results in code generation that increases the size of both the IL and the resulting asm,
-                // and with a large enough string, that can cause significant overhead as well as even risk stack overflow due to
-                // having an obscenely long method.  Such long string lengths in a pattern are generally quite rare.  However, we also
-                // want to unroll for shorter strings, because the overhead of invoking StartsWith instead of doing a few simple
-                // inline comparisons is very measurable, especially if we're doing a culture-sensitive comparison and StartsWith
-                // accesses CultureInfo.CurrentCulture on each call.  We need to be cognizant not only of the cost if the whole
-                // string matches, but also the cost when the comparison fails early on, and thus we pay for the call overhead
-                // but don't reap the benefits of all the vectorization StartsWith can do.
-                const int MaxUnrollLength = 64;
-                if (!caseInsensitive && // StartsWith(..., XxIgnoreCase) won't necessarily be the same as char-by-char comparison
-                    node.Str!.Length > MaxUnrollLength)
-                {
-                    // if (!slice.Slice(sliceStaticPos).StartsWith("...") goto doneLabel;
-                    Ldloca(slice);
-                    Ldc(sliceStaticPos);
-                    Call(s_spanSliceIntMethod);
-                    Ldstr(node.Str);
-                    Call(s_stringAsSpanMethod);
-                    Call(s_spanStartsWith);
-                    BrfalseFar(doneLabel);
-                    sliceStaticPos += node.Str.Length;
-                    return;
-                }
+                string str = node.Str!;
+                Debug.Assert(str.Length >= 2);
 
-                // Emit the length check for the whole string.  If the generated code gets past this point,
-                // we know the span is at least sliceStaticPos + s.Length long.
-                ReadOnlySpan<char> s = node.Str;
-                if (emitLengthCheck)
+                if (IsCaseInsensitive(node)) // StartsWith(..., XxIgnoreCase) won't necessarily be the same as char-by-char comparison
                 {
-                    EmitSpanLengthCheck(s.Length);
-                }
+                    // This case should be relatively rare.  It will only occur with IgnoreCase and a series of non-ASCII characters.
 
-                // If we're doing a case-insensitive comparison, we need to lower case each character,
-                // so we just go character-by-character.  But if we're not, we try to process multiple
-                // characters at a time; this is helpful not only for throughput but also in reducing
-                // the amount of IL and asm that results from this unrolling. This optimization
-                // is subject to endianness issues if the generated code is used on a machine with a
-                // different endianness, but that's not a concern when the code is emitted by the
-                // same process that then uses it.
-                if (!caseInsensitive)
-                {
-                    // On 64-bit, process 4 characters at a time until the string isn't at least 4 characters long.
-                    if (IntPtr.Size == 8)
+                    if (emitLengthCheck)
                     {
-                        const int CharsPerInt64 = 4;
-                        while (s.Length >= CharsPerInt64)
-                        {
-                            // if (Unsafe.ReadUnaligned<long>(ref Unsafe.Add(ref MemoryMarshal.GetReference(slice), sliceStaticPos)) != value) goto doneLabel;
-                            EmitTextSpanOffset();
-                            Unaligned(1);
-                            LdindI8();
-                            LdcI8(MemoryMarshal.Read<long>(MemoryMarshal.AsBytes(s)));
-                            BneFar(doneLabel);
-                            sliceStaticPos += CharsPerInt64;
-                            s = s.Slice(CharsPerInt64);
-                        }
+                        EmitSpanLengthCheck(str.Length);
                     }
 
-                    // Of what remains, process 2 characters at a time until the string isn't at least 2 characters long.
-                    const int CharsPerInt32 = 2;
-                    while (s.Length >= CharsPerInt32)
+                    foreach (char c in str)
                     {
-                        // if (Unsafe.ReadUnaligned<int>(ref Unsafe.Add(ref MemoryMarshal.GetReference(slice), sliceStaticPos)) != value) goto doneLabel;
+                        // if (c != slice[sliceStaticPos++]) goto doneLabel;
                         EmitTextSpanOffset();
-                        Unaligned(1);
-                        LdindI4();
-                        Ldc(MemoryMarshal.Read<int>(MemoryMarshal.AsBytes(s)));
+                        sliceStaticPos++;
+                        LdindU2();
+                        CallToLower();
+                        Ldc(c);
                         BneFar(doneLabel);
-                        sliceStaticPos += CharsPerInt32;
-                        s = s.Slice(CharsPerInt32);
                     }
                 }
-
-                // Finally, process all of the remaining characters one by one.
-                for (int i = 0; i < s.Length; i++)
+                else
                 {
-                    // if (s[i] != slice[sliceStaticPos++]) goto doneLabel;
-                    EmitTextSpanOffset();
-                    sliceStaticPos++;
-                    LdindU2();
-                    if (caseInsensitive)
-                    {
-                        CallToLower();
-                    }
-                    Ldc(s[i]);
-                    BneFar(doneLabel);
+                    // if (!slice.Slice(sliceStaticPos).StartsWith("...") goto doneLabel;
+                    Ldloca(slice);
+                    Ldc(sliceStaticPos);
+                    Call(s_spanSliceIntMethod);
+                    Ldstr(node.Str);
+                    Call(s_stringAsSpanMethod);
+                    Call(s_spanStartsWith);
+                    BrfalseFar(doneLabel);
+                    sliceStaticPos += node.Str.Length;
                 }
             }