LoopbackProxyServer cleanup (dotnet/corefx#38038)
authorCory Nelson <phrosty@gmail.com>
Tue, 4 Jun 2019 18:04:49 +0000 (11:04 -0700)
committerGitHub <noreply@github.com>
Tue, 4 Jun 2019 18:04:49 +0000 (11:04 -0700)
* Resolves dotnet/corefx#32808.

Ensure all connections are finished prior to Dispose() returning.
Shutdown CONNECT sessions properly.
Dispose sockets rather than relying on finalizers.
Eat fewer errors automatically.
Send errors to an event source to assist with troubleshooting.

Commit migrated from https://github.com/dotnet/corefx/commit/055dd469f7d402f496fcab31ae91e80efd474fa9

src/libraries/Common/tests/System/Net/EventSourceTestLogging.cs
src/libraries/Common/tests/System/Net/Http/LoopbackProxyServer.cs
src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj

index a1fb5e8..c1d4a07 100644 (file)
@@ -32,6 +32,12 @@ namespace System.Net.Test.Common
             WriteEvent(2, message);
         }
 
+        [Event(3, Keywords = Keywords.Debug, Level = EventLevel.Error)]
+        public void TestAncillaryError(Exception ex)
+        {
+            WriteEvent(3, ex.ToString());
+        }
+
         public static class Keywords
         {
             public const EventKeywords Default = (EventKeywords)0x0001;
index 4dafffd..c6c06fe 100644 (file)
@@ -3,11 +3,13 @@
 // See the LICENSE file in the project root for more information.
 
 using System;
+using System.Collections.Concurrent;
 using System.Collections.Generic;
 using System.IO;
 using System.Net;
 using System.Net.Http;
 using System.Net.Sockets;
+using System.Runtime.ExceptionServices;
 using System.Text;
 using System.Threading;
 using System.Threading.Tasks;
@@ -63,27 +65,46 @@ namespace System.Net.Test.Common
         {
             Task.Run(async () =>
             {
+                var activeTasks = new ConcurrentDictionary<Task, int>();
+
                 try
                 {
                     while (true)
                     {
                         Socket s = await _listener.AcceptAsync().ConfigureAwait(false);
-                        var ignored = Task.Run(async () =>
+
+                        var connectionTask = Task.Run(async () =>
                         {
                             try
                             {
                                 await ProcessConnection(s).ConfigureAwait(false);
                             }
-                            catch (Exception)
+                            catch (Exception ex)
                             {
-                                // Ignore exceptions.
+                                EventSourceTestLogging.Log.TestAncillaryError(ex);
                             }
                         });
+
+                        activeTasks.TryAdd(connectionTask, 0);
+                        _ = connectionTask.ContinueWith(t => activeTasks.TryRemove(connectionTask, out _), TaskContinuationOptions.ExecuteSynchronously);
                     }
                 }
-                catch (Exception)
+                catch (SocketException ex) when (ex.SocketErrorCode == SocketError.OperationAborted)
+                {
+                    // caused during Dispose() to cancel the loop. ignore.
+                }
+                catch (Exception ex)
                 {
-                    // Ignore exceptions.
+                    EventSourceTestLogging.Log.TestAncillaryError(ex);
+                }
+
+                try
+                {
+                    await Task.WhenAll(activeTasks.Keys).ConfigureAwait(false);
+                }
+                catch (Exception ex)
+                {
+                    EventSourceTestLogging.Log.TestAncillaryError(ex);
                 }
 
                 _serverStopped.Set();
@@ -94,13 +115,13 @@ namespace System.Net.Test.Common
         {
             Interlocked.Increment(ref _connections);
 
-            using (var ns = new NetworkStream(s))
+            using (var ns = new NetworkStream(s, ownsSocket: true))
             using (var reader = new StreamReader(ns))
             using (var writer = new StreamWriter(ns) { AutoFlush = true })
             {
                 while(true)
                 {
-                    if (!(await ProcessRequest(reader, writer).ConfigureAwait(false)))
+                    if (!(await ProcessRequest(s, reader, writer).ConfigureAwait(false)))
                     {
                         break;
                     }
@@ -108,7 +129,7 @@ namespace System.Net.Test.Common
             }
         }
 
-        private async Task<bool> ProcessRequest(StreamReader reader, StreamWriter writer)
+        private async Task<bool> ProcessRequest(Socket clientSocket, StreamReader reader, StreamWriter writer)
         {
             var headers = new Dictionary<string, string>();
             string url = null;
@@ -160,7 +181,7 @@ namespace System.Net.Test.Common
                 int remotePort = int.Parse(tokens[1]);
 
                 Send200Response(writer);
-                await ProcessConnectMethod((NetworkStream)reader.BaseStream, remoteHost, remotePort).ConfigureAwait(false);
+                await ProcessConnectMethod(clientSocket, (NetworkStream)reader.BaseStream, remoteHost, remotePort).ConfigureAwait(false);
 
                 return false; // connection can't be used for any more requests
             }
@@ -206,10 +227,10 @@ namespace System.Net.Test.Common
             }
         }
 
-        private async Task ProcessConnectMethod(NetworkStream clientStream, string remoteHost, int remotePort)
+        private async Task ProcessConnectMethod(Socket clientSocket, NetworkStream clientStream, string remoteHost, int remotePort)
         {
             // Open connection to destination server.
-            Socket serverSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+            using Socket serverSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
             await serverSocket.ConnectAsync(remoteHost, remotePort).ConfigureAwait(false);
             NetworkStream serverStream = new NetworkStream(serverSocket);
 
@@ -218,39 +239,64 @@ namespace System.Net.Test.Common
             {
                 try
                 {
-                    byte[] buffer = new byte[8000];
-                    int bytesRead;
-                    while ((bytesRead = await clientStream.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false)) > 0)
-                    {
-                        await serverStream.WriteAsync(buffer, 0, bytesRead).ConfigureAwait(false);
-                    }
-                    serverStream.Flush();
+                    await clientStream.CopyToAsync(serverStream).ConfigureAwait(false);
                     serverSocket.Shutdown(SocketShutdown.Send);
                 }
-                catch (IOException)
+                catch (Exception ex)
                 {
-                    // Ignore rude disconnects from either side.
+                    HandleExceptions(ex);
                 }
             });
+
             Task serverCopyTask = Task.Run(async () =>
             {
                 try
                 {
-                    byte[] buffer = new byte[8000];
-                    int bytesRead;
-                    while ((bytesRead = await serverStream.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false)) > 0)
-                    {
-                        await clientStream.WriteAsync(buffer, 0, bytesRead).ConfigureAwait(false);
-                    }
-                    clientStream.Flush();
+                    await serverStream.CopyToAsync(clientStream).ConfigureAwait(false);
+                    clientSocket.Shutdown(SocketShutdown.Send);
                 }
-                catch (IOException)
+                catch (Exception ex)
                 {
-                    // Ignore rude disconnects from either side.
+                    HandleExceptions(ex);
                 }
             });
 
-            await Task.WhenAny(new[] { clientCopyTask, serverCopyTask }).ConfigureAwait(false);
+            await Task.WhenAll(new[] { clientCopyTask, serverCopyTask }).ConfigureAwait(false);
+
+            /// <summary>Closes sockets to cause both tasks to end, and eats connection reset/aborted errors.</summary>
+            void HandleExceptions(Exception ex)
+            {
+                SocketError sockErr = (ex.InnerException as SocketException)?.SocketErrorCode ?? SocketError.Success;
+
+                // If aborted, the other task failed and is asking this task to end.
+                if (sockErr == SocketError.OperationAborted)
+                {
+                    return;
+                }
+
+                // Ask the other task to end by disposing, causing OperationAborted.
+                try
+                {
+                    clientSocket.Close();
+                }
+                catch (ObjectDisposedException)
+                {
+                }
+
+                try
+                {
+                    serverSocket.Close();
+                }
+                catch (ObjectDisposedException)
+                {
+                }
+
+                // Eat reset/abort.
+                if (sockErr != SocketError.ConnectionReset && sockErr != SocketError.ConnectionAborted)
+                {
+                    ExceptionDispatchInfo.Capture(ex).Throw();
+                }
+            }
         }
 
         private void Send200Response(StreamWriter writer)
index 8f7f3db..ba5d73d 100644 (file)
@@ -1,4 +1,4 @@
-<Project Sdk="Microsoft.NET.Sdk">
+<Project Sdk="Microsoft.NET.Sdk">
   <PropertyGroup>
     <ProjectGuid>{7C395A91-D955-444C-98BF-D3F809A56CE1}</ProjectGuid>
     <StringResourcesPath>../src/Resources/Strings.resx</StringResourcesPath>
@@ -28,6 +28,9 @@
     <Compile Include="$(CommonTestPath)\System\Net\Configuration.WebSockets.cs">
       <Link>Common\System\Net\Configuration.WebSockets.cs</Link>
     </Compile>
+    <Compile Include="$(CommonTestPath)\System\Net\EventSourceTestLogging.cs">
+      <Link>Common\System\Net\EventSourceTestLogging.cs</Link>
+    </Compile>
     <Compile Include="$(CommonTestPath)\System\Net\Http\LoopbackProxyServer.cs">
       <Link>Common\System\Net\Http\LoopbackProxyServer.cs</Link>
     </Compile>