Fix System.Net.Sockets telemetry (#42188)
authorMiha Zupan <mihazupan.zupan1@gmail.com>
Fri, 25 Sep 2020 08:16:18 +0000 (10:16 +0200)
committerGitHub <noreply@github.com>
Fri, 25 Sep 2020 08:16:18 +0000 (10:16 +0200)
* Correct NameResolutionTelemetry logic

Taken from c4c9a2f99b7e339388199086d3014041abccc21e

* Enable listening to multiple sources with TestEventListener

* Workaround EventWrittenEventArgs bug when the EventArgs are stored

Workaround https://github.com/dotnet/runtime/issues/42128

* Correct System.Net.Sockets Telemetry

* Avoid using value tuple in TestEventListener

* Remove unnecessary argument to OnCompletedInternal

* Remove redundant Telemetry.IsEnabled check

* Log Connect/Accept start before the initial context capture

* Use SocketHelperBase in Accept tests

* Avoid duplicate events for BeginConnect without ConnextEx support

* Enable Sync Socket tests

* Revert unrelated SocketPal change

* Log the correct ErrorCode in case of socket disposal

* Add more info on TelemetryTest timeout

* Add PlatformSpecific attribute to ConnectFailure test

* Add missing BeginConnect AfterConnect call on sync failure

* Add comment around GetHelperBase

* Correct WaitForEventCountersAsync helper

* Assert against SocketError.TimedOut in ConnectFailed test

* Add EndConnect comment

* Log Failure around sync SocketPal exceptions

* Don't assert that the exception message is empty

* Dispose socket in a different Thread

* Disable Telemetry failure tests for Sync on RedHat7

* Use more descriptive names in MemberData generation

* Avoid using reflection for #42128 workaround

* Remove ConnectCanceled event

src/libraries/Common/tests/System/Diagnostics/Tracing/TestEventListener.cs
src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketsTelemetry.cs
src/libraries/System.Net.Sockets/tests/FunctionalTests/TelemetryTest.cs

index b12f502..74c0d9f 100644 (file)
@@ -9,90 +9,91 @@ namespace System.Diagnostics.Tracing
     /// <summary>Simple event listener than invokes a callback for each event received.</summary>
     internal sealed class TestEventListener : EventListener
     {
-        private readonly string _targetSourceName;
-        private readonly Guid _targetSourceGuid;
-        private readonly EventLevel _level;
+        private class Settings
+        {
+            public EventLevel Level;
+            public EventKeywords Keywords;
+        }
+
+        private readonly Dictionary<string, Settings> _names = new Dictionary<string, Settings>();
+        private readonly Dictionary<Guid, Settings> _guids = new Dictionary<Guid, Settings>();
+
         private readonly double? _eventCounterInterval;
 
         private Action<EventWrittenEventArgs> _eventWritten;
-        private List<EventSource> _tmpEventSourceList = new List<EventSource>();
+        private readonly List<EventSource> _eventSourceList = new List<EventSource>();
 
         public TestEventListener(string targetSourceName, EventLevel level, double? eventCounterInterval = null)
         {
-            // Store the arguments
-            _targetSourceName = targetSourceName;
-            _level = level;
             _eventCounterInterval = eventCounterInterval;
-
-            LoadSourceList();
+            AddSource(targetSourceName, level);
         }
 
         public TestEventListener(Guid targetSourceGuid, EventLevel level, double? eventCounterInterval = null)
         {
-            // Store the arguments
-            _targetSourceGuid = targetSourceGuid;
-            _level = level;
             _eventCounterInterval = eventCounterInterval;
-
-            LoadSourceList();
+            AddSource(targetSourceGuid, level);
         }
 
-        private void LoadSourceList()
+        public void AddSource(string name, EventLevel level, EventKeywords keywords = EventKeywords.All) =>
+            AddSource(name, null, level, keywords);
+
+        public void AddSource(Guid guid, EventLevel level, EventKeywords keywords = EventKeywords.All) =>
+            AddSource(null, guid, level, keywords);
+
+        private void AddSource(string name, Guid? guid, EventLevel level, EventKeywords keywords)
         {
-            // The base constructor, which is called before this constructor,
-            // will invoke the virtual OnEventSourceCreated method for each
-            // existing EventSource, which means OnEventSourceCreated will be
-            // called before _targetSourceGuid and _level have been set.  As such,
-            // we store a temporary list that just exists from the moment this instance
-            // is created (instance field initializers run before the base constructor)
-            // and until we finish construction... in that window, OnEventSourceCreated
-            // will store the sources into the list rather than try to enable them directly,
-            // and then here we can enumerate that list, then clear it out.
-            List<EventSource> sources;
-            lock (_tmpEventSourceList)
+            lock (_eventSourceList)
             {
-                sources = _tmpEventSourceList;
-                _tmpEventSourceList = null;
-            }
-            foreach (EventSource source in sources)
-            {
-                EnableSourceIfMatch(source);
+                var settings = new Settings()
+                {
+                    Level = level,
+                    Keywords = keywords
+                };
+
+                if (name is not null)
+                    _names.Add(name, settings);
+
+                if (guid.HasValue)
+                    _guids.Add(guid.Value, settings);
+
+                foreach (EventSource source in _eventSourceList)
+                {
+                    if (name == source.Name || guid == source.Guid)
+                    {
+                        EnableEventSource(source, level, keywords);
+                    }
+                }
             }
         }
 
+        public void AddActivityTracking() =>
+            AddSource("System.Threading.Tasks.TplEventSource", EventLevel.Informational, (EventKeywords)0x80 /* TasksFlowActivityIds */);
+
         protected override void OnEventSourceCreated(EventSource eventSource)
         {
-            List<EventSource> tmp = _tmpEventSourceList;
-            if (tmp != null)
+            lock (_eventSourceList)
             {
-                lock (tmp)
+                _eventSourceList.Add(eventSource);
+
+                if (_names.TryGetValue(eventSource.Name, out Settings settings) ||
+                    _guids.TryGetValue(eventSource.Guid, out settings))
                 {
-                    if (_tmpEventSourceList != null)
-                    {
-                        _tmpEventSourceList.Add(eventSource);
-                        return;
-                    }
+                    EnableEventSource(eventSource, settings.Level, settings.Keywords);
                 }
             }
-
-            EnableSourceIfMatch(eventSource);
         }
 
-        private void EnableSourceIfMatch(EventSource source)
+        private void EnableEventSource(EventSource source, EventLevel level, EventKeywords keywords)
         {
-            if (source.Name.Equals(_targetSourceName) ||
-                source.Guid.Equals(_targetSourceGuid))
+            var args = new Dictionary<string, string>();
+
+            if (_eventCounterInterval != null)
             {
-                if (_eventCounterInterval != null)
-                {
-                    var args = new Dictionary<string, string> { { "EventCounterIntervalSec", _eventCounterInterval?.ToString() } };
-                    EnableEvents(source, _level, EventKeywords.All, args);
-                }
-                else
-                {
-                    EnableEvents(source, _level);
-                }
+                args.Add("EventCounterIntervalSec", _eventCounterInterval.ToString());
             }
+
+            EnableEvents(source, level, keywords, args);
         }
 
         public void RunWithCallback(Action<EventWrittenEventArgs> handler, Action body)
index 98cb7cc..60cc964 100644 (file)
@@ -466,34 +466,9 @@ namespace System.Net
 
                 if (NameResolutionTelemetry.Log.IsEnabled())
                 {
-                    ValueStopwatch stopwatch = NameResolutionTelemetry.Log.BeforeResolution(hostName);
-
-                    Task coreTask;
-                    try
-                    {
-                        coreTask = NameResolutionPal.GetAddrInfoAsync(hostName, justAddresses);
-                    }
-                    catch when (LogFailure(stopwatch))
-                    {
-                        Debug.Fail("LogFailure should return false");
-                        throw;
-                    }
-
-                    coreTask.ContinueWith(
-                        (task, state) =>
-                        {
-                            NameResolutionTelemetry.Log.AfterResolution(
-                                stopwatch: (ValueStopwatch)state!,
-                                successful: task.IsCompletedSuccessfully);
-                        },
-                        state: stopwatch,
-                        cancellationToken: default,
-                        TaskContinuationOptions.ExecuteSynchronously,
-                        TaskScheduler.Default);
-
-                    // coreTask is not actually a base Task, but Task<IPHostEntry> / Task<IPAddress[]>
-                    // We have to return it and not the continuation
-                    return coreTask;
+                    return justAddresses
+                        ? (Task)GetAddrInfoWithTelemetryAsync<IPAddress[]>(hostName, justAddresses)
+                        : (Task)GetAddrInfoWithTelemetryAsync<IPHostEntry>(hostName, justAddresses);
                 }
                 else
                 {
@@ -506,6 +481,23 @@ namespace System.Net
                 RunAsync(s => GetHostEntryCore((string)s), hostName);
         }
 
+        private static async Task<T> GetAddrInfoWithTelemetryAsync<T>(string hostName, bool justAddresses)
+            where T : class
+        {
+            ValueStopwatch stopwatch = NameResolutionTelemetry.Log.BeforeResolution(hostName);
+
+            T? result = null;
+            try
+            {
+                result = await ((Task<T>)NameResolutionPal.GetAddrInfoAsync(hostName, justAddresses)).ConfigureAwait(false);
+                return result;
+            }
+            finally
+            {
+                NameResolutionTelemetry.Log.AfterResolution(stopwatch, successful: result is not null);
+            }
+        }
+
         private static Task<TResult> RunAsync<TResult>(Func<object, TResult> func, object arg) =>
             Task.Factory.StartNew(func!, arg, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default);
 
index 52558da..0e73552 100644 (file)
@@ -1104,25 +1104,37 @@ namespace System.Net.Sockets
 
             // This may throw ObjectDisposedException.
             SafeSocketHandle acceptedSocketHandle;
-            SocketError errorCode = SocketPal.Accept(
-                _handle,
-                socketAddress.Buffer,
-                ref socketAddress.InternalSize,
-                out acceptedSocketHandle);
+            SocketError errorCode;
+            try
+            {
+                errorCode = SocketPal.Accept(
+                    _handle,
+                    socketAddress.Buffer,
+                    ref socketAddress.InternalSize,
+                    out acceptedSocketHandle);
+            }
+            catch (Exception ex)
+            {
+                if (SocketsTelemetry.Log.IsEnabled())
+                {
+                    SocketsTelemetry.Log.AfterAccept(SocketError.Interrupted, ex.Message);
+                }
+
+                throw;
+            }
 
             // Throw an appropriate SocketException if the native call fails.
             if (errorCode != SocketError.Success)
             {
-                if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AcceptFailedAndStop(errorCode, null);
-
                 Debug.Assert(acceptedSocketHandle.IsInvalid);
                 UpdateAcceptSocketErrorForDisposed(ref errorCode);
+
+                if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AfterAccept(errorCode);
+
                 UpdateStatusAfterSocketErrorAndThrowException(errorCode);
             }
-            else
-            {
-                if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AcceptStop();
-            }
+
+            if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AfterAccept(SocketError.Success);
 
             Debug.Assert(!acceptedSocketHandle.IsInvalid);
 
@@ -2140,8 +2152,6 @@ namespace System.Net.Sockets
 
         internal IAsyncResult UnsafeBeginConnect(EndPoint remoteEP, AsyncCallback? callback, object? state, bool flowContext = false)
         {
-            if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.ConnectStart(remoteEP);
-
             if (CanUseConnectEx(remoteEP))
             {
                 return BeginConnectEx(remoteEP, flowContext, callback, state);
@@ -2363,7 +2373,23 @@ namespace System.Net.Sockets
         //    int - Return code from async Connect, 0 for success, SocketError.NotConnected otherwise
         public void EndConnect(IAsyncResult asyncResult)
         {
-            ThrowIfDisposed();
+            // There are three AsyncResult types we support in EndConnect:
+            // - ConnectAsyncResult - a fully synchronous operation that already completed, wrapped in an AsyncResult
+            // - MultipleAddressConnectAsyncResult - a parent operation for other Connects (connecting to DnsEndPoint)
+            // - ConnectOverlappedAsyncResult - a connect to an IPEndPoint
+            // For Telemetry, we already logged everything for ConnectAsyncResult in DoConnect,
+            // and we want to avoid logging duplicated events for MultipleAddressConnect.
+            // Therefore, we always check that asyncResult is ConnectOverlapped before logging.
+
+            if (Disposed)
+            {
+                if (SocketsTelemetry.Log.IsEnabled() && asyncResult is ConnectOverlappedAsyncResult)
+                {
+                    SocketsTelemetry.Log.AfterConnect(SocketError.NotSocket);
+                }
+
+                ThrowObjectDisposedException();
+            }
 
             // Validate input parameters.
             if (asyncResult == null)
@@ -2391,13 +2417,13 @@ namespace System.Net.Sockets
             if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"asyncResult:{asyncResult}");
 
             Exception? ex = castedAsyncResult.Result as Exception;
+
             if (ex != null || (SocketError)castedAsyncResult.ErrorCode != SocketError.Success)
             {
-                if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.ConnectFailedAndStop((SocketError)castedAsyncResult.ErrorCode, ex?.Message);
+                SocketError errorCode = (SocketError)castedAsyncResult.ErrorCode;
 
                 if (ex == null)
                 {
-                    SocketError errorCode = (SocketError)castedAsyncResult.ErrorCode;
                     UpdateConnectSocketErrorForDisposed(ref errorCode);
                     // Update the internal state of this socket according to the error before throwing.
                     SocketException se = SocketExceptionFactory.CreateSocketException((int)errorCode, castedAsyncResult.RemoteEndPoint);
@@ -2405,11 +2431,19 @@ namespace System.Net.Sockets
                     ex = se;
                 }
 
+                if (SocketsTelemetry.Log.IsEnabled() && castedAsyncResult is ConnectOverlappedAsyncResult)
+                {
+                    SocketsTelemetry.Log.AfterConnect(errorCode, ex.Message);
+                }
+
                 if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, ex);
                 ExceptionDispatchInfo.Throw(ex);
             }
 
-            if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.ConnectStop();
+            if (SocketsTelemetry.Log.IsEnabled() && castedAsyncResult is ConnectOverlappedAsyncResult)
+            {
+                SocketsTelemetry.Log.AfterConnect(SocketError.Success);
+            }
 
             if (NetEventSource.Log.IsEnabled()) NetEventSource.Connected(this, LocalEndPoint, RemoteEndPoint);
         }
@@ -3533,21 +3567,33 @@ namespace System.Net.Sockets
             if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AcceptStart(_rightEndPoint);
 
             int socketAddressSize = GetAddressSize(_rightEndPoint);
-            SocketError errorCode = SocketPal.AcceptAsync(this, _handle, acceptHandle, receiveSize, socketAddressSize, asyncResult);
+            SocketError errorCode;
+            try
+            {
+                errorCode = SocketPal.AcceptAsync(this, _handle, acceptHandle, receiveSize, socketAddressSize, asyncResult);
+            }
+            catch (Exception ex)
+            {
+                if (SocketsTelemetry.Log.IsEnabled())
+                {
+                    SocketsTelemetry.Log.AfterAccept(SocketError.Interrupted, ex.Message);
+                }
+
+                throw;
+            }
 
             if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"AcceptAsync returns:{errorCode} {asyncResult}");
 
             // Throw an appropriate SocketException if the native call fails synchronously.
             if (!CheckErrorAndUpdateStatus(errorCode))
             {
-                if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AcceptFailedAndStop(errorCode, null);
-
                 UpdateAcceptSocketErrorForDisposed(ref errorCode);
+
+                if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AfterAccept(errorCode);
+
                 throw new SocketException((int)errorCode);
             }
 
-            if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AcceptStop();
-
             // Finish the flow capture, maybe complete here.
             asyncResult.FinishPostingAsyncOp(ref Caches.AcceptClosureCache);
 
@@ -3573,7 +3619,12 @@ namespace System.Net.Sockets
         }
         private Socket EndAcceptCommon(out byte[]? buffer, out int bytesTransferred, IAsyncResult asyncResult)
         {
-            ThrowIfDisposed();
+            if (Disposed)
+            {
+                if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AfterAccept(SocketError.Interrupted);
+
+                ThrowObjectDisposedException();
+            }
 
             // Validate input parameters.
             if (asyncResult == null)
@@ -3594,21 +3645,23 @@ namespace System.Net.Sockets
             bytesTransferred = (int)castedAsyncResult.BytesTransferred;
             buffer = castedAsyncResult.Buffer;
 
+            if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.BytesReceived(bytesTransferred);
+
             castedAsyncResult.EndCalled = true;
 
             // Throw an appropriate SocketException if the native call failed asynchronously.
             SocketError errorCode = (SocketError)castedAsyncResult.ErrorCode;
+
             if (errorCode != SocketError.Success)
             {
-                if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AcceptFailedAndStop(errorCode, null);
-
                 UpdateAcceptSocketErrorForDisposed(ref errorCode);
+
+                if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AfterAccept(errorCode);
+
                 UpdateStatusAfterSocketErrorAndThrowException(errorCode);
             }
-            else
-            {
-                if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AcceptStop();
-            }
+
+            if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AfterAccept(SocketError.Success);
 
             if (NetEventSource.Log.IsEnabled()) NetEventSource.Accepted(socket, socket.RemoteEndPoint, socket.LocalEndPoint);
             return socket;
@@ -3662,16 +3715,23 @@ namespace System.Net.Sockets
             SafeSocketHandle? acceptHandle;
             e.AcceptSocket = GetOrCreateAcceptSocket(e.AcceptSocket, true, "AcceptSocket", out acceptHandle);
 
+            if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AcceptStart(_rightEndPoint!);
+
             // Prepare for and make the native call.
             e.StartOperationCommon(this, SocketAsyncOperation.Accept);
             e.StartOperationAccept();
-            SocketError socketError = SocketError.Success;
+            SocketError socketError;
             try
             {
                 socketError = e.DoOperationAccept(this, _handle, acceptHandle);
             }
-            catch
+            catch (Exception ex)
             {
+                if (SocketsTelemetry.Log.IsEnabled())
+                {
+                    SocketsTelemetry.Log.AfterAccept(SocketError.Interrupted, ex.Message);
+                }
+
                 // Clear in-use flag on event args object.
                 e.Complete();
                 throw;
@@ -3762,12 +3822,17 @@ namespace System.Net.Sockets
                     _rightEndPoint = endPointSnapshot;
                 }
 
+                if (SocketsTelemetry.Log.IsEnabled())
+                {
+                    SocketsTelemetry.Log.ConnectStart(e._socketAddress!);
+                }
+
                 // Prepare for the native call.
                 e.StartOperationCommon(this, SocketAsyncOperation.Connect);
                 e.StartOperationConnect(multipleConnect: null, userSocket);
 
                 // Make the native call.
-                SocketError socketError = SocketError.Success;
+                SocketError socketError;
                 try
                 {
                     if (CanUseConnectEx(endPointSnapshot))
@@ -3780,8 +3845,13 @@ namespace System.Net.Sockets
                         socketError = e.DoOperationConnect(this, _handle);
                     }
                 }
-                catch
+                catch (Exception ex)
                 {
+                    if (SocketsTelemetry.Log.IsEnabled())
+                    {
+                        SocketsTelemetry.Log.AfterConnect(SocketError.NotSocket, ex.Message);
+                    }
+
                     _rightEndPoint = oldEndPoint;
                     _localEndPoint = null;
 
@@ -4221,22 +4291,36 @@ namespace System.Net.Sockets
         {
             if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.ConnectStart(socketAddress);
 
-            SocketError errorCode = SocketPal.Connect(_handle, socketAddress.Buffer, socketAddress.Size);
+            SocketError errorCode;
+            try
+            {
+                errorCode = SocketPal.Connect(_handle, socketAddress.Buffer, socketAddress.Size);
+            }
+            catch (Exception ex)
+            {
+                if (SocketsTelemetry.Log.IsEnabled())
+                {
+                    SocketsTelemetry.Log.AfterConnect(SocketError.NotSocket, ex.Message);
+                }
+
+                throw;
+            }
 
             // Throw an appropriate SocketException if the native call fails.
             if (errorCode != SocketError.Success)
             {
-                if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.ConnectFailedAndStop(errorCode, null);
-
                 UpdateConnectSocketErrorForDisposed(ref errorCode);
                 // Update the internal state of this socket according to the error before throwing.
                 SocketException socketException = SocketExceptionFactory.CreateSocketException((int)errorCode, endPointSnapshot);
                 UpdateStatusAfterSocketError(socketException);
                 if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, socketException);
+
+                if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AfterConnect(errorCode);
+
                 throw socketException;
             }
 
-            if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.ConnectStop();
+            if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AfterConnect(SocketError.Success);
 
             if (_rightEndPoint == null)
             {
@@ -4617,6 +4701,14 @@ namespace System.Net.Sockets
             EndPoint endPointSnapshot = remoteEP;
             Internals.SocketAddress socketAddress = Serialize(ref endPointSnapshot);
 
+            if (SocketsTelemetry.Log.IsEnabled())
+            {
+                SocketsTelemetry.Log.ConnectStart(socketAddress);
+
+                // Ignore flowContext when using Telemetry to avoid losing Activity tracking
+                flowContext = true;
+            }
+
             WildcardBindForConnectIfNecessary(endPointSnapshot.AddressFamily);
 
             // Allocate the async result and the event we'll pass to the thread pool.
@@ -4639,8 +4731,13 @@ namespace System.Net.Sockets
             {
                 errorCode = SocketPal.ConnectAsync(this, _handle, socketAddress.Buffer, socketAddress.Size, asyncResult);
             }
-            catch
+            catch (Exception ex)
             {
+                if (SocketsTelemetry.Log.IsEnabled())
+                {
+                    SocketsTelemetry.Log.AfterConnect(SocketError.NotSocket, ex.Message);
+                }
+
                 // _rightEndPoint will always equal oldEndPoint.
                 _rightEndPoint = oldEndPoint;
                 _localEndPoint = null;
@@ -4662,6 +4759,8 @@ namespace System.Net.Sockets
                 _rightEndPoint = oldEndPoint;
                 _localEndPoint = null;
 
+                if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AfterConnect(errorCode);
+
                 throw new SocketException((int)errorCode);
             }
 
index 6b304a2..7f86d70 100644 (file)
@@ -158,6 +158,9 @@ namespace System.Net.Sockets
                     // so we can set the results right now.
                     FreeNativeOverlapped(overlapped);
                     FinishOperationSyncSuccess(bytesTransferred, SocketFlags.None);
+
+                    if (SocketsTelemetry.Log.IsEnabled()) AfterConnectAcceptTelemetry();
+
                     return SocketError.Success;
                 }
 
@@ -173,6 +176,9 @@ namespace System.Net.Sockets
                     // Completed synchronously with a failure.
                     FreeNativeOverlapped(overlapped);
                     FinishOperationSyncFailure(socketError, bytesTransferred, SocketFlags.None);
+
+                    if (SocketsTelemetry.Log.IsEnabled()) AfterConnectAcceptTelemetry();
+
                     return socketError;
                 }
 
@@ -205,6 +211,9 @@ namespace System.Net.Sockets
                     _singleBufferHandleState = SingleBufferHandleState.None;
                     FreeNativeOverlapped(overlapped);
                     FinishOperationSyncSuccess(bytesTransferred, SocketFlags.None);
+
+                    if (SocketsTelemetry.Log.IsEnabled()) AfterConnectAcceptTelemetry();
+
                     return SocketError.Success;
                 }
 
@@ -221,6 +230,9 @@ namespace System.Net.Sockets
                     _singleBufferHandleState = SingleBufferHandleState.None;
                     FreeNativeOverlapped(overlapped);
                     FinishOperationSyncFailure(socketError, bytesTransferred, SocketFlags.None);
+
+                    if (SocketsTelemetry.Log.IsEnabled()) AfterConnectAcceptTelemetry();
+
                     return socketError;
                 }
 
index 114a65d..2a40157 100644 (file)
@@ -198,11 +198,38 @@ namespace System.Net.Sockets
 
         public event EventHandler<SocketAsyncEventArgs>? Completed;
 
+        private void OnCompletedInternal()
+        {
+            if (SocketsTelemetry.Log.IsEnabled())
+            {
+                AfterConnectAcceptTelemetry();
+            }
+
+            OnCompleted(this);
+        }
+
         protected virtual void OnCompleted(SocketAsyncEventArgs e)
         {
             Completed?.Invoke(e._currentSocket, e);
         }
 
+        private void AfterConnectAcceptTelemetry()
+        {
+            switch (LastOperation)
+            {
+                case SocketAsyncOperation.Accept:
+                    SocketsTelemetry.Log.AfterAccept(SocketError);
+                    break;
+
+                case SocketAsyncOperation.Connect:
+                    if (_multipleConnect is null)
+                    {
+                        SocketsTelemetry.Log.AfterConnect(SocketError);
+                    }
+                    break;
+            }
+        }
+
         // DisconnectResuseSocket property.
         public bool DisconnectReuseSocket
         {
@@ -420,7 +447,7 @@ namespace System.Net.Sockets
         private static void ExecutionCallback(object? state)
         {
             var thisRef = (SocketAsyncEventArgs)state!;
-            thisRef.OnCompleted(thisRef);
+            thisRef.OnCompletedInternal();
         }
 
         // Marks this object as no longer "in-use". Will also execute a Dispose deferred
@@ -509,7 +536,9 @@ namespace System.Net.Sockets
             _currentSocket = socket;
 
             // Capture execution context if needed (it is unless explicitly disabled).
-            if (_flowExecutionContext)
+            // If Telemetry is enabled, make sure to capture the context if we're making a Connect or Accept call to preserve the activity
+            if (_flowExecutionContext ||
+                (SocketsTelemetry.Log.IsEnabled() && (operation == SocketAsyncOperation.Connect || operation == SocketAsyncOperation.Accept)))
             {
                 _context = ExecutionContext.Capture();
             }
@@ -547,8 +576,6 @@ namespace System.Net.Sockets
                     _acceptBuffer = new byte[_acceptAddressBufferCount];
                 }
             }
-
-            if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AcceptStart(_currentSocket!._rightEndPoint!);
         }
 
         internal void StartOperationConnect(MultipleConnectAsync? multipleConnect, bool userSocket)
@@ -556,9 +583,6 @@ namespace System.Net.Sockets
             _multipleConnect = multipleConnect;
             _connectSocket = null;
             _userSocket = userSocket;
-
-            // Log only the actual connect operation to a remote endpoint.
-            if (SocketsTelemetry.Log.IsEnabled() && multipleConnect == null) SocketsTelemetry.Log.ConnectStart(_socketAddress!);
         }
 
         internal void CancelConnectAsync()
@@ -572,8 +596,6 @@ namespace System.Net.Sockets
                 }
                 else
                 {
-                    if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.ConnectCanceledAndStop();
-
                     // Otherwise we're doing a normal ConnectAsync - cancel it by closing the socket.
                     // _currentSocket will only be null if _multipleConnect was set, so we don't have to check.
                     if (_currentSocket == null)
@@ -589,12 +611,6 @@ namespace System.Net.Sockets
         {
             SetResults(socketError, bytesTransferred, flags);
 
-            if (SocketsTelemetry.Log.IsEnabled())
-            {
-                if (_multipleConnect == null && _completedOperation == SocketAsyncOperation.Connect) SocketsTelemetry.Log.ConnectFailedAndStop(socketError, null);
-                if (_completedOperation == SocketAsyncOperation.Accept) SocketsTelemetry.Log.AcceptFailedAndStop(socketError, null);
-            }
-
             // This will be null if we're doing a static ConnectAsync to a DnsEndPoint with AddressFamily.Unspecified;
             // the attempt socket will be closed anyways, so not updating the state is OK.
             // If we're doing a static ConnectAsync to an IPEndPoint, we need to dispose
@@ -640,7 +656,7 @@ namespace System.Net.Sockets
 
             if (context == null)
             {
-                OnCompleted(this);
+                OnCompletedInternal();
             }
             else
             {
@@ -656,7 +672,7 @@ namespace System.Net.Sockets
 
             if (context == null)
             {
-                OnCompleted(this);
+                OnCompletedInternal();
             }
             else
             {
@@ -677,7 +693,7 @@ namespace System.Net.Sockets
             Complete();
             if (context == null)
             {
-                OnCompleted(this);
+                OnCompletedInternal();
             }
             else
             {
@@ -715,13 +731,9 @@ namespace System.Net.Sockets
                             }
                             catch (ObjectDisposedException) { }
                         }
-
-                        if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AcceptStop();
                     }
                     else
                     {
-                        if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AcceptFailedAndStop(socketError, null);
-
                         SetResults(socketError, bytesTransferred, flags);
                         _acceptSocket = null;
                         _currentSocket.UpdateStatusAfterSocketError(socketError);
@@ -741,16 +753,12 @@ namespace System.Net.Sockets
                             catch (ObjectDisposedException) { }
                         }
 
-                        if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.ConnectStop();
-
                         // Mark socket connected.
                         _currentSocket!.SetToConnected();
                         _connectSocket = _currentSocket;
                     }
                     else
                     {
-                        if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.ConnectFailedAndStop(socketError, null);
-
                         SetResults(socketError, bytesTransferred, flags);
                         _currentSocket!.UpdateStatusAfterSocketError(socketError);
                     }
@@ -814,7 +822,7 @@ namespace System.Net.Sockets
             // Raise completion event.
             if (context == null)
             {
-                OnCompleted(this);
+                OnCompletedInternal();
             }
             else
             {
@@ -834,6 +842,8 @@ namespace System.Net.Sockets
             {
                 FinishOperationSyncFailure(socketError, bytesTransferred, flags);
             }
+
+            if (SocketsTelemetry.Log.IsEnabled()) AfterConnectAcceptTelemetry();
         }
 
         private static void LogBytesTransferEvents(SocketType? socketType, SocketAsyncOperation operation, int bytesTransferred)
index f3430e7..7b7641f 100644 (file)
@@ -1,6 +1,7 @@
 ï»¿// Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System.Diagnostics;
 using System.Diagnostics.Tracing;
 using System.Threading;
 
@@ -26,17 +27,13 @@ namespace System.Net.Sockets
         private long _datagramsSent;
 
         [Event(1, Level = EventLevel.Informational)]
-        public void ConnectStart(string? address)
+        private void ConnectStart(string? address)
         {
-            Interlocked.Increment(ref _outgoingConnectionsEstablished);
-            if (IsEnabled(EventLevel.Informational, EventKeywords.All))
-            {
-                WriteEvent(eventId: 1, address ?? "");
-            }
+            WriteEvent(eventId: 1, address);
         }
 
         [Event(2, Level = EventLevel.Informational)]
-        public void ConnectStop()
+        private void ConnectStop()
         {
             if (IsEnabled(EventLevel.Informational, EventKeywords.All))
             {
@@ -45,105 +42,108 @@ namespace System.Net.Sockets
         }
 
         [Event(3, Level = EventLevel.Error)]
-        public void ConnectFailed(SocketError error, string? exceptionMessage)
+        private void ConnectFailed(SocketError error, string? exceptionMessage)
         {
             if (IsEnabled(EventLevel.Error, EventKeywords.All))
             {
-                WriteEvent(eventId: 3, (int)error, exceptionMessage ?? string.Empty);
+                WriteEvent(eventId: 3, (int)error, exceptionMessage);
             }
         }
 
-        [Event(4, Level = EventLevel.Warning)]
-        public void ConnectCanceled()
+        [Event(4, Level = EventLevel.Informational)]
+        private void AcceptStart(string? address)
         {
-            if (IsEnabled(EventLevel.Warning, EventKeywords.All))
-            {
-                WriteEvent(eventId: 4);
-            }
+            WriteEvent(eventId: 4, address);
         }
 
         [Event(5, Level = EventLevel.Informational)]
-        public void AcceptStart(string? address)
+        private void AcceptStop()
         {
-            Interlocked.Increment(ref _incomingConnectionsEstablished);
             if (IsEnabled(EventLevel.Informational, EventKeywords.All))
             {
-                WriteEvent(eventId: 5, address ?? "");
+                WriteEvent(eventId: 5);
             }
         }
 
-        [Event(6, Level = EventLevel.Informational)]
-        public void AcceptStop()
-        {
-            if (IsEnabled(EventLevel.Informational, EventKeywords.All))
-            {
-                WriteEvent(eventId: 6);
-            }
-        }
-
-        [Event(7, Level = EventLevel.Error)]
-        public void AcceptFailed(SocketError error, string? exceptionMessage)
+        [Event(6, Level = EventLevel.Error)]
+        private void AcceptFailed(SocketError error, string? exceptionMessage)
         {
             if (IsEnabled(EventLevel.Error, EventKeywords.All))
             {
-                WriteEvent(eventId: 7, (int)error, exceptionMessage ?? string.Empty);
+                WriteEvent(eventId: 6, (int)error, exceptionMessage);
             }
         }
 
         [NonEvent]
         public void ConnectStart(Internals.SocketAddress address)
         {
-            ConnectStart(address.ToString());
-        }
-
-        [NonEvent]
-        public void ConnectStart(EndPoint address)
-        {
-            ConnectStart(address.ToString());
+            if (IsEnabled(EventLevel.Informational, EventKeywords.All))
+            {
+                ConnectStart(address.ToString());
+            }
         }
 
         [NonEvent]
-        public void ConnectCanceledAndStop()
+        public void AfterConnect(SocketError error, string? exceptionMessage = null)
         {
-            ConnectCanceled();
-            ConnectStop();
-        }
+            if (error == SocketError.Success)
+            {
+                Debug.Assert(exceptionMessage is null);
+                Interlocked.Increment(ref _outgoingConnectionsEstablished);
+            }
+            else
+            {
+                ConnectFailed(error, exceptionMessage);
+            }
 
-        [NonEvent]
-        public void ConnectFailedAndStop(SocketError error, string? exceptionMessage)
-        {
-            ConnectFailed(error, exceptionMessage);
             ConnectStop();
         }
 
         [NonEvent]
         public void AcceptStart(Internals.SocketAddress address)
         {
-            AcceptStart(address.ToString());
+            if (IsEnabled(EventLevel.Informational, EventKeywords.All))
+            {
+                AcceptStart(address.ToString());
+            }
         }
 
         [NonEvent]
         public void AcceptStart(EndPoint address)
         {
-            AcceptStart(address.ToString());
+            if (IsEnabled(EventLevel.Informational, EventKeywords.All))
+            {
+                AcceptStart(address.ToString());
+            }
         }
 
         [NonEvent]
-        public void AcceptFailedAndStop(SocketError error, string? exceptionMessage)
+        public void AfterAccept(SocketError error, string? exceptionMessage = null)
         {
-            AcceptFailed(error, exceptionMessage);
+            if (error == SocketError.Success)
+            {
+                Debug.Assert(exceptionMessage is null);
+                Interlocked.Increment(ref _incomingConnectionsEstablished);
+            }
+            else
+            {
+                AcceptFailed(error, exceptionMessage);
+            }
+
             AcceptStop();
         }
 
         [NonEvent]
         public void BytesReceived(int count)
         {
+            Debug.Assert(count >= 0);
             Interlocked.Add(ref _bytesReceived, count);
         }
 
         [NonEvent]
         public void BytesSent(int count)
         {
+            Debug.Assert(count >= 0);
             Interlocked.Add(ref _bytesSent, count);
         }
 
index 783fbf5..8f834a5 100644 (file)
@@ -1,14 +1,14 @@
 ï»¿// Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
-using System.Collections;
 using System.Collections.Concurrent;
 using System.Collections.Generic;
-using System.Diagnostics;
 using System.Diagnostics.Tracing;
 using System.Linq;
+using System.Threading;
 using System.Threading.Tasks;
 using Microsoft.DotNet.RemoteExecutor;
+using Microsoft.DotNet.XUnitExtensions;
 using Xunit;
 using Xunit.Abstractions;
 
@@ -35,16 +35,323 @@ namespace System.Net.Sockets.Tests
             Assert.NotEmpty(EventSource.GenerateManifest(esType, esType.Assembly.Location));
         }
 
+        public static IEnumerable<object[]> SocketMethods_MemberData()
+        {
+            yield return new[] { "Sync" };
+            yield return new[] { "Task" };
+            yield return new[] { "Apm" };
+            yield return new[] { "Eap" };
+        }
+
+        public static IEnumerable<object[]> SocketMethods_Matrix_MemberData()
+        {
+            return from connectMethod in SocketMethods_MemberData()
+                   from acceptMethod in SocketMethods_MemberData()
+                   select new[] { connectMethod[0], acceptMethod[0] };
+        }
+
+        public static IEnumerable<object[]> SocketMethods_WithBools_MemberData()
+        {
+            return from connectMethod in SocketMethods_MemberData()
+                   from useDnsEndPoint in new[] { true, false }
+                   select new[] { connectMethod[0], useDnsEndPoint };
+        }
+
+        private static async Task<EndPoint> GetRemoteEndPointAsync(string useDnsEndPointString, int port)
+        {
+            const string Address = "microsoft.com";
+
+            if (bool.Parse(useDnsEndPointString))
+            {
+                return new DnsEndPoint(Address, port);
+            }
+            else
+            {
+                IPAddress ip = (await Dns.GetHostAddressesAsync(Address))[0];
+                return new IPEndPoint(ip, port);
+            }
+        }
+
+        // RemoteExecutor only supports simple argument types such as strings
+        // That's why we use this helper method instead of returning SocketHelperBases from MemberDatas directly
+        private static SocketHelperBase GetHelperBase(string socketMethod)
+        {
+            return socketMethod switch
+            {
+                "Sync" => new SocketHelperArraySync(),
+                "Task" => new SocketHelperTask(),
+                "Apm" => new SocketHelperApm(),
+                "Eap" => new SocketHelperEap(),
+                _ => throw new ArgumentException(socketMethod)
+            };
+        }
+
+        [OuterLoop]
+        [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+        [MemberData(nameof(SocketMethods_Matrix_MemberData))]
+        public void EventSource_SocketConnectsLoopback_LogsConnectAcceptStartStop(string connectMethod, string acceptMethod)
+        {
+            RemoteExecutor.Invoke(async (connectMethod, acceptMethod) =>
+            {
+                using var listener = new TestEventListener("System.Net.Sockets", EventLevel.Verbose, 0.1);
+                listener.AddActivityTracking();
+
+                var events = new ConcurrentQueue<(EventWrittenEventArgs Event, Guid ActivityId)>();
+                await listener.RunWithCallbackAsync(e => events.Enqueue((e, e.ActivityId)), async () =>
+                {
+                    using var server = new Socket(SocketType.Stream, ProtocolType.Tcp);
+                    server.Bind(new IPEndPoint(IPAddress.Loopback, 0));
+                    server.Listen();
+
+                    using var client = new Socket(SocketType.Stream, ProtocolType.Tcp);
+
+                    Task acceptTask = GetHelperBase(acceptMethod).AcceptAsync(server);
+                    await WaitForEventAsync(events, "AcceptStart");
+
+                    await GetHelperBase(connectMethod).ConnectAsync(client, server.LocalEndPoint);
+                    await acceptTask;
+
+                    await WaitForEventAsync(events, "AcceptStop");
+                    await WaitForEventAsync(events, "ConnectStop");
+
+                    await WaitForEventCountersAsync(events);
+                });
+                Assert.DoesNotContain(events, ev => ev.Event.EventId == 0); // errors from the EventSource itself
+
+                VerifyStartStopEvents(events, connect: true, expectedCount: 1);
+                VerifyStartStopEvents(events, connect: false, expectedCount: 1);
+
+                Assert.DoesNotContain(events, e => e.Event.EventName == "ConnectFailed");
+                Assert.DoesNotContain(events, e => e.Event.EventName == "AcceptFailed");
+
+                VerifyEventCounters(events, connectCount: 1);
+            }, connectMethod, acceptMethod).Dispose();
+        }
+
+        [OuterLoop]
+        [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+        [MemberData(nameof(SocketMethods_WithBools_MemberData))]
+        public void EventSource_SocketConnectsRemote_LogsConnectStartStop(string connectMethod, bool useDnsEndPoint)
+        {
+            RemoteExecutor.Invoke(async (connectMethod, useDnsEndPointString) =>
+            {
+                using var listener = new TestEventListener("System.Net.Sockets", EventLevel.Verbose, 0.1);
+                listener.AddActivityTracking();
+
+                var events = new ConcurrentQueue<(EventWrittenEventArgs Event, Guid ActivityId)>();
+                await listener.RunWithCallbackAsync(e => events.Enqueue((e, e.ActivityId)), async () =>
+                {
+                    using var client = new Socket(SocketType.Stream, ProtocolType.Tcp);
+
+                    SocketHelperBase socketHelper = GetHelperBase(connectMethod);
+
+                    EndPoint endPoint = await GetRemoteEndPointAsync(useDnsEndPointString, port: 443);
+                    await socketHelper.ConnectAsync(client, endPoint);
+
+                    await WaitForEventAsync(events, "ConnectStop");
+
+                    await WaitForEventCountersAsync(events);
+                });
+                Assert.DoesNotContain(events, ev => ev.Event.EventId == 0); // errors from the EventSource itself
+
+                VerifyStartStopEvents(events, connect: true, expectedCount: 1);
+
+                Assert.DoesNotContain(events, e => e.Event.EventName == "ConnectFailed");
+
+                VerifyEventCounters(events, connectCount: 1, connectOnly: true);
+            }, connectMethod, useDnsEndPoint.ToString()).Dispose();
+        }
+
+        [OuterLoop]
+        [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+        [PlatformSpecific(~(TestPlatforms.OSX | TestPlatforms.FreeBSD))] // Same as Connect.ConnectGetsCanceledByDispose
+        [MemberData(nameof(SocketMethods_WithBools_MemberData))]
+        public void EventSource_SocketConnectFailure_LogsConnectFailed(string connectMethod, bool useDnsEndPoint)
+        {
+            if (connectMethod == "Sync" && PlatformDetection.IsRedHatFamily7)
+            {
+                // [ActiveIssue("https://github.com/dotnet/runtime/issues/42686")]
+                throw new SkipTestException("Disposing a Socket performing a sync operation can hang on RedHat7 systems");
+            }
+
+            RemoteExecutor.Invoke(async (connectMethod, useDnsEndPointString) =>
+            {
+                EndPoint endPoint = await GetRemoteEndPointAsync(useDnsEndPointString, port: 12345);
+
+                using var listener = new TestEventListener("System.Net.Sockets", EventLevel.Verbose, 0.1);
+                listener.AddActivityTracking();
+
+                var events = new ConcurrentQueue<(EventWrittenEventArgs Event, Guid ActivityId)>();
+                await listener.RunWithCallbackAsync(e => events.Enqueue((e, e.ActivityId)), async () =>
+                {
+                    using var client = new Socket(SocketType.Stream, ProtocolType.Tcp);
+
+                    SocketHelperBase socketHelper = GetHelperBase(connectMethod);
+
+                    Exception ex = await Assert.ThrowsAnyAsync<Exception>(async () =>
+                    {
+                        Task connectTask = socketHelper.ConnectAsync(client, endPoint);
+                        await WaitForEventAsync(events, "ConnectStart");
+                        Task disposeTask = Task.Run(() => client.Dispose());
+                        await new[] { connectTask, disposeTask }.WhenAllOrAnyFailed();
+                    });
+
+                    if (ex is SocketException se)
+                    {
+                        Assert.NotEqual(SocketError.TimedOut, se.SocketErrorCode);
+                    }
+
+                    await WaitForEventAsync(events, "ConnectStop");
+
+                    await WaitForEventCountersAsync(events);
+                });
+
+                VerifyConnectFailureEvents(events);
+            }, connectMethod, useDnsEndPoint.ToString()).Dispose();
+        }
+
+        [OuterLoop]
+        [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+        [MemberData(nameof(SocketMethods_MemberData))]
+        public void EventSource_SocketAcceptFailure_LogsAcceptFailed(string acceptMethod)
+        {
+            if (acceptMethod == "Sync" && PlatformDetection.IsRedHatFamily7)
+            {
+                // [ActiveIssue("https://github.com/dotnet/runtime/issues/42686")]
+                throw new SkipTestException("Disposing a Socket performing a sync operation can hang on RedHat7 systems");
+            }
+
+            RemoteExecutor.Invoke(async acceptMethod =>
+            {
+                using var listener = new TestEventListener("System.Net.Sockets", EventLevel.Verbose, 0.1);
+                listener.AddActivityTracking();
+
+                var events = new ConcurrentQueue<(EventWrittenEventArgs Event, Guid ActivityId)>();
+                await listener.RunWithCallbackAsync(e => events.Enqueue((e, e.ActivityId)), async () =>
+                {
+                    using var server = new Socket(SocketType.Stream, ProtocolType.Tcp);
+                    server.Bind(new IPEndPoint(IPAddress.Loopback, 0));
+                    server.Listen();
+
+                    await Assert.ThrowsAnyAsync<Exception>(async () =>
+                    {
+                        Task acceptTask = GetHelperBase(acceptMethod).AcceptAsync(server);
+                        await WaitForEventAsync(events, "AcceptStart");
+                        Task disposeTask = Task.Run(() => server.Dispose());
+                        await new[] { acceptTask, disposeTask }.WhenAllOrAnyFailed();
+                    });
+
+                    await WaitForEventAsync(events, "AcceptStop");
+
+                    await WaitForEventCountersAsync(events);
+                });
+                Assert.DoesNotContain(events, ev => ev.Event.EventId == 0); // errors from the EventSource itself
+
+                VerifyStartStopEvents(events, connect: false, expectedCount: 1);
+
+                (EventWrittenEventArgs Event, Guid ActivityId) failed = Assert.Single(events, e => e.Event.EventName == "AcceptFailed");
+                Assert.Equal(2, failed.Event.Payload.Count);
+                Assert.True(Enum.IsDefined((SocketError)failed.Event.Payload[0]));
+                Assert.IsType<string>(failed.Event.Payload[1]);
+
+                (_, Guid startActivityId) = Assert.Single(events, e => e.Event.EventName == "AcceptStart");
+                Assert.Equal(startActivityId, failed.ActivityId);
+
+                VerifyEventCounters(events, connectCount: 0);
+            }, acceptMethod).Dispose();
+        }
+
+        [OuterLoop]
+        [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+        [InlineData("Task", true)]
+        [InlineData("Task", false)]
+        [InlineData("Eap", true)]
+        [InlineData("Eap", false)]
+        public void EventSource_ConnectAsyncCanceled_LogsConnectFailed(string connectMethod, bool useDnsEndPoint)
+        {
+            RemoteExecutor.Invoke(async (connectMethod, useDnsEndPointString) =>
+            {
+                EndPoint endPoint = await GetRemoteEndPointAsync(useDnsEndPointString, port: 12345);
+
+                using var listener = new TestEventListener("System.Net.Sockets", EventLevel.Verbose, 0.1);
+                listener.AddActivityTracking();
+
+                var events = new ConcurrentQueue<(EventWrittenEventArgs Event, Guid ActivityId)>();
+                await listener.RunWithCallbackAsync(e => events.Enqueue((e, e.ActivityId)), async () =>
+                {
+                    using var client = new Socket(SocketType.Stream, ProtocolType.Tcp);
+
+                    await Assert.ThrowsAnyAsync<Exception>(async () =>
+                    {
+                        switch (connectMethod)
+                        {
+                            case "Task":
+                                using (var cts = new CancellationTokenSource())
+                                {
+                                    ValueTask connectTask = client.ConnectAsync(endPoint, cts.Token);
+                                    await WaitForEventAsync(events, "ConnectStart");
+                                    cts.Cancel();
+                                    await connectTask;
+                                }
+                                break;
+
+                            case "Eap":
+                                using (var saea = new SocketAsyncEventArgs())
+                                {
+                                    var tcs = new TaskCompletionSource();
+                                    saea.RemoteEndPoint = endPoint;
+                                    saea.Completed += (_, __) =>
+                                    {
+                                        Assert.NotEqual(SocketError.Success, saea.SocketError);
+                                        tcs.SetException(new SocketException((int)saea.SocketError));
+                                    };
+                                    Assert.True(client.ConnectAsync(saea));
+                                    await WaitForEventAsync(events, "ConnectStart");
+                                    Socket.CancelConnectAsync(saea);
+                                    await tcs.Task;
+                                }
+                                break;
+                        }
+                    });
+
+                    await WaitForEventAsync(events, "ConnectStop");
+
+                    await WaitForEventCountersAsync(events);
+                });
+
+                VerifyConnectFailureEvents(events);
+            }, connectMethod, useDnsEndPoint.ToString()).Dispose();
+        }
+
+        private static void VerifyConnectFailureEvents(ConcurrentQueue<(EventWrittenEventArgs Event, Guid ActivityId)> events)
+        {
+            Assert.DoesNotContain(events, ev => ev.Event.EventId == 0); // errors from the EventSource itself
+
+            VerifyStartStopEvents(events, connect: true, expectedCount: 1);
+
+            (EventWrittenEventArgs Event, Guid ActivityId) failed = Assert.Single(events, e => e.Event.EventName == "ConnectFailed");
+            Assert.Equal(2, failed.Event.Payload.Count);
+            Assert.True(Enum.IsDefined((SocketError)failed.Event.Payload[0]));
+            Assert.IsType<string>(failed.Event.Payload[1]);
+
+            (_, Guid startActivityId) = Assert.Single(events, e => e.Event.EventName == "ConnectStart");
+            Assert.Equal(startActivityId, failed.ActivityId);
+
+            VerifyEventCounters(events, connectCount: 0);
+        }
+
         [OuterLoop]
         [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
         public void EventSource_EventsRaisedAsExpected()
         {
-            RemoteExecutor.Invoke(() =>
+            RemoteExecutor.Invoke(async () =>
             {
                 using (var listener = new TestEventListener("System.Net.Sockets", EventLevel.Verbose, 0.1))
                 {
-                    var events = new ConcurrentQueue<EventWrittenEventArgs>();
-                    listener.RunWithCallbackAsync(events.Enqueue, async () =>
+                    listener.AddActivityTracking();
+
+                    var events = new ConcurrentQueue<(EventWrittenEventArgs Event, Guid ActivityId)>();
+                    await listener.RunWithCallbackAsync(e => events.Enqueue((e, e.ActivityId)), async () =>
                     {
                         // Invoke several tests to execute code paths while tracing is enabled
 
@@ -65,36 +372,116 @@ namespace System.Net.Sockets.Tests
 
                         await new NetworkStreamTest().CopyToAsync_AllDataCopied(4096, true).ConfigureAwait(false);
                         await new NetworkStreamTest().Timeout_ValidData_Roundtrips().ConfigureAwait(false);
-                        await Task.Delay(300).ConfigureAwait(false);
-                    }).Wait();
-                    Assert.DoesNotContain(events, ev => ev.EventId == 0); // errors from the EventSource itself
-                    VerifyEvents(events, "ConnectStart", 10);
-                    VerifyEvents(events, "ConnectStop", 10);
-
-                    Dictionary<string, double> eventCounters = events.Where(e => e.EventName == "EventCounters").Select(e => (IDictionary<string, object>) e.Payload.Single())
-                        .GroupBy(d => (string)d["Name"], d => (double)d["Mean"], (k, v) => new { Name = k, Value = v.Sum() })
-                        .ToDictionary(p => p.Name, p => p.Value);
-
-                    VerifyEventCounter("incoming-connections-established", eventCounters);
-                    VerifyEventCounter("outgoing-connections-established", eventCounters);
-                    VerifyEventCounter("bytes-received", eventCounters);
-                    VerifyEventCounter("bytes-sent", eventCounters);
-                    VerifyEventCounter("datagrams-received", eventCounters);
-                    VerifyEventCounter("datagrams-sent", eventCounters);
+
+                        await WaitForEventCountersAsync(events);
+                    });
+                    Assert.DoesNotContain(events, ev => ev.Event.EventId == 0); // errors from the EventSource itself
+
+                    VerifyStartStopEvents(events, connect: true, expectedCount: 10);
+
+                    Assert.DoesNotContain(events, e => e.Event.EventName == "ConnectFailed");
+
+                    VerifyEventCounters(events, connectCount: 10, shouldHaveTransferedBytes: true, shouldHaveDatagrams: true);
                 }
             }).Dispose();
         }
 
-        private static void VerifyEvents(IEnumerable<EventWrittenEventArgs> events, string eventName, int expectedCount)
+        private static void VerifyStartStopEvents(ConcurrentQueue<(EventWrittenEventArgs Event, Guid ActivityId)> events, bool connect, int expectedCount)
         {
-            EventWrittenEventArgs[] starts = events.Where(e => e.EventName == eventName).ToArray();
+            string startName = connect ? "ConnectStart" : "AcceptStart";
+            (EventWrittenEventArgs Event, Guid ActivityId)[] starts = events.Where(e => e.Event.EventName == startName).ToArray();
             Assert.Equal(expectedCount, starts.Length);
+            foreach ((EventWrittenEventArgs Event, _) in starts)
+            {
+                object startPayload = Assert.Single(Event.Payload);
+                Assert.False(string.IsNullOrWhiteSpace(startPayload as string));
+            }
+
+            string stopName = connect ? "ConnectStop" : "AcceptStop";
+            (EventWrittenEventArgs Event, Guid ActivityId)[] stops = events.Where(e => e.Event.EventName == stopName).ToArray();
+            Assert.Equal(expectedCount, stops.Length);
+            Assert.All(stops, stop => Assert.Empty(stop.Event.Payload));
+
+            for (int i = 0; i < expectedCount; i++)
+            {
+                Assert.NotEqual(Guid.Empty, starts[i].ActivityId);
+                Assert.Equal(starts[i].ActivityId, stops[i].ActivityId);
+            }
+        }
+
+        private static async Task WaitForEventAsync(ConcurrentQueue<(EventWrittenEventArgs Event, Guid ActivityId)> events, string name)
+        {
+            DateTime startTime = DateTime.UtcNow;
+            while (!events.Any(e => e.Event.EventName == name))
+            {
+                if (DateTime.UtcNow.Subtract(startTime) > TimeSpan.FromSeconds(30))
+                    throw new TimeoutException($"Timed out waiting for {name}");
+
+                await Task.Delay(100);
+            }
+        }
+
+        private static async Task WaitForEventCountersAsync(ConcurrentQueue<(EventWrittenEventArgs Event, Guid ActivityId)> events)
+        {
+            DateTime startTime = DateTime.UtcNow;
+            int startCount = events.Count;
+
+            while (events.Skip(startCount).Count(e => IsBytesSentEventCounter(e.Event)) < 2)
+            {
+                if (DateTime.UtcNow.Subtract(startTime) > TimeSpan.FromSeconds(30))
+                    throw new TimeoutException($"Timed out waiting for EventCounters");
+
+                await Task.Delay(100);
+            }
+
+            static bool IsBytesSentEventCounter(EventWrittenEventArgs e)
+            {
+                if (e.EventName != "EventCounters")
+                    return false;
+
+                var dictionary = (IDictionary<string, object>)e.Payload.Single();
+
+                return (string)dictionary["Name"] == "bytes-sent";
+            }
         }
 
-        private static void VerifyEventCounter(string name, Dictionary<string, double> eventCounters)
+        private static void VerifyEventCounters(ConcurrentQueue<(EventWrittenEventArgs Event, Guid ActivityId)> events, int connectCount, bool connectOnly = false, bool shouldHaveTransferedBytes = false, bool shouldHaveDatagrams = false)
         {
-            Assert.True(eventCounters.ContainsKey(name));
-            Assert.True(eventCounters[name] > 0);
+            Dictionary<string, double[]> eventCounters = events
+                .Where(e => e.Event.EventName == "EventCounters")
+                .Select(e => (IDictionary<string, object>)e.Event.Payload.Single())
+                .GroupBy(d => (string)d["Name"], d => (double)(d.ContainsKey("Mean") ? d["Mean"] : d["Increment"]))
+                .ToDictionary(p => p.Key, p => p.ToArray());
+
+            Assert.True(eventCounters.TryGetValue("outgoing-connections-established", out double[] outgoingConnections));
+            Assert.Equal(connectCount, outgoingConnections[^1]);
+
+            Assert.True(eventCounters.TryGetValue("incoming-connections-established", out double[] incomingConnections));
+            Assert.Equal(connectOnly ? 0 : connectCount, incomingConnections[^1]);
+
+            Assert.True(eventCounters.TryGetValue("bytes-received", out double[] bytesReceived));
+            if (shouldHaveTransferedBytes)
+            {
+                Assert.True(bytesReceived[^1] > 0);
+            }
+
+            Assert.True(eventCounters.TryGetValue("bytes-sent", out double[] bytesSent));
+            if (shouldHaveTransferedBytes)
+            {
+                Assert.True(bytesSent[^1] > 0);
+            }
+
+            Assert.True(eventCounters.TryGetValue("datagrams-received", out double[] datagramsReceived));
+            if (shouldHaveDatagrams)
+            {
+                Assert.True(datagramsReceived[^1] > 0);
+            }
+
+            Assert.True(eventCounters.TryGetValue("datagrams-sent", out double[] datagramsSent));
+            if (shouldHaveDatagrams)
+            {
+                Assert.True(datagramsSent[^1] > 0);
+            }
         }
     }
 }