Simplify ManagedWebSocket.ApplyMask (#87331)
authorStephen Toub <stoub@microsoft.com>
Sat, 10 Jun 2023 11:38:35 +0000 (07:38 -0400)
committerGitHub <noreply@github.com>
Sat, 10 Jun 2023 11:38:35 +0000 (07:38 -0400)
The alignment code isn't necessary or implemented ideally and can make shorter lengths slower to process.  Deleting it for now, and if we decide it's valuable to add back forced-alignment later, we can do so in a more optimal manner.

src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs

index 114cbbe..3b24288 100644 (file)
@@ -1451,61 +1451,26 @@ namespace System.Net.WebSockets
             {
                 byte* toMaskPtr = toMaskBeg;
                 byte* toMaskEnd = toMaskBeg + toMask.Length;
-                byte* maskPtr = (byte*)&mask;
 
                 if (toMaskEnd - toMaskPtr >= sizeof(int))
                 {
-                    // align our pointer to sizeof(int)
-
-                    while ((ulong)toMaskPtr % sizeof(int) != 0)
-                    {
-                        Debug.Assert(toMaskPtr < toMaskEnd);
-
-                        *toMaskPtr++ ^= maskPtr[maskIndex];
-                        maskIndex = (maskIndex + 1) & 3;
-                    }
-
-                    int rolledMask;
-                    if (BitConverter.IsLittleEndian)
-                    {
-                        rolledMask = (int)BitOperations.RotateRight((uint)mask, maskIndex * 8);
-                    }
-                    else
-                    {
-                        rolledMask = (int)BitOperations.RotateLeft((uint)mask, maskIndex * 8);
-                    }
-
-                    // use SIMD if possible.
+                    int rolledMask = BitConverter.IsLittleEndian ?
+                        (int)BitOperations.RotateRight((uint)mask, maskIndex * 8) :
+                        (int)BitOperations.RotateLeft((uint)mask, maskIndex * 8);
 
-                    if (Vector.IsHardwareAccelerated && Vector<byte>.Count % sizeof(int) == 0 && (toMaskEnd - toMaskPtr) >= Vector<byte>.Count)
+                    // Process Vector<byte>.Count bytes at a time.
+                    if (Vector.IsHardwareAccelerated && (toMaskEnd - toMaskPtr) >= Vector<byte>.Count)
                     {
-                        // align our pointer to Vector<byte>.Count
-
-                        while ((ulong)toMaskPtr % (uint)Vector<byte>.Count != 0)
-                        {
-                            Debug.Assert(toMaskPtr < toMaskEnd);
-
-                            *(int*)toMaskPtr ^= rolledMask;
-                            toMaskPtr += sizeof(int);
-                        }
-
-                        // use SIMD.
-
-                        if (toMaskEnd - toMaskPtr >= Vector<byte>.Count)
+                        Vector<byte> maskVector = Vector.AsVectorByte(new Vector<int>(rolledMask));
+                        do
                         {
-                            Vector<byte> maskVector = Vector.AsVectorByte(new Vector<int>(rolledMask));
-
-                            do
-                            {
-                                *(Vector<byte>*)toMaskPtr ^= maskVector;
-                                toMaskPtr += Vector<byte>.Count;
-                            }
-                            while (toMaskEnd - toMaskPtr >= Vector<byte>.Count);
+                            *(Vector<byte>*)toMaskPtr ^= maskVector;
+                            toMaskPtr += Vector<byte>.Count;
                         }
+                        while (toMaskEnd - toMaskPtr >= Vector<byte>.Count);
                     }
 
-                    // process remaining data (or all, if couldn't use SIMD) 4 bytes at a time.
-
+                    // Process 4 bytes at a time.
                     while (toMaskEnd - toMaskPtr >= sizeof(int))
                     {
                         *(int*)toMaskPtr ^= rolledMask;
@@ -1513,8 +1478,8 @@ namespace System.Net.WebSockets
                     }
                 }
 
-                // do any remaining data a byte at a time.
-
+                // Process 1 byte at a time.
+                byte* maskPtr = (byte*)&mask;
                 while (toMaskPtr != toMaskEnd)
                 {
                     *toMaskPtr++ ^= maskPtr[maskIndex];