simplify SslStream.AuthenticateAs*Async() (#453)
authorTomas Weinfurt <tweinfurt@yahoo.com>
Wed, 18 Dec 2019 18:52:52 +0000 (10:52 -0800)
committerGitHub <noreply@github.com>
Wed, 18 Dec 2019 18:52:52 +0000 (10:52 -0800)
* port from corefx

* add TestHelper.cs

* capture handshake exception

* feedback from review

* cleanup more apm and renegotiation

* fix unit test

* feedback from review

* feedback from review

* feedback from review

* fix test after merge

* feedback from review

* add back ProcessAuthentication as main entry point

18 files changed:
src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs
src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs
src/libraries/Common/tests/System/Net/Capability.Security.cs
src/libraries/Common/tests/System/Net/Configuration.Security.cs
src/libraries/System.Net.Security/src/System/Net/HelperAsyncResults.cs
src/libraries/System.Net.Security/src/System/Net/Security/Pal.OSX/SafeDeleteSslContext.cs
src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs
src/libraries/System.Net.Security/src/System/Net/Security/SniHelper.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.OSX.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs
src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNegotiatedCipherSuiteTest.cs
src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs
src/libraries/System.Net.Security/tests/FunctionalTests/System.Net.Security.Tests.csproj
src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs [new file with mode: 0644]
src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs

index e505f3e..75aa0d4 100644 (file)
@@ -253,15 +253,15 @@ internal static partial class Interop
             return context;
         }
 
-        internal static bool DoSslHandshake(SafeSslHandle context, byte[] recvBuf, int recvOffset, int recvCount, out byte[] sendBuf, out int sendCount)
+        internal static bool DoSslHandshake(SafeSslHandle context, ReadOnlySpan<byte> input, out byte[] sendBuf, out int sendCount)
         {
             sendBuf = null;
             sendCount = 0;
             Exception handshakeException = null;
 
-            if ((recvBuf != null) && (recvCount > 0))
+            if (input.Length > 0)
             {
-                if (BioWrite(context.InputBio, recvBuf, recvOffset, recvCount) <= 0)
+                if (Ssl.BioWrite(context.InputBio, ref MemoryMarshal.GetReference(input), input.Length) != input.Length)
                 {
                     // Make sure we clear out the error that is stored in the queue
                     throw Crypto.CreateOpenSslCryptographicException();
@@ -321,7 +321,7 @@ internal static partial class Interop
             return stateOk;
         }
 
-        internal static int Encrypt(SafeSslHandle context, ReadOnlyMemory<byte> input, ref byte[] output, out Ssl.SslErrorCode errorCode)
+        internal static int Encrypt(SafeSslHandle context, ReadOnlySpan<byte> input, ref byte[] output, out Ssl.SslErrorCode errorCode)
         {
 #if DEBUG
             ulong assertNoError = Crypto.ErrPeekError();
@@ -334,13 +334,7 @@ internal static partial class Interop
 
             lock (context)
             {
-                unsafe
-                {
-                    using (MemoryHandle handle = input.Pin())
-                    {
-                        retVal = Ssl.SslWrite(context, (byte*)handle.Pointer, input.Length);
-                    }
-                }
+                retVal = Ssl.SslWrite(context, ref MemoryMarshal.GetReference(input), input.Length);
 
                 if (retVal != input.Length)
                 {
index 02148f3..2246b5f 100644 (file)
@@ -72,7 +72,7 @@ internal static partial class Interop
         }
 
         [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslWrite")]
-        internal static extern unsafe int SslWrite(SafeSslHandle ssl, byte* buf, int num);
+        internal static extern unsafe int SslWrite(SafeSslHandle ssl, ref byte buf, int num);
 
         [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslRead")]
         internal static extern unsafe int SslRead(SafeSslHandle ssl, byte* buf, int num);
@@ -101,6 +101,9 @@ internal static partial class Interop
         [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_BioWrite")]
         internal static extern unsafe int BioWrite(SafeBioHandle b, byte* data, int len);
 
+        [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_BioWrite")]
+        internal static extern unsafe int BioWrite(SafeBioHandle b, ref byte data, int len);
+
         [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetPeerCertificate")]
         internal static extern SafeX509Handle SslGetPeerCertificate(SafeSslHandle ssl);
 
index 3de857a..fe78e01 100644 (file)
@@ -50,6 +50,18 @@ namespace System.Net.Test.Common
             {
                 return true;
             }
+
+            return false;
+        }
+
+        public static bool SecurityForceSocketStreams()
+        {
+            string value = Configuration.Security.SecurityForceSocketStreams;
+            if (value != null && (value.Equals("true", StringComparison.OrdinalIgnoreCase) || value.Equals("1")))
+            {
+                return true;
+            }
+
             return false;
         }
 
index c60e158..b09ad67 100644 (file)
@@ -35,6 +35,8 @@ namespace System.Net.Test.Common
             //      127.0.0.1 testclienteku.contoso.com
 
             public static string HostsFileNamesInstalled => GetValue("COREFX_NET_SECURITY_HOSTS_FILE_INSTALLED");
+            // Allows packet captures.
+            public static string SecurityForceSocketStreams => GetValue("COREFX_NET_SECURITY_FORCE_SOCKET_STREAMS");
         }
     }
 }
index 5d87ae5..a6a2ffd 100644 (file)
@@ -20,10 +20,6 @@ namespace System.Net
     //
     internal class AsyncProtocolRequest
     {
-#if DEBUG
-        internal object _DebugAsyncChain;         // Optionally used to track chains of async calls.
-#endif
-
         private AsyncProtocolCallback _callback;
         private int _completionStatus;
 
@@ -33,7 +29,6 @@ namespace System.Net
 
         public LazyAsyncResult UserAsyncResult;
         public int Result;
-        public object AsyncState;
         public readonly CancellationToken CancellationToken;
 
         public byte[] Buffer; // Temporary buffer reused by a protocol.
index 82a273a..1db55ec 100644 (file)
@@ -217,15 +217,18 @@ namespace System.Net
             Debug.Assert(count >= 0);
             Debug.Assert(count <= buf.Length - offset);
 
+            Write(buf.AsSpan(offset, count));
+        }
 
+        internal void Write(ReadOnlySpan<byte> buf)
+        {
             lock (_fromConnection)
             {
-                for (int i = 0; i < count; i++)
+                foreach (byte b in buf)
                 {
-                    _fromConnection.Enqueue(buf[offset + i]);
+                    _fromConnection.Enqueue(b);
                 }
             }
-
         }
 
         internal int BytesReadyForConnection => _toConnection.Count;
index e3cd006..7486f63 100644 (file)
@@ -626,7 +626,7 @@ namespace System.Net.Security
         //
         // Acquire Server Side Certificate information and set it on the class.
         //
-        private bool AcquireServerCredentials(ref byte[] thumbPrint, byte[] clientHello)
+        private bool AcquireServerCredentials(ref byte[] thumbPrint, ReadOnlySpan<byte> clientHello)
         {
             if (NetEventSource.IsEnabled)
                 NetEventSource.Enter(this);
@@ -797,7 +797,7 @@ namespace System.Net.Security
                     if (_refreshCredentialNeeded)
                     {
                         cachedCreds = _sslAuthenticationOptions.IsServer
-                                        ? AcquireServerCredentials(ref thumbPrint, input)
+                                        ? AcquireServerCredentials(ref thumbPrint, new ReadOnlySpan<byte>(input, offset, count))
                                         : AcquireClientCredentials(ref thumbPrint);
                     }
 
@@ -806,7 +806,7 @@ namespace System.Net.Security
                         status = SslStreamPal.AcceptSecurityContext(
                                       ref _credentialsHandle,
                                       ref _securityContext,
-                                      input != null ? new ArraySegment<byte>(input, offset, count) : default,
+                                      input, offset, count,
                                       ref result,
                                       _sslAuthenticationOptions);
                     }
@@ -816,7 +816,7 @@ namespace System.Net.Security
                                        ref _credentialsHandle,
                                        ref _securityContext,
                                        _sslAuthenticationOptions.TargetHost,
-                                      input != null ? new ArraySegment<byte>(input, offset, count) : default,
+                                       input, offset, count,
                                        ref result,
                                        _sslAuthenticationOptions);
                     }
index 63eb8b4..c52e0ba 100644 (file)
@@ -16,12 +16,7 @@ namespace System.Net.Security
         private static readonly IdnMapping s_idnMapping = CreateIdnMapping();
         private static readonly Encoding s_encoding = CreateEncoding();
 
-        public static string GetServerName(byte[] clientHello)
-        {
-            return GetSniFromSslPlainText(clientHello);
-        }
-
-        private static string GetSniFromSslPlainText(ReadOnlySpan<byte> sslPlainText)
+        public static string GetServerName(ReadOnlySpan<byte> sslPlainText)
         {
             // https://tools.ietf.org/html/rfc6101#section-5.2.1
             // struct {
index be9300a..e400afb 100644 (file)
@@ -19,9 +19,6 @@ namespace System.Net.Security
     public partial class SslStream
     {
         private static int s_uniqueNameInteger = 123;
-        private static readonly AsyncProtocolCallback s_partialFrameCallback = new AsyncProtocolCallback(PartialFrameCallback);
-        private static readonly AsyncProtocolCallback s_readFrameCallback = new AsyncProtocolCallback(ReadFrameCallback);
-        private static readonly AsyncCallback s_writeCallback = new AsyncCallback(WriteCallback);
 
         private SslAuthenticationOptions _sslAuthenticationOptions;
 
@@ -38,11 +35,30 @@ namespace System.Net.Security
         }
         private CachedSessionStatus _CachedSession;
 
+        private enum Framing
+        {
+            Unknown = 0,
+            BeforeSSL3,
+            SinceSSL3,
+            Unified,
+            Invalid
+        }
+
+        // This is set on the first packet to figure out the framing style.
+        private Framing _framing = Framing.Unknown;
+
+        // SSL3/TLS protocol frames definitions.
+        private enum FrameType : byte
+        {
+            ChangeCipherSpec = 20,
+            Alert = 21,
+            Handshake = 22,
+            AppData = 23
+        }
+
         // This block is used by re-handshake code to buffer data decrypted with the old key.
         private byte[] _queuedReadData;
         private int _queuedReadCount;
-        private bool _pendingReHandshake;
-        private const int MaxQueuedReadBytes = 1024 * 128;
 
         //
         // This block is used to rule the >>re-handshakes<< that are concurrent with read/write I/O requests.
@@ -192,31 +208,6 @@ namespace System.Net.Security
         }
 
         //
-        //  Called by re-handshake if found data decrypted with the old key
-        //
-        private Exception EnqueueOldKeyDecryptedData(byte[] buffer, int offset, int count)
-        {
-            lock (SyncLock)
-            {
-                if (_queuedReadCount + count > MaxQueuedReadBytes)
-                {
-                    return ExceptionDispatchInfo.SetCurrentStackTrace(
-                        new IOException(SR.Format(SR.net_auth_ignored_reauth, MaxQueuedReadBytes.ToString(NumberFormatInfo.CurrentInfo))));
-                }
-
-                if (count != 0)
-                {
-                    // This is inefficient yet simple and that should be a rare case of receiving data encrypted with "old" key.
-                    _queuedReadData = EnsureBufferSize(_queuedReadData, _queuedReadCount, _queuedReadCount + count);
-                    Buffer.BlockCopy(buffer, offset, _queuedReadData, _queuedReadCount, count);
-                    _queuedReadCount += count;
-                    FinishHandshakeRead(LockHandshake);
-                }
-            }
-            return null;
-        }
-
-        //
         // When re-handshaking the "old" key decrypted data are queued until the handshake is done.
         // When stream calls for decryption we will feed it queued data left from "old" encryption key.
         //
@@ -249,34 +240,29 @@ namespace System.Net.Security
         // This method assumes that a SSPI context is already in a good shape.
         // For example it is either a fresh context or already authenticated context that needs renegotiation.
         //
-        private void ProcessAuthentication(LazyAsyncResult lazyResult, CancellationToken cancellationToken)
+        private Task ProcessAuthentication(bool isAsync = false, bool isApm = false, CancellationToken cancellationToken = default)
         {
+            Task result = null;
             if (Interlocked.Exchange(ref _nestedAuth, 1) == 1)
             {
-                throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, lazyResult == null ? "BeginAuthenticate" : "Authenticate", "authenticate"));
+                throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, isApm ? "BeginAuthenticate" :  "Authenticate", "authenticate"));
             }
 
             try
             {
                 ThrowIfExceptional();
-                AsyncProtocolRequest asyncRequest = null;
-                if (lazyResult != null)
-                {
-                    asyncRequest = new AsyncProtocolRequest(lazyResult, cancellationToken);
-                    asyncRequest.Buffer = null;
-#if DEBUG
-                    lazyResult._debugAsyncChain = asyncRequest;
-#endif
-                }
 
                 //  A trick to discover and avoid cached sessions.
                 _CachedSession = CachedSessionStatus.Unknown;
 
-                ForceAuthentication(_context.IsServer, null, asyncRequest);
-
-                // Not aync so the connection is completed at this point.
-                if (lazyResult == null && NetEventSource.IsEnabled)
+                if (isAsync)
                 {
+                    result = ForceAuthenticationAsync(_context.IsServer, null, cancellationToken);
+                }
+                else
+                {
+                    ForceAuthentication(_context.IsServer, null);
+
                     if (NetEventSource.IsEnabled)
                         NetEventSource.Log.SspiSelectedCipherSuite(nameof(ProcessAuthentication),
                                                                     SslProtocol,
@@ -288,50 +274,28 @@ namespace System.Net.Security
                                                                     KeyExchangeStrength);
                 }
             }
-            catch (Exception)
-            {
-                // If an exception emerges synchronously, the asynchronous operation was not
-                // initiated, so no operation is in progress.
-                _nestedAuth = 0;
-                throw;
-            }
             finally
             {
-                // For synchronous operations, the operation has completed.
-                if (lazyResult == null)
-                {
-                    _nestedAuth = 0;
-                }
+                // Operation has completed.
+                _nestedAuth = 0;
             }
+
+            return result;
         }
 
         //
         // This is used to reply on re-handshake when received SEC_I_RENEGOTIATE on Read().
         //
-        private void ReplyOnReAuthentication(byte[] buffer, CancellationToken cancellationToken)
+        private async Task ReplyOnReAuthenticationAsync(byte[] buffer, CancellationToken cancellationToken)
         {
             lock (SyncLock)
             {
                 // Note we are already inside the read, so checking for already going concurrent handshake.
                 _lockReadState = LockHandshake;
-
-                if (_pendingReHandshake)
-                {
-                    // A concurrent handshake is pending, resume.
-                    FinishRead(buffer);
-                    return;
-                }
             }
 
-            // Start rehandshake from here.
-
-            // Forcing async mode.  The caller will queue another Read as soon as we return using its preferred
-            // calling convention, which will be woken up when the handshake completes.  The callback is just
-            // to capture any SocketErrors that happen during the handshake so they can be surfaced from the Read.
-            AsyncProtocolRequest asyncRequest = new AsyncProtocolRequest(new LazyAsyncResult(this, null, new AsyncCallback(RehandshakeCompleteCallback)), cancellationToken);
-            // Buffer contains a result from DecryptMessage that will be passed to ISC/ASC
-            asyncRequest.Buffer = buffer;
-            ForceAuthentication(false, buffer, asyncRequest);
+            await ForceAuthenticationAsync(false, buffer, cancellationToken).ConfigureAwait(false);
+            FinishHandshakeRead(LockNone);
         }
 
         //
@@ -339,35 +303,28 @@ namespace System.Net.Security
         // Incoming buffer is either null or is the result of "renegotiate" decrypted message
         // If write is in progress the method will either wait or be put on hold
         //
-        private void ForceAuthentication(bool receiveFirst, byte[] buffer, AsyncProtocolRequest asyncRequest)
+        private void ForceAuthentication(bool receiveFirst, byte[] buffer)
         {
-            if (CheckEnqueueHandshake(buffer, asyncRequest))
-            {
-                // Async handshake is enqueued and will resume later.
-                return;
-            }
-            // Either Sync handshake is ready to go or async handshake won the race over write.
-
             // This will tell that we don't know the framing yet (what SSL version is)
-            _Framing = Framing.Unknown;
+            _framing = Framing.Unknown;
 
             try
             {
                 if (receiveFirst)
                 {
                     // Listen for a client blob.
-                    StartReceiveBlob(buffer, asyncRequest);
+                    ReceiveBlob(buffer);
                 }
                 else
                 {
                     // We start with the first blob.
-                    StartSendBlob(buffer, (buffer == null ? 0 : buffer.Length), asyncRequest);
+                    SendBlob(buffer, (buffer == null ? 0 : buffer.Length));
                 }
             }
             catch (Exception e)
             {
                 // Failed auth, reset the framing if any.
-                _Framing = Framing.Unknown;
+                _framing = Framing.Unknown;
                 _handshakeCompleted = false;
 
                 SetException(e);
@@ -382,67 +339,66 @@ namespace System.Net.Security
                 if (_exception != null)
                 {
                     // This a failed handshake. Release waiting IO if any.
-                    FinishHandshake(null, null);
+                    FinishHandshake(null);
                 }
             }
         }
 
-        private void EndProcessAuthentication(IAsyncResult result)
+        internal async Task ForceAuthenticationAsync(bool receiveFirst, byte[] buffer, CancellationToken cancellationToken)
         {
-            if (result == null)
-            {
-                throw new ArgumentNullException("asyncResult");
-            }
+            _framing = Framing.Unknown;
+            ProtocolToken message;
+            SslReadAsync adapter = new SslReadAsync(this, cancellationToken);
 
-            LazyAsyncResult lazyResult = result as LazyAsyncResult;
-            if (lazyResult == null)
+            if (!receiveFirst)
             {
-                throw new ArgumentException(SR.Format(SR.net_io_async_result, result.GetType().FullName), "asyncResult");
+                message = _context.NextMessage(buffer, 0, (buffer == null ? 0 : buffer.Length));
+                if (message.Failed)
+                {
+                    // tracing done in NextMessage()
+                    throw new AuthenticationException(SR.net_auth_SSPI, message.GetException());
+                }
+
+                await InnerStream.WriteAsync(message.Payload, cancellationToken).ConfigureAwait(false);
             }
 
-            if (Interlocked.Exchange(ref _nestedAuth, 0) == 0)
+            do
             {
-                throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, "EndAuthenticate"));
-            }
+                message  = await ReceiveBlobAsync(adapter, buffer, cancellationToken).ConfigureAwait(false);
+                if (message.Size > 0)
+                {
+                    // If there is message send it out even if call failed. It may contain TLS Alert.
+                    await InnerStream.WriteAsync(message.Payload, cancellationToken).ConfigureAwait(false);
+                }
 
-            InternalEndProcessAuthentication(lazyResult);
+                if (message.Failed)
+                {
+                    throw new AuthenticationException(SR.net_auth_SSPI, message.GetException());
+                }
+            } while (message.Status.ErrorCode != SecurityStatusPalErrorCode.OK);
 
-            // Connection is completed at this point.
-            if (NetEventSource.IsEnabled)
+            ProtocolToken alertToken = null;
+            if (!CompleteHandshake(ref alertToken))
             {
-                if (NetEventSource.IsEnabled)
-                    NetEventSource.Log.SspiSelectedCipherSuite(nameof(EndProcessAuthentication),
-                                                                SslProtocol,
-                                                                CipherAlgorithm,
-                                                                CipherStrength,
-                                                                HashAlgorithm,
-                                                                HashStrength,
-                                                                KeyExchangeAlgorithm,
-                                                                KeyExchangeStrength);
+                SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_validation, null)));
             }
-        }
-
-        private void InternalEndProcessAuthentication(LazyAsyncResult lazyResult)
-        {
-            // No "artificial" timeouts implemented so far, InnerStream controls that.
-            lazyResult.InternalWaitForCompletion();
-            Exception e = lazyResult.Result as Exception;
 
-            if (e != null)
-            {
-                // Failed auth, reset the framing if any.
-                _Framing = Framing.Unknown;
-                _handshakeCompleted = false;
+            if (NetEventSource.IsEnabled)
+                NetEventSource.Log.SspiSelectedCipherSuite(nameof(ForceAuthenticationAsync),
+                                                                    SslProtocol,
+                                                                    CipherAlgorithm,
+                                                                    CipherStrength,
+                                                                    HashAlgorithm,
+                                                                    HashStrength,
+                                                                    KeyExchangeAlgorithm,
+                                                                    KeyExchangeStrength);
 
-                SetException(e);
-                ThrowIfExceptional();
-            }
         }
 
         //
         // Client side starts here, but server also loops through this method.
         //
-        private void StartSendBlob(byte[] incoming, int count, AsyncProtocolRequest asyncRequest)
+        private void SendBlob(byte[] incoming, int count)
         {
             ProtocolToken message = _context.NextMessage(incoming, 0, count);
             _securityStatus = message.Status;
@@ -458,125 +414,65 @@ namespace System.Net.Security
                     _CachedSession = message.Size < 200 ? CachedSessionStatus.IsCached : CachedSessionStatus.IsNotCached;
                 }
 
-                if (_Framing == Framing.Unified)
+                if (_framing == Framing.Unified)
                 {
-                    _Framing = DetectFraming(message.Payload, message.Payload.Length);
+                    _framing = DetectFraming(message.Payload, message.Payload.Length);
                 }
 
-                if (asyncRequest == null)
-                {
-                    InnerStream.Write(message.Payload, 0, message.Size);
-                }
-                else
-                {
-                    asyncRequest.AsyncState = message;
-                    Task t = InnerStream.WriteAsync(message.Payload, 0, message.Size, asyncRequest.CancellationToken);
-                    if (t.IsCompleted)
-                    {
-                        t.GetAwaiter().GetResult();
-                    }
-                    else
-                    {
-                        IAsyncResult ar = TaskToApm.Begin(t, s_writeCallback, asyncRequest);
-                        if (!ar.CompletedSynchronously)
-                        {
-#if DEBUG
-                            asyncRequest._DebugAsyncChain = ar;
-#endif
-                            return;
-                        }
-                        TaskToApm.End(ar);
-                    }
-                }
+                InnerStream.Write(message.Payload, 0, message.Size);
             }
 
-            CheckCompletionBeforeNextReceive(message, asyncRequest);
+            CheckCompletionBeforeNextReceive(message);
         }
 
         //
         // This will check and logically complete / fail the auth handshake.
         //
-        private void CheckCompletionBeforeNextReceive(ProtocolToken message, AsyncProtocolRequest asyncRequest)
+        private void CheckCompletionBeforeNextReceive(ProtocolToken message)
         {
             if (message.Failed)
             {
-                StartSendAuthResetSignal(null, asyncRequest, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_auth_SSPI, message.GetException())));
+                SendAuthResetSignal(null, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_auth_SSPI, message.GetException())));
                 return;
             }
-            else if (message.Done && !_pendingReHandshake)
+            else if (message.Done)
             {
                 ProtocolToken alertToken = null;
 
                 if (!CompleteHandshake(ref alertToken))
                 {
-                    StartSendAuthResetSignal(alertToken, asyncRequest, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_validation, null)));
+                    SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_validation, null)));
                     return;
                 }
 
                 // Release waiting IO if any. Presumably it should not throw.
                 // Otherwise application may get not expected type of the exception.
-                FinishHandshake(null, asyncRequest);
+                FinishHandshake(null);
                 return;
             }
 
-            StartReceiveBlob(message.Payload, asyncRequest);
+            ReceiveBlob(message.Payload);
         }
 
         //
         // Server side starts here, but client also loops through this method.
         //
-        private void StartReceiveBlob(byte[] buffer, AsyncProtocolRequest asyncRequest)
+        private void ReceiveBlob(byte[] buffer)
         {
-            if (_pendingReHandshake)
-            {
-                if (CheckEnqueueHandshakeRead(ref buffer, asyncRequest))
-                {
-                    return;
-                }
-
-                if (!_pendingReHandshake)
-                {
-                    // Renegotiate: proceed to the next step.
-                    ProcessReceivedBlob(buffer, buffer.Length, asyncRequest);
-                    return;
-                }
-            }
-
             //This is first server read.
             buffer = EnsureBufferSize(buffer, 0, SecureChannel.ReadHeaderSize);
 
-            int readBytes = 0;
-            if (asyncRequest == null)
-            {
-                readBytes = FixedSizeReader.ReadPacket(_innerStream, buffer, 0, SecureChannel.ReadHeaderSize);
-            }
-            else
-            {
-                asyncRequest.SetNextRequest(buffer, 0, SecureChannel.ReadHeaderSize, s_partialFrameCallback);
-                _ = FixedSizeReader.ReadPacketAsync(_innerStream, asyncRequest);
-                if (!asyncRequest.MustCompleteSynchronously)
-                {
-                    return;
-                }
-
-                readBytes = asyncRequest.Result;
-            }
+            int readBytes = FixedSizeReader.ReadPacket(_innerStream, buffer, 0, SecureChannel.ReadHeaderSize);
 
-            StartReadFrame(buffer, readBytes, asyncRequest);
-        }
-
-        //
-        private void StartReadFrame(byte[] buffer, int readBytes, AsyncProtocolRequest asyncRequest)
-        {
             if (readBytes == 0)
             {
                 // EOF received
                 throw new IOException(SR.net_auth_eof);
             }
 
-            if (_Framing == Framing.Unknown)
+            if (_framing == Framing.Unknown)
             {
-                _Framing = DetectFraming(buffer, readBytes);
+                _framing = DetectFraming(buffer, readBytes);
             }
 
             int restBytes = GetRemainingFrameSize(buffer, 0, readBytes);
@@ -594,81 +490,57 @@ namespace System.Net.Security
 
             buffer = EnsureBufferSize(buffer, readBytes, readBytes + restBytes);
 
-            if (asyncRequest == null)
-            {
-                restBytes = FixedSizeReader.ReadPacket(_innerStream, buffer, readBytes, restBytes);
-            }
-            else
-            {
-                asyncRequest.SetNextRequest(buffer, readBytes, restBytes, s_readFrameCallback);
-                _ = FixedSizeReader.ReadPacketAsync(_innerStream, asyncRequest);
-                if (!asyncRequest.MustCompleteSynchronously)
-                {
-                    return;
-                }
+            restBytes = FixedSizeReader.ReadPacket(_innerStream, buffer, readBytes, restBytes);
 
-                restBytes = asyncRequest.Result;
-                if (restBytes == 0)
-                {
-                    //EOF received: fail.
-                    readBytes = 0;
-                }
-            }
-            ProcessReceivedBlob(buffer, readBytes + restBytes, asyncRequest);
+            SendBlob(buffer, readBytes + restBytes);
         }
 
-        private void ProcessReceivedBlob(byte[] buffer, int count, AsyncProtocolRequest asyncRequest)
+        private async ValueTask<ProtocolToken> ReceiveBlobAsync(SslReadAsync adapter, byte[] buffer, CancellationToken cancellationToken)
         {
-            if (count == 0)
+            ResetReadBuffer();
+            int readBytes = await FillBufferAsync(adapter, SecureChannel.ReadHeaderSize).ConfigureAwait(false);
+            if (readBytes == 0)
             {
-                // EOF received.
-                throw new AuthenticationException(SR.net_auth_eof, null);
+                throw new IOException(SR.net_io_eof);
             }
 
-            if (_pendingReHandshake)
+            if (_framing == Framing.Unified || _framing == Framing.Unknown)
             {
-                int offset = 0;
-                SecurityStatusPal status = PrivateDecryptData(buffer, ref offset, ref count);
+                _framing = DetectFraming(_internalBuffer, readBytes);
+            }
 
-                if (status.ErrorCode == SecurityStatusPalErrorCode.OK)
-                {
-                    Exception e = EnqueueOldKeyDecryptedData(buffer, offset, count);
-                    if (e != null)
-                    {
-                        StartSendAuthResetSignal(null, asyncRequest, ExceptionDispatchInfo.Capture(e));
-                        return;
-                    }
+            int payloadBytes = GetRemainingFrameSize(_internalBuffer, _internalOffset, readBytes);
+            if (payloadBytes < 0)
+            {
+                throw new IOException(SR.net_frame_read_size);
+            }
 
-                    _Framing = Framing.Unknown;
-                    StartReceiveBlob(buffer, asyncRequest);
-                    return;
-                }
-                else if (status.ErrorCode != SecurityStatusPalErrorCode.Renegotiate)
-                {
-                    // Fail re-handshake.
-                    ProtocolToken message = new ProtocolToken(null, status);
-                    StartSendAuthResetSignal(null, asyncRequest, ExceptionDispatchInfo.Capture(
-                        ExceptionDispatchInfo.SetCurrentStackTrace(new AuthenticationException(SR.net_auth_SSPI, message.GetException()))));
-                    return;
-                }
+            int frameSize = SecureChannel.ReadHeaderSize + payloadBytes;
 
-                // We expect only handshake messages from now.
-                _pendingReHandshake = false;
-                if (offset != 0)
+            if (readBytes < frameSize)
+            {
+                readBytes = await FillBufferAsync(adapter, frameSize).ConfigureAwait(false);
+                Debug.Assert(readBytes >= 0);
+                if (readBytes == 0)
                 {
-                    Buffer.BlockCopy(buffer, offset, buffer, 0, count);
+                    throw new IOException(SR.net_io_eof);
                 }
             }
 
-            StartSendBlob(buffer, count, asyncRequest);
+            ProtocolToken token = _context.NextMessage(_internalBuffer, _internalOffset, frameSize);
+            ConsumeBufferedBytes(frameSize);
+
+            return token;
         }
 
         //
         //  This is to reset auth state on remote side.
         //  If this write succeeds we will allow auth retrying.
         //
-        private void StartSendAuthResetSignal(ProtocolToken message, AsyncProtocolRequest asyncRequest, ExceptionDispatchInfo exception)
+        private void SendAuthResetSignal(ProtocolToken message, ExceptionDispatchInfo exception)
         {
+            SetException(exception.SourceException);
+
             if (message == null || message.Size == 0)
             {
                 //
@@ -677,28 +549,7 @@ namespace System.Net.Security
                 exception.Throw();
             }
 
-            if (asyncRequest == null)
-            {
-                InnerStream.Write(message.Payload, 0, message.Size);
-            }
-            else
-            {
-                asyncRequest.AsyncState = exception;
-                Task t = InnerStream.WriteAsync(message.Payload, 0, message.Size, asyncRequest.CancellationToken);
-                if (t.IsCompleted)
-                {
-                    t.GetAwaiter().GetResult();
-                }
-                else
-                {
-                    IAsyncResult ar = TaskToApm.Begin(t, s_writeCallback, asyncRequest);
-                    if (!ar.CompletedSynchronously)
-                    {
-                        return;
-                    }
-                    TaskToApm.End(ar);
-                }
-            }
+            InnerStream.Write(message.Payload, 0, message.Size);
 
             exception.Throw();
         }
@@ -734,144 +585,6 @@ namespace System.Net.Security
             return true;
         }
 
-        private static void WriteCallback(IAsyncResult transportResult)
-        {
-            if (transportResult.CompletedSynchronously)
-            {
-                return;
-            }
-
-            AsyncProtocolRequest asyncRequest;
-            SslStream sslState;
-
-#if DEBUG
-            try
-            {
-#endif
-                asyncRequest = (AsyncProtocolRequest)transportResult.AsyncState;
-                sslState = (SslStream)asyncRequest.AsyncObject;
-#if DEBUG
-            }
-            catch (Exception exception) when (!ExceptionCheck.IsFatal(exception))
-            {
-                NetEventSource.Fail(null, $"Exception while decoding context: {exception}");
-                throw;
-            }
-#endif
-
-            // Async completion.
-            try
-            {
-                TaskToApm.End(transportResult);
-
-                // Special case for an error notification.
-                object asyncState = asyncRequest.AsyncState;
-                ExceptionDispatchInfo exception = asyncState as ExceptionDispatchInfo;
-                if (exception != null)
-                {
-                    exception.Throw();
-                }
-
-                sslState.CheckCompletionBeforeNextReceive((ProtocolToken)asyncState, asyncRequest);
-            }
-            catch (Exception e)
-            {
-                if (asyncRequest.IsUserCompleted)
-                {
-                    // This will throw on a worker thread.
-                    throw;
-                }
-
-                sslState.FinishHandshake(e, asyncRequest);
-            }
-        }
-
-        private static void PartialFrameCallback(AsyncProtocolRequest asyncRequest)
-        {
-            if (NetEventSource.IsEnabled)
-                NetEventSource.Enter(null);
-
-            // Async ONLY completion.
-            SslStream sslState = (SslStream)asyncRequest.AsyncObject;
-            try
-            {
-                sslState.StartReadFrame(asyncRequest.Buffer, asyncRequest.Result, asyncRequest);
-            }
-            catch (Exception e)
-            {
-                if (asyncRequest.IsUserCompleted)
-                {
-                    // This will throw on a worker thread.
-                    throw;
-                }
-
-                sslState.FinishHandshake(e, asyncRequest);
-            }
-        }
-
-        //
-        //
-        private static void ReadFrameCallback(AsyncProtocolRequest asyncRequest)
-        {
-            if (NetEventSource.IsEnabled)
-                NetEventSource.Enter(null);
-
-            // Async ONLY completion.
-            SslStream sslState = (SslStream)asyncRequest.AsyncObject;
-            try
-            {
-                if (asyncRequest.Result == 0)
-                {
-                    //EOF received: will fail.
-                    asyncRequest.Offset = 0;
-                }
-
-                sslState.ProcessReceivedBlob(asyncRequest.Buffer, asyncRequest.Offset + asyncRequest.Result, asyncRequest);
-            }
-            catch (Exception e)
-            {
-                if (asyncRequest.IsUserCompleted)
-                {
-                    // This will throw on a worker thread.
-                    throw;
-                }
-
-                sslState.FinishHandshake(e, asyncRequest);
-            }
-        }
-
-        private bool CheckEnqueueHandshakeRead(ref byte[] buffer, AsyncProtocolRequest request)
-        {
-            LazyAsyncResult lazyResult = null;
-            lock (SyncLock)
-            {
-                if (_lockReadState == LockPendingRead)
-                {
-                    return false;
-                }
-
-                int lockState = Interlocked.Exchange(ref _lockReadState, LockHandshake);
-                if (lockState != LockRead)
-                {
-                    return false;
-                }
-
-                if (request != null)
-                {
-                    _queuedReadStateRequest = request;
-                    return true;
-                }
-
-                lazyResult = new LazyAsyncResult(null, null, /*must be */ null);
-                _queuedReadStateRequest = lazyResult;
-            }
-
-            // Need to exit from lock before waiting.
-            lazyResult.InternalWaitForCompletion();
-            buffer = (byte[])lazyResult.Result;
-            return false;
-        }
-
         private void FinishHandshakeRead(int newState)
         {
             lock (SyncLock)
@@ -885,7 +598,6 @@ namespace System.Net.Security
                 }
 
                 _lockReadState = LockRead;
-                HandleQueuedCallback(ref _queuedReadStateRequest);
             }
         }
 
@@ -966,33 +678,6 @@ namespace System.Net.Security
             }
         }
 
-        private void FinishRead(byte[] renegotiateBuffer)
-        {
-            int lockState = Interlocked.CompareExchange(ref _lockReadState, LockNone, LockRead);
-
-            if (lockState != LockHandshake)
-            {
-                return;
-            }
-
-            lock (SyncLock)
-            {
-                LazyAsyncResult ar = _queuedReadStateRequest as LazyAsyncResult;
-                if (ar != null)
-                {
-                    _queuedReadStateRequest = null;
-                    ar.InvokeCallback(renegotiateBuffer);
-                }
-                else
-                {
-                    AsyncProtocolRequest request = (AsyncProtocolRequest)_queuedReadStateRequest;
-                    request.Buffer = renegotiateBuffer;
-                    _queuedReadStateRequest = null;
-                    ThreadPool.QueueUserWorkItem(s => s.sslState.AsyncResumeHandshakeRead(s.request), (sslState: this, request), preferLocal: false);
-                }
-            }
-        }
-
         private Task CheckEnqueueWriteAsync()
         {
             // Clear previous request.
@@ -1057,123 +742,28 @@ namespace System.Net.Security
             {
                 return;
             }
-
-            lock (SyncLock)
-            {
-                HandleQueuedCallback(ref _queuedWriteStateRequest);
-            }
-        }
-
-        private void HandleQueuedCallback(ref object queuedStateRequest)
-        {
-            object obj = queuedStateRequest;
-            if (obj == null)
-            {
-                return;
-            }
-            queuedStateRequest = null;
-
-            switch (obj)
-            {
-                case LazyAsyncResult lazy:
-                    lazy.InvokeCallback();
-                    break;
-                case TaskCompletionSource<int> taskCompletionSource when taskCompletionSource.Task.AsyncState != null:
-                    Memory<byte> array = (Memory<byte>)taskCompletionSource.Task.AsyncState;
-                    int oldKeyResult = -1;
-                    try
-                    {
-                        oldKeyResult = CheckOldKeyDecryptedData(array);
-                    }
-                    catch (Exception exc)
-                    {
-                        taskCompletionSource.SetException(exc);
-                        break;
-                    }
-                    taskCompletionSource.SetResult(oldKeyResult);
-                    break;
-                case TaskCompletionSource<int> taskCompletionSource:
-                    taskCompletionSource.SetResult(0);
-                    break;
-                default:
-                    ThreadPool.QueueUserWorkItem(s => s.sslState.AsyncResumeHandshake(s.obj), (sslState: this, obj), preferLocal: false);
-                    break;
-            }
         }
 
-        // Returns:
-        // true  - operation queued
-        // false - operation can proceed
-        private bool CheckEnqueueHandshake(byte[] buffer, AsyncProtocolRequest asyncRequest)
+        private void FinishHandshake(Exception e)
         {
-            LazyAsyncResult lazyResult = null;
-
             lock (SyncLock)
             {
-                if (_lockWriteState == LockPendingWrite)
+                if (e != null)
                 {
-                    return false;
+                    SetException(e);
                 }
 
-                int lockState = Interlocked.Exchange(ref _lockWriteState, LockHandshake);
-                if (lockState != LockWrite)
-                {
-                    // Proceed with handshake.
-                    return false;
-                }
+                // Release read if any.
+                FinishHandshakeRead(LockNone);
 
-                if (asyncRequest != null)
+                // If there is a pending write we want to keep it's lock state.
+                int lockState = Interlocked.CompareExchange(ref _lockWriteState, LockNone, LockHandshake);
+                if (lockState != LockPendingWrite)
                 {
-                    asyncRequest.Buffer = buffer;
-                    _queuedWriteStateRequest = asyncRequest;
-                    return true;
+                    return;
                 }
 
-                lazyResult = new LazyAsyncResult(null, null, /*must be*/null);
-                _queuedWriteStateRequest = lazyResult;
-            }
-            lazyResult.InternalWaitForCompletion();
-            return false;
-        }
-
-        private void FinishHandshake(Exception e, AsyncProtocolRequest asyncRequest)
-        {
-            try
-            {
-                lock (SyncLock)
-                {
-                    if (e != null)
-                    {
-                        SetException(e);
-                    }
-
-                    // Release read if any.
-                    FinishHandshakeRead(LockNone);
-
-                    // If there is a pending write we want to keep it's lock state.
-                    int lockState = Interlocked.CompareExchange(ref _lockWriteState, LockNone, LockHandshake);
-                    if (lockState != LockPendingWrite)
-                    {
-                        return;
-                    }
-
-                    _lockWriteState = LockWrite;
-                    HandleQueuedCallback(ref _queuedWriteStateRequest);
-                }
-            }
-            finally
-            {
-                if (asyncRequest != null)
-                {
-                    if (e != null)
-                    {
-                        asyncRequest.CompleteUserWithError(e);
-                    }
-                    else
-                    {
-                        asyncRequest.CompleteUser();
-                    }
-                }
+                _lockWriteState = LockWrite;
             }
         }
 
@@ -1307,8 +897,6 @@ namespace System.Net.Security
                     {
                         copyBytes = CopyDecryptedData(buffer);
 
-                        FinishRead(null);
-
                         return copyBytes;
                     }
 
@@ -1368,18 +956,17 @@ namespace System.Net.Security
                         {
                             if (!_sslAuthenticationOptions.AllowRenegotiation)
                             {
+                                if (NetEventSource.IsEnabled) NetEventSource.Fail(this, "Renegotiation was requested but it is disallowed");
                                 throw new IOException(SR.net_ssl_io_renego);
                             }
 
-                            ReplyOnReAuthentication(extraBuffer, adapter.CancellationToken);
-
+                            await ReplyOnReAuthenticationAsync(extraBuffer, adapter.CancellationToken).ConfigureAwait(false);
                             // Loop on read.
                             continue;
                         }
 
                         if (message.CloseConnection)
                         {
-                            FinishRead(null);
                             return 0;
                         }
 
@@ -1389,8 +976,6 @@ namespace System.Net.Security
             }
             catch (Exception e)
             {
-                FinishRead(null);
-
                 if (e is IOException || (e is OperationCanceledException && adapter.CancellationToken.IsCancellationRequested))
                 {
                     throw;
@@ -1526,6 +1111,7 @@ namespace System.Net.Security
                 _decryptedBytesOffset += copyBytes;
                 _decryptedBytesCount -= copyBytes;
             }
+
             ReturnReadBufferIfEmpty();
             return copyBytes;
         }
@@ -1562,27 +1148,6 @@ namespace System.Net.Security
             return buffer;
         }
 
-        private enum Framing
-        {
-            Unknown = 0,
-            BeforeSSL3,
-            SinceSSL3,
-            Unified,
-            Invalid
-        }
-
-        // This is set on the first packet to figure out the framing style.
-        private Framing _Framing = Framing.Unknown;
-
-        // SSL3/TLS protocol frames definitions.
-        private enum FrameType : byte
-        {
-            ChangeCipherSpec = 20,
-            Alert = 21,
-            Handshake = 22,
-            AppData = 23
-        }
-
         // We need at least 5 bytes to determine what we have.
         private Framing DetectFraming(byte[] bytes, int length)
         {
@@ -1735,7 +1300,7 @@ namespace System.Net.Security
                 // If this is the first packet, the client may start with an SSL2 packet
                 // but stating that the version is 3.x, so check the full range.
                 // For the subsequent packets we assume that an SSL2 packet should have a 2.x version.
-                if (_Framing == Framing.Unknown)
+                if (_framing == Framing.Unknown)
                 {
                     if (version != 0x0002 && (version < 0x200 || version >= 0x500))
                     {
@@ -1752,7 +1317,7 @@ namespace System.Net.Security
             }
 
             // When server has replied the framing is already fixed depending on the prior client packet
-            if (!_context.IsServer || _Framing == Framing.Unified)
+            if (!_context.IsServer || _framing == Framing.Unified)
             {
                 return Framing.BeforeSSL3;
             }
@@ -1768,7 +1333,7 @@ namespace System.Net.Security
                 NetEventSource.Enter(this, buffer, offset, dataSize);
 
             int payloadSize = -1;
-            switch (_Framing)
+            switch (_framing)
             {
                 case Framing.Unified:
                 case Framing.BeforeSSL3:
@@ -1809,104 +1374,5 @@ namespace System.Net.Security
                 NetEventSource.Exit(this, payloadSize);
             return payloadSize;
         }
-
-        //
-        // Called with no user stack.
-        //
-        private void AsyncResumeHandshake(object state)
-        {
-            AsyncProtocolRequest request = state as AsyncProtocolRequest;
-            Debug.Assert(request != null, "Expected an AsyncProtocolRequest reference.");
-
-            try
-            {
-                ForceAuthentication(_context.IsServer, request.Buffer, request);
-            }
-            catch (Exception e)
-            {
-                request.CompleteUserWithError(e);
-            }
-        }
-
-        //
-        // Called with no user stack.
-        //
-        private void AsyncResumeHandshakeRead(AsyncProtocolRequest asyncRequest)
-        {
-            try
-            {
-                if (_pendingReHandshake)
-                {
-                    // Resume as read a blob.
-                    StartReceiveBlob(asyncRequest.Buffer, asyncRequest);
-                }
-                else
-                {
-                    // Resume as process the blob.
-                    ProcessReceivedBlob(asyncRequest.Buffer, asyncRequest.Buffer == null ? 0 : asyncRequest.Buffer.Length, asyncRequest);
-                }
-            }
-            catch (Exception e)
-            {
-                if (asyncRequest.IsUserCompleted)
-                {
-                    // This will throw on a worker thread.
-                    throw;
-                }
-
-                FinishHandshake(e, asyncRequest);
-            }
-        }
-
-        private void RehandshakeCompleteCallback(IAsyncResult result)
-        {
-            LazyAsyncResult lazyAsyncResult = (LazyAsyncResult)result;
-            if (lazyAsyncResult == null)
-            {
-                NetEventSource.Fail(this, "result is null!");
-            }
-
-            if (!lazyAsyncResult.InternalPeekCompleted)
-            {
-                NetEventSource.Fail(this, "result is not completed!");
-            }
-
-            // If the rehandshake succeeded, FinishHandshake has already been called; if there was a SocketException
-            // during the handshake, this gets called directly from FixedSizeReader, and we need to call
-            // FinishHandshake to wake up the Read that triggered this rehandshake so the error gets back to the caller
-            Exception exception = lazyAsyncResult.InternalWaitForCompletion() as Exception;
-            if (exception != null)
-            {
-                // We may be calling FinishHandshake reentrantly, as FinishHandshake can call
-                // asyncRequest.CompleteWithError, which will result in this method being called.
-                // This is not a problem because:
-                //
-                // 1. We pass null as the asyncRequest parameter, so this second call to FinishHandshake won't loop
-                //    back here.
-                //
-                // 2. _QueuedWriteStateRequest and _QueuedReadStateRequest are set to null after the first call,
-                //    so, we won't invoke their callbacks again.
-                //
-                // 3. SetException won't overwrite an already-set _Exception.
-                //
-                // 4. There are three possibilities for _LockReadState and _LockWriteState:
-                //
-                //    a. They were set back to None by the first call to FinishHandshake, and this will set them to
-                //       None again: a no-op.
-                //
-                //    b. They were set to None by the first call to FinishHandshake, but as soon as the lock was given
-                //       up, another thread took a read/write lock.  Calling FinishHandshake again will set them back
-                //       to None, but that's fine because that thread will be throwing _Exception before it actually
-                //       does any reading or writing and setting them back to None in a catch block anyways.
-                //
-                //    c. If there is a Read/Write going on another thread, and the second FinishHandshake clears its
-                //       read/write lock, it's fine because no other Read/Write can look at the lock until the current
-                //       one gives up _SslStream._NestedRead/Write, and no handshake will look at the lock because
-                //       handshakes are only triggered in response to successful reads (which won't happen once
-                //       _Exception is set).
-
-                FinishHandshake(exception, null);
-            }
-        }
     }
 }
index 568e44f..8dc2e8b 100644 (file)
@@ -226,19 +226,10 @@ namespace System.Net.Security
             return BeginAuthenticateAsClient(options, CancellationToken.None, asyncCallback, asyncState);
         }
 
-        internal IAsyncResult BeginAuthenticateAsClient(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken, AsyncCallback asyncCallback, object asyncState)
-        {
-            SetAndVerifyValidationCallback(sslClientAuthenticationOptions.RemoteCertificateValidationCallback);
-            SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback);
-
-            ValidateCreateContext(sslClientAuthenticationOptions, _certValidationDelegate, _certSelectionDelegate);
+        internal IAsyncResult BeginAuthenticateAsClient(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken, AsyncCallback asyncCallback, object asyncState) =>
+            TaskToApm.Begin(AuthenticateAsClientApm(sslClientAuthenticationOptions, cancellationToken), asyncCallback, asyncState);
 
-            LazyAsyncResult result = new LazyAsyncResult(this, asyncState, asyncCallback);
-            ProcessAuthentication(result, cancellationToken);
-            return result;
-        }
-
-        public virtual void EndAuthenticateAsClient(IAsyncResult asyncResult) => EndProcessAuthentication(asyncResult);
+        public virtual void EndAuthenticateAsClient(IAsyncResult asyncResult) => TaskToApm.End(asyncResult);
 
         //
         // Server side auth.
@@ -248,7 +239,7 @@ namespace System.Net.Security
         {
             return BeginAuthenticateAsServer(serverCertificate, false, SecurityProtocol.SystemDefaultSecurityProtocols, false,
                                                           asyncCallback,
-                                                            asyncState);
+                                                          asyncState);
         }
 
         public virtual IAsyncResult BeginAuthenticateAsServer(X509Certificate serverCertificate, bool clientCertificateRequired,
@@ -274,34 +265,14 @@ namespace System.Net.Security
             return BeginAuthenticateAsServer(options, CancellationToken.None, asyncCallback, asyncState);
         }
 
-        private IAsyncResult BeginAuthenticateAsServer(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken, AsyncCallback asyncCallback, object asyncState)
-        {
-            SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback);
-
-            ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions));
-
-            LazyAsyncResult result = new LazyAsyncResult(this, asyncState, asyncCallback);
-            ProcessAuthentication(result, cancellationToken);
-            return result;
-        }
-
-        public virtual void EndAuthenticateAsServer(IAsyncResult asyncResult) => EndProcessAuthentication(asyncResult);
-
-        internal IAsyncResult BeginShutdown(AsyncCallback asyncCallback, object asyncState)
-        {
-            ThrowIfExceptionalOrNotAuthenticatedOrShutdown();
+        private IAsyncResult BeginAuthenticateAsServer(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken, AsyncCallback asyncCallback, object asyncState) =>
+            TaskToApm.Begin(AuthenticateAsServerApm(sslServerAuthenticationOptions, cancellationToken), asyncCallback, asyncState);
 
-            ProtocolToken message = _context.CreateShutdownToken();
-            return TaskToApm.Begin(InnerStream.WriteAsync(message.Payload, 0, message.Payload.Length), asyncCallback, asyncState);
-        }
+        public virtual void EndAuthenticateAsServer(IAsyncResult asyncResult) => TaskToApm.End(asyncResult);
 
-        internal void EndShutdown(IAsyncResult asyncResult)
-        {
-            ThrowIfExceptionalOrNotAuthenticatedOrShutdown();
+        internal IAsyncResult BeginShutdown(AsyncCallback asyncCallback, object asyncState) => TaskToApm.Begin(ShutdownAsync(), asyncCallback, asyncState);
 
-            TaskToApm.End(asyncResult);
-            _shutdown = true;
-        }
+        internal void EndShutdown(IAsyncResult asyncResult) => TaskToApm.End(asyncResult);
 
         public TransportContext TransportContext => new SslStreamContext(this);
 
@@ -338,7 +309,7 @@ namespace System.Net.Security
             SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback);
 
             ValidateCreateContext(sslClientAuthenticationOptions, _certValidationDelegate, _certSelectionDelegate);
-            ProcessAuthentication(null, default);
+            ProcessAuthentication();
         }
 
         public virtual void AuthenticateAsServer(X509Certificate serverCertificate)
@@ -370,86 +341,103 @@ namespace System.Net.Security
             SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback);
 
             ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions));
-            ProcessAuthentication(null, default);
+            ProcessAuthentication();
         }
         #endregion
 
         #region Task-based async public methods
-        public virtual Task AuthenticateAsClientAsync(string targetHost) =>
-            Task.Factory.FromAsync(
-                (arg1, callback, state) => ((SslStream)state).BeginAuthenticateAsClient(arg1, callback, state),
-                iar => ((SslStream)iar.AsyncState).EndAuthenticateAsClient(iar),
-                targetHost,
-                this);
-
-        public virtual Task AuthenticateAsClientAsync(string targetHost, X509CertificateCollection clientCertificates, bool checkCertificateRevocation) =>
-            Task.Factory.FromAsync(
-                (arg1, arg2, arg3, callback, state) => ((SslStream)state).BeginAuthenticateAsClient(arg1, arg2, SecurityProtocol.SystemDefaultSecurityProtocols, arg3, callback, state),
-                iar => ((SslStream)iar.AsyncState).EndAuthenticateAsClient(iar),
-                targetHost, clientCertificates, checkCertificateRevocation,
-                this);
+        public virtual Task AuthenticateAsClientAsync(string targetHost) => AuthenticateAsClientAsync(targetHost, null, false);
+
+        public virtual Task AuthenticateAsClientAsync(string targetHost, X509CertificateCollection clientCertificates, bool checkCertificateRevocation) => AuthenticateAsClientAsync(targetHost, clientCertificates, SecurityProtocol.SystemDefaultSecurityProtocols, checkCertificateRevocation);
 
         public virtual Task AuthenticateAsClientAsync(string targetHost, X509CertificateCollection clientCertificates, SslProtocols enabledSslProtocols, bool checkCertificateRevocation)
         {
-            var beginMethod = checkCertificateRevocation ? (Func<string, X509CertificateCollection, SslProtocols, AsyncCallback, object, IAsyncResult>)
-                ((arg1, arg2, arg3, callback, state) => ((SslStream)state).BeginAuthenticateAsClient(arg1, arg2, arg3, true, callback, state)) :
-                ((arg1, arg2, arg3, callback, state) => ((SslStream)state).BeginAuthenticateAsClient(arg1, arg2, arg3, false, callback, state));
-            return Task.Factory.FromAsync(
-                beginMethod,
-                iar => ((SslStream)iar.AsyncState).EndAuthenticateAsClient(iar),
-                targetHost, clientCertificates, enabledSslProtocols,
-                this);
+            SslClientAuthenticationOptions options = new SslClientAuthenticationOptions()
+            {
+                TargetHost =  targetHost,
+                ClientCertificates =  clientCertificates,
+                EnabledSslProtocols = enabledSslProtocols,
+                CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck,
+                EncryptionPolicy = _encryptionPolicy,
+            };
+
+            return AuthenticateAsClientAsync(options);
         }
 
         public Task AuthenticateAsClientAsync(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken = default)
         {
-            return Task.Factory.FromAsync(
-                (arg1, arg2, callback, state) => ((SslStream)state).BeginAuthenticateAsClient(arg1, arg2, callback, state),
-                iar => ((SslStream)iar.AsyncState).EndAuthenticateAsClient(iar),
-                sslClientAuthenticationOptions, cancellationToken,
-                this);
+            SetAndVerifyValidationCallback(sslClientAuthenticationOptions.RemoteCertificateValidationCallback);
+            SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback);
+
+            ValidateCreateContext(sslClientAuthenticationOptions, _certValidationDelegate, _certSelectionDelegate);
+
+            return ProcessAuthentication(true, false, cancellationToken);
+        }
+
+        private Task AuthenticateAsClientApm(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken = default)
+        {
+            SetAndVerifyValidationCallback(sslClientAuthenticationOptions.RemoteCertificateValidationCallback);
+            SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback);
+
+            ValidateCreateContext(sslClientAuthenticationOptions, _certValidationDelegate, _certSelectionDelegate);
+
+            return ProcessAuthentication(true, true, cancellationToken);
         }
 
         public virtual Task AuthenticateAsServerAsync(X509Certificate serverCertificate) =>
-            Task.Factory.FromAsync(
-                (arg1, callback, state) => ((SslStream)state).BeginAuthenticateAsServer(arg1, callback, state),
-                iar => ((SslStream)iar.AsyncState).EndAuthenticateAsServer(iar),
-                serverCertificate,
-                this);
-
-        public virtual Task AuthenticateAsServerAsync(X509Certificate serverCertificate, bool clientCertificateRequired, bool checkCertificateRevocation) =>
-            Task.Factory.FromAsync(
-                (arg1, arg2, arg3, callback, state) => ((SslStream)state).BeginAuthenticateAsServer(arg1, arg2, SecurityProtocol.SystemDefaultSecurityProtocols, arg3, callback, state),
-                iar => ((SslStream)iar.AsyncState).EndAuthenticateAsServer(iar),
-                serverCertificate, clientCertificateRequired, checkCertificateRevocation,
-                this);
+            AuthenticateAsServerAsync(serverCertificate, false, SecurityProtocol.SystemDefaultSecurityProtocols, false);
+
+        public virtual Task AuthenticateAsServerAsync(X509Certificate serverCertificate, bool clientCertificateRequired, bool checkCertificateRevocation)
+        {
+            SslServerAuthenticationOptions options = new SslServerAuthenticationOptions
+            {
+                ServerCertificate = serverCertificate,
+                ClientCertificateRequired = clientCertificateRequired,
+                CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck,
+                EncryptionPolicy = _encryptionPolicy,
+            };
+
+            return AuthenticateAsServerAsync(options);
+        }
 
         public virtual Task AuthenticateAsServerAsync(X509Certificate serverCertificate, bool clientCertificateRequired, SslProtocols enabledSslProtocols, bool checkCertificateRevocation)
         {
-            var beginMethod = checkCertificateRevocation ? (Func<X509Certificate, bool, SslProtocols, AsyncCallback, object, IAsyncResult>)
-                ((arg1, arg2, arg3, callback, state) => ((SslStream)state).BeginAuthenticateAsServer(arg1, arg2, arg3, true, callback, state)) :
-                ((arg1, arg2, arg3, callback, state) => ((SslStream)state).BeginAuthenticateAsServer(arg1, arg2, arg3, false, callback, state));
-            return Task.Factory.FromAsync(
-                beginMethod,
-                iar => ((SslStream)iar.AsyncState).EndAuthenticateAsServer(iar),
-                serverCertificate, clientCertificateRequired, enabledSslProtocols,
-                this);
+            SslServerAuthenticationOptions options = new SslServerAuthenticationOptions
+            {
+                ServerCertificate = serverCertificate,
+                ClientCertificateRequired = clientCertificateRequired,
+                EnabledSslProtocols = enabledSslProtocols,
+                CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck,
+                EncryptionPolicy = _encryptionPolicy,
+            };
+
+            return AuthenticateAsServerAsync(options);
         }
 
         public Task AuthenticateAsServerAsync(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken = default)
         {
-            return Task.Factory.FromAsync(
-                (arg1, arg2, callback, state) => ((SslStream)state).BeginAuthenticateAsServer(arg1, arg2, callback, state),
-                iar => ((SslStream)iar.AsyncState).EndAuthenticateAsServer(iar),
-                sslServerAuthenticationOptions, cancellationToken,
-                this);
+            SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback);
+            ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions));
+
+            return ProcessAuthentication(true, false, cancellationToken);
         }
 
-        public virtual Task ShutdownAsync() =>
-            Task.Factory.FromAsync(
-                (callback, state) => ((SslStream)state).BeginShutdown(callback, state),
-                iar => ((SslStream)iar.AsyncState).EndShutdown(iar),
-                this);
+        private Task AuthenticateAsServerApm(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken = default)
+        {
+            SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback);
+            ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions));
+
+            return ProcessAuthentication(true, true, cancellationToken);
+        }
+
+        public virtual Task ShutdownAsync()
+        {
+            ThrowIfExceptionalOrNotAuthenticatedOrShutdown();
+
+            ProtocolToken message = _context.CreateShutdownToken();
+            _shutdown = true;
+            return InnerStream.WriteAsync(message.Payload, default).AsTask();
+        }
         #endregion
 
         public override bool IsAuthenticated => _context != null && _context.IsValidContext && _exception == null && _handshakeCompleted;
index f7d6246..eda667a 100644 (file)
@@ -35,22 +35,22 @@ namespace System.Net.Security
         public static SecurityStatusPal AcceptSecurityContext(
             ref SafeFreeCredentials credential,
             ref SafeDeleteSslContext context,
-            ArraySegment<byte> inputBuffer,
+            byte[] inputBuffer, int offset, int count,
             ref byte[] outputBuffer,
             SslAuthenticationOptions sslAuthenticationOptions)
         {
-            return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions);
+            return HandshakeInternal(credential, ref context, new ReadOnlySpan<byte>(inputBuffer, offset, count), ref outputBuffer, sslAuthenticationOptions);
         }
 
         public static SecurityStatusPal InitializeSecurityContext(
             ref SafeFreeCredentials credential,
             ref SafeDeleteSslContext context,
             string targetName,
-            ArraySegment<byte> inputBuffer,
+            byte[] inputBuffer, int offset, int count,
             ref byte[] outputBuffer,
             SslAuthenticationOptions sslAuthenticationOptions)
         {
-            return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions);
+            return HandshakeInternal(credential, ref context, new ReadOnlySpan<byte>(inputBuffer, offset, count), ref outputBuffer, sslAuthenticationOptions);
         }
 
         public static SafeFreeCredentials AcquireCredentialsHandle(
@@ -233,7 +233,7 @@ namespace System.Net.Security
         private static SecurityStatusPal HandshakeInternal(
             SafeFreeCredentials credential,
             ref SafeDeleteSslContext context,
-            ArraySegment<byte> inputBuffer,
+            ReadOnlySpan<byte> inputBuffer,
             ref byte[] outputBuffer,
             SslAuthenticationOptions sslAuthenticationOptions)
         {
@@ -260,9 +260,9 @@ namespace System.Net.Security
                     }
                 }
 
-                if (inputBuffer.Array != null && inputBuffer.Count > 0)
+                if (inputBuffer.Length > 0)
                 {
-                    sslContext.Write(inputBuffer.Array, inputBuffer.Offset, inputBuffer.Count);
+                    sslContext.Write(inputBuffer);
                 }
 
                 SafeSslHandle sslHandle = sslContext.SslContext;
index 4abb400..4086764 100644 (file)
@@ -28,15 +28,15 @@ namespace System.Net.Security
         }
 
         public static SecurityStatusPal AcceptSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context,
-            ArraySegment<byte> inputBuffer, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
+            byte[] inputBuffer, int offset, int count, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
         {
-            return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions);
+            return HandshakeInternal(credential, ref context, new ReadOnlySpan<byte>(inputBuffer, offset, count), ref outputBuffer, sslAuthenticationOptions);
         }
 
         public static SecurityStatusPal InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName,
-            ArraySegment<byte> inputBuffer, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
+            byte[] inputBuffer, int offset, int count, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
         {
-            return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions);
+            return HandshakeInternal(credential, ref context, new ReadOnlySpan<byte>(inputBuffer, offset, count), ref outputBuffer, sslAuthenticationOptions);
         }
 
         public static SafeFreeCredentials AcquireCredentialsHandle(X509Certificate certificate,
@@ -100,7 +100,7 @@ namespace System.Net.Security
         }
 
         private static SecurityStatusPal HandshakeInternal(SafeFreeCredentials credential, ref SafeDeleteSslContext context,
-            ArraySegment<byte> inputBuffer, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
+            ReadOnlySpan<byte> inputBuffer, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
         {
             Debug.Assert(!credential.IsInvalid);
 
@@ -114,16 +114,7 @@ namespace System.Net.Security
                     context = new SafeDeleteSslContext(credential as SafeFreeSslCredentials, sslAuthenticationOptions);
                 }
 
-                bool done;
-
-                if (inputBuffer.Array == null)
-                {
-                    done = Interop.OpenSsl.DoSslHandshake(((SafeDeleteSslContext)context).SslContext, null, 0, 0, out output, out outputSize);
-                }
-                else
-                {
-                    done = Interop.OpenSsl.DoSslHandshake(((SafeDeleteSslContext)context).SslContext, inputBuffer.Array, inputBuffer.Offset, inputBuffer.Count, out output, out outputSize);
-                }
+                bool done = Interop.OpenSsl.DoSslHandshake(((SafeDeleteSslContext)context).SslContext, inputBuffer, out output, out outputSize);
 
                 // When the handshake is done, and the context is server, check if the alpnHandle target was set to null during ALPN.
                 // If it was, then that indicates ALPN failed, send failure.
@@ -172,7 +163,7 @@ namespace System.Net.Security
 
                 if (encrypt)
                 {
-                    resultSize = Interop.OpenSsl.Encrypt(scHandle, input, ref output, out errorCode);
+                    resultSize = Interop.OpenSsl.Encrypt(scHandle, input.Span, ref output, out errorCode);
                 }
                 else
                 {
index 5b9fa73..2e4b77c 100644 (file)
@@ -46,9 +46,10 @@ namespace System.Net.Security
             return Interop.Sec_Application_Protocols.ToByteArray(protocols);
         }
 
-        public static SecurityStatusPal AcceptSecurityContext(ref SafeFreeCredentials credentialsHandle, ref SafeDeleteSslContext context, ArraySegment<byte> input, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
+        public static SecurityStatusPal AcceptSecurityContext(ref SafeFreeCredentials credentialsHandle, ref SafeDeleteSslContext context, byte[] inputBuffer, int offset, int count, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
         {
             Interop.SspiCli.ContextFlags unusedAttributes = default;
+            ArraySegment<byte> input = inputBuffer != null ? new ArraySegment<byte>(inputBuffer, offset, count) : default;
 
             ThreeSecurityBuffers threeSecurityBuffers = default;
             SecurityBuffer? incomingSecurity = input.Array != null ?
@@ -73,9 +74,10 @@ namespace System.Net.Security
             return SecurityStatusAdapterPal.GetSecurityStatusPalFromNativeInt(errorCode);
         }
 
-        public static SecurityStatusPal InitializeSecurityContext(ref SafeFreeCredentials credentialsHandle, ref SafeDeleteSslContext context, string targetName, ArraySegment<byte> input, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
+        public static SecurityStatusPal InitializeSecurityContext(ref SafeFreeCredentials credentialsHandle, ref SafeDeleteSslContext context, string targetName, byte[] inputBuffer, int offset, int count, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
         {
             Interop.SspiCli.ContextFlags unusedAttributes = default;
+            ArraySegment<byte> input = inputBuffer != null ? new ArraySegment<byte>(inputBuffer, offset, count) : default;
 
             ThreeSecurityBuffers threeSecurityBuffers = default;
             SecurityBuffer? incomingSecurity = input.Array != null ?
index b12ad9b..14b4918 100644 (file)
@@ -593,7 +593,7 @@ namespace System.Net.Security.Tests
             return new CipherSuitesPolicy(cipherSuites);
         }
 
-        private static async Task<Exception> WaitForSecureConnection(VirtualNetwork connection, Func<Task> server, Func<Task> client)
+        private static async Task<Exception> WaitForSecureConnection(SslStream client, SslClientAuthenticationOptions clientOptions, SslStream server, SslServerAuthenticationOptions serverOptions)
         {
             Task serverTask = null;
             Task clientTask = null;
@@ -601,12 +601,13 @@ namespace System.Net.Security.Tests
             // check if failed synchronously
             try
             {
-                serverTask = server();
-                clientTask = client();
+                serverTask = server.AuthenticateAsServerAsync(serverOptions, CancellationToken.None);
+                clientTask = client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None);
             }
             catch (Exception e)
             {
-                connection.BreakConnection();
+                client.Close();
+                server.Close();
 
                 if (!(e is AuthenticationException || e is Win32Exception))
                 {
@@ -625,6 +626,7 @@ namespace System.Net.Security.Tests
                     catch (AuthenticationException) { }
                     catch (Win32Exception) { }
                     catch (VirtualNetwork.VirtualNetworkConnectionBroken) { }
+                    catch (IOException) { }
                 }
 
                 return e;
@@ -635,32 +637,42 @@ namespace System.Net.Security.Tests
             // Now we expect both sides to fail or both to succeed
 
             Exception failure = null;
+            Task task = null;
 
             try
             {
-                await serverTask.ConfigureAwait(false);
+                task = await Task.WhenAny(serverTask, clientTask).TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds).ConfigureAwait(false);
+                await task ;
             }
             catch (Exception e) when (e is AuthenticationException || e is Win32Exception)
             {
                 failure = e;
-
                 // avoid client waiting for server's response
-                connection.BreakConnection();
+                if (task == serverTask)
+                {
+                    server.Close();
+                }
+                else
+                {
+                    client.Close();
+                }
             }
 
             try
             {
-                await clientTask.ConfigureAwait(false);
+                // Now wait for the other task to finish.
+                task = (task == serverTask ? clientTask : serverTask);
+                await task.TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds).ConfigureAwait(false);
 
                 // Fail if server has failed but client has succeeded
                 Assert.Null(failure);
             }
-            catch (Exception e) when (e is VirtualNetwork.VirtualNetworkConnectionBroken || e is AuthenticationException || e is Win32Exception)
+            catch (Exception e) when (e is VirtualNetwork.VirtualNetworkConnectionBroken || e is AuthenticationException || e is Win32Exception || e is IOException)
             {
                 // Fail if server has succeeded but client has failed
                 Assert.NotNull(failure);
 
-                if (e.GetType() != typeof(VirtualNetwork.VirtualNetworkConnectionBroken))
+                if (e.GetType() != typeof(VirtualNetwork.VirtualNetworkConnectionBroken) && e.GetType() != typeof(IOException))
                 {
                     failure = new AggregateException(new Exception[] { failure, e });
                 }
@@ -671,9 +683,10 @@ namespace System.Net.Security.Tests
 
         private static NegotiatedParams ConnectAndGetNegotiatedParams(ConnectionParams serverParams, ConnectionParams clientParams)
         {
-            VirtualNetwork vn = new VirtualNetwork();
-            using (VirtualNetworkStream serverStream = new VirtualNetworkStream(vn, isServer: true),
-                                        clientStream = new VirtualNetworkStream(vn, isServer: false))
+            (Stream clientStream, Stream serverStream) = TestHelper.GetConnectedStreams();
+
+            using (clientStream)
+            using (serverStream)
             using (SslStream server = new SslStream(serverStream, leaveInnerStreamOpen: false),
                              client = new SslStream(clientStream, leaveInnerStreamOpen: false))
             {
@@ -696,10 +709,7 @@ namespace System.Net.Security.Tests
                                                                  return true;
                                                              });
 
-                Func<Task> serverTask = () => server.AuthenticateAsServerAsync(serverOptions, CancellationToken.None);
-                Func<Task> clientTask = () => client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None);
-
-                Exception failure = WaitForSecureConnection(vn, serverTask, clientTask).Result;
+                Exception failure = WaitForSecureConnection(client, clientOptions, server, serverOptions).GetAwaiter().GetResult();
 
                 if (failure == null)
                 {
index dfee6bc..838c673 100644 (file)
@@ -920,23 +920,18 @@ namespace System.Net.Security.Tests
         [Fact]
         public async Task AuthenticateAsClientAsync_Sockets_CanceledAfterStart_ThrowsOperationCanceledException()
         {
-            using (var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
-            using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
-            {
-                listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
-                listener.Listen(1);
+            (Stream client, Stream server) = TestHelper.GetConnectedTcpStreams();
 
-                await client.ConnectAsync(listener.LocalEndPoint);
-                using (Socket server = await listener.AcceptAsync())
-                using (var clientSslStream = new SslStream(new NetworkStream(client), false, AllowAnyServerCertificate))
-                using (var serverSslStream = new SslStream(new NetworkStream(server)))
-                using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate())
-                {
-                    var cts = new CancellationTokenSource();
-                    Task t = clientSslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions() { TargetHost = certificate.GetNameInfo(X509NameType.SimpleName, false) }, cts.Token);
-                    cts.Cancel();
-                    await Assert.ThrowsAnyAsync<OperationCanceledException>(() => t);
-                }
+            using (client)
+            using (server)
+            using (var clientSslStream = new SslStream(client, false, AllowAnyServerCertificate))
+            using (var serverSslStream = new SslStream(server))
+            using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate())
+            {
+                var cts = new CancellationTokenSource();
+                Task t = clientSslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions() { TargetHost = certificate.GetNameInfo(X509NameType.SimpleName, false) }, cts.Token);
+                cts.Cancel();
+                await Assert.ThrowsAnyAsync<OperationCanceledException>(() => t);
             }
         }
 
@@ -958,23 +953,18 @@ namespace System.Net.Security.Tests
         [Fact]
         public async Task AuthenticateAsServerAsync_Sockets_CanceledAfterStart_ThrowsOperationCanceledException()
         {
-            using (var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
-            using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
-            {
-                listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
-                listener.Listen(1);
+            (Stream client, Stream server) = TestHelper.GetConnectedTcpStreams();
 
-                await client.ConnectAsync(listener.LocalEndPoint);
-                using (Socket server = await listener.AcceptAsync())
-                using (var clientSslStream = new SslStream(new NetworkStream(client), false, AllowAnyServerCertificate))
-                using (var serverSslStream = new SslStream(new NetworkStream(server)))
-                using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate())
-                {
-                    var cts = new CancellationTokenSource();
-                    Task t = serverSslStream.AuthenticateAsServerAsync(new SslServerAuthenticationOptions() { ServerCertificate = certificate }, cts.Token);
-                    cts.Cancel();
-                    await Assert.ThrowsAnyAsync<OperationCanceledException>(() => t);
-                }
+            using (client)
+            using (server)
+            using (var clientSslStream = new SslStream(client, false, AllowAnyServerCertificate))
+            using (var serverSslStream = new SslStream(server))
+            using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate())
+            {
+                var cts = new CancellationTokenSource();
+                Task t = serverSslStream.AuthenticateAsServerAsync(new SslServerAuthenticationOptions() { ServerCertificate = certificate }, cts.Token);
+                cts.Cancel();
+                await Assert.ThrowsAnyAsync<OperationCanceledException>(() => t);
             }
         }
     }
index afddd56..b876bdb 100644 (file)
@@ -11,6 +11,7 @@
         <Compile Include="NotifyReadVirtualNetworkStream.cs" />
         <Compile Include="DummyTcpServer.cs" />
         <Compile Include="TestConfiguration.cs" />
+        <Compile Include="TestHelper.cs" />
         <!-- SslStream Tests -->
         <Compile Include="CertificateValidationClientServer.cs" />
         <Compile Include="CertificateValidationRemoteServer.cs" />
       </ItemGroup>
     </When>
   </Choose>
-</Project>
\ No newline at end of file
+</Project>
diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs
new file mode 100644 (file)
index 0000000..e695cde
--- /dev/null
@@ -0,0 +1,48 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using System.IO;
+using System.Net;
+using System.Net.Sockets;
+using System.Net.Test.Common;
+
+namespace System.Net.Security.Tests
+{
+    public static class TestHelper
+    {
+        public static (Stream ClientStream, Stream ServerStream) GetConnectedStreams()
+        {
+            if (Capability.SecurityForceSocketStreams())
+            {
+                return GetConnectedTcpStreams();
+            }
+
+            return GetConnectedVirtualStreams();
+        }
+
+        internal static (NetworkStream ClientStream, NetworkStream ServerStream) GetConnectedTcpStreams()
+        {
+            using (Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+            {
+                listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
+                listener.Listen(1);
+
+                var clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+                clientSocket.Connect(listener.LocalEndPoint);
+                Socket serverSocket = listener.Accept();
+
+                return (new NetworkStream(clientSocket, ownsSocket: true), new NetworkStream(serverSocket, ownsSocket: true));
+            }
+
+        }
+
+        internal static (VirtualNetworkStream ClientStream, VirtualNetworkStream ServerStream) GetConnectedVirtualStreams()
+        {
+            VirtualNetwork vn = new VirtualNetwork();
+
+            return (new VirtualNetworkStream(vn, isServer: false), new VirtualNetworkStream(vn, isServer: true));
+        }
+    }
+}
index dbeaead..4cd25bd 100644 (file)
@@ -65,12 +65,9 @@ namespace System.Net.Security
         // This method assumes that a SSPI context is already in a good shape.
         // For example it is either a fresh context or already authenticated context that needs renegotiation.
         //
-        private void ProcessAuthentication(LazyAsyncResult lazyResult, CancellationToken cancellationToken)
-        {
-        }
-
-        private void EndProcessAuthentication(IAsyncResult result)
+        private Task ProcessAuthentication(bool isAsync = false, bool isApm = false, CancellationToken cancellationToken = default)
         {
+            return Task.Run(() => {});
         }
 
         private void ReturnReadBufferIfEmpty()