PR feedback: Clarify charIsNonAscii vector usage
[platform/upstream/coreclr.git] / src / System.Private.CoreLib / shared / System / Text / Unicode / Utf16Utility.Validation.cs
1 // Licensed to the .NET Foundation under one or more agreements.
2 // The .NET Foundation licenses this file to you under the MIT license.
3 // See the LICENSE file in the project root for more information.
4
5 using System.Diagnostics;
6 using System.Runtime.Intrinsics;
7 using System.Runtime.Intrinsics.X86;
8 using System.Numerics;
9 using Internal.Runtime.CompilerServices;
10
11 #if BIT64
12 using nint = System.Int64;
13 using nuint = System.UInt64;
14 #else // BIT64
15 using nint = System.Int32;
16 using nuint = System.UInt32;
17 #endif // BIT64
18
19 namespace System.Text.Unicode
20 {
21     internal static unsafe partial class Utf16Utility
22     {
23 #if DEBUG
24         static Utf16Utility()
25         {
26             Debug.Assert(sizeof(nint) == IntPtr.Size && nint.MinValue < 0, "nint is defined incorrectly.");
27             Debug.Assert(sizeof(nuint) == IntPtr.Size && nuint.MinValue == 0, "nuint is defined incorrectly.");
28         }
29 #endif // DEBUG
30
31         // Returns &inputBuffer[inputLength] if the input buffer is valid.
32         /// <summary>
33         /// Given an input buffer <paramref name="pInputBuffer"/> of char length <paramref name="inputLength"/>,
34         /// returns a pointer to where the first invalid data appears in <paramref name="pInputBuffer"/>.
35         /// </summary>
36         /// <remarks>
37         /// Returns a pointer to the end of <paramref name="pInputBuffer"/> if the buffer is well-formed.
38         /// </remarks>
39         public static char* GetPointerToFirstInvalidChar(char* pInputBuffer, int inputLength, out long utf8CodeUnitCountAdjustment, out int scalarCountAdjustment)
40         {
41             Debug.Assert(inputLength >= 0, "Input length must not be negative.");
42             Debug.Assert(pInputBuffer != null || inputLength == 0, "Input length must be zero if input buffer pointer is null.");
43
44             // First, we'll handle the common case of all-ASCII. If this is able to
45             // consume the entire buffer, we'll skip the remainder of this method's logic.
46
47             int numAsciiCharsConsumedJustNow = (int)ASCIIUtility.GetIndexOfFirstNonAsciiChar(pInputBuffer, (uint)inputLength);
48             Debug.Assert(0 <= numAsciiCharsConsumedJustNow && numAsciiCharsConsumedJustNow <= inputLength);
49
50             pInputBuffer += (uint)numAsciiCharsConsumedJustNow;
51             inputLength -= numAsciiCharsConsumedJustNow;
52
53             if (inputLength == 0)
54             {
55                 utf8CodeUnitCountAdjustment = 0;
56                 scalarCountAdjustment = 0;
57                 return pInputBuffer;
58             }
59
60             // If we got here, it means we saw some non-ASCII data, so within our
61             // vectorized code paths below we'll handle all non-surrogate UTF-16
62             // code points branchlessly. We'll only branch if we see surrogates.
63             // 
64             // We still optimistically assume the data is mostly ASCII. This means that the
65             // number of UTF-8 code units and the number of scalars almost matches the number
66             // of UTF-16 code units. As we go through the input and find non-ASCII
67             // characters, we'll keep track of these "adjustment" fixups. To get the
68             // total number of UTF-8 code units required to encode the input data, add
69             // the UTF-8 code unit count adjustment to the number of UTF-16 code units
70             // seen.  To get the total number of scalars present in the input data,
71             // add the scalar count adjustment to the number of UTF-16 code units seen.
72
73             long tempUtf8CodeUnitCountAdjustment = 0;
74             int tempScalarCountAdjustment = 0;
75
76             if (Sse2.IsSupported)
77             {
78                 if (inputLength >= Vector128<ushort>.Count)
79                 {
80                     Vector128<ushort> vector0080 = Vector128.Create((ushort)0x80);
81                     Vector128<ushort> vectorA800 = Vector128.Create((ushort)0xA800);
82                     Vector128<short> vector8800 = Vector128.Create(unchecked((short)0x8800));
83                     Vector128<ushort> vectorZero = Vector128<ushort>.Zero;
84
85                     do
86                     {
87                         Vector128<ushort> utf16Data = Sse2.LoadVector128((ushort*)pInputBuffer); // unaligned
88                         uint mask;
89
90                         // The 'charIsNonAscii' vector we're about to build will have the 0x8000 or the 0x0080
91                         // bit set (but not both!) only if the corresponding input char is non-ASCII. Which of
92                         // the two bits is set doesn't matter, as will be explained in the diagram a few lines
93                         // below.
94
95                         Vector128<ushort> charIsNonAscii;
96                         if (Sse41.IsSupported)
97                         {
98                             // sets 0x0080 bit if corresponding char element is >= 0x0080
99                             charIsNonAscii = Sse41.Min(utf16Data, vector0080);
100                         }
101                         else
102                         {
103                             // sets 0x8000 bit if corresponding char element is >= 0x0080
104                             charIsNonAscii = Sse2.AndNot(vector0080, Sse2.Subtract(vectorZero, Sse2.ShiftRightLogical(utf16Data, 7)));
105                         }
106
107 #if DEBUG
108                         // Quick check to ensure we didn't accidentally set both 0x8080 bits in any element.
109                         uint debugMask = (uint)Sse2.MoveMask(charIsNonAscii.AsByte());
110                         Debug.Assert((debugMask & (debugMask << 1)) == 0, "Two set bits shouldn't occur adjacent to each other in this mask.");
111 #endif // DEBUG
112
113                         // sets 0x8080 bits if corresponding char element is >= 0x0800
114                         Vector128<ushort> charIsThreeByteUtf8Encoded = Sse2.Subtract(vectorZero, Sse2.ShiftRightLogical(utf16Data, 11));
115
116                         mask = (uint)Sse2.MoveMask(Sse2.Or(charIsNonAscii, charIsThreeByteUtf8Encoded).AsByte());
117
118                         // Each odd bit of mask will be 1 only if the char was >= 0x0080,
119                         // and each even bit of mask will be 1 only if the char was >= 0x0800.
120                         //
121                         // Example for UTF-16 input "[ 0123 ] [ 1234 ] ...":
122                         //
123                         //            ,-- set if char[1] is non-ASCII
124                         //            |   ,-- set if char[0] is non-ASCII
125                         //            v   v
126                         // mask = ... 1 1 1 0
127                         //              ^   ^-- set if char[0] is >= 0x0800
128                         //              `-- set if char[1] is >= 0x0800
129                         //
130                         // (If the SSE4.1 code path is taken above, the meaning of the odd and even
131                         // bits are swapped, but the logic below otherwise holds.)
132                         //
133                         // This means we can popcnt the number of set bits, and the result is the
134                         // number of *additional* UTF-8 bytes that each UTF-16 code unit requires as
135                         // it expands. This results in the wrong count for UTF-16 surrogate code
136                         // units (we just counted that each individual code unit expands to 3 bytes,
137                         // but in reality a well-formed UTF-16 surrogate pair expands to 4 bytes).
138                         // We'll handle this in just a moment.
139                         //
140                         // For now, compute the popcnt but squirrel it away. We'll fold it in to the
141                         // cumulative UTF-8 adjustment factor once we determine that there are no
142                         // unpaired surrogates in our data. (Unpaired surrogates would invalidate
143                         // our computed result and we'd have to throw it away.)
144
145                         uint popcnt = (uint)BitOperations.PopCount(mask);
146
147                         // Surrogates need to be special-cased for two reasons: (a) we need
148                         // to account for the fact that we over-counted in the addition above;
149                         // and (b) they require separate validation.
150
151                         utf16Data = Sse2.Add(utf16Data, vectorA800);
152                         mask = (uint)Sse2.MoveMask(Sse2.CompareLessThan(utf16Data.AsInt16(), vector8800).AsByte());
153
154                         if (mask != 0)
155                         {
156                             // There's at least one UTF-16 surrogate code unit present.
157                             // Since we performed a pmovmskb operation on the result of a 16-bit pcmpgtw,
158                             // the resulting bits of 'mask' will occur in pairs:
159                             // - 00 if the corresponding UTF-16 char was not a surrogate code unit;
160                             // - 11 if the corresponding UTF-16 char was a surrogate code unit.
161                             //
162                             // A UTF-16 high/low surrogate code unit has the bit pattern [ 11011q## ######## ],
163                             // where # is any bit; q = 0 represents a high surrogate, and q = 1 represents
164                             // a low surrogate. Since we added 0xA800 in the vectorized operation above,
165                             // our surrogate pairs will now have the bit pattern [ 10000q## ######## ].
166                             // If we logical right-shift each word by 3, we'll end up with the bit pattern
167                             // [ 00010000 q####### ], which means that we can immediately use pmovmskb to
168                             // determine whether a given char was a high or a low surrogate.
169                             //
170                             // Therefore the resulting bits of 'mask2' will occur in pairs:
171                             // - 00 if the corresponding UTF-16 char was a high surrogate code unit;
172                             // - 01 if the corresponding UTF-16 char was a low surrogate code unit;
173                             // - ## (garbage) if the corresponding UTF-16 char was not a surrogate code unit.
174
175                             uint mask2 = (uint)Sse2.MoveMask(Sse2.ShiftRightLogical(utf16Data, 3).AsByte());
176
177                             uint lowSurrogatesMask = mask2 & mask; // 01 only if was a low surrogate char, else 00
178                             uint highSurrogatesMask = (mask2 ^ mask) & 0x5555u; // 01 only if was a high surrogate char, else 00
179
180                             // Now check that each high surrogate is followed by a low surrogate and that each
181                             // low surrogate follows a high surrogate. We make an exception for the case where
182                             // the final char of the vector is a high surrogate, since we can't perform validation
183                             // on it until the next iteration of the loop when we hope to consume the matching
184                             // low surrogate.
185
186                             highSurrogatesMask <<= 2;
187                             if ((ushort)highSurrogatesMask != lowSurrogatesMask)
188                             {
189                                 goto NonVectorizedLoop; // error: mismatched surrogate pair; break out of vectorized logic
190                             }
191
192                             if (highSurrogatesMask > ushort.MaxValue)
193                             {
194                                 // There was a standalone high surrogate at the end of the vector.
195                                 // We'll adjust our counters so that we don't consider this char consumed.
196
197                                 highSurrogatesMask = (ushort)highSurrogatesMask; // don't allow stray high surrogate to be consumed by popcnt
198                                 popcnt -= 2; // the '0xC000_0000' bits in the original mask are shifted out and discarded, so account for that here
199                                 pInputBuffer--;
200                                 inputLength++;
201                             }
202
203                             int surrogatePairsCount = BitOperations.PopCount(highSurrogatesMask);
204
205                             // 2 UTF-16 chars become 1 Unicode scalar
206
207                             tempScalarCountAdjustment -= surrogatePairsCount;
208
209                             // Since each surrogate code unit was >= 0x0800, we eagerly assumed
210                             // it'd be encoded as 3 UTF-8 code units, so our earlier popcnt computation
211                             // assumes that the pair is encoded as 6 UTF-8 code units. Since each
212                             // pair is in reality only encoded as 4 UTF-8 code units, we need to
213                             // perform this adjustment now.
214
215                             nint surrogatePairsCountNint = (nint)(nuint)(uint)surrogatePairsCount; // zero-extend to native int size
216                             tempUtf8CodeUnitCountAdjustment -= surrogatePairsCountNint;
217                             tempUtf8CodeUnitCountAdjustment -= surrogatePairsCountNint;
218                         }
219
220                         tempUtf8CodeUnitCountAdjustment += popcnt;
221                         pInputBuffer += Vector128<ushort>.Count;
222                         inputLength -= Vector128<ushort>.Count;
223                     } while (inputLength >= Vector128<ushort>.Count);
224                 }
225             }
226             else if (Vector.IsHardwareAccelerated)
227             {
228                 if (inputLength >= Vector<ushort>.Count)
229                 {
230                     Vector<ushort> vector0080 = new Vector<ushort>(0x0080);
231                     Vector<ushort> vector0400 = new Vector<ushort>(0x0400);
232                     Vector<ushort> vector0800 = new Vector<ushort>(0x0800);
233                     Vector<ushort> vectorD800 = new Vector<ushort>(0xD800);
234
235                     do
236                     {
237                         // The 'twoOrMoreUtf8Bytes' and 'threeOrMoreUtf8Bytes' vectors will contain
238                         // elements whose values are 0xFFFF (-1 as signed word) iff the corresponding
239                         // UTF-16 code unit was >= 0x0080 and >= 0x0800, respectively. By summing these
240                         // vectors, each element of the sum will contain one of three values:
241                         //
242                         // 0x0000 ( 0) = original char was 0000..007F
243                         // 0xFFFF (-1) = original char was 0080..07FF
244                         // 0xFFFE (-2) = original char was 0800..FFFF
245                         //
246                         // We'll negate them to produce a value 0..2 for each element, then sum all the
247                         // elements together to produce the number of *additional* UTF-8 code units
248                         // required to represent this UTF-16 data. This is similar to the popcnt step
249                         // performed by the SSE2 code path. This will overcount surrogates, but we'll
250                         // handle that shortly.
251
252                         Vector<ushort> utf16Data = Unsafe.ReadUnaligned<Vector<ushort>>(pInputBuffer);
253                         Vector<ushort> twoOrMoreUtf8Bytes = Vector.GreaterThanOrEqual(utf16Data, vector0080);
254                         Vector<ushort> threeOrMoreUtf8Bytes = Vector.GreaterThanOrEqual(utf16Data, vector0800);
255                         Vector<nuint> sumVector = (Vector<nuint>)(Vector<ushort>.Zero - twoOrMoreUtf8Bytes - threeOrMoreUtf8Bytes);
256
257                         // We'll try summing by a natural word (rather than a 16-bit word) at a time,
258                         // which should halve the number of operations we must perform.
259
260                         nuint popcnt = 0;
261                         for (int i = 0; i < Vector<nuint>.Count; i++)
262                         {
263                             popcnt += sumVector[i];
264                         }
265
266                         uint popcnt32 = (uint)popcnt;
267                         if (IntPtr.Size == 8)
268                         {
269                             popcnt32 += (uint)(popcnt >> 32);
270                         }
271
272                         // As in the SSE4.1 paths, compute popcnt but don't fold it in until we
273                         // know there aren't any unpaired surrogates in the input data.
274
275                         popcnt32 = (ushort)popcnt32 + (popcnt32 >> 16);
276
277                         // Now check for surrogates.
278
279                         utf16Data -= vectorD800;
280                         Vector<ushort> surrogateChars = Vector.LessThan(utf16Data, vector0800);
281                         if (surrogateChars != Vector<ushort>.Zero)
282                         {
283                             // There's at least one surrogate (high or low) UTF-16 code unit in
284                             // the vector. We'll build up additional vectors: 'highSurrogateChars'
285                             // and 'lowSurrogateChars', where the elements are 0xFFFF iff the original
286                             // UTF-16 code unit was a high or low surrogate, respectively.
287
288                             Vector<ushort> highSurrogateChars = Vector.LessThan(utf16Data, vector0400);
289                             Vector<ushort> lowSurrogateChars = Vector.AndNot(surrogateChars, highSurrogateChars);
290
291                             // We want to make sure that each high surrogate code unit is followed by
292                             // a low surrogate code unit and each low surrogate code unit follows a
293                             // high surrogate code unit. Since we don't have an equivalent of pmovmskb
294                             // or palignr available to us, we'll do this as a loop. We won't look at
295                             // the very last high surrogate char element since we don't yet know if
296                             // the next vector read will have a low surrogate char element.
297
298                             ushort surrogatePairsCount = 0;
299                             for (int i = 0; i < Vector<ushort>.Count - 1; i++)
300                             {
301                                 surrogatePairsCount -= highSurrogateChars[i]; // turns into +1 or +0
302                                 if (highSurrogateChars[i] != lowSurrogateChars[i + 1])
303                                 {
304                                     goto NonVectorizedLoop; // error: mismatched surrogate pair; break out of vectorized logic
305                                 }
306                             }
307
308                             if (highSurrogateChars[Vector<ushort>.Count - 1] != 0)
309                             {
310                                 // There was a standalone high surrogate at the end of the vector.
311                                 // We'll adjust our counters so that we don't consider this char consumed.
312
313                                 pInputBuffer--;
314                                 inputLength++;
315                                 popcnt32 -= 2;
316                             }
317
318                             nint surrogatePairsCountNint = (nint)surrogatePairsCount; // zero-extend to native int size
319
320                             // 2 UTF-16 chars become 1 Unicode scalar
321
322                             tempScalarCountAdjustment -= (int)surrogatePairsCountNint;
323
324                             // Since each surrogate code unit was >= 0x0800, we eagerly assumed
325                             // it'd be encoded as 3 UTF-8 code units. Each surrogate half is only
326                             // encoded as 2 UTF-8 code units (for 4 UTF-8 code units total),
327                             // so we'll adjust this now.
328
329                             tempUtf8CodeUnitCountAdjustment -= surrogatePairsCountNint;
330                             tempUtf8CodeUnitCountAdjustment -= surrogatePairsCountNint;
331                         }
332
333                         tempUtf8CodeUnitCountAdjustment += popcnt32;
334                         pInputBuffer += Vector<ushort>.Count;
335                         inputLength -= Vector<ushort>.Count;
336                     } while (inputLength >= Vector<ushort>.Count);
337                 }
338             }
339
340         NonVectorizedLoop:
341
342             // Vectorization isn't supported on our current platform, or the input was too small to benefit
343             // from vectorization, or we saw invalid UTF-16 data in the vectorized code paths and need to
344             // drain remaining valid chars before we report failure.
345
346             for (; inputLength > 0; pInputBuffer++, inputLength--)
347             {
348                 uint thisChar = pInputBuffer[0];
349                 if (thisChar <= 0x7F)
350                 {
351                     continue;
352                 }
353
354                 // Bump adjustment by +1 for U+0080..U+07FF; by +2 for U+0800..U+FFFF.
355                 // This optimistically assumes no surrogates, which we'll handle shortly.
356
357                 tempUtf8CodeUnitCountAdjustment += (thisChar + 0x0001_F800u) >> 16;
358
359                 if (!UnicodeUtility.IsSurrogateCodePoint(thisChar))
360                 {
361                     continue;
362                 }
363
364                 // Found a surrogate char. Back out the adjustment we made above, then
365                 // try to consume the entire surrogate pair all at once. We won't bother
366                 // trying to interpret the surrogate pair as a scalar value; we'll only
367                 // validate that its bit pattern matches what's expected for a surrogate pair.
368
369                 tempUtf8CodeUnitCountAdjustment -= 2;
370
371                 if (inputLength == 1)
372                 {
373                     goto Error; // input buffer too small to read a surrogate pair
374                 }
375
376                 thisChar = Unsafe.ReadUnaligned<uint>(pInputBuffer);
377                 if (((thisChar - (BitConverter.IsLittleEndian ? 0xDC00_D800u : 0xD800_DC00u)) & 0xFC00_FC00u) != 0)
378                 {
379                     goto Error; // not a well-formed surrogate pair
380                 }
381
382                 tempScalarCountAdjustment--; // 2 UTF-16 code units -> 1 scalar
383                 tempUtf8CodeUnitCountAdjustment += 2; // 2 UTF-16 code units -> 4 UTF-8 code units
384
385                 pInputBuffer++; // consumed one extra char
386                 inputLength--;
387             }
388
389         Error:
390
391             // Also used for normal return.
392
393             utf8CodeUnitCountAdjustment = tempUtf8CodeUnitCountAdjustment;
394             scalarCountAdjustment = tempScalarCountAdjustment;
395             return pInputBuffer;
396         }
397     }
398 }