propagate TLS alerts from OS layers (dotnet/corefx#41967)
authorTomas Weinfurt <tweinfurt@yahoo.com>
Wed, 13 Nov 2019 21:50:57 +0000 (13:50 -0800)
committerGitHub <noreply@github.com>
Wed, 13 Nov 2019 21:50:57 +0000 (13:50 -0800)
* initial alerts with openssl

* get alerts from schannel

* update tests to work with openssl 1.1.x

* fix ClientAsyncAuthenticate_ServerNoEncryption_NoConnect to work properly with Tls13

* remove extra comment

* feedback from review

* feedback from review

* remove unused variable

Commit migrated from https://github.com/dotnet/corefx/commit/784cb6b5b7e947d3a69c7183847652a0e9335ff0

src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs
src/libraries/Common/src/Interop/Windows/SspiCli/SecuritySafeHandles.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs
src/libraries/System.Net.Security/tests/FunctionalTests/ClientAsyncAuthenticateTest.cs
src/libraries/System.Net.Security/tests/FunctionalTests/ClientDefaultEncryptionTest.cs
src/libraries/System.Net.Security/tests/FunctionalTests/ServerAsyncAuthenticateTest.cs
src/libraries/System.Net.Security/tests/FunctionalTests/ServerNoEncryptionTest.cs
src/libraries/System.Net.Security/tests/FunctionalTests/ServerRequireEncryptionTest.cs
src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamAlpnTests.cs
src/libraries/System.Net.Security/tests/FunctionalTests/TestConfiguration.cs

index 58b4ef1..e505f3e 100644 (file)
@@ -257,6 +257,7 @@ internal static partial class Interop
         {
             sendBuf = null;
             sendCount = 0;
+            Exception handshakeException = null;
 
             if ((recvBuf != null) && (recvCount > 0))
             {
@@ -275,7 +276,10 @@ internal static partial class Interop
 
                 if ((retVal != -1) || (error != Ssl.SslErrorCode.SSL_ERROR_WANT_READ))
                 {
-                    throw new SslException(SR.Format(SR.net_ssl_handshake_failed_error, error), innerError);
+                    // Handshake failed, but even if the handshake does not need to read, there may be an Alert going out.
+                    // To handle that we will fall-through the block below to pull it out, and we will fail after.
+                    handshakeException = new SslException(SR.Format(SR.net_ssl_handshake_failed_error, error), innerError);
+                    Crypto.ErrClearError();
                 }
             }
 
@@ -288,6 +292,10 @@ internal static partial class Interop
                 {
                     sendCount = BioRead(context.OutputBio, sendBuf, sendCount);
                 }
+                catch (Exception) when (handshakeException != null)
+                {
+                    // If we already have handshake exception, ignore any exception from BioRead().
+                }
                 finally
                 {
                     if (sendCount <= 0)
@@ -300,6 +308,11 @@ internal static partial class Interop
                 }
             }
 
+            if (handshakeException != null)
+            {
+                throw handshakeException;
+            }
+
             bool stateOk = Ssl.IsSslStateOK(context);
             if (stateOk)
             {
index a32c0d4..e5b0249 100644 (file)
@@ -644,7 +644,7 @@ namespace System.Net.Security
             }
 
             Interop.SspiCli.SecBufferDesc inSecurityBufferDescriptor = new Interop.SspiCli.SecBufferDesc(inSecBuffers.Length);
-            Interop.SspiCli.SecBufferDesc outSecurityBufferDescriptor = new Interop.SspiCli.SecBufferDesc(1);
+            Interop.SspiCli.SecBufferDesc outSecurityBufferDescriptor = new Interop.SspiCli.SecBufferDesc(count: 2);
 
             // Actually, this is returned in outFlags.
             bool isSspiAllocated = (inFlags & Interop.SspiCli.ContextFlags.AllocateMemory) != 0 ? true : false;
@@ -659,12 +659,15 @@ namespace System.Net.Security
 
             // Optional output buffer that may need to be freed.
             SafeFreeContextBuffer outFreeContextBuffer = null;
+            Span<Interop.SspiCli.SecBuffer> outUnmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[2];
+            outUnmanagedBuffer[1].pvBuffer = IntPtr.Zero;
             try
             {
                 Span<Interop.SspiCli.SecBuffer> inUnmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[inSecurityBufferDescriptor.cBuffers];
                 inUnmanagedBuffer.Clear();
 
                 fixed (void* inUnmanagedBufferPtr = inUnmanagedBuffer)
+                fixed (void* outUnmanagedBufferPtr = outUnmanagedBuffer)
                 fixed (void* pinnedToken0 = inSecBuffers.Length > 0 ? inSecBuffers[0].token : null)
                 fixed (void* pinnedToken1 = inSecBuffers.Length > 1 ? inSecBuffers[1].token : null)
                 fixed (void* pinnedToken2 = inSecBuffers.Length > 2 ? inSecBuffers[2].token : null) // pin all buffers, even if null or not used, to avoid needing to allocate GCHandles
@@ -694,16 +697,18 @@ namespace System.Net.Security
                     fixed (byte* pinnedOutBytes = outSecBuffer.token)
                     {
                         // Fix Descriptor pointer that points to unmanaged SecurityBuffers.
-                        Interop.SspiCli.SecBuffer outUnmanagedBuffer = default;
-                        outSecurityBufferDescriptor.pBuffers = &outUnmanagedBuffer;
+                        outSecurityBufferDescriptor.pBuffers = outUnmanagedBufferPtr;
 
                         // Copy the SecurityBuffer content into unmanaged place holder.
-                        outUnmanagedBuffer.cbBuffer = outSecBuffer.size;
-                        outUnmanagedBuffer.BufferType = outSecBuffer.type;
-                        outUnmanagedBuffer.pvBuffer = outSecBuffer.token == null || outSecBuffer.token.Length == 0 ?
+                        outUnmanagedBuffer[0].cbBuffer = outSecBuffer.size;
+                        outUnmanagedBuffer[0].BufferType = outSecBuffer.type;
+                        outUnmanagedBuffer[0].pvBuffer = outSecBuffer.token == null || outSecBuffer.token.Length == 0 ?
                             IntPtr.Zero :
                             (IntPtr)(pinnedOutBytes + outSecBuffer.offset);
 
+                        outUnmanagedBuffer[1].cbBuffer = 0;
+                        outUnmanagedBuffer[1].BufferType = SecurityBufferType.SECBUFFER_ALERT;
+
                         if (isSspiAllocated)
                         {
                             outFreeContextBuffer = SafeFreeContextBuffer.CreateEmptyHandle();
@@ -731,18 +736,31 @@ namespace System.Net.Security
 
                         if (NetEventSource.IsEnabled) NetEventSource.Info(null, "Marshaling OUT buffer");
 
-                        // Get unmanaged buffer with index 0 as the only one passed into PInvoke.
-                        outSecBuffer.size = outUnmanagedBuffer.cbBuffer;
-                        outSecBuffer.type = outUnmanagedBuffer.BufferType;
-                        outSecBuffer.token = outUnmanagedBuffer.cbBuffer > 0 ?
-                            new Span<byte>((byte*)outUnmanagedBuffer.pvBuffer, outUnmanagedBuffer.cbBuffer).ToArray() :
-                            null;
+                        // No data written out but there is Alert
+                        if (outUnmanagedBuffer[0].cbBuffer == 0 && outUnmanagedBuffer[1].cbBuffer > 0)
+                        {
+                            outSecBuffer.size = outUnmanagedBuffer[1].cbBuffer;
+                            outSecBuffer.type = outUnmanagedBuffer[1].BufferType;
+                            outSecBuffer.token = new Span<byte>((byte*)outUnmanagedBuffer[1].pvBuffer, outUnmanagedBuffer[1].cbBuffer).ToArray();
+                        }
+                        else
+                        {
+                             outSecBuffer.size = outUnmanagedBuffer[0].cbBuffer;
+                             outSecBuffer.type = outUnmanagedBuffer[0].BufferType;
+                             outSecBuffer.token = outUnmanagedBuffer[0].cbBuffer > 0 ?
+                                 new Span<byte>((byte*)outUnmanagedBuffer[0].pvBuffer, outUnmanagedBuffer[0].cbBuffer).ToArray() :
+                                 null;
+                        }
                     }
                 }
             }
             finally
             {
                 outFreeContextBuffer?.Dispose();
+                if (outUnmanagedBuffer[1].pvBuffer != IntPtr.Zero)
+                {
+                    Interop.SspiCli.FreeContextBuffer(outUnmanagedBuffer[1].pvBuffer);
+                }
             }
 
             if (NetEventSource.IsEnabled) NetEventSource.Exit(null, $"errorCode:0x{errorCode:x8}, refContext:{refContext}");
index 8857db6..4abb400 100644 (file)
@@ -104,6 +104,9 @@ namespace System.Net.Security
         {
             Debug.Assert(!credential.IsInvalid);
 
+            byte[] output = null;
+            int outputSize = 0;
+
             try
             {
                 if ((null == context) || context.IsInvalid)
@@ -111,8 +114,6 @@ namespace System.Net.Security
                     context = new SafeDeleteSslContext(credential as SafeFreeSslCredentials, sslAuthenticationOptions);
                 }
 
-                byte[] output = null;
-                int outputSize;
                 bool done;
 
                 if (inputBuffer.Array == null)
@@ -143,6 +144,12 @@ namespace System.Net.Security
             }
             catch (Exception exc)
             {
+                // Even if handshake failed we may have Alert to sent.
+                if (outputSize > 0)
+                {
+                    outputBuffer = outputSize == output.Length ? output : new Span<byte>(output, 0, outputSize).ToArray();
+                }
+
                 return new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError, exc);
             }
         }
index 4d0d13b..5b9fa73 100644 (file)
@@ -25,7 +25,7 @@ namespace System.Net.Security
             Interop.SspiCli.ContextFlags.AllocateMemory;
 
         private const Interop.SspiCli.ContextFlags ServerRequiredFlags =
-            RequiredFlags | Interop.SspiCli.ContextFlags.AcceptStream;
+            RequiredFlags | Interop.SspiCli.ContextFlags.AcceptStream | Interop.SspiCli.ContextFlags.AcceptExtendedError;
 
         public static Exception GetException(SecurityStatusPal status)
         {
index 545111f..69c9673 100644 (file)
@@ -43,7 +43,15 @@ namespace System.Net.Security.Tests
         [Fact]
         public async Task ClientAsyncAuthenticate_ServerNoEncryption_NoConnect()
         {
-            await Assert.ThrowsAsync<IOException>(() => ClientAsyncSslHelper(EncryptionPolicy.NoEncryption));
+            // Don't use Tls13 since we are trying to use NullEncryption
+            Type expectedExceptionType = TestConfiguration.SupportsHandshakeAlerts && TestConfiguration.SupportsNullEncryption ?
+                typeof(AuthenticationException) :
+                typeof(IOException);
+
+            await Assert.ThrowsAsync(expectedExceptionType,
+                () => ClientAsyncSslHelper(
+                    EncryptionPolicy.NoEncryption,
+                    SslProtocolSupport.DefaultSslProtocols,  SslProtocols.Tls | SslProtocols.Tls11 |  SslProtocols.Tls12 ));
         }
 
         [Theory]
@@ -112,12 +120,12 @@ namespace System.Net.Security.Tests
             yield return new object[] { SslProtocols.Ssl2, SslProtocols.Tls12, typeof(Exception) };
             yield return new object[] { SslProtocols.Ssl3, SslProtocols.Tls12, typeof(Exception) };
 #pragma warning restore 0618
-            yield return new object[] { SslProtocols.Tls, SslProtocols.Tls11, typeof(IOException) };
-            yield return new object[] { SslProtocols.Tls, SslProtocols.Tls12, typeof(IOException) };
+            yield return new object[] { SslProtocols.Tls, SslProtocols.Tls11, TestConfiguration.SupportsVersionAlerts ? typeof(AuthenticationException) : typeof(IOException) };
+            yield return new object[] { SslProtocols.Tls, SslProtocols.Tls12, TestConfiguration.SupportsVersionAlerts ? typeof(AuthenticationException) : typeof(IOException) };
             yield return new object[] { SslProtocols.Tls11, SslProtocols.Tls, typeof(AuthenticationException) };
             yield return new object[] { SslProtocols.Tls12, SslProtocols.Tls, typeof(AuthenticationException) };
             yield return new object[] { SslProtocols.Tls12, SslProtocols.Tls11, typeof(AuthenticationException) };
-            yield return new object[] { SslProtocols.Tls11, SslProtocols.Tls12, typeof(IOException) };
+            yield return new object[] { SslProtocols.Tls11, SslProtocols.Tls12, TestConfiguration.SupportsVersionAlerts ? typeof(AuthenticationException) : typeof(IOException) };
         }
 
         #region Helpers
index 99d72b8..4a57ca9 100644 (file)
@@ -84,7 +84,7 @@ namespace System.Net.Security.Tests
 
                 using (var sslStream = new SslStream(client.GetStream(), false, AllowAnyServerCertificate, null))
                 {
-                    await Assert.ThrowsAsync<IOException>(() =>
+                    await Assert.ThrowsAsync(TestConfiguration.SupportsHandshakeAlerts ? typeof(AuthenticationException) : typeof(IOException), () =>
                         sslStream.AuthenticateAsClientAsync("localhost", null, SslProtocolSupport.DefaultSslProtocols, false));
                 }
             }
index c2fe64a..46d6670 100644 (file)
@@ -79,10 +79,10 @@ namespace System.Net.Security.Tests
 #pragma warning restore 0618
             yield return new object[] { SslProtocols.Tls, SslProtocols.Tls11, typeof(AuthenticationException) };
             yield return new object[] { SslProtocols.Tls, SslProtocols.Tls12, typeof(AuthenticationException) };
-            yield return new object[] { SslProtocols.Tls11, SslProtocols.Tls, typeof(TimeoutException) };
+            yield return new object[] { SslProtocols.Tls11, SslProtocols.Tls, TestConfiguration.SupportsVersionAlerts ? typeof(AuthenticationException) : typeof(TimeoutException) };
             yield return new object[] { SslProtocols.Tls11, SslProtocols.Tls12, typeof(AuthenticationException) };
-            yield return new object[] { SslProtocols.Tls12, SslProtocols.Tls, typeof(TimeoutException) };
-            yield return new object[] { SslProtocols.Tls12, SslProtocols.Tls11, typeof(TimeoutException) };
+            yield return new object[] { SslProtocols.Tls12, SslProtocols.Tls, TestConfiguration.SupportsVersionAlerts ? typeof(AuthenticationException) : typeof(TimeoutException) };
+            yield return new object[] { SslProtocols.Tls12, SslProtocols.Tls11, TestConfiguration.SupportsVersionAlerts ? typeof(AuthenticationException) : typeof(TimeoutException) };
         }
 
         #region Helpers
index c1a3ddb..5cade97 100644 (file)
@@ -43,7 +43,7 @@ namespace System.Net.Security.Tests
 
                 using (var sslStream = new SslStream(client.GetStream(), false, AllowAnyServerCertificate, null, EncryptionPolicy.RequireEncryption))
                 {
-                    await Assert.ThrowsAsync<IOException>(() =>
+                    await Assert.ThrowsAsync(TestConfiguration.SupportsHandshakeAlerts ? typeof(AuthenticationException) : typeof(IOException), () =>
                         sslStream.AuthenticateAsClientAsync("localhost", null, SslProtocolSupport.DefaultSslProtocols, false));
                 }
             }
index 8d6c3ba..1b15a40 100644 (file)
@@ -82,7 +82,7 @@ namespace System.Net.Security.Tests
                 await client.ConnectAsync(serverRequireEncryption.RemoteEndPoint.Address, serverRequireEncryption.RemoteEndPoint.Port);
                 using (var sslStream = new SslStream(client.GetStream(), false, AllowAnyServerCertificate, null, EncryptionPolicy.NoEncryption))
                 {
-                    await Assert.ThrowsAsync<IOException>(() =>
+                    await Assert.ThrowsAsync(TestConfiguration.SupportsHandshakeAlerts ? typeof(AuthenticationException) : typeof(IOException), () =>
                         sslStream.AuthenticateAsClientAsync("localhost", null, SslProtocols.Tls | SslProtocols.Tls11 |  SslProtocols.Tls12, false));
                 }
             }
index 27d84e0..7c48c81 100644 (file)
@@ -161,7 +161,10 @@ namespace System.Net.Security.Tests
                         // Test alpn failure only on platforms that supports ALPN.
                         if (BackendSupportsAlpn)
                         {
-                            Task t1 = Assert.ThrowsAsync<IOException>(() => clientStream.AuthenticateAsClientAsync(clientOptions, CancellationToken.None));
+                            // schannel sends alert on ALPN failure, openssl does not.
+                            Task t1 = Assert.ThrowsAsync(TestConfiguration.SupportsAlpnAlerts ? typeof(AuthenticationException) : typeof(IOException), () =>
+                                clientStream.AuthenticateAsClientAsync(clientOptions, CancellationToken.None));
+
                             try
                             {
                                 await serverStream.AuthenticateAsServerAsync(serverOptions, CancellationToken.None);
index 92903c9..9769590 100644 (file)
@@ -30,6 +30,12 @@ namespace System.Net.Security.Tests
 
         public static bool SupportsNullEncryption { get { return s_supportsNullEncryption.Value; } }
 
+        public static bool SupportsHandshakeAlerts { get { return RuntimeInformation.IsOSPlatform(OSPlatform.Linux) || RuntimeInformation.IsOSPlatform(OSPlatform.Windows); } }
+
+        public static bool SupportsAlpnAlerts { get { return RuntimeInformation.IsOSPlatform(OSPlatform.Windows) || (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)  && PlatformDetection.OpenSslVersion.CompareTo(new Version(1,1,0)) >= 0); } }
+
+        public static bool SupportsVersionAlerts { get { return RuntimeInformation.IsOSPlatform(OSPlatform.Linux)  && PlatformDetection.OpenSslVersion.CompareTo(new Version(1,1,0)) >= 0; } }
+
         public static Task WhenAllOrAnyFailedWithTimeout(params Task[] tasks)
             => tasks.WhenAllOrAnyFailed(PassingTestTimeoutMilliseconds);