[wasm] optimize WebSocket perf for small buffers (#72800)
authorPavel Savara <pavel.savara@gmail.com>
Tue, 26 Jul 2022 14:47:26 +0000 (16:47 +0200)
committerGitHub <noreply@github.com>
Tue, 26 Jul 2022 14:47:26 +0000 (16:47 +0200)
* use MemoryHandle instead of ArraySegment marshaling to improve WS performance for small buffers
* re-use responseStatus buffer

src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/BrowserInterop.cs
src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/BrowserWebSocket.cs
src/mono/wasm/runtime/web-socket.ts

index 4e16d14..5ea45c2 100644 (file)
@@ -3,7 +3,7 @@
 
 using System.Threading.Tasks;
 using System.Runtime.InteropServices.JavaScript;
-using System.Runtime.InteropServices;
+using System.Buffers;
 
 namespace System.Net.WebSockets
 {
@@ -28,8 +28,18 @@ namespace System.Net.WebSockets
         public static partial JSObject WebSocketCreate(
             string uri,
             string?[]? subProtocols,
+            IntPtr responseStatusPtr,
             [JSMarshalAs<JSType.Function<JSType.Number, JSType.String>>] Action<int, string> onClosed);
 
+        public static unsafe JSObject UnsafeCreate(
+            string uri,
+            string?[]? subProtocols,
+            MemoryHandle responseHandle,
+            [JSMarshalAs<JSType.Function<JSType.Number, JSType.String>>] Action<int, string> onClosed)
+        {
+            return WebSocketCreate(uri, subProtocols, (IntPtr)responseHandle.Pointer, onClosed);
+        }
+
         [JSImport("INTERNAL.ws_wasm_open")]
         public static partial Task WebSocketOpen(
             JSObject webSocket);
@@ -37,15 +47,36 @@ namespace System.Net.WebSockets
         [JSImport("INTERNAL.ws_wasm_send")]
         public static partial Task? WebSocketSend(
             JSObject webSocket,
-            [JSMarshalAs<JSType.MemoryView>] ArraySegment<byte> buffer,
+            IntPtr bufferPtr,
+            int bufferLength,
             int messageType,
             bool endOfMessage);
 
+        public static unsafe Task? UnsafeSendSync(JSObject jsWs, ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage)
+        {
+            if (buffer.Count == 0)
+            {
+                return WebSocketSend(jsWs, IntPtr.Zero, 0, (int)messageType, endOfMessage);
+            }
+
+            var span = buffer.AsSpan();
+            // we can do this because the bytes in the buffer are always consumed synchronously (not later with Task resolution)
+            fixed (void* spanPtr = span)
+            {
+                return WebSocketSend(jsWs, (IntPtr)spanPtr, buffer.Count, (int)messageType, endOfMessage);
+            }
+        }
+
         [JSImport("INTERNAL.ws_wasm_receive")]
         public static partial Task? WebSocketReceive(
             JSObject webSocket,
-            [JSMarshalAs<JSType.MemoryView>] ArraySegment<byte> buffer,
-            [JSMarshalAs<JSType.MemoryView>] ArraySegment<int> response);
+            IntPtr bufferPtr,
+            int bufferLength);
+
+        public static unsafe Task? ReceiveUnsafeSync(JSObject jsWs, MemoryHandle pinBuffer, int length)
+        {
+            return WebSocketReceive(jsWs, (IntPtr)pinBuffer.Pointer, length);
+        }
 
         [JSImport("INTERNAL.ws_wasm_close")]
         public static partial Task? WebSocketClose(
index 2fea505..42dacbb 100644 (file)
@@ -5,6 +5,7 @@ using System.Collections.Generic;
 using System.Threading;
 using System.Threading.Tasks;
 using System.Runtime.InteropServices.JavaScript;
+using System.Buffers;
 
 namespace System.Net.WebSockets
 {
@@ -19,6 +20,8 @@ namespace System.Net.WebSockets
         private WebSocketState _state;
         private bool _disposed;
         private bool _aborted;
+        private int[] responseStatus = new int[3];
+        private MemoryHandle? responseStatusHandle;
 
         #region Properties
 
@@ -162,6 +165,7 @@ namespace System.Net.WebSockets
                 }
                 _innerWebSocket?.Dispose();
                 _innerWebSocket = null;
+                responseStatusHandle?.Dispose();
             }
         }
 
@@ -181,7 +185,10 @@ namespace System.Net.WebSockets
                     }
                 };
 
-                _innerWebSocket = BrowserInterop.WebSocketCreate(uri.ToString(), subProtocols, onClose);
+                Memory<int> responseMemory = new Memory<int>(responseStatus);
+                responseStatusHandle = responseMemory.Pin();
+
+                _innerWebSocket = BrowserInterop.UnsafeCreate(uri.ToString(), subProtocols, responseStatusHandle.Value, onClose);
                 var openTask = BrowserInterop.WebSocketOpen(_innerWebSocket);
                 var wrappedTask = CancelationHelper(openTask!, cancellationToken, _state);
 
@@ -211,7 +218,7 @@ namespace System.Net.WebSockets
         {
             try
             {
-                var sendTask = BrowserInterop.WebSocketSend(_innerWebSocket!, buffer, (int)messageType, endOfMessage);
+                var sendTask = BrowserInterop.UnsafeSendSync(_innerWebSocket!, buffer, messageType, endOfMessage);
                 if (sendTask == null)
                 {
                     // return synchronously
@@ -239,18 +246,21 @@ namespace System.Net.WebSockets
         {
             try
             {
-                ArraySegment<int> response = new ArraySegment<int>(new int[3]);
-                var receiveTask = BrowserInterop.WebSocketReceive(_innerWebSocket!, buffer, response);
-                if (receiveTask == null)
+                Memory<byte> bufferMemory = buffer.AsMemory();
+                using (MemoryHandle pinBuffer = bufferMemory.Pin())
                 {
-                    // return synchronously
-                    return ConvertResponse(response);
-                }
+                    var receiveTask = BrowserInterop.ReceiveUnsafeSync(_innerWebSocket!, pinBuffer, bufferMemory.Length);
+                    if (receiveTask == null)
+                    {
+                        // return synchronously
+                        return ConvertResponse();
+                    }
 
-                var wrappedTask = CancelationHelper(receiveTask, cancellationToken, _state);
-                await wrappedTask.ConfigureAwait(true);
+                    var wrappedTask = CancelationHelper(receiveTask, cancellationToken, _state);
+                    await wrappedTask.ConfigureAwait(true);
 
-                return ConvertResponse(response);
+                    return ConvertResponse();
+                }
             }
             catch (OperationCanceledException)
             {
@@ -266,18 +276,18 @@ namespace System.Net.WebSockets
             }
         }
 
-        private WebSocketReceiveResult ConvertResponse(ArraySegment<int> response)
+        private WebSocketReceiveResult ConvertResponse()
         {
             const int countIndex = 0;
             const int typeIndex = 1;
             const int endIndex = 2;
 
-            WebSocketMessageType messageType = (WebSocketMessageType)response[typeIndex];
+            WebSocketMessageType messageType = (WebSocketMessageType)responseStatus[typeIndex];
             if (messageType == WebSocketMessageType.Close)
             {
-                return new WebSocketReceiveResult(response[countIndex], messageType, response[endIndex] != 0, CloseStatus, CloseStatusDescription);
+                return new WebSocketReceiveResult(responseStatus[countIndex], messageType, responseStatus[endIndex] != 0, CloseStatus, CloseStatusDescription);
             }
-            return new WebSocketReceiveResult(response[countIndex], messageType, response[endIndex] != 0);
+            return new WebSocketReceiveResult(responseStatus[countIndex], messageType, responseStatus[endIndex] != 0);
         }
 
         private async Task CloseAsyncCore(WebSocketCloseStatus closeStatus, string? statusDescription, bool waitForCloseReceived, CancellationToken cancellationToken)
index 81c69c2..cfc0b30 100644 (file)
@@ -5,7 +5,9 @@ import { prevent_timer_throttling } from "./scheduling";
 import { Queue } from "./queue";
 import { PromiseController, createPromiseController } from "./promise-controller";
 import { mono_assert } from "./types";
-import { ArraySegment, IDisposable } from "./marshal";
+import { VoidPtr } from "./export-types";
+import { Module } from "./imports";
+import { setI32 } from "./memory";
 
 const wasm_ws_pending_send_buffer = Symbol.for("wasm ws_pending_send_buffer");
 const wasm_ws_pending_send_buffer_offset = Symbol.for("wasm ws_pending_send_buffer_offset");
@@ -16,13 +18,14 @@ const wasm_ws_pending_open_promise = Symbol.for("wasm ws_pending_open_promise");
 const wasm_ws_pending_close_promises = Symbol.for("wasm ws_pending_close_promises");
 const wasm_ws_pending_send_promises = Symbol.for("wasm ws_pending_send_promises");
 const wasm_ws_is_aborted = Symbol.for("wasm ws_is_aborted");
+const wasm_ws_receive_status_ptr = Symbol.for("wasm ws_receive_status_ptr");
 let mono_wasm_web_socket_close_warning = false;
 let _text_decoder_utf8: TextDecoder | undefined = undefined;
 let _text_encoder_utf8: TextEncoder | undefined = undefined;
 const ws_send_buffer_blocking_threshold = 65536;
 const emptyBuffer = new Uint8Array();
 
-export function ws_wasm_create(uri: string, sub_protocols: string[] | null, onClosed: (code: number, reason: string) => void): WebSocketExtension {
+export function ws_wasm_create(uri: string, sub_protocols: string[] | null, receive_status_ptr: VoidPtr, onClosed: (code: number, reason: string) => void): WebSocketExtension {
     mono_assert(uri && typeof uri === "string", () => `ERR12: Invalid uri ${typeof uri}`);
 
     const ws = new globalThis.WebSocket(uri, sub_protocols || undefined) as WebSocketExtension;
@@ -33,6 +36,7 @@ export function ws_wasm_create(uri: string, sub_protocols: string[] | null, onCl
     ws[wasm_ws_pending_open_promise] = open_promise_control;
     ws[wasm_ws_pending_send_promises] = [];
     ws[wasm_ws_pending_close_promises] = [];
+    ws[wasm_ws_receive_status_ptr] = receive_status_ptr;
     ws.binaryType = "arraybuffer";
     const local_on_open = () => {
         if (ws[wasm_ws_is_aborted]) return;
@@ -59,12 +63,9 @@ export function ws_wasm_create(uri: string, sub_protocols: string[] | null, onCl
         // send close to any pending receivers, to wake them
         const receive_promise_queue = ws[wasm_ws_pending_receive_promise_queue];
         receive_promise_queue.drain((receive_promise_control) => {
-
-            const response = new Int32Array([
-                0,// count
-                2, // type:close
-                1]);// end_of_message: true
-            receive_promise_control.responseView.set(response);
+            setI32(receive_status_ptr, 0); // count
+            setI32(<any>receive_status_ptr + 4, 2); // type:close
+            setI32(<any>receive_status_ptr + 8, 1);// end_of_message: true
             receive_promise_control.resolve();
         });
     };
@@ -85,19 +86,20 @@ export function ws_wasm_open(ws: WebSocketExtension): Promise<WebSocketExtension
     return open_promise_control.promise;
 }
 
-export function ws_wasm_send(ws: WebSocketExtension, bufferView: ArraySegment, message_type: number, end_of_message: boolean): Promise<void> | null {
+export function ws_wasm_send(ws: WebSocketExtension, buffer_ptr: VoidPtr, buffer_length: number, message_type: number, end_of_message: boolean): Promise<void> | null {
     mono_assert(!!ws, "ERR17: expected ws instance");
 
-    const whole_buffer = _mono_wasm_web_socket_send_buffering(ws, bufferView, message_type, end_of_message);
+    const buffer_view = new Uint8Array(Module.HEAPU8.buffer, <any>buffer_ptr, buffer_length);
+    const whole_buffer = _mono_wasm_web_socket_send_buffering(ws, buffer_view, message_type, end_of_message);
 
     if (!end_of_message || !whole_buffer) {
         return null;
     }
 
-    return _mono_wasm_web_socket_send_and_wait(ws, whole_buffer, bufferView);
+    return _mono_wasm_web_socket_send_and_wait(ws, whole_buffer);
 }
 
-export function ws_wasm_receive(ws: WebSocketExtension, bufferView: ArraySegment, responseView: ArraySegment): Promise<void> | null {
+export function ws_wasm_receive(ws: WebSocketExtension, buffer_ptr: VoidPtr, buffer_length: number): Promise<void> | null {
     mono_assert(!!ws, "ERR18: expected ws instance");
 
     const receive_event_queue = ws[wasm_ws_pending_receive_event_queue];
@@ -112,14 +114,14 @@ export function ws_wasm_receive(ws: WebSocketExtension, bufferView: ArraySegment
         mono_assert(receive_promise_queue.getLength() == 0, "ERR20: Invalid WS state");
 
         // finish synchronously
-        _mono_wasm_web_socket_receive_buffering(receive_event_queue, bufferView, responseView);
+        _mono_wasm_web_socket_receive_buffering(ws, receive_event_queue, buffer_ptr, buffer_length);
 
         return null;
     }
     const { promise, promise_control } = createPromiseController<void>();
     const receive_promise_control = promise_control as ReceivePromiseControl;
-    receive_promise_control.bufferView = bufferView;
-    receive_promise_control.responseView = responseView;
+    receive_promise_control.buffer_ptr = buffer_ptr;
+    receive_promise_control.buffer_length = buffer_length;
     receive_promise_queue.enqueue(receive_promise_control);
 
     return promise;
@@ -181,10 +183,9 @@ export function ws_wasm_abort(ws: WebSocketExtension): void {
     ws.close(1000, "Connection was aborted.");
 }
 
-function _mono_wasm_web_socket_send_and_wait(ws: WebSocketExtension, buffer: Uint8Array | string, managedBuffer: IDisposable): Promise<void> | null {
-    // send and return promise
-    ws.send(buffer);
-    managedBuffer.dispose();
+// send and return promise
+function _mono_wasm_web_socket_send_and_wait(ws: WebSocketExtension, buffer_view: Uint8Array | string): Promise<void> | null {
+    ws.send(buffer_view);
     ws[wasm_ws_pending_send_buffer] = null;
 
     // if the remaining send buffer is small, we don't block so that the throughput doesn't suffer.
@@ -260,19 +261,20 @@ function _mono_wasm_web_socket_on_message(ws: WebSocketExtension, event: Message
     }
     while (promise_queue.getLength() && event_queue.getLength()) {
         const promise_control = promise_queue.dequeue()!;
-        _mono_wasm_web_socket_receive_buffering(event_queue,
-            promise_control.bufferView, promise_control.responseView);
+        _mono_wasm_web_socket_receive_buffering(ws, event_queue,
+            promise_control.buffer_ptr, promise_control.buffer_length);
         promise_control.resolve();
     }
     prevent_timer_throttling();
 }
 
-function _mono_wasm_web_socket_receive_buffering(event_queue: Queue<any>, bufferView: ArraySegment, responseView: ArraySegment) {
+function _mono_wasm_web_socket_receive_buffering(ws: WebSocketExtension, event_queue: Queue<any>, buffer_ptr: VoidPtr, buffer_length: number) {
     const event = event_queue.peek();
 
-    const count = Math.min(bufferView.length, event.data.length - event.offset);
+    const count = Math.min(buffer_length, event.data.length - event.offset);
     if (count > 0) {
         const sourceView = event.data.subarray(event.offset, event.offset + count);
+        const bufferView = new Uint8Array(Module.HEAPU8.buffer, <any>buffer_ptr, buffer_length);
         bufferView.set(sourceView, 0);
         event.offset += count;
     }
@@ -280,18 +282,16 @@ function _mono_wasm_web_socket_receive_buffering(event_queue: Queue<any>, buffer
     if (end_of_message) {
         event_queue.dequeue();
     }
-
-    const response = new Int32Array([count, event.type, end_of_message]);
-    responseView.set(response);
-
-    bufferView.dispose();
-    responseView.dispose();
+    const response_ptr = ws[wasm_ws_receive_status_ptr];
+    setI32(response_ptr, count);
+    setI32(<any>response_ptr + 4, event.type);
+    setI32(<any>response_ptr + 8, end_of_message);
 }
 
-function _mono_wasm_web_socket_send_buffering(ws: WebSocketExtension, bufferView: ArraySegment, message_type: number, end_of_message: boolean): Uint8Array | string | null {
+function _mono_wasm_web_socket_send_buffering(ws: WebSocketExtension, buffer_view: Uint8Array, message_type: number, end_of_message: boolean): Uint8Array | string | null {
     let buffer = ws[wasm_ws_pending_send_buffer];
     let offset = 0;
-    const length = bufferView.length;
+    const length = buffer_view.byteLength;
 
     if (buffer) {
         offset = ws[wasm_ws_pending_send_buffer_offset];
@@ -302,11 +302,11 @@ function _mono_wasm_web_socket_send_buffering(ws: WebSocketExtension, bufferView
             if (offset + length > buffer.length) {
                 const newbuffer = new Uint8Array((offset + length + 50) * 1.5); // exponential growth
                 newbuffer.set(buffer, 0);// copy previous buffer
-                bufferView.copyTo(newbuffer.subarray(offset));// append copy at the end
+                newbuffer.subarray(offset).set(buffer_view);// append copy at the end
                 ws[wasm_ws_pending_send_buffer] = buffer = newbuffer;
             }
             else {
-                bufferView.copyTo(buffer.subarray(offset));// append copy at the end
+                buffer.subarray(offset).set(buffer_view);// append copy at the end
             }
             offset += length;
             ws[wasm_ws_pending_send_buffer_offset] = offset;
@@ -315,7 +315,7 @@ function _mono_wasm_web_socket_send_buffering(ws: WebSocketExtension, bufferView
     else if (!end_of_message) {
         // create new buffer
         if (length !== 0) {
-            buffer = <Uint8Array>bufferView.slice(); // copy
+            buffer = <Uint8Array>buffer_view.slice(); // copy
             offset = length;
             ws[wasm_ws_pending_send_buffer_offset] = offset;
             ws[wasm_ws_pending_send_buffer] = buffer;
@@ -324,8 +324,8 @@ function _mono_wasm_web_socket_send_buffering(ws: WebSocketExtension, bufferView
     }
     else {
         if (length !== 0) {
-            // we could use the unsafe view, because it will be immediately used in ws.send()
-            buffer = <Uint8Array>bufferView._unsafe_create_view();
+            // we could use the un-pinned view, because it will be immediately used in ws.send()
+            buffer = buffer_view;
             offset = length;
         }
     }
@@ -361,14 +361,15 @@ type WebSocketExtension = WebSocket & {
     [wasm_ws_pending_send_promises]: PromiseController<void>[]
     [wasm_ws_pending_close_promises]: PromiseController<void>[]
     [wasm_ws_is_aborted]: boolean
+    [wasm_ws_receive_status_ptr]: VoidPtr
     [wasm_ws_pending_send_buffer_offset]: number
     [wasm_ws_pending_send_buffer_type]: number
     [wasm_ws_pending_send_buffer]: Uint8Array | null
 }
 
 type ReceivePromiseControl = PromiseController<void> & {
-    bufferView: ArraySegment,
-    responseView: ArraySegment
+    buffer_ptr: VoidPtr,
+    buffer_length: number,
 }
 
 type Message = {