Do not call into MsQuic inside a lock (#67037)
authorRadek Zikmund <32671551+rzikm@users.noreply.github.com>
Tue, 29 Mar 2022 20:11:09 +0000 (22:11 +0200)
committerGitHub <noreply@github.com>
Tue, 29 Mar 2022 20:11:09 +0000 (22:11 +0200)
Fixes #59345

src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs

index c8f4bb9..7b9903a 100644 (file)
@@ -149,7 +149,7 @@ namespace System.Net.Quic.Implementations.MsQuic
 
             try
             {
-                Debug.Assert(!Monitor.IsEntered(_state));
+                Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
                 MsQuicApi.Api.SetCallbackHandlerDelegate(
                     _state.Handle,
                     s_connectionDelegate,
@@ -187,7 +187,7 @@ namespace System.Net.Quic.Implementations.MsQuic
             _state.StateGCHandle = GCHandle.Alloc(_state);
             try
             {
-                Debug.Assert(!Monitor.IsEntered(_state));
+                Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
                 uint status = MsQuicApi.Api.ConnectionOpenDelegate(
                     MsQuicApi.Api.Registration,
                     s_connectionDelegate,
@@ -389,7 +389,7 @@ namespace System.Net.Quic.Implementations.MsQuic
 
         private static uint HandleEventPeerCertificateReceived(State state, ref ConnectionEvent connectionEvent)
         {
-            SslPolicyErrors sslPolicyErrors  = SslPolicyErrors.None;
+            SslPolicyErrors sslPolicyErrors = SslPolicyErrors.None;
             X509Chain? chain = null;
             X509Certificate2? certificate = null;
             X509Certificate2Collection? additionalCertificates = null;
@@ -606,13 +606,13 @@ namespace System.Net.Quic.Implementations.MsQuic
 
         internal override int GetRemoteAvailableUnidirectionalStreamCount()
         {
-            Debug.Assert(!Monitor.IsEntered(_state));
+            Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
             return MsQuicParameterHelpers.GetUShortParam(MsQuicApi.Api, _state.Handle, QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.LOCAL_UNIDI_STREAM_COUNT);
         }
 
         internal override int GetRemoteAvailableBidirectionalStreamCount()
         {
-            Debug.Assert(!Monitor.IsEntered(_state));
+            Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
             return MsQuicParameterHelpers.GetUShortParam(MsQuicApi.Api, _state.Handle, QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.LOCAL_BIDI_STREAM_COUNT);
         }
 
@@ -645,7 +645,7 @@ namespace System.Net.Quic.Implementations.MsQuic
                 SOCKADDR_INET address = MsQuicAddressHelpers.IPEndPointToINet((IPEndPoint)_remoteEndPoint);
                 unsafe
                 {
-                    Debug.Assert(!Monitor.IsEntered(_state));
+                    Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
                     status = MsQuicApi.Api.SetParamDelegate(_state.Handle, QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.REMOTE_ADDRESS, (uint)sizeof(SOCKADDR_INET), (byte*)&address);
                     QuicExceptionHelpers.ThrowIfFailed(status, "Failed to connect to peer.");
                 }
@@ -668,7 +668,7 @@ namespace System.Net.Quic.Implementations.MsQuic
                     SOCKADDR_INET quicAddress = MsQuicAddressHelpers.IPEndPointToINet(new IPEndPoint(address, port));
                     unsafe
                     {
-                        Debug.Assert(!Monitor.IsEntered(_state));
+                        Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
                         status = MsQuicApi.Api.SetParamDelegate(_state.Handle, QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.REMOTE_ADDRESS, (uint)sizeof(SOCKADDR_INET), (byte*)&quicAddress);
                         QuicExceptionHelpers.ThrowIfFailed(status, "Failed to connect to peer.");
                     }
@@ -689,7 +689,7 @@ namespace System.Net.Quic.Implementations.MsQuic
 
             try
             {
-                Debug.Assert(!Monitor.IsEntered(_state));
+                Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
                 status = MsQuicApi.Api.ConnectionStartDelegate(
                     _state.Handle,
                     _configuration,
@@ -723,7 +723,7 @@ namespace System.Net.Quic.Implementations.MsQuic
 
             try
             {
-                Debug.Assert(!Monitor.IsEntered(_state));
+                Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
                 MsQuicApi.Api.ConnectionShutdownDelegate(
                     _state.Handle,
                     Flags,
@@ -851,7 +851,7 @@ namespace System.Net.Quic.Implementations.MsQuic
             if (_state.Handle != null && !_state.Handle.IsInvalid && !_state.Handle.IsClosed)
             {
                 // Handle can be null if outbound constructor failed and we are called from finalizer.
-                Debug.Assert(!Monitor.IsEntered(_state));
+                Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
                 MsQuicApi.Api.ConnectionShutdownDelegate(
                     _state.Handle,
                     QUIC_CONNECTION_SHUTDOWN_FLAGS.SILENT,
index 71f19c2..0e65bf6 100644 (file)
@@ -87,6 +87,7 @@ namespace System.Net.Quic.Implementations.MsQuic
             _stateHandle = GCHandle.Alloc(_state);
             try
             {
+                Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
                 uint status = MsQuicApi.Api.ListenerOpenDelegate(
                     MsQuicApi.Api.Registration,
                     s_listenerDelegate,
@@ -185,6 +186,7 @@ namespace System.Net.Quic.Implementations.MsQuic
             QuicBuffer[]? buffers = null;
             try
             {
+                Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
                 MsQuicAlpnHelper.Prepare(applicationProtocols, out handles, out buffers);
                 status = MsQuicApi.Api.ListenerStartDelegate(_state.Handle, (QuicBuffer*)Marshal.UnsafeAddrOfPinnedArrayElement(buffers, 0), (uint)applicationProtocols.Count, ref address);
             }
@@ -200,6 +202,7 @@ namespace System.Net.Quic.Implementations.MsQuic
 
             QuicExceptionHelpers.ThrowIfFailed(status, "ListenerStart failed.");
 
+            Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
             SOCKADDR_INET inetAddress = MsQuicParameterHelpers.GetINetParam(MsQuicApi.Api, _state.Handle, QUIC_PARAM_LEVEL.LISTENER, (uint)QUIC_PARAM_LISTENER.LOCAL_ADDRESS);
             return MsQuicAddressHelpers.INetToIPEndPoint(ref inetAddress);
         }
@@ -216,6 +219,7 @@ namespace System.Net.Quic.Implementations.MsQuic
 
             if (_state.Handle != null)
             {
+                Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
                 MsQuicApi.Api.ListenerStopDelegate(_state.Handle);
             }
         }
@@ -276,6 +280,7 @@ namespace System.Net.Quic.Implementations.MsQuic
 
                 connectionHandle = new SafeMsQuicConnectionHandle(evt->Data.NewConnection.Connection);
 
+                Debug.Assert(!Monitor.IsEntered(state), "!Monitor.IsEntered(state)");
                 uint status = MsQuicApi.Api.ConnectionSetConfigurationDelegate(connectionHandle, connectionConfiguration);
                 if (MsQuicStatusHelper.SuccessfulStatusCode(status))
                 {
index 1a1df28..a8c2bc3 100644 (file)
@@ -119,6 +119,7 @@ namespace System.Net.Quic.Implementations.MsQuic
             _state.StateGCHandle = GCHandle.Alloc(_state);
             try
             {
+                Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
                 MsQuicApi.Api.SetCallbackHandlerDelegate(
                     _state.Handle,
                     s_streamDelegate,
@@ -164,6 +165,7 @@ namespace System.Net.Quic.Implementations.MsQuic
 
             try
             {
+                Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
                 uint status = MsQuicApi.Api.StreamOpenDelegate(
                     connectionState.Handle,
                     flags,
@@ -173,6 +175,7 @@ namespace System.Net.Quic.Implementations.MsQuic
 
                 QuicExceptionHelpers.ThrowIfFailed(status, "Failed to open stream to peer.");
 
+                Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
                 status = MsQuicApi.Api.StreamStartDelegate(_state.Handle, QUIC_STREAM_START_FLAGS.FAIL_BLOCKED);
                 QuicExceptionHelpers.ThrowIfFailed(status, "Could not start stream.");
             }
@@ -227,7 +230,7 @@ namespace System.Net.Quic.Implementations.MsQuic
             get
             {
                 ThrowIfDisposed();
-                return  _writeTimeout;
+                return _writeTimeout;
             }
             set
             {
@@ -420,6 +423,8 @@ namespace System.Net.Quic.Implementations.MsQuic
             long abortError;
             bool preCanceled = false;
 
+            int bytesRead = -1;
+            bool reenableReceive = false;
             lock (_state)
             {
                 initialReadState = _state.ReadState;
@@ -482,22 +487,32 @@ namespace System.Net.Quic.Implementations.MsQuic
                 {
                     _state.ReadState = ReadState.None;
 
-                    int taken = CopyMsQuicBuffersToUserBuffer(_state.ReceiveQuicBuffers.AsSpan(0, _state.ReceiveQuicBuffersCount), destination.Span);
-                    ReceiveComplete(taken);
+                    bytesRead = CopyMsQuicBuffersToUserBuffer(_state.ReceiveQuicBuffers.AsSpan(0, _state.ReceiveQuicBuffersCount), destination.Span);
 
-                    if (taken != _state.ReceiveQuicBuffersTotalBytes)
+                    if (bytesRead != _state.ReceiveQuicBuffersTotalBytes)
                     {
                         // Need to re-enable receives because MsQuic will pause them when we don't consume the entire buffer.
-                        EnableReceive();
+                        reenableReceive = true;
                     }
                     else if (_state.ReceiveIsFinal)
                     {
                         // This was a final message and we've consumed everything. We can complete the state without waiting for PEER_SEND_SHUTDOWN
                         _state.ReadState = ReadState.ReadsCompleted;
                     }
+                }
+            }
+
+            // methods below need to be called outside of the lock
+            if (bytesRead > -1)
+            {
+                ReceiveComplete(bytesRead);
 
-                    return new ValueTask<int>(taken);
+                if (reenableReceive)
+                {
+                    EnableReceive();
                 }
+
+                return new ValueTask<int>(bytesRead);
             }
 
             // All success scenarios returned at this point. Failure scenarios below:
@@ -510,7 +525,7 @@ namespace System.Net.Quic.Implementations.MsQuic
                     ex = new InvalidOperationException("Only one read is supported at a time.");
                     break;
                 case ReadState.Aborted:
-                    ex =  preCanceled ? new OperationCanceledException(cancellationToken) :
+                    ex = preCanceled ? new OperationCanceledException(cancellationToken) :
                           ThrowHelper.GetStreamAbortedException(abortError);
                     break;
                 case ReadState.ConnectionClosed:
@@ -609,6 +624,7 @@ namespace System.Net.Quic.Implementations.MsQuic
 
         private void StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS flags, long errorCode)
         {
+            Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
             uint status = MsQuicApi.Api.StreamShutdownDelegate(_state.Handle, flags, errorCode);
             QuicExceptionHelpers.ThrowIfFailed(status, "StreamShutdown failed.");
         }
@@ -818,7 +834,8 @@ namespace System.Net.Quic.Implementations.MsQuic
                 {
                     // Handle race condition when stream can be closed handling SHUTDOWN_COMPLETE.
                     StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0);
-                } catch (ObjectDisposedException) { };
+                }
+                catch (ObjectDisposedException) { };
             }
 
             if (abortRead)
@@ -826,7 +843,8 @@ namespace System.Net.Quic.Implementations.MsQuic
                 try
                 {
                     StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_RECEIVE, 0xffffffff);
-                } catch (ObjectDisposedException) { };
+                }
+                catch (ObjectDisposedException) { };
             }
 
             if (completeRead)
@@ -845,6 +863,7 @@ namespace System.Net.Quic.Implementations.MsQuic
 
         private void EnableReceive()
         {
+            Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
             uint status = MsQuicApi.Api.StreamReceiveSetEnabledDelegate(_state.Handle, enabled: true);
             QuicExceptionHelpers.ThrowIfFailed(status, "StreamReceiveSetEnabled failed.");
         }
@@ -1289,6 +1308,7 @@ namespace System.Net.Quic.Implementations.MsQuic
             _state.BufferArrays[0] = handle;
             _state.SendBufferCount = 1;
 
+            Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
             uint status = MsQuicApi.Api.StreamSendDelegate(
                 _state.Handle,
                 quicBuffers,
@@ -1352,6 +1372,7 @@ namespace System.Net.Quic.Implementations.MsQuic
                 ++count;
             }
 
+            Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
             uint status = MsQuicApi.Api.StreamSendDelegate(
                 _state.Handle,
                 quicBuffers,
@@ -1412,6 +1433,7 @@ namespace System.Net.Quic.Implementations.MsQuic
                 _state.BufferArrays[i] = handle;
             }
 
+            Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
             uint status = MsQuicApi.Api.StreamSendDelegate(
                 _state.Handle,
                 quicBuffers,
@@ -1434,6 +1456,7 @@ namespace System.Net.Quic.Implementations.MsQuic
 
         private void ReceiveComplete(int bufferLength)
         {
+            Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
             uint status = MsQuicApi.Api.StreamReceiveCompleteDelegate(_state.Handle, (ulong)bufferLength);
             QuicExceptionHelpers.ThrowIfFailed(status, "Could not complete receive call.");
         }