Switch over to managed Marvin implementation for string hashing (#17029)
[platform/upstream/coreclr.git] / src / mscorlib / src / System / Globalization / CompareInfo.Windows.cs
1 // Licensed to the .NET Foundation under one or more agreements.
2 // The .NET Foundation licenses this file to you under the MIT license.
3 // See the LICENSE file in the project root for more information.
4
5 using System.Buffers;
6 using System.Diagnostics;
7 using System.Security;
8 using System.Runtime.CompilerServices;
9 using System.Runtime.InteropServices;
10
11 namespace System.Globalization
12 {
13     public partial class CompareInfo
14     {
15         private unsafe void InitSort(CultureInfo culture)
16         {
17             _sortName = culture.SortName;
18
19             m_name = culture._name;
20             _sortName = culture.SortName;
21
22             if (_invariantMode)
23             {
24                 _sortHandle = IntPtr.Zero;
25             }
26             else
27             {
28                 const uint LCMAP_SORTHANDLE = 0x20000000;
29                 IntPtr handle;
30                 int ret = Interop.Kernel32.LCMapStringEx(_sortName, LCMAP_SORTHANDLE, null, 0, &handle, IntPtr.Size, null, null, IntPtr.Zero);
31                 _sortHandle = ret > 0 ? handle : IntPtr.Zero;
32             }
33         }
34
35         private static unsafe int FindStringOrdinal(
36             uint dwFindStringOrdinalFlags,
37             string stringSource,
38             int offset,
39             int cchSource,
40             string value,
41             int cchValue,
42             bool bIgnoreCase)
43         {
44             Debug.Assert(!GlobalizationMode.Invariant);
45
46             fixed (char* pSource = stringSource)
47             fixed (char* pValue = value)
48             {
49                 int ret = Interop.Kernel32.FindStringOrdinal(
50                             dwFindStringOrdinalFlags,
51                             pSource + offset,
52                             cchSource,
53                             pValue,
54                             cchValue,
55                             bIgnoreCase ? 1 : 0);
56                 return ret < 0 ? ret : ret + offset;
57             }
58         }
59
60         private static unsafe int FindStringOrdinal(
61             uint dwFindStringOrdinalFlags,
62             ReadOnlySpan<char> source,
63             ReadOnlySpan<char> value,
64             bool bIgnoreCase)
65         {
66             Debug.Assert(!GlobalizationMode.Invariant);
67
68             fixed (char* pSource = &MemoryMarshal.GetReference(source))
69             fixed (char* pValue = &MemoryMarshal.GetReference(value))
70             {
71                 int ret = Interop.Kernel32.FindStringOrdinal(
72                             dwFindStringOrdinalFlags,
73                             pSource,
74                             source.Length,
75                             pValue,
76                             value.Length,
77                             bIgnoreCase ? 1 : 0);
78                 return ret;
79             }
80         }
81
82         internal static int IndexOfOrdinalCore(string source, string value, int startIndex, int count, bool ignoreCase)
83         {
84             Debug.Assert(!GlobalizationMode.Invariant);
85
86             Debug.Assert(source != null);
87             Debug.Assert(value != null);
88
89             return FindStringOrdinal(FIND_FROMSTART, source, startIndex, count, value, value.Length, ignoreCase);
90         }
91
92         internal static int IndexOfOrdinalCore(ReadOnlySpan<char> source, ReadOnlySpan<char> value, bool ignoreCase)
93         {
94             Debug.Assert(!GlobalizationMode.Invariant);
95
96             Debug.Assert(source.Length != 0);
97             Debug.Assert(value.Length != 0);
98
99             return FindStringOrdinal(FIND_FROMSTART, source, value, ignoreCase);
100         }
101
102         internal static int LastIndexOfOrdinalCore(string source, string value, int startIndex, int count, bool ignoreCase)
103         {
104             Debug.Assert(!GlobalizationMode.Invariant);
105
106             Debug.Assert(source != null);
107             Debug.Assert(value != null);
108
109             return FindStringOrdinal(FIND_FROMEND, source, startIndex - count + 1, count, value, value.Length, ignoreCase);
110         }
111
112         private unsafe int GetHashCodeOfStringCore(string source, CompareOptions options)
113         {
114             Debug.Assert(!_invariantMode);
115
116             Debug.Assert(source != null);
117             Debug.Assert((options & (CompareOptions.Ordinal | CompareOptions.OrdinalIgnoreCase)) == 0);
118
119             if (source.Length == 0)
120             {
121                 return 0;
122             }
123
124             uint flags = LCMAP_SORTKEY | (uint)GetNativeCompareFlags(options);
125
126             fixed (char* pSource = source)
127             {
128                 int sortKeyLength = Interop.Kernel32.LCMapStringEx(_sortHandle != IntPtr.Zero ? null : _sortName,
129                                                   flags,
130                                                   pSource, source.Length,
131                                                   null, 0,
132                                                   null, null, _sortHandle);
133                 if (sortKeyLength == 0)
134                 {
135                     throw new ArgumentException(SR.Arg_ExternalException);
136                 }
137
138                 byte[] borrowedArr = null;
139                 Span<byte> span = sortKeyLength <= 512 ?
140                     stackalloc byte[512] :
141                     (borrowedArr = ArrayPool<byte>.Shared.Rent(sortKeyLength));
142
143                 fixed (byte* pSortKey = &MemoryMarshal.GetReference(span))
144                 {
145                     if (Interop.Kernel32.LCMapStringEx(_sortHandle != IntPtr.Zero ? null : _sortName,
146                                                       flags,
147                                                       pSource, source.Length,
148                                                       null, 0,
149                                                       null, null, _sortHandle) != sortKeyLength)
150                     {
151                         throw new ArgumentException(SR.Arg_ExternalException);
152                     }
153                 }
154
155                 int hash = Marvin.ComputeHash32(span.Slice(0, sortKeyLength), Marvin.DefaultSeed);
156
157                 // Return the borrowed array if necessary.
158                 if (borrowedArr != null)
159                 {
160                     ArrayPool<byte>.Shared.Return(borrowedArr);
161                 }
162
163                 return hash;
164             }
165         }
166
167         private static unsafe int CompareStringOrdinalIgnoreCase(char* string1, int count1, char* string2, int count2)
168         {
169             Debug.Assert(!GlobalizationMode.Invariant);
170
171             // Use the OS to compare and then convert the result to expected value by subtracting 2 
172             return Interop.Kernel32.CompareStringOrdinal(string1, count1, string2, count2, true) - 2;
173         }
174
175         // TODO https://github.com/dotnet/coreclr/issues/13827:
176         // This method shouldn't be necessary, as we should be able to just use the overload
177         // that takes two spans.  But due to this issue, that's adding significant overhead.
178         private unsafe int CompareString(ReadOnlySpan<char> string1, string string2, CompareOptions options)
179         {
180             Debug.Assert(string2 != null);
181             Debug.Assert(!_invariantMode);
182             Debug.Assert((options & (CompareOptions.Ordinal | CompareOptions.OrdinalIgnoreCase)) == 0);
183
184             string localeName = _sortHandle != IntPtr.Zero ? null : _sortName;
185
186             fixed (char* pLocaleName = localeName)
187             fixed (char* pString1 = &MemoryMarshal.GetReference(string1))
188             fixed (char* pString2 = &string2.GetRawStringData())
189             {
190                 int result = Interop.Kernel32.CompareStringEx(
191                                     pLocaleName,
192                                     (uint)GetNativeCompareFlags(options),
193                                     pString1,
194                                     string1.Length,
195                                     pString2,
196                                     string2.Length,
197                                     null,
198                                     null,
199                                     _sortHandle);
200
201                 if (result == 0)
202                 {
203                     Environment.FailFast("CompareStringEx failed");
204                 }
205
206                 // Map CompareStringEx return value to -1, 0, 1.
207                 return result - 2;
208             }
209         }
210
211         private unsafe int CompareString(ReadOnlySpan<char> string1, ReadOnlySpan<char> string2, CompareOptions options)
212         {
213             Debug.Assert(!_invariantMode);
214             Debug.Assert((options & (CompareOptions.Ordinal | CompareOptions.OrdinalIgnoreCase)) == 0);
215
216             string localeName = _sortHandle != IntPtr.Zero ? null : _sortName;
217
218             fixed (char* pLocaleName = localeName)
219             fixed (char* pString1 = &MemoryMarshal.GetReference(string1))
220             fixed (char* pString2 = &MemoryMarshal.GetReference(string2))
221             {
222                 int result = Interop.Kernel32.CompareStringEx(
223                                     pLocaleName,
224                                     (uint)GetNativeCompareFlags(options),
225                                     pString1,
226                                     string1.Length,
227                                     pString2,
228                                     string2.Length,
229                                     null,
230                                     null,
231                                     _sortHandle);
232
233                 if (result == 0)
234                 {
235                     Environment.FailFast("CompareStringEx failed");
236                 }
237
238                 // Map CompareStringEx return value to -1, 0, 1.
239                 return result - 2;
240             }
241         }
242
243         private unsafe int FindString(
244                     uint dwFindNLSStringFlags,
245                     ReadOnlySpan<char> lpStringSource,
246                     ReadOnlySpan<char> lpStringValue,
247                     int* pcchFound)
248         {
249             Debug.Assert(!_invariantMode);
250
251             string localeName = _sortHandle != IntPtr.Zero ? null : _sortName;
252
253             fixed (char* pLocaleName = localeName)
254             fixed (char* pSource = &MemoryMarshal.GetReference(lpStringSource))
255             fixed (char* pValue = &MemoryMarshal.GetReference(lpStringValue))
256             {
257                 return Interop.Kernel32.FindNLSStringEx(
258                                     pLocaleName,
259                                     dwFindNLSStringFlags,
260                                     pSource,
261                                     lpStringSource.Length,
262                                     pValue,
263                                     lpStringValue.Length,
264                                     pcchFound,
265                                     null,
266                                     null,
267                                     _sortHandle);
268             }
269         }
270         
271         private unsafe int FindString(
272             uint dwFindNLSStringFlags,
273             string lpStringSource,
274             int startSource,
275             int cchSource,
276             string lpStringValue,
277             int startValue,
278             int cchValue,
279             int* pcchFound)
280         {
281             Debug.Assert(!_invariantMode);
282
283             string localeName = _sortHandle != IntPtr.Zero ? null : _sortName;
284
285             fixed (char* pLocaleName = localeName)
286             fixed (char* pSource = lpStringSource)
287             fixed (char* pValue = lpStringValue)
288             {
289                 char* pS = pSource + startSource;
290                 char* pV = pValue + startValue;
291
292                 return Interop.Kernel32.FindNLSStringEx(
293                                     pLocaleName,
294                                     dwFindNLSStringFlags,
295                                     pS,
296                                     cchSource,
297                                     pV,
298                                     cchValue,
299                                     pcchFound,
300                                     null,
301                                     null,
302                                     _sortHandle);
303             }
304         }
305
306         internal unsafe int IndexOfCore(String source, String target, int startIndex, int count, CompareOptions options, int* matchLengthPtr)
307         {
308             Debug.Assert(!_invariantMode);
309
310             Debug.Assert(source != null);
311             Debug.Assert(target != null);
312             Debug.Assert((options & CompareOptions.OrdinalIgnoreCase) == 0);
313
314             if (target.Length == 0)
315             {
316                 if (matchLengthPtr != null)
317                     *matchLengthPtr = 0;
318                 return startIndex;
319             }
320
321             if (source.Length == 0)
322             {
323                 return -1;
324             }
325
326             if ((options & CompareOptions.Ordinal) != 0)
327             {
328                 int retValue = FastIndexOfString(source, target, startIndex, count, target.Length, findLastIndex: false);
329                 if (retValue >= 0)
330                 {
331                     if (matchLengthPtr != null)
332                         *matchLengthPtr = target.Length;
333                 }
334                 return retValue;
335             }
336             else
337             {
338                 int retValue = FindString(FIND_FROMSTART | (uint)GetNativeCompareFlags(options), source, startIndex, count,
339                                                                target, 0, target.Length, matchLengthPtr);
340                 if (retValue >= 0)
341                 {
342                     return retValue + startIndex;
343                 }
344             }
345
346             return -1;
347         }
348
349         internal unsafe int IndexOfCore(ReadOnlySpan<char> source, ReadOnlySpan<char> target, CompareOptions options, int* matchLengthPtr)
350         {
351             Debug.Assert(!_invariantMode);
352
353             Debug.Assert(source.Length != 0);
354             Debug.Assert(target.Length != 0);
355             Debug.Assert((options == CompareOptions.None || options == CompareOptions.IgnoreCase));
356
357             int retValue = FindString(FIND_FROMSTART | (uint)GetNativeCompareFlags(options), source, target, matchLengthPtr);
358             return retValue;
359         }
360
361         private unsafe int LastIndexOfCore(string source, string target, int startIndex, int count, CompareOptions options)
362         {
363             Debug.Assert(!_invariantMode);
364
365             Debug.Assert(!string.IsNullOrEmpty(source));
366             Debug.Assert(target != null);
367             Debug.Assert((options & CompareOptions.OrdinalIgnoreCase) == 0);
368
369             // TODO: Consider moving this up to the relevent APIs we need to ensure this behavior for
370             // and add a precondition that target is not empty. 
371             if (target.Length == 0)
372                 return startIndex;       // keep Whidbey compatibility
373
374             if ((options & CompareOptions.Ordinal) != 0)
375             {
376                 return FastIndexOfString(source, target, startIndex, count, target.Length, findLastIndex: true);
377             }
378             else
379             {
380                 int retValue = FindString(FIND_FROMEND | (uint)GetNativeCompareFlags(options), source, startIndex - count + 1,
381                                                                count, target, 0, target.Length, null);
382
383                 if (retValue >= 0)
384                 {
385                     return retValue + startIndex - (count - 1);
386                 }
387             }
388
389             return -1;
390         }
391
392         private unsafe bool StartsWith(string source, string prefix, CompareOptions options)
393         {
394             Debug.Assert(!_invariantMode);
395
396             Debug.Assert(!string.IsNullOrEmpty(source));
397             Debug.Assert(!string.IsNullOrEmpty(prefix));
398             Debug.Assert((options & (CompareOptions.Ordinal | CompareOptions.OrdinalIgnoreCase)) == 0);
399
400             return FindString(FIND_STARTSWITH | (uint)GetNativeCompareFlags(options), source, 0, source.Length,
401                                                    prefix, 0, prefix.Length, null) >= 0;
402         }
403
404         private unsafe bool StartsWith(ReadOnlySpan<char> source, ReadOnlySpan<char> prefix, CompareOptions options)
405         {
406             Debug.Assert(!_invariantMode);
407
408             Debug.Assert(!source.IsEmpty);
409             Debug.Assert(!prefix.IsEmpty);
410             Debug.Assert((options & (CompareOptions.Ordinal | CompareOptions.OrdinalIgnoreCase)) == 0);
411
412             return FindString(FIND_STARTSWITH | (uint)GetNativeCompareFlags(options), source, prefix, null) >= 0;
413         }
414
415         private unsafe bool EndsWith(string source, string suffix, CompareOptions options)
416         {
417             Debug.Assert(!_invariantMode);
418
419             Debug.Assert(!string.IsNullOrEmpty(source));
420             Debug.Assert(!string.IsNullOrEmpty(suffix));
421             Debug.Assert((options & (CompareOptions.Ordinal | CompareOptions.OrdinalIgnoreCase)) == 0);
422
423             return FindString(FIND_ENDSWITH | (uint)GetNativeCompareFlags(options), source, 0, source.Length,
424                                                    suffix, 0, suffix.Length, null) >= 0;
425         }
426
427         private unsafe bool EndsWith(ReadOnlySpan<char> source, ReadOnlySpan<char> suffix, CompareOptions options)
428         {
429             Debug.Assert(!_invariantMode);
430
431             Debug.Assert(!source.IsEmpty);
432             Debug.Assert(!suffix.IsEmpty);
433             Debug.Assert((options & (CompareOptions.Ordinal | CompareOptions.OrdinalIgnoreCase)) == 0);
434
435             return FindString(FIND_ENDSWITH | (uint)GetNativeCompareFlags(options), source, suffix, null) >= 0;
436         }
437
438         // PAL ends here
439         [NonSerialized]
440         private IntPtr _sortHandle;
441
442         private const uint LCMAP_SORTKEY = 0x00000400;
443         private const uint LCMAP_HASH    = 0x00040000;
444
445         private const int FIND_STARTSWITH = 0x00100000;
446         private const int FIND_ENDSWITH = 0x00200000;
447         private const int FIND_FROMSTART = 0x00400000;
448         private const int FIND_FROMEND = 0x00800000;
449
450         // TODO: Instead of this method could we just have upstack code call IndexOfOrdinal with ignoreCase = false?
451         private static unsafe int FastIndexOfString(string source, string target, int startIndex, int sourceCount, int targetCount, bool findLastIndex)
452         {
453             int retValue = -1;
454
455             int sourceStartIndex = findLastIndex ? startIndex - sourceCount + 1 : startIndex;
456
457             fixed (char* pSource = source, spTarget = target)
458             {
459                 char* spSubSource = pSource + sourceStartIndex;
460
461                 if (findLastIndex)
462                 {
463                     int startPattern = (sourceCount - 1) - targetCount + 1;
464                     if (startPattern < 0)
465                         return -1;
466
467                     char patternChar0 = spTarget[0];
468                     for (int ctrSrc = startPattern; ctrSrc >= 0; ctrSrc--)
469                     {
470                         if (spSubSource[ctrSrc] != patternChar0)
471                             continue;
472
473                         int ctrPat;
474                         for (ctrPat = 1; ctrPat < targetCount; ctrPat++)
475                         {
476                             if (spSubSource[ctrSrc + ctrPat] != spTarget[ctrPat])
477                                 break;
478                         }
479                         if (ctrPat == targetCount)
480                         {
481                             retValue = ctrSrc;
482                             break;
483                         }
484                     }
485
486                     if (retValue >= 0)
487                     {
488                         retValue += startIndex - sourceCount + 1;
489                     }
490                 }
491                 else
492                 {
493                     int endPattern = (sourceCount - 1) - targetCount + 1;
494                     if (endPattern < 0)
495                         return -1;
496
497                     char patternChar0 = spTarget[0];
498                     for (int ctrSrc = 0; ctrSrc <= endPattern; ctrSrc++)
499                     {
500                         if (spSubSource[ctrSrc] != patternChar0)
501                             continue;
502                         int ctrPat;
503                         for (ctrPat = 1; ctrPat < targetCount; ctrPat++)
504                         {
505                             if (spSubSource[ctrSrc + ctrPat] != spTarget[ctrPat])
506                                 break;
507                         }
508                         if (ctrPat == targetCount)
509                         {
510                             retValue = ctrSrc;
511                             break;
512                         }
513                     }
514
515                     if (retValue >= 0)
516                     {
517                         retValue += startIndex;
518                     }
519                 }
520             }
521
522             return retValue;
523         }
524
525         private unsafe SortKey CreateSortKey(String source, CompareOptions options)
526         {
527             Debug.Assert(!_invariantMode);
528
529             if (source == null) { throw new ArgumentNullException(nameof(source)); }
530
531             if ((options & ValidSortkeyCtorMaskOffFlags) != 0)
532             {
533                 throw new ArgumentException(SR.Argument_InvalidFlag, nameof(options));
534             }
535
536             byte [] keyData = null;
537             if (source.Length == 0)
538             { 
539                 keyData = Array.Empty<byte>();
540             }
541             else
542             {
543                 uint flags = LCMAP_SORTKEY | (uint)GetNativeCompareFlags(options);
544
545                 fixed (char *pSource = source)
546                 {
547                     int sortKeyLength = Interop.Kernel32.LCMapStringEx(_sortHandle != IntPtr.Zero ? null : _sortName,
548                                                 flags,
549                                                 pSource, source.Length,
550                                                 null, 0,
551                                                 null, null, _sortHandle);
552                     if (sortKeyLength == 0)
553                     {
554                         throw new ArgumentException(SR.Arg_ExternalException);
555                     }
556
557                     keyData = new byte[sortKeyLength];
558
559                     fixed (byte* pBytes =  keyData)
560                     {
561                         if (Interop.Kernel32.LCMapStringEx(_sortHandle != IntPtr.Zero ? null : _sortName,
562                                                 flags,
563                                                 pSource, source.Length,
564                                                 pBytes, keyData.Length,
565                                                 null, null, _sortHandle) != sortKeyLength)
566                         {
567                             throw new ArgumentException(SR.Arg_ExternalException);
568                         }
569                     }
570                 }
571             }
572
573             return new SortKey(Name, source, options, keyData);
574         }
575
576         private static unsafe bool IsSortable(char* text, int length)
577         {
578             Debug.Assert(!GlobalizationMode.Invariant);
579
580             return Interop.Kernel32.IsNLSDefinedString(Interop.Kernel32.COMPARE_STRING, 0, IntPtr.Zero, text, length);
581         }
582
583         private const int COMPARE_OPTIONS_ORDINAL = 0x40000000;       // Ordinal
584         private const int NORM_IGNORECASE = 0x00000001;       // Ignores case.  (use LINGUISTIC_IGNORECASE instead)
585         private const int NORM_IGNOREKANATYPE = 0x00010000;       // Does not differentiate between Hiragana and Katakana characters. Corresponding Hiragana and Katakana will compare as equal.
586         private const int NORM_IGNORENONSPACE = 0x00000002;       // Ignores nonspacing. This flag also removes Japanese accent characters.  (use LINGUISTIC_IGNOREDIACRITIC instead)
587         private const int NORM_IGNORESYMBOLS = 0x00000004;       // Ignores symbols.
588         private const int NORM_IGNOREWIDTH = 0x00020000;       // Does not differentiate between a single-byte character and the same character as a double-byte character.
589         private const int NORM_LINGUISTIC_CASING = 0x08000000;       // use linguistic rules for casing
590         private const int SORT_STRINGSORT = 0x00001000;       // Treats punctuation the same as symbols.
591
592         private static int GetNativeCompareFlags(CompareOptions options)
593         {
594             // Use "linguistic casing" by default (load the culture's casing exception tables)
595             int nativeCompareFlags = NORM_LINGUISTIC_CASING;
596
597             if ((options & CompareOptions.IgnoreCase) != 0) { nativeCompareFlags |= NORM_IGNORECASE; }
598             if ((options & CompareOptions.IgnoreKanaType) != 0) { nativeCompareFlags |= NORM_IGNOREKANATYPE; }
599             if ((options & CompareOptions.IgnoreNonSpace) != 0) { nativeCompareFlags |= NORM_IGNORENONSPACE; }
600             if ((options & CompareOptions.IgnoreSymbols) != 0) { nativeCompareFlags |= NORM_IGNORESYMBOLS; }
601             if ((options & CompareOptions.IgnoreWidth) != 0) { nativeCompareFlags |= NORM_IGNOREWIDTH; }
602             if ((options & CompareOptions.StringSort) != 0) { nativeCompareFlags |= SORT_STRINGSORT; }
603
604             // TODO: Can we try for GetNativeCompareFlags to never
605             // take Ordinal or OrdinalIgnoreCase.  This value is not part of Win32, we just handle it special
606             // in some places.
607             // Suffix & Prefix shouldn't use this, make sure to turn off the NORM_LINGUISTIC_CASING flag
608             if (options == CompareOptions.Ordinal) { nativeCompareFlags = COMPARE_OPTIONS_ORDINAL; }
609
610             Debug.Assert(((options & ~(CompareOptions.IgnoreCase |
611                                           CompareOptions.IgnoreKanaType |
612                                           CompareOptions.IgnoreNonSpace |
613                                           CompareOptions.IgnoreSymbols |
614                                           CompareOptions.IgnoreWidth |
615                                           CompareOptions.StringSort)) == 0) ||
616                              (options == CompareOptions.Ordinal), "[CompareInfo.GetNativeCompareFlags]Expected all flags to be handled");
617
618             return nativeCompareFlags;
619         }
620
621         private unsafe SortVersion GetSortVersion()
622         {
623             Debug.Assert(!_invariantMode);
624
625             Interop.Kernel32.NlsVersionInfoEx nlsVersion = new Interop.Kernel32.NlsVersionInfoEx();
626             nlsVersion.dwNLSVersionInfoSize = Marshal.SizeOf(typeof(Interop.Kernel32.NlsVersionInfoEx));
627             Interop.Kernel32.GetNLSVersionEx(Interop.Kernel32.COMPARE_STRING, _sortName, &nlsVersion);
628             return new SortVersion(
629                         nlsVersion.dwNLSVersion,
630                         nlsVersion.dwEffectiveId == 0 ? LCID : nlsVersion.dwEffectiveId,
631                         nlsVersion.guidCustomVersion);
632         }
633     }
634 }