Extract HttpListener.Windows queue handles into a separate class (#37515)
authorMiha Zupan <mihazupan.zupan1@gmail.com>
Sat, 20 Jun 2020 08:26:26 +0000 (10:26 +0200)
committerGitHub <noreply@github.com>
Sat, 20 Jun 2020 08:26:26 +0000 (10:26 +0200)
* Extract HttpListener queue handles into a separate class

* Move CreateRequestQueueHandle logic to HttpListenerSession ctor

* Move HttpListenerSession to a separate file

* Expand HttpListener restart test for BeginGetContext

src/libraries/System.Net.HttpListener/src/System.Net.HttpListener.csproj
src/libraries/System.Net.HttpListener/src/System/Net/Windows/HttpListener.Windows.cs
src/libraries/System.Net.HttpListener/src/System/Net/Windows/HttpListenerContext.Windows.cs
src/libraries/System.Net.HttpListener/src/System/Net/Windows/HttpListenerRequest.Windows.cs
src/libraries/System.Net.HttpListener/src/System/Net/Windows/HttpListenerSession.Windows.cs [new file with mode: 0644]
src/libraries/System.Net.HttpListener/src/System/Net/Windows/ListenerAsyncResult.Windows.cs
src/libraries/System.Net.HttpListener/tests/SimpleHttpTests.cs

index 17a49ab..b4ed96e 100644 (file)
@@ -86,6 +86,7 @@
   </ItemGroup>
   <ItemGroup Condition="'$(TargetsWindows)' == 'true' and '$(ForceManagedImplementation)' != 'true'">
     <Compile Include="System\Net\Windows\HttpListener.Windows.cs" />
+    <Compile Include="System\Net\Windows\HttpListenerSession.Windows.cs" />
     <Compile Include="System\Net\Windows\HttpListenerContext.Windows.cs" />
     <Compile Include="System\Net\Windows\HttpListenerRequest.Windows.cs" />
     <Compile Include="System\Net\Windows\HttpListenerResponse.Windows.cs" />
index 7d98e40..ca6a8ed 100644 (file)
@@ -10,12 +10,10 @@ using System.Diagnostics;
 using System.Net.Security;
 using System.Runtime.ExceptionServices;
 using System.Runtime.InteropServices;
-using System.Security;
 using System.Security.Authentication.ExtendedProtection;
 using System.Security.Principal;
 using System.Text;
 using System.Threading;
-using System.Threading.Tasks;
 
 namespace System.Net
 {
@@ -44,8 +42,8 @@ namespace System.Net
             (byte) 'e', (byte) 'n', (byte) 't', (byte) 'i', (byte) 'c', (byte) 'a', (byte) 't', (byte) 'e'
         };
 
-        private SafeHandle _requestQueueHandle;
-        private ThreadPoolBoundHandle _requestQueueBoundHandle;
+        private HttpListenerSession _currentSession;
+
         private bool _unsafeConnectionNtlmAuthentication;
 
         private HttpServerSessionHandle _serverSessionHandle;
@@ -54,8 +52,6 @@ namespace System.Net
         private bool _V2Initialized;
         private Dictionary<ulong, DisconnectAsyncResult> _disconnectResults;
 
-        internal SafeHandle RequestQueueHandle => _requestQueueHandle;
-
         private void ValidateV2Property()
         {
             // Make sure that calling CheckDisposed and SetupV2Config is an atomic operation. This
@@ -156,31 +152,6 @@ namespace System.Net
             }
         }
 
-        private IntPtr DangerousGetHandle()
-        {
-            return ((HttpRequestQueueV2Handle)_requestQueueHandle).DangerousGetHandle();
-        }
-
-        internal ThreadPoolBoundHandle RequestQueueBoundHandle
-        {
-            get
-            {
-                if (_requestQueueBoundHandle == null)
-                {
-                    lock (_internalLock)
-                    {
-                        if (_requestQueueBoundHandle == null)
-                        {
-                            _requestQueueBoundHandle = ThreadPoolBoundHandle.BindHandle(_requestQueueHandle);
-                            if (NetEventSource.IsEnabled) NetEventSource.Info($"ThreadPoolBoundHandle.BindHandle({_requestQueueHandle}) -> {_requestQueueBoundHandle}");
-                        }
-                    }
-                }
-
-                return _requestQueueBoundHandle;
-            }
-        }
-
         private void SetupV2Config()
         {
             uint statusCode = Interop.HttpApi.ERROR_SUCCESS;
@@ -267,6 +238,8 @@ namespace System.Net
                         return;
                     }
 
+                    Debug.Assert(_currentSession is null);
+
                     // SetupV2Config() is not called in the ctor, because it may throw. This would
                     // be a regression since in v1 the ctor never threw. Besides, ctors should do
                     // minimal work according to the framework design guidelines.
@@ -340,13 +313,13 @@ namespace System.Net
 
         private void AttachRequestQueueToUrlGroup()
         {
-            //
+            Debug.Assert(Monitor.IsEntered(_internalLock));
+
             // Set the association between request queue and url group. After this, requests for registered urls will
             // get delivered to this request queue.
-            //
             Interop.HttpApi.HTTP_BINDING_INFO info = default;
             info.Flags = Interop.HttpApi.HTTP_FLAGS.HTTP_PROPERTY_FLAG_PRESENT;
-            info.RequestQueueHandle = DangerousGetHandle();
+            info.RequestQueueHandle = _currentSession.RequestQueueHandle.DangerousGetHandle();
 
             IntPtr infoptr = new IntPtr(&info);
 
@@ -419,51 +392,18 @@ namespace System.Net
 
         private unsafe void CreateRequestQueueHandle()
         {
-            uint statusCode = Interop.HttpApi.ERROR_SUCCESS;
-
-            HttpRequestQueueV2Handle requestQueueHandle = null;
-            statusCode =
-                Interop.HttpApi.HttpCreateRequestQueue(
-                    Interop.HttpApi.s_version, null, null, 0, out requestQueueHandle);
-
-            if (statusCode != Interop.HttpApi.ERROR_SUCCESS)
-            {
-                throw new HttpListenerException((int)statusCode);
-            }
-
-            // Disabling callbacks when IO operation completes synchronously (returns ErrorCodes.ERROR_SUCCESS)
-            if (SkipIOCPCallbackOnSuccess &&
-                !Interop.Kernel32.SetFileCompletionNotificationModes(
-                    requestQueueHandle,
-                    Interop.Kernel32.FileCompletionNotificationModes.SkipCompletionPortOnSuccess |
-                    Interop.Kernel32.FileCompletionNotificationModes.SkipSetEventOnHandle))
-            {
-                throw new HttpListenerException(Marshal.GetLastWin32Error());
-            }
+            Debug.Assert(Monitor.IsEntered(_internalLock));
+            Debug.Assert(_currentSession is null);
 
-            _requestQueueHandle = requestQueueHandle;
+            _currentSession = new HttpListenerSession(this);
         }
 
         private unsafe void CloseRequestQueueHandle()
         {
-            if ((_requestQueueHandle != null) && (!_requestQueueHandle.IsInvalid))
-            {
-                if (NetEventSource.IsEnabled) NetEventSource.Info($"Dispose ThreadPoolBoundHandle: {_requestQueueBoundHandle}");
-                _requestQueueBoundHandle?.Dispose();
-                _requestQueueHandle.Dispose();
+            Debug.Assert(Monitor.IsEntered(_internalLock));
 
-                // CancelIoEx is called after Dispose to prevent a race condition involving parallel GetContext and
-                // HttpReceiveHttpRequest calls. Otherwise, calling CancelIoEx before Dispose might block the synchronous
-                // GetContext call until the next request arrives.
-                try
-                {
-                    Interop.Kernel32.CancelIoEx(_requestQueueHandle, null); // This cancels the synchronous call to HttpReceiveHttpRequest
-                }
-                catch (ObjectDisposedException)
-                {
-                    // Ignore the exception since it only means that the queue handle has been successfully disposed
-                }
-            }
+            _currentSession?.CloseRequestQueueHandle();
+            _currentSession = null;
         }
 
         public void Abort()
@@ -588,6 +528,8 @@ namespace System.Net
                 uint size = 4096;
                 ulong requestId = 0;
                 memoryBlob = new SyncRequestContext((int)size);
+                HttpListenerSession session = _currentSession;
+
                 while (true)
                 {
                     while (true)
@@ -596,7 +538,7 @@ namespace System.Net
                         uint bytesTransferred = 0;
                         statusCode =
                             Interop.HttpApi.HttpReceiveHttpRequest(
-                                _requestQueueHandle,
+                                session.RequestQueueHandle,
                                 requestId,
                                 (uint)Interop.HttpApi.HTTP_FLAGS.HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY,
                                 memoryBlob.RequestBlob,
@@ -631,10 +573,10 @@ namespace System.Net
                         throw new HttpListenerException((int)statusCode);
                     }
 
-                    if (ValidateRequest(memoryBlob))
+                    if (ValidateRequest(session, memoryBlob))
                     {
                         // We need to hook up our authentication handling code here.
-                        httpContext = HandleAuthentication(memoryBlob, out stoleBlob);
+                        httpContext = HandleAuthentication(session, memoryBlob, out stoleBlob);
                     }
 
                     if (stoleBlob)
@@ -675,12 +617,12 @@ namespace System.Net
             }
         }
 
-        internal unsafe bool ValidateRequest(RequestContextBase requestMemory)
+        internal static unsafe bool ValidateRequest(HttpListenerSession session, RequestContextBase requestMemory)
         {
             // Block potential DOS attacks
             if (requestMemory.RequestBlob->Headers.UnknownHeaderCount > UnknownHeaderLimit)
             {
-                SendError(requestMemory.RequestBlob->RequestId, HttpStatusCode.BadRequest, null);
+                SendError(session, requestMemory.RequestBlob->RequestId, HttpStatusCode.BadRequest, null);
                 return false;
             }
             return true;
@@ -700,7 +642,7 @@ namespace System.Net
                 // prepare the ListenerAsyncResult object (this will have it's own
                 // event that the user can wait on for IO completion - which means we
                 // need to signal it when IO completes)
-                asyncResult = new ListenerAsyncResult(this, state, callback);
+                asyncResult = new ListenerAsyncResult(_currentSession, state, callback);
                 uint statusCode = asyncResult.QueueBeginGetContext();
                 if (statusCode != Interop.HttpApi.ERROR_SUCCESS &&
                     statusCode != Interop.HttpApi.ERROR_IO_PENDING)
@@ -735,8 +677,7 @@ namespace System.Net
                     throw new ArgumentNullException(nameof(asyncResult));
                 }
                 if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"asyncResult: {asyncResult}");
-                ListenerAsyncResult castedAsyncResult = asyncResult as ListenerAsyncResult;
-                if (castedAsyncResult == null || castedAsyncResult.AsyncObject != this)
+                if (!(asyncResult is ListenerAsyncResult castedAsyncResult) || !(castedAsyncResult.AsyncObject is HttpListenerSession session) || session.Listener != this)
                 {
                     throw new ArgumentException(SR.net_io_invalidasyncresult, nameof(asyncResult));
                 }
@@ -764,7 +705,7 @@ namespace System.Net
             return httpContext;
         }
 
-        internal HttpListenerContext HandleAuthentication(RequestContextBase memoryBlob, out bool stoleBlob)
+        internal HttpListenerContext HandleAuthentication(HttpListenerSession session, RequestContextBase memoryBlob, out bool stoleBlob)
         {
             if (NetEventSource.IsEnabled) NetEventSource.Info(this, "HandleAuthentication() memoryBlob:0x" + ((IntPtr)memoryBlob.RequestBlob).ToString("x"));
 
@@ -796,7 +737,7 @@ namespace System.Net
                     {
                         if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"Principal: {principal} principal.Identity.Name: {principal.Identity.Name} creating request");
                         stoleBlob = true;
-                        HttpListenerContext ntlmContext = new HttpListenerContext(this, memoryBlob);
+                        HttpListenerContext ntlmContext = new HttpListenerContext(session, memoryBlob);
                         ntlmContext.SetIdentity(principal, null);
                         ntlmContext.Request.ReleasePins();
                         return ntlmContext;
@@ -837,7 +778,7 @@ namespace System.Net
                     oldContext = disconnectResult.Session;
                 }
 
-                httpContext = new HttpListenerContext(this, memoryBlob);
+                httpContext = new HttpListenerContext(session, memoryBlob);
 
                 AuthenticationSchemeSelector authenticationSelector = _authenticationDelegate;
                 if (authenticationSelector != null)
@@ -857,7 +798,7 @@ namespace System.Net
                             NetEventSource.Error(this, SR.Format(SR.net_log_listener_delegate_exception, exception));
                             NetEventSource.Info(this, $"authenticationScheme: {authenticationScheme}");
                         }
-                        SendError(requestId, HttpStatusCode.InternalServerError, null);
+                        SendError(session, requestId, HttpStatusCode.InternalServerError, null);
                         FreeContext(ref httpContext, memoryBlob);
                         return null;
                     }
@@ -979,7 +920,7 @@ namespace System.Net
                             }
                             else
                             {
-                                binding = GetChannelBinding(connectionId, isSecureConnection, extendedProtectionPolicy);
+                                binding = GetChannelBinding(session, connectionId, isSecureConnection, extendedProtectionPolicy);
                                 ContextFlagsPal contextFlags = GetContextFlags(extendedProtectionPolicy, isSecureConnection);
                                 context = new NTAuthentication(true, package, CredentialCache.DefaultNetworkCredentials, null, contextFlags, binding);
                             }
@@ -1073,7 +1014,7 @@ namespace System.Net
                                                     // We may need to call WaitForDisconnect.
                                                     if (disconnectResult == null)
                                                     {
-                                                        RegisterForDisconnectNotification(connectionId, ref disconnectResult);
+                                                        RegisterForDisconnectNotification(session, connectionId, ref disconnectResult);
                                                     }
                                                     if (disconnectResult != null)
                                                     {
@@ -1205,7 +1146,7 @@ namespace System.Net
                         if (httpError != HttpStatusCode.Unauthorized)
                         {
                             if (NetEventSource.IsEnabled) NetEventSource.Info(this, "ConnectionId:" + connectionId + " because of error:" + httpError.ToString());
-                            SendError(requestId, httpError, null);
+                            SendError(session, requestId, httpError, null);
                             return null;
                         }
 
@@ -1217,7 +1158,7 @@ namespace System.Net
                 // Check if we need to call WaitForDisconnect, because if we do and it fails, we want to send a 500 instead.
                 if (disconnectResult == null && newContext != null)
                 {
-                    RegisterForDisconnectNotification(connectionId, ref disconnectResult);
+                    RegisterForDisconnectNotification(session, connectionId, ref disconnectResult);
 
                     // Failed - send 500.
                     if (disconnectResult == null)
@@ -1242,7 +1183,7 @@ namespace System.Net
                         }
 
                         if (NetEventSource.IsEnabled) NetEventSource.Info(this, "connectionId:" + connectionId + " because of failed HttpWaitForDisconnect");
-                        SendError(requestId, HttpStatusCode.InternalServerError, null);
+                        SendError(session, requestId, HttpStatusCode.InternalServerError, null);
                         FreeContext(ref httpContext, memoryBlob);
                         return null;
                     }
@@ -1270,7 +1211,7 @@ namespace System.Net
                 // Send the 401 here.
                 if (httpContext == null)
                 {
-                    SendError(requestId, challenges != null && challenges.Count > 0 ? HttpStatusCode.Unauthorized : HttpStatusCode.Forbidden, challenges);
+                    SendError(session, requestId, challenges != null && challenges.Count > 0 ? HttpStatusCode.Unauthorized : HttpStatusCode.Forbidden, challenges);
                     if (NetEventSource.IsEnabled) NetEventSource.Info(this, "Scheme:" + authenticationScheme);
                     return null;
                 }
@@ -1377,7 +1318,7 @@ namespace System.Net
             }
         }
 
-        private ChannelBinding GetChannelBinding(ulong connectionId, bool isSecureConnection, ExtendedProtectionPolicy policy)
+        private ChannelBinding GetChannelBinding(HttpListenerSession session, ulong connectionId, bool isSecureConnection, ExtendedProtectionPolicy policy)
         {
             if (policy.PolicyEnforcement == PolicyEnforcement.Never)
             {
@@ -1397,7 +1338,7 @@ namespace System.Net
                 return null;
             }
 
-            ChannelBinding result = GetChannelBindingFromTls(connectionId);
+            ChannelBinding result = GetChannelBindingFromTls(session, connectionId);
 
             if (NetEventSource.IsEnabled && result != null)
                 NetEventSource.Info(this, "GetChannelBindingFromTls returned null even though OS supposedly supports Extended Protection");
@@ -1645,22 +1586,22 @@ namespace System.Net
             return challenges;
         }
 
-        private void RegisterForDisconnectNotification(ulong connectionId, ref DisconnectAsyncResult disconnectResult)
+        private static void RegisterForDisconnectNotification(HttpListenerSession session, ulong connectionId, ref DisconnectAsyncResult disconnectResult)
         {
             Debug.Assert(disconnectResult == null);
 
             try
             {
-                if (NetEventSource.IsEnabled) NetEventSource.Info(this, "Calling Interop.HttpApi.HttpWaitForDisconnect");
+                if (NetEventSource.IsEnabled) NetEventSource.Info(session.Listener, "Calling Interop.HttpApi.HttpWaitForDisconnect");
 
-                DisconnectAsyncResult result = new DisconnectAsyncResult(this, connectionId);
+                DisconnectAsyncResult result = new DisconnectAsyncResult(session, connectionId);
 
                 uint statusCode = Interop.HttpApi.HttpWaitForDisconnect(
-                    _requestQueueHandle,
+                    session.RequestQueueHandle,
                     connectionId,
                     result.NativeOverlapped);
 
-                if (NetEventSource.IsEnabled) NetEventSource.Info(this, "Call to Interop.HttpApi.HttpWaitForDisconnect returned:" + statusCode);
+                if (NetEventSource.IsEnabled) NetEventSource.Info(session.Listener, "Call to Interop.HttpApi.HttpWaitForDisconnect returned:" + statusCode);
 
                 if (statusCode == Interop.HttpApi.ERROR_SUCCESS ||
                     statusCode == Interop.HttpApi.ERROR_IO_PENDING)
@@ -1668,7 +1609,7 @@ namespace System.Net
                     // Need to make sure it's going to get returned before adding it to the hash.  That way it'll be handled
                     // correctly in HandleAuthentication's finally.
                     disconnectResult = result;
-                    DisconnectResults[connectionId] = disconnectResult;
+                    session.Listener.DisconnectResults[connectionId] = disconnectResult;
                 }
 
                 if (statusCode == Interop.HttpApi.ERROR_SUCCESS && HttpListener.SkipIOCPCallbackOnSuccess)
@@ -1680,13 +1621,13 @@ namespace System.Net
             catch (Win32Exception exception)
             {
                 uint statusCode = (uint)exception.NativeErrorCode;
-                if (NetEventSource.IsEnabled) NetEventSource.Info(this, "Call to Interop.HttpApi.HttpWaitForDisconnect threw, statusCode:" + statusCode);
+                if (NetEventSource.IsEnabled) NetEventSource.Info(session.Listener, "Call to Interop.HttpApi.HttpWaitForDisconnect threw, statusCode:" + statusCode);
             }
         }
 
-        private void SendError(ulong requestId, HttpStatusCode httpStatusCode, ArrayList challenges)
+        private static void SendError(HttpListenerSession session, ulong requestId, HttpStatusCode httpStatusCode, ArrayList challenges)
         {
-            if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"RequestId: {requestId}");
+            if (NetEventSource.IsEnabled) NetEventSource.Info(session.Listener, $"RequestId: {requestId}");
             Interop.HttpApi.HTTP_RESPONSE httpResponse = default;
             httpResponse.Version = default;
             httpResponse.Version.MajorVersion = (ushort)1;
@@ -1738,10 +1679,10 @@ namespace System.Net
                             }
                         }
 
-                        if (NetEventSource.IsEnabled) NetEventSource.Info(this, "Calling Interop.HttpApi.HttpSendHtthttpResponse");
+                        if (NetEventSource.IsEnabled) NetEventSource.Info(session.Listener, "Calling Interop.HttpApi.HttpSendHtthttpResponse");
                         statusCode =
                             Interop.HttpApi.HttpSendHttpResponse(
-                                _requestQueueHandle,
+                                session.RequestQueueHandle,
                                 requestId,
                                 0,
                                 &httpResponse,
@@ -1775,12 +1716,12 @@ namespace System.Net
                     }
                 }
             }
-            if (NetEventSource.IsEnabled) NetEventSource.Info(this, "Call to Interop.HttpApi.HttpSendHttpResponse returned:" + statusCode);
+            if (NetEventSource.IsEnabled) NetEventSource.Info(session.Listener, "Call to Interop.HttpApi.HttpSendHttpResponse returned:" + statusCode);
             if (statusCode != Interop.HttpApi.ERROR_SUCCESS)
             {
                 // if we fail to send a 401 something's seriously wrong, abort the request
-                if (NetEventSource.IsEnabled) NetEventSource.Info(this, "SendUnauthorized returned:" + statusCode);
-                HttpListenerContext.CancelRequest(_requestQueueHandle, requestId);
+                if (NetEventSource.IsEnabled) NetEventSource.Info(session.Listener, "SendUnauthorized returned:" + statusCode);
+                HttpListenerContext.CancelRequest(session.RequestQueueHandle, requestId);
             }
         }
 
@@ -1799,9 +1740,9 @@ namespace System.Net
             return (int)((Interop.HttpApi.HTTP_REQUEST_CHANNEL_BIND_STATUS*)blob)->ChannelTokenSize;
         }
 
-        internal ChannelBinding GetChannelBindingFromTls(ulong connectionId)
+        internal static ChannelBinding GetChannelBindingFromTls(HttpListenerSession session, ulong connectionId)
         {
-            if (NetEventSource.IsEnabled) NetEventSource.Enter(this, $"connectionId: {connectionId}");
+            if (NetEventSource.IsEnabled) NetEventSource.Enter(session.Listener, $"connectionId: {connectionId}");
 
             // +128 since a CBT is usually <128 thus we need to call HRCC just once. If the CBT
             // is >128 we will get ERROR_MORE_DATA and call again
@@ -1823,7 +1764,7 @@ namespace System.Net
                     // Http.sys team: ServiceName will always be null if
                     // HTTP_RECEIVE_SECURE_CHANNEL_TOKEN flag is set.
                     statusCode = Interop.HttpApi.HttpReceiveClientCertificate(
-                        RequestQueueHandle,
+                        session.RequestQueueHandle,
                         connectionId,
                         (uint)Interop.HttpApi.HTTP_FLAGS.HTTP_RECEIVE_SECURE_CHANNEL_TOKEN,
                         blobPtr,
@@ -1855,7 +1796,7 @@ namespace System.Net
                     {
                         if (NetEventSource.IsEnabled)
                         {
-                            NetEventSource.Error(this, SR.net_ssp_dont_support_cbt);
+                            NetEventSource.Error(session.Listener, SR.net_ssp_dont_support_cbt);
                         }
                         return null; // old schannel library which doesn't support CBT
                     }
@@ -1875,7 +1816,7 @@ namespace System.Net
             private static readonly IOCompletionCallback s_IOCallback = new IOCompletionCallback(WaitCallback);
 
             private readonly ulong _connectionId;
-            private readonly HttpListener _httpListener;
+            private readonly HttpListenerSession _listenerSession;
             private readonly NativeOverlapped* _nativeOverlapped;
             private int _ownershipState;   // 0 = normal, 1 = in HandleAuthentication(), 2 = disconnected, 3 = cleaned up
 
@@ -1919,16 +1860,16 @@ namespace System.Net
                 }
             }
 
-            internal unsafe DisconnectAsyncResult(HttpListener httpListener, ulong connectionId)
+            internal unsafe DisconnectAsyncResult(HttpListenerSession session, ulong connectionId)
             {
-                if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"HttpListener: {httpListener}, ConnectionId: {connectionId}");
+                if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"HttpListener: {session.Listener}, ConnectionId: {connectionId}");
                 _ownershipState = 1;
-                _httpListener = httpListener;
+                _listenerSession = session;
                 _connectionId = connectionId;
 
                 // we can call the Unsafe API here, we won't ever call user code
-                _nativeOverlapped = httpListener.RequestQueueBoundHandle.AllocateNativeOverlapped(s_IOCallback, state: this, pinData: null);
-                if (NetEventSource.IsEnabled) NetEventSource.Info($"DisconnectAsyncResult: ThreadPoolBoundHandle.AllocateNativeOverlapped({httpListener._requestQueueBoundHandle}) -> {_nativeOverlapped->GetHashCode()}");
+                _nativeOverlapped = session.RequestQueueBoundHandle.AllocateNativeOverlapped(s_IOCallback, state: this, pinData: null);
+                if (NetEventSource.IsEnabled) NetEventSource.Info($"DisconnectAsyncResult: ThreadPoolBoundHandle.AllocateNativeOverlapped({session.RequestQueueBoundHandle}) -> {_nativeOverlapped->GetHashCode()}");
             }
 
             internal bool StartOwningDisconnectHandling()
@@ -1964,7 +1905,7 @@ namespace System.Net
             {
                 if (NetEventSource.IsEnabled) NetEventSource.Info(null, "_connectionId:" + asyncResult._connectionId);
 
-                asyncResult._httpListener._requestQueueBoundHandle.FreeNativeOverlapped(nativeOverlapped);
+                asyncResult._listenerSession.RequestQueueBoundHandle.FreeNativeOverlapped(nativeOverlapped);
                 if (Interlocked.Exchange(ref asyncResult._ownershipState, 2) == 0)
                 {
                     asyncResult.HandleDisconnect();
@@ -1981,8 +1922,10 @@ namespace System.Net
 
             private void HandleDisconnect()
             {
-                if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"DisconnectResults {_httpListener.DisconnectResults} removing for _connectionId: {_connectionId}");
-                _httpListener.DisconnectResults.Remove(_connectionId);
+                HttpListener listener = _listenerSession.Listener;
+
+                if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"DisconnectResults {listener.DisconnectResults} removing for _connectionId: {_connectionId}");
+                listener.DisconnectResults.Remove(_connectionId);
                 if (_session != null)
                 {
                     _session.CloseContext();
@@ -1994,7 +1937,7 @@ namespace System.Net
                 IDisposable identity = _authenticatedConnection == null ? null : _authenticatedConnection.Identity as IDisposable;
                 if ((identity != null) &&
                     (_authenticatedConnection.Identity.AuthenticationType == AuthenticationTypes.NTLM) &&
-                    (_httpListener.UnsafeConnectionNtlmAuthentication))
+                    (listener.UnsafeConnectionNtlmAuthentication))
                 {
                     identity.Dispose();
                 }
index 87f9f0a..05863ac 100644 (file)
@@ -15,14 +15,16 @@ namespace System.Net
     public sealed unsafe partial class HttpListenerContext
     {
         private string _mutualAuthentication;
+        internal HttpListenerSession ListenerSession { get; private set; }
 
-        internal HttpListenerContext(HttpListener httpListener, RequestContextBase memoryBlob)
+        internal HttpListenerContext(HttpListenerSession session, RequestContextBase memoryBlob)
         {
-            if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"httpListener {httpListener} requestBlob={((IntPtr)memoryBlob.RequestBlob)}");
-            _listener = httpListener;
+            if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"httpListener {session.Listener} requestBlob={((IntPtr)memoryBlob.RequestBlob)}");
+            _listener = session.Listener;
+            ListenerSession = session;
             Request = new HttpListenerRequest(this, memoryBlob);
-            AuthenticationSchemes = httpListener.AuthenticationSchemes;
-            ExtendedProtectionPolicy = httpListener.ExtendedProtectionPolicy;
+            AuthenticationSchemes = _listener.AuthenticationSchemes;
+            ExtendedProtectionPolicy = _listener.ExtendedProtectionPolicy;
             if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"HttpListener: {_listener} HttpListenerRequest: {Request}");
         }
 
@@ -41,9 +43,9 @@ namespace System.Net
 
         internal HttpListener Listener => _listener;
 
-        internal SafeHandle RequestQueueHandle => _listener.RequestQueueHandle;
+        internal SafeHandle RequestQueueHandle => ListenerSession.RequestQueueHandle;
 
-        internal ThreadPoolBoundHandle RequestQueueBoundHandle => _listener.RequestQueueBoundHandle;
+        internal ThreadPoolBoundHandle RequestQueueBoundHandle => ListenerSession.RequestQueueBoundHandle;
 
         internal ulong RequestId => Request.RequestId;
 
index f43558c..b7ccc86 100644 (file)
@@ -555,7 +555,7 @@ namespace System.Net
 
         internal ChannelBinding GetChannelBinding()
         {
-            return HttpListenerContext.Listener.GetChannelBindingFromTls(_connectionId);
+            return HttpListener.GetChannelBindingFromTls(HttpListenerContext.ListenerSession, _connectionId);
         }
 
         internal void CheckDisposed()
diff --git a/src/libraries/System.Net.HttpListener/src/System/Net/Windows/HttpListenerSession.Windows.cs b/src/libraries/System.Net.HttpListener/src/System/Net/Windows/HttpListenerSession.Windows.cs
new file mode 100644 (file)
index 0000000..ba2aca5
--- /dev/null
@@ -0,0 +1,87 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Runtime.InteropServices;
+using System.Threading;
+
+namespace System.Net
+{
+    internal sealed class HttpListenerSession
+    {
+        public readonly HttpListener Listener;
+        public readonly SafeHandle RequestQueueHandle;
+        private ThreadPoolBoundHandle _requestQueueBoundHandle;
+
+        public ThreadPoolBoundHandle RequestQueueBoundHandle
+        {
+            get
+            {
+                if (_requestQueueBoundHandle == null)
+                {
+                    lock (this)
+                    {
+                        if (_requestQueueBoundHandle == null)
+                        {
+                            _requestQueueBoundHandle = ThreadPoolBoundHandle.BindHandle(RequestQueueHandle);
+                            if (NetEventSource.IsEnabled) NetEventSource.Info($"ThreadPoolBoundHandle.BindHandle({RequestQueueHandle}) -> {_requestQueueBoundHandle}");
+                        }
+                    }
+                }
+
+                return _requestQueueBoundHandle;
+            }
+        }
+
+        public unsafe HttpListenerSession(HttpListener listener)
+        {
+            Listener = listener;
+
+            uint statusCode =
+                Interop.HttpApi.HttpCreateRequestQueue(
+                    Interop.HttpApi.s_version, null, null, 0, out HttpRequestQueueV2Handle requestQueueHandle);
+
+            if (statusCode != Interop.HttpApi.ERROR_SUCCESS)
+            {
+                throw new HttpListenerException((int)statusCode);
+            }
+
+            // Disabling callbacks when IO operation completes synchronously (returns ErrorCodes.ERROR_SUCCESS)
+            if (HttpListener.SkipIOCPCallbackOnSuccess &&
+                !Interop.Kernel32.SetFileCompletionNotificationModes(
+                    requestQueueHandle,
+                    Interop.Kernel32.FileCompletionNotificationModes.SkipCompletionPortOnSuccess |
+                    Interop.Kernel32.FileCompletionNotificationModes.SkipSetEventOnHandle))
+            {
+                throw new HttpListenerException(Marshal.GetLastWin32Error());
+            }
+
+            RequestQueueHandle = requestQueueHandle;
+        }
+
+        public unsafe void CloseRequestQueueHandle()
+        {
+            lock (this)
+            {
+                if (!RequestQueueHandle.IsInvalid)
+                {
+                    if (NetEventSource.IsEnabled) NetEventSource.Info($"Dispose ThreadPoolBoundHandle: {_requestQueueBoundHandle}");
+                    _requestQueueBoundHandle?.Dispose();
+                    RequestQueueHandle.Dispose();
+
+                    // CancelIoEx is called after Dispose to prevent a race condition involving parallel GetContext and
+                    // HttpReceiveHttpRequest calls. Otherwise, calling CancelIoEx before Dispose might block the synchronous
+                    // GetContext call until the next request arrives.
+                    try
+                    {
+                        Interop.Kernel32.CancelIoEx(RequestQueueHandle, null); // This cancels the synchronous call to HttpReceiveHttpRequest
+                    }
+                    catch (ObjectDisposedException)
+                    {
+                        // Ignore the exception since it only means that the queue handle has been successfully disposed
+                    }
+                }
+            }
+        }
+    }
+}
index a1b85b6..e58ecd0 100644 (file)
@@ -13,10 +13,10 @@ namespace System.Net
 
         internal static IOCompletionCallback IOCallback => s_ioCallback;
 
-        internal ListenerAsyncResult(HttpListener listener, object userState, AsyncCallback callback) :
-            base(listener, userState, callback)
+        internal ListenerAsyncResult(HttpListenerSession session, object userState, AsyncCallback callback) :
+            base(session, userState, callback)
         {
-            _requestContext = new AsyncRequestContext(listener.RequestQueueBoundHandle, this);
+            _requestContext = new AsyncRequestContext(session.RequestQueueBoundHandle, this);
         }
 
         private static void IOCompleted(ListenerAsyncResult asyncResult, uint errorCode, uint numBytes)
@@ -34,7 +34,7 @@ namespace System.Net
                 }
                 else
                 {
-                    HttpListener httpWebListener = asyncResult.AsyncObject as HttpListener;
+                    HttpListenerSession listenerSession = asyncResult.AsyncObject as HttpListenerSession;
                     if (errorCode == Interop.HttpApi.ERROR_SUCCESS)
                     {
                         // at this point we have received an unmanaged HTTP_REQUEST and memoryBlob
@@ -42,9 +42,9 @@ namespace System.Net
                         bool stoleBlob = false;
                         try
                         {
-                            if (httpWebListener.ValidateRequest(asyncResult._requestContext))
+                            if (HttpListener.ValidateRequest(listenerSession, asyncResult._requestContext))
                             {
-                                result = httpWebListener.HandleAuthentication(asyncResult._requestContext, out stoleBlob);
+                                result = listenerSession.Listener.HandleAuthentication(listenerSession, asyncResult._requestContext, out stoleBlob);
                             }
                         }
                         finally
@@ -52,17 +52,17 @@ namespace System.Net
                             if (stoleBlob)
                             {
                                 // The request has been handed to the user, which means this code can't reuse the blob.  Reset it here.
-                                asyncResult._requestContext = result == null ? new AsyncRequestContext(httpWebListener.RequestQueueBoundHandle, asyncResult) : null;
+                                asyncResult._requestContext = result == null ? new AsyncRequestContext(listenerSession.RequestQueueBoundHandle, asyncResult) : null;
                             }
                             else
                             {
-                                asyncResult._requestContext.Reset(httpWebListener.RequestQueueBoundHandle, 0, 0);
+                                asyncResult._requestContext.Reset(listenerSession.RequestQueueBoundHandle, 0, 0);
                             }
                         }
                     }
                     else
                     {
-                        asyncResult._requestContext.Reset(httpWebListener.RequestQueueBoundHandle, asyncResult._requestContext.RequestBlob->RequestId, numBytes);
+                        asyncResult._requestContext.Reset(listenerSession.RequestQueueBoundHandle, asyncResult._requestContext.RequestBlob->RequestId, numBytes);
                     }
 
                     // We need to issue a new request, either because auth failed, or because our buffer was too small the first time.
@@ -107,9 +107,9 @@ namespace System.Net
             {
                 if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"Calling Interop.HttpApi.HttpReceiveHttpRequest RequestId: {_requestContext.RequestBlob->RequestId}Buffer:0x {((IntPtr)_requestContext.RequestBlob).ToString("x")} Size: {_requestContext.Size}");
                 uint bytesTransferred = 0;
-                HttpListener listener = (HttpListener)AsyncObject;
+                HttpListenerSession listenerSession = (HttpListenerSession)AsyncObject;
                 statusCode = Interop.HttpApi.HttpReceiveHttpRequest(
-                    listener.RequestQueueHandle,
+                    listenerSession.RequestQueueHandle,
                     _requestContext.RequestBlob->RequestId,
                     (uint)Interop.HttpApi.HTTP_FLAGS.HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY,
                     _requestContext.RequestBlob,
@@ -129,7 +129,7 @@ namespace System.Net
                 {
                     // the buffer was not big enough to fit the headers, we need
                     // to read the RequestId returned, allocate a new buffer of the required size
-                    _requestContext.Reset(listener.RequestQueueBoundHandle, _requestContext.RequestBlob->RequestId, bytesTransferred);
+                    _requestContext.Reset(listenerSession.RequestQueueBoundHandle, _requestContext.RequestBlob->RequestId, bytesTransferred);
                     continue;
                 }
                 else if (statusCode == Interop.HttpApi.ERROR_SUCCESS && HttpListener.SkipIOCPCallbackOnSuccess)
index 93f59cf..12f5fe4 100644 (file)
@@ -168,23 +168,10 @@ namespace System.Net.Tests
             }
         }
 
-        [Fact]
-        [ActiveIssue("https://github.com/dotnet/runtime/issues/30284")]
-        public void ListenerRestart_BeginGetContext_Success()
-        {
-            using (HttpListenerFactory factory = new HttpListenerFactory())
-            {
-                HttpListener listener = factory.GetListener();
-                listener.BeginGetContext((f) => { }, null);
-                listener.Stop();
-                listener.Start();
-                listener.BeginGetContext((f) => { }, null);
-            }
-        }
-
-        [ConditionalFact]
-        [ActiveIssue("https://github.com/dotnet/runtime/issues/30284")]
-        public async Task ListenerRestart_GetContext_Success()
+        [ConditionalTheory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async Task ListenerRestart_Success(bool sync)
         {
             const string Content = "ListenerRestart_GetContext_Success";
             using (HttpListenerFactory factory = new HttpListenerFactory())
@@ -195,7 +182,10 @@ namespace System.Net.Tests
                 _output.WriteLine("Connecting to {0}", factory.ListeningUrl);
                 Task<string> clientTask = client.GetStringAsync(factory.ListeningUrl);
 
-                HttpListenerContext context = listener.GetContext();
+                HttpListenerContext context = sync
+                    ? listener.GetContext()
+                    : listener.EndGetContext(listener.BeginGetContext(ar => { }, null));
+
                 HttpListenerResponse response = context.Response;
                 response.OutputStream.Write(Encoding.UTF8.GetBytes(Content));
                 response.OutputStream.Close();
@@ -210,7 +200,7 @@ namespace System.Net.Tests
                     // This may fail if something else took our port while restarting.
                     listener.Start();
                 }
-                catch (Exception e)
+                catch (HttpListenerException e)
                 {
                     _output.WriteLine(e.Message);
                     // Skip test if we lost race and we are unable to bind on same port again.
@@ -221,7 +211,11 @@ namespace System.Net.Tests
 
                 // Repeat request to be sure listener is working.
                 clientTask = client.GetStringAsync(factory.ListeningUrl);
-                context = listener.GetContext();
+
+                context = sync
+                    ? listener.GetContext()
+                    : listener.EndGetContext(listener.BeginGetContext(ar => { }, null));
+
                 response = context.Response;
                 response.OutputStream.Write(Encoding.UTF8.GetBytes(Content));
                 response.OutputStream.Close();