Eliminate cancellation deadlock in RateLimiter implementations (#90285)
authorBrennan <brecon@microsoft.com>
Thu, 10 Aug 2023 15:28:45 +0000 (08:28 -0700)
committerGitHub <noreply@github.com>
Thu, 10 Aug 2023 15:28:45 +0000 (08:28 -0700)
src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/ConcurrencyLimiter.cs
src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/FixedWindowRateLimiter.cs
src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/SlidingWindowRateLimiter.cs
src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/TokenBucketRateLimiter.cs
src/libraries/System.Threading.RateLimiting/tests/ConcurrencyLimiterTests.cs

index f375e50..6b5a401 100644 (file)
@@ -132,6 +132,8 @@ namespace System.Threading.RateLimiting
                 return new ValueTask<RateLimitLease>(SuccessfulLease);
             }
 
+            using var disposer = default(RequestRegistration.Disposer);
+
             // Perf: Check SemaphoreSlim implementation instead of locking
             lock (Lock)
             {
@@ -152,7 +154,7 @@ namespace System.Threading.RateLimiting
                             RequestRegistration oldestRequest = _queue.DequeueHead();
                             _queueCount -= oldestRequest.Count;
                             Debug.Assert(_queueCount >= 0);
-                            if (!oldestRequest.Tcs.TrySetResult(FailedLease))
+                            if (!oldestRequest.TrySetResult(FailedLease))
                             {
                                 // Updating queue count is handled by the cancellation code
                                 _queueCount += oldestRequest.Count;
@@ -161,7 +163,7 @@ namespace System.Threading.RateLimiting
                             {
                                 Interlocked.Increment(ref _failedLeasesCount);
                             }
-                            oldestRequest.CancellationTokenRegistration.Dispose();
+                            disposer.Add(oldestRequest);
                         }
                         while (_options.QueueLimit - _queueCount < permitCount);
                     }
@@ -173,22 +175,12 @@ namespace System.Threading.RateLimiting
                     }
                 }
 
-                CancelQueueState tcs = new CancelQueueState(permitCount, this, cancellationToken);
-                CancellationTokenRegistration ctr = default;
-                if (cancellationToken.CanBeCanceled)
-                {
-                    ctr = cancellationToken.Register(static obj =>
-                    {
-                        ((CancelQueueState)obj!).TrySetCanceled();
-                    }, tcs);
-                }
-
-                RequestRegistration request = new RequestRegistration(permitCount, tcs, ctr);
+                var request = new RequestRegistration(permitCount, this, cancellationToken);
                 _queue.EnqueueTail(request);
                 _queueCount += permitCount;
                 Debug.Assert(_queueCount <= _options.QueueLimit);
 
-                return new ValueTask<RateLimitLease>(request.Tcs.Task);
+                return new ValueTask<RateLimitLease>(request.Task);
             }
         }
 
@@ -224,8 +216,15 @@ namespace System.Threading.RateLimiting
             return false;
         }
 
+#if DEBUG
+        // for unit testing
+        internal event Action? ReleasePreHook;
+        internal event Action? ReleasePostHook;
+#endif
+
         private void Release(int releaseCount)
         {
+            using var disposer = default(RequestRegistration.Disposer);
             lock (Lock)
             {
                 if (_disposed)
@@ -236,6 +235,10 @@ namespace System.Threading.RateLimiting
                 _permitCount += releaseCount;
                 Debug.Assert(_permitCount <= _options.PermitLimit);
 
+#if DEBUG
+                ReleasePreHook?.Invoke();
+#endif
+
                 while (_queue.Count > 0)
                 {
                     RequestRegistration nextPendingRequest =
@@ -245,15 +248,21 @@ namespace System.Threading.RateLimiting
 
                     // Request was handled already, either via cancellation or being kicked from the queue due to a newer request being queued.
                     // We just need to remove the item and let the next queued item be considered for completion.
-                    if (nextPendingRequest.Tcs.Task.IsCompleted)
+                    if (nextPendingRequest.Task.IsCompleted)
                     {
                         nextPendingRequest =
                             _options.QueueProcessingOrder == QueueProcessingOrder.OldestFirst
                             ? _queue.DequeueHead()
                             : _queue.DequeueTail();
-                        nextPendingRequest.CancellationTokenRegistration.Dispose();
+                        disposer.Add(nextPendingRequest);
+                        continue;
                     }
-                    else if (_permitCount >= nextPendingRequest.Count)
+
+#if DEBUG
+                    ReleasePostHook?.Invoke();
+#endif
+
+                    if (_permitCount >= nextPendingRequest.Count)
                     {
                         nextPendingRequest =
                             _options.QueueProcessingOrder == QueueProcessingOrder.OldestFirst
@@ -266,7 +275,7 @@ namespace System.Threading.RateLimiting
 
                         ConcurrencyLease lease = nextPendingRequest.Count == 0 ? SuccessfulLease : new ConcurrencyLease(true, this, nextPendingRequest.Count);
                         // Check if request was canceled
-                        if (!nextPendingRequest.Tcs.TrySetResult(lease))
+                        if (!nextPendingRequest.TrySetResult(lease))
                         {
                             // Queued item was canceled so add count back
                             _permitCount += nextPendingRequest.Count;
@@ -277,7 +286,7 @@ namespace System.Threading.RateLimiting
                         {
                             Interlocked.Increment(ref _successfulLeasesCount);
                         }
-                        nextPendingRequest.CancellationTokenRegistration.Dispose();
+                        disposer.Add(nextPendingRequest);
                         Debug.Assert(_queueCount >= 0);
                     }
                     else
@@ -289,7 +298,6 @@ namespace System.Threading.RateLimiting
                 if (_permitCount == _options.PermitLimit)
                 {
                     Debug.Assert(_idleSince is null);
-                    Debug.Assert(_queueCount == 0);
                     _idleSince = Stopwatch.GetTimestamp();
                 }
             }
@@ -303,6 +311,7 @@ namespace System.Threading.RateLimiting
                 return;
             }
 
+            using var disposer = default(RequestRegistration.Disposer);
             lock (Lock)
             {
                 if (_disposed)
@@ -315,8 +324,8 @@ namespace System.Threading.RateLimiting
                     RequestRegistration next = _options.QueueProcessingOrder == QueueProcessingOrder.OldestFirst
                         ? _queue.DequeueHead()
                         : _queue.DequeueTail();
-                    next.CancellationTokenRegistration.Dispose();
-                    next.Tcs.TrySetResult(FailedLease);
+                    disposer.Add(next);
+                    next.TrySetResult(FailedLease);
                 }
             }
         }
@@ -385,49 +394,68 @@ namespace System.Threading.RateLimiting
             }
         }
 
-        private readonly struct RequestRegistration
+        private sealed class RequestRegistration : TaskCompletionSource<RateLimitLease>
         {
-            public RequestRegistration(int requestedCount, TaskCompletionSource<RateLimitLease> tcs,
-                CancellationTokenRegistration cancellationTokenRegistration)
-            {
-                Count = requestedCount;
-                // Perf: Use AsyncOperation<TResult> instead
-                Tcs = tcs;
-                CancellationTokenRegistration = cancellationTokenRegistration;
-            }
+            private readonly CancellationToken _cancellationToken;
+            private CancellationTokenRegistration _cancellationTokenRegistration;
 
-            public int Count { get; }
+            // this field is used only by the disposal mechanics and never shared between threads
+            private RequestRegistration? _next;
 
-            public TaskCompletionSource<RateLimitLease> Tcs { get; }
+            public RequestRegistration(int permitCount, ConcurrencyLimiter limiter, CancellationToken cancellationToken)
+                : base(limiter, TaskCreationOptions.RunContinuationsAsynchronously)
+            {
+                Count = permitCount;
+                _cancellationToken = cancellationToken;
 
-            public CancellationTokenRegistration CancellationTokenRegistration { get; }
-        }
+                // RequestRegistration objects are created while the limiter lock is held
+                // if cancellationToken fires before or while the lock is held, UnsafeRegister
+                // is going to invoke the callback synchronously, but this does not create
+                // a deadlock because lock are reentrant
+                if (cancellationToken.CanBeCanceled)
+#if NETCOREAPP || NETSTANDARD2_1_OR_GREATER
+                    _cancellationTokenRegistration = cancellationToken.UnsafeRegister(Cancel, this);
+#else
+                    _cancellationTokenRegistration = cancellationToken.Register(Cancel, this);
+#endif
+            }
 
-        private sealed class CancelQueueState : TaskCompletionSource<RateLimitLease>
-        {
-            private readonly int _permitCount;
-            private readonly ConcurrencyLimiter _limiter;
-            private readonly CancellationToken _cancellationToken;
+            public int Count { get; }
 
-            public CancelQueueState(int permitCount, ConcurrencyLimiter limiter, CancellationToken cancellationToken)
-                : base(TaskCreationOptions.RunContinuationsAsynchronously)
+            private static void Cancel(object? state)
             {
-                _permitCount = permitCount;
-                _limiter = limiter;
-                _cancellationToken = cancellationToken;
+                if (state is RequestRegistration registration && registration.TrySetCanceled(registration._cancellationToken))
+                {
+                    var limiter = (ConcurrencyLimiter)registration.Task.AsyncState!;
+                    lock (limiter.Lock)
+                    {
+                        limiter._queueCount -= registration.Count;
+                    }
+                }
             }
 
-            public new bool TrySetCanceled()
+            /// <summary>
+            /// Collects registrations to dispose outside the limiter lock to avoid deadlock.
+            /// </summary>
+            public struct Disposer : IDisposable
             {
-                if (TrySetCanceled(_cancellationToken))
+                private RequestRegistration? _next;
+
+                public void Add(RequestRegistration request)
+                {
+                    request._next = _next;
+                    _next = request;
+                }
+
+                public void Dispose()
                 {
-                    lock (_limiter.Lock)
+                    for (var current = _next; current is not null; current = current._next)
                     {
-                        _limiter._queueCount -= _permitCount;
+                        current._cancellationTokenRegistration.Dispose();
                     }
-                    return true;
+
+                    _next = null;
                 }
-                return false;
             }
         }
     }
index 897e5fa..d09c797 100644 (file)
@@ -151,6 +151,7 @@ namespace System.Threading.RateLimiting
                 return new ValueTask<RateLimitLease>(SuccessfulLease);
             }
 
+            using var disposer = default(RequestRegistration.Disposer);
             lock (Lock)
             {
                 if (TryLeaseUnsynchronized(permitCount, out RateLimitLease? lease))
@@ -170,7 +171,7 @@ namespace System.Threading.RateLimiting
                             RequestRegistration oldestRequest = _queue.DequeueHead();
                             _queueCount -= oldestRequest.Count;
                             Debug.Assert(_queueCount >= 0);
-                            if (!oldestRequest.Tcs.TrySetResult(FailedLease))
+                            if (!oldestRequest.TrySetResult(FailedLease))
                             {
                                 _queueCount += oldestRequest.Count;
                             }
@@ -178,7 +179,7 @@ namespace System.Threading.RateLimiting
                             {
                                 Interlocked.Increment(ref _failedLeasesCount);
                             }
-                            oldestRequest.CancellationTokenRegistration.Dispose();
+                            disposer.Add(oldestRequest);
                         }
                         while (_options.QueueLimit - _queueCount < permitCount);
                     }
@@ -190,22 +191,12 @@ namespace System.Threading.RateLimiting
                     }
                 }
 
-                CancelQueueState tcs = new CancelQueueState(permitCount, this, cancellationToken);
-                CancellationTokenRegistration ctr = default;
-                if (cancellationToken.CanBeCanceled)
-                {
-                    ctr = cancellationToken.Register(static obj =>
-                    {
-                        ((CancelQueueState)obj!).TrySetCanceled();
-                    }, tcs);
-                }
-
-                RequestRegistration registration = new RequestRegistration(permitCount, tcs, ctr);
+                var registration = new RequestRegistration(permitCount, this, cancellationToken);
                 _queue.EnqueueTail(registration);
                 _queueCount += permitCount;
                 Debug.Assert(_queueCount <= _options.QueueLimit);
 
-                return new ValueTask<RateLimitLease>(registration.Tcs.Task);
+                return new ValueTask<RateLimitLease>(registration.Task);
             }
         }
 
@@ -280,6 +271,8 @@ namespace System.Threading.RateLimiting
         // Used in tests that test behavior with specific time intervals
         private void ReplenishInternal(long nowTicks)
         {
+            using var disposer = default(RequestRegistration.Disposer);
+
             // Method is re-entrant (from Timer), lock to avoid multiple simultaneous replenishes
             lock (Lock)
             {
@@ -315,13 +308,13 @@ namespace System.Threading.RateLimiting
 
                     // Request was handled already, either via cancellation or being kicked from the queue due to a newer request being queued.
                     // We just need to remove the item and let the next queued item be considered for completion.
-                    if (nextPendingRequest.Tcs.Task.IsCompleted)
+                    if (nextPendingRequest.Task.IsCompleted)
                     {
                         nextPendingRequest =
                             _options.QueueProcessingOrder == QueueProcessingOrder.OldestFirst
                             ? _queue.DequeueHead()
                             : _queue.DequeueTail();
-                        nextPendingRequest.CancellationTokenRegistration.Dispose();
+                        disposer.Add(nextPendingRequest);
                     }
                     else if (_permitCount >= nextPendingRequest.Count)
                     {
@@ -335,7 +328,7 @@ namespace System.Threading.RateLimiting
                         _permitCount -= nextPendingRequest.Count;
                         Debug.Assert(_permitCount >= 0);
 
-                        if (!nextPendingRequest.Tcs.TrySetResult(SuccessfulLease))
+                        if (!nextPendingRequest.TrySetResult(SuccessfulLease))
                         {
                             // Queued item was canceled so add count back
                             _permitCount += nextPendingRequest.Count;
@@ -346,7 +339,7 @@ namespace System.Threading.RateLimiting
                         {
                             Interlocked.Increment(ref _successfulLeasesCount);
                         }
-                        nextPendingRequest.CancellationTokenRegistration.Dispose();
+                        disposer.Add(nextPendingRequest);
                         Debug.Assert(_queueCount >= 0);
                     }
                     else
@@ -359,7 +352,6 @@ namespace System.Threading.RateLimiting
                 if (_permitCount == _options.PermitLimit)
                 {
                     Debug.Assert(_idleSince is null);
-                    Debug.Assert(_queueCount == 0);
                     _idleSince = Stopwatch.GetTimestamp();
                 }
             }
@@ -373,6 +365,7 @@ namespace System.Threading.RateLimiting
                 return;
             }
 
+            using var disposer = default(RequestRegistration.Disposer);
             lock (Lock)
             {
                 if (_disposed)
@@ -386,8 +379,8 @@ namespace System.Threading.RateLimiting
                     RequestRegistration next = _options.QueueProcessingOrder == QueueProcessingOrder.OldestFirst
                         ? _queue.DequeueHead()
                         : _queue.DequeueTail();
-                    next.CancellationTokenRegistration.Dispose();
-                    next.Tcs.TrySetResult(FailedLease);
+                    disposer.Add(next);
+                    next.TrySetResult(FailedLease);
                 }
             }
         }
@@ -437,48 +430,68 @@ namespace System.Threading.RateLimiting
             }
         }
 
-        private readonly struct RequestRegistration
+        private sealed class RequestRegistration : TaskCompletionSource<RateLimitLease>
         {
-            public RequestRegistration(int permitCount, TaskCompletionSource<RateLimitLease> tcs, CancellationTokenRegistration cancellationTokenRegistration)
+            private readonly CancellationToken _cancellationToken;
+            private CancellationTokenRegistration _cancellationTokenRegistration;
+
+            // this field is used only by the disposal mechanics and never shared between threads
+            private RequestRegistration? _next;
+
+            public RequestRegistration(int permitCount, FixedWindowRateLimiter limiter, CancellationToken cancellationToken)
+                : base(limiter, TaskCreationOptions.RunContinuationsAsynchronously)
             {
                 Count = permitCount;
-                // Use VoidAsyncOperationWithData<T> instead
-                Tcs = tcs;
-                CancellationTokenRegistration = cancellationTokenRegistration;
+                _cancellationToken = cancellationToken;
+
+                // RequestRegistration objects are created while the limiter lock is held
+                // if cancellationToken fires before or while the lock is held, UnsafeRegister
+                // is going to invoke the callback synchronously, but this does not create
+                // a deadlock because lock are reentrant
+                if (cancellationToken.CanBeCanceled)
+#if NETCOREAPP || NETSTANDARD2_1_OR_GREATER
+                    _cancellationTokenRegistration = cancellationToken.UnsafeRegister(Cancel, this);
+#else
+                    _cancellationTokenRegistration = cancellationToken.Register(Cancel, this);
+#endif
             }
 
             public int Count { get; }
 
-            public TaskCompletionSource<RateLimitLease> Tcs { get; }
-
-            public CancellationTokenRegistration CancellationTokenRegistration { get; }
-        }
-
-        private sealed class CancelQueueState : TaskCompletionSource<RateLimitLease>
-        {
-            private readonly int _permitCount;
-            private readonly FixedWindowRateLimiter _limiter;
-            private readonly CancellationToken _cancellationToken;
-
-            public CancelQueueState(int permitCount, FixedWindowRateLimiter limiter, CancellationToken cancellationToken)
-                : base(TaskCreationOptions.RunContinuationsAsynchronously)
+            private static void Cancel(object? state)
             {
-                _permitCount = permitCount;
-                _limiter = limiter;
-                _cancellationToken = cancellationToken;
+                if (state is RequestRegistration registration && registration.TrySetCanceled(registration._cancellationToken))
+                {
+                    var limiter = (FixedWindowRateLimiter)registration.Task.AsyncState!;
+                    lock (limiter.Lock)
+                    {
+                        limiter._queueCount -= registration.Count;
+                    }
+                }
             }
 
-            public new bool TrySetCanceled()
+            /// <summary>
+            /// Collects registrations to dispose outside the limiter lock to avoid deadlock.
+            /// </summary>
+            public struct Disposer : IDisposable
             {
-                if (TrySetCanceled(_cancellationToken))
+                private RequestRegistration? _next;
+
+                public void Add(RequestRegistration request)
+                {
+                    request._next = _next;
+                    _next = request;
+                }
+
+                public void Dispose()
                 {
-                    lock (_limiter.Lock)
+                    for (var current = _next; current is not null; current = current._next)
                     {
-                        _limiter._queueCount -= _permitCount;
+                        current._cancellationTokenRegistration.Dispose();
                     }
-                    return true;
+
+                    _next = null;
                 }
-                return false;
             }
         }
     }
index 340362c..a179720 100644 (file)
@@ -163,6 +163,7 @@ namespace System.Threading.RateLimiting
                 return new ValueTask<RateLimitLease>(SuccessfulLease);
             }
 
+            using var disposer = default(RequestRegistration.Disposer);
             lock (Lock)
             {
                 if (TryLeaseUnsynchronized(permitCount, out RateLimitLease? lease))
@@ -182,7 +183,7 @@ namespace System.Threading.RateLimiting
                             RequestRegistration oldestRequest = _queue.DequeueHead();
                             _queueCount -= oldestRequest.Count;
                             Debug.Assert(_queueCount >= 0);
-                            if (!oldestRequest.Tcs.TrySetResult(FailedLease))
+                            if (!oldestRequest.TrySetResult(FailedLease))
                             {
                                 _queueCount += oldestRequest.Count;
                             }
@@ -190,7 +191,7 @@ namespace System.Threading.RateLimiting
                             {
                                 Interlocked.Increment(ref _failedLeasesCount);
                             }
-                            oldestRequest.CancellationTokenRegistration.Dispose();
+                            disposer.Add(oldestRequest);
                         }
                         while (_options.QueueLimit - _queueCount < permitCount);
                     }
@@ -202,22 +203,12 @@ namespace System.Threading.RateLimiting
                     }
                 }
 
-                CancelQueueState tcs = new CancelQueueState(permitCount, this, cancellationToken);
-                CancellationTokenRegistration ctr = default;
-                if (cancellationToken.CanBeCanceled)
-                {
-                    ctr = cancellationToken.Register(static obj =>
-                    {
-                        ((CancelQueueState)obj!).TrySetCanceled();
-                    }, tcs);
-                }
-
-                RequestRegistration registration = new RequestRegistration(permitCount, tcs, ctr);
+                var registration = new RequestRegistration(permitCount, this, cancellationToken);
                 _queue.EnqueueTail(registration);
                 _queueCount += permitCount;
                 Debug.Assert(_queueCount <= _options.QueueLimit);
 
-                return new ValueTask<RateLimitLease>(registration.Tcs.Task);
+                return new ValueTask<RateLimitLease>(registration.Task);
             }
         }
 
@@ -286,6 +277,8 @@ namespace System.Threading.RateLimiting
         // Used in tests that test behavior with specific time intervals
         private void ReplenishInternal(long nowTicks)
         {
+            using var disposer = default(RequestRegistration.Disposer);
+
             // Method is re-entrant (from Timer), lock to avoid multiple simultaneous replenishes
             lock (Lock)
             {
@@ -325,13 +318,13 @@ namespace System.Threading.RateLimiting
 
                     // Request was handled already, either via cancellation or being kicked from the queue due to a newer request being queued.
                     // We just need to remove the item and let the next queued item be considered for completion.
-                    if (nextPendingRequest.Tcs.Task.IsCompleted)
+                    if (nextPendingRequest.Task.IsCompleted)
                     {
                         nextPendingRequest =
                             _options.QueueProcessingOrder == QueueProcessingOrder.OldestFirst
                             ? _queue.DequeueHead()
                             : _queue.DequeueTail();
-                        nextPendingRequest.CancellationTokenRegistration.Dispose();
+                        disposer.Add(nextPendingRequest);
                     }
                     // If we have enough permits after replenishing to serve the queued requests
                     else if (_permitCount >= nextPendingRequest.Count)
@@ -347,7 +340,7 @@ namespace System.Threading.RateLimiting
                         _requestsPerSegment[_currentSegmentIndex] += nextPendingRequest.Count;
                         Debug.Assert(_permitCount >= 0);
 
-                        if (!nextPendingRequest.Tcs.TrySetResult(SuccessfulLease))
+                        if (!nextPendingRequest.TrySetResult(SuccessfulLease))
                         {
                             // Queued item was canceled so add count back
                             _permitCount += nextPendingRequest.Count;
@@ -359,7 +352,7 @@ namespace System.Threading.RateLimiting
                         {
                             Interlocked.Increment(ref _successfulLeasesCount);
                         }
-                        nextPendingRequest.CancellationTokenRegistration.Dispose();
+                        disposer.Add(nextPendingRequest);
                         Debug.Assert(_queueCount >= 0);
                     }
                     else
@@ -372,7 +365,6 @@ namespace System.Threading.RateLimiting
                 if (_permitCount == _options.PermitLimit)
                 {
                     Debug.Assert(_idleSince is null);
-                    Debug.Assert(_queueCount == 0);
                     _idleSince = Stopwatch.GetTimestamp();
                 }
             }
@@ -386,6 +378,7 @@ namespace System.Threading.RateLimiting
                 return;
             }
 
+            using var disposer = default(RequestRegistration.Disposer);
             lock (Lock)
             {
                 if (_disposed)
@@ -399,8 +392,8 @@ namespace System.Threading.RateLimiting
                     RequestRegistration next = _options.QueueProcessingOrder == QueueProcessingOrder.OldestFirst
                         ? _queue.DequeueHead()
                         : _queue.DequeueTail();
-                    next.CancellationTokenRegistration.Dispose();
-                    next.Tcs.TrySetResult(FailedLease);
+                    disposer.Add(next);
+                    next.TrySetResult(FailedLease);
                 }
             }
         }
@@ -450,48 +443,68 @@ namespace System.Threading.RateLimiting
             }
         }
 
-        private readonly struct RequestRegistration
+        private sealed class RequestRegistration : TaskCompletionSource<RateLimitLease>
         {
-            public RequestRegistration(int permitCount, TaskCompletionSource<RateLimitLease> tcs, CancellationTokenRegistration cancellationTokenRegistration)
+            private readonly CancellationToken _cancellationToken;
+            private CancellationTokenRegistration _cancellationTokenRegistration;
+
+            // this field is used only by the disposal mechanics and never shared between threads
+            private RequestRegistration? _next;
+
+            public RequestRegistration(int permitCount, SlidingWindowRateLimiter limiter, CancellationToken cancellationToken)
+                : base(limiter, TaskCreationOptions.RunContinuationsAsynchronously)
             {
                 Count = permitCount;
-                // Use VoidAsyncOperationWithData<T> instead
-                Tcs = tcs;
-                CancellationTokenRegistration = cancellationTokenRegistration;
+                _cancellationToken = cancellationToken;
+
+                // RequestRegistration objects are created while the limiter lock is held
+                // if cancellationToken fires before or while the lock is held, UnsafeRegister
+                // is going to invoke the callback synchronously, but this does not create
+                // a deadlock because lock are reentrant
+                if (cancellationToken.CanBeCanceled)
+#if NETCOREAPP || NETSTANDARD2_1_OR_GREATER
+                    _cancellationTokenRegistration = cancellationToken.UnsafeRegister(Cancel, this);
+#else
+                    _cancellationTokenRegistration = cancellationToken.Register(Cancel, this);
+#endif
             }
 
             public int Count { get; }
 
-            public TaskCompletionSource<RateLimitLease> Tcs { get; }
-
-            public CancellationTokenRegistration CancellationTokenRegistration { get; }
-        }
-
-        private sealed class CancelQueueState : TaskCompletionSource<RateLimitLease>
-        {
-            private readonly int _permitCount;
-            private readonly SlidingWindowRateLimiter _limiter;
-            private readonly CancellationToken _cancellationToken;
-
-            public CancelQueueState(int permitCount, SlidingWindowRateLimiter limiter, CancellationToken cancellationToken)
-                : base(TaskCreationOptions.RunContinuationsAsynchronously)
+            private static void Cancel(object? state)
             {
-                _permitCount = permitCount;
-                _limiter = limiter;
-                _cancellationToken = cancellationToken;
+                if (state is RequestRegistration registration && registration.TrySetCanceled(registration._cancellationToken))
+                {
+                    var limiter = (SlidingWindowRateLimiter)registration.Task.AsyncState!;
+                    lock (limiter.Lock)
+                    {
+                        limiter._queueCount -= registration.Count;
+                    }
+                }
             }
 
-            public new bool TrySetCanceled()
+            /// <summary>
+            /// Collects registrations to dispose outside the limiter lock to avoid deadlock.
+            /// </summary>
+            public struct Disposer : IDisposable
             {
-                if (TrySetCanceled(_cancellationToken))
+                private RequestRegistration? _next;
+
+                public void Add(RequestRegistration request)
+                {
+                    request._next = _next;
+                    _next = request;
+                }
+
+                public void Dispose()
                 {
-                    lock (_limiter.Lock)
+                    for (var current = _next; current is not null; current = current._next)
                     {
-                        _limiter._queueCount -= _permitCount;
+                        current._cancellationTokenRegistration.Dispose();
                     }
-                    return true;
+
+                    _next = null;
                 }
-                return false;
             }
         }
     }
index 3c22644..5ad7859 100644 (file)
@@ -156,6 +156,7 @@ namespace System.Threading.RateLimiting
                 return new ValueTask<RateLimitLease>(SuccessfulLease);
             }
 
+            using var disposer = default(RequestRegistration.Disposer);
             lock (Lock)
             {
                 if (TryLeaseUnsynchronized(tokenCount, out RateLimitLease? lease))
@@ -175,7 +176,7 @@ namespace System.Threading.RateLimiting
                             RequestRegistration oldestRequest = _queue.DequeueHead();
                             _queueCount -= oldestRequest.Count;
                             Debug.Assert(_queueCount >= 0);
-                            if (!oldestRequest.Tcs.TrySetResult(FailedLease))
+                            if (!oldestRequest.TrySetResult(FailedLease))
                             {
                                 // Updating queue count is handled by the cancellation code
                                 _queueCount += oldestRequest.Count;
@@ -184,7 +185,7 @@ namespace System.Threading.RateLimiting
                             {
                                 Interlocked.Increment(ref _failedLeasesCount);
                             }
-                            oldestRequest.CancellationTokenRegistration.Dispose();
+                            disposer.Add(oldestRequest);
                         }
                         while (_options.QueueLimit - _queueCount < tokenCount);
                     }
@@ -196,22 +197,12 @@ namespace System.Threading.RateLimiting
                     }
                 }
 
-                CancelQueueState tcs = new CancelQueueState(tokenCount, this, cancellationToken);
-                CancellationTokenRegistration ctr = default;
-                if (cancellationToken.CanBeCanceled)
-                {
-                    ctr = cancellationToken.Register(static obj =>
-                    {
-                        ((CancelQueueState)obj!).TrySetCanceled();
-                    }, tcs);
-                }
-
-                RequestRegistration registration = new RequestRegistration(tokenCount, tcs, ctr);
+                var registration = new RequestRegistration(tokenCount, this, cancellationToken);
                 _queue.EnqueueTail(registration);
                 _queueCount += tokenCount;
                 Debug.Assert(_queueCount <= _options.QueueLimit);
 
-                return new ValueTask<RateLimitLease>(registration.Tcs.Task);
+                return new ValueTask<RateLimitLease>(registration.Task);
             }
         }
 
@@ -288,6 +279,8 @@ namespace System.Threading.RateLimiting
         // Used in tests to avoid dealing with real time
         private void ReplenishInternal(long nowTicks)
         {
+            using var disposer = default(RequestRegistration.Disposer);
+
             // method is re-entrant (from Timer), lock to avoid multiple simultaneous replenishes
             lock (Lock)
             {
@@ -330,13 +323,13 @@ namespace System.Threading.RateLimiting
 
                     // Request was handled already, either via cancellation or being kicked from the queue due to a newer request being queued.
                     // We just need to remove the item and let the next queued item be considered for completion.
-                    if (nextPendingRequest.Tcs.Task.IsCompleted)
+                    if (nextPendingRequest.Task.IsCompleted)
                     {
                         nextPendingRequest =
                             _options.QueueProcessingOrder == QueueProcessingOrder.OldestFirst
                             ? queue.DequeueHead()
                             : queue.DequeueTail();
-                        nextPendingRequest.CancellationTokenRegistration.Dispose();
+                        disposer.Add(nextPendingRequest);
                     }
                     else if (_tokenCount >= nextPendingRequest.Count)
                     {
@@ -350,7 +343,7 @@ namespace System.Threading.RateLimiting
                         _tokenCount -= nextPendingRequest.Count;
                         Debug.Assert(_tokenCount >= 0);
 
-                        if (!nextPendingRequest.Tcs.TrySetResult(SuccessfulLease))
+                        if (!nextPendingRequest.TrySetResult(SuccessfulLease))
                         {
                             // Queued item was canceled so add count back
                             _tokenCount += nextPendingRequest.Count;
@@ -361,7 +354,7 @@ namespace System.Threading.RateLimiting
                         {
                             Interlocked.Increment(ref _successfulLeasesCount);
                         }
-                        nextPendingRequest.CancellationTokenRegistration.Dispose();
+                        disposer.Add(nextPendingRequest);
                         Debug.Assert(_queueCount >= 0);
                     }
                     else
@@ -374,7 +367,6 @@ namespace System.Threading.RateLimiting
                 if (_tokenCount == _options.TokenLimit)
                 {
                     Debug.Assert(_idleSince is null);
-                    Debug.Assert(_queueCount == 0);
                     _idleSince = Stopwatch.GetTimestamp();
                 }
             }
@@ -388,6 +380,7 @@ namespace System.Threading.RateLimiting
                 return;
             }
 
+            using var disposer = default(RequestRegistration.Disposer);
             lock (Lock)
             {
                 if (_disposed)
@@ -401,8 +394,8 @@ namespace System.Threading.RateLimiting
                     RequestRegistration next = _options.QueueProcessingOrder == QueueProcessingOrder.OldestFirst
                         ? _queue.DequeueHead()
                         : _queue.DequeueTail();
-                    next.CancellationTokenRegistration.Dispose();
-                    next.Tcs.TrySetResult(FailedLease);
+                    disposer.Add(next);
+                    next.TrySetResult(FailedLease);
                 }
             }
         }
@@ -452,48 +445,68 @@ namespace System.Threading.RateLimiting
             }
         }
 
-        private readonly struct RequestRegistration
+        private sealed class RequestRegistration : TaskCompletionSource<RateLimitLease>
         {
-            public RequestRegistration(int tokenCount, TaskCompletionSource<RateLimitLease> tcs, CancellationTokenRegistration cancellationTokenRegistration)
-            {
-                Count = tokenCount;
-                // Use VoidAsyncOperationWithData<T> instead
-                Tcs = tcs;
-                CancellationTokenRegistration = cancellationTokenRegistration;
-            }
+            private readonly CancellationToken _cancellationToken;
+            private CancellationTokenRegistration _cancellationTokenRegistration;
 
-            public int Count { get; }
+            // this field is used only by the disposal mechanics and never shared between threads
+            private RequestRegistration? _next;
 
-            public TaskCompletionSource<RateLimitLease> Tcs { get; }
+            public RequestRegistration(int permitCount, TokenBucketRateLimiter limiter, CancellationToken cancellationToken)
+                : base(limiter, TaskCreationOptions.RunContinuationsAsynchronously)
+            {
+                Count = permitCount;
+                _cancellationToken = cancellationToken;
 
-            public CancellationTokenRegistration CancellationTokenRegistration { get; }
-        }
+                // RequestRegistration objects are created while the limiter lock is held
+                // if cancellationToken fires before or while the lock is held, UnsafeRegister
+                // is going to invoke the callback synchronously, but this does not create
+                // a deadlock because lock are reentrant
+                if (cancellationToken.CanBeCanceled)
+#if NETCOREAPP || NETSTANDARD2_1_OR_GREATER
+                    _cancellationTokenRegistration = cancellationToken.UnsafeRegister(Cancel, this);
+#else
+                    _cancellationTokenRegistration = cancellationToken.Register(Cancel, this);
+#endif
+            }
 
-        private sealed class CancelQueueState : TaskCompletionSource<RateLimitLease>
-        {
-            private readonly int _tokenCount;
-            private readonly TokenBucketRateLimiter _limiter;
-            private readonly CancellationToken _cancellationToken;
+            public int Count { get; }
 
-            public CancelQueueState(int tokenCount, TokenBucketRateLimiter limiter, CancellationToken cancellationToken)
-                : base(TaskCreationOptions.RunContinuationsAsynchronously)
+            private static void Cancel(object? state)
             {
-                _tokenCount = tokenCount;
-                _limiter = limiter;
-                _cancellationToken = cancellationToken;
+                if (state is RequestRegistration registration && registration.TrySetCanceled(registration._cancellationToken))
+                {
+                    var limiter = (TokenBucketRateLimiter)registration.Task.AsyncState!;
+                    lock (limiter.Lock)
+                    {
+                        limiter._queueCount -= registration.Count;
+                    }
+                }
             }
 
-            public new bool TrySetCanceled()
+            /// <summary>
+            /// Collects registrations to dispose outside the limiter lock to avoid deadlock.
+            /// </summary>
+            public struct Disposer : IDisposable
             {
-                if (TrySetCanceled(_cancellationToken))
+                private RequestRegistration? _next;
+
+                public void Add(RequestRegistration request)
+                {
+                    request._next = _next;
+                    _next = request;
+                }
+
+                public void Dispose()
                 {
-                    lock (_limiter.Lock)
+                    for (var current = _next; current is not null; current = current._next)
                     {
-                        _limiter._queueCount -= _tokenCount;
+                        current._cancellationTokenRegistration.Dispose();
                     }
-                    return true;
+
+                    _next = null;
                 }
-                return false;
             }
         }
     }
index cc503a0..66d70b6 100644 (file)
@@ -1,6 +1,7 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System.Reflection;
 using System.Threading.Tasks;
 using Xunit;
 
@@ -123,6 +124,51 @@ namespace System.Threading.RateLimiting.Test
             Assert.True(lease.IsAcquired);
         }
 
+#if DEBUG
+        [Fact]
+        public Task DoesNotDeadlockCleaningUpCanceledRequestedLease_Pre() =>
+            DoesNotDeadlockCleaningUpCanceledRequestedLease((limiter, hook) => SetReleasePreHook(limiter, hook));
+
+        [Fact]
+        public Task DoesNotDeadlockCleaningUpCanceledRequestedLease_Post() =>
+            DoesNotDeadlockCleaningUpCanceledRequestedLease((limiter, hook) => SetReleasePostHook(limiter, hook));
+
+        private void SetReleasePreHook(ConcurrencyLimiter limiter, Action hook)
+        {
+            typeof(ConcurrencyLimiter).GetEvent("ReleasePreHook", BindingFlags.NonPublic | BindingFlags.Instance).AddMethod.Invoke(limiter, new object[] { hook });
+        }
+
+        private void SetReleasePostHook(ConcurrencyLimiter limiter, Action hook)
+        {
+            typeof(ConcurrencyLimiter).GetEvent("ReleasePostHook", BindingFlags.NonPublic | BindingFlags.Instance).AddMethod.Invoke(limiter, new object[] { hook });
+        }
+
+        private async Task DoesNotDeadlockCleaningUpCanceledRequestedLease(Action<ConcurrencyLimiter, Action> attachHook)
+        {
+            using var limiter = new ConcurrencyLimiter(new ConcurrencyLimiterOptions
+            {
+                PermitLimit = 1,
+                QueueProcessingOrder = QueueProcessingOrder.OldestFirst,
+                QueueLimit = 1
+            });
+            var lease = limiter.AttemptAcquire(1);
+            Assert.True(lease.IsAcquired);
+
+            var cts = new CancellationTokenSource();
+            _ = limiter.AcquireAsync(1, cts.Token);
+            attachHook(limiter, () =>
+            {
+                Task.Run(cts.Cancel);
+                Thread.Sleep(1);
+            });
+
+            var task1 = Task.Delay(1000);
+            var task2 = Task.Run(lease.Dispose);
+            Assert.Same(task2, await Task.WhenAny(task1, task2));
+            await task2;
+        }
+#endif
+
         [Fact]
         public override async Task FailsWhenQueuingMoreThanLimit_OldestFirst()
         {