{
internal static class CertificateAssetDownloader
{
- private static readonly Func<string, CancellationToken, byte[]>? s_downloadBytes = CreateDownloadBytesFunc();
+ private static readonly Func<string, CancellationToken, byte[]?>? s_downloadBytes = CreateDownloadBytesFunc();
internal static X509Certificate2? DownloadCertificate(string uri, TimeSpan downloadTimeout)
{
return null;
}
- private static Func<string, CancellationToken, byte[]>? CreateDownloadBytesFunc()
+ private static Func<string, CancellationToken, byte[]?>? CreateDownloadBytesFunc()
{
try
{
// Get the relevant types needed.
Type? socketsHttpHandlerType = Type.GetType("System.Net.Http.SocketsHttpHandler, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false);
+ Type? httpMessageHandlerType = Type.GetType("System.Net.Http.HttpMessageHandler, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false);
Type? httpClientType = Type.GetType("System.Net.Http.HttpClient, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false);
Type? httpRequestMessageType = Type.GetType("System.Net.Http.HttpRequestMessage, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false);
Type? httpResponseMessageType = Type.GetType("System.Net.Http.HttpResponseMessage, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false);
+ Type? httpResponseHeadersType = Type.GetType("System.Net.Http.Headers.HttpResponseHeaders, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false);
Type? httpContentType = Type.GetType("System.Net.Http.HttpContent, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false);
- if (socketsHttpHandlerType == null || httpClientType == null || httpRequestMessageType == null || httpResponseMessageType == null || httpContentType == null)
+ if (socketsHttpHandlerType == null || httpMessageHandlerType == null || httpClientType == null || httpRequestMessageType == null ||
+ httpResponseMessageType == null || httpResponseHeadersType == null || httpContentType == null)
{
Debug.Fail("Unable to load required type.");
return null;
}
// Get the methods on those types.
+ ConstructorInfo? socketsHttpHandlerCtor = socketsHttpHandlerType.GetConstructor(Type.EmptyTypes);
PropertyInfo? pooledConnectionIdleTimeoutProp = socketsHttpHandlerType.GetProperty("PooledConnectionIdleTimeout");
+ PropertyInfo? allowAutoRedirectProp = socketsHttpHandlerType.GetProperty("AllowAutoRedirect");
+ ConstructorInfo? httpClientCtor = httpClientType.GetConstructor(new Type[] { httpMessageHandlerType });
PropertyInfo? requestUriProp = httpRequestMessageType.GetProperty("RequestUri");
ConstructorInfo? httpRequestMessageCtor = httpRequestMessageType.GetConstructor(Type.EmptyTypes);
MethodInfo? sendMethod = httpClientType.GetMethod("Send", new Type[] { httpRequestMessageType, typeof(CancellationToken) });
PropertyInfo? responseContentProp = httpResponseMessageType.GetProperty("Content");
+ PropertyInfo? responseStatusCodeProp = httpResponseMessageType.GetProperty("StatusCode");
+ PropertyInfo? responseHeadersProp = httpResponseMessageType.GetProperty("Headers");
+ PropertyInfo? responseHeadersLocationProp = httpResponseHeadersType.GetProperty("Location");
MethodInfo? readAsStreamMethod = httpContentType.GetMethod("ReadAsStream", Type.EmptyTypes);
- if (pooledConnectionIdleTimeoutProp == null || requestUriProp == null || httpRequestMessageCtor == null || sendMethod == null || responseContentProp == null || readAsStreamMethod == null)
+
+ if (socketsHttpHandlerCtor == null || pooledConnectionIdleTimeoutProp == null || allowAutoRedirectProp == null || httpClientCtor == null ||
+ requestUriProp == null || httpRequestMessageCtor == null || sendMethod == null || responseContentProp == null || responseStatusCodeProp == null ||
+ responseHeadersProp == null || responseHeadersLocationProp == null || readAsStreamMethod == null)
{
Debug.Fail("Unable to load required member.");
return null;
// Only keep idle connections around briefly, as a compromise between resource leakage and port exhaustion.
const int PooledConnectionIdleTimeoutSeconds = 15;
+ const int MaxRedirections = 10;
// Equivalent of:
- // var socketsHttpHandler = new SocketsHttpHandler() { PooledConnectionIdleTimeout = TimeSpan.FromSeconds(PooledConnectionIdleTimeoutSeconds) };
+ // var socketsHttpHandler = new SocketsHttpHandler() {
+ // PooledConnectionIdleTimeout = TimeSpan.FromSeconds(PooledConnectionIdleTimeoutSeconds),
+ // AllowAutoRedirect = false
+ // };
// var httpClient = new HttpClient(socketsHttpHandler);
- object? socketsHttpHandler = Activator.CreateInstance(socketsHttpHandlerType);
+ // Note: using a ConstructorInfo instead of Activator.CreateInstance, so the ILLinker can see the usage through the lambda method.
+ object? socketsHttpHandler = socketsHttpHandlerCtor.Invoke(null);
pooledConnectionIdleTimeoutProp.SetValue(socketsHttpHandler, TimeSpan.FromSeconds(PooledConnectionIdleTimeoutSeconds));
- object? httpClient = Activator.CreateInstance(httpClientType, new object?[] { socketsHttpHandler });
+ allowAutoRedirectProp.SetValue(socketsHttpHandler, false);
+ object? httpClient = httpClientCtor.Invoke(new object?[] { socketsHttpHandler });
- // Return a delegate for getting the byte[] for a uri. This delegate references the HttpClient object and thus
- // all accesses will be through that singleton.
- return (string uri, CancellationToken cancellationToken) =>
+ return (string uriString, CancellationToken cancellationToken) =>
{
+ Uri uri = new Uri(uriString);
+
+ if (!IsAllowedScheme(uri.Scheme))
+ {
+ return null;
+ }
+
// Equivalent of:
- // HttpResponseMessage resp = httpClient.Send(new HttpRequestMessage() { RequestUri = new Uri(uri) });
- // using Stream responseStream = resp.Content.ReadAsStream();
+ // HttpRequestMessage requestMessage = new HttpRequestMessage() { RequestUri = new Uri(uri) };
+ // HttpResponseMessage responseMessage = httpClient.Send(requestMessage, cancellationToken);
// Note: using a ConstructorInfo instead of Activator.CreateInstance, so the ILLinker can see the usage through the lambda method.
object requestMessage = httpRequestMessageCtor.Invoke(null);
- requestUriProp.SetValue(requestMessage, new Uri(uri));
+ requestUriProp.SetValue(requestMessage, uri);
object responseMessage = sendMethod.Invoke(httpClient, new object[] { requestMessage, cancellationToken })!;
+
+ int redirections = 0;
+ Uri? redirectUri;
+ bool hasRedirect;
+ while (true)
+ {
+ int statusCode = (int)responseStatusCodeProp.GetValue(responseMessage)!;
+ object responseHeaders = responseHeadersProp.GetValue(responseMessage)!;
+ Uri? location = (Uri?)responseHeadersLocationProp.GetValue(responseHeaders);
+ redirectUri = GetUriForRedirect((Uri)requestUriProp.GetValue(requestMessage)!, statusCode, location, out hasRedirect);
+ if (redirectUri == null)
+ {
+ break;
+ }
+
+ ((IDisposable)responseMessage).Dispose();
+
+ redirections++;
+ if (redirections > MaxRedirections)
+ {
+ return null;
+ }
+
+ // Equivalent of:
+ // requestMessage = new HttpRequestMessage() { RequestUri = redirectUri };
+ // requestMessage.RequestUri = redirectUri;
+ // responseMessage = httpClient.Send(requestMessage, cancellationToken);
+ requestMessage = httpRequestMessageCtor.Invoke(null);
+ requestUriProp.SetValue(requestMessage, redirectUri);
+ responseMessage = sendMethod.Invoke(httpClient, new object[] { requestMessage, cancellationToken })!;
+ }
+
+ if (hasRedirect && redirectUri == null)
+ {
+ return null;
+ }
+
+ // Equivalent of:
+ // using Stream responseStream = resp.Content.ReadAsStream();
object content = responseContentProp.GetValue(responseMessage)!;
using Stream responseStream = (Stream)readAsStreamMethod.Invoke(content, null)!;
var result = new MemoryStream();
responseStream.CopyTo(result);
+ ((IDisposable)responseMessage).Dispose();
return result.ToArray();
};
}
return null;
}
}
+
+ private static Uri? GetUriForRedirect(Uri requestUri, int statusCode, Uri? location, out bool hasRedirect)
+ {
+ if (!IsRedirectStatusCode(statusCode))
+ {
+ hasRedirect = false;
+ return null;
+ }
+
+ hasRedirect = true;
+
+ if (location == null)
+ {
+ return null;
+ }
+
+ // Ensure the redirect location is an absolute URI.
+ if (!location.IsAbsoluteUri)
+ {
+ location = new Uri(requestUri, location);
+ }
+
+ // Per https://tools.ietf.org/html/rfc7231#section-7.1.2, a redirect location without a
+ // fragment should inherit the fragment from the original URI.
+ string requestFragment = requestUri.Fragment;
+ if (!string.IsNullOrEmpty(requestFragment))
+ {
+ string redirectFragment = location.Fragment;
+ if (string.IsNullOrEmpty(redirectFragment))
+ {
+ location = new UriBuilder(location) { Fragment = requestFragment }.Uri;
+ }
+ }
+
+ if (!IsAllowedScheme(location.Scheme))
+ {
+ return null;
+ }
+
+ return location;
+ }
+
+ private static bool IsRedirectStatusCode(int statusCode)
+ {
+ // MultipleChoices (300), Moved (301), Found (302), SeeOther (303), TemporaryRedirect (307), PermanentRedirect (308)
+ return (statusCode >= 300 && statusCode <= 303) || statusCode == 307 || statusCode == 308;
+ }
+
+ private static bool IsAllowedScheme(string scheme)
+ {
+ return string.Equals(scheme, "http", StringComparison.OrdinalIgnoreCase);
+ }
}
}