Consume DistributedContextPropagator in DiagnosticsHandler (#55392)
authorMiha Zupan <mihazupan.zupan1@gmail.com>
Tue, 13 Jul 2021 23:44:57 +0000 (16:44 -0700)
committerGitHub <noreply@github.com>
Tue, 13 Jul 2021 23:44:57 +0000 (16:44 -0700)
12 files changed:
src/libraries/System.Diagnostics.DiagnosticSource/tests/PropagatorTests.cs
src/libraries/System.Net.Http/ref/System.Net.Http.cs
src/libraries/System.Net.Http/ref/System.Net.Http.csproj
src/libraries/System.Net.Http/src/System/Net/Http/BrowserHttpHandler/SocketsHttpHandler.cs
src/libraries/System.Net.Http/src/System/Net/Http/DiagnosticsHandler.cs
src/libraries/System.Net.Http/src/System/Net/Http/DiagnosticsHandlerLoggingStrings.cs
src/libraries/System.Net.Http/src/System/Net/Http/HttpClientHandler.cs
src/libraries/System.Net.Http/src/System/Net/Http/HttpRequestMessage.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionSettings.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RedirectHandler.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs
src/libraries/System.Net.Http/tests/FunctionalTests/DiagnosticsTests.cs

index 2a914f62ad5309604a18a91210dc49c2b2aa57da..5935ad4770e268cefcf5a598302a7e52c580663d 100644 (file)
@@ -527,6 +527,14 @@ namespace System.Diagnostics.Tests
             return list;
         }
 
+        [Fact]
+        public void TestBuiltInPropagatorsAreCached()
+        {
+            Assert.Same(DistributedContextPropagator.CreateDefaultPropagator(), DistributedContextPropagator.CreateDefaultPropagator());
+            Assert.Same(DistributedContextPropagator.CreateNoOutputPropagator(), DistributedContextPropagator.CreateNoOutputPropagator());
+            Assert.Same(DistributedContextPropagator.CreatePassThroughPropagator(), DistributedContextPropagator.CreatePassThroughPropagator());
+        }
+
         [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
         public void TestCustomPropagator()
         {
index 3cb29a543b0770de8eef749047b7d33c5054650e..ef465d5a3d05384d001c7cf000e95066c6528925 100644 (file)
@@ -394,6 +394,8 @@ namespace System.Net.Http
         public bool EnableMultipleHttp2Connections { get { throw null; } set { } }
         public Func<SocketsHttpConnectionContext, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask<System.IO.Stream>>? ConnectCallback { get { throw null; } set { } }
         public Func<SocketsHttpPlaintextStreamFilterContext, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask<System.IO.Stream>>? PlaintextStreamFilter { get { throw null; } set { } }
+        [System.CLSCompliantAttribute(false)]
+        public System.Diagnostics.DistributedContextPropagator? ActivityHeadersPropagator { get { throw null; } set { } }
     }
     public sealed class SocketsHttpConnectionContext
     {
index ae6a8158fa04eb698560bd5702dd4efdb16b7db1..1b20e03c7905a9224a2195d6fe370fe60289ba5f 100644 (file)
@@ -14,5 +14,6 @@
     <ProjectReference Include="..\..\System.Net.Security\ref\System.Net.Security.csproj" />
     <ProjectReference Include="..\..\System.Security.Cryptography.X509Certificates\ref\System.Security.Cryptography.X509Certificates.csproj" />
     <ProjectReference Include="..\..\System.Text.Encoding\ref\System.Text.Encoding.csproj" />
+    <ProjectReference Include="..\..\System.Diagnostics.DiagnosticSource\ref\System.Diagnostics.DiagnosticSource.csproj" />
   </ItemGroup>
 </Project>
index 4805010002b2240d8ba5105b708bdfa0fd89f857..1cd0dbdb7a645fd4a6e8ad81b20ddcec63f315e5 100644 (file)
@@ -3,13 +3,12 @@
 
 using System.Collections.Generic;
 using System.IO;
-using System.Net.Quic;
-using System.Net.Quic.Implementations;
 using System.Net.Security;
 using System.Runtime.Versioning;
 using System.Threading;
 using System.Threading.Tasks;
 using System.Diagnostics.CodeAnalysis;
+using System.Diagnostics;
 
 namespace System.Net.Http
 {
@@ -173,6 +172,13 @@ namespace System.Net.Http
             set => throw new PlatformNotSupportedException();
         }
 
+        [CLSCompliant(false)]
+        public DistributedContextPropagator? ActivityHeadersPropagator
+        {
+            get => throw new PlatformNotSupportedException();
+            set => throw new PlatformNotSupportedException();
+        }
+
         protected internal override Task<HttpResponseMessage> SendAsync(
             HttpRequestMessage request, CancellationToken cancellationToken) => throw new PlatformNotSupportedException();
 
index 21b161f9fbb4a083859957eb8284318a9a58f4e5..75954afccab501772ba79238009c5ba56286972e 100644 (file)
@@ -13,36 +13,59 @@ namespace System.Net.Http
     /// <summary>
     /// DiagnosticHandler notifies DiagnosticSource subscribers about outgoing Http requests
     /// </summary>
-    internal sealed class DiagnosticsHandler : DelegatingHandler
+    internal sealed class DiagnosticsHandler : HttpMessageHandlerStage
     {
         private static readonly DiagnosticListener s_diagnosticListener =
                 new DiagnosticListener(DiagnosticsHandlerLoggingStrings.DiagnosticListenerName);
 
-        /// <summary>
-        /// DiagnosticHandler constructor
-        /// </summary>
-        /// <param name="innerHandler">Inner handler: Windows or Unix implementation of HttpMessageHandler.
-        /// Note that DiagnosticHandler is the latest in the pipeline </param>
-        public DiagnosticsHandler(HttpMessageHandler innerHandler) : base(innerHandler)
+        private readonly HttpMessageHandler _innerHandler;
+        private readonly DistributedContextPropagator _propagator;
+        private readonly HeaderDescriptor[]? _propagatorFields;
+
+        public DiagnosticsHandler(HttpMessageHandler innerHandler, DistributedContextPropagator propagator, bool autoRedirect = false)
         {
+            Debug.Assert(IsGloballyEnabled());
+            Debug.Assert(innerHandler is not null && propagator is not null);
+
+            _innerHandler = innerHandler;
+            _propagator = propagator;
+
+            // Prepare HeaderDescriptors for fields we need to clear when following redirects
+            if (autoRedirect && _propagator.Fields is IReadOnlyCollection<string> fields && fields.Count > 0)
+            {
+                var fieldDescriptors = new List<HeaderDescriptor>(fields.Count);
+                foreach (string field in fields)
+                {
+                    if (field is not null && HeaderDescriptor.TryGet(field, out HeaderDescriptor descriptor))
+                    {
+                        fieldDescriptors.Add(descriptor);
+                    }
+                }
+                _propagatorFields = fieldDescriptors.ToArray();
+            }
         }
 
-        internal static bool IsEnabled()
+        private static bool IsEnabled()
         {
-            // check if there is a parent Activity (and propagation is not suppressed)
-            // or if someone listens to HttpHandlerDiagnosticListener
-            return IsGloballyEnabled() && (Activity.Current != null || s_diagnosticListener.IsEnabled());
+            // check if there is a parent Activity or if someone listens to HttpHandlerDiagnosticListener
+            return Activity.Current != null || s_diagnosticListener.IsEnabled();
         }
 
         internal static bool IsGloballyEnabled() => GlobalHttpSettings.DiagnosticsHandler.EnableActivityPropagation;
 
-        // SendAsyncCore returns already completed ValueTask for when async: false is passed.
-        // Internally, it calls the synchronous Send method of the base class.
-        protected internal override HttpResponseMessage Send(HttpRequestMessage request, CancellationToken cancellationToken) =>
-            SendAsyncCore(request, async: false, cancellationToken).AsTask().GetAwaiter().GetResult();
-
-        protected internal override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) =>
-            SendAsyncCore(request, async: true, cancellationToken).AsTask();
+        internal override ValueTask<HttpResponseMessage> SendAsync(HttpRequestMessage request, bool async, CancellationToken cancellationToken)
+        {
+            if (IsEnabled())
+            {
+                return SendAsyncCore(request, async, cancellationToken);
+            }
+            else
+            {
+                return async ?
+                    new ValueTask<HttpResponseMessage>(_innerHandler.SendAsync(request, cancellationToken)) :
+                    new ValueTask<HttpResponseMessage>(_innerHandler.Send(request, cancellationToken));
+            }
+        }
 
         private async ValueTask<HttpResponseMessage> SendAsyncCore(HttpRequestMessage request, bool async,
             CancellationToken cancellationToken)
@@ -58,6 +81,16 @@ namespace System.Net.Http
                 throw new ArgumentNullException(nameof(request), SR.net_http_handler_norequest);
             }
 
+            // Since we are reusing the request message instance on redirects, clear any existing headers
+            // Do so before writing DiagnosticListener events as instrumentations use those to inject headers
+            if (request.WasRedirected() && _propagatorFields is HeaderDescriptor[] fields)
+            {
+                foreach (HeaderDescriptor field in fields)
+                {
+                    request.Headers.Remove(field);
+                }
+            }
+
             Activity? activity = null;
             DiagnosticListener diagnosticListener = s_diagnosticListener;
 
@@ -72,8 +105,8 @@ namespace System.Net.Http
                 try
                 {
                     return async ?
-                        await base.SendAsync(request, cancellationToken).ConfigureAwait(false) :
-                        base.Send(request, cancellationToken);
+                        await _innerHandler.SendAsync(request, cancellationToken).ConfigureAwait(false) :
+                        _innerHandler.Send(request, cancellationToken);
                 }
                 finally
                 {
@@ -119,8 +152,8 @@ namespace System.Net.Http
             try
             {
                 response = async ?
-                    await base.SendAsync(request, cancellationToken).ConfigureAwait(false) :
-                    base.Send(request, cancellationToken);
+                    await _innerHandler.SendAsync(request, cancellationToken).ConfigureAwait(false) :
+                    _innerHandler.Send(request, cancellationToken);
                 return response;
             }
             catch (OperationCanceledException)
@@ -170,6 +203,16 @@ namespace System.Net.Http
             }
         }
 
+        protected override void Dispose(bool disposing)
+        {
+            if (disposing)
+            {
+                _innerHandler.Dispose();
+            }
+
+            base.Dispose(disposing);
+        }
+
         #region private
 
         private sealed class ActivityStartData
@@ -269,42 +312,18 @@ namespace System.Net.Http
             public override string ToString() => $"{{ {nameof(Response)} = {Response}, {nameof(LoggingRequestId)} = {LoggingRequestId}, {nameof(Timestamp)} = {Timestamp}, {nameof(RequestTaskStatus)} = {RequestTaskStatus} }}";
         }
 
-        private static void InjectHeaders(Activity currentActivity, HttpRequestMessage request)
+        private void InjectHeaders(Activity currentActivity, HttpRequestMessage request)
         {
-            if (currentActivity.IdFormat == ActivityIdFormat.W3C)
+            _propagator.Inject(currentActivity, request, static (carrier, key, value) =>
             {
-                if (!request.Headers.Contains(DiagnosticsHandlerLoggingStrings.TraceParentHeaderName))
+                if (carrier is HttpRequestMessage request &&
+                    key is not null &&
+                    HeaderDescriptor.TryGet(key, out HeaderDescriptor descriptor) &&
+                    !request.Headers.TryGetHeaderValue(descriptor, out _))
                 {
-                    request.Headers.TryAddWithoutValidation(DiagnosticsHandlerLoggingStrings.TraceParentHeaderName, currentActivity.Id);
-                    if (currentActivity.TraceStateString != null)
-                    {
-                        request.Headers.TryAddWithoutValidation(DiagnosticsHandlerLoggingStrings.TraceStateHeaderName, currentActivity.TraceStateString);
-                    }
+                    request.Headers.TryAddWithoutValidation(descriptor, value);
                 }
-            }
-            else
-            {
-                if (!request.Headers.Contains(DiagnosticsHandlerLoggingStrings.RequestIdHeaderName))
-                {
-                    request.Headers.TryAddWithoutValidation(DiagnosticsHandlerLoggingStrings.RequestIdHeaderName, currentActivity.Id);
-                }
-            }
-
-            // we expect baggage to be empty or contain a few items
-            using (IEnumerator<KeyValuePair<string, string?>> e = currentActivity.Baggage.GetEnumerator())
-            {
-                if (e.MoveNext())
-                {
-                    var baggage = new List<string>();
-                    do
-                    {
-                        KeyValuePair<string, string?> item = e.Current;
-                        baggage.Add(new NameValueHeaderValue(WebUtility.UrlEncode(item.Key), WebUtility.UrlEncode(item.Value)).ToString());
-                    }
-                    while (e.MoveNext());
-                    request.Headers.TryAddWithoutValidation(DiagnosticsHandlerLoggingStrings.CorrelationContextHeaderName, baggage);
-                }
-            }
+            });
         }
 
         [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026:UnrecognizedReflectionPattern",
index 0fa57394c1cc326d545e8e964be34b58dff0d0b1..cd91daaed3cbcfa48b7098bc359f03a0219af839 100644 (file)
@@ -15,11 +15,5 @@ namespace System.Net.Http
         public const string ExceptionEventName = "System.Net.Http.Exception";
         public const string ActivityName = "System.Net.Http.HttpRequestOut";
         public const string ActivityStartName = "System.Net.Http.HttpRequestOut.Start";
-
-        public const string RequestIdHeaderName = "Request-Id";
-        public const string CorrelationContextHeaderName = "Correlation-Context";
-
-        public const string TraceParentHeaderName = "traceparent";
-        public const string TraceStateHeaderName = "tracestate";
     }
 }
index 2e3289643cfbed5cebfcfbfc79ba3c9da08e752a..fce5166f279edf35995cfdea5d5d8e6584ea5b79 100644 (file)
@@ -9,6 +9,7 @@ using System.Security.Authentication;
 using System.Security.Cryptography.X509Certificates;
 using System.Threading;
 using System.Threading.Tasks;
+using System.Diagnostics;
 #if TARGET_BROWSER
 using HttpHandlerType = System.Net.Http.BrowserHttpHandler;
 #else
@@ -20,7 +21,14 @@ namespace System.Net.Http
     public partial class HttpClientHandler : HttpMessageHandler
     {
         private readonly HttpHandlerType _underlyingHandler;
-        private readonly DiagnosticsHandler? _diagnosticsHandler;
+
+        private HttpMessageHandler Handler
+#if TARGET_BROWSER
+            { get; }
+#else
+            => _underlyingHandler;
+#endif
+
         private ClientCertificateOption _clientCertificateOptions;
 
         private volatile bool _disposed;
@@ -28,10 +36,15 @@ namespace System.Net.Http
         public HttpClientHandler()
         {
             _underlyingHandler = new HttpHandlerType();
+
+#if TARGET_BROWSER
+            Handler = _underlyingHandler;
             if (DiagnosticsHandler.IsGloballyEnabled())
             {
-                _diagnosticsHandler = new DiagnosticsHandler(_underlyingHandler);
+                Handler = new DiagnosticsHandler(Handler, DistributedContextPropagator.Current);
             }
+#endif
+
             ClientCertificateOptions = ClientCertificateOption.Manual;
         }
 
@@ -288,21 +301,11 @@ namespace System.Net.Http
         public IDictionary<string, object?> Properties => _underlyingHandler.Properties;
 
         [UnsupportedOSPlatform("browser")]
-        protected internal override HttpResponseMessage Send(HttpRequestMessage request,
-            CancellationToken cancellationToken)
-        {
-            return DiagnosticsHandler.IsEnabled() && _diagnosticsHandler != null ?
-                _diagnosticsHandler.Send(request, cancellationToken) :
-                _underlyingHandler.Send(request, cancellationToken);
-        }
+        protected internal override HttpResponseMessage Send(HttpRequestMessage request, CancellationToken cancellationToken) =>
+            Handler.Send(request, cancellationToken);
 
-        protected internal override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request,
-            CancellationToken cancellationToken)
-        {
-            return DiagnosticsHandler.IsEnabled() && _diagnosticsHandler != null ?
-                _diagnosticsHandler.SendAsync(request, cancellationToken) :
-                _underlyingHandler.SendAsync(request, cancellationToken);
-        }
+        protected internal override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) =>
+            Handler.SendAsync(request, cancellationToken);
 
         // lazy-load the validator func so it can be trimmed by the ILLinker if it isn't used.
         private static Func<HttpRequestMessage, X509Certificate2?, X509Chain?, SslPolicyErrors, bool>? s_dangerousAcceptAnyServerCertificateValidator;
index d0c88d60f7714239e0870050d12d8f41c0192148..daeda85fff79ac27d86fdc89839aee7d96a7efbb 100644 (file)
@@ -15,6 +15,7 @@ namespace System.Net.Http
 
         private const int MessageNotYetSent = 0;
         private const int MessageAlreadySent = 1;
+        private const int MessageIsRedirect = 2;
 
         // Track whether the message has been sent.
         // The message shouldn't be sent again if this field is equal to MessageAlreadySent.
@@ -159,12 +160,13 @@ namespace System.Net.Http
             return sb.ToString();
         }
 
-        internal bool MarkAsSent()
-        {
-            return Interlocked.Exchange(ref _sendStatus, MessageAlreadySent) == MessageNotYetSent;
-        }
+        internal bool MarkAsSent() => Interlocked.CompareExchange(ref _sendStatus, MessageAlreadySent, MessageNotYetSent) == MessageNotYetSent;
+
+        internal bool WasSentByHttpClient() => (_sendStatus & MessageAlreadySent) != 0;
+
+        internal void MarkAsRedirected() => _sendStatus |= MessageIsRedirect;
 
-        internal bool WasSentByHttpClient() => _sendStatus == MessageAlreadySent;
+        internal bool WasRedirected() => (_sendStatus & MessageIsRedirect) != 0;
 
         #region IDisposable Members
 
index 03cb3e9a4062432617bfbcfe392e59d0ed928179..a24a6403b299d1dd66024f7f74c4bd44600d0989 100644 (file)
@@ -8,6 +8,7 @@ using System.Net.Quic.Implementations;
 using System.Runtime.Versioning;
 using System.Threading;
 using System.Threading.Tasks;
+using System.Diagnostics;
 
 namespace System.Net.Http
 {
@@ -47,6 +48,8 @@ namespace System.Net.Http
         internal HeaderEncodingSelector<HttpRequestMessage>? _requestHeaderEncodingSelector;
         internal HeaderEncodingSelector<HttpRequestMessage>? _responseHeaderEncodingSelector;
 
+        internal DistributedContextPropagator? _activityHeadersPropagator = DistributedContextPropagator.Current;
+
         internal Version _maxHttpVersion;
 
         internal SslClientAuthenticationOptions? _sslOptions;
@@ -119,6 +122,7 @@ namespace System.Net.Http
                 _connectCallback = _connectCallback,
                 _plaintextStreamFilter = _plaintextStreamFilter,
                 _initialHttp2StreamWindowSize = _initialHttp2StreamWindowSize,
+                _activityHeadersPropagator = _activityHeadersPropagator,
             };
 
             // TODO: Remove if/when QuicImplementationProvider is removed from System.Net.Quic.
index 856b4e61da394f44a864eb1bf6e1c98e226eaa53..5d4b4846a37ebdca46a4d71dc0b5d99e925a3975 100644 (file)
@@ -75,6 +75,8 @@ namespace System.Net.Http
                     }
                 }
 
+                request.MarkAsRedirected();
+
                 // Issue the redirected request.
                 response = await _redirectInnerHandler.SendAsync(request, async, cancellationToken).ConfigureAwait(false);
             }
index 42361fba08085544e9f939b7d26e7259604a43ae..68fbd071e7d64a28b0226149531d6bb5cae159b2 100644 (file)
@@ -9,6 +9,7 @@ using System.Threading;
 using System.Threading.Tasks;
 using System.Diagnostics.CodeAnalysis;
 using System.Text;
+using System.Diagnostics;
 
 namespace System.Net.Http
 {
@@ -448,6 +449,22 @@ namespace System.Net.Http
             }
         }
 
+        /// <summary>
+        /// Gets or sets the <see cref="DistributedContextPropagator"/> to use when propagating the distributed trace and context.
+        /// Use <see langword="null"/> to disable propagation.
+        /// Defaults to <see cref="DistributedContextPropagator.Current"/>.
+        /// </summary>
+        [CLSCompliant(false)]
+        public DistributedContextPropagator? ActivityHeadersPropagator
+        {
+            get => _settings._activityHeadersPropagator;
+            set
+            {
+                CheckDisposedOrStarted();
+                _settings._activityHeadersPropagator = value;
+            }
+        }
+
         protected override void Dispose(bool disposing)
         {
             if (disposing && !_disposed)
@@ -478,6 +495,12 @@ namespace System.Net.Http
                 handler = new HttpAuthenticatedConnectionHandler(poolManager);
             }
 
+            // DiagnosticsHandler is inserted before RedirectHandler so that trace propagation is done on redirects as well
+            if (DiagnosticsHandler.IsGloballyEnabled() && settings._activityHeadersPropagator is DistributedContextPropagator propagator)
+            {
+                handler = new DiagnosticsHandler(handler, propagator, settings._allowAutoRedirect);
+            }
+
             if (settings._allowAutoRedirect)
             {
                 // Just as with WinHttpHandler, for security reasons, we do not support authentication on redirects
index 1fb6fd925fd33c07535698db6bcc7689bdfe38b2..49e4b0a384806dadf7a4dcf135a2f5d29e75bc58 100644 (file)
@@ -256,7 +256,7 @@ namespace System.Net.Http.Functional.Tests
                         GetProperty<HttpRequestMessage>(kvp.Value, "Request");
                         TaskStatus status = GetProperty<TaskStatus>(kvp.Value, "RequestTaskStatus");
                         Assert.Equal(TaskStatus.Canceled, status);
-                        activityStopTcs.SetResult();;
+                        activityStopTcs.SetResult();
                     }
                 });
 
@@ -308,6 +308,7 @@ namespace System.Net.Http.Functional.Tests
                 parentActivity.AddBaggage("correlationId", Guid.NewGuid().ToString("N").ToString());
                 parentActivity.AddBaggage("moreBaggage", Guid.NewGuid().ToString("N").ToString());
                 parentActivity.AddTag("tag", "tag"); // add tag to ensure it is not injected into request
+                parentActivity.TraceStateString = "Foo";
 
                 parentActivity.Start();
 
@@ -344,7 +345,7 @@ namespace System.Net.Http.Functional.Tests
                         activityStopResponseLogged = GetProperty<HttpResponseMessage>(kvp.Value, "Response");
                         TaskStatus requestStatus = GetProperty<TaskStatus>(kvp.Value, "RequestTaskStatus");
                         Assert.Equal(TaskStatus.RanToCompletion, requestStatus);
-                        activityStopTcs.SetResult();;
+                        activityStopTcs.SetResult();
                     }
                 });
 
@@ -403,13 +404,10 @@ namespace System.Net.Http.Functional.Tests
                         HttpRequestMessage request = GetProperty<HttpRequestMessage>(kvp.Value, "Request");
                         Assert.True(request.Headers.TryGetValues("Request-Id", out var requestId));
                         Assert.True(request.Headers.TryGetValues("Correlation-Context", out var correlationContext));
-                        Assert.Equal(3, correlationContext.Count());
-                        Assert.Contains("key=value", correlationContext);
-                        Assert.Contains("bad%2Fkey=value", correlationContext);
-                        Assert.Contains("goodkey=bad%2Fvalue", correlationContext);
+                        Assert.Equal("key=value, goodkey=bad%2Fvalue, bad%2Fkey=value", Assert.Single(correlationContext));
                         TaskStatus requestStatus = GetProperty<TaskStatus>(kvp.Value, "RequestTaskStatus");
                         Assert.Equal(TaskStatus.RanToCompletion, requestStatus);
-                        activityStopTcs.SetResult();;
+                        activityStopTcs.SetResult();
                     }
                     else if (kvp.Key.Equals("System.Net.Http.Exception"))
                     {
@@ -467,7 +465,7 @@ namespace System.Net.Http.Functional.Tests
 
                         Assert.False(request.Headers.TryGetValues("traceparent", out var _));
                         Assert.False(request.Headers.TryGetValues("tracestate", out var _));
-                        activityStopTcs.SetResult();;
+                        activityStopTcs.SetResult();
                     }
                 });
 
@@ -519,7 +517,7 @@ namespace System.Net.Http.Functional.Tests
                     }
                     else if (kvp.Key.Equals("System.Net.Http.HttpRequestOut.Stop"))
                     {
-                        activityStopTcs.SetResult();;
+                        activityStopTcs.SetResult();
                     }
                 });
 
@@ -608,7 +606,7 @@ namespace System.Net.Http.Functional.Tests
                         GetProperty<HttpRequestMessage>(kvp.Value, "Request");
                         TaskStatus requestStatus = GetProperty<TaskStatus>(kvp.Value, "RequestTaskStatus");
                         Assert.Equal(TaskStatus.Faulted, requestStatus);
-                        activityStopTcs.SetResult();;
+                        activityStopTcs.SetResult();
                     }
                     else if (kvp.Key.Equals("System.Net.Http.Exception"))
                     {
@@ -647,7 +645,7 @@ namespace System.Net.Http.Functional.Tests
                         GetProperty<HttpRequestMessage>(kvp.Value, "Request");
                         TaskStatus requestStatus = GetProperty<TaskStatus>(kvp.Value, "RequestTaskStatus");
                         Assert.Equal(TaskStatus.Faulted, requestStatus);
-                        activityStopTcs.SetResult();;
+                        activityStopTcs.SetResult();
                     }
                     else if (kvp.Key.Equals("System.Net.Http.Exception"))
                     {
@@ -796,42 +794,49 @@ namespace System.Net.Http.Functional.Tests
             }, UseVersion.ToString(), TestAsync.ToString()).Dispose();
         }
 
-        [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
-        public void SendAsync_ExpectedActivityPropagationWithoutListener()
+        public static IEnumerable<object[]> UseSocketsHttpHandler_WithIdFormat_MemberData()
         {
-            RemoteExecutor.Invoke(async (useVersion, testAsync) =>
-            {
-                Activity parent = new Activity("parent").Start();
+            yield return new object[] { true, ActivityIdFormat.Hierarchical };
+            yield return new object[] { true, ActivityIdFormat.W3C };
+            yield return new object[] { false, ActivityIdFormat.Hierarchical };
+            yield return new object[] { false, ActivityIdFormat.W3C };
+        }
 
-                await GetFactoryForVersion(useVersion).CreateClientAndServerAsync(
-                    async uri =>
-                    {
-                        await GetAsync(useVersion, testAsync, uri);
-                    },
-                    async server =>
-                    {
-                        HttpRequestData requestData = await server.AcceptConnectionSendResponseAndCloseAsync();
-                        AssertHeadersAreInjected(requestData, parent);
-                    });
-            }, UseVersion.ToString(), TestAsync.ToString()).Dispose();
+        [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))]
+        [MemberData(nameof(UseSocketsHttpHandler_WithIdFormat_MemberData))]
+        public async Task SendAsync_ExpectedActivityPropagationWithoutListener(bool useSocketsHttpHandler, ActivityIdFormat idFormat)
+        {
+            Activity parent = new Activity("parent");
+            parent.SetIdFormat(idFormat);
+            parent.Start();
+
+            await GetFactoryForVersion(UseVersion).CreateClientAndServerAsync(
+                async uri =>
+                {
+                    await GetAsync(UseVersion.ToString(), TestAsync.ToString(), uri, useSocketsHttpHandler: useSocketsHttpHandler);
+                },
+                async server =>
+                {
+                    HttpRequestData requestData = await server.HandleRequestAsync();
+                    AssertHeadersAreInjected(requestData, parent);
+                });
         }
 
-        [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
-        public void SendAsync_ExpectedActivityPropagationWithoutListenerOrParentActivity()
+        [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async Task SendAsync_ExpectedActivityPropagationWithoutListenerOrParentActivity(bool useSocketsHttpHandler)
         {
-            RemoteExecutor.Invoke(async (useVersion, testAsync) =>
-            {
-                await GetFactoryForVersion(useVersion).CreateClientAndServerAsync(
-                    async uri =>
-                    {
-                        await GetAsync(useVersion, testAsync, uri);
-                    },
-                    async server =>
-                    {
-                        HttpRequestData requestData = await server.AcceptConnectionSendResponseAndCloseAsync();
-                        AssertNoHeadersAreInjected(requestData);
-                    });
-            }, UseVersion.ToString(), TestAsync.ToString()).Dispose();
+            await GetFactoryForVersion(UseVersion).CreateClientAndServerAsync(
+                async uri =>
+                {
+                    await GetAsync(UseVersion.ToString(), TestAsync.ToString(), uri, useSocketsHttpHandler: useSocketsHttpHandler);
+                },
+                async server =>
+                {
+                    HttpRequestData requestData = await server.HandleRequestAsync();
+                    AssertNoHeadersAreInjected(requestData);
+                });
         }
 
         [ConditionalTheory(nameof(EnableActivityPropagationEnvironmentVariableIsNotSetAndRemoteExecutorSupported))]
@@ -877,6 +882,56 @@ namespace System.Net.Http.Functional.Tests
             }, UseVersion.ToString(), TestAsync.ToString(), envVarValue).Dispose();
         }
 
+        [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))]
+        [MemberData(nameof(UseSocketsHttpHandler_WithIdFormat_MemberData))]
+        public async Task SendAsync_HeadersAreInjectedOnRedirects(bool useSocketsHttpHandler, ActivityIdFormat idFormat)
+        {
+            Activity parent = new Activity("parent");
+            parent.SetIdFormat(idFormat);
+            parent.TraceStateString = "Foo";
+            parent.Start();
+
+            await GetFactoryForVersion(UseVersion).CreateServerAsync(async (originalServer, originalUri) =>
+            {
+                await GetFactoryForVersion(UseVersion).CreateServerAsync(async (redirectServer, redirectUri) =>
+                {
+                    Task clientTask = GetAsync(UseVersion.ToString(), TestAsync.ToString(), originalUri, useSocketsHttpHandler: useSocketsHttpHandler);
+
+                    Task<HttpRequestData> serverTask = originalServer.HandleRequestAsync(HttpStatusCode.Redirect, new[] { new HttpHeaderData("Location", redirectUri.AbsoluteUri) });
+
+                    await Task.WhenAny(clientTask, serverTask);
+                    Assert.False(clientTask.IsCompleted, $"{clientTask.Status}: {clientTask.Exception}");
+                    HttpRequestData firstRequestData = await serverTask;
+                    AssertHeadersAreInjected(firstRequestData, parent);
+
+                    serverTask = redirectServer.HandleRequestAsync();
+                    await TestHelper.WhenAllCompletedOrAnyFailed(clientTask, serverTask);
+                    HttpRequestData secondRequestData = await serverTask;
+                    AssertHeadersAreInjected(secondRequestData, parent);
+
+                    if (idFormat == ActivityIdFormat.W3C)
+                    {
+                        string firstParent = GetHeaderValue(firstRequestData, "traceparent");
+                        string firstState = GetHeaderValue(firstRequestData, "tracestate");
+                        Assert.True(ActivityContext.TryParse(firstParent, firstState, out ActivityContext firstContext));
+
+                        string secondParent = GetHeaderValue(secondRequestData, "traceparent");
+                        string secondState = GetHeaderValue(secondRequestData, "tracestate");
+                        Assert.True(ActivityContext.TryParse(secondParent, secondState, out ActivityContext secondContext));
+
+                        Assert.Equal(firstContext.TraceId, secondContext.TraceId);
+                        Assert.Equal(firstContext.TraceFlags, secondContext.TraceFlags);
+                        Assert.Equal(firstContext.TraceState, secondContext.TraceState);
+                        Assert.NotEqual(firstContext.SpanId, secondContext.SpanId);
+                    }
+                    else
+                    {
+                        Assert.NotEqual(GetHeaderValue(firstRequestData, "Request-Id"), GetHeaderValue(secondRequestData, "Request-Id"));
+                    }
+                });
+            });
+        }
+
         [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
         [InlineData(true)]
         [InlineData(false)]
@@ -893,12 +948,56 @@ namespace System.Net.Http.Functional.Tests
                         (HttpRequestMessage request, _) = await GetAsync(useVersion, testAsync, uri);
 
                         string headerName = parent.IdFormat == ActivityIdFormat.Hierarchical ? "Request-Id" : "traceparent";
+
                         Assert.Equal(bool.Parse(switchValue), request.Headers.Contains(headerName));
                     },
                     async server => await server.HandleRequestAsync());
             }, UseVersion.ToString(), TestAsync.ToString(), switchValue.ToString()).Dispose();
         }
 
+        public static IEnumerable<object[]> SocketsHttpHandlerPropagators_WithIdFormat_MemberData()
+        {
+            foreach (var propagator in new[] { null, DistributedContextPropagator.CreateDefaultPropagator(), DistributedContextPropagator.CreateNoOutputPropagator(), DistributedContextPropagator.CreatePassThroughPropagator() })
+            {
+                foreach (ActivityIdFormat format in new[] { ActivityIdFormat.Hierarchical, ActivityIdFormat.W3C })
+                {
+                    yield return new object[] { propagator, format };
+                }
+            }
+        }
+
+        [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))]
+        [MemberData(nameof(SocketsHttpHandlerPropagators_WithIdFormat_MemberData))]
+        public async Task SendAsync_CustomSocketsHttpHandlerPropagator_PropagatorIsUsed(DistributedContextPropagator propagator, ActivityIdFormat idFormat)
+        {
+            Activity parent = new Activity("parent");
+            parent.SetIdFormat(idFormat);
+            parent.Start();
+
+            await GetFactoryForVersion(UseVersion).CreateClientAndServerAsync(
+                async uri =>
+                {
+                    using var handler = new SocketsHttpHandler { ActivityHeadersPropagator = propagator };
+                    handler.SslOptions.RemoteCertificateValidationCallback = delegate { return true; };
+                    using var client = new HttpClient(handler);
+                    var request = CreateRequest(HttpMethod.Get, uri, UseVersion, exactVersion: true);
+                    await client.SendAsync(TestAsync, request);
+                },
+                async server =>
+                {
+                    HttpRequestData requestData = await server.HandleRequestAsync();
+
+                    if (propagator is null || ReferenceEquals(propagator, DistributedContextPropagator.CreateNoOutputPropagator()))
+                    {
+                        AssertNoHeadersAreInjected(requestData);
+                    }
+                    else
+                    {
+                        AssertHeadersAreInjected(requestData, parent, ReferenceEquals(propagator, DistributedContextPropagator.CreatePassThroughPropagator()));
+                    }
+                });
+        }
+
         private static T GetProperty<T>(object obj, string propertyName)
         {
             Type t = obj.GetType();
@@ -925,7 +1024,7 @@ namespace System.Net.Http.Functional.Tests
             Assert.Null(GetHeaderValue(request, "Correlation-Context"));
         }
 
-        private static void AssertHeadersAreInjected(HttpRequestData request, Activity parent)
+        private static void AssertHeadersAreInjected(HttpRequestData request, Activity parent, bool passthrough = false)
         {
             string requestId = GetHeaderValue(request, "Request-Id");
             string traceparent = GetHeaderValue(request, "traceparent");
@@ -935,7 +1034,7 @@ namespace System.Net.Http.Functional.Tests
             {
                 Assert.True(requestId != null, "Request-Id was not injected when instrumentation was enabled");
                 Assert.StartsWith(parent.Id, requestId);
-                Assert.NotEqual(parent.Id, requestId);
+                Assert.Equal(passthrough, parent.Id == requestId);
                 Assert.Null(traceparent);
                 Assert.Null(tracestate);
             }
@@ -944,6 +1043,7 @@ namespace System.Net.Http.Functional.Tests
                 Assert.Null(requestId);
                 Assert.True(traceparent != null, "traceparent was not injected when W3C instrumentation was enabled");
                 Assert.StartsWith($"00-{parent.TraceId.ToHexString()}-", traceparent);
+                Assert.Equal(passthrough, parent.Id == traceparent);
                 Assert.Equal(parent.TraceStateString, tracestate);
             }
 
@@ -960,10 +1060,23 @@ namespace System.Net.Http.Functional.Tests
             }
         }
 
-        private static async Task<(HttpRequestMessage, HttpResponseMessage)> GetAsync(string useVersion, string testAsync, Uri uri, CancellationToken cancellationToken = default)
+        private static async Task<(HttpRequestMessage, HttpResponseMessage)> GetAsync(string useVersion, string testAsync, Uri uri, CancellationToken cancellationToken = default, bool useSocketsHttpHandler = false)
         {
-            HttpClientHandler handler = CreateHttpClientHandler(useVersion);
-            handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+            HttpMessageHandler handler;
+            if (useSocketsHttpHandler)
+            {
+                var socketsHttpHandler = new SocketsHttpHandler();
+                socketsHttpHandler.SslOptions.RemoteCertificateValidationCallback = delegate { return true; };
+                handler = socketsHttpHandler;
+            }
+            else
+            {
+                handler = new HttpClientHandler
+                {
+                    ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates
+                };
+            }
+
             using var client = new HttpClient(handler);
             var request = CreateRequest(HttpMethod.Get, uri, Version.Parse(useVersion), exactVersion: true);
             return (request, await client.SendAsync(bool.Parse(testAsync), request, cancellationToken));