MemoryExtensions.Replace(Span<T>, T, T) implemented (#76337)
authorGünther Foidl <gue@korporal.at>
Fri, 11 Nov 2022 20:50:56 +0000 (21:50 +0100)
committerGitHub <noreply@github.com>
Fri, 11 Nov 2022 20:50:56 +0000 (15:50 -0500)
* Defined API

* Tests

* Scalar implementation

* Use EqualityComparer<T>.Default instead

* Delegation to SpanHelpers.Replace

* ReplaceValueType implemented

* Use ushort instead of short, as it doesn't sign-extend for broadcast and in the scalar loop

* Forward string.Replace(char, char) to SpanHelpers.ReplaceValueType

* Process remainder vectorized only when not done already and with max width available

* Split into inlineable scalar path and non-inlineable vectorized path

* Replaced open coded loops with Replace

* Don't use EqualityComparer<T>.Default

Cf. https://github.com/dotnet/runtime/pull/76337#discussion_r982886319

* Remove guards for remainder

Cf. https://github.com/dotnet/runtime/pull/76337#discussion_r983448480

* Don't split method into scalar and vectorized and don't force inlining of scalar-part

* Fixed assert

ReplaceValueType is called from string.Replace(char, char) so the Debug.Assert was on wrong position, as at entry to method non accelerated platforms are allowed to call it.

* Better handling of remainder from the vectorized loop(s)

Intentionally leave one iteration off, as the remaining elements are done vectorized anyway. This eliminates the less probable case (cf. https://github.com/dotnet/runtime/pull/76337#discussion_r983448480) that the last vector is done twice.

* PR feedback

src/libraries/Common/src/System/IO/Archiving.Utils.Unix.cs
src/libraries/Common/src/System/IO/Archiving.Utils.Windows.cs
src/libraries/System.Memory/ref/System.Memory.cs
src/libraries/System.Memory/tests/Span/Replace.T.cs [new file with mode: 0644]
src/libraries/System.Memory/tests/System.Memory.Tests.csproj
src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs
src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs
src/libraries/System.Private.CoreLib/src/System/String.Manipulation.cs
src/libraries/System.Private.CoreLib/src/System/Text/StringBuilder.cs
src/libraries/System.Private.Uri/src/System/Uri.cs

index 18ab525..ce0aaf6 100644 (file)
@@ -11,7 +11,7 @@ namespace System.IO
         {
             // Remove leading separators.
             int nonSlash = path.IndexOfAnyExcept('/');
-            if (nonSlash == -1)
+            if (nonSlash < 0)
             {
                 nonSlash = path.Length;
             }
index 4125639..beceebc 100644 (file)
@@ -48,7 +48,7 @@ namespace System.IO
         {
             // Remove leading separators.
             int nonSlash = path.IndexOfAnyExcept('/', '\\');
-            if (nonSlash == -1)
+            if (nonSlash < 0)
             {
                 nonSlash = path.Length;
             }
@@ -76,12 +76,7 @@ namespace System.IO
 
                     // To ensure tar files remain compatible with Unix, and per the ZIP File Format Specification 4.4.17.1,
                     // all slashes should be forward slashes.
-                    int pos;
-                    while ((pos = dest.IndexOf('\\')) >= 0)
-                    {
-                        dest[pos] = '/';
-                        dest = dest.Slice(pos + 1);
-                    }
+                    dest.Replace('\\', '/');
                 });
             }
         }
index 9e0c91d..6c8f777 100644 (file)
@@ -293,6 +293,7 @@ namespace System
         public static bool Overlaps<T>(this System.ReadOnlySpan<T> span, System.ReadOnlySpan<T> other, out int elementOffset) { throw null; }
         public static bool Overlaps<T>(this System.Span<T> span, System.ReadOnlySpan<T> other) { throw null; }
         public static bool Overlaps<T>(this System.Span<T> span, System.ReadOnlySpan<T> other, out int elementOffset) { throw null; }
+        public static void Replace<T>(this System.Span<T> span, T oldValue, T newValue) where T : System.IEquatable<T>? { }
         public static void Reverse<T>(this System.Span<T> span) { }
         public static int SequenceCompareTo<T>(this System.ReadOnlySpan<T> span, System.ReadOnlySpan<T> other) where T : System.IComparable<T>? { throw null; }
         public static int SequenceCompareTo<T>(this System.Span<T> span, System.ReadOnlySpan<T> other) where T : System.IComparable<T>? { throw null; }
diff --git a/src/libraries/System.Memory/tests/Span/Replace.T.cs b/src/libraries/System.Memory/tests/Span/Replace.T.cs
new file mode 100644 (file)
index 0000000..92b4c16
--- /dev/null
@@ -0,0 +1,149 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using Xunit;
+
+namespace System.SpanTests
+{
+    public class ReplaceTests_Byte : ReplaceTests<byte> { protected override byte Create(int value) => (byte)value; }
+    public class ReplaceTests_Int16 : ReplaceTests<short> { protected override short Create(int value) => (short)value; }
+    public class ReplaceTests_Int32 : ReplaceTests<int> { protected override int Create(int value) => value; }
+    public class ReplaceTests_Int64 : ReplaceTests<long> { protected override long Create(int value) => value; }
+    public class ReplaceTests_Char : ReplaceTests<char> { protected override char Create(int value) => (char)value; }
+    public class ReplaceTests_Double : ReplaceTests<double> { protected override double Create(int value) => (double)value; }
+    public class ReplaceTests_Record : ReplaceTests<SimpleRecord> { protected override SimpleRecord Create(int value) => new SimpleRecord(value); }
+    public class ReplaceTests_CustomEquatable : ReplaceTests<CustomEquatable> { protected override CustomEquatable Create(int value) => new CustomEquatable((byte)value); }
+
+    public readonly struct CustomEquatable : IEquatable<CustomEquatable>
+    {
+        public byte Value { get; }
+
+        public CustomEquatable(byte value) => Value = value;
+
+        public bool Equals(CustomEquatable other) => other.Value == Value;
+    }
+
+    public abstract class ReplaceTests<T> where T : IEquatable<T>
+    {
+        private readonly T _oldValue;
+        private readonly T _newValue;
+
+        protected ReplaceTests()
+        {
+            _oldValue = Create('a');
+            _newValue = Create('b');
+        }
+
+        [Fact]
+        public void ZeroLengthSpan()
+        {
+            Exception actual = Record.Exception(() => Span<T>.Empty.Replace(_oldValue, _newValue));
+
+            Assert.Null(actual);
+        }
+
+        [Theory]
+        [MemberData(nameof(Length_MemberData))]
+        public void AllElementsNeedToBeReplaced(int length)
+        {
+            Span<T> span = CreateArray(length, _oldValue);
+            T[] expected = CreateArray(length, _newValue);
+
+            span.Replace(_oldValue, _newValue);
+            T[] actual = span.ToArray();
+
+            Assert.Equal(expected, actual);
+        }
+
+        [Theory]
+        [MemberData(nameof(Length_MemberData))]
+        public void DefaultToBeReplaced(int length)
+        {
+            Span<T> span = CreateArray(length);
+            T[] expected = CreateArray(length, _newValue);
+
+            span.Replace(default, _newValue);
+            T[] actual = span.ToArray();
+
+            Assert.Equal(expected, actual);
+        }
+
+        [Theory]
+        [MemberData(nameof(Length_MemberData))]
+        public void NoElementsNeedToBeReplaced(int length)
+        {
+            T[] values = { Create('0'), Create('1') };
+
+            Span<T> span = CreateArray(length, values);
+            T[] expected = span.ToArray();
+
+            span.Replace(_oldValue, _newValue);
+            T[] actual = span.ToArray();
+
+            Assert.Equal(expected, actual);
+        }
+
+        [Theory]
+        [MemberData(nameof(Length_MemberData))]
+        public void SomeElementsNeedToBeReplaced(int length)
+        {
+            T[] values = { Create('0'), Create('1') };
+
+            Span<T> span = CreateArray(length, values);
+            span[0] = _oldValue;
+            span[^1] = _oldValue;
+
+            T[] expected = CreateArray(length, values);
+            expected[0] = _newValue;
+            expected[^1] = _newValue;
+
+            span.Replace(_oldValue, _newValue);
+            T[] actual = span.ToArray();
+
+            Assert.Equal(expected, actual);
+        }
+
+        [Theory]
+        [MemberData(nameof(Length_MemberData))]
+        public void OldAndNewValueAreSame(int length)
+        {
+            T[] values = { Create('0'), Create('1') };
+
+            Span<T> span = CreateArray(length, values);
+            span[0] = _oldValue;
+            span[^1] = _oldValue;
+            T[] expected = span.ToArray();
+
+            span.Replace(_oldValue, _oldValue);
+            T[] actual = span.ToArray();
+
+            Assert.Equal(expected, actual);
+        }
+
+        public static IEnumerable<object[]> Length_MemberData()
+        {
+            foreach (int length in new[] { 1, 2, 4, 7, 15, 16, 17, 31, 32, 33, 100 })
+            {
+                yield return new object[] { length };
+            }
+        }
+
+        protected abstract T Create(int value);
+
+        private T[] CreateArray(int length, params T[] values)
+        {
+            var arr = new T[length];
+
+            if (values.Length > 0)
+            {
+                for (int i = 0; i < arr.Length; i++)
+                {
+                    arr[i] = values[i % values.Length];
+                }
+            }
+
+            return arr;
+        }
+    }
+}
index fa93b4b..e775856 100644 (file)
     <Compile Include="Span\LastIndexOfSequence.T.cs" />
     <Compile Include="Span\Overflow.cs" />
     <Compile Include="Span\Overlaps.cs" />
+    <Compile Include="Span\Replace.T.cs" />
     <Compile Include="Span\Reverse.cs" />
     <Compile Include="Span\SequenceCompareTo.bool.cs" />
     <Compile Include="Span\SequenceCompareTo.byte.cs" />
index cf80b57..9bb51d9 100644 (file)
@@ -2923,6 +2923,68 @@ namespace System
             }
         }
 
+        /// <summary>
+        /// Replaces all occurrences of <paramref name="oldValue"/> with <paramref name="newValue"/>.
+        /// </summary>
+        /// <typeparam name="T">The type of the elements in the span.</typeparam>
+        /// <param name="span">The span in which the elements should be replaced.</param>
+        /// <param name="oldValue">The value to be replaced with <paramref name="newValue"/>.</param>
+        /// <param name="newValue">The value to replace all occurrences of <paramref name="oldValue"/>.</param>
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static void Replace<T>(this Span<T> span, T oldValue, T newValue) where T : IEquatable<T>?
+        {
+            if (SpanHelpers.CanVectorizeAndBenefit<T>(span.Length))
+            {
+                nuint length = (uint)span.Length;
+
+                if (Unsafe.SizeOf<T>() == sizeof(byte))
+                {
+                    ref byte src = ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span));
+                    SpanHelpers.ReplaceValueType(
+                        ref src,
+                        ref src,
+                        Unsafe.As<T, byte>(ref oldValue),
+                        Unsafe.As<T, byte>(ref newValue),
+                        length);
+                }
+                else if (Unsafe.SizeOf<T>() == sizeof(ushort))
+                {
+                    // Use ushort rather than short, as this avoids a sign-extending move.
+                    ref ushort src = ref Unsafe.As<T, ushort>(ref MemoryMarshal.GetReference(span));
+                    SpanHelpers.ReplaceValueType(
+                        ref src,
+                        ref src,
+                        Unsafe.As<T, ushort>(ref oldValue),
+                        Unsafe.As<T, ushort>(ref newValue),
+                        length);
+                }
+                else if (Unsafe.SizeOf<T>() == sizeof(int))
+                {
+                    ref int src = ref Unsafe.As<T, int>(ref MemoryMarshal.GetReference(span));
+                    SpanHelpers.ReplaceValueType(
+                        ref src,
+                        ref src,
+                        Unsafe.As<T, int>(ref oldValue),
+                        Unsafe.As<T, int>(ref newValue),
+                        length);
+                }
+                else
+                {
+                    Debug.Assert(Unsafe.SizeOf<T>() == sizeof(long));
+
+                    ref long src = ref Unsafe.As<T, long>(ref MemoryMarshal.GetReference(span));
+                    SpanHelpers.ReplaceValueType(
+                        ref src,
+                        ref src,
+                        Unsafe.As<T, long>(ref oldValue),
+                        Unsafe.As<T, long>(ref newValue),
+                        length);
+                }
+            }
+
+            SpanHelpers.Replace(span, oldValue, newValue);
+        }
+
         /// <summary>Finds the length of any common prefix shared between <paramref name="span"/> and <paramref name="other"/>.</summary>
         /// <typeparam name="T">The type of the elements in the spans.</typeparam>
         /// <param name="span">The first sequence to compare.</param>
index e3bef2b..ea734fd 100644 (file)
@@ -2620,6 +2620,103 @@ namespace System
             return -1;
         }
 
+        public static void Replace<T>(Span<T> span, T oldValue, T newValue) where T : IEquatable<T>?
+        {
+            if (default(T) is not null || oldValue is not null)
+            {
+                Debug.Assert(oldValue is not null);
+
+                for (int i = 0; i < span.Length; ++i)
+                {
+                    ref T val = ref span[i];
+                    if (oldValue.Equals(val))
+                    {
+                        val = newValue;
+                    }
+                }
+            }
+            else
+            {
+                for (int i = 0; i < span.Length; ++i)
+                {
+                    ref T val = ref span[i];
+                    val ??= newValue;
+                }
+            }
+        }
+
+        public static void ReplaceValueType<T>(ref T src, ref T dst, T oldValue, T newValue, nuint length) where T : struct
+        {
+            if (!Vector128.IsHardwareAccelerated || length < (uint)Vector128<T>.Count)
+            {
+                for (nuint idx = 0; idx < length; ++idx)
+                {
+                    T original = Unsafe.Add(ref src, idx);
+                    Unsafe.Add(ref dst, idx) = EqualityComparer<T>.Default.Equals(original, oldValue) ? newValue : original;
+                }
+            }
+            else
+            {
+                Debug.Assert(Vector128.IsHardwareAccelerated && Vector128<T>.IsSupported, "Vector128 is not HW-accelerated or not supported");
+
+                nuint idx = 0;
+
+                if (!Vector256.IsHardwareAccelerated || length < (uint)Vector256<T>.Count)
+                {
+                    nuint lastVectorIndex = length - (uint)Vector128<T>.Count;
+                    Vector128<T> oldValues = Vector128.Create(oldValue);
+                    Vector128<T> newValues = Vector128.Create(newValue);
+                    Vector128<T> original, mask, result;
+
+                    do
+                    {
+                        original = Vector128.LoadUnsafe(ref src, idx);
+                        mask = Vector128.Equals(oldValues, original);
+                        result = Vector128.ConditionalSelect(mask, newValues, original);
+                        result.StoreUnsafe(ref dst, idx);
+
+                        idx += (uint)Vector128<T>.Count;
+                    }
+                    while (idx < lastVectorIndex);
+
+                    // There are (0, Vector128<T>.Count] elements remaining now.
+                    // As the operation is idempotent, and we know that in total there are at least Vector128<T>.Count
+                    // elements available, we read a vector from the very end, perform the replace and write to the
+                    // the resulting vector at the very end.
+                    // Thus we can eliminate the scalar processing of the remaining elements.
+                    original = Vector128.LoadUnsafe(ref src, lastVectorIndex);
+                    mask = Vector128.Equals(oldValues, original);
+                    result = Vector128.ConditionalSelect(mask, newValues, original);
+                    result.StoreUnsafe(ref dst, lastVectorIndex);
+                }
+                else
+                {
+                    Debug.Assert(Vector256.IsHardwareAccelerated && Vector256<T>.IsSupported, "Vector256 is not HW-accelerated or not supported");
+
+                    nuint lastVectorIndex = length - (uint)Vector256<T>.Count;
+                    Vector256<T> oldValues = Vector256.Create(oldValue);
+                    Vector256<T> newValues = Vector256.Create(newValue);
+                    Vector256<T> original, mask, result;
+
+                    do
+                    {
+                        original = Vector256.LoadUnsafe(ref src, idx);
+                        mask = Vector256.Equals(oldValues, original);
+                        result = Vector256.ConditionalSelect(mask, newValues, original);
+                        result.StoreUnsafe(ref dst, idx);
+
+                        idx += (uint)Vector256<T>.Count;
+                    }
+                    while (idx < lastVectorIndex);
+
+                    original = Vector256.LoadUnsafe(ref src, lastVectorIndex);
+                    mask = Vector256.Equals(oldValues, original);
+                    result = Vector256.ConditionalSelect(mask, newValues, original);
+                    result.StoreUnsafe(ref dst, lastVectorIndex);
+                }
+            }
+        }
+
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
         private static int ComputeFirstIndex<T>(ref T searchSpace, ref T current, Vector128<T> equals) where T : struct
         {
index 50b30cb..a2524d0 100644 (file)
@@ -1008,56 +1008,21 @@ namespace System
             // Copy the remaining characters, doing the replacement as we go.
             ref ushort pSrc = ref Unsafe.Add(ref GetRawStringDataAsUInt16(), (uint)copyLength);
             ref ushort pDst = ref Unsafe.Add(ref result.GetRawStringDataAsUInt16(), (uint)copyLength);
-            nuint i = 0;
 
-            if (Vector.IsHardwareAccelerated && Length >= Vector<ushort>.Count)
+            // If the string is long enough for vectorization to kick in, we'd like to
+            // process the remaining elements vectorized too.
+            // Thus we adjust the pointers so that at least one full vector from the end can be processed.
+            nuint length = (uint)Length;
+            if (Vector128.IsHardwareAccelerated && length >= (uint)Vector128<ushort>.Count)
             {
-                Vector<ushort> oldChars = new(oldChar);
-                Vector<ushort> newChars = new(newChar);
-
-                Vector<ushort> original;
-                Vector<ushort> equals;
-                Vector<ushort> results;
-
-                if (remainingLength > (nuint)Vector<ushort>.Count)
-                {
-                    nuint lengthToExamine = remainingLength - (nuint)Vector<ushort>.Count;
-
-                    do
-                    {
-                        original = Vector.LoadUnsafe(ref pSrc, i);
-                        equals = Vector.Equals(original, oldChars);
-                        results = Vector.ConditionalSelect(equals, newChars, original);
-                        results.StoreUnsafe(ref pDst, i);
-
-                        i += (nuint)Vector<ushort>.Count;
-                    }
-                    while (i < lengthToExamine);
-                }
-
-                // There are [0, Vector<ushort>.Count) elements remaining now.
-                // As the operation is idempotent, and we know that in total there are at least Vector<ushort>.Count
-                // elements available, we read a vector from the very end of the string, perform the replace
-                // and write to the destination at the very end.
-                // Thus we can eliminate the scalar processing of the remaining elements.
-                // We perform this operation even if there are 0 elements remaining, as it is cheaper than the
-                // additional check which would introduce a branch here.
-
-                i = (uint)(Length - Vector<ushort>.Count);
-                original = Vector.LoadUnsafe(ref GetRawStringDataAsUInt16(), i);
-                equals = Vector.Equals(original, oldChars);
-                results = Vector.ConditionalSelect(equals, newChars, original);
-                results.StoreUnsafe(ref result.GetRawStringDataAsUInt16(), i);
-            }
-            else
-            {
-                for (; i < remainingLength; ++i)
-                {
-                    ushort currentChar = Unsafe.Add(ref pSrc, i);
-                    Unsafe.Add(ref pDst, i) = currentChar == oldChar ? newChar : currentChar;
-                }
+                nuint adjust = (length - remainingLength) & ((uint)Vector128<ushort>.Count - 1);
+                pSrc = ref Unsafe.Subtract(ref pSrc, adjust);
+                pDst = ref Unsafe.Subtract(ref pDst, adjust);
+                remainingLength += adjust;
             }
 
+            SpanHelpers.ReplaceValueType(ref pSrc, ref pDst, oldChar, newChar, remainingLength);
+
             return result;
         }
 
index 1c0c6f4..65a228f 100644 (file)
@@ -1965,12 +1965,7 @@ namespace System.Text
                     int endInChunk = Math.Min(chunk.m_ChunkLength, endIndexInChunk);
 
                     Span<char> span = chunk.m_ChunkChars.AsSpan(curInChunk, endInChunk - curInChunk);
-                    int i;
-                    while ((i = span.IndexOf(oldChar)) >= 0)
-                    {
-                        span[i] = newChar;
-                        span = span.Slice(i + 1);
-                    }
+                    span.Replace(oldChar, newChar);
                 }
 
                 if (startIndexInChunk >= 0)
index 634f4cf..b24d459 100644 (file)
@@ -1035,12 +1035,7 @@ namespace System
                 // Plus going through Compress will turn them into / anyway
                 // Converting / back into \
                 Span<char> slashSpan = result.AsSpan(0, count);
-                int slashPos;
-                while ((slashPos = slashSpan.IndexOf('/')) >= 0)
-                {
-                    slashSpan[slashPos] = '\\';
-                    slashSpan = slashSpan.Slice(slashPos + 1);
-                }
+                slashSpan.Replace('/', '\\');
 
                 return new string(result, 0, count);
             }