// See the LICENSE file in the project root for more information.
using System;
+using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Net.Sockets;
+using System.Runtime.ExceptionServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
{
Task.Run(async () =>
{
+ var activeTasks = new ConcurrentDictionary<Task, int>();
+
try
{
while (true)
{
Socket s = await _listener.AcceptAsync().ConfigureAwait(false);
- var ignored = Task.Run(async () =>
+
+ var connectionTask = Task.Run(async () =>
{
try
{
await ProcessConnection(s).ConfigureAwait(false);
}
- catch (Exception)
+ catch (Exception ex)
{
- // Ignore exceptions.
+ EventSourceTestLogging.Log.TestAncillaryError(ex);
}
});
+
+ activeTasks.TryAdd(connectionTask, 0);
+ _ = connectionTask.ContinueWith(t => activeTasks.TryRemove(connectionTask, out _), TaskContinuationOptions.ExecuteSynchronously);
}
}
- catch (Exception)
+ catch (SocketException ex) when (ex.SocketErrorCode == SocketError.OperationAborted)
+ {
+ // caused during Dispose() to cancel the loop. ignore.
+ }
+ catch (Exception ex)
{
- // Ignore exceptions.
+ EventSourceTestLogging.Log.TestAncillaryError(ex);
+ }
+
+ try
+ {
+ await Task.WhenAll(activeTasks.Keys).ConfigureAwait(false);
+ }
+ catch (Exception ex)
+ {
+ EventSourceTestLogging.Log.TestAncillaryError(ex);
}
_serverStopped.Set();
{
Interlocked.Increment(ref _connections);
- using (var ns = new NetworkStream(s))
+ using (var ns = new NetworkStream(s, ownsSocket: true))
using (var reader = new StreamReader(ns))
using (var writer = new StreamWriter(ns) { AutoFlush = true })
{
while(true)
{
- if (!(await ProcessRequest(reader, writer).ConfigureAwait(false)))
+ if (!(await ProcessRequest(s, reader, writer).ConfigureAwait(false)))
{
break;
}
}
}
- private async Task<bool> ProcessRequest(StreamReader reader, StreamWriter writer)
+ private async Task<bool> ProcessRequest(Socket clientSocket, StreamReader reader, StreamWriter writer)
{
var headers = new Dictionary<string, string>();
string url = null;
int remotePort = int.Parse(tokens[1]);
Send200Response(writer);
- await ProcessConnectMethod((NetworkStream)reader.BaseStream, remoteHost, remotePort).ConfigureAwait(false);
+ await ProcessConnectMethod(clientSocket, (NetworkStream)reader.BaseStream, remoteHost, remotePort).ConfigureAwait(false);
return false; // connection can't be used for any more requests
}
}
}
- private async Task ProcessConnectMethod(NetworkStream clientStream, string remoteHost, int remotePort)
+ private async Task ProcessConnectMethod(Socket clientSocket, NetworkStream clientStream, string remoteHost, int remotePort)
{
// Open connection to destination server.
- Socket serverSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+ using Socket serverSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await serverSocket.ConnectAsync(remoteHost, remotePort).ConfigureAwait(false);
NetworkStream serverStream = new NetworkStream(serverSocket);
{
try
{
- byte[] buffer = new byte[8000];
- int bytesRead;
- while ((bytesRead = await clientStream.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false)) > 0)
- {
- await serverStream.WriteAsync(buffer, 0, bytesRead).ConfigureAwait(false);
- }
- serverStream.Flush();
+ await clientStream.CopyToAsync(serverStream).ConfigureAwait(false);
serverSocket.Shutdown(SocketShutdown.Send);
}
- catch (IOException)
+ catch (Exception ex)
{
- // Ignore rude disconnects from either side.
+ HandleExceptions(ex);
}
});
+
Task serverCopyTask = Task.Run(async () =>
{
try
{
- byte[] buffer = new byte[8000];
- int bytesRead;
- while ((bytesRead = await serverStream.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false)) > 0)
- {
- await clientStream.WriteAsync(buffer, 0, bytesRead).ConfigureAwait(false);
- }
- clientStream.Flush();
+ await serverStream.CopyToAsync(clientStream).ConfigureAwait(false);
+ clientSocket.Shutdown(SocketShutdown.Send);
}
- catch (IOException)
+ catch (Exception ex)
{
- // Ignore rude disconnects from either side.
+ HandleExceptions(ex);
}
});
- await Task.WhenAny(new[] { clientCopyTask, serverCopyTask }).ConfigureAwait(false);
+ await Task.WhenAll(new[] { clientCopyTask, serverCopyTask }).ConfigureAwait(false);
+
+ /// <summary>Closes sockets to cause both tasks to end, and eats connection reset/aborted errors.</summary>
+ void HandleExceptions(Exception ex)
+ {
+ SocketError sockErr = (ex.InnerException as SocketException)?.SocketErrorCode ?? SocketError.Success;
+
+ // If aborted, the other task failed and is asking this task to end.
+ if (sockErr == SocketError.OperationAborted)
+ {
+ return;
+ }
+
+ // Ask the other task to end by disposing, causing OperationAborted.
+ try
+ {
+ clientSocket.Close();
+ }
+ catch (ObjectDisposedException)
+ {
+ }
+
+ try
+ {
+ serverSocket.Close();
+ }
+ catch (ObjectDisposedException)
+ {
+ }
+
+ // Eat reset/abort.
+ if (sockErr != SocketError.ConnectionReset && sockErr != SocketError.ConnectionAborted)
+ {
+ ExceptionDispatchInfo.Capture(ex).Throw();
+ }
+ }
}
private void Send200Response(StreamWriter writer)