QUIC stream limits (#52704)
authorMarie Píchová <11718369+ManickaP@users.noreply.github.com>
Sun, 6 Jun 2021 11:08:37 +0000 (13:08 +0200)
committerGitHub <noreply@github.com>
Sun, 6 Jun 2021 11:08:37 +0000 (13:08 +0200)
Implements the 3rd option Allowing the caller to perform their own wait from #32079 (comment)
Adds WaitForAvailable(Bidi|Uni)rectionalStreamsAsync:
- triggered by peer announcement about new streams (QUIC_CONNECTION_EVENT_TYPE.STREAMS_AVAILABLE)
- if the connection is closed/disposed, the method throws QuicConnectionAbortedException which fitted our H3 better than boolean (can be changed)
Changes stream limit type to int

17 files changed:
src/libraries/Common/tests/System/Net/Http/Http3LoopbackServer.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs
src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs
src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTestBase.SocketsHttpHandler.cs
src/libraries/System.Net.Quic/ref/System.Net.Quic.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockConnection.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockImplementationProvider.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockListener.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicConnectionProvider.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/QuicOptions.cs
src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs
src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamConnectedStreamConformanceTests.cs
src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs

index 1be1364..cf83b89 100644 (file)
@@ -20,26 +20,32 @@ namespace System.Net.Test.Common
 
         public override Uri Address => new Uri($"https://{_listener.ListenEndPoint}/");
 
-        public Http3LoopbackServer(QuicImplementationProvider quicImplementationProvider = null, GenericLoopbackOptions options = null)
+        public Http3LoopbackServer(QuicImplementationProvider quicImplementationProvider = null, Http3Options options = null)
         {
-            options ??= new GenericLoopbackOptions();
+            options ??= new Http3Options();
 
             _cert = Configuration.Certificates.GetServerCertificate();
 
-            var sslOpts = new SslServerAuthenticationOptions
+            var listenerOptions = new QuicListenerOptions()
             {
-                EnabledSslProtocols = options.SslProtocols,
-                ApplicationProtocols = new List<SslApplicationProtocol>
+                ListenEndPoint = new IPEndPoint(options.Address, 0),
+                ServerAuthenticationOptions = new SslServerAuthenticationOptions
                 {
-                    new SslApplicationProtocol("h3-31"),
-                    new SslApplicationProtocol("h3-30"),
-                    new SslApplicationProtocol("h3-29")
+                    EnabledSslProtocols = options.SslProtocols,
+                    ApplicationProtocols = new List<SslApplicationProtocol>
+                    {
+                        new SslApplicationProtocol("h3-31"),
+                        new SslApplicationProtocol("h3-30"),
+                        new SslApplicationProtocol("h3-29")
+                    },
+                    ServerCertificate = _cert,
+                    ClientCertificateRequired = false
                 },
-                ServerCertificate = _cert,
-                ClientCertificateRequired = false
+                MaxUnidirectionalStreams = options.MaxUnidirectionalStreams,
+                MaxBidirectionalStreams = options.MaxBidirectionalStreams,
             };
 
-            _listener = new QuicListener(quicImplementationProvider ?? QuicImplementationProviders.Default, new IPEndPoint(options.Address, 0), sslOpts);
+            _listener = new QuicListener(quicImplementationProvider ?? QuicImplementationProviders.Default, listenerOptions);
         }
 
         public override void Dispose()
@@ -82,7 +88,7 @@ namespace System.Net.Test.Common
 
         public override GenericLoopbackServer CreateServer(GenericLoopbackOptions options = null)
         {
-            return new Http3LoopbackServer(_quicImplementationProvider, options);
+            return new Http3LoopbackServer(_quicImplementationProvider, CreateOptions(options));
         }
 
         public override async Task CreateServerAsync(Func<GenericLoopbackServer, Uri, Task> funcAsync, int millisecondsTimeout = 60000, GenericLoopbackOptions options = null)
@@ -97,5 +103,29 @@ namespace System.Net.Test.Common
             // This method is always unacceptable to call for HTTP/3.
             throw new NotImplementedException("HTTP/3 does not operate over a Socket.");
         }
+
+        private static Http3Options CreateOptions(GenericLoopbackOptions options)
+        {
+            Http3Options http3Options = new Http3Options();
+            if (options != null)
+            {
+                http3Options.Address = options.Address;
+                http3Options.UseSsl = options.UseSsl;
+                http3Options.SslProtocols = options.SslProtocols;
+                http3Options.ListenBacklog = options.ListenBacklog;
+            }
+            return http3Options;
+        }
+    }
+    public class Http3Options : GenericLoopbackOptions
+    {
+        public int MaxUnidirectionalStreams {get; set; }
+
+        public int MaxBidirectionalStreams {get; set; }
+        public Http3Options()
+        {
+            MaxUnidirectionalStreams = 100;
+            MaxBidirectionalStreams = 100;
+        }
     }
 }
index ca6dd5d..51f65a8 100644 (file)
@@ -49,11 +49,6 @@ namespace System.Net.Http
         private int _haveServerQpackDecodeStream;
         private int _haveServerQpackEncodeStream;
 
-        // Manages MAX_STREAM count from server.
-        private long _maximumRequestStreams;
-        private long _requestStreamsRemaining;
-        private readonly Queue<TaskCompletionSourceWithCancellation<bool>> _waitingRequests = new Queue<TaskCompletionSourceWithCancellation<bool>>();
-
         // A connection-level error will abort any future operations.
         private Exception? _abortException;
 
@@ -87,8 +82,6 @@ namespace System.Net.Http
             string altUsedValue = altUsedDefaultPort ? authority.IdnHost : authority.IdnHost + ":" + authority.Port.ToString(Globalization.CultureInfo.InvariantCulture);
             _altUsedEncodedHeader = QPack.QPackEncoder.EncodeLiteralHeaderFieldWithoutNameReferenceToArray(KnownHeaders.AltUsed.Name, altUsedValue);
 
-            _maximumRequestStreams = _requestStreamsRemaining = connection.GetRemoteAvailableBidirectionalStreamCount();
-
             // Errors are observed via Abort().
             _ = SendSettingsAsync();
 
@@ -166,45 +159,34 @@ namespace System.Net.Http
         {
             Debug.Assert(async);
 
-            // Wait for an available stream (based on QUIC MAX_STREAMS) if there isn't one available yet.
-
-            TaskCompletionSourceWithCancellation<bool>? waitForAvailableStreamTcs = null;
-
-            lock (SyncObj)
-            {
-                long remaining = _requestStreamsRemaining;
-
-                if (remaining > 0)
-                {
-                    _requestStreamsRemaining = remaining - 1;
-                }
-                else
-                {
-                    waitForAvailableStreamTcs = new TaskCompletionSourceWithCancellation<bool>();
-                    _waitingRequests.Enqueue(waitForAvailableStreamTcs);
-                }
-            }
-
-            if (waitForAvailableStreamTcs != null)
-            {
-                await waitForAvailableStreamTcs.WaitWithCancellationAsync(cancellationToken).ConfigureAwait(false);
-            }
-
             // Allocate an active request
-
             QuicStream? quicStream = null;
             Http3RequestStream? requestStream = null;
+            ValueTask waitTask = default;
 
             try
             {
-                lock (SyncObj)
+                while (true)
                 {
-                    if (_connection != null)
+                    lock (SyncObj)
                     {
-                        quicStream = _connection.OpenBidirectionalStream();
-                        requestStream = new Http3RequestStream(request, this, quicStream);
-                        _activeRequests.Add(quicStream, requestStream);
+                        if (_connection == null)
+                        {
+                            break;
+                        }
+
+                        if (_connection.GetRemoteAvailableBidirectionalStreamCount() > 0)
+                        {
+                            quicStream = _connection.OpenBidirectionalStream();
+                            requestStream = new Http3RequestStream(request, this, quicStream);
+                            _activeRequests.Add(quicStream, requestStream);
+                            break;
+                        }
+                        waitTask = _connection.WaitForAvailableBidirectionalStreamsAsync(cancellationToken);
                     }
+
+                    // Wait for an available stream (based on QUIC MAX_STREAMS) if there isn't one available yet.
+                    await waitTask.ConfigureAwait(false);
                 }
 
                 if (quicStream == null)
@@ -212,8 +194,6 @@ namespace System.Net.Http
                     throw new HttpRequestException(SR.net_http_request_aborted, null, RequestRetryType.RetryOnConnectionFailure);
                 }
 
-                // 0-byte write to force QUIC to allocate a stream ID.
-                await quicStream.WriteAsync(Array.Empty<byte>(), cancellationToken).ConfigureAwait(false);
                 requestStream!.StreamId = quicStream.StreamId;
 
                 bool goAway;
@@ -247,76 +227,6 @@ namespace System.Net.Http
         }
 
         /// <summary>
-        /// Waits for MAX_STREAMS to be raised by the server.
-        /// </summary>
-        private Task WaitForAvailableRequestStreamAsync(CancellationToken cancellationToken)
-        {
-            TaskCompletionSourceWithCancellation<bool> tcs;
-
-            lock (SyncObj)
-            {
-                long remaining = _requestStreamsRemaining;
-
-                if (remaining > 0)
-                {
-                    _requestStreamsRemaining = remaining - 1;
-                    return Task.CompletedTask;
-                }
-
-                tcs = new TaskCompletionSourceWithCancellation<bool>();
-                _waitingRequests.Enqueue(tcs);
-            }
-
-            // Note: cancellation on connection shutdown is handled in CancelWaiters.
-            return tcs.WaitWithCancellationAsync(cancellationToken).AsTask();
-        }
-
-        /// <summary>
-        /// Cancels any waiting SendAsync calls.
-        /// </summary>
-        /// <remarks>Requires <see cref="SyncObj"/> to be held.</remarks>
-        private void CancelWaiters()
-        {
-            Debug.Assert(Monitor.IsEntered(SyncObj));
-
-            while (_waitingRequests.TryDequeue(out TaskCompletionSourceWithCancellation<bool>? tcs))
-            {
-                tcs.TrySetException(new HttpRequestException(SR.net_http_request_aborted, null, RequestRetryType.RetryOnConnectionFailure));
-            }
-        }
-
-        // TODO: how do we get this event? -> HandleEventStreamsAvailable reports currently available Uni/Bi streams
-        private void OnMaximumStreamCountIncrease(long newMaximumStreamCount)
-        {
-            lock (SyncObj)
-            {
-                if (newMaximumStreamCount <= _maximumRequestStreams)
-                {
-                    return;
-                }
-
-                IncreaseRemainingStreamCount(newMaximumStreamCount - _maximumRequestStreams);
-                _maximumRequestStreams = newMaximumStreamCount;
-            }
-        }
-
-        private void IncreaseRemainingStreamCount(long delta)
-        {
-            Debug.Assert(Monitor.IsEntered(SyncObj));
-            Debug.Assert(delta > 0);
-
-            _requestStreamsRemaining += delta;
-
-            while (_requestStreamsRemaining != 0 && _waitingRequests.TryDequeue(out TaskCompletionSourceWithCancellation<bool>? tcs))
-            {
-                if (tcs.TrySetResult(true))
-                {
-                    --_requestStreamsRemaining;
-                }
-            }
-        }
-
-        /// <summary>
         /// Aborts the connection with an error.
         /// </summary>
         /// <remarks>
@@ -358,7 +268,6 @@ namespace System.Net.Http
                     _connectionClosedTask = _connection.CloseAsync((long)connectionResetErrorCode).AsTask();
                 }
 
-                CancelWaiters();
                 CheckForShutdown();
             }
 
@@ -396,7 +305,6 @@ namespace System.Net.Http
                     }
                 }
 
-                CancelWaiters();
                 CheckForShutdown();
             }
 
@@ -414,8 +322,6 @@ namespace System.Net.Http
                 bool removed = _activeRequests.Remove(stream);
                 Debug.Assert(removed == true);
 
-                IncreaseRemainingStreamCount(1);
-
                 if (ShuttingDown)
                 {
                     CheckForShutdown();
index 0732797..0a19368 100644 (file)
@@ -79,10 +79,12 @@ namespace System.Net.Http.Functional.Tests
         }
 
         [Theory]
+        [InlineData(10)]
         [InlineData(100)]
+        [InlineData(1000)]
         public async Task SendMoreThanStreamLimitRequests_Succeeds(int streamLimit)
         {
-            using Http3LoopbackServer server = CreateHttp3LoopbackServer();
+            using Http3LoopbackServer server = CreateHttp3LoopbackServer(new Http3Options(){ MaxBidirectionalStreams = streamLimit });
 
             Task serverTask = Task.Run(async () =>
             {
@@ -100,7 +102,7 @@ namespace System.Net.Http.Functional.Tests
 
                 for (int i = 0; i < streamLimit + 1; ++i)
                 {
-                    using HttpRequestMessage request = new()
+                    HttpRequestMessage request = new()
                     {
                         Method = HttpMethod.Get,
                         RequestUri = server.Address,
@@ -114,6 +116,162 @@ namespace System.Net.Http.Functional.Tests
             await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000);
         }
 
+        [Theory]
+        [InlineData(10)]
+        [InlineData(100)]
+        [InlineData(1000)]
+        public async Task SendStreamLimitRequestsConcurrently_Succeeds(int streamLimit)
+        {
+            using Http3LoopbackServer server = CreateHttp3LoopbackServer(new Http3Options(){ MaxBidirectionalStreams = streamLimit });
+
+            Task serverTask = Task.Run(async () =>
+            {
+                using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
+                for (int i = 0; i < streamLimit; ++i)
+                {
+                    using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
+                    await stream.HandleRequestAsync();
+                }
+            });
+
+            Task clientTask = Task.Run(async () =>
+            {
+                using HttpClient client = CreateHttpClient();
+
+                var tasks = new Task<HttpResponseMessage>[streamLimit];
+                Parallel.For(0, streamLimit, i =>
+                {
+                    HttpRequestMessage request = new()
+                    {
+                        Method = HttpMethod.Get,
+                        RequestUri = server.Address,
+                        Version = HttpVersion30,
+                        VersionPolicy = HttpVersionPolicy.RequestVersionExact
+                    };
+
+                    tasks[i] = client.SendAsync(request);
+                });
+
+                var responses = await Task.WhenAll(tasks);
+                foreach (var response in responses)
+                {
+                    response.Dispose();
+                }
+            });
+
+            await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000);
+        }
+
+        [Theory]
+        [InlineData(10)]
+        [InlineData(100)]
+        [InlineData(1000)]
+        public async Task SendMoreThanStreamLimitRequestsConcurrently_LastWaits(int streamLimit)
+        {
+            // This combination leads to a hang manifesting in CI only. Disabling it until there's more time to investigate.
+            // [ActiveIssue("https://github.com/dotnet/runtime/issues/53688")]
+            if (streamLimit == 10 && this.UseQuicImplementationProvider == QuicImplementationProviders.Mock)
+            {
+                return;
+            }
+
+            using Http3LoopbackServer server = CreateHttp3LoopbackServer(new Http3Options(){ MaxBidirectionalStreams = streamLimit });
+            var lastRequestContentStarted = new TaskCompletionSource();
+
+            Task serverTask = Task.Run(async () =>
+            {
+                // Read the first streamLimit requests, keep the streams open to make the last one wait.
+                using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
+                var streams = new Http3LoopbackStream[streamLimit];
+                for (int i = 0; i < streamLimit; ++i)
+                {
+                    Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
+                    var body = await stream.ReadRequestDataAsync();
+                    streams[i] = stream;
+                }
+
+                // Make the last request running independently.
+                var lastRequest = Task.Run(async () => {
+                    using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
+                    await stream.HandleRequestAsync();
+                });
+
+                // All the initial streamLimit streams are still opened so the last request cannot started yet.
+                Assert.False(lastRequestContentStarted.Task.IsCompleted);
+
+                // Reply to the first streamLimit requests.
+                for (int i = 0; i < streamLimit; ++i)
+                {
+                    await streams[i].SendResponseAsync();
+                    streams[i].Dispose();
+                    // After the first request is fully processed, the last request should unblock and get processed.
+                    if (i == 0)
+                    {
+                        await lastRequestContentStarted.Task;
+                    }
+                }
+                await lastRequest;
+            });
+
+            Task clientTask = Task.Run(async () =>
+            {
+                using HttpClient client = CreateHttpClient();
+
+                // Fire out the first streamLimit requests in parallel, no waiting for the responses yet.
+                var countdown = new CountdownEvent(streamLimit);
+                var tasks = new Task<HttpResponseMessage>[streamLimit];
+                Parallel.For(0, streamLimit, i =>
+                {
+                    HttpRequestMessage request = new()
+                    {
+                        Method = HttpMethod.Post,
+                        RequestUri = server.Address,
+                        Version = HttpVersion30,
+                        VersionPolicy = HttpVersionPolicy.RequestVersionExact,
+                        Content = new StreamContent(new DelegateStream(
+                            canReadFunc: () => true,
+                            readFunc: (buffer, offset, count) =>
+                            {
+                                countdown.Signal();
+                                return 0;
+                            }))
+                    };
+
+                    tasks[i] = client.SendAsync(request);
+                });
+
+                // Wait for the first streamLimit request to get started.
+                countdown.Wait();
+
+                // Fire out the last request, that should wait until the server fully handles at least one request.
+                HttpRequestMessage last = new()
+                {
+                    Method = HttpMethod.Post,
+                    RequestUri = server.Address,
+                    Version = HttpVersion30,
+                    VersionPolicy = HttpVersionPolicy.RequestVersionExact,
+                    Content = new StreamContent(new DelegateStream(
+                        canReadFunc: () => true,
+                        readFunc: (buffer, offset, count) =>
+                        {
+                            lastRequestContentStarted.SetResult();
+                            return 0;
+                        }))
+                };
+                var lastTask = client.SendAsync(last);
+
+                // Wait for all requests to finish. Whether the last request was pending is checked on the server side.
+                var responses = await Task.WhenAll(tasks);
+                foreach (var response in responses)
+                {
+                    response.Dispose();
+                }
+                await lastTask;
+            });
+
+            await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000);
+        }
+
         [Fact]
         [ActiveIssue("https://github.com/dotnet/runtime/issues/53090")]
         public async Task ReservedFrameType_Throws()
index bff9da8..a702777 100644 (file)
@@ -52,9 +52,9 @@ namespace System.Net.Http.Functional.Tests
             return handler;
         }
 
-        protected Http3LoopbackServer CreateHttp3LoopbackServer()
+        protected Http3LoopbackServer CreateHttp3LoopbackServer(Http3Options options = default)
         {
-            return new Http3LoopbackServer(UseQuicImplementationProvider);
+            return new Http3LoopbackServer(UseQuicImplementationProvider, options);
         }
 
         protected HttpClientHandler CreateHttpClientHandler() => CreateHttpClientHandler(UseVersion, UseQuicImplementationProvider);
@@ -97,7 +97,7 @@ namespace System.Net.Http.Functional.Tests
     internal class VersionHttpClientHandler : HttpClientHandler
     {
         private readonly Version _useVersion;
-        
+
         public VersionHttpClientHandler(Version useVersion)
         {
             _useVersion = useVersion;
@@ -120,7 +120,7 @@ namespace System.Net.Http.Functional.Tests
             {
                 request.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
             }
-            
+
             return base.SendAsync(request, cancellationToken);
         }
 
index 4f4abcf..be6df9a 100644 (file)
@@ -27,10 +27,12 @@ namespace System.Net.Quic
         public System.Threading.Tasks.ValueTask CloseAsync(long errorCode, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
         public System.Threading.Tasks.ValueTask ConnectAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
         public void Dispose() { }
-        public long GetRemoteAvailableBidirectionalStreamCount() { throw null; }
-        public long GetRemoteAvailableUnidirectionalStreamCount() { throw null; }
+        public int GetRemoteAvailableBidirectionalStreamCount() { throw null; }
+        public int GetRemoteAvailableUnidirectionalStreamCount() { throw null; }
         public System.Net.Quic.QuicStream OpenBidirectionalStream() { throw null; }
         public System.Net.Quic.QuicStream OpenUnidirectionalStream() { throw null; }
+        public System.Threading.Tasks.ValueTask WaitForAvailableBidirectionalStreamsAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
+        public System.Threading.Tasks.ValueTask WaitForAvailableUnidirectionalStreamsAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
     }
     public partial class QuicConnectionAbortedException : System.Net.Quic.QuicException
     {
@@ -73,8 +75,8 @@ namespace System.Net.Quic
     {
         public QuicOptions() { }
         public System.TimeSpan IdleTimeout { get { throw null; } set { } }
-        public long MaxBidirectionalStreams { get { throw null; } set { } }
-        public long MaxUnidirectionalStreams { get { throw null; } set { } }
+        public int MaxBidirectionalStreams { get { throw null; } set { } }
+        public int MaxUnidirectionalStreams { get { throw null; } set { } }
     }
     public sealed partial class QuicStream : System.IO.Stream
     {
@@ -101,8 +103,8 @@ namespace System.Net.Quic
         public override long Seek(long offset, System.IO.SeekOrigin origin) { throw null; }
         public override void SetLength(long value) { }
         public void Shutdown() { }
-        public System.Threading.Tasks.ValueTask ShutdownWriteCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
         public System.Threading.Tasks.ValueTask ShutdownCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
+        public System.Threading.Tasks.ValueTask ShutdownWriteCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
         public override void Write(byte[] buffer, int offset, int count) { }
         public override void Write(System.ReadOnlySpan<byte> buffer) { }
         public System.Threading.Tasks.ValueTask WriteAsync(System.Buffers.ReadOnlySequence<byte> buffers, bool endStream, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
index 9a24bc1..3a876d2 100644 (file)
@@ -4,6 +4,7 @@
 using System.Diagnostics;
 using System.Net;
 using System.Net.Security;
+using System.Runtime.ExceptionServices;
 using System.Threading;
 using System.Threading.Channels;
 using System.Threading.Tasks;
@@ -20,11 +21,16 @@ namespace System.Net.Quic.Implementations.Mock
         private object _syncObject = new object();
         private long _nextOutboundBidirectionalStream;
         private long _nextOutboundUnidirectionalStream;
+        private readonly int _maxUnidirectionalStreams;
+        private readonly int _maxBidirectionalStreams;
 
         private ConnectionState? _state;
 
+        internal PeerStreamLimit? LocalStreamLimit => _isClient ? _state?._clientStreamLimit : _state?._serverStreamLimit;
+        internal PeerStreamLimit? RemoteStreamLimit => _isClient ? _state?._serverStreamLimit : _state?._clientStreamLimit;
+
         // Constructor for outbound connections
-        internal MockConnection(EndPoint? remoteEndPoint, SslClientAuthenticationOptions? sslClientAuthenticationOptions, IPEndPoint? localEndPoint = null)
+        internal MockConnection(EndPoint? remoteEndPoint, SslClientAuthenticationOptions? sslClientAuthenticationOptions, IPEndPoint? localEndPoint = null, int maxUnidirectionalStreams = 100, int maxBidirectionalStreams = 100)
         {
             if (remoteEndPoint is null)
             {
@@ -43,6 +49,8 @@ namespace System.Net.Quic.Implementations.Mock
             _sslClientAuthenticationOptions = sslClientAuthenticationOptions;
             _nextOutboundBidirectionalStream = 0;
             _nextOutboundUnidirectionalStream = 2;
+            _maxUnidirectionalStreams = maxUnidirectionalStreams;
+            _maxBidirectionalStreams = maxBidirectionalStreams;
 
             // _state is not initialized until ConnectAsync
         }
@@ -129,7 +137,10 @@ namespace System.Net.Quic.Implementations.Mock
             }
 
             // TODO: deal with protocol negotiation
-            _state = new ConnectionState(_sslClientAuthenticationOptions!.ApplicationProtocols![0]);
+            _state = new ConnectionState(_sslClientAuthenticationOptions!.ApplicationProtocols![0])
+            {
+                _clientStreamLimit = new PeerStreamLimit(_maxUnidirectionalStreams, _maxBidirectionalStreams)
+            };
             if (!listener.TryConnect(_state))
             {
                 throw new QuicException("Connection refused");
@@ -138,8 +149,41 @@ namespace System.Net.Quic.Implementations.Mock
             return ValueTask.CompletedTask;
         }
 
+        internal override ValueTask WaitForAvailableUnidirectionalStreamsAsync(CancellationToken cancellationToken = default)
+        {
+            PeerStreamLimit? streamLimit = RemoteStreamLimit;
+            if (streamLimit is null)
+            {
+                throw new InvalidOperationException("Not connected");
+            }
+
+            return streamLimit.Unidirectional.WaitForAvailableStreams(cancellationToken);
+        }
+
+        internal override ValueTask WaitForAvailableBidirectionalStreamsAsync(CancellationToken cancellationToken = default)
+        {
+            PeerStreamLimit? streamLimit = RemoteStreamLimit;
+            if (streamLimit is null)
+            {
+                throw new InvalidOperationException("Not connected");
+            }
+
+            return streamLimit.Bidirectional.WaitForAvailableStreams(cancellationToken);
+        }
+
         internal override QuicStreamProvider OpenUnidirectionalStream()
         {
+            PeerStreamLimit? streamLimit = RemoteStreamLimit;
+            if (streamLimit is null)
+            {
+                throw new InvalidOperationException("Not connected");
+            }
+
+            if (!streamLimit.Unidirectional.TryIncrement())
+            {
+                throw new QuicException("No available unidirectional stream");
+            }
+
             long streamId;
             lock (_syncObject)
             {
@@ -152,6 +196,17 @@ namespace System.Net.Quic.Implementations.Mock
 
         internal override QuicStreamProvider OpenBidirectionalStream()
         {
+            PeerStreamLimit? streamLimit = RemoteStreamLimit;
+            if (streamLimit is null)
+            {
+                throw new InvalidOperationException("Not connected");
+            }
+
+            if (!streamLimit.Bidirectional.TryIncrement())
+            {
+                throw new QuicException("No available bidirectional stream");
+            }
+
             long streamId;
             lock (_syncObject)
             {
@@ -174,12 +229,30 @@ namespace System.Net.Quic.Implementations.Mock
             Channel<MockStream.StreamState> streamChannel = _isClient ? state._clientInitiatedStreamChannel : state._serverInitiatedStreamChannel;
             streamChannel.Writer.TryWrite(streamState);
 
-            return new MockStream(streamState, true);
+            return new MockStream(this, streamState, true);
         }
 
-        internal override long GetRemoteAvailableUnidirectionalStreamCount() => long.MaxValue;
+        internal override int GetRemoteAvailableUnidirectionalStreamCount()
+        {
+            PeerStreamLimit? streamLimit = RemoteStreamLimit;
+            if (streamLimit is null)
+            {
+                throw new InvalidOperationException("Not connected");
+            }
+
+            return streamLimit.Unidirectional.AvailableCount;
+        }
+
+        internal override int GetRemoteAvailableBidirectionalStreamCount()
+        {
+            PeerStreamLimit? streamLimit = RemoteStreamLimit;
+            if (streamLimit is null)
+            {
+                throw new InvalidOperationException("Not connected");
+            }
 
-        internal override long GetRemoteAvailableBidirectionalStreamCount() => long.MaxValue;
+            return streamLimit.Bidirectional.AvailableCount;
+        }
 
         internal override async ValueTask<QuicStreamProvider> AcceptStreamAsync(CancellationToken cancellationToken = default)
         {
@@ -196,7 +269,7 @@ namespace System.Net.Quic.Implementations.Mock
             try
             {
                 MockStream.StreamState streamState = await streamChannel.Reader.ReadAsync(cancellationToken).ConfigureAwait(false);
-                return new MockStream(streamState, false);
+                return new MockStream(this, streamState, false);
             }
             catch (ChannelClosedException)
             {
@@ -251,6 +324,14 @@ namespace System.Net.Quic.Implementations.Mock
                         Channel<MockStream.StreamState> streamChannel = _isClient ? state._clientInitiatedStreamChannel : state._serverInitiatedStreamChannel;
                         streamChannel.Writer.Complete();
                     }
+
+
+                    PeerStreamLimit? streamLimit = LocalStreamLimit;
+                    if (streamLimit is not null)
+                    {
+                        streamLimit.Unidirectional.CloseWaiters();
+                        streamLimit.Bidirectional.CloseWaiters();
+                    }
                 }
 
                 // TODO: free unmanaged resources (unmanaged objects) and override a finalizer below.
@@ -271,11 +352,77 @@ namespace System.Net.Quic.Implementations.Mock
             GC.SuppressFinalize(this);
         }
 
+        internal sealed class StreamLimit
+        {
+            public readonly int MaxCount;
+
+            private int _actualCount;
+            // Since this is mock, we don't need to be conservative with the allocations.
+            // We keep the TCSes allocated all the time for the simplicity of the code.
+            private TaskCompletionSource _availableTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+            private readonly object _syncRoot = new object();
+
+            public StreamLimit(int maxCount)
+            {
+                MaxCount = maxCount;
+            }
+
+            public int AvailableCount => MaxCount - _actualCount;
+
+            public void Decrement()
+            {
+                lock (_syncRoot)
+                {
+                    --_actualCount;
+                    if (!_availableTcs.Task.IsCompleted)
+                    {
+                        _availableTcs.SetResult();
+                        _availableTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+                    }
+                }
+            }
+
+            public bool TryIncrement()
+            {
+                lock (_syncRoot)
+                {
+                    if (_actualCount < MaxCount)
+                    {
+                        ++_actualCount;
+                        return true;
+                    }
+                    return false;
+                }
+            }
+
+            public ValueTask WaitForAvailableStreams(CancellationToken cancellationToken)
+                => new ValueTask(_availableTcs.Task.WaitAsync(cancellationToken));
+
+            public void CloseWaiters()
+                => _availableTcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException()));
+        }
+
+        internal class PeerStreamLimit
+        {
+            public readonly StreamLimit Unidirectional;
+            public readonly StreamLimit Bidirectional;
+
+            public PeerStreamLimit(int maxUnidirectional, int maxBidirectional)
+            {
+                Unidirectional = new StreamLimit(maxUnidirectional);
+                Bidirectional = new StreamLimit(maxBidirectional);
+            }
+        }
+
         internal sealed class ConnectionState
         {
             public readonly SslApplicationProtocol _applicationProtocol;
             public Channel<MockStream.StreamState> _clientInitiatedStreamChannel;
             public Channel<MockStream.StreamState> _serverInitiatedStreamChannel;
+
+            public PeerStreamLimit? _clientStreamLimit;
+            public PeerStreamLimit? _serverStreamLimit;
+
             public long _clientErrorCode;
             public long _serverErrorCode;
             public bool _closed;
index 03b5361..a46b169 100644 (file)
@@ -16,7 +16,11 @@ namespace System.Net.Quic.Implementations.Mock
 
         internal override QuicConnectionProvider CreateConnection(QuicClientConnectionOptions options)
         {
-            return new MockConnection(options.RemoteEndPoint, options.ClientAuthenticationOptions, options.LocalEndPoint);
+            return new MockConnection(options.RemoteEndPoint,
+                                      options.ClientAuthenticationOptions,
+                                      options.LocalEndPoint,
+                                      options.MaxUnidirectionalStreams,
+                                      options.MaxBidirectionalStreams);
         }
     }
 }
index 826746a..48ebf8a 100644 (file)
@@ -69,6 +69,7 @@ namespace System.Net.Quic.Implementations.Mock
         // Returns false if backlog queue is full.
         internal bool TryConnect(MockConnection.ConnectionState state)
         {
+            state._serverStreamLimit = new MockConnection.PeerStreamLimit(_options.MaxUnidirectionalStreams, _options.MaxBidirectionalStreams);
             return _listenQueue.Writer.TryWrite(state);
         }
 
index 68964fc..bd814f6 100644 (file)
@@ -14,12 +14,14 @@ namespace System.Net.Quic.Implementations.Mock
     {
         private bool _disposed;
         private readonly bool _isInitiator;
+        private readonly MockConnection _connection;
 
         private readonly StreamState _streamState;
         private bool _writesCanceled;
 
-        internal MockStream(StreamState streamState, bool isInitiator)
+        internal MockStream(MockConnection connection, StreamState streamState, bool isInitiator)
         {
+            _connection = connection;
             _streamState = streamState;
             _isInitiator = isInitiator;
         }
@@ -186,7 +188,6 @@ namespace System.Net.Quic.Implementations.Mock
             WriteStreamBuffer?.EndWrite();
         }
 
-
         internal override ValueTask ShutdownWriteCompleted(CancellationToken cancellationToken = default)
         {
             CheckDisposed();
@@ -208,6 +209,15 @@ namespace System.Net.Quic.Implementations.Mock
 
             // This seems to mean shutdown send, in particular, not both.
             WriteStreamBuffer?.EndWrite();
+
+            if (_streamState._inboundStreamBuffer is null) // unidirectional stream
+            {
+                _connection.LocalStreamLimit!.Unidirectional.Decrement();
+            }
+            else
+            {
+                _connection.LocalStreamLimit!.Bidirectional.Decrement();
+            }
         }
 
         private void CheckDisposed()
index 66c8add..76b3693 100644 (file)
@@ -51,6 +51,12 @@ namespace System.Net.Quic.Implementations.MsQuic
             public readonly TaskCompletionSource<uint> ConnectTcs = new TaskCompletionSource<uint>(TaskCreationOptions.RunContinuationsAsynchronously);
             public readonly TaskCompletionSource<uint> ShutdownTcs = new TaskCompletionSource<uint>(TaskCreationOptions.RunContinuationsAsynchronously);
 
+            // Note that there's no such thing as resetable TCS, so we cannot reuse the same instance after we've set the result.
+            // We also cannot use solutions like ManualResetValueTaskSourceCore, since we can have multiple waiters on the same TCS.
+            // As a result, we allocate a new TCS when needed, which is when someone explicitely asks for them in WaitForAvailableStreamsAsync.
+            public TaskCompletionSource? NewUnidirectionalStreamsAvailable;
+            public TaskCompletionSource? NewBidirectionalStreamsAvailable;
+
             public bool Connected;
             public long AbortErrorCode = -1;
 
@@ -192,6 +198,26 @@ namespace System.Net.Quic.Implementations.MsQuic
 
             // Stop accepting new streams.
             state.AcceptQueue.Writer.Complete();
+
+            // Stop notifying about available streams.
+            TaskCompletionSource? unidirectionalTcs = null;
+            TaskCompletionSource? bidirectionalTcs = null;
+            lock (state)
+            {
+                unidirectionalTcs = state.NewBidirectionalStreamsAvailable;
+                bidirectionalTcs = state.NewBidirectionalStreamsAvailable;
+                state.NewUnidirectionalStreamsAvailable = null;
+                state.NewBidirectionalStreamsAvailable = null;
+            }
+
+            if (unidirectionalTcs is not null)
+            {
+                unidirectionalTcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException()));
+            }
+            if (bidirectionalTcs is not null)
+            {
+                bidirectionalTcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException()));
+            }
             return MsQuicStatusCodes.Success;
         }
 
@@ -206,6 +232,32 @@ namespace System.Net.Quic.Implementations.MsQuic
 
         private static uint HandleEventStreamsAvailable(State state, ref ConnectionEvent connectionEvent)
         {
+            TaskCompletionSource? unidirectionalTcs = null;
+            TaskCompletionSource? bidirectionalTcs = null;
+            lock (state)
+            {
+                if (connectionEvent.Data.StreamsAvailable.UniDirectionalCount > 0)
+                {
+                    unidirectionalTcs = state.NewUnidirectionalStreamsAvailable;
+                    state.NewUnidirectionalStreamsAvailable = null;
+                }
+
+                if (connectionEvent.Data.StreamsAvailable.BiDirectionalCount > 0)
+                {
+                    bidirectionalTcs = state.NewBidirectionalStreamsAvailable;
+                    state.NewBidirectionalStreamsAvailable = null;
+                }
+            }
+
+            if (unidirectionalTcs is not null)
+            {
+                unidirectionalTcs.SetResult();
+            }
+            if (bidirectionalTcs is not null)
+            {
+                bidirectionalTcs.SetResult();
+            }
+
             return MsQuicStatusCodes.Success;
         }
 
@@ -329,24 +381,82 @@ namespace System.Net.Quic.Implementations.MsQuic
             return stream;
         }
 
+        internal override ValueTask WaitForAvailableUnidirectionalStreamsAsync(CancellationToken cancellationToken = default)
+        {
+            TaskCompletionSource? tcs = _state.NewUnidirectionalStreamsAvailable;
+            if (tcs is null)
+            {
+                lock (_state)
+                {
+                    if (_state.NewUnidirectionalStreamsAvailable is null)
+                    {
+                        if (_state.ShutdownTcs.Task.IsCompleted)
+                        {
+                            throw new QuicOperationAbortedException();
+                        }
+
+                        if (GetRemoteAvailableUnidirectionalStreamCount() > 0)
+                        {
+                            return ValueTask.CompletedTask;
+                        }
+
+                        _state.NewUnidirectionalStreamsAvailable = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+                    }
+                    tcs = _state.NewUnidirectionalStreamsAvailable;
+                }
+            }
+
+            return new ValueTask(tcs.Task.WaitAsync(cancellationToken));
+        }
+
+        internal override ValueTask WaitForAvailableBidirectionalStreamsAsync(CancellationToken cancellationToken = default)
+        {
+            TaskCompletionSource? tcs = _state.NewBidirectionalStreamsAvailable;
+            if (tcs is null)
+            {
+                lock (_state)
+                {
+                    if (_state.NewBidirectionalStreamsAvailable is null)
+                    {
+                        if (_state.ShutdownTcs.Task.IsCompleted)
+                        {
+                            throw new QuicOperationAbortedException();
+                        }
+
+                        if (GetRemoteAvailableBidirectionalStreamCount() > 0)
+                        {
+                            return ValueTask.CompletedTask;
+                        }
+
+                        _state.NewBidirectionalStreamsAvailable = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+                    }
+                    tcs = _state.NewBidirectionalStreamsAvailable;
+                }
+            }
+
+            return new ValueTask(tcs.Task.WaitAsync(cancellationToken));
+        }
+
         internal override QuicStreamProvider OpenUnidirectionalStream()
         {
             ThrowIfDisposed();
+
             return new MsQuicStream(_state, QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL);
         }
 
         internal override QuicStreamProvider OpenBidirectionalStream()
         {
             ThrowIfDisposed();
+
             return new MsQuicStream(_state, QUIC_STREAM_OPEN_FLAGS.NONE);
         }
 
-        internal override long GetRemoteAvailableUnidirectionalStreamCount()
+        internal override int GetRemoteAvailableUnidirectionalStreamCount()
         {
             return MsQuicParameterHelpers.GetUShortParam(MsQuicApi.Api, _state.Handle, QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.LOCAL_UNIDI_STREAM_COUNT);
         }
 
-        internal override long GetRemoteAvailableBidirectionalStreamCount()
+        internal override int GetRemoteAvailableBidirectionalStreamCount()
         {
             return MsQuicParameterHelpers.GetUShortParam(MsQuicApi.Api, _state.Handle, QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.LOCAL_BIDI_STREAM_COUNT);
         }
index bb1468c..f6c768d 100644 (file)
@@ -64,7 +64,6 @@ namespace System.Net.Quic.Implementations.MsQuic
             // Set once writes have been shutdown.
             public readonly TaskCompletionSource ShutdownWriteCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
 
-
             public ShutdownState ShutdownState;
 
             // Set once stream have been shutdown.
@@ -124,7 +123,7 @@ namespace System.Net.Quic.Implementations.MsQuic
 
                 QuicExceptionHelpers.ThrowIfFailed(status, "Failed to open stream to peer.");
 
-                status = MsQuicApi.Api.StreamStartDelegate(_state.Handle, QUIC_STREAM_START_FLAGS.ASYNC);
+                status = MsQuicApi.Api.StreamStartDelegate(_state.Handle, QUIC_STREAM_START_FLAGS.FAIL_BLOCKED);
                 QuicExceptionHelpers.ThrowIfFailed(status, "Could not start stream.");
             }
             catch
@@ -492,6 +491,7 @@ namespace System.Net.Quic.Implementations.MsQuic
         internal override void Shutdown()
         {
             ThrowIfDisposed();
+
             // it is ok to send shutdown several times, MsQuic will ignore it
             StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0);
         }
@@ -592,7 +592,7 @@ namespace System.Net.Quic.Implementations.MsQuic
                     // Stream has started.
                     // Will only be done for outbound streams (inbound streams have already started)
                     case QUIC_STREAM_EVENT_TYPE.START_COMPLETE:
-                        return HandleStartComplete(state);
+                        return HandleEventStartComplete(state);
                     // Received data on the stream
                     case QUIC_STREAM_EVENT_TYPE.RECEIVE:
                         return HandleEventRecv(state, ref evt);
@@ -678,7 +678,7 @@ namespace System.Net.Quic.Implementations.MsQuic
             return MsQuicStatusCodes.Success;
         }
 
-        private static uint HandleStartComplete(State state)
+        private static uint HandleEventStartComplete(State state)
         {
             bool shouldComplete = false;
             lock (state)
index 9425833..e5153c3 100644 (file)
@@ -16,13 +16,17 @@ namespace System.Net.Quic.Implementations
 
         internal abstract ValueTask ConnectAsync(CancellationToken cancellationToken = default);
 
+        internal abstract ValueTask WaitForAvailableUnidirectionalStreamsAsync(CancellationToken cancellationToken = default);
+
+        internal abstract ValueTask WaitForAvailableBidirectionalStreamsAsync(CancellationToken cancellationToken = default);
+
         internal abstract QuicStreamProvider OpenUnidirectionalStream();
 
         internal abstract QuicStreamProvider OpenBidirectionalStream();
 
-        internal abstract long GetRemoteAvailableUnidirectionalStreamCount();
+        internal abstract int GetRemoteAvailableUnidirectionalStreamCount();
 
-        internal abstract long GetRemoteAvailableBidirectionalStreamCount();
+        internal abstract int GetRemoteAvailableBidirectionalStreamCount();
 
         internal abstract ValueTask<QuicStreamProvider> AcceptStreamAsync(CancellationToken cancellationToken = default);
 
index fd21c41..e91913f 100644 (file)
@@ -68,6 +68,18 @@ namespace System.Net.Quic
         public ValueTask ConnectAsync(CancellationToken cancellationToken = default) => _provider.ConnectAsync(cancellationToken);
 
         /// <summary>
+        /// Waits for available unidirectional stream capacity to be announced by the peer. If any capacity is available, returns immediately.
+        /// </summary>
+        /// <returns></returns>
+        public ValueTask WaitForAvailableUnidirectionalStreamsAsync(CancellationToken cancellationToken = default) => _provider.WaitForAvailableUnidirectionalStreamsAsync(cancellationToken);
+
+        /// <summary>
+        /// Waits for available bidirectional stream capacity to be announced by the peer. If any capacity is available, returns immediately.
+        /// </summary>
+        /// <returns></returns>
+        public ValueTask WaitForAvailableBidirectionalStreamsAsync(CancellationToken cancellationToken = default) => _provider.WaitForAvailableBidirectionalStreamsAsync(cancellationToken);
+
+        /// <summary>
         /// Create an outbound unidirectional stream.
         /// </summary>
         /// <returns></returns>
@@ -95,11 +107,11 @@ namespace System.Net.Quic
         /// <summary>
         /// Gets the maximum number of bidirectional streams that can be made to the peer.
         /// </summary>
-        public long GetRemoteAvailableUnidirectionalStreamCount() => _provider.GetRemoteAvailableUnidirectionalStreamCount();
+        public int GetRemoteAvailableUnidirectionalStreamCount() => _provider.GetRemoteAvailableUnidirectionalStreamCount();
 
         /// <summary>
         /// Gets the maximum number of unidirectional streams that can be made to the peer.
         /// </summary>
-        public long GetRemoteAvailableBidirectionalStreamCount() => _provider.GetRemoteAvailableBidirectionalStreamCount();
+        public int GetRemoteAvailableBidirectionalStreamCount() => _provider.GetRemoteAvailableBidirectionalStreamCount();
     }
 }
index 86dd644..3d02ee3 100644 (file)
@@ -19,14 +19,14 @@ namespace System.Net.Quic
         /// Default is 100.
         /// </summary>
         // TODO consider constraining these limits to 0 to whatever the max of the QUIC library we are using.
-        public long MaxBidirectionalStreams { get; set; } = 100;
+        public int MaxBidirectionalStreams { get; set; } = 100;
 
         /// <summary>
         /// Limit on the number of unidirectional streams the remote peer connection can create on an open connection.
         /// Default is 100.
         /// </summary>
         // TODO consider constraining these limits to 0 to whatever the max of the QUIC library we are using.
-        public long MaxUnidirectionalStreams { get; set; } = 100;
+        public int MaxUnidirectionalStreams { get; set; } = 100;
 
         /// <summary>
         /// Idle timeout for connections, after which the connection will be closed.
index 0e2a227..443f759 100644 (file)
@@ -96,6 +96,56 @@ namespace System.Net.Quic.Tests
         }
 
         [Fact]
+        [ActiveIssue("https://github.com/dotnet/runtime/issues/52048")]
+        public async Task WaitForAvailableUnidirectionStreamsAsyncWorks()
+        {
+            using QuicListener listener = CreateQuicListener(maxUnidirectionalStreams: 1);
+            using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);
+
+            ValueTask clientTask = clientConnection.ConnectAsync();
+            using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
+            await clientTask;
+
+            // No stream openned yet, should return immediately.
+            Assert.True(clientConnection.WaitForAvailableUnidirectionalStreamsAsync().IsCompletedSuccessfully);
+
+            // Open one stream, should wait till it closes.
+            QuicStream stream = clientConnection.OpenUnidirectionalStream();
+            ValueTask waitTask = clientConnection.WaitForAvailableUnidirectionalStreamsAsync();
+            Assert.False(waitTask.IsCompleted);
+            Assert.Throws<QuicException>(() => clientConnection.OpenUnidirectionalStream());
+
+            // Close the stream, the waitTask should finish as a result.
+            stream.Dispose();
+            await waitTask.AsTask().WaitAsync(TimeSpan.FromSeconds(10));
+        }
+
+        [Fact]
+        [ActiveIssue("https://github.com/dotnet/runtime/issues/52048")]
+        public async Task WaitForAvailableBidirectionStreamsAsyncWorks()
+        {
+            using QuicListener listener = CreateQuicListener(maxBidirectionalStreams: 1);
+            using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);
+
+            ValueTask clientTask = clientConnection.ConnectAsync();
+            using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
+            await clientTask;
+
+            // No stream openned yet, should return immediately.
+            Assert.True(clientConnection.WaitForAvailableBidirectionalStreamsAsync().IsCompletedSuccessfully);
+
+            // Open one stream, should wait till it closes.
+            QuicStream stream = clientConnection.OpenBidirectionalStream();
+            ValueTask waitTask = clientConnection.WaitForAvailableBidirectionalStreamsAsync();
+            Assert.False(waitTask.IsCompleted);
+            Assert.Throws<QuicException>(() => clientConnection.OpenBidirectionalStream());
+
+            // Close the stream, the waitTask should finish as a result.
+            stream.Dispose();
+            await waitTask.AsTask().WaitAsync(TimeSpan.FromSeconds(10));
+        }
+
+        [Fact]
         [OuterLoop("May take several seconds")]
         public async Task SetListenerTimeoutWorksWithSmallTimeout()
         {
@@ -234,7 +284,7 @@ namespace System.Net.Quic.Tests
             int res = await serverStream.ReadAsync(memory);
             Assert.Equal(12, res);
             ReadOnlyMemory<ReadOnlyMemory<byte>> romrom = new ReadOnlyMemory<ReadOnlyMemory<byte>>(new ReadOnlyMemory<byte>[] { helloWorld, helloWorld });
-            
+
             await clientStream.WriteAsync(romrom);
 
             res = await serverStream.ReadAsync(memory);
@@ -254,7 +304,7 @@ namespace System.Net.Quic.Tests
                 {
                     var acceptTask = serverConnection.AcceptStreamAsync();
                     await serverConnection.CloseAsync(errorCode: 0);
-                    // make sure 
+                    // make sure
                     await Assert.ThrowsAsync<QuicOperationAbortedException>(() => acceptTask.AsTask());
                 });
         }
index ad7b74c..6c3670b 100644 (file)
@@ -92,7 +92,7 @@ namespace System.Net.Quic.Tests
                 ServerCertificate = System.Net.Test.Common.Configuration.Certificates.GetServerCertificate()
             };
         }
-        
+
         protected abstract QuicImplementationProvider Provider { get; }
 
         protected override async Task<StreamPair> CreateConnectedStreamsAsync()
index 027d0ad..ee75018 100644 (file)
@@ -53,16 +53,30 @@ namespace System.Net.Quic.Tests
             return new QuicConnection(ImplementationProvider, endpoint, GetSslClientAuthenticationOptions());
         }
 
-        internal QuicListener CreateQuicListener()
+        internal QuicListener CreateQuicListener(int maxUnidirectionalStreams = 100, int maxBidirectionalStreams = 100)
         {
-            return CreateQuicListener(new IPEndPoint(IPAddress.Loopback, 0));
+            var options = new QuicListenerOptions()
+            {
+                ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0),
+                ServerAuthenticationOptions = GetSslServerAuthenticationOptions(),
+                MaxUnidirectionalStreams = maxUnidirectionalStreams,
+                MaxBidirectionalStreams = maxBidirectionalStreams
+            };
+            return CreateQuicListener(options);
         }
 
         internal QuicListener CreateQuicListener(IPEndPoint endpoint)
         {
-            return new QuicListener(ImplementationProvider, endpoint, GetSslServerAuthenticationOptions());
+            var options = new QuicListenerOptions()
+            {
+                ListenEndPoint = endpoint,
+                ServerAuthenticationOptions = GetSslServerAuthenticationOptions()
+            };
+            return CreateQuicListener(options);
         }
 
+        private QuicListener CreateQuicListener(QuicListenerOptions options) => new QuicListener(ImplementationProvider, options);
+
         internal async Task RunClientServer(Func<QuicConnection, Task> clientFunction, Func<QuicConnection, Task> serverFunction, int iterations = 1, int millisecondsTimeout = 10_000)
         {
             using QuicListener listener = CreateQuicListener();