fix SslStreamCertificateContext.Create with partial chain (#46664)
authorTomas Weinfurt <tweinfurt@yahoo.com>
Mon, 11 Jan 2021 19:32:11 +0000 (11:32 -0800)
committerGitHub <noreply@github.com>
Mon, 11 Jan 2021 19:32:11 +0000 (11:32 -0800)
* fix SslStreamCertificateContext.Create with partial chain

* add test with long partial chain

src/libraries/System.Net.Security/src/System/Net/Security/SslStreamCertificateContext.cs
src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNetworkStreamTest.cs
src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs

index 37f3e03..ed40024 100644 (file)
@@ -38,16 +38,22 @@ namespace System.Net.Security
                     NetEventSource.Error(null, $"Failed to build chain for {target.Subject}");
                 }
 
-                int count = chain.ChainElements.Count - (TrimRootCertificate ? 1 : 2);
-                foreach (X509ChainStatus status in chain.ChainStatus)
+                int count = chain.ChainElements.Count - 1;
+#pragma warning disable 0162 // Disable unreachable code warning. TrimRootCertificate is const bool = false on some platforms
+                if (TrimRootCertificate)
                 {
-                    if (status.Status.HasFlag(X509ChainStatusFlags.PartialChain))
+                    count--;
+                    foreach (X509ChainStatus status in chain.ChainStatus)
                     {
-                        // The last cert isn't a root cert
-                        count++;
-                        break;
+                        if (status.Status.HasFlag(X509ChainStatusFlags.PartialChain))
+                        {
+                            // The last cert isn't a root cert
+                            count++;
+                            break;
+                        }
                     }
                 }
+#pragma warning restore 0162
 
                 // Count can be zero for a self-signed certificate, or a cert issued directly from a root.
                 if (count > 0 && chain.ChainElements.Count > 1)
index 40c6b20..3678147 100644 (file)
@@ -24,7 +24,7 @@ namespace System.Net.Security.Tests
         static SslStreamNetworkStreamTest()
         {
             TestHelper.CleanupCertificates(nameof(SslStreamNetworkStreamTest));
-            (_serverCert, _serverChain) = TestHelper.GenerateCertificates("localhost", nameof(SslStreamNetworkStreamTest));
+            (_serverCert, _serverChain) = TestHelper.GenerateCertificates("localhost", nameof(SslStreamNetworkStreamTest), longChain: true);
         }
 
         [ConditionalFact]
@@ -213,15 +213,30 @@ namespace System.Net.Security.Tests
             }
         }
 
-        [Fact]
-        public async Task SslStream_UntrustedCaWithCustomCallback_OK()
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async Task SslStream_UntrustedCaWithCustomCallback_OK(bool usePartialChain)
         {
+            var rnd = new Random();
+            int split = rnd.Next(0, _serverChain.Count - 1);
+
             var clientOptions = new  SslClientAuthenticationOptions() { TargetHost = "localhost" };
             clientOptions.RemoteCertificateValidationCallback =
                 (sender, certificate, chain, sslPolicyErrors) =>
                 {
-                    chain.ChainPolicy.CustomTrustStore.Add(_serverChain[_serverChain.Count -1]);
+                    // add our custom root CA
+                    chain.ChainPolicy.CustomTrustStore.Add(_serverChain[_serverChain.Count - 1]);
                     chain.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust;
+                    // Add only one CA to verify that peer did send intermediate CA cert.
+                    // In case of partial chain, we need to make missing certs available.
+                    if (usePartialChain)
+                    {
+                        for (int i = split; i < _serverChain.Count - 1; i++)
+                        {
+                            chain.ChainPolicy.ExtraStore.Add(_serverChain[i]);
+                        }
+                    }
 
                     bool result = chain.Build((X509Certificate2)certificate);
                     Assert.True(result);
@@ -230,7 +245,22 @@ namespace System.Net.Security.Tests
                 };
 
             var serverOptions = new SslServerAuthenticationOptions();
-            serverOptions.ServerCertificateContext = SslStreamCertificateContext.Create(_serverCert, _serverChain);
+            X509Certificate2Collection serverChain;
+            if (usePartialChain)
+            {
+                // give first few certificates without root CA
+                serverChain = new X509Certificate2Collection();
+                for (int i = 0; i < split; i++)
+                {
+                    serverChain.Add(_serverChain[i]);
+                }
+            }
+            else
+            {
+                serverChain = _serverChain;
+            }
+
+            serverOptions.ServerCertificateContext = SslStreamCertificateContext.Create(_serverCert, serverChain);
 
             (Stream clientStream, Stream serverStream) = TestHelper.GetConnectedStreams();
             using (clientStream)
@@ -258,6 +288,7 @@ namespace System.Net.Security.Tests
                 clientOptions.RemoteCertificateValidationCallback =
                     (sender, certificate, chain, sslPolicyErrors) =>
                     {
+                        // Add only root CA to verify that peer did send intermediate CA cert.
                         chain.ChainPolicy.CustomTrustStore.Add(_serverChain[_serverChain.Count -1]);
                         chain.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust;
                         // This should work and we should be able to trust the chain.
@@ -270,7 +301,8 @@ namespace System.Net.Security.Tests
             }
             else
             {
-                errorMessage = "UntrustedRoot";
+                // On Windows we hand whole chain to OS so they can always see the root CA.
+                errorMessage = PlatformDetection.IsWindows ? "UntrustedRoot" : "PartialChain";
             }
 
             var serverOptions = new SslServerAuthenticationOptions();
index 8cd95d4..ecdfeb4 100644 (file)
@@ -107,8 +107,9 @@ namespace System.Net.Security.Tests
             catch { };
         }
 
-        internal static (X509Certificate2 certificate, X509Certificate2Collection) GenerateCertificates(string targetName, [CallerMemberName] string? testName = null)
+        internal static (X509Certificate2 certificate, X509Certificate2Collection) GenerateCertificates(string targetName, [CallerMemberName] string? testName = null, bool longChain = false)
         {
+            const int keySize = 2048;
             if (PlatformDetection.IsWindows && testName != null)
             {
                 CleanupCertificates(testName);
@@ -132,9 +133,43 @@ namespace System.Net.Security.Tests
                 out X509Certificate2 endEntity,
                 subjectName: targetName,
                 testName: testName,
-                keySize: 2048,
+                keySize: keySize,
                 extensions: extensions);
 
+            if (longChain)
+            {
+                using (RSA intermedKey2 = RSA.Create(keySize))
+                using (RSA intermedKey3 = RSA.Create(keySize))
+                {
+                    X509Certificate2 intermedPub2 = intermediate.CreateSubordinateCA(
+                        $"CN=\"A SSL Test CA 2\", O=\"testName\"",
+                        intermedKey2);
+
+                    X509Certificate2 intermedCert2 = intermedPub2.CopyWithPrivateKey(intermedKey2);
+                    intermedPub2.Dispose();
+                    CertificateAuthority intermediateAuthority2 = new CertificateAuthority(intermedCert2, null, null, null);
+
+                    X509Certificate2 intermedPub3 = intermediateAuthority2.CreateSubordinateCA(
+                        $"CN=\"A SSL Test CA 3\", O=\"testName\"",
+                        intermedKey3);
+
+                    X509Certificate2 intermedCert3 = intermedPub3.CopyWithPrivateKey(intermedKey3);
+                    intermedPub3.Dispose();
+                    CertificateAuthority intermediateAuthority3 = new CertificateAuthority(intermedCert3, null, null, null);
+
+                    RSA  eeKey = (RSA)endEntity.PrivateKey;
+                    endEntity = intermediateAuthority3.CreateEndEntity(
+                        $"CN=\"A SSL Test\", O=\"testName\"",
+                        eeKey,
+                        extensions);
+
+                    endEntity = endEntity.CopyWithPrivateKey(eeKey);
+
+                    chain.Add(intermedCert3);
+                    chain.Add(intermedCert2);
+                }
+            }
+
             chain.Add(intermediate.CloneIssuerCert());
             chain.Add(root.CloneIssuerCert());