Serialize Dns async-over-sync requests for the same host (#49171)
authorStephen Toub <stoub@microsoft.com>
Fri, 5 Mar 2021 21:15:30 +0000 (16:15 -0500)
committerGitHub <noreply@github.com>
Fri, 5 Mar 2021 21:15:30 +0000 (16:15 -0500)
* Serialize Dns async-over-sync requests for the same host

* Update src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs

src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs

index f684e9b..b4ca39a 100644 (file)
@@ -628,9 +628,6 @@ namespace System.Net
             }
         }
 
-        private static Task<TResult> RunAsync<TResult>(Func<object, TResult> func, object arg, CancellationToken cancellationToken) =>
-            Task.Factory.StartNew(func!, arg, cancellationToken, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default);
-
         private static IPHostEntry CreateHostEntryForAddress(IPAddress address) =>
             new IPHostEntry
             {
@@ -656,5 +653,71 @@ namespace System.Net
             NameResolutionTelemetry.Log.AfterResolution(stopwatch, successful: false);
             return false;
         }
+
+        /// <summary>Mapping from key to current task in flight for that key.</summary>
+        private static readonly Dictionary<object, Task> s_tasks = new Dictionary<object, Task>();
+
+        /// <summary>Queue the function to be invoked asynchronously.</summary>
+        /// <remarks>
+        /// Since this is doing synchronous work on a thread pool thread, we want to limit how many threads end up being
+        /// blocked.  We could employ a semaphore to limit overall usage, but a common case is that DNS requests are made
+        /// for only a handful of endpoints, and a reasonable compromise is to ensure that requests for a given host are
+        /// serialized.  Once the data for that host is cached locally by the OS, the subsequent requests should all complete
+        /// very quickly, and if the head-of-line request is taking a long time due to the connection to the server, we won't
+        /// block lots of threads all getting data for that one host.  We also still want to issue the request to the OS, rather
+        /// than having all concurrent requests for the same host share the exact same task, so that any shuffling of the results
+        /// by the OS to enable round robin is still perceived.
+        /// </remarks>
+        private static Task<TResult> RunAsync<TResult>(Func<object, TResult> func, object key, CancellationToken cancellationToken)
+        {
+            Task<TResult>? task = null;
+
+            lock (s_tasks)
+            {
+                // Get the previous task for this key, if there is one.
+                s_tasks.TryGetValue(key, out Task? prevTask);
+                prevTask ??= Task.CompletedTask;
+
+                // Invoke the function in a queued work item when the previous task completes. Note that some callers expect the
+                // returned task to have the key as the task's AsyncState.
+                task = prevTask.ContinueWith(delegate
+                {
+                    Debug.Assert(!Monitor.IsEntered(s_tasks));
+                    try
+                    {
+                        return func(key);
+                    }
+                    finally
+                    {
+                        // When the work is done, remove this key/task pair from the dictionary if this is still the current task.
+                        // Because the work item is created and stored into both the local and the dictionary while the lock is
+                        // held, and since we take the same lock here, inside this lock it's guaranteed to see the changes
+                        // made by the call site.
+                        lock (s_tasks)
+                        {
+                            ((ICollection<KeyValuePair<object, Task>>)s_tasks).Remove(new KeyValuePair<object, Task>(key!, task!));
+                        }
+                    }
+                }, key, cancellationToken, TaskContinuationOptions.DenyChildAttach, TaskScheduler.Default);
+
+                // If it's possible the task may end up getting canceled, it won't have a chance to remove itself from
+                // the dictionary if it is canceled, so use a separate continuation to do so.
+                if (cancellationToken.CanBeCanceled)
+                {
+                    task.ContinueWith((task, key) =>
+                    {
+                        lock (s_tasks)
+                        {
+                            ((ICollection<KeyValuePair<object, Task>>)s_tasks).Remove(new KeyValuePair<object, Task>(key!, task));
+                        }
+                    }, key, CancellationToken.None, TaskContinuationOptions.OnlyOnCanceled | TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);
+                }
+
+                // Finally, store the task into the dictionary as the current task for this key.
+                s_tasks[key] = task;
+            }
+
+            return task;
+        }
     }
 }