Support non-ascii in fgVNBasedIntrinsicExpansionForCall_ReadUtf8 (#89383)
authorEgor Bogatov <egorbo@gmail.com>
Tue, 25 Jul 2023 17:59:52 +0000 (19:59 +0200)
committerGitHub <noreply@github.com>
Tue, 25 Jul 2023 17:59:52 +0000 (19:59 +0200)
src/coreclr/jit/helperexpansion.cpp
src/coreclr/minipal/Windows/CMakeLists.txt
src/coreclr/pal/inc/rt/cpp/stdbool.h [new file with mode: 0644]
src/tests/JIT/opt/Vectorization/ReadUtf8.cs

index e8c237a..795245d 100644 (file)
@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 #include "jitpch.h"
+#include <minipal/utf8.h>
 #ifdef _MSC_VER
 #pragma hdrstop
 #endif
@@ -1238,7 +1239,7 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall(BasicBlock** pBlock, Statement
 
 //------------------------------------------------------------------------------
 // fgVNBasedIntrinsicExpansionForCall_ReadUtf8 : Expand NI_System_Text_UTF8Encoding_UTF8EncodingSealed_ReadUtf8
-//    when src data is a string literal (UTF16) that can be narrowed to ASCII (UTF8), e.g.:
+//    when src data is a string literal (UTF16) that can be converted to UTF8, e.g.:
 //
 //      string str = "Hello, world!";
 //      int bytesWritten = ReadUtf8(ref str[0], str.Length, buffer, buffer.Length);
@@ -1282,6 +1283,8 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock,
         return false;
     }
 
+    assert(strObj != nullptr);
+
     // We mostly expect string literal objects here, but let's be more agile just in case
     if (!info.compCompHnd->isObjectImmutable(strObj))
     {
@@ -1289,52 +1292,57 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock,
         return false;
     }
 
-    GenTree* srcLen = call->gtArgs.GetUserArgByIndex(1)->GetNode();
+    const GenTree* srcLen = call->gtArgs.GetUserArgByIndex(1)->GetNode();
     if (!srcLen->gtVNPair.BothEqual() || !vnStore->IsVNInt32Constant(srcLen->gtVNPair.GetLiberal()))
     {
         JITDUMP("ReadUtf8: srcLen is not constant\n")
         return false;
     }
 
-    const int      MaxPossibleUnrollThreshold = 256;
-    const unsigned unrollThreshold            = min(getUnrollThreshold(UnrollKind::Memcpy), MaxPossibleUnrollThreshold);
-    const unsigned srcLenCns                  = (unsigned)vnStore->GetConstantInt32(srcLen->gtVNPair.GetLiberal());
-    if ((srcLenCns == 0) || (srcLenCns > unrollThreshold))
+    // Source UTF16 (U16) string length in characters
+    const unsigned srcLenCnsU16            = (unsigned)vnStore->GetConstantInt32(srcLen->gtVNPair.GetLiberal());
+    const int      MaxU16BufferSizeInChars = 256;
+    if ((srcLenCnsU16 == 0) || (srcLenCnsU16 > MaxU16BufferSizeInChars))
     {
         // TODO: handle srcLenCns == 0 if it's a common case
-        JITDUMP("ReadUtf8: srcLenCns is out of unrollable range\n")
+        JITDUMP("ReadUtf8: srcLenCns is 0 or > MaxPossibleUnrollThreshold\n")
         return false;
     }
 
-    // Read the string literal (UTF16) into a local buffer (UTF8)
-    assert(strObj != nullptr);
-    uint16_t bufferU16[MaxPossibleUnrollThreshold];
-    uint8_t  bufferU8[MaxPossibleUnrollThreshold]; // twice smaller because of narrowing
-
-    // Both must be within [0..INT_MAX] range as we're going to cast them to int
-    assert((unsigned)srcLenCns <= INT_MAX);
-    assert((unsigned)strObjOffset <= INT_MAX);
+    uint16_t bufferU16[MaxU16BufferSizeInChars];
 
     // getObjectContent is expected to validate the offset and length
-    if (!info.compCompHnd->getObjectContent(strObj, (uint8_t*)bufferU16, (int)srcLenCns * 2, (int)strObjOffset))
+    // NOTE: (int) casts should not overflow:
+    //  * srcLenCns is <= MaxUTF16BufferSizeInChars
+    //  * strObjOffset is already checked to be <= INT_MAX
+    if (!info.compCompHnd->getObjectContent(strObj, (uint8_t*)bufferU16, (int)(srcLenCnsU16 * sizeof(uint16_t)),
+                                            (int)strObjOffset))
     {
         JITDUMP("ReadUtf8: getObjectContent returned false.\n")
         return false;
     }
 
-    for (unsigned charIndex = 0; charIndex < srcLenCns; charIndex++)
+    const int MaxU8BufferSizeInBytes = 256;
+    uint8_t   bufferU8[MaxU8BufferSizeInBytes];
+
+    const int srcLenU8 = (int)minipal_convert_utf16_to_utf8((const CHAR16_T*)bufferU16, srcLenCnsU16, (char*)bufferU8,
+                                                            MaxU8BufferSizeInBytes, 0);
+    if (srcLenU8 <= 0)
     {
-        // Buffer keeps the original utf16 chars
-        uint16_t ch = bufferU16[charIndex];
-        if (ch > 127)
-        {
-            // Only ASCII is supported.
-            JITDUMP("ReadUtf8: %dth char is not ASCII.\n", charIndex)
-            return false;
-        }
+        // E.g. output buffer is too small
+        JITDUMP("ReadUtf8: minipal_convert_utf16_to_utf8 returned <= 0\n")
+        return false;
+    }
 
-        // Narrow U16 to U8 in the same buffer
-        bufferU8[charIndex] = (uint8_t)ch;
+    // The API is expected to return [1..MaxU8BufferSizeInBytes] real length of the UTF-8 value
+    // stored in bufferU8
+    assert((unsigned)srcLenU8 <= MaxU8BufferSizeInBytes);
+
+    // Now that we know the exact UTF8 buffer length we can check if it's unrollable
+    if (srcLenU8 > (int)getUnrollThreshold(UnrollKind::Memcpy))
+    {
+        JITDUMP("ReadUtf8: srcLenU8 is out of unrollable range\n")
+        return false;
     }
 
     DebugInfo debugInfo = stmt->GetDebugInfo();
@@ -1373,10 +1381,10 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock,
     fgMorphStmtBlockOps(block, stmt);
     gtUpdateStmtSideEffects(stmt);
 
-    // srcLenCns is the length of the string literal in chars (UTF16)
+    // srcLenU8 is the length of the string literal in chars (UTF16)
     // but we're going to use the same value as the "bytesWritten" result in the fast path and in the length check.
-    GenTree* srcLenCnsNode = gtNewIconNode(srcLenCns);
-    fgValueNumberTreeConst(srcLenCnsNode);
+    GenTree* srcLenU8Node = gtNewIconNode(srcLenU8);
+    fgValueNumberTreeConst(srcLenU8Node);
 
     // We're going to insert the following blocks:
     //
@@ -1384,12 +1392,12 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock,
     //
     //  lengthCheckBb:
     //      bytesWritten = -1;
-    //      if (dstLen <srcLen)
+    //      if (dstLen < srcLenU8)
     //          goto block;
     //
     //  fastpathBb:
     //      <unrolled block copy>
-    //      bytesWritten = srcLenCns * 2;
+    //      bytesWritten = srcLenU8;
     //
     //  block:
     //      use(bytesWritten)
@@ -1406,7 +1414,7 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock,
     fgInsertStmtAtEnd(lengthCheckBb, fgNewStmtFromTree(bytesWrittenDefaultVal, debugInfo));
 
     GenTree* dstLen      = call->gtArgs.GetUserArgByIndex(3)->GetNode();
-    GenTree* lengthCheck = gtNewOperNode(GT_LT, TYP_INT, gtCloneExpr(dstLen), srcLenCnsNode);
+    GenTree* lengthCheck = gtNewOperNode(GT_LT, TYP_INT, gtCloneExpr(dstLen), srcLenU8Node);
     lengthCheck->gtFlags |= GTF_RELOP_JMP_USED;
     Statement* lengthCheckStmt = fgNewStmtFromTree(gtNewOperNode(GT_JTRUE, TYP_VOID, lengthCheck), debugInfo);
     fgInsertStmtAtEnd(lengthCheckBb, lengthCheckStmt);
@@ -1424,14 +1432,14 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock,
     fastpathBb->bbFlags |= BBF_INTERNAL;
 
     // The widest type we can use for loads
-    const var_types maxLoadType = roundDownMaxType(srcLenCns);
+    const var_types maxLoadType = roundDownMaxType(srcLenU8);
     assert(genTypeSize(maxLoadType) > 0);
 
     // How many iterations we need to copy UTF8 const data to the destination
-    unsigned iterations = srcLenCns / genTypeSize(maxLoadType);
+    unsigned iterations = srcLenU8 / genTypeSize(maxLoadType);
 
     // Add one more iteration if we have a remainder
-    iterations += (srcLenCns % genTypeSize(maxLoadType) == 0) ? 0 : 1;
+    iterations += (srcLenU8 % genTypeSize(maxLoadType) == 0) ? 0 : 1;
 
     GenTree* dstPtr = call->gtArgs.GetUserArgByIndex(2)->GetNode();
     for (unsigned i = 0; i < iterations; i++)
@@ -1441,7 +1449,7 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock,
         // Last iteration: overlap with previous load if needed
         if (i == iterations - 1)
         {
-            offset = (ssize_t)srcLenCns - genTypeSize(maxLoadType);
+            offset = (ssize_t)srcLenU8 - genTypeSize(maxLoadType);
         }
 
         // We're going to emit the following tree (in case of SIMD16 load):
@@ -1465,7 +1473,7 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock,
     }
 
     // Finally, store the number of bytes written to the resultLcl local
-    Statement* finalStmt = fgNewStmtFromTree(gtNewStoreLclVarNode(resultLclNum, gtCloneExpr(srcLenCnsNode)), debugInfo);
+    Statement* finalStmt = fgNewStmtFromTree(gtNewStoreLclVarNode(resultLclNum, gtCloneExpr(srcLenU8Node)), debugInfo);
     fgInsertStmtAtEnd(fastpathBb, finalStmt);
     fastpathBb->bbCodeOffs    = block->bbCodeOffsEnd;
     fastpathBb->bbCodeOffsEnd = block->bbCodeOffsEnd;
index 0c83eea..3f8368b 100644 (file)
@@ -1,6 +1,7 @@
 set(SOURCES
     doublemapping.cpp
     dn-u16.cpp
+    ${CLR_SRC_NATIVE_DIR}/minipal/utf8.c
 )
 
 if(NOT CLR_CROSS_COMPONENTS_BUILD)
diff --git a/src/coreclr/pal/inc/rt/cpp/stdbool.h b/src/coreclr/pal/inc/rt/cpp/stdbool.h
new file mode 100644 (file)
index 0000000..b23533a
--- /dev/null
@@ -0,0 +1,4 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+#include "palrt.h"
index 957f6b5..29a9534 100644 (file)
@@ -5,6 +5,7 @@ using System;
 using System.Numerics;
 using System.Runtime.CompilerServices;
 using System.Text;
+using System.Text.Unicode;
 using System.Threading;
 using Xunit;
 
@@ -20,6 +21,11 @@ public class ReadUtf8
             Test_hello();
             Test_CJK();
             Test_SIMD();
+            Test_1();
+            Test_2();
+            Test_3();
+            Test_4();
+            Test_5();
             Thread.Sleep(10);
         }
         return 100;
@@ -242,4 +248,85 @@ public class ReadUtf8
                 throw new Exception($"{item} != 0");
         }
     }
+
+    // ReadUtf8 is used inside Utf8.TryWrite + interpolation syntax:
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    static void Test_1()
+    {
+        var buffer = new byte[1024];
+        ValidateResult("", Utf8.TryWrite(buffer, $"", out var written1), buffer, written1);
+        ValidateResult("1", Utf8.TryWrite(buffer, $"1", out var written2), buffer, written2);
+        ValidateResult("12", Utf8.TryWrite(buffer, $"12", out var written3), buffer, written3);
+        ValidateResult("123", Utf8.TryWrite(buffer, $"123", out var written4), buffer, written4);
+        ValidateResult("1234", Utf8.TryWrite(buffer, $"1234", out var written5), buffer, written5);
+    }
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    static void Test_2()
+    {
+        var buffer = new byte[1024];
+        ValidateResult("12345", Utf8.TryWrite(buffer, $"12345", out var written1), buffer, written1);
+        ValidateResult("123456", Utf8.TryWrite(buffer, $"123456", out var written2), buffer, written2);
+        ValidateResult("1234567", Utf8.TryWrite(buffer, $"1234567", out var written3), buffer, written3);
+        ValidateResult("12345678", Utf8.TryWrite(buffer, $"12345678", out var written4), buffer, written4);
+        ValidateResult("123456789", Utf8.TryWrite(buffer, $"123456789", out var written5), buffer, written5);
+    }
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    static void Test_3()
+    {
+        var buffer = new byte[1024];
+        ValidateResult("123456789A", Utf8.TryWrite(buffer, $"123456789A", out var written1), buffer, written1);
+        ValidateResult("123456789AB", Utf8.TryWrite(buffer, $"123456789AB", out var written2), buffer, written2);
+        ValidateResult("123456789ABC", Utf8.TryWrite(buffer, $"123456789ABC", out var written3), buffer, written3);
+        ValidateResult("123456789ABCD", Utf8.TryWrite(buffer, $"123456789ABCD", out var written4), buffer, written4);
+        ValidateResult("123456789ABCDE", Utf8.TryWrite(buffer, $"123456789ABCDE", out var written5), buffer, written5);
+    }
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    static void Test_4()
+    {
+        var buffer = new byte[1024];
+        ValidateResult("123456789ABCDEF", Utf8.TryWrite(buffer, $"123456789ABCDEF", out var written1), buffer, written1);
+        ValidateResult("123456789ABCDEF\u0419", Utf8.TryWrite(buffer, $"123456789ABCDEF\u0419", out var written2), buffer, written2);
+        ValidateResult("123456789ABCDEF\u0419\u044C", Utf8.TryWrite(buffer, $"123456789ABCDEF\u0419\u044C", out var written3), buffer, written3);
+        ValidateResult("123456789ABCDEF\u0419\u044Cf", Utf8.TryWrite(buffer, $"123456789ABCDEF\u0419\u044Cf", out var written4), buffer, written4);
+        ValidateResult("123456789ABCDEF\u0419\u044Cf.", Utf8.TryWrite(buffer, $"123456789ABCDEF\u0419\u044Cf.", out var written5), buffer, written5);
+    }
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    static void Test_5()
+    {
+        var buffer = new byte[1024];
+        ValidateResult("\uD800b", Utf8.TryWrite(buffer, $"\uD800b", out var written1), buffer, written1);
+        ValidateResult("1\uD800b", Utf8.TryWrite(buffer, $"1\uD800b", out var written2), buffer, written2);
+        ValidateResult("11\uD800b", Utf8.TryWrite(buffer, $"11\uD800b", out var written3), buffer, written3);
+        ValidateResult("\uD800b\uD800b", Utf8.TryWrite(buffer, $"\uD800b\uD800b", out var written4), buffer, written4);
+        ValidateResult("\uD800b435345435", Utf8.TryWrite(buffer, $"\uD800b435345435", out var written5), buffer, written5);
+        ValidateResult("342532523\uD800b\uD800b35235", Utf8.TryWrite(buffer, $"342532523\uD800b\uD800b35235", out var written6), buffer, written6);
+        ValidateResult("efewfwfwfwfwefwe\uD800bfewfw\uD800bwfwefew\uD800b", Utf8.TryWrite(buffer, $"efewfwfwfwfwefwe\uD800bfewfw\uD800bwfwefew\uD800b", out var written7), buffer, written7);
+    }
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    static void ValidateResult(string str, bool actualResult, byte[] actualData, int actualBytesWritten)
+    {
+        byte[] expectedData = new byte[actualData.Length];
+        bool expectedResult = Utf8.TryWrite(expectedData, $"{str}", out int expectedBytesWritten);
+        if (expectedResult != actualResult)
+        {
+            throw new Exception($"Unexpected return value: {actualResult}");
+        }
+
+        if (actualBytesWritten != expectedBytesWritten)
+        {
+            throw new Exception($"bytesWritten value: {actualBytesWritten} != {expectedBytesWritten}");
+        }
+
+        if (expectedResult && !actualData.AsSpan(0, actualBytesWritten).SequenceEqual(
+                expectedData.AsSpan(0, expectedBytesWritten)))
+        {
+            throw new Exception("actualData != expectedData");
+        }
+    }
 }