Allow only HTTP requests during X509 chain building
authorKrzysztof Wicher <mordotymoja@gmail.com>
Fri, 12 Feb 2021 16:55:01 +0000 (17:55 +0100)
committerGitHub <noreply@github.com>
Fri, 12 Feb 2021 16:55:01 +0000 (08:55 -0800)
Specifically, do not follow an HTTP->HTTPS redirect.  Primary HTTPS URLs were already ignored.

src/libraries/System.Security.Cryptography.X509Certificates/src/Internal/Cryptography/Pal.Unix/CertificateAssetDownloader.cs

index af89741..cd7ef4b 100644 (file)
@@ -15,7 +15,7 @@ namespace Internal.Cryptography.Pal
 {
     internal static class CertificateAssetDownloader
     {
-        private static readonly Func<string, CancellationToken, byte[]>? s_downloadBytes = CreateDownloadBytesFunc();
+        private static readonly Func<string, CancellationToken, byte[]?>? s_downloadBytes = CreateDownloadBytesFunc();
 
         internal static X509Certificate2? DownloadCertificate(string uri, TimeSpan downloadTimeout)
         {
@@ -120,7 +120,7 @@ namespace Internal.Cryptography.Pal
             return null;
         }
 
-        private static Func<string, CancellationToken, byte[]>? CreateDownloadBytesFunc()
+        private static Func<string, CancellationToken, byte[]?>? CreateDownloadBytesFunc()
         {
             try
             {
@@ -130,24 +130,36 @@ namespace Internal.Cryptography.Pal
 
                 // Get the relevant types needed.
                 Type? socketsHttpHandlerType = Type.GetType("System.Net.Http.SocketsHttpHandler, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false);
+                Type? httpMessageHandlerType = Type.GetType("System.Net.Http.HttpMessageHandler, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false);
                 Type? httpClientType = Type.GetType("System.Net.Http.HttpClient, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false);
                 Type? httpRequestMessageType = Type.GetType("System.Net.Http.HttpRequestMessage, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false);
                 Type? httpResponseMessageType = Type.GetType("System.Net.Http.HttpResponseMessage, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false);
+                Type? httpResponseHeadersType = Type.GetType("System.Net.Http.Headers.HttpResponseHeaders, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false);
                 Type? httpContentType = Type.GetType("System.Net.Http.HttpContent, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false);
-                if (socketsHttpHandlerType == null || httpClientType == null || httpRequestMessageType == null || httpResponseMessageType == null || httpContentType == null)
+                if (socketsHttpHandlerType == null || httpMessageHandlerType == null || httpClientType == null || httpRequestMessageType == null ||
+                    httpResponseMessageType == null || httpResponseHeadersType == null || httpContentType == null)
                 {
                     Debug.Fail("Unable to load required type.");
                     return null;
                 }
 
                 // Get the methods on those types.
+                ConstructorInfo? socketsHttpHandlerCtor = socketsHttpHandlerType.GetConstructor(Type.EmptyTypes);
                 PropertyInfo? pooledConnectionIdleTimeoutProp = socketsHttpHandlerType.GetProperty("PooledConnectionIdleTimeout");
+                PropertyInfo? allowAutoRedirectProp = socketsHttpHandlerType.GetProperty("AllowAutoRedirect");
+                ConstructorInfo? httpClientCtor = httpClientType.GetConstructor(new Type[] { httpMessageHandlerType });
                 PropertyInfo? requestUriProp = httpRequestMessageType.GetProperty("RequestUri");
                 ConstructorInfo? httpRequestMessageCtor = httpRequestMessageType.GetConstructor(Type.EmptyTypes);
                 MethodInfo? sendMethod = httpClientType.GetMethod("Send", new Type[] { httpRequestMessageType, typeof(CancellationToken) });
                 PropertyInfo? responseContentProp = httpResponseMessageType.GetProperty("Content");
+                PropertyInfo? responseStatusCodeProp = httpResponseMessageType.GetProperty("StatusCode");
+                PropertyInfo? responseHeadersProp = httpResponseMessageType.GetProperty("Headers");
+                PropertyInfo? responseHeadersLocationProp = httpResponseHeadersType.GetProperty("Location");
                 MethodInfo? readAsStreamMethod = httpContentType.GetMethod("ReadAsStream", Type.EmptyTypes);
-                if (pooledConnectionIdleTimeoutProp == null || requestUriProp == null || httpRequestMessageCtor == null || sendMethod == null || responseContentProp == null || readAsStreamMethod == null)
+
+                if (socketsHttpHandlerCtor == null || pooledConnectionIdleTimeoutProp == null || allowAutoRedirectProp == null || httpClientCtor == null ||
+                    requestUriProp == null || httpRequestMessageCtor == null || sendMethod == null || responseContentProp == null || responseStatusCodeProp == null ||
+                    responseHeadersProp == null || responseHeadersLocationProp == null || readAsStreamMethod == null)
                 {
                     Debug.Fail("Unable to load required member.");
                     return null;
@@ -155,30 +167,81 @@ namespace Internal.Cryptography.Pal
 
                 // Only keep idle connections around briefly, as a compromise between resource leakage and port exhaustion.
                 const int PooledConnectionIdleTimeoutSeconds = 15;
+                const int MaxRedirections = 10;
 
                 // Equivalent of:
-                // var socketsHttpHandler = new SocketsHttpHandler() { PooledConnectionIdleTimeout = TimeSpan.FromSeconds(PooledConnectionIdleTimeoutSeconds) };
+                // var socketsHttpHandler = new SocketsHttpHandler() {
+                //     PooledConnectionIdleTimeout = TimeSpan.FromSeconds(PooledConnectionIdleTimeoutSeconds),
+                //     AllowAutoRedirect = false
+                // };
                 // var httpClient = new HttpClient(socketsHttpHandler);
-                object? socketsHttpHandler = Activator.CreateInstance(socketsHttpHandlerType);
+                // Note: using a ConstructorInfo instead of Activator.CreateInstance, so the ILLinker can see the usage through the lambda method.
+                object? socketsHttpHandler = socketsHttpHandlerCtor.Invoke(null);
                 pooledConnectionIdleTimeoutProp.SetValue(socketsHttpHandler, TimeSpan.FromSeconds(PooledConnectionIdleTimeoutSeconds));
-                object? httpClient = Activator.CreateInstance(httpClientType, new object?[] { socketsHttpHandler });
+                allowAutoRedirectProp.SetValue(socketsHttpHandler, false);
+                object? httpClient = httpClientCtor.Invoke(new object?[] { socketsHttpHandler });
 
-                // Return a delegate for getting the byte[] for a uri. This delegate references the HttpClient object and thus
-                // all accesses will be through that singleton.
-                return (string uri, CancellationToken cancellationToken) =>
+                return (string uriString, CancellationToken cancellationToken) =>
                 {
+                    Uri uri = new Uri(uriString);
+
+                    if (!IsAllowedScheme(uri.Scheme))
+                    {
+                        return null;
+                    }
+
                     // Equivalent of:
-                    // HttpResponseMessage resp = httpClient.Send(new HttpRequestMessage() { RequestUri = new Uri(uri) });
-                    // using Stream responseStream = resp.Content.ReadAsStream();
+                    // HttpRequestMessage requestMessage = new HttpRequestMessage() { RequestUri = new Uri(uri) };
+                    // HttpResponseMessage responseMessage = httpClient.Send(requestMessage, cancellationToken);
                     // Note: using a ConstructorInfo instead of Activator.CreateInstance, so the ILLinker can see the usage through the lambda method.
                     object requestMessage = httpRequestMessageCtor.Invoke(null);
-                    requestUriProp.SetValue(requestMessage, new Uri(uri));
+                    requestUriProp.SetValue(requestMessage, uri);
                     object responseMessage = sendMethod.Invoke(httpClient, new object[] { requestMessage, cancellationToken })!;
+
+                    int redirections = 0;
+                    Uri? redirectUri;
+                    bool hasRedirect;
+                    while (true)
+                    {
+                        int statusCode = (int)responseStatusCodeProp.GetValue(responseMessage)!;
+                        object responseHeaders = responseHeadersProp.GetValue(responseMessage)!;
+                        Uri? location = (Uri?)responseHeadersLocationProp.GetValue(responseHeaders);
+                        redirectUri = GetUriForRedirect((Uri)requestUriProp.GetValue(requestMessage)!, statusCode, location, out hasRedirect);
+                        if (redirectUri == null)
+                        {
+                            break;
+                        }
+
+                        ((IDisposable)responseMessage).Dispose();
+
+                        redirections++;
+                        if (redirections > MaxRedirections)
+                        {
+                            return null;
+                        }
+
+                        // Equivalent of:
+                        // requestMessage = new HttpRequestMessage() { RequestUri = redirectUri };
+                        // requestMessage.RequestUri = redirectUri;
+                        // responseMessage = httpClient.Send(requestMessage, cancellationToken);
+                        requestMessage = httpRequestMessageCtor.Invoke(null);
+                        requestUriProp.SetValue(requestMessage, redirectUri);
+                        responseMessage = sendMethod.Invoke(httpClient, new object[] { requestMessage, cancellationToken })!;
+                    }
+
+                    if (hasRedirect && redirectUri == null)
+                    {
+                        return null;
+                    }
+
+                    // Equivalent of:
+                    // using Stream responseStream = resp.Content.ReadAsStream();
                     object content = responseContentProp.GetValue(responseMessage)!;
                     using Stream responseStream = (Stream)readAsStreamMethod.Invoke(content, null)!;
 
                     var result = new MemoryStream();
                     responseStream.CopyTo(result);
+                    ((IDisposable)responseMessage).Dispose();
                     return result.ToArray();
                 };
             }
@@ -188,5 +251,57 @@ namespace Internal.Cryptography.Pal
                 return null;
             }
         }
+
+        private static Uri? GetUriForRedirect(Uri requestUri, int statusCode, Uri? location, out bool hasRedirect)
+        {
+            if (!IsRedirectStatusCode(statusCode))
+            {
+                hasRedirect = false;
+                return null;
+            }
+
+            hasRedirect = true;
+
+            if (location == null)
+            {
+                return null;
+            }
+
+            // Ensure the redirect location is an absolute URI.
+            if (!location.IsAbsoluteUri)
+            {
+                location = new Uri(requestUri, location);
+            }
+
+            // Per https://tools.ietf.org/html/rfc7231#section-7.1.2, a redirect location without a
+            // fragment should inherit the fragment from the original URI.
+            string requestFragment = requestUri.Fragment;
+            if (!string.IsNullOrEmpty(requestFragment))
+            {
+                string redirectFragment = location.Fragment;
+                if (string.IsNullOrEmpty(redirectFragment))
+                {
+                    location = new UriBuilder(location) { Fragment = requestFragment }.Uri;
+                }
+            }
+
+            if (!IsAllowedScheme(location.Scheme))
+            {
+                return null;
+            }
+
+            return location;
+        }
+
+        private static bool IsRedirectStatusCode(int statusCode)
+        {
+            // MultipleChoices (300), Moved (301), Found (302), SeeOther (303), TemporaryRedirect (307), PermanentRedirect (308)
+            return (statusCode >= 300 && statusCode <= 303) || statusCode == 307 || statusCode == 308;
+        }
+
+        private static bool IsAllowedScheme(string scheme)
+        {
+            return string.Equals(scheme, "http", StringComparison.OrdinalIgnoreCase);
+        }
     }
 }