From c33b3048ecbd4c79f27014b2bedaa031eb0570d5 Mon Sep 17 00:00:00 2001 From: Egor Bogatov Date: Tue, 25 Jul 2023 19:59:52 +0200 Subject: [PATCH] Support non-ascii in fgVNBasedIntrinsicExpansionForCall_ReadUtf8 (#89383) --- src/coreclr/jit/helperexpansion.cpp | 84 +++++++++++++++------------- src/coreclr/minipal/Windows/CMakeLists.txt | 1 + src/coreclr/pal/inc/rt/cpp/stdbool.h | 4 ++ src/tests/JIT/opt/Vectorization/ReadUtf8.cs | 87 +++++++++++++++++++++++++++++ 4 files changed, 138 insertions(+), 38 deletions(-) create mode 100644 src/coreclr/pal/inc/rt/cpp/stdbool.h diff --git a/src/coreclr/jit/helperexpansion.cpp b/src/coreclr/jit/helperexpansion.cpp index e8c237a..795245d 100644 --- a/src/coreclr/jit/helperexpansion.cpp +++ b/src/coreclr/jit/helperexpansion.cpp @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. #include "jitpch.h" +#include #ifdef _MSC_VER #pragma hdrstop #endif @@ -1238,7 +1239,7 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall(BasicBlock** pBlock, Statement //------------------------------------------------------------------------------ // fgVNBasedIntrinsicExpansionForCall_ReadUtf8 : Expand NI_System_Text_UTF8Encoding_UTF8EncodingSealed_ReadUtf8 -// when src data is a string literal (UTF16) that can be narrowed to ASCII (UTF8), e.g.: +// when src data is a string literal (UTF16) that can be converted to UTF8, e.g.: // // string str = "Hello, world!"; // int bytesWritten = ReadUtf8(ref str[0], str.Length, buffer, buffer.Length); @@ -1282,6 +1283,8 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock, return false; } + assert(strObj != nullptr); + // We mostly expect string literal objects here, but let's be more agile just in case if (!info.compCompHnd->isObjectImmutable(strObj)) { @@ -1289,52 +1292,57 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock, return false; } - GenTree* srcLen = call->gtArgs.GetUserArgByIndex(1)->GetNode(); + const GenTree* srcLen = call->gtArgs.GetUserArgByIndex(1)->GetNode(); if (!srcLen->gtVNPair.BothEqual() || !vnStore->IsVNInt32Constant(srcLen->gtVNPair.GetLiberal())) { JITDUMP("ReadUtf8: srcLen is not constant\n") return false; } - const int MaxPossibleUnrollThreshold = 256; - const unsigned unrollThreshold = min(getUnrollThreshold(UnrollKind::Memcpy), MaxPossibleUnrollThreshold); - const unsigned srcLenCns = (unsigned)vnStore->GetConstantInt32(srcLen->gtVNPair.GetLiberal()); - if ((srcLenCns == 0) || (srcLenCns > unrollThreshold)) + // Source UTF16 (U16) string length in characters + const unsigned srcLenCnsU16 = (unsigned)vnStore->GetConstantInt32(srcLen->gtVNPair.GetLiberal()); + const int MaxU16BufferSizeInChars = 256; + if ((srcLenCnsU16 == 0) || (srcLenCnsU16 > MaxU16BufferSizeInChars)) { // TODO: handle srcLenCns == 0 if it's a common case - JITDUMP("ReadUtf8: srcLenCns is out of unrollable range\n") + JITDUMP("ReadUtf8: srcLenCns is 0 or > MaxPossibleUnrollThreshold\n") return false; } - // Read the string literal (UTF16) into a local buffer (UTF8) - assert(strObj != nullptr); - uint16_t bufferU16[MaxPossibleUnrollThreshold]; - uint8_t bufferU8[MaxPossibleUnrollThreshold]; // twice smaller because of narrowing - - // Both must be within [0..INT_MAX] range as we're going to cast them to int - assert((unsigned)srcLenCns <= INT_MAX); - assert((unsigned)strObjOffset <= INT_MAX); + uint16_t bufferU16[MaxU16BufferSizeInChars]; // getObjectContent is expected to validate the offset and length - if (!info.compCompHnd->getObjectContent(strObj, (uint8_t*)bufferU16, (int)srcLenCns * 2, (int)strObjOffset)) + // NOTE: (int) casts should not overflow: + // * srcLenCns is <= MaxUTF16BufferSizeInChars + // * strObjOffset is already checked to be <= INT_MAX + if (!info.compCompHnd->getObjectContent(strObj, (uint8_t*)bufferU16, (int)(srcLenCnsU16 * sizeof(uint16_t)), + (int)strObjOffset)) { JITDUMP("ReadUtf8: getObjectContent returned false.\n") return false; } - for (unsigned charIndex = 0; charIndex < srcLenCns; charIndex++) + const int MaxU8BufferSizeInBytes = 256; + uint8_t bufferU8[MaxU8BufferSizeInBytes]; + + const int srcLenU8 = (int)minipal_convert_utf16_to_utf8((const CHAR16_T*)bufferU16, srcLenCnsU16, (char*)bufferU8, + MaxU8BufferSizeInBytes, 0); + if (srcLenU8 <= 0) { - // Buffer keeps the original utf16 chars - uint16_t ch = bufferU16[charIndex]; - if (ch > 127) - { - // Only ASCII is supported. - JITDUMP("ReadUtf8: %dth char is not ASCII.\n", charIndex) - return false; - } + // E.g. output buffer is too small + JITDUMP("ReadUtf8: minipal_convert_utf16_to_utf8 returned <= 0\n") + return false; + } - // Narrow U16 to U8 in the same buffer - bufferU8[charIndex] = (uint8_t)ch; + // The API is expected to return [1..MaxU8BufferSizeInBytes] real length of the UTF-8 value + // stored in bufferU8 + assert((unsigned)srcLenU8 <= MaxU8BufferSizeInBytes); + + // Now that we know the exact UTF8 buffer length we can check if it's unrollable + if (srcLenU8 > (int)getUnrollThreshold(UnrollKind::Memcpy)) + { + JITDUMP("ReadUtf8: srcLenU8 is out of unrollable range\n") + return false; } DebugInfo debugInfo = stmt->GetDebugInfo(); @@ -1373,10 +1381,10 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock, fgMorphStmtBlockOps(block, stmt); gtUpdateStmtSideEffects(stmt); - // srcLenCns is the length of the string literal in chars (UTF16) + // srcLenU8 is the length of the string literal in chars (UTF16) // but we're going to use the same value as the "bytesWritten" result in the fast path and in the length check. - GenTree* srcLenCnsNode = gtNewIconNode(srcLenCns); - fgValueNumberTreeConst(srcLenCnsNode); + GenTree* srcLenU8Node = gtNewIconNode(srcLenU8); + fgValueNumberTreeConst(srcLenU8Node); // We're going to insert the following blocks: // @@ -1384,12 +1392,12 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock, // // lengthCheckBb: // bytesWritten = -1; - // if (dstLen - // bytesWritten = srcLenCns * 2; + // bytesWritten = srcLenU8; // // block: // use(bytesWritten) @@ -1406,7 +1414,7 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock, fgInsertStmtAtEnd(lengthCheckBb, fgNewStmtFromTree(bytesWrittenDefaultVal, debugInfo)); GenTree* dstLen = call->gtArgs.GetUserArgByIndex(3)->GetNode(); - GenTree* lengthCheck = gtNewOperNode(GT_LT, TYP_INT, gtCloneExpr(dstLen), srcLenCnsNode); + GenTree* lengthCheck = gtNewOperNode(GT_LT, TYP_INT, gtCloneExpr(dstLen), srcLenU8Node); lengthCheck->gtFlags |= GTF_RELOP_JMP_USED; Statement* lengthCheckStmt = fgNewStmtFromTree(gtNewOperNode(GT_JTRUE, TYP_VOID, lengthCheck), debugInfo); fgInsertStmtAtEnd(lengthCheckBb, lengthCheckStmt); @@ -1424,14 +1432,14 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock, fastpathBb->bbFlags |= BBF_INTERNAL; // The widest type we can use for loads - const var_types maxLoadType = roundDownMaxType(srcLenCns); + const var_types maxLoadType = roundDownMaxType(srcLenU8); assert(genTypeSize(maxLoadType) > 0); // How many iterations we need to copy UTF8 const data to the destination - unsigned iterations = srcLenCns / genTypeSize(maxLoadType); + unsigned iterations = srcLenU8 / genTypeSize(maxLoadType); // Add one more iteration if we have a remainder - iterations += (srcLenCns % genTypeSize(maxLoadType) == 0) ? 0 : 1; + iterations += (srcLenU8 % genTypeSize(maxLoadType) == 0) ? 0 : 1; GenTree* dstPtr = call->gtArgs.GetUserArgByIndex(2)->GetNode(); for (unsigned i = 0; i < iterations; i++) @@ -1441,7 +1449,7 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock, // Last iteration: overlap with previous load if needed if (i == iterations - 1) { - offset = (ssize_t)srcLenCns - genTypeSize(maxLoadType); + offset = (ssize_t)srcLenU8 - genTypeSize(maxLoadType); } // We're going to emit the following tree (in case of SIMD16 load): @@ -1465,7 +1473,7 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock, } // Finally, store the number of bytes written to the resultLcl local - Statement* finalStmt = fgNewStmtFromTree(gtNewStoreLclVarNode(resultLclNum, gtCloneExpr(srcLenCnsNode)), debugInfo); + Statement* finalStmt = fgNewStmtFromTree(gtNewStoreLclVarNode(resultLclNum, gtCloneExpr(srcLenU8Node)), debugInfo); fgInsertStmtAtEnd(fastpathBb, finalStmt); fastpathBb->bbCodeOffs = block->bbCodeOffsEnd; fastpathBb->bbCodeOffsEnd = block->bbCodeOffsEnd; diff --git a/src/coreclr/minipal/Windows/CMakeLists.txt b/src/coreclr/minipal/Windows/CMakeLists.txt index 0c83eea..3f8368b 100644 --- a/src/coreclr/minipal/Windows/CMakeLists.txt +++ b/src/coreclr/minipal/Windows/CMakeLists.txt @@ -1,6 +1,7 @@ set(SOURCES doublemapping.cpp dn-u16.cpp + ${CLR_SRC_NATIVE_DIR}/minipal/utf8.c ) if(NOT CLR_CROSS_COMPONENTS_BUILD) diff --git a/src/coreclr/pal/inc/rt/cpp/stdbool.h b/src/coreclr/pal/inc/rt/cpp/stdbool.h new file mode 100644 index 0000000..b23533a --- /dev/null +++ b/src/coreclr/pal/inc/rt/cpp/stdbool.h @@ -0,0 +1,4 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#include "palrt.h" diff --git a/src/tests/JIT/opt/Vectorization/ReadUtf8.cs b/src/tests/JIT/opt/Vectorization/ReadUtf8.cs index 957f6b5..29a9534 100644 --- a/src/tests/JIT/opt/Vectorization/ReadUtf8.cs +++ b/src/tests/JIT/opt/Vectorization/ReadUtf8.cs @@ -5,6 +5,7 @@ using System; using System.Numerics; using System.Runtime.CompilerServices; using System.Text; +using System.Text.Unicode; using System.Threading; using Xunit; @@ -20,6 +21,11 @@ public class ReadUtf8 Test_hello(); Test_CJK(); Test_SIMD(); + Test_1(); + Test_2(); + Test_3(); + Test_4(); + Test_5(); Thread.Sleep(10); } return 100; @@ -242,4 +248,85 @@ public class ReadUtf8 throw new Exception($"{item} != 0"); } } + + // ReadUtf8 is used inside Utf8.TryWrite + interpolation syntax: + + [MethodImpl(MethodImplOptions.NoInlining)] + static void Test_1() + { + var buffer = new byte[1024]; + ValidateResult("", Utf8.TryWrite(buffer, $"", out var written1), buffer, written1); + ValidateResult("1", Utf8.TryWrite(buffer, $"1", out var written2), buffer, written2); + ValidateResult("12", Utf8.TryWrite(buffer, $"12", out var written3), buffer, written3); + ValidateResult("123", Utf8.TryWrite(buffer, $"123", out var written4), buffer, written4); + ValidateResult("1234", Utf8.TryWrite(buffer, $"1234", out var written5), buffer, written5); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void Test_2() + { + var buffer = new byte[1024]; + ValidateResult("12345", Utf8.TryWrite(buffer, $"12345", out var written1), buffer, written1); + ValidateResult("123456", Utf8.TryWrite(buffer, $"123456", out var written2), buffer, written2); + ValidateResult("1234567", Utf8.TryWrite(buffer, $"1234567", out var written3), buffer, written3); + ValidateResult("12345678", Utf8.TryWrite(buffer, $"12345678", out var written4), buffer, written4); + ValidateResult("123456789", Utf8.TryWrite(buffer, $"123456789", out var written5), buffer, written5); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void Test_3() + { + var buffer = new byte[1024]; + ValidateResult("123456789A", Utf8.TryWrite(buffer, $"123456789A", out var written1), buffer, written1); + ValidateResult("123456789AB", Utf8.TryWrite(buffer, $"123456789AB", out var written2), buffer, written2); + ValidateResult("123456789ABC", Utf8.TryWrite(buffer, $"123456789ABC", out var written3), buffer, written3); + ValidateResult("123456789ABCD", Utf8.TryWrite(buffer, $"123456789ABCD", out var written4), buffer, written4); + ValidateResult("123456789ABCDE", Utf8.TryWrite(buffer, $"123456789ABCDE", out var written5), buffer, written5); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void Test_4() + { + var buffer = new byte[1024]; + ValidateResult("123456789ABCDEF", Utf8.TryWrite(buffer, $"123456789ABCDEF", out var written1), buffer, written1); + ValidateResult("123456789ABCDEF\u0419", Utf8.TryWrite(buffer, $"123456789ABCDEF\u0419", out var written2), buffer, written2); + ValidateResult("123456789ABCDEF\u0419\u044C", Utf8.TryWrite(buffer, $"123456789ABCDEF\u0419\u044C", out var written3), buffer, written3); + ValidateResult("123456789ABCDEF\u0419\u044Cf", Utf8.TryWrite(buffer, $"123456789ABCDEF\u0419\u044Cf", out var written4), buffer, written4); + ValidateResult("123456789ABCDEF\u0419\u044Cf.", Utf8.TryWrite(buffer, $"123456789ABCDEF\u0419\u044Cf.", out var written5), buffer, written5); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void Test_5() + { + var buffer = new byte[1024]; + ValidateResult("\uD800b", Utf8.TryWrite(buffer, $"\uD800b", out var written1), buffer, written1); + ValidateResult("1\uD800b", Utf8.TryWrite(buffer, $"1\uD800b", out var written2), buffer, written2); + ValidateResult("11\uD800b", Utf8.TryWrite(buffer, $"11\uD800b", out var written3), buffer, written3); + ValidateResult("\uD800b\uD800b", Utf8.TryWrite(buffer, $"\uD800b\uD800b", out var written4), buffer, written4); + ValidateResult("\uD800b435345435", Utf8.TryWrite(buffer, $"\uD800b435345435", out var written5), buffer, written5); + ValidateResult("342532523\uD800b\uD800b35235", Utf8.TryWrite(buffer, $"342532523\uD800b\uD800b35235", out var written6), buffer, written6); + ValidateResult("efewfwfwfwfwefwe\uD800bfewfw\uD800bwfwefew\uD800b", Utf8.TryWrite(buffer, $"efewfwfwfwfwefwe\uD800bfewfw\uD800bwfwefew\uD800b", out var written7), buffer, written7); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void ValidateResult(string str, bool actualResult, byte[] actualData, int actualBytesWritten) + { + byte[] expectedData = new byte[actualData.Length]; + bool expectedResult = Utf8.TryWrite(expectedData, $"{str}", out int expectedBytesWritten); + if (expectedResult != actualResult) + { + throw new Exception($"Unexpected return value: {actualResult}"); + } + + if (actualBytesWritten != expectedBytesWritten) + { + throw new Exception($"bytesWritten value: {actualBytesWritten} != {expectedBytesWritten}"); + } + + if (expectedResult && !actualData.AsSpan(0, actualBytesWritten).SequenceEqual( + expectedData.AsSpan(0, expectedBytesWritten))) + { + throw new Exception("actualData != expectedData"); + } + } } -- 2.7.4