More robust handling of CERT_CONTEXT with multiple threads
authorKevin Jones <kevin@vcsjones.com>
Mon, 21 Nov 2022 23:46:41 +0000 (18:46 -0500)
committerGitHub <noreply@github.com>
Mon, 21 Nov 2022 23:46:41 +0000 (18:46 -0500)
src/libraries/Common/src/Microsoft/Win32/SafeHandles/SafeCertContextHandle.cs
src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/CertificatePal.Windows.cs
src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/FindPal.Windows.cs
src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/StorePal.Windows.Export.cs
src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/StorePal.Windows.cs
src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/X509Pal.Windows.PublicKey.cs
src/libraries/System.Security.Cryptography/tests/X509Certificates/CertTests.cs

index cca22c6..0aea95b 100644 (file)
@@ -48,7 +48,7 @@ namespace Microsoft.Win32.SafeHandles
             return true;
         }
 
-        public unsafe CERT_CONTEXT* CertContext
+        public unsafe CERT_CONTEXT* DangerousCertContext
         {
             get { return (CERT_CONTEXT*)handle; }
         }
index c9815a2..13354c3 100644 (file)
@@ -80,10 +80,10 @@ namespace System.Security.Cryptography.X509Certificates
             {
                 unsafe
                 {
-                    Interop.Crypt32.CERT_CONTEXT* pCertContext = _certContext.CertContext;
-                    string keyAlgorithm = Marshal.PtrToStringAnsi(pCertContext->pCertInfo->SubjectPublicKeyInfo.Algorithm.pszObjId)!;
-                    GC.KeepAlive(this);
-                    return keyAlgorithm;
+                    return InvokeWithCertContext(static certContext =>
+                    {
+                        return Marshal.PtrToStringAnsi(certContext->pCertInfo->SubjectPublicKeyInfo.Algorithm.pszObjId)!;
+                    });
                 }
             }
         }
@@ -94,39 +94,40 @@ namespace System.Security.Cryptography.X509Certificates
             {
                 unsafe
                 {
-                    Interop.Crypt32.CERT_CONTEXT* pCertContext = _certContext.CertContext;
-                    string keyAlgorithmOid = Marshal.PtrToStringAnsi(pCertContext->pCertInfo->SubjectPublicKeyInfo.Algorithm.pszObjId)!;
-
-                    int algId;
-                    if (keyAlgorithmOid == Oids.Rsa)
-                        algId = AlgId.CALG_RSA_KEYX;  // Fast-path for the most common case.
-                    else
-                        algId = Interop.Crypt32.FindOidInfo(Interop.Crypt32.CryptOidInfoKeyType.CRYPT_OID_INFO_OID_KEY, keyAlgorithmOid, OidGroup.PublicKeyAlgorithm, fallBackToAllGroups: true).AlgId;
-
-                    unsafe
+                    return InvokeWithCertContext(pCertContext =>
                     {
-                        byte* NULL_ASN_TAG = (byte*)0x5;
+                        string keyAlgorithmOid = Marshal.PtrToStringAnsi(pCertContext->pCertInfo->SubjectPublicKeyInfo.Algorithm.pszObjId)!;
 
-                        byte[] keyAlgorithmParameters;
-
-                        if (algId == AlgId.CALG_DSS_SIGN
-                            && pCertContext->pCertInfo->SubjectPublicKeyInfo.Algorithm.Parameters.cbData == 0
-                            && pCertContext->pCertInfo->SubjectPublicKeyInfo.Algorithm.Parameters.pbData.ToPointer() == NULL_ASN_TAG)
-                        {
-                            //
-                            // DSS certificates may not have the DSS parameters in the certificate. In this case, we try to build
-                            // the certificate chain and propagate the parameters down from the certificate chain.
-                            //
-                            keyAlgorithmParameters = PropagateKeyAlgorithmParametersFromChain();
-                        }
+                        int algId;
+                        if (keyAlgorithmOid == Oids.Rsa)
+                            algId = AlgId.CALG_RSA_KEYX;  // Fast-path for the most common case.
                         else
+                            algId = Interop.Crypt32.FindOidInfo(Interop.Crypt32.CryptOidInfoKeyType.CRYPT_OID_INFO_OID_KEY, keyAlgorithmOid, OidGroup.PublicKeyAlgorithm, fallBackToAllGroups: true).AlgId;
+
+                        unsafe
                         {
-                            keyAlgorithmParameters = pCertContext->pCertInfo->SubjectPublicKeyInfo.Algorithm.Parameters.ToByteArray();
+                            byte* NULL_ASN_TAG = (byte*)0x5;
+
+                            byte[] keyAlgorithmParameters;
+
+                            if (algId == AlgId.CALG_DSS_SIGN
+                                && pCertContext->pCertInfo->SubjectPublicKeyInfo.Algorithm.Parameters.cbData == 0
+                                && pCertContext->pCertInfo->SubjectPublicKeyInfo.Algorithm.Parameters.pbData.ToPointer() == NULL_ASN_TAG)
+                            {
+                                //
+                                // DSS certificates may not have the DSS parameters in the certificate. In this case, we try to build
+                                // the certificate chain and propagate the parameters down from the certificate chain.
+                                //
+                                keyAlgorithmParameters = PropagateKeyAlgorithmParametersFromChain();
+                            }
+                            else
+                            {
+                                keyAlgorithmParameters = pCertContext->pCertInfo->SubjectPublicKeyInfo.Algorithm.Parameters.ToByteArray();
+                            }
+
+                            return keyAlgorithmParameters;
                         }
-
-                        GC.KeepAlive(this);
-                        return keyAlgorithmParameters;
-                    }
+                    });
                 }
             }
         }
@@ -168,10 +169,10 @@ namespace System.Security.Cryptography.X509Certificates
             {
                 unsafe
                 {
-                    Interop.Crypt32.CERT_CONTEXT* pCertContext = _certContext.CertContext;
-                    byte[] publicKey = pCertContext->pCertInfo->SubjectPublicKeyInfo.PublicKey.ToByteArray();
-                    GC.KeepAlive(this);
-                    return publicKey;
+                    return InvokeWithCertContext(static pCertContext =>
+                    {
+                        return pCertContext->pCertInfo->SubjectPublicKeyInfo.PublicKey.ToByteArray();
+                    });
                 }
             }
         }
@@ -182,11 +183,12 @@ namespace System.Security.Cryptography.X509Certificates
             {
                 unsafe
                 {
-                    Interop.Crypt32.CERT_CONTEXT* pCertContext = _certContext.CertContext;
-                    byte[] serialNumber = pCertContext->pCertInfo->SerialNumber.ToByteArray();
-                    Array.Reverse(serialNumber);
-                    GC.KeepAlive(this);
-                    return serialNumber;
+                    return InvokeWithCertContext(static pCertContext =>
+                    {
+                        byte[] serialNumber = pCertContext->pCertInfo->SerialNumber.ToByteArray();
+                        Array.Reverse(serialNumber);
+                        return serialNumber;
+                    });
                 }
             }
         }
@@ -197,10 +199,10 @@ namespace System.Security.Cryptography.X509Certificates
             {
                 unsafe
                 {
-                    Interop.Crypt32.CERT_CONTEXT* pCertContext = _certContext.CertContext;
-                    string signatureAlgorithm = Marshal.PtrToStringAnsi(pCertContext->pCertInfo->SignatureAlgorithm.pszObjId)!;
-                    GC.KeepAlive(this);
-                    return signatureAlgorithm;
+                    return InvokeWithCertContext(static pCertContext =>
+                    {
+                        return Marshal.PtrToStringAnsi(pCertContext->pCertInfo->SignatureAlgorithm.pszObjId)!;
+                    });
                 }
             }
         }
@@ -211,10 +213,7 @@ namespace System.Security.Cryptography.X509Certificates
             {
                 unsafe
                 {
-                    Interop.Crypt32.CERT_CONTEXT* pCertContext = _certContext.CertContext;
-                    DateTime notAfter = pCertContext->pCertInfo->NotAfter.ToDateTime();
-                    GC.KeepAlive(this);
-                    return notAfter;
+                    return InvokeWithCertContext(static pCertContext => pCertContext->pCertInfo->NotAfter.ToDateTime());
                 }
             }
         }
@@ -225,10 +224,7 @@ namespace System.Security.Cryptography.X509Certificates
             {
                 unsafe
                 {
-                    Interop.Crypt32.CERT_CONTEXT* pCertContext = _certContext.CertContext;
-                    DateTime notBefore = pCertContext->pCertInfo->NotBefore.ToDateTime();
-                    GC.KeepAlive(this);
-                    return notBefore;
+                    return InvokeWithCertContext(static pCertContext => pCertContext->pCertInfo->NotBefore.ToDateTime());
                 }
             }
         }
@@ -239,10 +235,10 @@ namespace System.Security.Cryptography.X509Certificates
             {
                 unsafe
                 {
-                    Interop.Crypt32.CERT_CONTEXT* pCertContext = _certContext.CertContext;
-                    byte[] rawData = new Span<byte>(pCertContext->pbCertEncoded, pCertContext->cbCertEncoded).ToArray();
-                    GC.KeepAlive(this);
-                    return rawData;
+                    return InvokeWithCertContext(static pCertContext =>
+                    {
+                        return new Span<byte>(pCertContext->pbCertEncoded, pCertContext->cbCertEncoded).ToArray();
+                    });
                 }
             }
         }
@@ -253,10 +249,7 @@ namespace System.Security.Cryptography.X509Certificates
             {
                 unsafe
                 {
-                    Interop.Crypt32.CERT_CONTEXT* pCertContext = _certContext.CertContext;
-                    int version = pCertContext->pCertInfo->dwVersion + 1;
-                    GC.KeepAlive(this);
-                    return version;
+                    return InvokeWithCertContext(static pCertContext => pCertContext->pCertInfo->dwVersion + 1);
                 }
             }
         }
@@ -332,11 +325,12 @@ namespace System.Security.Cryptography.X509Certificates
             {
                 unsafe
                 {
-                    // X500DN creates a copy of the data for itself; data is kept alive with GC.KeepAlive.
-                    ReadOnlySpan<byte> encodedSubjectName = _certContext.CertContext->pCertInfo->Subject.DangerousAsSpan();
-                    X500DistinguishedName subjectName = new X500DistinguishedName(encodedSubjectName);
-                    GC.KeepAlive(this);
-                    return subjectName;
+                    return InvokeWithCertContext(static certContext =>
+                    {
+                        ReadOnlySpan<byte> encodedSubjectName = certContext->pCertInfo->Subject.DangerousAsSpan();
+                        X500DistinguishedName subjectName = new X500DistinguishedName(encodedSubjectName);
+                        return subjectName;
+                    });
                 }
             }
         }
@@ -347,11 +341,12 @@ namespace System.Security.Cryptography.X509Certificates
             {
                 unsafe
                 {
-                    // X500DN creates a copy of the data for itself; data is kept alive with GC.KeepAlive.
-                    ReadOnlySpan<byte> encodedIssuerName = _certContext.CertContext->pCertInfo->Issuer.DangerousAsSpan();
-                    X500DistinguishedName issuerName = new X500DistinguishedName(encodedIssuerName);
-                    GC.KeepAlive(this);
-                    return issuerName;
+                    return InvokeWithCertContext(static certContext =>
+                    {
+                        ReadOnlySpan<byte> encodedIssuerName = certContext->pCertInfo->Issuer.DangerousAsSpan();
+                        X500DistinguishedName issuerName = new X500DistinguishedName(encodedIssuerName);
+                        return issuerName;
+                    });
                 }
             }
         }
@@ -367,25 +362,26 @@ namespace System.Security.Cryptography.X509Certificates
             {
                 unsafe
                 {
-                    Interop.Crypt32.CERT_INFO* pCertInfo = _certContext.CertContext->pCertInfo;
-                    int numExtensions = pCertInfo->cExtension;
-                    X509Extension[] extensions = new X509Extension[numExtensions];
-
-                    for (int i = 0; i < numExtensions; i++)
+                    return InvokeWithCertContext(static certContext =>
                     {
-                        Interop.Crypt32.CERT_EXTENSION* pCertExtension = (Interop.Crypt32.CERT_EXTENSION*)pCertInfo->rgExtension.ToPointer() + i;
-                        string oidValue = Marshal.PtrToStringAnsi(pCertExtension->pszObjId)!;
-                        Oid oid = new Oid(oidValue, friendlyName: null);
-                        bool critical = pCertExtension->fCritical != 0;
-
-                        // X509Extension creates a copy of the data for itself. The underlying data
-                        // is kept alive with the KeepAlive below.
-                        ReadOnlySpan<byte> rawData = pCertExtension->Value.DangerousAsSpan();
-                        extensions[i] = new X509Extension(oid, rawData, critical);
-                    }
+                        Interop.Crypt32.CERT_INFO* pCertInfo = certContext->pCertInfo;
+                        int numExtensions = pCertInfo->cExtension;
+                        X509Extension[] extensions = new X509Extension[numExtensions];
+
+                        for (int i = 0; i < numExtensions; i++)
+                        {
+                            Interop.Crypt32.CERT_EXTENSION* pCertExtension = (Interop.Crypt32.CERT_EXTENSION*)pCertInfo->rgExtension.ToPointer() + i;
+                            string oidValue = Marshal.PtrToStringAnsi(pCertExtension->pszObjId)!;
+                            Oid oid = new Oid(oidValue, friendlyName: null);
+                            bool critical = pCertExtension->fCritical != 0;
+
+                            // X509Extension creates a copy of the data for itself.
+                            ReadOnlySpan<byte> rawData = pCertExtension->Value.DangerousAsSpan();
+                            extensions[i] = new X509Extension(oid, rawData, critical);
+                        }
 
-                    GC.KeepAlive(this);
-                    return extensions;
+                        return extensions;
+                    });
                 }
             }
         }
@@ -477,9 +473,13 @@ namespace System.Security.Cryptography.X509Certificates
 
         internal SafeCertContextHandle GetCertContext()
         {
-            SafeCertContextHandle certContext = Interop.Crypt32.CertDuplicateCertificateContext(_certContext.DangerousGetHandle());
-            GC.KeepAlive(_certContext);
-            return certContext;
+            unsafe
+            {
+                return InvokeWithCertContext(static certContext =>
+                {
+                    return Interop.Crypt32.CertDuplicateCertificateContext((IntPtr)certContext);
+                });
+            }
         }
 
         private static Interop.Crypt32.CertNameType MapNameType(X509NameType nameType)
@@ -544,5 +544,25 @@ namespace System.Security.Cryptography.X509Certificates
                 return exported;
             }
         }
+
+        private unsafe T InvokeWithCertContext<T>(CertContextCallback<T> callback)
+        {
+            bool added = false;
+            _certContext.DangerousAddRef(ref added);
+
+            try
+            {
+                return callback(_certContext.DangerousCertContext);
+            }
+            finally
+            {
+                if (added)
+                {
+                    _certContext.DangerousRelease();
+                }
+            }
+        }
+
+        private unsafe delegate T CertContextCallback<T>(Interop.Crypt32.CERT_CONTEXT* certContext);
     }
 }
index 5d22c6b..84acc7b 100644 (file)
@@ -93,7 +93,9 @@ namespace System.Security.Cryptography.X509Certificates
                 (hexValue, decimalValue),
                 static (state, pCertContext) =>
                 {
-                    ReadOnlySpan<byte> actual = pCertContext.CertContext->pCertInfo->SerialNumber.DangerousAsSpan();
+                    // FindCore owns the lifetime of the CERT_CONTEXT and doesn't escape, so it can't be disposed of
+                    // by another thread.
+                    ReadOnlySpan<byte> actual = pCertContext.DangerousCertContext->pCertInfo->SerialNumber.DangerousAsSpan();
 
                     // Convert to BigInteger as the comparison must not fail due to spurious leading zeros
                     BigInteger actualAsBigInteger = new BigInteger(actual, isUnsigned: true);
@@ -119,7 +121,7 @@ namespace System.Security.Cryptography.X509Certificates
                 {
                     int comparison = Interop.Crypt32.CertVerifyTimeValidity(
                         ref state.fileTime,
-                        pCertContext.CertContext->pCertInfo);
+                        pCertContext.DangerousCertContext->pCertInfo);
                     GC.KeepAlive(pCertContext);
                     return comparison == state.compareResult;
                 });
@@ -158,8 +160,10 @@ namespace System.Security.Cryptography.X509Certificates
                     // V2 format (XP only) can be a friendly name or an OID.
                     // An example of Template Name can be "ClientAuth".
 
+                    // FindCore owns the lifetime of the CERT_CONTEXT and doesn't escape, so it can't be disposed of
+                    // by another thread.
                     bool foundMatch = false;
-                    Interop.Crypt32.CERT_INFO* pCertInfo = pCertContext.CertContext->pCertInfo;
+                    Interop.Crypt32.CERT_INFO* pCertInfo = pCertContext.DangerousCertContext->pCertInfo;
                     Interop.Crypt32.CERT_EXTENSION* pV1Template = Interop.Crypt32.CertFindExtension(
                         Oids.EnrollCertTypeExtension,
                         pCertInfo->cExtension,
@@ -273,7 +277,7 @@ namespace System.Security.Cryptography.X509Certificates
                 oidValue,
                 static (oidValue, pCertContext) =>
                 {
-                    Interop.Crypt32.CERT_INFO* pCertInfo = pCertContext.CertContext->pCertInfo;
+                    Interop.Crypt32.CERT_INFO* pCertInfo = pCertContext.DangerousCertContext->pCertInfo;
                     Interop.Crypt32.CERT_EXTENSION* pCertExtension = Interop.Crypt32.CertFindExtension(
                         Oids.CertPolicies,
                         pCertInfo->cExtension,
@@ -303,7 +307,7 @@ namespace System.Security.Cryptography.X509Certificates
                 oidValue,
                 static (oidValue, pCertContext) =>
                 {
-                    Interop.Crypt32.CERT_INFO* pCertInfo = pCertContext.CertContext->pCertInfo;
+                    Interop.Crypt32.CERT_INFO* pCertInfo = pCertContext.DangerousCertContext->pCertInfo;
                     Interop.Crypt32.CERT_EXTENSION* pCertExtension = Interop.Crypt32.CertFindExtension(oidValue, pCertInfo->cExtension, pCertInfo->rgExtension);
                     GC.KeepAlive(pCertContext);
                     return pCertExtension != null;
@@ -316,7 +320,7 @@ namespace System.Security.Cryptography.X509Certificates
                 keyUsage,
                 static (keyUsage, pCertContext) =>
                 {
-                    Interop.Crypt32.CERT_INFO* pCertInfo = pCertContext.CertContext->pCertInfo;
+                    Interop.Crypt32.CERT_INFO* pCertInfo = pCertContext.DangerousCertContext->pCertInfo;
                     X509KeyUsageFlags actual;
 
                     if (!Interop.crypt32.CertGetIntendedKeyUsage(Interop.Crypt32.CertEncodingType.All, pCertInfo, out actual, sizeof(X509KeyUsageFlags)))
index 6e52703..920f262 100644 (file)
@@ -32,8 +32,10 @@ namespace System.Security.Cryptography.X509Certificates
                         {
                             unsafe
                             {
-                                byte[] rawData = new byte[pCertContext.CertContext->cbCertEncoded];
-                                Marshal.Copy((IntPtr)(pCertContext.CertContext->pbCertEncoded), rawData, 0, rawData.Length);
+                                // We can use the DangerousCertContext because the safehandle never leaves this method
+                                // and can't be disposed of by another thread.
+                                byte[] rawData = new byte[pCertContext.DangerousCertContext->cbCertEncoded];
+                                Marshal.Copy((IntPtr)(pCertContext.DangerousCertContext->pbCertEncoded), rawData, 0, rawData.Length);
                                 GC.KeepAlive(pCertContext);
                                 return rawData;
                             }
index de077b7..f368ba7 100644 (file)
@@ -61,7 +61,9 @@ namespace System.Security.Cryptography.X509Certificates
             using (SafeCertContextHandle existingCertContext = ((CertificatePal)certificate).GetCertContext())
             {
                 SafeCertContextHandle? enumCertContext = null;
-                Interop.Crypt32.CERT_CONTEXT* pCertContext = existingCertContext.CertContext;
+                // We can use DangerousCertContext safely here because GetCertContext returns a duplicated context
+                // that we own and doesn't escape.
+                Interop.Crypt32.CERT_CONTEXT* pCertContext = existingCertContext.DangerousCertContext;
                 if (!Interop.crypt32.CertFindCertificateInStore(_certStore, Interop.Crypt32.CertFindType.CERT_FIND_EXISTING, pCertContext, ref enumCertContext))
                     return; // The certificate is not present in the store, simply return.
 
index 525b31b..bc3b2b0 100644 (file)
@@ -135,7 +135,13 @@ namespace System.Security.Cryptography.X509Certificates
                 {
                     unsafe
                     {
-                        bool success = Interop.Crypt32.CryptImportPublicKeyInfoEx2(Interop.Crypt32.CertEncodingType.X509_ASN_ENCODING, &(certContext.CertContext->pCertInfo->SubjectPublicKeyInfo), importFlags, null, out bCryptKeyHandle);
+                        bool success = Interop.Crypt32.CryptImportPublicKeyInfoEx2(
+                            Interop.Crypt32.CertEncodingType.X509_ASN_ENCODING,
+                            &(certContext.DangerousCertContext->pCertInfo->SubjectPublicKeyInfo),
+                            importFlags,
+                            null,
+                            out bCryptKeyHandle);
+
                         if (!success)
                         {
                             Exception e = Marshal.GetHRForLastWin32Error().ToCryptographicException();
index 068d7b7..ed0db08 100644 (file)
@@ -26,6 +26,47 @@ namespace System.Security.Cryptography.X509Certificates.Tests
         }
 
         [Fact]
+        public static void RaceDisposeAndKeyAccess()
+        {
+            using RSA rsa = RSA.Create();
+            CertificateRequest req = new CertificateRequest("CN=potato", rsa, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
+            using X509Certificate2 cert = req.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now);
+
+            for (int i = 0; i < 100; i++)
+            {
+                X509Certificate2 w = new X509Certificate2(cert.RawData.AsSpan());
+                X509Certificate2 y = w.CopyWithPrivateKey(rsa);
+                w.Dispose();
+
+                Thread t1 = new Thread(cert => {
+                    Thread.Sleep(Random.Shared.Next(0, 20));
+                    X509Certificate2 c = (X509Certificate2)cert!;
+                    c.Dispose();
+                    GC.Collect();
+                });
+
+                Thread t2 = new Thread(cert => {
+                    Thread.Sleep(Random.Shared.Next(0, 20));
+                    X509Certificate2 c = (X509Certificate2)cert!;
+
+                    try
+                    {
+                        c.GetRSAPrivateKey()!.ExportParameters(false);
+                    }
+                    catch
+                    {
+                        // don't care about managed exceptions.
+                    }
+                });
+
+                t1.Start(y);
+                t2.Start(y);
+                t1.Join();
+                t2.Join();
+            }
+        }
+
+        [Fact]
         public static void RaceUseAndDisposeDoesNotCrash()
         {
             X509Certificate2 cert = new X509Certificate2(TestFiles.MicrosoftRootCertFile);