Avoid using X509Certificate2's IntPtr ctor in SecureChannel (dotnet/corefx#38344)
authorStephen Toub <stoub@microsoft.com>
Fri, 7 Jun 2019 20:10:54 +0000 (16:10 -0400)
committerGitHub <noreply@github.com>
Fri, 7 Jun 2019 20:10:54 +0000 (16:10 -0400)
* Avoid using X509Certificate2's IntPtr ctor in SecureChannel

* Address PR feedback

Commit migrated from https://github.com/dotnet/corefx/commit/82e49157906d6c9c1ed7013e99d50dca74e7fafc

src/libraries/Common/tests/System/Net/Configuration.Certificates.cs
src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs
src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs
src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.netcoreapp.cs

index 439fcb7..03d1aad 100644 (file)
@@ -75,7 +75,7 @@ namespace System.Net.Test.Common
                     return new X509Certificate2(
                         File.ReadAllBytes(Path.Combine(TestDataFolder, certificateFileName)),
                         CertificatePassword,
-                        X509KeyStorageFlags.DefaultKeySet);
+                        X509KeyStorageFlags.DefaultKeySet | X509KeyStorageFlags.Exportable);
                 }
                 catch (Exception ex)
                 {
index cf8d0e6..f43f355 100644 (file)
@@ -268,7 +268,7 @@ namespace System.Net.Security
             {
                 if (certificate.Handle != IntPtr.Zero)
                 {
-                    certificateEx = new X509Certificate2(certificate.Handle);
+                    certificateEx = new X509Certificate2(certificate);
                 }
             }
             catch (SecurityException) { }
index 76fd9ad..0e36c49 100644 (file)
@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
+using System.Collections.Generic;
 using System.IO;
 using System.Linq;
 using System.Net.Test.Common;
@@ -10,7 +11,6 @@ using System.Security.Cryptography.X509Certificates;
 using System.Text;
 using System.Threading;
 using System.Threading.Tasks;
-
 using Xunit;
 
 namespace System.Net.Security.Tests
@@ -21,27 +21,66 @@ namespace System.Net.Security.Tests
     {
         private readonly byte[] _sampleMsg = Encoding.UTF8.GetBytes("Sample Test Message");
 
-        protected abstract Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream);
+        protected static async Task WithServerCertificate(X509Certificate serverCertificate, Func<X509Certificate, string, Task> func)
+        {
+            X509Certificate certificate = serverCertificate ?? Configuration.Certificates.GetServerCertificate();
+            try
+            {
+                string name;
+                if (certificate is X509Certificate2 cert2)
+                {
+                    name = cert2.GetNameInfo(X509NameType.SimpleName, forIssuer: false);
+                }
+                else
+                {
+                    using (cert2 = new X509Certificate2(certificate))
+                    {
+                        name = cert2.GetNameInfo(X509NameType.SimpleName, forIssuer: false);
+                    }
+                }
+
+                await func(certificate, name).ConfigureAwait(false);
+            }
+            finally
+            {
+                if (certificate != serverCertificate)
+                {
+                    certificate.Dispose();
+                }
+            }
+        }
+
+        protected abstract Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream, X509Certificate serverCertificate = null, X509Certificate clientCertificate = null);
 
         protected abstract Task<int> ReadAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default);
 
         protected abstract Task WriteAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default);
 
-        [Fact]
-        public async Task SslStream_StreamToStream_Authentication_Success()
+        public static IEnumerable<object[]> SslStream_StreamToStream_Authentication_Success_MemberData()
         {
-            VirtualNetwork network = new VirtualNetwork();
+            using (X509Certificate2 serverCert = Configuration.Certificates.GetServerCertificate())
+            using (X509Certificate2 clientCert = Configuration.Certificates.GetClientCertificate())
+            {
+                yield return new object[] { new X509Certificate2(serverCert), new X509Certificate2(clientCert) };
+                yield return new object[] { new X509Certificate(serverCert.Export(X509ContentType.Pfx)), new X509Certificate(clientCert.Export(X509ContentType.Pfx)) };
+            }
+        }
 
-            using (var clientStream = new VirtualNetworkStream(network, isServer: false))
-            using (var serverStream = new VirtualNetworkStream(network, isServer: true))
-            using (var client = new SslStream(clientStream, false, AllowAnyServerCertificate))
-            using (var server = new SslStream(serverStream))
-            using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate())
+        [Theory]
+        [MemberData(nameof(SslStream_StreamToStream_Authentication_Success_MemberData))]
+        public async Task SslStream_StreamToStream_Authentication_Success(X509Certificate serverCert = null, X509Certificate clientCert = null)
+        {
+            var network = new VirtualNetwork();
+            using (var client = new SslStream(new VirtualNetworkStream(network, isServer: false), false, AllowAnyServerCertificate))
+            using (var server = new SslStream(new VirtualNetworkStream(network, isServer: true), false, delegate { return true; }))
             {
-                await DoHandshake(client, server);
+                await DoHandshake(client, server, serverCert, clientCert);
                 Assert.True(client.IsAuthenticated);
                 Assert.True(server.IsAuthenticated);
             }
+
+            clientCert?.Dispose();
+            serverCert?.Dispose();
         }
 
         [Fact]
@@ -739,14 +778,15 @@ namespace System.Net.Security.Tests
 
     public sealed class SslStreamStreamToStreamTest_Async : SslStreamStreamToStreamTest_CancelableReadWriteAsync
     {
-        protected override async Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream)
+        protected override async Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream, X509Certificate serverCertificate = null, X509Certificate clientCertificate = null)
         {
-            using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate())
+            X509CertificateCollection clientCerts = clientCertificate != null ? new X509CertificateCollection() { clientCertificate } : null;
+            await WithServerCertificate(serverCertificate, async(certificate, name) =>
             {
-                Task t1 = clientSslStream.AuthenticateAsClientAsync(certificate.GetNameInfo(X509NameType.SimpleName, false));
-                Task t2 = serverSslStream.AuthenticateAsServerAsync(certificate);
+                Task t1 = clientSslStream.AuthenticateAsClientAsync(name, clientCerts, SslProtocols.None, checkCertificateRevocation: false);
+                Task t2 = serverSslStream.AuthenticateAsServerAsync(certificate, clientCertificateRequired: clientCertificate != null, checkCertificateRevocation: false);
                 await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
-            }
+            });
         }
 
         protected override Task<int> ReadAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
@@ -758,14 +798,15 @@ namespace System.Net.Security.Tests
 
     public sealed class SslStreamStreamToStreamTest_BeginEnd : SslStreamStreamToStreamTest
     {
-        protected override async Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream)
+        protected override async Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream, X509Certificate serverCertificate = null, X509Certificate clientCertificate = null)
         {
-            using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate())
+            X509CertificateCollection clientCerts = clientCertificate != null ? new X509CertificateCollection() { clientCertificate } : null;
+            await WithServerCertificate(serverCertificate, async (certificate, name) =>
             {
-                Task t1 = Task.Factory.FromAsync(clientSslStream.BeginAuthenticateAsClient(certificate.GetNameInfo(X509NameType.SimpleName, false), null, null), clientSslStream.EndAuthenticateAsClient);
-                Task t2 = Task.Factory.FromAsync(serverSslStream.BeginAuthenticateAsServer(certificate, null, null), serverSslStream.EndAuthenticateAsServer);
+                Task t1 = Task.Factory.FromAsync(clientSslStream.BeginAuthenticateAsClient(name, clientCerts, SslProtocols.None, checkCertificateRevocation: false, null, null), clientSslStream.EndAuthenticateAsClient);
+                Task t2 = Task.Factory.FromAsync(serverSslStream.BeginAuthenticateAsServer(certificate, clientCertificateRequired: clientCertificate != null, checkCertificateRevocation: false, null, null), serverSslStream.EndAuthenticateAsServer);
                 await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
-            }
+            });
         }
 
         protected override Task<int> ReadAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
@@ -781,14 +822,15 @@ namespace System.Net.Security.Tests
 
     public sealed class SslStreamStreamToStreamTest_Sync : SslStreamStreamToStreamTest
     {
-        protected override async Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream)
+        protected override async Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream, X509Certificate serverCertificate = null, X509Certificate clientCertificate = null)
         {
-            using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate())
+            X509CertificateCollection clientCerts = clientCertificate != null ? new X509CertificateCollection() { clientCertificate } : null;
+            await WithServerCertificate(serverCertificate, async (certificate, name) =>
             {
-                Task t1 = Task.Run(() => clientSslStream.AuthenticateAsClient(certificate.GetNameInfo(X509NameType.SimpleName, false)));
-                Task t2 = Task.Run(() => serverSslStream.AuthenticateAsServer(certificate));
+                Task t1 = Task.Run(() => clientSslStream.AuthenticateAsClient(name, clientCerts, SslProtocols.None, checkCertificateRevocation: false));
+                Task t2 = Task.Run(() => serverSslStream.AuthenticateAsServer(certificate, clientCertificateRequired: clientCertificate != null, checkCertificateRevocation: false));
                 await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
-            }
+            });
         }
 
         protected override Task<int> ReadAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken)
index 62de050..1103a91 100644 (file)
@@ -16,14 +16,15 @@ namespace System.Net.Security.Tests
 
     public sealed class SslStreamStreamToStreamTest_MemoryAsync : SslStreamStreamToStreamTest_CancelableReadWriteAsync
     {
-        protected override async Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream)
+        protected override async Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream, X509Certificate serverCertificate = null, X509Certificate clientCertificate = null)
         {
-            using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate())
+            X509CertificateCollection clientCerts = clientCertificate != null ? new X509CertificateCollection() { clientCertificate } : null;
+            await WithServerCertificate(serverCertificate, async(certificate, name) =>
             {
-                Task t1 = clientSslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions() { TargetHost = certificate.GetNameInfo(X509NameType.SimpleName, false) }, CancellationToken.None);
-                Task t2 = serverSslStream.AuthenticateAsServerAsync(new SslServerAuthenticationOptions() { ServerCertificate = certificate }, CancellationToken.None);
+                Task t1 = clientSslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions() { TargetHost = name, ClientCertificates = clientCerts }, CancellationToken.None);
+                Task t2 = serverSslStream.AuthenticateAsServerAsync(new SslServerAuthenticationOptions() { ServerCertificate = certificate, ClientCertificateRequired = clientCertificate != null }, CancellationToken.None);
                 await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
-            }
+            });
         }
 
         protected override Task<int> ReadAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>