Rewrite NegotiateStream.XxAsync operations with async/await (#36583)
authorStephen Toub <stoub@microsoft.com>
Tue, 19 May 2020 11:42:06 +0000 (07:42 -0400)
committerGitHub <noreply@github.com>
Tue, 19 May 2020 11:42:06 +0000 (07:42 -0400)
* Rewrite NegotiateStream.Read/Write* operations with async/await

Gets rid of a bunch of IAsyncResult cruft and makes the XxAsync APIs cancelable.

* Combine NegoState into NegotiateStream

* Rewrite AuthenticateAs* with async/await

* Add more NegotiateStream tests

Including for cancellation and a product fix to enable cancellation.

* Update ref with overrides

* Remove custom IAsyncResults from System.Net.Security

* Fix UnitTests project

25 files changed:
src/libraries/Common/src/Interop/Unix/System.Net.Security.Native/Interop.NetSecurityNative.cs
src/libraries/Common/src/System/Net/Security/NegotiateStreamPal.Unix.cs
src/libraries/Common/src/System/Threading/Tasks/TaskToApm.cs
src/libraries/Common/tests/System/Net/VirtualNetwork/VirtualNetworkStream.cs
src/libraries/Native/Unix/System.Net.Security.Native/pal_gssapi.c
src/libraries/Native/Unix/System.Net.Security.Native/pal_gssapi.h
src/libraries/System.Net.Security/ref/System.Net.Security.cs
src/libraries/System.Net.Security/src/System.Net.Security.csproj
src/libraries/System.Net.Security/src/System/Net/FixedSizeReader.cs [deleted file]
src/libraries/System.Net.Security/src/System/Net/NTAuthentication.cs
src/libraries/System.Net.Security/src/System/Net/Security/InternalNegotiateStream.cs [deleted file]
src/libraries/System.Net.Security/src/System/Net/Security/NegoState.cs [deleted file]
src/libraries/System.Net.Security/src/System/Net/Security/NegotiateStream.cs
src/libraries/System.Net.Security/src/System/Net/Security/NegotiateStreamPal.Windows.cs
src/libraries/System.Net.Security/src/System/Net/Security/ReadWriteAdapter.cs [new file with mode: 0644]
src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.Adapters.cs [deleted file]
src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs
src/libraries/System.Net.Security/src/System/Net/SslStreamContext.cs
src/libraries/System.Net.Security/src/System/Net/StreamFramer.cs
src/libraries/System.Net.Security/tests/FunctionalTests/NegotiateStreamInvalidOperationTest.cs
src/libraries/System.Net.Security/tests/FunctionalTests/NegotiateStreamStreamToStreamTest.cs
src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeLazyAsyncResult.cs [deleted file]
src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs
src/libraries/System.Net.Security/tests/UnitTests/System.Net.Security.Unit.Tests.csproj

index c80283f..b86d5ca 100644 (file)
@@ -127,12 +127,11 @@ internal static partial class Interop
             ref GssBuffer token);
 
         [DllImport(Interop.Libraries.NetSecurityNative, EntryPoint="NetSecurityNative_Wrap")]
-        private static extern Status Wrap(
+        private static extern unsafe Status Wrap(
             out Status minorStatus,
             SafeGssContextHandle? contextHandle,
             bool isEncrypt,
-            byte[] inputBytes,
-            int offset,
+            byte* inputBytes,
             int count,
             ref GssBuffer outBuffer);
 
@@ -145,20 +144,17 @@ internal static partial class Interop
             int count,
             ref GssBuffer outBuffer);
 
-        internal static Status WrapBuffer(
+        internal static unsafe Status WrapBuffer(
             out Status minorStatus,
             SafeGssContextHandle? contextHandle,
             bool isEncrypt,
-            byte[] inputBytes,
-            int offset,
-            int count,
+            ReadOnlySpan<byte> inputBytes,
             ref GssBuffer outBuffer)
         {
-            Debug.Assert(inputBytes != null, "inputBytes must be valid value");
-            Debug.Assert(offset >= 0 && offset <= inputBytes.Length, "offset must be valid");
-            Debug.Assert(count >= 0 && count <= (inputBytes.Length - offset), "count must be valid");
-
-            return Wrap(out minorStatus, contextHandle, isEncrypt, inputBytes, offset, count, ref outBuffer);
+            fixed (byte* inputBytesPtr = inputBytes)
+            {
+                return Wrap(out minorStatus, contextHandle, isEncrypt, inputBytesPtr, inputBytes.Length, ref outBuffer);
+            }
         }
 
         internal static Status UnwrapBuffer(
index 21edf85..8ba6978 100644 (file)
@@ -43,19 +43,13 @@ namespace System.Net.Security
         private static byte[] GssWrap(
             SafeGssContextHandle? context,
             bool encrypt,
-            byte[] buffer,
-            int offset,
-            int count)
+            ReadOnlySpan<byte> buffer)
         {
-            Debug.Assert((buffer != null) && (buffer.Length > 0), "Invalid input buffer passed to Encrypt");
-            Debug.Assert((offset >= 0) && (offset < buffer.Length), "Invalid input offset passed to Encrypt");
-            Debug.Assert((count >= 0) && (count <= (buffer.Length - offset)), "Invalid input count passed to Encrypt");
-
-            Interop.NetSecurityNative.GssBuffer encryptedBuffer = default(Interop.NetSecurityNative.GssBuffer);
+            Interop.NetSecurityNative.GssBuffer encryptedBuffer = default;
             try
             {
                 Interop.NetSecurityNative.Status minorStatus;
-                Interop.NetSecurityNative.Status status = Interop.NetSecurityNative.WrapBuffer(out minorStatus, context, encrypt, buffer, offset, count, ref encryptedBuffer);
+                Interop.NetSecurityNative.Status status = Interop.NetSecurityNative.WrapBuffer(out minorStatus, context, encrypt, buffer, ref encryptedBuffer);
                 if (status != Interop.NetSecurityNative.Status.GSS_S_COMPLETE)
                 {
                     throw new Interop.NetSecurityNative.GssApiException(status, minorStatus);
@@ -555,16 +549,14 @@ namespace System.Net.Security
 
         internal static int Encrypt(
             SafeDeleteContext securityContext,
-            byte[] buffer,
-            int offset,
-            int count,
+            ReadOnlySpan<byte> buffer,
             bool isConfidential,
             bool isNtlm,
-            ref byte[]? output,
+            [NotNull] ref byte[]? output,
             uint sequenceNumber)
         {
             SafeDeleteNegoContext gssContext = (SafeDeleteNegoContext) securityContext;
-            byte[] tempOutput = GssWrap(gssContext.GssContext, isConfidential, buffer, offset, count);
+            byte[] tempOutput = GssWrap(gssContext.GssContext, isConfidential, buffer);
 
             // Create space for prefixing with the length
             const int prefixLength = 4;
@@ -628,7 +620,7 @@ namespace System.Net.Security
         internal static int MakeSignature(SafeDeleteContext securityContext, byte[] buffer, int offset, int count, [AllowNull] ref byte[] output)
         {
             SafeDeleteNegoContext gssContext = (SafeDeleteNegoContext)securityContext;
-            byte[] tempOutput = GssWrap(gssContext.GssContext, false, buffer, offset, count);
+            byte[] tempOutput = GssWrap(gssContext.GssContext, false, new ReadOnlySpan<byte>(buffer, offset, count));
             // Create space for prefixing with the length
             const int prefixLength = 4;
             output = new byte[tempOutput.Length + prefixLength];
index 96b4150..7be4563 100644 (file)
@@ -14,6 +14,7 @@
 
 #nullable enable
 using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
 
 namespace System.Threading.Tasks
 {
@@ -43,7 +44,7 @@ namespace System.Threading.Tasks
                 return;
             }
 
-            throw new ArgumentNullException(nameof(asyncResult));
+            ThrowArgumentException(asyncResult);
         }
 
         /// <summary>Processes an IAsyncResult returned by Begin.</summary>
@@ -55,9 +56,17 @@ namespace System.Threading.Tasks
                 return task.GetAwaiter().GetResult();
             }
 
-            throw new ArgumentNullException(nameof(asyncResult));
+            ThrowArgumentException(asyncResult);
+            return default!; // unreachable
         }
 
+        /// <summary>Throws an argument exception for the invalid <paramref name="asyncResult"/>.</summary>
+        [DoesNotReturn]
+        private static void ThrowArgumentException(IAsyncResult asyncResult) =>
+            throw (asyncResult is null ?
+                new ArgumentNullException(nameof(asyncResult)) :
+                new ArgumentException(null, nameof(asyncResult)));
+
         /// <summary>Provides a simple IAsyncResult that wraps a Task.</summary>
         /// <remarks>
         /// We could use the Task as the IAsyncResult if the Task's AsyncState is the same as the object state,
index 0f45acc..daf5caf 100644 (file)
@@ -22,6 +22,8 @@ namespace System.Net.Test.Common
             _isServer = isServer;
         }
 
+        public int DelayMilliseconds { get; set; }
+
         public bool Disposed { get; private set; }
 
         public override bool CanRead => true;
@@ -87,6 +89,11 @@ namespace System.Net.Test.Common
             await _readStreamLock.WaitAsync(cancellationToken).ConfigureAwait(false);
             try
             {
+                if (DelayMilliseconds > 0)
+                {
+                    await Task.Delay(DelayMilliseconds, cancellationToken);
+                }
+
                 if (_readStream == null || (_readStream.Position >= _readStream.Length))
                 {
                     _readStream = new MemoryStream(await _network.ReadFrameAsync(_isServer, cancellationToken).ConfigureAwait(false));
@@ -105,22 +112,16 @@ namespace System.Net.Test.Common
             _network.WriteFrame(_isServer, buffer.AsSpan(offset, count).ToArray());
         }
 
-        public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+        public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
         {
-            if (cancellationToken.IsCancellationRequested)
-            {
-                return Task.FromCanceled<int>(cancellationToken);
-            }
+            cancellationToken.ThrowIfCancellationRequested();
 
-            try
+            if (DelayMilliseconds > 0)
             {
-                Write(buffer, offset, count);
-                return Task.CompletedTask;
-            }
-            catch (Exception exc)
-            {
-                return Task.FromException(exc);
+                await Task.Delay(DelayMilliseconds, cancellationToken);
             }
+
+            Write(buffer, offset, count);
         }
 
         public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) =>
index 0fbd518..03eef62 100644 (file)
@@ -417,7 +417,6 @@ uint32_t NetSecurityNative_Wrap(uint32_t* minorStatus,
                                 GssCtxId* contextHandle,
                                 int32_t isEncrypt,
                                 uint8_t* inputBytes,
-                                int32_t offset,
                                 int32_t count,
                                 PAL_GssBuffer* outBuffer)
 {
@@ -425,14 +424,13 @@ uint32_t NetSecurityNative_Wrap(uint32_t* minorStatus,
     assert(contextHandle != NULL);
     assert(isEncrypt == 1 || isEncrypt == 0);
     assert(inputBytes != NULL);
-    assert(offset >= 0);
     assert(count >= 0);
     assert(outBuffer != NULL);
     // count refers to the length of the input message. That is, number of bytes of inputBytes
-    // starting at offset that need to be wrapped.
+    // that need to be wrapped.
 
     int confState;
-    GssBuffer inputMessageBuffer = {.length = (size_t)count, .value = inputBytes + offset};
+    GssBuffer inputMessageBuffer = {.length = (size_t)count, .value = inputBytes};
     GssBuffer gssBuffer;
     uint32_t majorStatus =
         gss_wrap(minorStatus, contextHandle, isEncrypt, GSS_C_QOP_DEFAULT, &inputMessageBuffer, &confState, &gssBuffer);
index 11e232d..489b66b 100644 (file)
@@ -159,7 +159,6 @@ PALEXPORT uint32_t NetSecurityNative_Wrap(uint32_t* minorStatus,
                                           GssCtxId* contextHandle,
                                           int32_t isEncrypt,
                                           uint8_t* inputBytes,
-                                          int32_t offset,
                                           int32_t count,
                                           PAL_GssBuffer* outBuffer);
 
index ff6cee1..d657bb1 100644 (file)
@@ -91,9 +91,13 @@ namespace System.Net.Security
         public override void Flush() { }
         public override System.Threading.Tasks.Task FlushAsync(System.Threading.CancellationToken cancellationToken) { throw null; }
         public override int Read(byte[] buffer, int offset, int count) { throw null; }
+        public override System.Threading.Tasks.Task<int> ReadAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; }
+        public override System.Threading.Tasks.ValueTask<int> ReadAsync(System.Memory<byte> buffer, System.Threading.CancellationToken cancellationToken = default) { throw null; }
         public override long Seek(long offset, System.IO.SeekOrigin origin) { throw null; }
         public override void SetLength(long value) { }
         public override void Write(byte[] buffer, int offset, int count) { }
+        public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; }
+        public override System.Threading.Tasks.ValueTask WriteAsync(System.ReadOnlyMemory<byte> buffer, System.Threading.CancellationToken cancellationToken = default) { throw null; }
     }
     public enum ProtectionLevel
     {
index f2daece..73d613d 100644 (file)
   </PropertyGroup>
   <ItemGroup>
     <Compile Include="System\Net\CertificateValidationPal.cs" />
-    <Compile Include="System\Net\FixedSizeReader.cs" />
-    <Compile Include="System\Net\HelperAsyncResults.cs" />
     <Compile Include="System\Net\Logging\NetEventSource.cs" />
+    <Compile Include="System\Net\SslStreamContext.cs" />
+    <Compile Include="System\Net\Security\AuthenticatedStream.cs" />
+    <Compile Include="System\Security\Authentication\AuthenticationException.cs" />
+    <Compile Include="System\Net\Security\CipherSuitesPolicy.cs" />
+    <Compile Include="System\Net\Security\NetEventSource.Security.cs" />
+    <Compile Include="System\Net\Security\ReadWriteAdapter.cs" />
+    <Compile Include="System\Net\Security\ProtectionLevel.cs" />
     <Compile Include="System\Net\Security\SniHelper.cs" />
     <Compile Include="System\Net\Security\SslApplicationProtocol.cs" />
     <Compile Include="System\Net\Security\SslAuthenticationOptions.cs" />
     <Compile Include="System\Net\Security\SslClientAuthenticationOptions.cs" />
     <Compile Include="System\Net\Security\SslServerAuthenticationOptions.cs" />
-    <Compile Include="System\Net\Security\SslStream.Implementation.Adapters.cs" />
-    <Compile Include="System\Net\SslStreamContext.cs" />
-    <Compile Include="System\Net\Security\AuthenticatedStream.cs" />
-    <Compile Include="System\Net\Security\CipherSuitesPolicy.cs" />
-    <Compile Include="System\Net\Security\NetEventSource.Security.cs" />
     <Compile Include="System\Net\Security\SecureChannel.cs" />
     <Compile Include="System\Net\Security\SslSessionsCache.cs" />
     <Compile Include="System\Net\Security\SslStream.cs" />
     <Compile Include="System\Net\Security\SslStream.Implementation.cs" />
-    <Compile Include="System\Net\Security\ProtectionLevel.cs" />
     <Compile Include="System\Net\Security\SslConnectionInfo.cs" />
     <Compile Include="System\Net\Security\StreamSizes.cs" />
     <Compile Include="System\Net\Security\TlsAlertType.cs" />
     <Compile Include="System\Net\Security\TlsAlertMessage.cs" />
     <Compile Include="System\Net\Security\TlsFrameHelper.cs" />
-    <Compile Include="System\Security\Authentication\AuthenticationException.cs" />
     <!-- NegotiateStream -->
-    <Compile Include="System\Net\BufferAsyncResult.cs" />
     <Compile Include="System\Net\NTAuthentication.cs" />
     <Compile Include="System\Net\StreamFramer.cs" />
     <Compile Include="System\Net\Security\NegotiateStream.cs" />
-    <Compile Include="System\Net\Security\NegoState.cs" />
-    <Compile Include="System\Net\Security\InternalNegotiateStream.cs" />
     <Compile Include="System\Security\Authentication\ExtendedProtection\ExtendedProtectionPolicy.cs" />
     <Compile Include="System\Security\Authentication\ExtendedProtection\PolicyEnforcement.cs" />
     <Compile Include="System\Security\Authentication\ExtendedProtection\ProtectionScenario.cs" />
@@ -65,8 +60,6 @@
     </Compile>
     <Compile Include="$(CommonPath)System\Net\ExceptionCheck.cs"
              Link="Common\System\Net\ExceptionCheck.cs" />
-    <Compile Include="$(CommonPath)System\Net\LazyAsyncResult.cs"
-             Link="Common\System\Net\LazyAsyncResult.cs" />
     <Compile Include="$(CommonPath)System\Net\SecurityProtocol.cs"
              Link="Common\System\Net\SecurityProtocol.cs" />
     <Compile Include="$(CommonPath)System\Net\UriScheme.cs"
diff --git a/src/libraries/System.Net.Security/src/System/Net/FixedSizeReader.cs b/src/libraries/System.Net.Security/src/System/Net/FixedSizeReader.cs
deleted file mode 100644 (file)
index 3152739..0000000
+++ /dev/null
@@ -1,80 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System.Diagnostics;
-using System.IO;
-using System.Threading.Tasks;
-
-namespace System.Net
-{
-    /// <summary>
-    /// The class is a simple wrapper on top of a read stream. It will read the exact number of bytes requested.
-    /// It will throw if EOF is reached before the expected number of bytes is returned.
-    /// </summary>
-    internal static class FixedSizeReader
-    {
-        /// <summary>
-        /// Returns 0 on legitimate EOF or if 0 bytes were requested, otherwise reads as directed or throws.
-        /// Returns count on success.
-        /// </summary>
-        public static int ReadPacket(Stream transport, byte[] buffer, int offset, int count)
-        {
-            int remainingCount = count;
-            do
-            {
-                int bytes = transport.Read(buffer, offset, remainingCount);
-                if (bytes == 0)
-                {
-                    if (remainingCount != count)
-                    {
-                        throw new IOException(SR.net_io_eof);
-                    }
-
-                    return 0;
-                }
-
-                remainingCount -= bytes;
-                offset += bytes;
-            } while (remainingCount > 0);
-
-            Debug.Assert(remainingCount == 0);
-            return count;
-        }
-
-        /// <summary>
-        /// Completes "request" with 0 if 0 bytes was requested or legitimate EOF received.
-        /// Otherwise, reads as directed or completes "request" with an Exception.
-        /// </summary>
-        public static async Task ReadPacketAsync(Stream transport, AsyncProtocolRequest request)
-        {
-            try
-            {
-                int remainingCount = request.Count, offset = request.Offset;
-                do
-                {
-                    int bytes = await transport.ReadAsync(new Memory<byte>(request.Buffer, offset, remainingCount), request.CancellationToken).ConfigureAwait(false);
-                    if (bytes == 0)
-                    {
-                        if (remainingCount != request.Count)
-                        {
-                            throw new IOException(SR.net_io_eof);
-                        }
-                        request.CompleteRequest(0);
-                        return;
-                    }
-
-                    offset += bytes;
-                    remainingCount -= bytes;
-                } while (remainingCount > 0);
-
-                Debug.Assert(remainingCount == 0);
-                request.CompleteRequest(request.Count);
-            }
-            catch (Exception e)
-            {
-                request.CompleteUserWithError(e);
-            }
-        }
-    }
-}
index af93167..0f4f7af 100644 (file)
@@ -3,6 +3,7 @@
 // See the LICENSE file in the project root for more information.
 
 using System.ComponentModel;
+using System.Diagnostics.CodeAnalysis;
 using System.Net.Security;
 using System.Security.Authentication.ExtendedProtection;
 
@@ -114,13 +115,11 @@ namespace System.Net
             context.ThisPtr.Initialize(context.IsServer, context.Package, context.Credential, context.Spn, context.RequestedContextFlags, context.ChannelBinding);
         }
 
-        internal int Encrypt(byte[] buffer, int offset, int count, ref byte[]? output, uint sequenceNumber)
+        internal int Encrypt(ReadOnlySpan<byte> buffer, [NotNull] ref byte[]? output, uint sequenceNumber)
         {
             return NegotiateStreamPal.Encrypt(
                 _securityContext!,
                 buffer,
-                offset,
-                count,
                 IsConfidentialityFlag,
                 IsNTLM,
                 ref output,
diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/InternalNegotiateStream.cs b/src/libraries/System.Net.Security/src/System/Net/Security/InternalNegotiateStream.cs
deleted file mode 100644 (file)
index 4c748bd..0000000
+++ /dev/null
@@ -1,446 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System.IO;
-using System.Threading;
-using System.Threading.Tasks;
-
-namespace System.Net.Security
-{
-    //
-    // This is a wrapping stream that does data encryption/decryption based on a successfully authenticated SSPI context.
-    // This file contains the private implementation.
-    //
-    public partial class NegotiateStream : AuthenticatedStream
-    {
-        private static readonly AsyncCallback s_writeCallback = new AsyncCallback(WriteCallback);
-        private static readonly AsyncProtocolCallback s_readCallback = new AsyncProtocolCallback(ReadCallback);
-
-        private int _NestedWrite;
-        private int _NestedRead;
-        private byte[] _ReadHeader = null!; // will be initialized by ctor helper
-
-        // Never updated directly, special properties are used.
-        private byte[]? _InternalBuffer;
-        private int _InternalOffset;
-        private int _InternalBufferCount;
-
-        private void InitializeStreamPart()
-        {
-            _ReadHeader = new byte[4];
-        }
-
-        private byte[]? InternalBuffer
-        {
-            get
-            {
-                return _InternalBuffer;
-            }
-        }
-
-        private int InternalOffset
-        {
-            get
-            {
-                return _InternalOffset;
-            }
-        }
-
-        private int InternalBufferCount
-        {
-            get
-            {
-                return _InternalBufferCount;
-            }
-        }
-
-        private void DecrementInternalBufferCount(int decrCount)
-        {
-            _InternalOffset += decrCount;
-            _InternalBufferCount -= decrCount;
-        }
-
-        private void EnsureInternalBufferSize(int bytes)
-        {
-            _InternalBufferCount = bytes;
-            _InternalOffset = 0;
-            if (InternalBuffer == null || InternalBuffer.Length < bytes)
-            {
-                _InternalBuffer = new byte[bytes];
-            }
-        }
-
-        private void AdjustInternalBufferOffsetSize(int bytes, int offset)
-        {
-            _InternalBufferCount = bytes;
-            _InternalOffset = offset;
-        }
-
-        //
-        // Validates user parameters for all Read/Write methods.
-        //
-        private void ValidateParameters(byte[] buffer, int offset, int count)
-        {
-            if (buffer == null)
-            {
-                throw new ArgumentNullException(nameof(buffer));
-            }
-
-            if (offset < 0)
-            {
-                throw new ArgumentOutOfRangeException(nameof(offset));
-            }
-
-            if (count < 0)
-            {
-                throw new ArgumentOutOfRangeException(nameof(count));
-            }
-
-            if (count > buffer.Length - offset)
-            {
-                throw new ArgumentOutOfRangeException(nameof(count), SR.net_offset_plus_count);
-            }
-        }
-
-        //
-        // Combined sync/async write method. For sync request asyncRequest==null.
-        //
-        private void ProcessWrite(byte[] buffer, int offset, int count, AsyncProtocolRequest? asyncRequest)
-        {
-            ValidateParameters(buffer, offset, count);
-
-            if (Interlocked.Exchange(ref _NestedWrite, 1) == 1)
-            {
-                throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, (asyncRequest != null ? "BeginWrite" : "Write"), "write"));
-            }
-
-            bool failed = false;
-            try
-            {
-                StartWriting(buffer, offset, count, asyncRequest);
-            }
-            catch (Exception e)
-            {
-                failed = true;
-                if (e is IOException)
-                {
-                    throw;
-                }
-
-                throw new IOException(SR.net_io_write, e);
-            }
-            finally
-            {
-                if (asyncRequest == null || failed)
-                {
-                    _NestedWrite = 0;
-                }
-            }
-        }
-
-        private void StartWriting(byte[] buffer, int offset, int count, AsyncProtocolRequest? asyncRequest)
-        {
-            // We loop to this method from the callback.
-            // If the last chunk was just completed from async callback (count < 0), we complete user request.
-            if (count >= 0)
-            {
-                byte[]? outBuffer = null;
-                do
-                {
-                    int chunkBytes = Math.Min(count, NegoState.MaxWriteDataSize);
-                    int encryptedBytes;
-
-                    try
-                    {
-                        encryptedBytes = _negoState.EncryptData(buffer, offset, chunkBytes, ref outBuffer);
-                    }
-                    catch (Exception e)
-                    {
-                        throw new IOException(SR.net_io_encrypt, e);
-                    }
-
-                    if (asyncRequest != null)
-                    {
-                        // prepare for the next request
-                        asyncRequest.SetNextRequest(buffer, offset + chunkBytes, count - chunkBytes, null);
-                        Task t = InnerStream.WriteAsync(outBuffer!, 0, encryptedBytes);
-                        if (t.IsCompleted)
-                        {
-                            t.GetAwaiter().GetResult();
-                        }
-                        else
-                        {
-                            IAsyncResult ar = TaskToApm.Begin(t, s_writeCallback, asyncRequest);
-                            if (!ar.CompletedSynchronously)
-                            {
-                                return;
-                            }
-                            TaskToApm.End(ar);
-                        }
-                    }
-                    else
-                    {
-                        InnerStream.Write(outBuffer!, 0, encryptedBytes);
-                    }
-
-                    offset += chunkBytes;
-                    count -= chunkBytes;
-                } while (count != 0);
-            }
-
-            if (asyncRequest != null)
-            {
-                asyncRequest.CompleteUser();
-            }
-        }
-
-        //
-        // Combined sync/async read method. For sync request asyncRequest==null.
-        // There is a little overhead because we need to pass buffer/offset/count used only in sync.
-        // Still the benefit is that we have a common sync/async code path.
-        //
-        private int ProcessRead(byte[] buffer, int offset, int count, AsyncProtocolRequest? asyncRequest)
-        {
-            ValidateParameters(buffer, offset, count);
-
-            if (Interlocked.Exchange(ref _NestedRead, 1) == 1)
-            {
-                throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, (asyncRequest != null ? "BeginRead" : "Read"), "read"));
-            }
-
-            bool failed = false;
-            try
-            {
-                if (InternalBufferCount != 0)
-                {
-                    int copyBytes = InternalBufferCount > count ? count : InternalBufferCount;
-                    if (copyBytes != 0)
-                    {
-                        Buffer.BlockCopy(InternalBuffer!, InternalOffset, buffer, offset, copyBytes);
-                        DecrementInternalBufferCount(copyBytes);
-                    }
-                    asyncRequest?.CompleteUser(copyBytes);
-                    return copyBytes;
-                }
-
-                // Performing actual I/O.
-                return StartReading(buffer, offset, count, asyncRequest);
-            }
-            catch (Exception e)
-            {
-                failed = true;
-                if (e is IOException)
-                {
-                    throw;
-                }
-                throw new IOException(SR.net_io_read, e);
-            }
-            finally
-            {
-                if (asyncRequest == null || failed)
-                {
-                    _NestedRead = 0;
-                }
-            }
-        }
-
-        //
-        // To avoid recursion when 0 bytes have been decrypted, loop until decryption results in at least 1 byte.
-        //
-        private int StartReading(byte[] buffer, int offset, int count, AsyncProtocolRequest? asyncRequest)
-        {
-            int result;
-            // When we read -1 bytes means we have decrypted 0 bytes, need looping.
-            while ((result = StartFrameHeader(buffer, offset, count, asyncRequest)) == -1);
-
-            return result;
-        }
-
-        private int StartFrameHeader(byte[] buffer, int offset, int count, AsyncProtocolRequest? asyncRequest)
-        {
-            int readBytes = 0;
-            if (asyncRequest != null)
-            {
-                asyncRequest.SetNextRequest(_ReadHeader, 0, _ReadHeader.Length, s_readCallback);
-                _ = FixedSizeReader.ReadPacketAsync(InnerStream, asyncRequest);
-                if (!asyncRequest.MustCompleteSynchronously)
-                {
-                    return 0;
-                }
-
-                readBytes = asyncRequest.Result;
-            }
-            else
-            {
-                readBytes = FixedSizeReader.ReadPacket(InnerStream, _ReadHeader, 0, _ReadHeader.Length);
-            }
-
-            return StartFrameBody(readBytes, buffer, offset, count, asyncRequest);
-        }
-
-        private int StartFrameBody(int readBytes, byte[] buffer, int offset, int count, AsyncProtocolRequest? asyncRequest)
-        {
-            if (readBytes == 0)
-            {
-                //EOF
-                asyncRequest?.CompleteUser(0);
-                return 0;
-            }
-
-            if (!(readBytes == _ReadHeader.Length))
-            {
-                NetEventSource.Fail(this, $"Frame size must be 4 but received {readBytes} bytes.");
-            }
-
-            // Replace readBytes with the body size recovered from the header content.
-            readBytes = _ReadHeader[3];
-            readBytes = (readBytes << 8) | _ReadHeader[2];
-            readBytes = (readBytes << 8) | _ReadHeader[1];
-            readBytes = (readBytes << 8) | _ReadHeader[0];
-
-            //
-            // The body carries 4 bytes for trailer size slot plus trailer, hence <=4 frame size is always an error.
-            // Additionally we'd like to restrict the read frame size to 64k.
-            //
-            if (readBytes <= 4 || readBytes > NegoState.MaxReadFrameSize)
-            {
-                throw new IOException(SR.net_frame_read_size);
-            }
-
-            //
-            // Always pass InternalBuffer for SSPI "in place" decryption.
-            // A user buffer can be shared by many threads in that case decryption/integrity check may fail cause of data corruption.
-            //
-            EnsureInternalBufferSize(readBytes);
-            if (asyncRequest != null)
-            {
-                asyncRequest.SetNextRequest(InternalBuffer, 0, readBytes, s_readCallback);
-
-                _ = FixedSizeReader.ReadPacketAsync(InnerStream, asyncRequest);
-
-                if (!asyncRequest.MustCompleteSynchronously)
-                {
-                    return 0;
-                }
-
-                readBytes = asyncRequest.Result;
-            }
-            else //Sync
-            {
-                readBytes = FixedSizeReader.ReadPacket(InnerStream, InternalBuffer!, 0, readBytes);
-            }
-
-            return ProcessFrameBody(readBytes, buffer, offset, count, asyncRequest);
-        }
-
-        private int ProcessFrameBody(int readBytes, byte[] buffer, int offset, int count, AsyncProtocolRequest? asyncRequest)
-        {
-            if (readBytes == 0)
-            {
-                // We already checked that the frame body is bigger than 0 bytes
-                // Hence, this is an EOF ... fire.
-                throw new IOException(SR.net_io_eof);
-            }
-
-            // Decrypt into internal buffer, change "readBytes" to count now _Decrypted Bytes_
-            int internalOffset;
-            readBytes = _negoState.DecryptData(InternalBuffer!, 0, readBytes, out internalOffset);
-
-            // Decrypted data start from zero offset, the size can be shrunk after decryption.
-            AdjustInternalBufferOffsetSize(readBytes, internalOffset);
-
-            if (readBytes == 0 && count != 0)
-            {
-                // Read again.
-                return -1;
-            }
-
-            if (readBytes > count)
-            {
-                readBytes = count;
-            }
-
-            Buffer.BlockCopy(InternalBuffer!, InternalOffset, buffer, offset, readBytes);
-
-            // This will adjust both the remaining internal buffer count and the offset.
-            DecrementInternalBufferCount(readBytes);
-
-            asyncRequest?.CompleteUser(readBytes);
-
-            return readBytes;
-        }
-
-        private static void WriteCallback(IAsyncResult transportResult)
-        {
-            if (transportResult.CompletedSynchronously)
-            {
-                return;
-            }
-
-            if (!(transportResult.AsyncState is AsyncProtocolRequest))
-            {
-                NetEventSource.Fail(transportResult, "State type is wrong, expected AsyncProtocolRequest.");
-            }
-
-            AsyncProtocolRequest asyncRequest = (AsyncProtocolRequest)transportResult.AsyncState!;
-
-            try
-            {
-                NegotiateStream negoStream = (NegotiateStream)asyncRequest.AsyncObject!;
-                TaskToApm.End(transportResult);
-                if (asyncRequest.Count == 0)
-                {
-                    // This was the last chunk.
-                    asyncRequest.Count = -1;
-                }
-
-                negoStream.StartWriting(asyncRequest.Buffer!, asyncRequest.Offset, asyncRequest.Count, asyncRequest);
-            }
-            catch (Exception e)
-            {
-                if (asyncRequest.IsUserCompleted)
-                {
-                    // This will throw on a worker thread.
-                    throw;
-                }
-
-                asyncRequest.CompleteUserWithError(e);
-            }
-        }
-
-        private static void ReadCallback(AsyncProtocolRequest asyncRequest)
-        {
-            // Async ONLY completion.
-            try
-            {
-                NegotiateStream negoStream = (NegotiateStream)asyncRequest.AsyncObject!;
-                BufferAsyncResult bufferResult = (BufferAsyncResult)asyncRequest.UserAsyncResult;
-
-                // This is an optimization to avoid an additional callback.
-                if ((object?)asyncRequest.Buffer == (object?)negoStream._ReadHeader)
-                {
-                    negoStream.StartFrameBody(asyncRequest.Result, bufferResult.Buffer, bufferResult.Offset, bufferResult.Count, asyncRequest);
-                }
-                else
-                {
-                    if (-1 == negoStream.ProcessFrameBody(asyncRequest.Result, bufferResult.Buffer, bufferResult.Offset, bufferResult.Count, asyncRequest))
-                    {
-                        // In case we decrypted 0 bytes, start another reading.
-                        negoStream.StartReading(bufferResult.Buffer, bufferResult.Offset, bufferResult.Count, asyncRequest);
-                    }
-                }
-            }
-            catch (Exception e)
-            {
-                if (asyncRequest.IsUserCompleted)
-                {
-                    // This will throw on a worker thread.
-                    throw;
-                }
-
-                asyncRequest.CompleteUserWithError(e);
-            }
-        }
-    }
-}
diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/NegoState.cs b/src/libraries/System.Net.Security/src/System/Net/Security/NegoState.cs
deleted file mode 100644 (file)
index b59ffb9..0000000
+++ /dev/null
@@ -1,840 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System.Diagnostics;
-using System.IO;
-using System.Security.Principal;
-using System.Threading;
-using System.ComponentModel;
-using System.Runtime.ExceptionServices;
-using System.Security.Authentication;
-using System.Security.Authentication.ExtendedProtection;
-
-namespace System.Net.Security
-{
-    //
-    // The class maintains the state of the authentication process and the security context.
-    // It encapsulates security context and does the real work in authentication and
-    // user data encryption
-    //
-    internal class NegoState
-    {
-#pragma warning disable CA1825 // used in reference comparison, requires unique object identity
-        private static readonly byte[] s_emptyMessage = new byte[0];
-#pragma warning restore CA1825
-        private static readonly AsyncCallback s_readCallback = new AsyncCallback(ReadCallback);
-        private static readonly AsyncCallback s_writeCallback = new AsyncCallback(WriteCallback);
-
-        private readonly Stream _innerStream;
-
-        private Exception? _exception;
-
-        private StreamFramer? _framer;
-        private NTAuthentication? _context;
-
-        private int _nestedAuth;
-
-        internal const int ERROR_TRUST_FAILURE = 1790;   // Used to serialize protectionLevel or impersonationLevel mismatch error to the remote side.
-        internal const int MaxReadFrameSize = 64 * 1024;
-        internal const int MaxWriteDataSize = 63 * 1024; // 1k for the framing and trailer that is always less as per SSPI.
-
-        private bool _canRetryAuthentication;
-        private ProtectionLevel _expectedProtectionLevel;
-        private TokenImpersonationLevel _expectedImpersonationLevel;
-        private uint _writeSequenceNumber;
-        private uint _readSequenceNumber;
-
-        private ExtendedProtectionPolicy? _extendedProtectionPolicy;
-
-        // SSPI does not send a server ack on successful auth.
-        // This is a state variable used to gracefully handle auth confirmation.
-        private bool _remoteOk = false;
-
-        internal NegoState(Stream innerStream)
-        {
-            Debug.Assert(innerStream != null);
-
-            _innerStream = innerStream;
-        }
-
-        internal static string DefaultPackage
-        {
-            get
-            {
-                return NegotiationInfoClass.Negotiate;
-            }
-        }
-
-        internal IIdentity GetIdentity()
-        {
-            CheckThrow(true);
-            return NegotiateStreamPal.GetIdentity(_context!);
-        }
-
-        internal void ValidateCreateContext(
-            string package,
-            NetworkCredential credential,
-            string servicePrincipalName,
-            ExtendedProtectionPolicy? policy,
-            ProtectionLevel protectionLevel,
-            TokenImpersonationLevel impersonationLevel)
-        {
-            if (policy != null)
-            {
-                // One of these must be set if EP is turned on
-                if (policy.CustomChannelBinding == null && policy.CustomServiceNames == null)
-                {
-                    throw new ArgumentException(SR.net_auth_must_specify_extended_protection_scheme, nameof(policy));
-                }
-
-                _extendedProtectionPolicy = policy;
-            }
-            else
-            {
-                _extendedProtectionPolicy = new ExtendedProtectionPolicy(PolicyEnforcement.Never);
-            }
-
-            ValidateCreateContext(package, true, credential, servicePrincipalName, _extendedProtectionPolicy!.CustomChannelBinding, protectionLevel, impersonationLevel);
-        }
-
-        internal void ValidateCreateContext(
-            string package,
-            bool isServer,
-            NetworkCredential credential,
-            string? servicePrincipalName,
-            ChannelBinding? channelBinding,
-            ProtectionLevel protectionLevel,
-            TokenImpersonationLevel impersonationLevel)
-        {
-            if (_exception != null && !_canRetryAuthentication)
-            {
-                ExceptionDispatchInfo.Throw(_exception);
-            }
-
-            if (_context != null && _context.IsValidContext)
-            {
-                throw new InvalidOperationException(SR.net_auth_reauth);
-            }
-
-            if (credential == null)
-            {
-                throw new ArgumentNullException(nameof(credential));
-            }
-
-            if (servicePrincipalName == null)
-            {
-                throw new ArgumentNullException(nameof(servicePrincipalName));
-            }
-
-            NegotiateStreamPal.ValidateImpersonationLevel(impersonationLevel);
-            if (_context != null && IsServer != isServer)
-            {
-                throw new InvalidOperationException(SR.net_auth_client_server);
-            }
-
-            _exception = null;
-            _remoteOk = false;
-            _framer = new StreamFramer(_innerStream);
-            _framer.WriteHeader.MessageId = FrameHeader.HandshakeId;
-
-            _expectedProtectionLevel = protectionLevel;
-            _expectedImpersonationLevel = isServer ? impersonationLevel : TokenImpersonationLevel.None;
-            _writeSequenceNumber = 0;
-            _readSequenceNumber = 0;
-
-            ContextFlagsPal flags = ContextFlagsPal.Connection;
-
-            // A workaround for the client when talking to Win9x on the server side.
-            if (protectionLevel == ProtectionLevel.None && !isServer)
-            {
-                package = NegotiationInfoClass.NTLM;
-            }
-            else if (protectionLevel == ProtectionLevel.EncryptAndSign)
-            {
-                flags |= ContextFlagsPal.Confidentiality;
-            }
-            else if (protectionLevel == ProtectionLevel.Sign)
-            {
-                // Assuming user expects NT4 SP4 and above.
-                flags |= (ContextFlagsPal.ReplayDetect | ContextFlagsPal.SequenceDetect | ContextFlagsPal.InitIntegrity);
-            }
-
-            if (isServer)
-            {
-                if (_extendedProtectionPolicy!.PolicyEnforcement == PolicyEnforcement.WhenSupported)
-                {
-                    flags |= ContextFlagsPal.AllowMissingBindings;
-                }
-
-                if (_extendedProtectionPolicy.PolicyEnforcement != PolicyEnforcement.Never &&
-                    _extendedProtectionPolicy.ProtectionScenario == ProtectionScenario.TrustedProxy)
-                {
-                    flags |= ContextFlagsPal.ProxyBindings;
-                }
-            }
-            else
-            {
-                // Server side should not request any of these flags.
-                if (protectionLevel != ProtectionLevel.None)
-                {
-                    flags |= ContextFlagsPal.MutualAuth;
-                }
-
-                if (impersonationLevel == TokenImpersonationLevel.Identification)
-                {
-                    flags |= ContextFlagsPal.InitIdentify;
-                }
-
-                if (impersonationLevel == TokenImpersonationLevel.Delegation)
-                {
-                    flags |= ContextFlagsPal.Delegate;
-                }
-            }
-
-            _canRetryAuthentication = false;
-
-            try
-            {
-                _context = new NTAuthentication(isServer, package, credential, servicePrincipalName, flags, channelBinding!);
-            }
-            catch (Win32Exception e)
-            {
-                throw new AuthenticationException(SR.net_auth_SSPI, e);
-            }
-        }
-
-        private Exception SetException(Exception e)
-        {
-            if (_exception == null || !(_exception is ObjectDisposedException))
-            {
-                _exception = e;
-            }
-
-            if (_exception != null && _context != null)
-            {
-                _context.CloseContext();
-            }
-
-            return _exception!;
-        }
-
-        internal bool IsAuthenticated
-        {
-            get
-            {
-                return _context != null && HandshakeComplete && _exception == null && _remoteOk;
-            }
-        }
-
-        internal bool IsMutuallyAuthenticated
-        {
-            get
-            {
-                if (!IsAuthenticated)
-                {
-                    return false;
-                }
-
-                // Suppressing for NTLM since SSPI does not return correct value in the context flags.
-                if (_context!.IsNTLM)
-                {
-                    return false;
-                }
-
-                return _context.IsMutualAuthFlag;
-            }
-        }
-
-        internal bool IsEncrypted
-        {
-            get
-            {
-                return IsAuthenticated && _context!.IsConfidentialityFlag;
-            }
-        }
-
-        internal bool IsSigned
-        {
-            get
-            {
-                return IsAuthenticated && (_context!.IsIntegrityFlag || _context.IsConfidentialityFlag);
-            }
-        }
-
-        internal bool IsServer
-        {
-            get
-            {
-                return _context != null && _context.IsServer;
-            }
-        }
-
-        internal bool CanGetSecureStream
-        {
-            get
-            {
-                return (_context!.IsConfidentialityFlag || _context.IsIntegrityFlag);
-            }
-        }
-
-        internal TokenImpersonationLevel AllowedImpersonation
-        {
-            get
-            {
-                CheckThrow(true);
-                return PrivateImpersonationLevel;
-            }
-        }
-
-        private TokenImpersonationLevel PrivateImpersonationLevel
-        {
-            get
-            {
-                // We should suppress the delegate flag in NTLM case.
-                return (_context!.IsDelegationFlag && _context.ProtocolName != NegotiationInfoClass.NTLM) ? TokenImpersonationLevel.Delegation
-                        : _context.IsIdentifyFlag ? TokenImpersonationLevel.Identification
-                        : TokenImpersonationLevel.Impersonation;
-            }
-        }
-
-        private bool HandshakeComplete
-        {
-            get
-            {
-                return _context!.IsCompleted && _context.IsValidContext;
-            }
-        }
-
-        internal void CheckThrow(bool authSucessCheck)
-        {
-            if (_exception != null)
-            {
-                ExceptionDispatchInfo.Throw(_exception);
-            }
-
-            if (authSucessCheck && !IsAuthenticated)
-            {
-                throw new InvalidOperationException(SR.net_auth_noauth);
-            }
-        }
-
-        //
-        // This is to not depend on GC&SafeHandle class if the context is not needed anymore.
-        //
-        internal void Close()
-        {
-            // Mark this instance as disposed.
-            _exception = new ObjectDisposedException("NegotiateStream");
-            if (_context != null)
-            {
-                _context.CloseContext();
-            }
-        }
-
-        internal void ProcessAuthentication(LazyAsyncResult? lazyResult)
-        {
-            CheckThrow(false);
-            if (Interlocked.Exchange(ref _nestedAuth, 1) == 1)
-            {
-                throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, lazyResult == null ? "BeginAuthenticate" : "Authenticate", "authenticate"));
-            }
-
-            try
-            {
-                if (_context!.IsServer)
-                {
-                    // Listen for a client blob.
-                    StartReceiveBlob(lazyResult);
-                }
-                else
-                {
-                    // Start with the first blob.
-                    StartSendBlob(null, lazyResult);
-                }
-            }
-            catch (Exception e)
-            {
-                // Round-trip it through SetException().
-                e = SetException(e);
-                throw;
-            }
-            finally
-            {
-                if (lazyResult == null || _exception != null)
-                {
-                    _nestedAuth = 0;
-                }
-            }
-        }
-
-        internal void EndProcessAuthentication(IAsyncResult result)
-        {
-            if (result == null)
-            {
-                throw new ArgumentNullException("asyncResult");
-            }
-
-            LazyAsyncResult? lazyResult = result as LazyAsyncResult;
-            if (lazyResult == null)
-            {
-                throw new ArgumentException(SR.Format(SR.net_io_async_result, result.GetType().FullName), "asyncResult");
-            }
-
-            if (Interlocked.Exchange(ref _nestedAuth, 0) == 0)
-            {
-                throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, "EndAuthenticate"));
-            }
-
-            // No "artificial" timeouts implemented so far, InnerStream controls that.
-            lazyResult.InternalWaitForCompletion();
-
-            Exception? e = lazyResult.Result as Exception;
-
-            if (e != null)
-            {
-                // Round-trip it through the SetException().
-                e = SetException(e);
-                ExceptionDispatchInfo.Throw(e);
-            }
-        }
-
-        private bool CheckSpn()
-        {
-            if (_context!.IsKerberos)
-            {
-                return true;
-            }
-
-            if (_extendedProtectionPolicy!.PolicyEnforcement == PolicyEnforcement.Never ||
-                    _extendedProtectionPolicy.CustomServiceNames == null)
-            {
-                return true;
-            }
-
-            string? clientSpn = _context.ClientSpecifiedSpn;
-
-            if (string.IsNullOrEmpty(clientSpn))
-            {
-                if (_extendedProtectionPolicy.PolicyEnforcement == PolicyEnforcement.WhenSupported)
-                {
-                    return true;
-                }
-            }
-            else
-            {
-                return _extendedProtectionPolicy.CustomServiceNames.Contains(clientSpn);
-            }
-
-            return false;
-        }
-
-        //
-        // Client side starts here, but server also loops through this method.
-        //
-        private void StartSendBlob(byte[]? message, LazyAsyncResult? lazyResult)
-        {
-            Exception? exception = null;
-            if (message != s_emptyMessage)
-            {
-                message = GetOutgoingBlob(message, ref exception);
-            }
-
-            if (exception != null)
-            {
-                // Signal remote side on a failed attempt.
-                StartSendAuthResetSignal(lazyResult, message!, exception);
-                return;
-            }
-
-            if (HandshakeComplete)
-            {
-                if (_context!.IsServer && !CheckSpn())
-                {
-                    exception = new AuthenticationException(SR.net_auth_bad_client_creds_or_target_mismatch);
-                    int statusCode = ERROR_TRUST_FAILURE;
-                    message = new byte[8];  //sizeof(long)
-
-                    for (int i = message.Length - 1; i >= 0; --i)
-                    {
-                        message[i] = (byte)(statusCode & 0xFF);
-                        statusCode = (int)((uint)statusCode >> 8);
-                    }
-
-                    StartSendAuthResetSignal(lazyResult, message, exception);
-                    return;
-                }
-
-                if (PrivateImpersonationLevel < _expectedImpersonationLevel)
-                {
-                    exception = new AuthenticationException(SR.Format(SR.net_auth_context_expectation, _expectedImpersonationLevel.ToString(), PrivateImpersonationLevel.ToString()));
-                    int statusCode = ERROR_TRUST_FAILURE;
-                    message = new byte[8];  //sizeof(long)
-
-                    for (int i = message.Length - 1; i >= 0; --i)
-                    {
-                        message[i] = (byte)(statusCode & 0xFF);
-                        statusCode = (int)((uint)statusCode >> 8);
-                    }
-
-                    StartSendAuthResetSignal(lazyResult, message, exception);
-                    return;
-                }
-
-                ProtectionLevel result = _context.IsConfidentialityFlag ? ProtectionLevel.EncryptAndSign : _context.IsIntegrityFlag ? ProtectionLevel.Sign : ProtectionLevel.None;
-
-                if (result < _expectedProtectionLevel)
-                {
-                    exception = new AuthenticationException(SR.Format(SR.net_auth_context_expectation, result.ToString(), _expectedProtectionLevel.ToString()));
-                    int statusCode = ERROR_TRUST_FAILURE;
-                    message = new byte[8];  //sizeof(long)
-
-                    for (int i = message.Length - 1; i >= 0; --i)
-                    {
-                        message[i] = (byte)(statusCode & 0xFF);
-                        statusCode = (int)((uint)statusCode >> 8);
-                    }
-
-                    StartSendAuthResetSignal(lazyResult, message, exception);
-                    return;
-                }
-
-                // Signal remote party that we are done
-                _framer!.WriteHeader.MessageId = FrameHeader.HandshakeDoneId;
-                if (_context.IsServer)
-                {
-                    // Server may complete now because client SSPI would not complain at this point.
-                    _remoteOk = true;
-
-                    // However the client will wait for server to send this ACK
-                    //Force signaling server OK to the client
-                    if (message == null)
-                    {
-                        message = s_emptyMessage;
-                    }
-                }
-            }
-            else if (message == null || message == s_emptyMessage)
-            {
-                throw new InternalException();
-            }
-
-            if (message != null)
-            {
-                //even if we are completed, there could be a blob for sending.
-                if (lazyResult == null)
-                {
-                    _framer!.WriteMessage(message);
-                }
-                else
-                {
-                    IAsyncResult ar = _framer!.BeginWriteMessage(message, s_writeCallback, lazyResult);
-                    if (!ar.CompletedSynchronously)
-                    {
-                        return;
-                    }
-                    _framer.EndWriteMessage(ar);
-                }
-            }
-            CheckCompletionBeforeNextReceive(lazyResult);
-        }
-
-        //
-        // This will check and logically complete the auth handshake.
-        //
-        private void CheckCompletionBeforeNextReceive(LazyAsyncResult? lazyResult)
-        {
-            if (HandshakeComplete && _remoteOk)
-            {
-                // We are done with success.
-                if (lazyResult != null)
-                {
-                    lazyResult.InvokeCallback();
-                }
-
-                return;
-            }
-
-            StartReceiveBlob(lazyResult);
-        }
-
-        //
-        // Server side starts here, but client also loops through this method.
-        //
-        private void StartReceiveBlob(LazyAsyncResult? lazyResult)
-        {
-            Debug.Assert(_framer != null);
-
-            byte[]? message;
-            if (lazyResult == null)
-            {
-                message = _framer.ReadMessage();
-            }
-            else
-            {
-                IAsyncResult ar = _framer.BeginReadMessage(s_readCallback, lazyResult);
-                if (!ar.CompletedSynchronously)
-                {
-                    return;
-                }
-
-                message = _framer.EndReadMessage(ar);
-            }
-
-            ProcessReceivedBlob(message, lazyResult);
-        }
-
-        private void ProcessReceivedBlob(byte[]? message, LazyAsyncResult? lazyResult)
-        {
-            // This is an EOF otherwise we would get at least *empty* message but not a null one.
-            if (message == null)
-            {
-                throw new AuthenticationException(SR.net_auth_eof, null);
-            }
-
-            // Process Header information.
-            if (_framer!.ReadHeader.MessageId == FrameHeader.HandshakeErrId)
-            {
-                if (message.Length >= 8)    // sizeof(long)
-                {
-                    // Try to recover remote win32 Exception.
-                    long error = 0;
-                    for (int i = 0; i < 8; ++i)
-                    {
-                        error = (error << 8) + message[i];
-                    }
-
-                    ThrowCredentialException(error);
-                }
-
-                throw new AuthenticationException(SR.net_auth_alert, null);
-            }
-
-            if (_framer.ReadHeader.MessageId == FrameHeader.HandshakeDoneId)
-            {
-                _remoteOk = true;
-            }
-            else if (_framer.ReadHeader.MessageId != FrameHeader.HandshakeId)
-            {
-                throw new AuthenticationException(SR.Format(SR.net_io_header_id, "MessageId", _framer.ReadHeader.MessageId, FrameHeader.HandshakeId), null);
-            }
-
-            CheckCompletionBeforeNextSend(message, lazyResult);
-        }
-
-        //
-        // This will check and logically complete the auth handshake.
-        //
-        private void CheckCompletionBeforeNextSend(byte[] message, LazyAsyncResult? lazyResult)
-        {
-            //If we are done don't go into send.
-            if (HandshakeComplete)
-            {
-                if (!_remoteOk)
-                {
-                    throw new AuthenticationException(SR.Format(SR.net_io_header_id, "MessageId", _framer!.ReadHeader.MessageId, FrameHeader.HandshakeDoneId), null);
-                }
-                if (lazyResult != null)
-                {
-                    lazyResult.InvokeCallback();
-                }
-
-                return;
-            }
-
-            // Not yet done, get a new blob and send it if any.
-            StartSendBlob(message, lazyResult);
-        }
-
-        //
-        //  This is to reset auth state on the remote side.
-        //  If this write succeeds we will allow auth retrying.
-        //
-        private void StartSendAuthResetSignal(LazyAsyncResult? lazyResult, byte[] message, Exception exception)
-        {
-            _framer!.WriteHeader.MessageId = FrameHeader.HandshakeErrId;
-
-            if (IsLogonDeniedException(exception))
-            {
-                if (IsServer)
-                {
-                    exception = new InvalidCredentialException(SR.net_auth_bad_client_creds, exception);
-                }
-                else
-                {
-                    exception = new InvalidCredentialException(SR.net_auth_bad_client_creds_or_target_mismatch, exception);
-                }
-            }
-
-            if (!(exception is AuthenticationException))
-            {
-                exception = new AuthenticationException(SR.net_auth_SSPI, exception);
-            }
-
-            if (lazyResult == null)
-            {
-                _framer.WriteMessage(message);
-            }
-            else
-            {
-                lazyResult.Result = exception;
-                IAsyncResult ar = _framer.BeginWriteMessage(message, s_writeCallback, lazyResult);
-                if (!ar.CompletedSynchronously)
-                {
-                    return;
-                }
-
-                _framer.EndWriteMessage(ar);
-            }
-
-            _canRetryAuthentication = true;
-            ExceptionDispatchInfo.Throw(exception);
-        }
-
-        private static void WriteCallback(IAsyncResult transportResult)
-        {
-            if (!(transportResult.AsyncState is LazyAsyncResult))
-            {
-                NetEventSource.Fail(transportResult, "State type is wrong, expected LazyAsyncResult.");
-            }
-
-            if (transportResult.CompletedSynchronously)
-            {
-                return;
-            }
-
-            LazyAsyncResult lazyResult = (LazyAsyncResult)transportResult.AsyncState!;
-
-            // Async completion.
-            try
-            {
-                NegoState authState = (NegoState)lazyResult.AsyncObject!;
-                authState._framer!.EndWriteMessage(transportResult);
-
-                // Special case for an error notification.
-                if (lazyResult.Result is Exception e)
-                {
-                    authState._canRetryAuthentication = true;
-                    ExceptionDispatchInfo.Throw(e);
-                }
-
-                authState.CheckCompletionBeforeNextReceive(lazyResult);
-            }
-            catch (Exception e)
-            {
-                if (lazyResult.InternalPeekCompleted)
-                {
-                    // This will throw on a worker thread.
-                    throw;
-                }
-
-                lazyResult.InvokeCallback(e);
-            }
-        }
-
-        private static void ReadCallback(IAsyncResult transportResult)
-        {
-            if (!(transportResult.AsyncState is LazyAsyncResult))
-            {
-                NetEventSource.Fail(transportResult, "State type is wrong, expected LazyAsyncResult.");
-            }
-
-            if (transportResult.CompletedSynchronously)
-            {
-                return;
-            }
-
-            LazyAsyncResult lazyResult = (LazyAsyncResult)transportResult.AsyncState!;
-
-            // Async completion.
-            try
-            {
-                NegoState authState = (NegoState)lazyResult.AsyncObject!;
-                byte[]? message = authState._framer!.EndReadMessage(transportResult);
-                authState.ProcessReceivedBlob(message, lazyResult);
-            }
-            catch (Exception e)
-            {
-                if (lazyResult.InternalPeekCompleted)
-                {
-                    // This will throw on a worker thread.
-                    throw;
-                }
-
-                lazyResult.InvokeCallback(e);
-            }
-        }
-
-        internal static bool IsError(SecurityStatusPal status)
-        {
-            return ((int)status.ErrorCode >= (int)SecurityStatusPalErrorCode.OutOfMemory);
-        }
-
-        private unsafe byte[]? GetOutgoingBlob(byte[]? incomingBlob, ref Exception? e)
-        {
-            byte[]? message = _context!.GetOutgoingBlob(incomingBlob, false, out SecurityStatusPal statusCode);
-
-            if (IsError(statusCode))
-            {
-                e = NegotiateStreamPal.CreateExceptionFromError(statusCode);
-                uint error = (uint)e.HResult;
-
-                message = new byte[sizeof(long)];
-                for (int i = message.Length - 1; i >= 0; --i)
-                {
-                    message[i] = (byte)(error & 0xFF);
-                    error = (error >> 8);
-                }
-            }
-
-            if (message != null && message.Length == 0)
-            {
-                message = s_emptyMessage;
-            }
-
-            return message;
-        }
-
-        internal int EncryptData(byte[] buffer, int offset, int count, ref byte[]? outBuffer)
-        {
-            CheckThrow(true);
-
-            // SSPI seems to ignore this sequence number.
-            ++_writeSequenceNumber;
-            return _context!.Encrypt(buffer, offset, count, ref outBuffer, _writeSequenceNumber);
-        }
-
-        internal int DecryptData(byte[] buffer, int offset, int count, out int newOffset)
-        {
-            CheckThrow(true);
-
-            // SSPI seems to ignore this sequence number.
-            ++_readSequenceNumber;
-            return _context!.Decrypt(buffer, offset, count, out newOffset, _readSequenceNumber);
-        }
-
-        internal static void ThrowCredentialException(long error)
-        {
-            Win32Exception e = new Win32Exception((int)error);
-
-            if (e.NativeErrorCode == (int)SecurityStatusPalErrorCode.LogonDenied)
-            {
-                throw new InvalidCredentialException(SR.net_auth_bad_client_creds, e);
-            }
-
-            if (e.NativeErrorCode == NegoState.ERROR_TRUST_FAILURE)
-            {
-                throw new AuthenticationException(SR.net_auth_context_expectation_remote, e);
-            }
-
-            throw new AuthenticationException(SR.net_auth_alert, e);
-        }
-
-        internal static bool IsLogonDeniedException(Exception exception)
-        {
-            Win32Exception? win32exception = exception as Win32Exception;
-
-            return (win32exception != null) && (win32exception.NativeErrorCode == (int)SecurityStatusPalErrorCode.LogonDenied);
-        }
-    }
-}
index 0b221f4..8a420ed 100644 (file)
@@ -2,34 +2,59 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
+using System.ComponentModel;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
 using System.IO;
-using System.Threading;
-using System.Threading.Tasks;
+using System.Runtime.CompilerServices;
 using System.Runtime.ExceptionServices;
+using System.Security.Authentication;
 using System.Security.Authentication.ExtendedProtection;
 using System.Security.Principal;
+using System.Threading;
+using System.Threading.Tasks;
 
 namespace System.Net.Security
 {
-    /*
-        An authenticated stream based on NEGO SSP.
-
-            The class that can be used by client and server side applications
-            - to transfer Identities across the stream
-            - to encrypt data based on NEGO SSP package
-
-            In most cases the innerStream will be of type NetworkStream.
-            On Win9x data encryption is not available and both sides have
-            to explicitly drop SecurityLevel and MuatualAuth requirements.
-
-            This is a simple wrapper class.
-            All real work is done by internal NegoState class and the other partial implementation files.
-    */
+    /// <summary>
+    /// Provides a stream that uses the Negotiate security protocol to authenticate the client, and optionally the server, in client-server communication.
+    /// </summary>
     public partial class NegotiateStream : AuthenticatedStream
     {
-        private readonly NegoState _negoState;
-        private readonly string _package;
+        private const int ERROR_TRUST_FAILURE = 1790;   // Used to serialize protectionLevel or impersonationLevel mismatch error to the remote side.
+        private const int MaxReadFrameSize = 64 * 1024;
+        private const int MaxWriteDataSize = 63 * 1024; // 1k for the framing and trailer that is always less as per SSPI.
+        private const string DefaultPackage = NegotiationInfoClass.Negotiate;
+
+#pragma warning disable CA1825 // used in reference comparison, requires unique object identity
+        private static readonly byte[] s_emptyMessage = new byte[0];
+#pragma warning restore CA1825
+
+        private readonly byte[] _readHeader;
         private IIdentity? _remoteIdentity;
+        private byte[] _buffer;
+        private int _bufferOffset;
+        private int _bufferCount;
+
+        private volatile int _writeInProgress;
+        private volatile int _readInProgress;
+        private volatile int _authInProgress;
+
+        private Exception? _exception;
+        private StreamFramer? _framer;
+        private NTAuthentication? _context;
+        private bool _canRetryAuthentication;
+        private ProtectionLevel _expectedProtectionLevel;
+        private TokenImpersonationLevel _expectedImpersonationLevel;
+        private uint _writeSequenceNumber;
+        private uint _readSequenceNumber;
+        private ExtendedProtectionPolicy? _extendedProtectionPolicy;
+
+        /// <summary>
+        /// SSPI does not send a server ack on successful auth.
+        /// This is a state variable used to gracefully handle auth confirmation.
+        /// </summary>
+        private bool _remoteOk = false;
 
         public NegotiateStream(Stream innerStream) : this(innerStream, false)
         {
@@ -37,530 +62,920 @@ namespace System.Net.Security
 
         public NegotiateStream(Stream innerStream, bool leaveInnerStreamOpen) : base(innerStream, leaveInnerStreamOpen)
         {
-            _negoState = new NegoState(innerStream);
-            _package = NegoState.DefaultPackage;
-            InitializeStreamPart();
+            _readHeader = new byte[4];
+            _buffer = Array.Empty<byte>();
         }
 
-        public virtual IAsyncResult BeginAuthenticateAsClient(AsyncCallback? asyncCallback, object? asyncState)
+        protected override void Dispose(bool disposing)
         {
-            return BeginAuthenticateAsClient((NetworkCredential)CredentialCache.DefaultCredentials, null, string.Empty,
-                                           ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification,
-                                           asyncCallback, asyncState);
+            try
+            {
+                _exception = new ObjectDisposedException(nameof(NegotiateStream));
+                _context?.CloseContext();
+            }
+            finally
+            {
+                base.Dispose(disposing);
+            }
         }
 
-        public virtual IAsyncResult BeginAuthenticateAsClient(NetworkCredential credential, string targetName, AsyncCallback? asyncCallback, object? asyncState)
+        public override async ValueTask DisposeAsync()
         {
-            return BeginAuthenticateAsClient(credential, null, targetName,
-                                           ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification,
-                                           asyncCallback, asyncState);
+            try
+            {
+                _exception = new ObjectDisposedException(nameof(NegotiateStream));
+                _context?.CloseContext();
+            }
+            finally
+            {
+                await base.DisposeAsync().ConfigureAwait(false);
+            }
         }
 
-        public virtual IAsyncResult BeginAuthenticateAsClient(NetworkCredential credential, ChannelBinding? binding, string targetName, AsyncCallback? asyncCallback, object? asyncState)
-        {
-            return BeginAuthenticateAsClient(credential, binding, targetName,
-                                             ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification,
-                                             asyncCallback, asyncState);
-        }
+        public virtual IAsyncResult BeginAuthenticateAsClient(AsyncCallback? asyncCallback, object? asyncState) =>
+            BeginAuthenticateAsClient((NetworkCredential)CredentialCache.DefaultCredentials, binding: null, string.Empty, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification,
+                                      asyncCallback, asyncState);
 
-        public virtual IAsyncResult BeginAuthenticateAsClient(
-            NetworkCredential credential,
-            string targetName,
-            ProtectionLevel requiredProtectionLevel,
-            TokenImpersonationLevel allowedImpersonationLevel,
-            AsyncCallback? asyncCallback,
-            object? asyncState)
-        {
-            return BeginAuthenticateAsClient(credential, null, targetName,
-                                             requiredProtectionLevel, allowedImpersonationLevel,
-                                             asyncCallback, asyncState);
-        }
+        public virtual IAsyncResult BeginAuthenticateAsClient(NetworkCredential credential, string targetName, AsyncCallback? asyncCallback, object? asyncState) =>
+            BeginAuthenticateAsClient(credential, binding: null, targetName, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification,
+                                      asyncCallback, asyncState);
+
+        public virtual IAsyncResult BeginAuthenticateAsClient(NetworkCredential credential, ChannelBinding? binding, string targetName, AsyncCallback? asyncCallback, object? asyncState) =>
+            BeginAuthenticateAsClient(credential, binding, targetName, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification,
+                                      asyncCallback, asyncState);
 
         public virtual IAsyncResult BeginAuthenticateAsClient(
-            NetworkCredential credential,
-            ChannelBinding? binding,
-            string targetName,
-            ProtectionLevel requiredProtectionLevel,
-            TokenImpersonationLevel allowedImpersonationLevel,
-            AsyncCallback? asyncCallback,
-            object? asyncState)
-        {
-            _negoState.ValidateCreateContext(_package, false, credential, targetName, binding, requiredProtectionLevel, allowedImpersonationLevel);
+            NetworkCredential credential, string targetName, ProtectionLevel requiredProtectionLevel, TokenImpersonationLevel allowedImpersonationLevel,
+            AsyncCallback? asyncCallback, object? asyncState) =>
+            BeginAuthenticateAsClient(credential, binding: null, targetName, requiredProtectionLevel, allowedImpersonationLevel,
+                                      asyncCallback, asyncState);
 
-            LazyAsyncResult result = new LazyAsyncResult(_negoState, asyncState, asyncCallback);
-            _negoState.ProcessAuthentication(result);
+        public virtual IAsyncResult BeginAuthenticateAsClient(
+            NetworkCredential credential, ChannelBinding? binding, string targetName, ProtectionLevel requiredProtectionLevel, TokenImpersonationLevel allowedImpersonationLevel,
+            AsyncCallback? asyncCallback, object? asyncState) =>
+            TaskToApm.Begin(AuthenticateAsClientAsync(credential, binding, targetName, requiredProtectionLevel, allowedImpersonationLevel), asyncCallback, asyncState);
 
-            return result;
-        }
+        public virtual void EndAuthenticateAsClient(IAsyncResult asyncResult) => TaskToApm.End(asyncResult);
 
-        public virtual void EndAuthenticateAsClient(IAsyncResult asyncResult)
-        {
-            _negoState.EndProcessAuthentication(asyncResult);
-        }
+        public virtual void AuthenticateAsServer() =>
+            AuthenticateAsServer((NetworkCredential)CredentialCache.DefaultCredentials, policy: null, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification);
 
-        public virtual void AuthenticateAsServer()
-        {
-            AuthenticateAsServer((NetworkCredential)CredentialCache.DefaultCredentials, null, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification);
-        }
-
-        public virtual void AuthenticateAsServer(ExtendedProtectionPolicy? policy)
-        {
+        public virtual void AuthenticateAsServer(ExtendedProtectionPolicy? policy) =>
             AuthenticateAsServer((NetworkCredential)CredentialCache.DefaultCredentials, policy, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification);
-        }
 
-        public virtual void AuthenticateAsServer(NetworkCredential credential, ProtectionLevel requiredProtectionLevel, TokenImpersonationLevel requiredImpersonationLevel)
-        {
-            AuthenticateAsServer(credential, null, requiredProtectionLevel, requiredImpersonationLevel);
-        }
+        public virtual void AuthenticateAsServer(NetworkCredential credential, ProtectionLevel requiredProtectionLevel, TokenImpersonationLevel requiredImpersonationLevel) =>
+            AuthenticateAsServer(credential, policy: null, requiredProtectionLevel, requiredImpersonationLevel);
 
         public virtual void AuthenticateAsServer(NetworkCredential credential, ExtendedProtectionPolicy? policy, ProtectionLevel requiredProtectionLevel, TokenImpersonationLevel requiredImpersonationLevel)
         {
-            _negoState.ValidateCreateContext(_package, credential, string.Empty, policy, requiredProtectionLevel, requiredImpersonationLevel);
-            _negoState.ProcessAuthentication(null);
+            ValidateCreateContext(DefaultPackage, credential, string.Empty, policy, requiredProtectionLevel, requiredImpersonationLevel);
+            AuthenticateAsync(new SyncReadWriteAdapter(InnerStream)).GetAwaiter().GetResult();
         }
 
-        public virtual IAsyncResult BeginAuthenticateAsServer(AsyncCallback? asyncCallback, object? asyncState)
-        {
-            return BeginAuthenticateAsServer((NetworkCredential)CredentialCache.DefaultCredentials, null, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification, asyncCallback, asyncState);
-        }
+        public virtual IAsyncResult BeginAuthenticateAsServer(AsyncCallback? asyncCallback, object? asyncState) =>
+            BeginAuthenticateAsServer((NetworkCredential)CredentialCache.DefaultCredentials, policy: null, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification, asyncCallback, asyncState);
 
-        public virtual IAsyncResult BeginAuthenticateAsServer(ExtendedProtectionPolicy? policy, AsyncCallback? asyncCallback, object? asyncState)
-        {
-            return BeginAuthenticateAsServer((NetworkCredential)CredentialCache.DefaultCredentials, policy, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification, asyncCallback, asyncState);
-        }
+        public virtual IAsyncResult BeginAuthenticateAsServer(ExtendedProtectionPolicy? policy, AsyncCallback? asyncCallback, object? asyncState) =>
+            BeginAuthenticateAsServer((NetworkCredential)CredentialCache.DefaultCredentials, policy, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification, asyncCallback, asyncState);
 
         public virtual IAsyncResult BeginAuthenticateAsServer(
-            NetworkCredential credential,
-            ProtectionLevel requiredProtectionLevel,
-            TokenImpersonationLevel requiredImpersonationLevel,
-            AsyncCallback? asyncCallback,
-            object? asyncState)
-        {
-            return BeginAuthenticateAsServer(credential, null, requiredProtectionLevel, requiredImpersonationLevel, asyncCallback, asyncState);
-        }
+            NetworkCredential credential, ProtectionLevel requiredProtectionLevel, TokenImpersonationLevel requiredImpersonationLevel,
+            AsyncCallback? asyncCallback, object? asyncState) =>
+            BeginAuthenticateAsServer(credential, policy: null, requiredProtectionLevel, requiredImpersonationLevel, asyncCallback, asyncState);
 
         public virtual IAsyncResult BeginAuthenticateAsServer(
-            NetworkCredential credential,
-            ExtendedProtectionPolicy? policy,
-            ProtectionLevel requiredProtectionLevel,
-            TokenImpersonationLevel requiredImpersonationLevel,
-            AsyncCallback? asyncCallback,
-            object? asyncState)
-        {
-            _negoState.ValidateCreateContext(_package, credential, string.Empty, policy, requiredProtectionLevel, requiredImpersonationLevel);
+            NetworkCredential credential, ExtendedProtectionPolicy? policy, ProtectionLevel requiredProtectionLevel, TokenImpersonationLevel requiredImpersonationLevel,
+            AsyncCallback? asyncCallback, object? asyncState) =>
+            TaskToApm.Begin(AuthenticateAsServerAsync(credential, policy, requiredProtectionLevel, requiredImpersonationLevel), asyncCallback, asyncState);
 
-            LazyAsyncResult result = new LazyAsyncResult(_negoState, asyncState, asyncCallback);
-            _negoState.ProcessAuthentication(result);
+        public virtual void EndAuthenticateAsServer(IAsyncResult asyncResult) => TaskToApm.End(asyncResult);
 
-            return result;
-        }
-        //
-        public virtual void EndAuthenticateAsServer(IAsyncResult asyncResult)
-        {
-            _negoState.EndProcessAuthentication(asyncResult);
-        }
+        public virtual void AuthenticateAsClient() =>
+            AuthenticateAsClient((NetworkCredential)CredentialCache.DefaultCredentials, binding: null, string.Empty, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification);
 
-        public virtual void AuthenticateAsClient()
-        {
-            AuthenticateAsClient((NetworkCredential)CredentialCache.DefaultCredentials, null, string.Empty, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification);
-        }
+        public virtual void AuthenticateAsClient(NetworkCredential credential, string targetName) =>
+            AuthenticateAsClient(credential, binding: null, targetName, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification);
 
-        public virtual void AuthenticateAsClient(NetworkCredential credential, string targetName)
-        {
-            AuthenticateAsClient(credential, null, targetName, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification);
-        }
-
-        public virtual void AuthenticateAsClient(NetworkCredential credential, ChannelBinding? binding, string targetName)
-        {
+        public virtual void AuthenticateAsClient(NetworkCredential credential, ChannelBinding? binding, string targetName) =>
             AuthenticateAsClient(credential, binding, targetName, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification);
-        }
 
         public virtual void AuthenticateAsClient(
-            NetworkCredential credential, string targetName, ProtectionLevel requiredProtectionLevel, TokenImpersonationLevel allowedImpersonationLevel)
-        {
-            AuthenticateAsClient(credential, null, targetName, requiredProtectionLevel, allowedImpersonationLevel);
-        }
+            NetworkCredential credential, string targetName, ProtectionLevel requiredProtectionLevel, TokenImpersonationLevel allowedImpersonationLevel) =>
+            AuthenticateAsClient(credential, binding: null, targetName, requiredProtectionLevel, allowedImpersonationLevel);
 
         public virtual void AuthenticateAsClient(
             NetworkCredential credential, ChannelBinding? binding, string targetName, ProtectionLevel requiredProtectionLevel, TokenImpersonationLevel allowedImpersonationLevel)
         {
-            _negoState.ValidateCreateContext(_package, false, credential, targetName, binding, requiredProtectionLevel, allowedImpersonationLevel);
-            _negoState.ProcessAuthentication(null);
+            ValidateCreateContext(DefaultPackage, isServer: false, credential, targetName, binding, requiredProtectionLevel, allowedImpersonationLevel);
+            AuthenticateAsync(new SyncReadWriteAdapter(InnerStream)).GetAwaiter().GetResult();
         }
 
-        public virtual Task AuthenticateAsClientAsync()
-        {
-            return Task.Factory.FromAsync(BeginAuthenticateAsClient, EndAuthenticateAsClient, null);
-        }
+        public virtual Task AuthenticateAsClientAsync() =>
+            AuthenticateAsClientAsync((NetworkCredential)CredentialCache.DefaultCredentials, binding: null, string.Empty, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification);
 
-        public virtual Task AuthenticateAsClientAsync(NetworkCredential credential, string targetName)
-        {
-            return Task.Factory.FromAsync(BeginAuthenticateAsClient, EndAuthenticateAsClient, credential, targetName, null);
-        }
+        public virtual Task AuthenticateAsClientAsync(NetworkCredential credential, string targetName) =>
+            AuthenticateAsClientAsync(credential, binding: null, targetName, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification);
 
         public virtual Task AuthenticateAsClientAsync(
             NetworkCredential credential, string targetName,
             ProtectionLevel requiredProtectionLevel,
-            TokenImpersonationLevel allowedImpersonationLevel)
-        {
-            return Task.Factory.FromAsync((callback, state) => BeginAuthenticateAsClient(credential, targetName, requiredProtectionLevel, allowedImpersonationLevel, callback, state), EndAuthenticateAsClient, null);
-        }
+            TokenImpersonationLevel allowedImpersonationLevel) =>
+            AuthenticateAsClientAsync(credential, binding: null, targetName, requiredProtectionLevel, allowedImpersonationLevel);
 
-        public virtual Task AuthenticateAsClientAsync(NetworkCredential credential, ChannelBinding? binding, string targetName)
-        {
-            return Task.Factory.FromAsync(BeginAuthenticateAsClient, EndAuthenticateAsClient, credential, binding, targetName, null);
-        }
+        public virtual Task AuthenticateAsClientAsync(NetworkCredential credential, ChannelBinding? binding, string targetName) =>
+            AuthenticateAsClientAsync(credential, binding, targetName, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification);
 
         public virtual Task AuthenticateAsClientAsync(
-            NetworkCredential credential, ChannelBinding? binding,
-            string targetName, ProtectionLevel requiredProtectionLevel,
+            NetworkCredential credential, ChannelBinding? binding, string targetName, ProtectionLevel requiredProtectionLevel,
             TokenImpersonationLevel allowedImpersonationLevel)
         {
-            return Task.Factory.FromAsync((callback, state) => BeginAuthenticateAsClient(credential, binding, targetName, requiredProtectionLevel, allowedImpersonationLevel, callback, state), EndAuthenticateAsClient, null);
+            ValidateCreateContext(DefaultPackage, isServer: false, credential, targetName, binding, requiredProtectionLevel, allowedImpersonationLevel);
+            return AuthenticateAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken: default));
         }
 
-        public virtual Task AuthenticateAsServerAsync()
-        {
-            return Task.Factory.FromAsync(BeginAuthenticateAsServer, EndAuthenticateAsServer, null);
-        }
+        public virtual Task AuthenticateAsServerAsync() =>
+            AuthenticateAsServerAsync((NetworkCredential)CredentialCache.DefaultCredentials, policy: null, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification);
 
-        public virtual Task AuthenticateAsServerAsync(ExtendedProtectionPolicy? policy)
-        {
-            return Task.Factory.FromAsync(BeginAuthenticateAsServer, EndAuthenticateAsServer, policy, null);
-        }
+        public virtual Task AuthenticateAsServerAsync(ExtendedProtectionPolicy? policy) =>
+            AuthenticateAsServerAsync((NetworkCredential)CredentialCache.DefaultCredentials, policy, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification);
 
-        public virtual Task AuthenticateAsServerAsync(NetworkCredential credential, ProtectionLevel requiredProtectionLevel, TokenImpersonationLevel requiredImpersonationLevel)
-        {
-            return Task.Factory.FromAsync(BeginAuthenticateAsServer, EndAuthenticateAsServer, credential, requiredProtectionLevel, requiredImpersonationLevel, null);
-        }
+        public virtual Task AuthenticateAsServerAsync(NetworkCredential credential, ProtectionLevel requiredProtectionLevel, TokenImpersonationLevel requiredImpersonationLevel) =>
+            AuthenticateAsServerAsync(credential, policy: null, requiredProtectionLevel, requiredImpersonationLevel);
 
         public virtual Task AuthenticateAsServerAsync(
-            NetworkCredential credential, ExtendedProtectionPolicy? policy,
-            ProtectionLevel requiredProtectionLevel,
-            TokenImpersonationLevel requiredImpersonationLevel)
+            NetworkCredential credential, ExtendedProtectionPolicy? policy, ProtectionLevel requiredProtectionLevel, TokenImpersonationLevel requiredImpersonationLevel)
         {
-            return Task.Factory.FromAsync((callback, state) => BeginAuthenticateAsServer(credential, policy, requiredProtectionLevel, requiredImpersonationLevel, callback, state), EndAuthenticateAsClient, null);
+            ValidateCreateContext(DefaultPackage, credential, string.Empty, policy, requiredProtectionLevel, requiredImpersonationLevel);
+            return AuthenticateAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken: default));
         }
 
-        public override bool IsAuthenticated
+        public override bool IsAuthenticated => IsAuthenticatedCore;
+
+        [MemberNotNullWhen(true, nameof(_context))]
+        private bool IsAuthenticatedCore => _context != null && HandshakeComplete && _exception == null && _remoteOk;
+
+        public override bool IsMutuallyAuthenticated =>
+            IsAuthenticatedCore &&
+            !_context.IsNTLM && // suppressing for NTLM since SSPI does not return correct value in the context flags.
+            _context.IsMutualAuthFlag;
+
+        public override bool IsEncrypted => IsAuthenticatedCore && _context.IsConfidentialityFlag;
+
+        public override bool IsSigned => IsAuthenticatedCore && (_context.IsIntegrityFlag || _context.IsConfidentialityFlag);
+
+        public override bool IsServer => _context != null && _context.IsServer;
+
+        public virtual TokenImpersonationLevel ImpersonationLevel
         {
             get
             {
-                return _negoState.IsAuthenticated;
+                ThrowIfFailed(authSuccessCheck: true);
+                return PrivateImpersonationLevel;
             }
         }
 
-        public override bool IsMutuallyAuthenticated
+        private TokenImpersonationLevel PrivateImpersonationLevel =>
+            _context!.IsDelegationFlag && _context.ProtocolName != NegotiationInfoClass.NTLM ? TokenImpersonationLevel.Delegation : // We should suppress the delegate flag in NTLM case.
+            _context.IsIdentifyFlag ? TokenImpersonationLevel.Identification :
+            TokenImpersonationLevel.Impersonation;
+
+        private bool HandshakeComplete => _context!.IsCompleted && _context.IsValidContext;
+
+        private bool CanGetSecureStream => _context!.IsConfidentialityFlag || _context.IsIntegrityFlag;
+
+        public virtual IIdentity RemoteIdentity
         {
             get
             {
-                return _negoState.IsMutuallyAuthenticated;
+                IIdentity? identity = _remoteIdentity;
+                if (identity is null)
+                {
+                    ThrowIfFailed(authSuccessCheck: true);
+                    _remoteIdentity = identity = NegotiateStreamPal.GetIdentity(_context!);
+                }
+                return identity;
             }
         }
 
-        public override bool IsEncrypted
+        public override bool CanSeek => false;
+
+        public override bool CanRead => IsAuthenticated && InnerStream.CanRead;
+
+        public override bool CanTimeout => InnerStream.CanTimeout;
+
+        public override bool CanWrite => IsAuthenticated && InnerStream.CanWrite;
+
+        public override int ReadTimeout
+        {
+            get => InnerStream.ReadTimeout;
+            set => InnerStream.ReadTimeout = value;
+        }
+
+        public override int WriteTimeout
         {
-            get
-            {
-                return _negoState.IsEncrypted;
-            }
+            get => InnerStream.WriteTimeout;
+            set => InnerStream.WriteTimeout = value;
         }
 
-        public override bool IsSigned
+        public override long Length => InnerStream.Length;
+
+        public override long Position
         {
-            get
+            get => InnerStream.Position;
+            set => throw new NotSupportedException(SR.net_noseek);
+        }
+
+        public override void SetLength(long value) =>
+            InnerStream.SetLength(value);
+
+        public override long Seek(long offset, SeekOrigin origin) =>
+            throw new NotSupportedException(SR.net_noseek);
+
+        public override void Flush() =>
+            InnerStream.Flush();
+
+        public override Task FlushAsync(CancellationToken cancellationToken) =>
+            InnerStream.FlushAsync(cancellationToken);
+
+        public override int Read(byte[] buffer, int offset, int count)
+        {
+            ValidateParameters(buffer, offset, count);
+
+            ThrowIfFailed(authSuccessCheck: true);
+            if (!CanGetSecureStream)
             {
-                return _negoState.IsSigned;
+                return InnerStream.Read(buffer, offset, count);
             }
+
+            ValueTask<int> vt = ReadAsync(new SyncReadWriteAdapter(InnerStream), new Memory<byte>(buffer, offset, count));
+            Debug.Assert(vt.IsCompleted, "Should have completed synchroously with sync adapter");
+            return vt.GetAwaiter().GetResult();
         }
 
-        public override bool IsServer
+        public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
         {
-            get
+            ValidateParameters(buffer, offset, count);
+
+            ThrowIfFailed(authSuccessCheck: true);
+            if (!CanGetSecureStream)
             {
-                return _negoState.IsServer;
+                return InnerStream.ReadAsync(buffer, offset, count, cancellationToken);
             }
+
+            return ReadAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken), new Memory<byte>(buffer, offset, count)).AsTask();
         }
 
-        public virtual TokenImpersonationLevel ImpersonationLevel
+        public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
         {
-            get
+            ThrowIfFailed(authSuccessCheck: true);
+            if (!CanGetSecureStream)
             {
-                return _negoState.AllowedImpersonation;
+                return InnerStream.ReadAsync(buffer, cancellationToken);
             }
+
+            return ReadAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken), buffer);
         }
 
-        public virtual IIdentity RemoteIdentity
+        private async ValueTask<int> ReadAsync<TAdapter>(TAdapter adapter, Memory<byte> buffer, [CallerMemberName] string? callerName = null) where TAdapter : IReadWriteAdapter
         {
-            get
+            if (Interlocked.Exchange(ref _readInProgress, 1) == 1)
+            {
+                throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, callerName, "read"));
+            }
+
+            try
             {
-                if (_remoteIdentity == null)
+                if (_bufferCount != 0)
                 {
-                    _remoteIdentity = _negoState.GetIdentity();
+                    int copyBytes = Math.Min(_bufferCount, buffer.Length);
+                    if (copyBytes != 0)
+                    {
+                        _buffer.AsMemory(_bufferOffset, copyBytes).CopyTo(buffer);
+                        _bufferOffset += copyBytes;
+                        _bufferCount -= copyBytes;
+                    }
+                    return copyBytes;
                 }
 
-                return _remoteIdentity;
+                while (true)
+                {
+                    int readBytes = await adapter.ReadAllAsync(_readHeader).ConfigureAwait(false);
+                    if (readBytes == 0)
+                    {
+                        return 0;
+                    }
+
+                    // Replace readBytes with the body size recovered from the header content.
+                    readBytes = BitConverter.ToInt32(_readHeader, 0);
+
+                    // The body carries 4 bytes for trailer size slot plus trailer, hence <= 4 frame size is always an error.
+                    // Additionally we'd like to restrict the read frame size to 64k.
+                    if (readBytes <= 4 || readBytes > MaxReadFrameSize)
+                    {
+                        throw new IOException(SR.net_frame_read_size);
+                    }
+
+                    // Always pass InternalBuffer for SSPI "in place" decryption.
+                    // A user buffer can be shared by many threads in that case decryption/integrity check may fail cause of data corruption.
+                    _bufferCount = readBytes;
+                    _bufferOffset = 0;
+                    if (_buffer.Length < readBytes)
+                    {
+                        _buffer = new byte[readBytes];
+                    }
+                    readBytes = await adapter.ReadAllAsync(new Memory<byte>(_buffer, 0, readBytes)).ConfigureAwait(false);
+                    if (readBytes == 0)
+                    {
+                        // We already checked that the frame body is bigger than 0 bytes. Hence, this is an EOF.
+                        throw new IOException(SR.net_io_eof);
+                    }
+
+                    // Decrypt into internal buffer, change "readBytes" to count now _Decrypted Bytes_
+                    // Decrypted data start from zero offset, the size can be shrunk after decryption.
+                    _bufferCount = readBytes = DecryptData(_buffer!, 0, readBytes, out _bufferOffset);
+                    if (readBytes == 0 && buffer.Length != 0)
+                    {
+                        // Read again.
+                        continue;
+                    }
+
+                    if (readBytes > buffer.Length)
+                    {
+                        readBytes = buffer.Length;
+                    }
+
+                    _buffer.AsMemory(_bufferOffset, readBytes).CopyTo(buffer);
+                    _bufferOffset += readBytes;
+                    _bufferCount -= readBytes;
+
+                    return readBytes;
+                }
+            }
+            catch (Exception e) when (!(e is IOException || e is OperationCanceledException))
+            {
+                throw new IOException(SR.net_io_read, e);
+            }
+            finally
+            {
+                _readInProgress = 0;
             }
         }
 
-        //
-        // Stream contract implementation
-        //
-        public override bool CanSeek
+        public override void Write(byte[] buffer, int offset, int count)
         {
-            get
+            ValidateParameters(buffer, offset, count);
+
+            ThrowIfFailed(authSuccessCheck: true);
+            if (!CanGetSecureStream)
             {
-                return false;
+                InnerStream.Write(buffer, offset, count);
+                return;
             }
+
+            WriteAsync(new SyncReadWriteAdapter(InnerStream), new ReadOnlyMemory<byte>(buffer, offset, count)).GetAwaiter().GetResult();
         }
 
-        public override bool CanRead
+        public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
         {
-            get
+            ValidateParameters(buffer, offset, count);
+
+            ThrowIfFailed(authSuccessCheck: true);
+            if (!CanGetSecureStream)
             {
-                return IsAuthenticated && InnerStream.CanRead;
+                return InnerStream.WriteAsync(buffer, offset, count, cancellationToken);
             }
+
+            return WriteAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken), new ReadOnlyMemory<byte>(buffer, offset, count));
         }
 
-        public override bool CanTimeout
+        public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
         {
-            get
+            ThrowIfFailed(authSuccessCheck: true);
+            if (!CanGetSecureStream)
             {
-                return InnerStream.CanTimeout;
+                return InnerStream.WriteAsync(buffer, cancellationToken);
             }
+
+            return new ValueTask(WriteAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken), buffer));
         }
 
-        public override bool CanWrite
+        private async Task WriteAsync<TAdapter>(TAdapter adapter, ReadOnlyMemory<byte> buffer) where TAdapter : IReadWriteAdapter
         {
-            get
+            if (Interlocked.Exchange(ref _writeInProgress, 1) == 1)
             {
-                return IsAuthenticated && InnerStream.CanWrite;
+                throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, nameof(Write), "write"));
             }
-        }
 
-        public override int ReadTimeout
-        {
-            get
+            try
             {
-                return InnerStream.ReadTimeout;
+                byte[]? outBuffer = null;
+                while (!buffer.IsEmpty)
+                {
+                    int chunkBytes = Math.Min(buffer.Length, MaxWriteDataSize);
+                    int encryptedBytes;
+                    try
+                    {
+                        encryptedBytes = EncryptData(buffer.Slice(0, chunkBytes).Span, ref outBuffer);
+                    }
+                    catch (Exception e)
+                    {
+                        throw new IOException(SR.net_io_encrypt, e);
+                    }
+
+                    await adapter.WriteAsync(outBuffer, 0, encryptedBytes).ConfigureAwait(false);
+                    buffer = buffer.Slice(chunkBytes);
+                }
+            }
+            catch (Exception e) when (!(e is IOException || e is OperationCanceledException))
+            {
+                throw new IOException(SR.net_io_write, e);
             }
-            set
+            finally
             {
-                InnerStream.ReadTimeout = value;
+                _writeInProgress = 0;
             }
         }
 
-        public override int WriteTimeout
+        public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? asyncCallback, object? asyncState) =>
+            TaskToApm.Begin(ReadAsync(buffer, offset, count), asyncCallback, asyncState);
+
+        public override int EndRead(IAsyncResult asyncResult) =>
+            TaskToApm.End<int>(asyncResult);
+
+        public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? asyncCallback, object? asyncState) =>
+            TaskToApm.Begin(WriteAsync(buffer, offset, count), asyncCallback, asyncState);
+
+        public override void EndWrite(IAsyncResult asyncResult) =>
+            TaskToApm.End(asyncResult);
+
+        /// <summary>Validates user parameters for all Read/Write methods.</summary>
+        private static void ValidateParameters(byte[] buffer, int offset, int count)
         {
-            get
+            if (buffer == null)
             {
-                return InnerStream.WriteTimeout;
+                throw new ArgumentNullException(nameof(buffer));
             }
-            set
+
+            if (offset < 0)
             {
-                InnerStream.WriteTimeout = value;
+                throw new ArgumentOutOfRangeException(nameof(offset));
             }
-        }
 
-        public override long Length
-        {
-            get
+            if (count < 0)
             {
-                return InnerStream.Length;
+                throw new ArgumentOutOfRangeException(nameof(count));
+            }
+
+            if (count > buffer.Length - offset)
+            {
+                throw new ArgumentOutOfRangeException(nameof(count), SR.net_offset_plus_count);
             }
         }
 
-        public override long Position
+        private void ValidateCreateContext(
+            string package,
+            NetworkCredential credential,
+            string servicePrincipalName,
+            ExtendedProtectionPolicy? policy,
+            ProtectionLevel protectionLevel,
+            TokenImpersonationLevel impersonationLevel)
         {
-            get
+            if (policy != null)
             {
-                return InnerStream.Position;
+                // One of these must be set if EP is turned on
+                if (policy.CustomChannelBinding == null && policy.CustomServiceNames == null)
+                {
+                    throw new ArgumentException(SR.net_auth_must_specify_extended_protection_scheme, nameof(policy));
+                }
+
+                _extendedProtectionPolicy = policy;
             }
-            set
+            else
             {
-                throw new NotSupportedException(SR.net_noseek);
+                _extendedProtectionPolicy = new ExtendedProtectionPolicy(PolicyEnforcement.Never);
             }
-        }
 
-        public override void SetLength(long value)
-        {
-            InnerStream.SetLength(value);
+            ValidateCreateContext(package, isServer: true, credential, servicePrincipalName, _extendedProtectionPolicy.CustomChannelBinding, protectionLevel, impersonationLevel);
         }
 
-        public override long Seek(long offset, SeekOrigin origin)
+        private void ValidateCreateContext(
+            string package,
+            bool isServer,
+            NetworkCredential credential,
+            string? servicePrincipalName,
+            ChannelBinding? channelBinding,
+            ProtectionLevel protectionLevel,
+            TokenImpersonationLevel impersonationLevel)
         {
-            throw new NotSupportedException(SR.net_noseek);
-        }
+            if (_exception != null && !_canRetryAuthentication)
+            {
+                ExceptionDispatchInfo.Throw(_exception);
+            }
 
-        public override void Flush()
-        {
-            InnerStream.Flush();
+            if (_context != null && _context.IsValidContext)
+            {
+                throw new InvalidOperationException(SR.net_auth_reauth);
+            }
+
+            if (credential == null)
+            {
+                throw new ArgumentNullException(nameof(credential));
+            }
+
+            if (servicePrincipalName == null)
+            {
+                throw new ArgumentNullException(nameof(servicePrincipalName));
+            }
+
+            NegotiateStreamPal.ValidateImpersonationLevel(impersonationLevel);
+            if (_context != null && IsServer != isServer)
+            {
+                throw new InvalidOperationException(SR.net_auth_client_server);
+            }
+
+            _exception = null;
+            _remoteOk = false;
+            _framer = new StreamFramer();
+            _framer.WriteHeader.MessageId = FrameHeader.HandshakeId;
+
+            _expectedProtectionLevel = protectionLevel;
+            _expectedImpersonationLevel = isServer ? impersonationLevel : TokenImpersonationLevel.None;
+            _writeSequenceNumber = 0;
+            _readSequenceNumber = 0;
+
+            ContextFlagsPal flags = ContextFlagsPal.Connection;
+
+            // A workaround for the client when talking to Win9x on the server side.
+            if (protectionLevel == ProtectionLevel.None && !isServer)
+            {
+                package = NegotiationInfoClass.NTLM;
+            }
+            else if (protectionLevel == ProtectionLevel.EncryptAndSign)
+            {
+                flags |= ContextFlagsPal.Confidentiality;
+            }
+            else if (protectionLevel == ProtectionLevel.Sign)
+            {
+                // Assuming user expects NT4 SP4 and above.
+                flags |= ContextFlagsPal.ReplayDetect | ContextFlagsPal.SequenceDetect | ContextFlagsPal.InitIntegrity;
+            }
+
+            if (isServer)
+            {
+                if (_extendedProtectionPolicy!.PolicyEnforcement == PolicyEnforcement.WhenSupported)
+                {
+                    flags |= ContextFlagsPal.AllowMissingBindings;
+                }
+
+                if (_extendedProtectionPolicy.PolicyEnforcement != PolicyEnforcement.Never &&
+                    _extendedProtectionPolicy.ProtectionScenario == ProtectionScenario.TrustedProxy)
+                {
+                    flags |= ContextFlagsPal.ProxyBindings;
+                }
+            }
+            else
+            {
+                // Server side should not request any of these flags.
+                if (protectionLevel != ProtectionLevel.None)
+                {
+                    flags |= ContextFlagsPal.MutualAuth;
+                }
+
+                if (impersonationLevel == TokenImpersonationLevel.Identification)
+                {
+                    flags |= ContextFlagsPal.InitIdentify;
+                }
+
+                if (impersonationLevel == TokenImpersonationLevel.Delegation)
+                {
+                    flags |= ContextFlagsPal.Delegate;
+                }
+            }
+
+            _canRetryAuthentication = false;
+
+            try
+            {
+                _context = new NTAuthentication(isServer, package, credential, servicePrincipalName, flags, channelBinding!);
+            }
+            catch (Win32Exception e)
+            {
+                throw new AuthenticationException(SR.net_auth_SSPI, e);
+            }
         }
 
-        public override Task FlushAsync(CancellationToken cancellationToken)
+        private void SetFailed(Exception e)
         {
-            return InnerStream.FlushAsync(cancellationToken);
+            if (!(_exception is ObjectDisposedException))
+            {
+                _exception = e;
+            }
+
+            _context?.CloseContext();
         }
 
-        protected override void Dispose(bool disposing)
+        private void ThrowIfFailed(bool authSuccessCheck)
         {
-            try
+            if (_exception != null)
             {
-                _negoState.Close();
+                ExceptionDispatchInfo.Throw(_exception);
             }
-            finally
+
+            if (authSuccessCheck && !IsAuthenticatedCore)
             {
-                base.Dispose(disposing);
+                throw new InvalidOperationException(SR.net_auth_noauth);
             }
         }
 
-        public override async ValueTask DisposeAsync()
+        private async Task AuthenticateAsync<TAdapter>(TAdapter adapter, [CallerMemberName] string? callerName = null) where TAdapter : IReadWriteAdapter
         {
+            Debug.Assert(_context != null);
+
+            ThrowIfFailed(authSuccessCheck: false);
+            if (Interlocked.Exchange(ref _authInProgress, 1) == 1)
+            {
+                throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, callerName, "authenticate"));
+            }
+
             try
             {
-                _negoState.Close();
+                await (_context.IsServer ?
+                    ReceiveBlobAsync(adapter) : // server should listen for a client blob
+                    SendBlobAsync(adapter, message: null)).ConfigureAwait(false); // client should send the first blob
+            }
+            catch (Exception e)
+            {
+                SetFailed(e);
+                throw;
             }
             finally
             {
-                await base.DisposeAsync().ConfigureAwait(false);
+                _authInProgress = 0;
             }
         }
 
-        public override int Read(byte[] buffer, int offset, int count)
+        private bool CheckSpn()
         {
-            _negoState.CheckThrow(true);
+            Debug.Assert(_context != null);
 
-            if (!_negoState.CanGetSecureStream)
+            if (_context.IsKerberos ||
+                _extendedProtectionPolicy!.PolicyEnforcement == PolicyEnforcement.Never ||
+                _extendedProtectionPolicy.CustomServiceNames == null)
             {
-                return InnerStream.Read(buffer, offset, count);
+                return true;
+            }
+
+            string? clientSpn = _context.ClientSpecifiedSpn;
+
+            if (string.IsNullOrEmpty(clientSpn))
+            {
+                return _extendedProtectionPolicy.PolicyEnforcement == PolicyEnforcement.WhenSupported;
             }
 
-            return ProcessRead(buffer, offset, count, null);
+            return _extendedProtectionPolicy.CustomServiceNames.Contains(clientSpn);
         }
 
-        public override void Write(byte[] buffer, int offset, int count)
+        // Client authentication starts here, but server also loops through this method.
+        private async Task SendBlobAsync<TAdapter>(TAdapter adapter, byte[]? message) where TAdapter : IReadWriteAdapter
         {
-            _negoState.CheckThrow(true);
+            Debug.Assert(_context != null);
 
-            if (!_negoState.CanGetSecureStream)
+            Exception? exception = null;
+            if (message != s_emptyMessage)
             {
-                InnerStream.Write(buffer, offset, count);
-                return;
+                message = GetOutgoingBlob(message, ref exception);
             }
 
-            ProcessWrite(buffer, offset, count, null);
-        }
+            if (exception != null)
+            {
+                // Signal remote side on a failed attempt.
+                await SendAuthResetSignalAndThrowAsync(adapter, message!, exception).ConfigureAwait(false);
+                Debug.Fail("Unreachable");
+            }
 
-        public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? asyncCallback, object? asyncState)
-        {
-            _negoState.CheckThrow(true);
+            if (HandshakeComplete)
+            {
+                if (_context.IsServer && !CheckSpn())
+                {
+                    exception = new AuthenticationException(SR.net_auth_bad_client_creds_or_target_mismatch);
+                    int statusCode = ERROR_TRUST_FAILURE;
+                    message = new byte[sizeof(long)];
+
+                    for (int i = message.Length - 1; i >= 0; --i)
+                    {
+                        message[i] = (byte)(statusCode & 0xFF);
+                        statusCode = (int)((uint)statusCode >> 8);
+                    }
+
+                    await SendAuthResetSignalAndThrowAsync(adapter, message, exception).ConfigureAwait(false);
+                    Debug.Fail("Unreachable");
+                }
 
-            if (!_negoState.CanGetSecureStream)
+                if (PrivateImpersonationLevel < _expectedImpersonationLevel)
+                {
+                    exception = new AuthenticationException(SR.Format(SR.net_auth_context_expectation, _expectedImpersonationLevel.ToString(), PrivateImpersonationLevel.ToString()));
+                    int statusCode = ERROR_TRUST_FAILURE;
+                    message = new byte[sizeof(long)];
+
+                    for (int i = message.Length - 1; i >= 0; --i)
+                    {
+                        message[i] = (byte)(statusCode & 0xFF);
+                        statusCode = (int)((uint)statusCode >> 8);
+                    }
+
+                    await SendAuthResetSignalAndThrowAsync(adapter, message, exception).ConfigureAwait(false);
+                    Debug.Fail("Unreachable");
+                }
+
+                ProtectionLevel result = _context.IsConfidentialityFlag ? ProtectionLevel.EncryptAndSign : _context.IsIntegrityFlag ? ProtectionLevel.Sign : ProtectionLevel.None;
+
+                if (result < _expectedProtectionLevel)
+                {
+                    exception = new AuthenticationException(SR.Format(SR.net_auth_context_expectation, result.ToString(), _expectedProtectionLevel.ToString()));
+                    int statusCode = ERROR_TRUST_FAILURE;
+                    message = new byte[sizeof(long)];
+
+                    for (int i = message.Length - 1; i >= 0; --i)
+                    {
+                        message[i] = (byte)(statusCode & 0xFF);
+                        statusCode = (int)((uint)statusCode >> 8);
+                    }
+
+                    await SendAuthResetSignalAndThrowAsync(adapter, message, exception).ConfigureAwait(false);
+                    Debug.Fail("Unreachable");
+                }
+
+                // Signal remote party that we are done
+                _framer!.WriteHeader.MessageId = FrameHeader.HandshakeDoneId;
+                if (_context.IsServer)
+                {
+                    // Server may complete now because client SSPI would not complain at this point.
+                    _remoteOk = true;
+
+                    // However the client will wait for server to send this ACK
+                    // Force signaling server OK to the client
+                    message ??= s_emptyMessage;
+                }
+            }
+            else if (message == null || message == s_emptyMessage)
             {
-                return TaskToApm.Begin(InnerStream.ReadAsync(buffer, offset, count), asyncCallback, asyncState);
+                throw new InternalException();
             }
 
-            BufferAsyncResult bufferResult = new BufferAsyncResult(this, buffer, offset, count, asyncState, asyncCallback);
-            AsyncProtocolRequest asyncRequest = new AsyncProtocolRequest(bufferResult);
-            ProcessRead(buffer, offset, count, asyncRequest);
-            return bufferResult;
+            if (message != null)
+            {
+                //even if we are completed, there could be a blob for sending.
+                await _framer!.WriteMessageAsync(adapter, message).ConfigureAwait(false);
+            }
+
+            if (HandshakeComplete && _remoteOk)
+            {
+                // We are done with success.
+                return;
+            }
+
+            await ReceiveBlobAsync(adapter).ConfigureAwait(false);
         }
 
-        public override int EndRead(IAsyncResult asyncResult)
+        // Server authentication starts here, but client also loops through this method.
+        private async Task ReceiveBlobAsync<TAdapter>(TAdapter adapter) where TAdapter : IReadWriteAdapter
         {
-            _negoState.CheckThrow(true);
+            Debug.Assert(_framer != null);
 
-            if (!_negoState.CanGetSecureStream)
+            byte[]? message = await _framer.ReadMessageAsync(adapter).ConfigureAwait(false);
+            if (message == null)
             {
-                return TaskToApm.End<int>(asyncResult);
+                // This is an EOF otherwise we would get at least *empty* message but not a null one.
+                throw new AuthenticationException(SR.net_auth_eof);
             }
 
-            if (asyncResult == null)
+            // Process Header information.
+            if (_framer.ReadHeader.MessageId == FrameHeader.HandshakeErrId)
             {
-                throw new ArgumentNullException(nameof(asyncResult));
+                if (message.Length >= sizeof(long))
+                {
+                    // Try to recover remote win32 Exception.
+                    long error = 0;
+                    for (int i = 0; i < 8; ++i)
+                    {
+                        error = (error << 8) + message[i];
+                    }
+
+                    ThrowCredentialException(error);
+                }
+
+                throw new AuthenticationException(SR.net_auth_alert);
             }
 
-            BufferAsyncResult? bufferResult = asyncResult as BufferAsyncResult;
-            if (bufferResult == null)
+            if (_framer.ReadHeader.MessageId == FrameHeader.HandshakeDoneId)
             {
-                throw new ArgumentException(SR.Format(SR.net_io_async_result, asyncResult.GetType().FullName), nameof(asyncResult));
+                _remoteOk = true;
             }
-
-            if (Interlocked.Exchange(ref _NestedRead, 0) == 0)
+            else if (_framer.ReadHeader.MessageId != FrameHeader.HandshakeId)
             {
-                throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, "EndRead"));
+                throw new AuthenticationException(SR.Format(SR.net_io_header_id, nameof(FrameHeader.MessageId), _framer.ReadHeader.MessageId, FrameHeader.HandshakeId));
             }
 
-            // No "artificial" timeouts implemented so far, InnerStream controls timeout.
-            bufferResult.InternalWaitForCompletion();
-
-            if (bufferResult.Result is Exception e)
+            // If we are done don't go into send.
+            if (HandshakeComplete)
             {
-                if (e is IOException)
+                if (!_remoteOk)
                 {
-                    ExceptionDispatchInfo.Throw(e);
+                    throw new AuthenticationException(SR.Format(SR.net_io_header_id, nameof(FrameHeader.MessageId), _framer.ReadHeader.MessageId, FrameHeader.HandshakeDoneId));
                 }
 
-                throw new IOException(SR.net_io_read, e);
+                return;
             }
 
-            return bufferResult.Int32Result;
+            // Not yet done, get a new blob and send it if any.
+            await SendBlobAsync(adapter, message).ConfigureAwait(false);
         }
 
-        public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? asyncCallback, object? asyncState)
+        //  This is to reset auth state on the remote side.
+        //  If this write succeeds we will allow auth retrying.
+        private async Task SendAuthResetSignalAndThrowAsync<TAdapter>(TAdapter adapter, byte[] message, Exception exception) where TAdapter : IReadWriteAdapter
         {
-            _negoState.CheckThrow(true);
+            _framer!.WriteHeader.MessageId = FrameHeader.HandshakeErrId;
+
+            if (IsLogonDeniedException(exception))
+            {
+                exception = new InvalidCredentialException(IsServer ? SR.net_auth_bad_client_creds : SR.net_auth_bad_client_creds_or_target_mismatch, exception);
+            }
 
-            if (!_negoState.CanGetSecureStream)
+            if (!(exception is AuthenticationException))
             {
-                return TaskToApm.Begin(InnerStream.WriteAsync(buffer, offset, count), asyncCallback, asyncState);
+                exception = new AuthenticationException(SR.net_auth_SSPI, exception);
             }
 
-            BufferAsyncResult bufferResult = new BufferAsyncResult(this, buffer, offset, count, asyncState, asyncCallback);
-            AsyncProtocolRequest asyncRequest = new AsyncProtocolRequest(bufferResult);
+            await _framer.WriteMessageAsync(adapter, message).ConfigureAwait(false);
 
-            ProcessWrite(buffer, offset, count, asyncRequest);
-            return bufferResult;
+            _canRetryAuthentication = true;
+            ExceptionDispatchInfo.Throw(exception);
         }
 
-        public override void EndWrite(IAsyncResult asyncResult)
+        private static bool IsError(SecurityStatusPal status) =>
+            (int)status.ErrorCode >= (int)SecurityStatusPalErrorCode.OutOfMemory;
+
+        private unsafe byte[]? GetOutgoingBlob(byte[]? incomingBlob, ref Exception? e)
         {
-            _negoState.CheckThrow(true);
+            Debug.Assert(_context != null);
 
-            if (!_negoState.CanGetSecureStream)
-            {
-                TaskToApm.End(asyncResult);
-                return;
-            }
+            byte[]? message = _context.GetOutgoingBlob(incomingBlob, false, out SecurityStatusPal statusCode);
 
-            if (asyncResult == null)
+            if (IsError(statusCode))
             {
-                throw new ArgumentNullException(nameof(asyncResult));
-            }
+                e = NegotiateStreamPal.CreateExceptionFromError(statusCode);
+                uint error = (uint)e.HResult;
 
-            BufferAsyncResult? bufferResult = asyncResult as BufferAsyncResult;
-            if (bufferResult == null)
-            {
-                throw new ArgumentException(SR.Format(SR.net_io_async_result, asyncResult.GetType().FullName), nameof(asyncResult));
+                message = new byte[sizeof(long)];
+                for (int i = message.Length - 1; i >= 0; --i)
+                {
+                    message[i] = (byte)(error & 0xFF);
+                    error >>= 8;
+                }
             }
 
-            if (Interlocked.Exchange(ref _NestedWrite, 0) == 0)
+            if (message != null && message.Length == 0)
             {
-                throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, "EndWrite"));
+                message = s_emptyMessage;
             }
 
-            // No "artificial" timeouts implemented so far, InnerStream controls timeout.
-            bufferResult.InternalWaitForCompletion();
+            return message;
+        }
 
-            if (bufferResult.Result is Exception e)
-            {
-                if (e is IOException)
-                {
-                    ExceptionDispatchInfo.Throw(e);
-                }
+        private int EncryptData(ReadOnlySpan<byte> buffer, [NotNull] ref byte[]? outBuffer)
+        {
+            Debug.Assert(_context != null);
+            ThrowIfFailed(authSuccessCheck: true);
 
-                throw new IOException(SR.net_io_write, e);
-            }
+            // SSPI seems to ignore this sequence number.
+            ++_writeSequenceNumber;
+            return _context.Encrypt(buffer, ref outBuffer, _writeSequenceNumber);
+        }
+
+        private int DecryptData(byte[] buffer, int offset, int count, out int newOffset)
+        {
+            Debug.Assert(_context != null);
+            ThrowIfFailed(authSuccessCheck: true);
+
+            // SSPI seems to ignore this sequence number.
+            ++_readSequenceNumber;
+            return _context.Decrypt(buffer, offset, count, out newOffset, _readSequenceNumber);
         }
+
+        private static void ThrowCredentialException(long error)
+        {
+            var e = new Win32Exception((int)error);
+            throw e.NativeErrorCode switch
+            {
+                (int)SecurityStatusPalErrorCode.LogonDenied => new InvalidCredentialException(SR.net_auth_bad_client_creds, e),
+                ERROR_TRUST_FAILURE => new AuthenticationException(SR.net_auth_context_expectation_remote, e),
+                _ => new AuthenticationException(SR.net_auth_alert, e)
+            };
+        }
+
+        private static bool IsLogonDeniedException(Exception exception) =>
+            exception is Win32Exception win32exception &&
+            win32exception.NativeErrorCode == (int)SecurityStatusPalErrorCode.LogonDenied;
     }
 }
index a03fe69..51d6963 100644 (file)
@@ -4,6 +4,7 @@
 
 using System.ComponentModel;
 using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
 using System.Runtime.InteropServices;
 using System.Security;
 using System.Security.Principal;
@@ -85,12 +86,10 @@ namespace System.Net.Security
 
         internal static int Encrypt(
             SafeDeleteContext securityContext,
-            byte[] buffer,
-            int offset,
-            int count,
+            ReadOnlySpan<byte> buffer,
             bool isConfidential,
             bool isNtlm,
-            ref byte[]? output,
+            [NotNull] ref byte[]? output,
             uint sequenceNumber)
         {
             SecPkgContext_Sizes sizes = default;
@@ -101,9 +100,9 @@ namespace System.Net.Security
             {
                 int maxCount = checked(int.MaxValue - 4 - sizes.cbBlockSize - sizes.cbSecurityTrailer);
 
-                if (count > maxCount || count < 0)
+                if (buffer.Length > maxCount)
                 {
-                    throw new ArgumentOutOfRangeException(nameof(count), SR.Format(SR.net_io_out_range, maxCount));
+                    throw new ArgumentOutOfRangeException(nameof(buffer.Length), SR.Format(SR.net_io_out_range, maxCount));
                 }
             }
             catch (Exception e) when (!ExceptionCheck.IsFatal(e))
@@ -112,21 +111,21 @@ namespace System.Net.Security
                 throw;
             }
 
-            int resultSize = count + sizes.cbSecurityTrailer + sizes.cbBlockSize;
+            int resultSize = buffer.Length + sizes.cbSecurityTrailer + sizes.cbBlockSize;
             if (output == null || output.Length < resultSize + 4)
             {
                 output = new byte[resultSize + 4];
             }
 
             // Make a copy of user data for in-place encryption.
-            Buffer.BlockCopy(buffer, offset, output, 4 + sizes.cbSecurityTrailer, count);
+            buffer.CopyTo(output.AsSpan(4 + sizes.cbSecurityTrailer));
 
             // Prepare buffers TOKEN(signature), DATA and Padding.
             ThreeSecurityBuffers buffers = default;
             var securityBuffer = MemoryMarshal.CreateSpan(ref buffers._item0, 3);
             securityBuffer[0] = new SecurityBuffer(output, 4, sizes.cbSecurityTrailer, SecurityBufferType.SECBUFFER_TOKEN);
-            securityBuffer[1] = new SecurityBuffer(output, 4 + sizes.cbSecurityTrailer, count, SecurityBufferType.SECBUFFER_DATA);
-            securityBuffer[2] = new SecurityBuffer(output, 4 + sizes.cbSecurityTrailer + count, sizes.cbBlockSize, SecurityBufferType.SECBUFFER_PADDING);
+            securityBuffer[1] = new SecurityBuffer(output, 4 + sizes.cbSecurityTrailer, buffer.Length, SecurityBufferType.SECBUFFER_DATA);
+            securityBuffer[2] = new SecurityBuffer(output, 4 + sizes.cbSecurityTrailer + buffer.Length, sizes.cbBlockSize, SecurityBufferType.SECBUFFER_PADDING);
 
             int errorCode;
             if (isConfidential)
@@ -160,7 +159,7 @@ namespace System.Net.Security
             }
 
             resultSize += securityBuffer[1].size;
-            if (securityBuffer[2].size != 0 && (forceCopy || resultSize != (count + sizes.cbSecurityTrailer)))
+            if (securityBuffer[2].size != 0 && (forceCopy || resultSize != (buffer.Length + sizes.cbSecurityTrailer)))
             {
                 Buffer.BlockCopy(output, securityBuffer[2].offset, output, 4 + resultSize, securityBuffer[2].size);
             }
diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/ReadWriteAdapter.cs b/src/libraries/System.Net.Security/src/System/Net/Security/ReadWriteAdapter.cs
new file mode 100644 (file)
index 0000000..497baeb
--- /dev/null
@@ -0,0 +1,89 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.IO;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace System.Net.Security
+{
+    internal interface IReadWriteAdapter
+    {
+        ValueTask<int> ReadAsync(Memory<byte> buffer);
+
+        ValueTask WriteAsync(byte[] buffer, int offset, int count);
+
+        Task WaitAsync(TaskCompletionSource<bool> waiter);
+
+        CancellationToken CancellationToken { get; }
+
+        public async ValueTask<int> ReadAllAsync(Memory<byte> buffer)
+        {
+            int length = buffer.Length;
+
+            do
+            {
+                int bytes = await ReadAsync(buffer).ConfigureAwait(false);
+                if (bytes == 0)
+                {
+                    if (!buffer.IsEmpty)
+                    {
+                        throw new IOException(SR.net_io_eof);
+                    }
+                    break;
+                }
+
+                buffer = buffer.Slice(bytes);
+            }
+            while (!buffer.IsEmpty);
+
+            return length;
+        }
+    }
+
+    internal readonly struct AsyncReadWriteAdapter : IReadWriteAdapter
+    {
+        private readonly Stream _stream;
+
+        public AsyncReadWriteAdapter(Stream stream, CancellationToken cancellationToken)
+        {
+            _stream = stream;
+            CancellationToken = cancellationToken;
+        }
+
+        public ValueTask<int> ReadAsync(Memory<byte> buffer) =>
+            _stream.ReadAsync(buffer, CancellationToken);
+
+        public ValueTask WriteAsync(byte[] buffer, int offset, int count) =>
+            _stream.WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), CancellationToken);
+
+        public Task WaitAsync(TaskCompletionSource<bool> waiter) => waiter.Task;
+
+        public CancellationToken CancellationToken { get; }
+    }
+
+    internal readonly struct SyncReadWriteAdapter : IReadWriteAdapter
+    {
+        private readonly Stream _stream;
+
+        public SyncReadWriteAdapter(Stream stream) => _stream = stream;
+
+        public ValueTask<int> ReadAsync(Memory<byte> buffer) =>
+            new ValueTask<int>(_stream.Read(buffer.Span));
+
+        public ValueTask WriteAsync(byte[] buffer, int offset, int count)
+        {
+            _stream.Write(buffer, offset, count);
+            return default;
+        }
+
+        public Task WaitAsync(TaskCompletionSource<bool> waiter)
+        {
+            waiter.Task.GetAwaiter().GetResult();
+            return Task.CompletedTask;
+        }
+
+        public CancellationToken CancellationToken => default;
+    }
+}
diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.Adapters.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.Adapters.cs
deleted file mode 100644 (file)
index 4995f3f..0000000
+++ /dev/null
@@ -1,63 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-using System.Threading;
-using System.Threading.Tasks;
-
-namespace System.Net.Security
-{
-    // This contains adapters to allow a single code path for sync/async logic
-    public partial class SslStream
-    {
-        private interface ISslIOAdapter
-        {
-            ValueTask<int> ReadAsync(Memory<byte> buffer);
-            ValueTask WriteAsync(byte[] buffer, int offset, int count);
-            Task WaitAsync(TaskCompletionSource<bool> waiter);
-            CancellationToken CancellationToken { get; }
-        }
-
-        private readonly struct AsyncSslIOAdapter : ISslIOAdapter
-        {
-            private readonly SslStream _sslStream;
-            private readonly CancellationToken _cancellationToken;
-
-            public AsyncSslIOAdapter(SslStream sslStream, CancellationToken cancellationToken)
-            {
-                _cancellationToken = cancellationToken;
-                _sslStream = sslStream;
-            }
-
-            public ValueTask<int> ReadAsync(Memory<byte> buffer) => _sslStream.InnerStream.ReadAsync(buffer, _cancellationToken);
-
-            public ValueTask WriteAsync(byte[] buffer, int offset, int count) => _sslStream.InnerStream.WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), _cancellationToken);
-
-            public Task WaitAsync(TaskCompletionSource<bool> waiter) => waiter.Task;
-
-            public CancellationToken CancellationToken => _cancellationToken;
-        }
-
-        private readonly struct SyncSslIOAdapter : ISslIOAdapter
-        {
-            private readonly SslStream _sslStream;
-
-            public SyncSslIOAdapter(SslStream sslStream) => _sslStream = sslStream;
-
-            public ValueTask<int> ReadAsync(Memory<byte> buffer) => new ValueTask<int>(_sslStream.InnerStream.Read(buffer.Span));
-
-            public ValueTask WriteAsync(byte[] buffer, int offset, int count)
-            {
-                _sslStream.InnerStream.Write(buffer, offset, count);
-                return default;
-            }
-
-            public Task WaitAsync(TaskCompletionSource<bool> waiter)
-            {
-                waiter.Task.Wait();
-                return Task.CompletedTask;
-            }
-
-            public CancellationToken CancellationToken => default;
-        }
-    }
-}
index d9ff52d..68b1425 100644 (file)
@@ -196,11 +196,11 @@ namespace System.Net.Security
 
             if (isAsync)
             {
-                result = ForceAuthenticationAsync(new AsyncSslIOAdapter(this, cancellationToken), _context!.IsServer, null, isApm);
+                result = ForceAuthenticationAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken), _context!.IsServer, null, isApm);
             }
             else
             {
-                ForceAuthenticationAsync(new SyncSslIOAdapter(this), _context!.IsServer, null).GetAwaiter().GetResult();
+                ForceAuthenticationAsync(new SyncReadWriteAdapter(InnerStream), _context!.IsServer, null).GetAwaiter().GetResult();
                 result = null;
             }
 
@@ -211,7 +211,7 @@ namespace System.Net.Security
         // This is used to reply on re-handshake when received SEC_I_RENEGOTIATE on Read().
         //
         private async Task ReplyOnReAuthenticationAsync<TIOAdapter>(TIOAdapter adapter, byte[]? buffer)
-            where TIOAdapter : ISslIOAdapter
+            where TIOAdapter : IReadWriteAdapter
         {
             try
             {
@@ -226,7 +226,7 @@ namespace System.Net.Security
 
         // reAuthenticationData is only used on Windows in case of renegotiation.
         private async Task ForceAuthenticationAsync<TIOAdapter>(TIOAdapter adapter, bool receiveFirst, byte[]? reAuthenticationData, bool isApm = false)
-             where TIOAdapter : ISslIOAdapter
+             where TIOAdapter : IReadWriteAdapter
         {
             ProtocolToken message;
             bool handshakeCompleted = false;
@@ -339,7 +339,7 @@ namespace System.Net.Security
         }
 
         private async ValueTask<ProtocolToken> ReceiveBlobAsync<TIOAdapter>(TIOAdapter adapter)
-                 where TIOAdapter : ISslIOAdapter
+                 where TIOAdapter : IReadWriteAdapter
         {
             int readBytes = await FillHandshakeBufferAsync(adapter, SecureChannel.ReadHeaderSize).ConfigureAwait(false);
             if (readBytes == 0)
@@ -486,7 +486,7 @@ namespace System.Net.Security
         }
 
         private async ValueTask WriteAsyncChunked<TIOAdapter>(TIOAdapter writeAdapter, ReadOnlyMemory<byte> buffer)
-            where TIOAdapter : struct, ISslIOAdapter
+            where TIOAdapter : struct, IReadWriteAdapter
         {
             do
             {
@@ -497,7 +497,7 @@ namespace System.Net.Security
         }
 
         private ValueTask WriteSingleChunk<TIOAdapter>(TIOAdapter writeAdapter, ReadOnlyMemory<byte> buffer)
-            where TIOAdapter : struct, ISslIOAdapter
+            where TIOAdapter : struct, IReadWriteAdapter
         {
             byte[] rentedBuffer = ArrayPool<byte>.Shared.Rent(buffer.Length + FrameOverhead);
             byte[] outBuffer = rentedBuffer;
@@ -643,7 +643,7 @@ namespace System.Net.Security
         }
 
         private async ValueTask<int> ReadAsyncInternal<TIOAdapter>(TIOAdapter adapter, Memory<byte> buffer)
-            where TIOAdapter : ISslIOAdapter
+            where TIOAdapter : IReadWriteAdapter
         {
             if (Interlocked.Exchange(ref _nestedRead, 1) == 1)
             {
@@ -790,7 +790,7 @@ namespace System.Net.Security
         // If we have enough data, it returns synchronously. If not, it will try to read
         // remaining bytes from given stream.
         private ValueTask<int> FillHandshakeBufferAsync<TIOAdapter>(TIOAdapter adapter, int minSize)
-             where TIOAdapter : ISslIOAdapter
+             where TIOAdapter : IReadWriteAdapter
         {
             if (_handshakeBuffer.ActiveLength >= minSize)
             {
@@ -840,7 +840,7 @@ namespace System.Net.Security
         }
 
         private async ValueTask FillBufferAsync<TIOAdapter>(TIOAdapter adapter, int numBytesRequired)
-            where TIOAdapter : ISslIOAdapter
+            where TIOAdapter : IReadWriteAdapter
         {
             Debug.Assert(_internalBufferCount > 0);
             Debug.Assert(_internalBufferCount < numBytesRequired);
@@ -858,7 +858,7 @@ namespace System.Net.Security
         }
 
         private async ValueTask WriteAsyncInternal<TIOAdapter>(TIOAdapter writeAdapter, ReadOnlyMemory<byte> buffer)
-            where TIOAdapter : struct, ISslIOAdapter
+            where TIOAdapter : struct, IReadWriteAdapter
         {
             ThrowIfExceptionalOrNotAuthenticatedOrShutdown();
 
index 79c16bf..f363879 100644 (file)
@@ -762,8 +762,7 @@ namespace System.Net.Security
         {
             ThrowIfExceptionalOrNotAuthenticated();
             ValidateParameters(buffer, offset, count);
-            SyncSslIOAdapter reader = new SyncSslIOAdapter(this);
-            ValueTask<int> vt = ReadAsyncInternal(reader, new Memory<byte>(buffer, offset, count));
+            ValueTask<int> vt = ReadAsyncInternal(new SyncReadWriteAdapter(InnerStream), new Memory<byte>(buffer, offset, count));
             Debug.Assert(vt.IsCompleted, "Sync operation must have completed synchronously");
             return vt.GetAwaiter().GetResult();
         }
@@ -775,8 +774,7 @@ namespace System.Net.Security
             ThrowIfExceptionalOrNotAuthenticated();
             ValidateParameters(buffer, offset, count);
 
-            SyncSslIOAdapter writeAdapter = new SyncSslIOAdapter(this);
-            ValueTask vt = WriteAsyncInternal(writeAdapter, new ReadOnlyMemory<byte>(buffer, offset, count));
+            ValueTask vt = WriteAsyncInternal(new SyncReadWriteAdapter(InnerStream), new ReadOnlyMemory<byte>(buffer, offset, count));
             Debug.Assert(vt.IsCompleted, "Sync operation must have completed synchronously");
             vt.GetAwaiter().GetResult();
         }
@@ -815,23 +813,20 @@ namespace System.Net.Security
         public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
         {
             ThrowIfExceptionalOrNotAuthenticated();
-            AsyncSslIOAdapter writeAdapter = new AsyncSslIOAdapter(this, cancellationToken);
-            return WriteAsyncInternal(writeAdapter, buffer);
+            return WriteAsyncInternal(new AsyncReadWriteAdapter(InnerStream, cancellationToken), buffer);
         }
 
         public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
         {
             ThrowIfExceptionalOrNotAuthenticated();
             ValidateParameters(buffer, offset, count);
-            AsyncSslIOAdapter read = new AsyncSslIOAdapter(this, cancellationToken);
-            return ReadAsyncInternal(read, new Memory<byte>(buffer, offset, count)).AsTask();
+            return ReadAsyncInternal(new AsyncReadWriteAdapter(InnerStream, cancellationToken), new Memory<byte>(buffer, offset, count)).AsTask();
         }
 
         public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
         {
             ThrowIfExceptionalOrNotAuthenticated();
-            AsyncSslIOAdapter read = new AsyncSslIOAdapter(this, cancellationToken);
-            return ReadAsyncInternal(read, buffer);
+            return ReadAsyncInternal(new AsyncReadWriteAdapter(InnerStream, cancellationToken), buffer);
         }
 
         private void ThrowIfExceptional()
index eda52c9..338e3c8 100644 (file)
@@ -2,28 +2,23 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
+using System.Diagnostics;
 using System.Net.Security;
 using System.Security.Authentication.ExtendedProtection;
 
 namespace System.Net
 {
-    internal class SslStreamContext : TransportContext
+    internal sealed class SslStreamContext : TransportContext
     {
+        private readonly SslStream _sslStream;
+
         internal SslStreamContext(SslStream sslStream)
         {
-            if (sslStream == null)
-            {
-                NetEventSource.Fail(this, "Not expecting a null sslStream!");
-            }
-
+            Debug.Assert(sslStream != null);
             _sslStream = sslStream!;
         }
 
-        public override ChannelBinding? GetChannelBinding(ChannelBindingKind kind)
-        {
-            return _sslStream.GetChannelBinding(kind);
-        }
-
-        private readonly SslStream _sslStream;
+        public override ChannelBinding? GetChannelBinding(ChannelBindingKind kind) =>
+            _sslStream.GetChannelBinding(kind);
     }
 }
index dc494ca..264fb4d 100644 (file)
@@ -4,83 +4,37 @@
 
 using System.IO;
 using System.Globalization;
-using System.Runtime.ExceptionServices;
+using System.Net.Security;
 using System.Threading.Tasks;
 
 namespace System.Net
 {
-    internal class StreamFramer
+    internal sealed class StreamFramer
     {
-        private readonly Stream _transport;
-
-        private bool _eof;
-
         private readonly FrameHeader _writeHeader = new FrameHeader();
         private readonly FrameHeader _curReadHeader = new FrameHeader();
-        private readonly FrameHeader _readVerifier = new FrameHeader(
-                                                    FrameHeader.IgnoreValue,
-                                                    FrameHeader.IgnoreValue,
-                                                    FrameHeader.IgnoreValue);
-
-        private readonly byte[] _readHeaderBuffer;
-        private readonly byte[] _writeHeaderBuffer;
 
-        private readonly AsyncCallback _readFrameCallback;
-        private readonly AsyncCallback _beginWriteCallback;
-
-        public StreamFramer(Stream Transport)
-        {
-            if (Transport == null || Transport == Stream.Null)
-            {
-                throw new ArgumentNullException(nameof(Transport));
-            }
-
-            _transport = Transport;
-            _readHeaderBuffer = new byte[_curReadHeader.Size];
-            _writeHeaderBuffer = new byte[_writeHeader.Size];
-
-            _readFrameCallback = new AsyncCallback(ReadFrameCallback);
-            _beginWriteCallback = new AsyncCallback(BeginWriteCallback);
-        }
-
-        public FrameHeader ReadHeader
-        {
-            get
-            {
-                return _curReadHeader;
-            }
-        }
-
-        public FrameHeader WriteHeader
-        {
-            get
-            {
-                return _writeHeader;
-            }
-        }
+        private readonly byte[] _readHeaderBuffer = new byte[FrameHeader.Size];
+        private readonly byte[] _writeHeaderBuffer = new byte[FrameHeader.Size];
+        private bool _eof;
 
-        public Stream Transport
-        {
-            get
-            {
-                return _transport;
-            }
-        }
+        public FrameHeader ReadHeader => _curReadHeader;
+        public FrameHeader WriteHeader => _writeHeader;
 
-        public byte[]? ReadMessage()
+        public async ValueTask<byte[]?> ReadMessageAsync<TAdapter>(TAdapter adapter) where TAdapter : IReadWriteAdapter
         {
             if (_eof)
             {
                 return null;
             }
 
-            int offset = 0;
             byte[] buffer = _readHeaderBuffer;
 
             int bytesRead;
+            int offset = 0;
             while (offset < buffer.Length)
             {
-                bytesRead = Transport.Read(buffer, offset, buffer.Length - offset);
+                bytesRead = await adapter.ReadAsync(buffer.AsMemory(offset)).ConfigureAwait(false);
                 if (bytesRead == 0)
                 {
                     if (offset == 0)
@@ -89,20 +43,18 @@ namespace System.Net
                         _eof = true;
                         return null;
                     }
-                    else
-                    {
-                        throw new IOException(SR.Format(SR.net_io_readfailure, SR.net_io_connectionclosed));
-                    }
+
+                    throw new IOException(SR.Format(SR.net_io_readfailure, SR.net_io_connectionclosed));
                 }
 
                 offset += bytesRead;
             }
 
-            _curReadHeader.CopyFrom(buffer, 0, _readVerifier);
-            if (_curReadHeader.PayloadSize > _curReadHeader.MaxMessageSize)
+            _curReadHeader.CopyFrom(buffer, 0);
+            if (_curReadHeader.PayloadSize > FrameHeader.MaxMessageSize)
             {
                 throw new InvalidOperationException(SR.Format(SR.net_frame_size,
-                                                               _curReadHeader.MaxMessageSize.ToString(NumberFormatInfo.InvariantInfo),
+                                                               FrameHeader.MaxMessageSize,
                                                                _curReadHeader.PayloadSize.ToString(NumberFormatInfo.InvariantInfo)));
             }
 
@@ -111,7 +63,7 @@ namespace System.Net
             offset = 0;
             while (offset < buffer.Length)
             {
-                bytesRead = Transport.Read(buffer, offset, buffer.Length - offset);
+                bytesRead = await adapter.ReadAsync(buffer.AsMemory(offset)).ConfigureAwait(false);
                 if (bytesRead == 0)
                 {
                     throw new IOException(SR.Format(SR.net_io_readfailure, SR.net_io_connectionclosed));
@@ -122,226 +74,7 @@ namespace System.Net
             return buffer;
         }
 
-        public IAsyncResult BeginReadMessage(AsyncCallback asyncCallback, object stateObject)
-        {
-            WorkerAsyncResult workerResult;
-
-            if (_eof)
-            {
-                workerResult = new WorkerAsyncResult(this, stateObject, asyncCallback, null, 0, 0);
-                workerResult.InvokeCallback(-1);
-                return workerResult;
-            }
-
-            workerResult = new WorkerAsyncResult(this, stateObject, asyncCallback,
-                                                                   _readHeaderBuffer, 0,
-                                                                   _readHeaderBuffer.Length);
-
-            IAsyncResult result = TaskToApm.Begin(_transport.ReadAsync(_readHeaderBuffer, 0, _readHeaderBuffer.Length),
-                _readFrameCallback, workerResult);
-
-            if (result.CompletedSynchronously)
-            {
-                ReadFrameComplete(result);
-            }
-
-            return workerResult;
-        }
-
-        private void ReadFrameCallback(IAsyncResult transportResult)
-        {
-            if (!(transportResult.AsyncState is WorkerAsyncResult))
-            {
-                NetEventSource.Fail(this, $"The state expected to be WorkerAsyncResult, received {transportResult}.");
-            }
-
-            if (transportResult.CompletedSynchronously)
-            {
-                return;
-            }
-
-            WorkerAsyncResult workerResult = (WorkerAsyncResult)transportResult.AsyncState!;
-
-            try
-            {
-                ReadFrameComplete(transportResult);
-            }
-            catch (Exception e)
-            {
-                if (e is OutOfMemoryException)
-                {
-                    throw;
-                }
-
-                if (!(e is IOException))
-                {
-                    e = new System.IO.IOException(SR.Format(SR.net_io_readfailure, e.Message), e);
-                }
-
-                workerResult.InvokeCallback(e);
-            }
-        }
-
-        // IO COMPLETION CALLBACK
-        //
-        // This callback is responsible for getting the complete protocol frame.
-        // 1. it reads the header.
-        // 2. it determines the frame size.
-        // 3. loops while not all frame received or an error.
-        //
-        private void ReadFrameComplete(IAsyncResult transportResult)
-        {
-            do
-            {
-                if (!(transportResult.AsyncState is WorkerAsyncResult))
-                {
-                    NetEventSource.Fail(this, $"The state expected to be WorkerAsyncResult, received {transportResult}.");
-                }
-
-                WorkerAsyncResult workerResult = (WorkerAsyncResult)transportResult.AsyncState!;
-
-                int bytesRead = TaskToApm.End<int>(transportResult);
-                workerResult.Offset += bytesRead;
-
-                if (!(workerResult.Offset <= workerResult.End))
-                {
-                    NetEventSource.Fail(this, $"WRONG: offset - end = {workerResult.Offset - workerResult.End}");
-                }
-
-                if (bytesRead <= 0)
-                {
-                    // (by design) This indicates the stream has receives EOF
-                    // If we are in the middle of a Frame - fail, otherwise - produce EOF
-                    object? result = null;
-                    if (!workerResult.HeaderDone && workerResult.Offset == 0)
-                    {
-                        result = (object)-1;
-                    }
-                    else
-                    {
-                        result = new System.IO.IOException(SR.net_frame_read_io);
-                    }
-
-                    workerResult.InvokeCallback(result);
-                    return;
-                }
-
-                if (workerResult.Offset >= workerResult.End)
-                {
-                    if (!workerResult.HeaderDone)
-                    {
-                        workerResult.HeaderDone = true;
-                        // This indicates the header has been read successfully
-                        _curReadHeader.CopyFrom(workerResult.Buffer!, 0, _readVerifier);
-                        int payloadSize = _curReadHeader.PayloadSize;
-                        if (payloadSize < 0)
-                        {
-                            // Let's call user callback and they call us back and we will throw
-                            workerResult.InvokeCallback(new System.IO.IOException(SR.net_frame_read_size));
-                        }
-
-                        if (payloadSize == 0)
-                        {
-                            // report empty frame (NOT eof!) to the caller, he might be interested in
-                            workerResult.InvokeCallback(0);
-                            return;
-                        }
-
-                        if (payloadSize > _curReadHeader.MaxMessageSize)
-                        {
-                            throw new InvalidOperationException(SR.Format(SR.net_frame_size,
-                                                                            _curReadHeader.MaxMessageSize.ToString(NumberFormatInfo.InvariantInfo),
-                                                                            payloadSize.ToString(NumberFormatInfo.InvariantInfo)));
-                        }
-
-                        // Start reading the remaining frame data (note header does not count).
-                        byte[] frame = new byte[payloadSize];
-                        // Save the ref of the data block
-                        workerResult.Buffer = frame;
-                        workerResult.End = frame.Length;
-                        workerResult.Offset = 0;
-
-                        // Transport.ReadAsync below will pickup those changes.
-                    }
-                    else
-                    {
-                        workerResult.HeaderDone = false; // Reset for optional object reuse.
-                        workerResult.InvokeCallback(workerResult.End);
-                        return;
-                    }
-                }
-
-                // This means we need more data to complete the data block.
-                transportResult = TaskToApm.Begin(_transport.ReadAsync(workerResult.Buffer!, workerResult.Offset, workerResult.End - workerResult.Offset),
-                                            _readFrameCallback, workerResult);
-            } while (transportResult.CompletedSynchronously);
-        }
-
-        //
-        // User code will call this when workerResult gets signaled.
-        //
-        // On BeginRead, the user always gets back our WorkerAsyncResult.
-        // The Result property represents either a number of bytes read or an
-        // exception put by our async state machine.
-        //
-        public byte[]? EndReadMessage(IAsyncResult asyncResult)
-        {
-            if (asyncResult == null)
-            {
-                throw new ArgumentNullException(nameof(asyncResult));
-            }
-            WorkerAsyncResult? workerResult = asyncResult as WorkerAsyncResult;
-
-            if (workerResult == null)
-            {
-                throw new ArgumentException(SR.Format(SR.net_io_async_result, typeof(WorkerAsyncResult).FullName), nameof(asyncResult));
-            }
-
-            if (!workerResult.InternalPeekCompleted)
-            {
-                workerResult.InternalWaitForCompletion();
-            }
-
-            if (workerResult.Result is Exception e)
-            {
-                ExceptionDispatchInfo.Throw(e);
-            }
-
-            int size = (int)workerResult.Result!;
-            if (size == -1)
-            {
-                _eof = true;
-                return null;
-            }
-            else if (size == 0)
-            {
-                // Empty frame.
-                return Array.Empty<byte>();
-            }
-
-            return workerResult.Buffer;
-        }
-
-        public void WriteMessage(byte[] message)
-        {
-            if (message == null)
-            {
-                throw new ArgumentNullException(nameof(message));
-            }
-
-            _writeHeader.PayloadSize = message.Length;
-            _writeHeader.CopyTo(_writeHeaderBuffer, 0);
-
-            Transport.Write(_writeHeaderBuffer, 0, _writeHeaderBuffer.Length);
-            if (message.Length == 0)
-            {
-                return;
-            }
-
-            Transport.Write(message, 0, message.Length);
-        }
-
-        public IAsyncResult BeginWriteMessage(byte[] message, AsyncCallback asyncCallback, object stateObject)
+        public async Task WriteMessageAsync<TAdapter>(TAdapter adapter, byte[] message) where TAdapter : IReadWriteAdapter
         {
             if (message == null)
             {
@@ -351,141 +84,16 @@ namespace System.Net
             _writeHeader.PayloadSize = message.Length;
             _writeHeader.CopyTo(_writeHeaderBuffer, 0);
 
-            if (message.Length == 0)
-            {
-                return TaskToApm.Begin(_transport.WriteAsync(_writeHeaderBuffer, 0, _writeHeaderBuffer.Length),
-                                                   asyncCallback, stateObject);
-            }
-
-            // Will need two async writes. Prepare the second:
-            WorkerAsyncResult workerResult = new WorkerAsyncResult(this, stateObject, asyncCallback,
-                                                                   message, 0, message.Length);
-
-            // Charge the first:
-            IAsyncResult result = TaskToApm.Begin(_transport.WriteAsync(_writeHeaderBuffer, 0, _writeHeaderBuffer.Length),
-                                 _beginWriteCallback, workerResult);
-
-            if (result.CompletedSynchronously)
-            {
-                BeginWriteComplete(result);
-            }
-
-            return workerResult;
-        }
-
-        private void BeginWriteCallback(IAsyncResult transportResult)
-        {
-            if (!(transportResult.AsyncState is WorkerAsyncResult))
-            {
-                NetEventSource.Fail(this, $"The state expected to be WorkerAsyncResult, received {transportResult}.");
-            }
-
-            if (transportResult.CompletedSynchronously)
-            {
-                return;
-            }
-
-            var workerResult = (WorkerAsyncResult)transportResult.AsyncState!;
-
-            try
-            {
-                BeginWriteComplete(transportResult);
-            }
-            catch (Exception e)
-            {
-                if (e is OutOfMemoryException)
-                {
-                    throw;
-                }
-
-                workerResult.InvokeCallback(e);
-            }
-        }
-
-        // IO COMPLETION CALLBACK
-        //
-        // Called when user IO request was wrapped to do several underlined IO.
-        //
-        private void BeginWriteComplete(IAsyncResult transportResult)
-        {
-            do
-            {
-                WorkerAsyncResult workerResult = (WorkerAsyncResult)transportResult.AsyncState!;
-
-                // First, complete the previous portion write.
-                TaskToApm.End(transportResult);
-
-                // Check on exit criterion.
-                if (workerResult.Offset == workerResult.End)
-                {
-                    workerResult.InvokeCallback();
-                    return;
-                }
-
-                // Setup exit criterion.
-                workerResult.Offset = workerResult.End;
-
-                // Write next portion (frame body) using Async IO.
-                transportResult = TaskToApm.Begin(_transport.WriteAsync(workerResult.Buffer!, 0, workerResult.End),
-                                            _beginWriteCallback, workerResult);
-            }
-            while (transportResult.CompletedSynchronously);
-        }
-
-        public void EndWriteMessage(IAsyncResult asyncResult)
-        {
-            if (asyncResult == null)
-            {
-                throw new ArgumentNullException(nameof(asyncResult));
-            }
-
-            WorkerAsyncResult? workerResult = asyncResult as WorkerAsyncResult;
-
-            if (workerResult != null)
+            await adapter.WriteAsync(_writeHeaderBuffer, 0, _writeHeaderBuffer.Length).ConfigureAwait(false);
+            if (message.Length != 0)
             {
-                if (!workerResult.InternalPeekCompleted)
-                {
-                    workerResult.InternalWaitForCompletion();
-                }
-
-                if (workerResult.Result is Exception e)
-                {
-                    ExceptionDispatchInfo.Throw(e);
-                }
-            }
-            else
-            {
-                TaskToApm.End(asyncResult);
+                await adapter.WriteAsync(message, 0, message.Length).ConfigureAwait(false);
             }
         }
     }
 
-    //
-    // This class wraps an Async IO request. It is based on our internal LazyAsyncResult helper.
-    // - If ParentResult is not null then the base class (LazyAsyncResult) methods must not be used.
-    // - If ParentResult == null, then real user IO request is wrapped.
-    //
-
-    internal class WorkerAsyncResult : LazyAsyncResult
-    {
-        public byte[]? Buffer;
-        public int Offset;
-        public int End;
-        public bool HeaderDone; // This might be reworked so we read both header and frame in one chunk.
-
-        public WorkerAsyncResult(object asyncObject, object asyncState,
-                                   AsyncCallback savedAsyncCallback,
-                                   byte[]? buffer, int offset, int end)
-            : base(asyncObject, asyncState, savedAsyncCallback)
-        {
-            Buffer = buffer;
-            Offset = offset;
-            End = end;
-        }
-    }
-
     // Describes the header used in framing of the stream data.
-    internal class FrameHeader
+    internal sealed class FrameHeader
     {
         public const int IgnoreValue = -1;
         public const int HandshakeDoneId = 20;
@@ -493,121 +101,44 @@ namespace System.Net
         public const int HandshakeId = 22;
         public const int DefaultMajorV = 1;
         public const int DefaultMinorV = 0;
+        public const int Size = 5;
+        public const int MaxMessageSize = 0xFFFF;
 
-        private int _MessageId;
-        private int _MajorV;
-        private int _MinorV;
-        private int _PayloadSize;
-
-        public FrameHeader()
-        {
-            _MessageId = HandshakeId;
-            _MajorV = DefaultMajorV;
-            _MinorV = DefaultMinorV;
-            _PayloadSize = -1;
-        }
-
-        public FrameHeader(int messageId, int majorV, int minorV)
-        {
-            _MessageId = messageId;
-            _MajorV = majorV;
-            _MinorV = minorV;
-            _PayloadSize = -1;
-        }
-
-        public int Size
-        {
-            get
-            {
-                return 5;
-            }
-        }
-
-        public int MaxMessageSize
-        {
-            get
-            {
-                return 0xFFFF;
-            }
-        }
-
-        public int MessageId
-        {
-            get
-            {
-                return _MessageId;
-            }
-            set
-            {
-                _MessageId = value;
-            }
-        }
+        private int _payloadSize = -1;
 
-        public int MajorV
-        {
-            get
-            {
-                return _MajorV;
-            }
-        }
-
-        public int MinorV
-        {
-            get
-            {
-                return _MinorV;
-            }
-        }
+        public int MessageId { get; set; } = HandshakeId;
+        public int MajorV { get; private set; } = DefaultMajorV;
+        public int MinorV { get; private set; } = DefaultMinorV;
 
         public int PayloadSize
         {
-            get
-            {
-                return _PayloadSize;
-            }
+            get => _payloadSize;
             set
             {
                 if (value > MaxMessageSize)
                 {
-                    throw new ArgumentException(SR.Format(SR.net_frame_max_size,
-                        MaxMessageSize.ToString(NumberFormatInfo.InvariantInfo),
-                        value.ToString(NumberFormatInfo.InvariantInfo)), "PayloadSize");
+                    throw new ArgumentException(SR.Format(SR.net_frame_max_size, MaxMessageSize, value), nameof(PayloadSize));
                 }
 
-                _PayloadSize = value;
+                _payloadSize = value;
             }
         }
 
         public void CopyTo(byte[] dest, int start)
         {
-            dest[start++] = (byte)_MessageId;
-            dest[start++] = (byte)_MajorV;
-            dest[start++] = (byte)_MinorV;
-            dest[start++] = (byte)((_PayloadSize >> 8) & 0xFF);
-            dest[start] = (byte)(_PayloadSize & 0xFF);
+            dest[start++] = (byte)MessageId;
+            dest[start++] = (byte)MajorV;
+            dest[start++] = (byte)MinorV;
+            dest[start++] = (byte)((_payloadSize >> 8) & 0xFF);
+            dest[start] = (byte)(_payloadSize & 0xFF);
         }
 
-        public void CopyFrom(byte[] bytes, int start, FrameHeader verifier)
+        public void CopyFrom(byte[] bytes, int start)
         {
-            _MessageId = bytes[start++];
-            _MajorV = bytes[start++];
-            _MinorV = bytes[start++];
-            _PayloadSize = (int)((bytes[start++] << 8) | bytes[start]);
-
-            if (verifier.MessageId != FrameHeader.IgnoreValue && MessageId != verifier.MessageId)
-            {
-                throw new InvalidOperationException(SR.Format(SR.net_io_header_id, "MessageId", MessageId, verifier.MessageId));
-            }
-
-            if (verifier.MajorV != FrameHeader.IgnoreValue && MajorV != verifier.MajorV)
-            {
-                throw new InvalidOperationException(SR.Format(SR.net_io_header_id, "MajorV", MajorV, verifier.MajorV));
-            }
-
-            if (verifier.MinorV != FrameHeader.IgnoreValue && MinorV != verifier.MinorV)
-            {
-                throw new InvalidOperationException(SR.Format(SR.net_io_header_id, "MinorV", MinorV, verifier.MinorV));
-            }
+            MessageId = bytes[start++];
+            MajorV = bytes[start++];
+            MinorV = bytes[start++];
+            _payloadSize = (bytes[start++] << 8) | bytes[start];
         }
     }
 }
index bdde8c6..68c6368 100644 (file)
@@ -99,56 +99,6 @@ namespace System.Net.Security.Tests
         }
 
         [Fact]
-        public async Task NegotiateStream_ConcurrentAsyncReadOrWrite_ThrowsNotSupportedException()
-        {
-            byte[] recvBuf = new byte[s_sampleMsg.Length];
-            var network = new VirtualNetwork();
-
-            using (var clientStream = new VirtualNetworkStream(network, isServer: false))
-            using (var serverStream = new VirtualNetworkStream(network, isServer: true))
-            using (var client = new NegotiateStream(clientStream))
-            using (var server = new NegotiateStream(serverStream))
-            {
-                await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
-                    client.AuthenticateAsClientAsync(CredentialCache.DefaultNetworkCredentials, string.Empty),
-                    server.AuthenticateAsServerAsync());
-
-                // Custom EndWrite/Read will not reset the variable which monitors concurrent write/read.
-                await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
-                    Task.Factory.FromAsync(client.BeginWrite, (ar) => { Assert.NotNull(ar); }, s_sampleMsg, 0, s_sampleMsg.Length, client),
-                    Task.Factory.FromAsync(server.BeginRead, (ar) => { Assert.NotNull(ar); }, recvBuf, 0, s_sampleMsg.Length, server));
-
-                Assert.Throws<NotSupportedException>(() => client.BeginWrite(s_sampleMsg, 0, s_sampleMsg.Length, (ar) => { Assert.Null(ar); }, null));
-                Assert.Throws<NotSupportedException>(() => server.BeginRead(recvBuf, 0, s_sampleMsg.Length, (ar) => { Assert.Null(ar); }, null));
-            }
-        }
-
-        [Fact]
-        public async Task NegotiateStream_ConcurrentSyncReadOrWrite_ThrowsNotSupportedException()
-        {
-            byte[] recvBuf = new byte[s_sampleMsg.Length];
-            var network = new VirtualNetwork();
-
-            using (var clientStream = new VirtualNetworkStream(network, isServer: false))
-            using (var serverStream = new VirtualNetworkStream(network, isServer: true))
-            using (var client = new NegotiateStream(clientStream))
-            using (var server = new NegotiateStream(serverStream))
-            {
-                await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
-                    client.AuthenticateAsClientAsync(CredentialCache.DefaultNetworkCredentials, string.Empty),
-                    server.AuthenticateAsServerAsync());
-
-                // Custom EndWrite/Read will not reset the variable which monitors concurrent write/read.
-                await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
-                    Task.Factory.FromAsync(client.BeginWrite, (ar) => { Assert.NotNull(ar); }, s_sampleMsg, 0, s_sampleMsg.Length, client),
-                    Task.Factory.FromAsync(server.BeginRead, (ar) => { Assert.NotNull(ar); }, recvBuf, 0, s_sampleMsg.Length, server));
-
-                Assert.Throws<NotSupportedException>(() => client.Write(s_sampleMsg, 0, s_sampleMsg.Length));
-                Assert.Throws<NotSupportedException>(() => server.Read(recvBuf, 0, s_sampleMsg.Length));
-            }
-        }
-
-        [Fact]
         public async Task NegotiateStream_DisposeTooEarly_Throws()
         {
             byte[] recvBuf = new byte[s_sampleMsg.Length];
@@ -336,7 +286,6 @@ namespace System.Net.Security.Tests
                         AssertExtensions.Throws<ArgumentException>(nameof(asyncResult), () => authStream.EndAuthenticateAsClient(result));
 
                         authStream.EndAuthenticateAsClient(asyncResult);
-                        Assert.Throws<InvalidOperationException>(() => authStream.EndAuthenticateAsClient(asyncResult));
                     }, CredentialCache.DefaultNetworkCredentials, string.Empty, client),
 
                     Task.Factory.FromAsync(server.BeginAuthenticateAsServer, (asyncResult) =>
@@ -348,7 +297,6 @@ namespace System.Net.Security.Tests
                         AssertExtensions.Throws<ArgumentException>(nameof(asyncResult), () => authStream.EndAuthenticateAsServer(result));
 
                         authStream.EndAuthenticateAsServer(asyncResult);
-                        Assert.Throws<InvalidOperationException>(() => authStream.EndAuthenticateAsServer(asyncResult));
                     }, server));
             }
         }
index a507020..1ffe221 100644 (file)
@@ -2,11 +2,13 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
-using System.Diagnostics;
+using System.IO;
 using System.Linq;
 using System.Net.Test.Common;
+using System.Security.Authentication.ExtendedProtection;
 using System.Security.Principal;
 using System.Text;
+using System.Threading;
 using System.Threading.Tasks;
 
 using Xunit;
@@ -19,7 +21,7 @@ namespace System.Net.Security.Tests
         public static bool IsNtlmInstalled => Capability.IsNtlmInstalled();
 
         private const int PartialBytesToRead = 5;
-        private static readonly byte[] s_sampleMsg = Encoding.UTF8.GetBytes("Sample Test Message");
+        protected static readonly byte[] s_sampleMsg = Encoding.UTF8.GetBytes("Sample Test Message");
 
         private const int MaxWriteDataSize = 63 * 1024; // NegoState.MaxWriteDataSize
         private static string s_longString = new string('A', MaxWriteDataSize) + 'Z';
@@ -27,14 +29,20 @@ namespace System.Net.Security.Tests
 
         protected abstract Task AuthenticateAsClientAsync(NegotiateStream client, NetworkCredential credential, string targetName);
         protected abstract Task AuthenticateAsServerAsync(NegotiateStream server);
-
-        [ConditionalFact(nameof(IsNtlmInstalled))]
-        public async Task NegotiateStream_StreamToStream_Authentication_Success()
+        protected abstract Task<int> ReadAsync(NegotiateStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default);
+        protected abstract Task WriteAsync(NegotiateStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default);
+        protected virtual bool SupportsCancelableReadsWrites => false;
+        protected virtual bool IsEncryptedAndSigned => true;
+
+        [ConditionalTheory(nameof(IsNtlmInstalled))]
+        [InlineData(0)]
+        [InlineData(1)]
+        public async Task NegotiateStream_StreamToStream_Authentication_Success(int delay)
         {
             VirtualNetwork network = new VirtualNetwork();
 
-            using (var clientStream = new VirtualNetworkStream(network, isServer: false))
-            using (var serverStream = new VirtualNetworkStream(network, isServer: true))
+            using (var clientStream = new VirtualNetworkStream(network, isServer: false) { DelayMilliseconds = delay })
+            using (var serverStream = new VirtualNetworkStream(network, isServer: true) { DelayMilliseconds = delay })
             using (var client = new NegotiateStream(clientStream))
             using (var server = new NegotiateStream(serverStream))
             {
@@ -49,10 +57,10 @@ namespace System.Net.Security.Tests
                 // Expected Client property values:
                 Assert.True(client.IsAuthenticated);
                 Assert.Equal(TokenImpersonationLevel.Identification, client.ImpersonationLevel);
-                Assert.True(client.IsEncrypted);
+                Assert.Equal(IsEncryptedAndSigned, client.IsEncrypted);
                 Assert.False(client.IsMutuallyAuthenticated);
                 Assert.False(client.IsServer);
-                Assert.True(client.IsSigned);
+                Assert.Equal(IsEncryptedAndSigned, client.IsSigned);
                 Assert.False(client.LeaveInnerStreamOpen);
 
                 IIdentity serverIdentity = client.RemoteIdentity;
@@ -63,10 +71,10 @@ namespace System.Net.Security.Tests
                 // Expected Server property values:
                 Assert.True(server.IsAuthenticated);
                 Assert.Equal(TokenImpersonationLevel.Identification, server.ImpersonationLevel);
-                Assert.True(server.IsEncrypted);
+                Assert.Equal(IsEncryptedAndSigned, server.IsEncrypted);
                 Assert.False(server.IsMutuallyAuthenticated);
                 Assert.True(server.IsServer);
-                Assert.True(server.IsSigned);
+                Assert.Equal(IsEncryptedAndSigned, server.IsSigned);
                 Assert.False(server.LeaveInnerStreamOpen);
 
                 IIdentity clientIdentity = server.RemoteIdentity;
@@ -78,6 +86,43 @@ namespace System.Net.Security.Tests
             }
         }
 
+        [ConditionalTheory(nameof(IsNtlmInstalled))]
+        [InlineData(0)]
+        [InlineData(1)]
+        public async Task NegotiateStream_StreamToStream_Authenticated_DisposeAsync(int delay)
+        {
+            var network = new VirtualNetwork();
+            await using (var client = new NegotiateStream(new VirtualNetworkStream(network, isServer: false) { DelayMilliseconds = delay }))
+            await using (var server = new NegotiateStream(new VirtualNetworkStream(network, isServer: true) { DelayMilliseconds = delay }))
+            {
+                Assert.False(client.IsServer);
+                Assert.False(server.IsServer);
+
+                Assert.False(client.IsAuthenticated);
+                Assert.False(server.IsAuthenticated);
+
+                Assert.False(client.IsMutuallyAuthenticated);
+                Assert.False(server.IsMutuallyAuthenticated);
+
+                Assert.False(client.IsEncrypted);
+                Assert.False(server.IsEncrypted);
+
+                Assert.False(client.IsSigned);
+                Assert.False(server.IsSigned);
+
+                await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
+                    AuthenticateAsClientAsync(client, CredentialCache.DefaultNetworkCredentials, string.Empty),
+                    AuthenticateAsServerAsync(server));
+            }
+        }
+
+        [ConditionalFact(nameof(IsNtlmInstalled))]
+        public async Task NegotiateStream_StreamToStream_Unauthenticated_Dispose()
+        {
+            new NegotiateStream(new MemoryStream()).Dispose();
+            await new NegotiateStream(new MemoryStream()).DisposeAsync();
+        }
+
         [ConditionalFact(nameof(IsNtlmInstalled))]
         public async Task NegotiateStream_StreamToStream_Authentication_TargetName_Success()
         {
@@ -105,10 +150,10 @@ namespace System.Net.Security.Tests
                 // Expected Client property values:
                 Assert.True(client.IsAuthenticated);
                 Assert.Equal(TokenImpersonationLevel.Identification, client.ImpersonationLevel);
-                Assert.True(client.IsEncrypted);
+                Assert.Equal(IsEncryptedAndSigned, client.IsEncrypted);
                 Assert.False(client.IsMutuallyAuthenticated);
                 Assert.False(client.IsServer);
-                Assert.True(client.IsSigned);
+                Assert.Equal(IsEncryptedAndSigned, client.IsSigned);
                 Assert.False(client.LeaveInnerStreamOpen);
 
                 IIdentity serverIdentity = client.RemoteIdentity;
@@ -119,10 +164,10 @@ namespace System.Net.Security.Tests
                 // Expected Server property values:
                 Assert.True(server.IsAuthenticated);
                 Assert.Equal(TokenImpersonationLevel.Identification, server.ImpersonationLevel);
-                Assert.True(server.IsEncrypted);
+                Assert.Equal(IsEncryptedAndSigned, server.IsEncrypted);
                 Assert.False(server.IsMutuallyAuthenticated);
                 Assert.True(server.IsServer);
-                Assert.True(server.IsSigned);
+                Assert.Equal(IsEncryptedAndSigned, server.IsSigned);
                 Assert.False(server.LeaveInnerStreamOpen);
 
                 IIdentity clientIdentity = server.RemoteIdentity;
@@ -165,10 +210,10 @@ namespace System.Net.Security.Tests
                 // Expected Client property values:
                 Assert.True(client.IsAuthenticated);
                 Assert.Equal(TokenImpersonationLevel.Identification, client.ImpersonationLevel);
-                Assert.True(client.IsEncrypted);
+                Assert.Equal(IsEncryptedAndSigned, client.IsEncrypted);
                 Assert.False(client.IsMutuallyAuthenticated);
                 Assert.False(client.IsServer);
-                Assert.True(client.IsSigned);
+                Assert.Equal(IsEncryptedAndSigned, client.IsSigned);
                 Assert.False(client.LeaveInnerStreamOpen);
 
                 IIdentity serverIdentity = client.RemoteIdentity;
@@ -179,10 +224,10 @@ namespace System.Net.Security.Tests
                 // Expected Server property values:
                 Assert.True(server.IsAuthenticated);
                 Assert.Equal(TokenImpersonationLevel.Identification, server.ImpersonationLevel);
-                Assert.True(server.IsEncrypted);
+                Assert.Equal(IsEncryptedAndSigned, server.IsEncrypted);
                 Assert.False(server.IsMutuallyAuthenticated);
                 Assert.True(server.IsServer);
-                Assert.True(server.IsSigned);
+                Assert.Equal(IsEncryptedAndSigned, server.IsSigned);
                 Assert.False(server.LeaveInnerStreamOpen);
 
                 IIdentity clientIdentity = server.RemoteIdentity;
@@ -195,15 +240,17 @@ namespace System.Net.Security.Tests
             }
         }
 
-        [ConditionalFact(nameof(IsNtlmInstalled))]
-        public async Task NegotiateStream_StreamToStream_Successive_ClientWrite_Sync_Success()
+        [ConditionalTheory(nameof(IsNtlmInstalled))]
+        [InlineData(0)]
+        [InlineData(1)]
+        public async Task NegotiateStream_StreamToStream_Successive_ClientWrite_Success(int delay)
         {
             byte[] recvBuf = new byte[s_sampleMsg.Length];
             VirtualNetwork network = new VirtualNetwork();
             int bytesRead = 0;
 
-            using (var clientStream = new VirtualNetworkStream(network, isServer: false))
-            using (var serverStream = new VirtualNetworkStream(network, isServer: true))
+            using (var clientStream = new VirtualNetworkStream(network, isServer: false) { DelayMilliseconds = delay })
+            using (var serverStream = new VirtualNetworkStream(network, isServer: true) { DelayMilliseconds = delay })
             using (var client = new NegotiateStream(clientStream))
             using (var server = new NegotiateStream(serverStream))
             {
@@ -216,99 +263,35 @@ namespace System.Net.Security.Tests
 
                 await TestConfiguration.WhenAllOrAnyFailedWithTimeout(auth);
 
-                client.Write(s_sampleMsg, 0, s_sampleMsg.Length);
-                server.Read(recvBuf, 0, s_sampleMsg.Length);
-
-                Assert.True(s_sampleMsg.SequenceEqual(recvBuf));
-
-                client.Write(s_sampleMsg, 0, s_sampleMsg.Length);
-
-                // Test partial sync read.
-                bytesRead = server.Read(recvBuf, 0, PartialBytesToRead);
-                Assert.Equal(PartialBytesToRead, bytesRead);
-
-                bytesRead = server.Read(recvBuf, PartialBytesToRead, s_sampleMsg.Length - PartialBytesToRead);
-                Assert.Equal(s_sampleMsg.Length - PartialBytesToRead, bytesRead);
-
-                Assert.True(s_sampleMsg.SequenceEqual(recvBuf));
-            }
-        }
-
-        [ConditionalFact(nameof(IsNtlmInstalled))]
-        public async Task NegotiateStream_StreamToStream_Successive_ClientWrite_Async_Success()
-        {
-            byte[] recvBuf = new byte[s_sampleMsg.Length];
-            VirtualNetwork network = new VirtualNetwork();
-            int bytesRead = 0;
-
-            using (var clientStream = new VirtualNetworkStream(network, isServer: false))
-            using (var serverStream = new VirtualNetworkStream(network, isServer: true))
-            using (var client = new NegotiateStream(clientStream))
-            using (var server = new NegotiateStream(serverStream))
-            {
-                Assert.False(client.IsAuthenticated);
-                Assert.False(server.IsAuthenticated);
-
-                Task[] auth = new Task[2];
-                auth[0] = AuthenticateAsClientAsync(client, CredentialCache.DefaultNetworkCredentials, string.Empty);
-                auth[1] = AuthenticateAsServerAsync(server);
-
-                await TestConfiguration.WhenAllOrAnyFailedWithTimeout(auth);
-
-                auth[0] = client.WriteAsync(s_sampleMsg, 0, s_sampleMsg.Length);
-                auth[1] = server.ReadAsync(recvBuf, 0, s_sampleMsg.Length);
+                auth[0] = WriteAsync(client, s_sampleMsg, 0, s_sampleMsg.Length);
+                auth[1] = ReadAsync(server, recvBuf, 0, s_sampleMsg.Length);
                 await TestConfiguration.WhenAllOrAnyFailedWithTimeout(auth);
                 Assert.True(s_sampleMsg.SequenceEqual(recvBuf));
 
-                await client.WriteAsync(s_sampleMsg, 0, s_sampleMsg.Length);
+                await WriteAsync(client, s_sampleMsg, 0, s_sampleMsg.Length);
 
                 // Test partial async read.
-                bytesRead = await server.ReadAsync(recvBuf, 0, PartialBytesToRead);
+                bytesRead = await ReadAsync(server, recvBuf, 0, PartialBytesToRead);
                 Assert.Equal(PartialBytesToRead, bytesRead);
 
-                bytesRead = await server.ReadAsync(recvBuf, PartialBytesToRead, s_sampleMsg.Length - PartialBytesToRead);
+                bytesRead = await ReadAsync(server, recvBuf, PartialBytesToRead, s_sampleMsg.Length - PartialBytesToRead);
                 Assert.Equal(s_sampleMsg.Length - PartialBytesToRead, bytesRead);
 
                 Assert.True(s_sampleMsg.SequenceEqual(recvBuf));
             }
         }
 
-        [ConditionalFact(nameof(IsNtlmInstalled))]
-        public async Task NegotiateStream_ReadWriteLongMsgSync_Success()
-        {
-            byte[] recvBuf = new byte[s_longMsg.Length];
-            var network = new VirtualNetwork();
-            int bytesRead = 0;
-
-            using (var clientStream = new VirtualNetworkStream(network, isServer: false))
-            using (var serverStream = new VirtualNetworkStream(network, isServer: true))
-            using (var client = new NegotiateStream(clientStream))
-            using (var server = new NegotiateStream(serverStream))
-            {
-                await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
-                    client.AuthenticateAsClientAsync(CredentialCache.DefaultNetworkCredentials, string.Empty),
-                    server.AuthenticateAsServerAsync());
-
-                client.Write(s_longMsg, 0, s_longMsg.Length);
-
-                while (bytesRead < s_longMsg.Length)
-                {
-                    bytesRead += server.Read(recvBuf, bytesRead, s_longMsg.Length - bytesRead);
-                }
-
-                Assert.True(s_longMsg.SequenceEqual(recvBuf));
-            }
-        }
-
-        [ConditionalFact(nameof(IsNtlmInstalled))]
-        public async Task NegotiateStream_ReadWriteLongMsgAsync_Success()
+        [ConditionalTheory(nameof(IsNtlmInstalled))]
+        [InlineData(0)]
+        [InlineData(1)]
+        public async Task NegotiateStream_ReadWriteLongMsg_Success(int delay)
         {
             byte[] recvBuf = new byte[s_longMsg.Length];
             var network = new VirtualNetwork();
             int bytesRead = 0;
 
-            using (var clientStream = new VirtualNetworkStream(network, isServer: false))
-            using (var serverStream = new VirtualNetworkStream(network, isServer: true))
+            using (var clientStream = new VirtualNetworkStream(network, isServer: false) { DelayMilliseconds = delay })
+            using (var serverStream = new VirtualNetworkStream(network, isServer: true) { DelayMilliseconds = delay })
             using (var client = new NegotiateStream(clientStream))
             using (var server = new NegotiateStream(serverStream))
             {
@@ -316,11 +299,11 @@ namespace System.Net.Security.Tests
                     client.AuthenticateAsClientAsync(CredentialCache.DefaultNetworkCredentials, string.Empty),
                     server.AuthenticateAsServerAsync());
 
-                await client.WriteAsync(s_longMsg, 0, s_longMsg.Length);
+                await WriteAsync(client, s_longMsg, 0, s_longMsg.Length);
 
                 while (bytesRead < s_longMsg.Length)
                 {
-                    bytesRead += await server.ReadAsync(recvBuf, bytesRead, s_longMsg.Length - bytesRead);
+                    bytesRead += await ReadAsync(server, recvBuf, bytesRead, s_longMsg.Length - bytesRead);
                 }
 
                 Assert.True(s_longMsg.SequenceEqual(recvBuf));
@@ -356,18 +339,91 @@ namespace System.Net.Security.Tests
                 Assert.True(task.IsCompleted);
             }
         }
+
+        [ConditionalFact(nameof(IsNtlmInstalled))]
+        public async Task NegotiateStream_StreamToStream_Successive_CancelableReadsWrites()
+        {
+            if (!SupportsCancelableReadsWrites)
+            {
+                return;
+            }
+
+            byte[] recvBuf = new byte[s_sampleMsg.Length];
+            VirtualNetwork network = new VirtualNetwork();
+
+            using (var clientStream = new VirtualNetworkStream(network, isServer: false))
+            using (var serverStream = new VirtualNetworkStream(network, isServer: true))
+            using (var client = new NegotiateStream(clientStream))
+            using (var server = new NegotiateStream(serverStream))
+            {
+                await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
+                    AuthenticateAsClientAsync(client, CredentialCache.DefaultNetworkCredentials, string.Empty),
+                    AuthenticateAsServerAsync(server));
+
+                clientStream.DelayMilliseconds = int.MaxValue;
+                serverStream.DelayMilliseconds = int.MaxValue;
+
+                var cts = new CancellationTokenSource();
+                Task t = WriteAsync(client, s_sampleMsg, 0, s_sampleMsg.Length, cts.Token);
+                Assert.False(t.IsCompleted);
+                cts.Cancel();
+                await Assert.ThrowsAnyAsync<OperationCanceledException>(() => t);
+
+                cts = new CancellationTokenSource();
+                t = ReadAsync(server, s_sampleMsg, 0, s_sampleMsg.Length, cts.Token);
+                Assert.False(t.IsCompleted);
+                cts.Cancel();
+                await Assert.ThrowsAnyAsync<OperationCanceledException>(() => t);
+            }
+        }
     }
 
-    public sealed class NegotiateStreamStreamToStreamTest_Async : NegotiateStreamStreamToStreamTest
+    public sealed class NegotiateStreamStreamToStreamTest_Async_Array : NegotiateStreamStreamToStreamTest
     {
         protected override Task AuthenticateAsClientAsync(NegotiateStream client, NetworkCredential credential, string targetName) =>
             client.AuthenticateAsClientAsync(credential, targetName);
 
         protected override Task AuthenticateAsServerAsync(NegotiateStream server) =>
             server.AuthenticateAsServerAsync();
+
+        protected override Task<int> ReadAsync(NegotiateStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
+            stream.ReadAsync(buffer, offset, count, cancellationToken);
+
+        protected override Task WriteAsync(NegotiateStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
+            stream.WriteAsync(buffer, offset, count, cancellationToken);
+
+        protected override bool SupportsCancelableReadsWrites => true;
     }
 
-    public sealed class NegotiateStreamStreamToStreamTest_Async_TestOverloadNullBinding : NegotiateStreamStreamToStreamTest
+    public class NegotiateStreamStreamToStreamTest_Async_Memory : NegotiateStreamStreamToStreamTest
+    {
+        protected override Task AuthenticateAsClientAsync(NegotiateStream client, NetworkCredential credential, string targetName) =>
+            client.AuthenticateAsClientAsync(credential, targetName);
+
+        protected override Task AuthenticateAsServerAsync(NegotiateStream server) =>
+            server.AuthenticateAsServerAsync();
+
+        protected override Task<int> ReadAsync(NegotiateStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
+            stream.ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
+
+        protected override Task WriteAsync(NegotiateStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
+            stream.WriteAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
+
+        protected override bool SupportsCancelableReadsWrites => true;
+    }
+
+    public class NegotiateStreamStreamToStreamTest_Async_Memory_NotEncrypted : NegotiateStreamStreamToStreamTest_Async_Memory
+    {
+        protected override Task AuthenticateAsClientAsync(NegotiateStream client, NetworkCredential credential, string targetName) =>
+            client.AuthenticateAsClientAsync(credential, targetName, ProtectionLevel.None, TokenImpersonationLevel.Identification);
+
+        protected override Task AuthenticateAsServerAsync(NegotiateStream server) =>
+            server.AuthenticateAsServerAsync(CredentialCache.DefaultNetworkCredentials, ProtectionLevel.None, TokenImpersonationLevel.Identification);
+
+        protected override bool IsEncryptedAndSigned => false;
+    }
+
+    public sealed class NegotiateStreamStreamToStreamTest_Async_TestOverloadNullBinding : NegotiateStreamStreamToStreamTest_Async_Memory
     {
         protected override Task AuthenticateAsClientAsync(NegotiateStream client, NetworkCredential credential, string targetName) =>
             client.AuthenticateAsClientAsync(credential, null, targetName);
@@ -376,7 +432,7 @@ namespace System.Net.Security.Tests
             server.AuthenticateAsServerAsync(null);
     }
 
-    public sealed class NegotiateStreamStreamToStreamTest_Async_TestOverloadProtectionLevel : NegotiateStreamStreamToStreamTest
+    public sealed class NegotiateStreamStreamToStreamTest_Async_TestOverloadProtectionLevel : NegotiateStreamStreamToStreamTest_Async_Memory
     {
         protected override Task AuthenticateAsClientAsync(NegotiateStream client, NetworkCredential credential, string targetName) =>
             client.AuthenticateAsClientAsync(credential, targetName, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification);
@@ -385,7 +441,7 @@ namespace System.Net.Security.Tests
             server.AuthenticateAsServerAsync((NetworkCredential)CredentialCache.DefaultCredentials, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification);
     }
 
-    public sealed class NegotiateStreamStreamToStreamTest_Async_TestOverloadAllParameters : NegotiateStreamStreamToStreamTest
+    public sealed class NegotiateStreamStreamToStreamTest_Async_TestOverloadAllParameters : NegotiateStreamStreamToStreamTest_Async_Memory
     {
         protected override Task AuthenticateAsClientAsync(NegotiateStream client, NetworkCredential credential, string targetName) =>
             client.AuthenticateAsClientAsync(credential, null, targetName, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification);
@@ -394,25 +450,62 @@ namespace System.Net.Security.Tests
             server.AuthenticateAsServerAsync((NetworkCredential)CredentialCache.DefaultCredentials, null, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification);
     }
 
-    public sealed class NegotiateStreamStreamToStreamTest_BeginEnd : NegotiateStreamStreamToStreamTest
+    public class NegotiateStreamStreamToStreamTest_BeginEnd : NegotiateStreamStreamToStreamTest
     {
         protected override Task AuthenticateAsClientAsync(NegotiateStream client, NetworkCredential credential, string targetName) =>
             Task.Factory.FromAsync(client.BeginAuthenticateAsClient, client.EndAuthenticateAsClient, credential, targetName, null);
 
         protected override Task AuthenticateAsServerAsync(NegotiateStream server) =>
             Task.Factory.FromAsync(server.BeginAuthenticateAsServer, server.EndAuthenticateAsServer, null);
+
+        protected override Task<int> ReadAsync(NegotiateStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
+            Task.Factory.FromAsync(stream.BeginRead, stream.EndRead, buffer, offset, count, null);
+
+        protected override Task WriteAsync(NegotiateStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
+            Task.Factory.FromAsync(stream.BeginWrite, stream.EndWrite, buffer, offset, count, null);
     }
 
-    public sealed class NegotiateStreamStreamToStreamTest_Sync : NegotiateStreamStreamToStreamTest
+    public sealed class NegotiateStreamStreamToStreamTest_BeginEnd_TestOverloadNullBinding : NegotiateStreamStreamToStreamTest_BeginEnd
+    {
+        protected override Task AuthenticateAsClientAsync(NegotiateStream client, NetworkCredential credential, string targetName) =>
+            Task.Factory.FromAsync(client.BeginAuthenticateAsClient, client.EndAuthenticateAsClient, credential, (ChannelBinding)null, targetName, null);
+
+        protected override Task AuthenticateAsServerAsync(NegotiateStream server) =>
+            Task.Factory.FromAsync(server.BeginAuthenticateAsServer, server.EndAuthenticateAsServer, (ExtendedProtectionPolicy)null, null);
+    }
+
+    public sealed class NegotiateStreamStreamToStreamTest_BeginEnd_TestOverloadProtectionLevel : NegotiateStreamStreamToStreamTest_BeginEnd
+    {
+        protected override Task AuthenticateAsClientAsync(NegotiateStream client, NetworkCredential credential, string targetName) =>
+            Task.Factory.FromAsync(
+                (callback, state) => client.BeginAuthenticateAsClient(credential, targetName, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification, callback, state),
+                client.EndAuthenticateAsClient, null);
+
+        protected override Task AuthenticateAsServerAsync(NegotiateStream server) =>
+            Task.Factory.FromAsync(
+                (callback, state) => server.BeginAuthenticateAsServer((NetworkCredential)CredentialCache.DefaultCredentials, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification, callback, state),
+                server.EndAuthenticateAsServer, null);
+    }
+
+    public class NegotiateStreamStreamToStreamTest_Sync : NegotiateStreamStreamToStreamTest
     {
         protected override Task AuthenticateAsClientAsync(NegotiateStream client, NetworkCredential credential, string targetName) =>
             Task.Run(() => client.AuthenticateAsClient(credential, targetName));
 
         protected override Task AuthenticateAsServerAsync(NegotiateStream server) =>
             Task.Run(() => server.AuthenticateAsServer());
+
+        protected override Task<int> ReadAsync(NegotiateStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
+            Task.FromResult(stream.Read(buffer, offset, count));
+
+        protected override Task WriteAsync(NegotiateStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+        {
+            stream.Write(buffer, offset, count);
+            return Task.CompletedTask;
+        }
     }
 
-    public sealed class NegotiateStreamStreamToStreamTest_Sync_TestOverloadNullBinding : NegotiateStreamStreamToStreamTest
+    public sealed class NegotiateStreamStreamToStreamTest_Sync_TestOverloadNullBinding : NegotiateStreamStreamToStreamTest_Sync
     {
         protected override Task AuthenticateAsClientAsync(NegotiateStream client, NetworkCredential credential, string targetName) =>
             Task.Run(() => client.AuthenticateAsClient(credential, null, targetName));
@@ -421,7 +514,7 @@ namespace System.Net.Security.Tests
             Task.Run(() => server.AuthenticateAsServer(null));
     }
 
-    public sealed class NegotiateStreamStreamToStreamTest_Sync_TestOverloadAllParameters : NegotiateStreamStreamToStreamTest
+    public sealed class NegotiateStreamStreamToStreamTest_Sync_TestOverloadAllParameters : NegotiateStreamStreamToStreamTest_Sync
     {
         protected override Task AuthenticateAsClientAsync(NegotiateStream client, NetworkCredential credential, string targetName) =>
             Task.Run(() => client.AuthenticateAsClient(credential, targetName, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification));
@@ -429,4 +522,15 @@ namespace System.Net.Security.Tests
         protected override Task AuthenticateAsServerAsync(NegotiateStream server) =>
             Task.Run(() => server.AuthenticateAsServer((NetworkCredential)CredentialCache.DefaultCredentials, ProtectionLevel.EncryptAndSign, TokenImpersonationLevel.Identification));
     }
+
+    public class NegotiateStreamStreamToStreamTest_Sync_NotEncrypted : NegotiateStreamStreamToStreamTest_Sync
+    {
+        protected override Task AuthenticateAsClientAsync(NegotiateStream client, NetworkCredential credential, string targetName) =>
+            Task.Run(() => client.AuthenticateAsClient(credential, targetName, ProtectionLevel.None, TokenImpersonationLevel.Identification));
+
+        protected override Task AuthenticateAsServerAsync(NegotiateStream server) =>
+            Task.Run(() => server.AuthenticateAsServer((NetworkCredential)CredentialCache.DefaultCredentials, ProtectionLevel.None, TokenImpersonationLevel.Identification));
+
+        protected override bool IsEncryptedAndSigned => false;
+    }
 }
diff --git a/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeLazyAsyncResult.cs b/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeLazyAsyncResult.cs
deleted file mode 100644 (file)
index 709fc58..0000000
+++ /dev/null
@@ -1,43 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System.Threading;
-
-namespace System.Net.Security
-{
-    internal class LazyAsyncResult : IAsyncResult
-    {
-        public LazyAsyncResult(SslStream sslState, object asyncState, AsyncCallback asyncCallback)
-        {
-            AsyncState = asyncState;
-            asyncCallback?.Invoke(this);
-        }
-
-        public object AsyncState { get; }
-
-        public WaitHandle AsyncWaitHandle
-        {
-            get
-            {
-                throw new NotImplementedException();
-            }
-        }
-
-        public bool CompletedSynchronously
-        {
-            get
-            {
-                return true;
-            }
-        }
-
-        public bool IsCompleted
-        {
-            get
-            {
-                return true;
-            }
-        }
-    }
-}
index 3dffb86..ccd294b 100644 (file)
@@ -38,7 +38,7 @@ namespace System.Net.Security
         }
 
         private ValueTask WriteAsyncInternal<TWriteAdapter>(TWriteAdapter writeAdapter, ReadOnlyMemory<byte> buffer)
-            where TWriteAdapter : struct, ISslIOAdapter => default;
+            where TWriteAdapter : struct, IReadWriteAdapter => default;
 
         private ValueTask<int> ReadAsyncInternal<TReadAdapter>(TReadAdapter adapter, Memory<byte> buffer) => default;
 
index 49823f8..50d702a 100644 (file)
@@ -24,7 +24,6 @@
     <!-- Fakes -->
     <Compile Include="Fakes\FakeSslStream.Implementation.cs" />
     <Compile Include="Fakes\FakeAuthenticatedStream.cs" />
-    <Compile Include="Fakes\FakeLazyAsyncResult.cs" />
     <!-- Common test files -->
     <Compile Include="$(CommonTestPath)System\Net\SslProtocolSupport.cs"
              Link="CommonTest\System\Net\SslProtocolSupport.cs" />
@@ -45,8 +44,8 @@
              Link="ProductionCode\System\Net\Security\SslApplicationProtocol.cs" />
     <Compile Include="..\..\src\System\Net\Security\SslConnectionInfo.cs"
              Link="ProductionCode\System\Net\Security\SslConnectionInfo.cs" />
-    <Compile Include="..\..\src\System\Net\Security\SslStream.Implementation.Adapters.cs"
-             Link="ProductionCode\System\Net\Security\SslStream.Implementation.Adapters.cs" />
+    <Compile Include="..\..\src\System\Net\Security\ReadWriteAdapter.cs"
+             Link="ProductionCode\System\Net\Security\ReadWriteAdapter.cs" />
     <Compile Include="..\..\src\System\Net\SslStreamContext.cs"
              Link="ProductionCode\System\Net\SslStreamContext.cs" />
     <Compile Include="$(CommonPath)System\Net\SecurityProtocol.cs"