Support sites with invalid IDN in SslStream (#82934)
authorTomas Weinfurt <tweinfurt@yahoo.com>
Fri, 10 Mar 2023 19:00:18 +0000 (11:00 -0800)
committerGitHub <noreply@github.com>
Fri, 10 Mar 2023 19:00:18 +0000 (11:00 -0800)
* initial test

* 'update'

* feedback from review

* android

src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs
src/libraries/Common/src/Interop/Windows/SspiCli/SecuritySafeHandles.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslAuthenticationOptions.cs
src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamSniTest.cs

index 3738267..485b08c 100644 (file)
@@ -5,7 +5,6 @@ using System;
 using System.Collections.Concurrent;
 using System.Collections.Generic;
 using System.Diagnostics;
-using System.Globalization;
 using System.IO;
 using System.Net;
 using System.Net.Security;
@@ -25,7 +24,6 @@ internal static partial class Interop
         private const string TlsCacheSizeCtxName = "System.Net.Security.TlsCacheSize";
         private const string TlsCacheSizeEnvironmentVariable = "DOTNET_SYSTEM_NET_SECURITY_TLSCACHESIZE";
         private const SslProtocols FakeAlpnSslProtocol = (SslProtocols)1;   // used to distinguish server sessions with ALPN
-        private static readonly IdnMapping s_idnMapping = new IdnMapping();
         private static readonly ConcurrentDictionary<SslProtocols, SafeSslContextHandle> s_clientSslContexts = new ConcurrentDictionary<SslProtocols, SafeSslContextHandle>();
 
         #region internal methods
@@ -385,21 +383,22 @@ internal static partial class Interop
 
                 if (sslAuthenticationOptions.IsClient)
                 {
-                    // The IdnMapping converts unicode input into the IDNA punycode sequence.
-                    string punyCode = string.IsNullOrEmpty(sslAuthenticationOptions.TargetHost) ? string.Empty : s_idnMapping.GetAscii(sslAuthenticationOptions.TargetHost!);
-
-                    // Similar to windows behavior, set SNI on openssl by default for client context, ignore errors.
-                    if (!Ssl.SslSetTlsExtHostName(sslHandle, punyCode))
+                    if (!string.IsNullOrEmpty(sslAuthenticationOptions.TargetHost))
                     {
-                        Crypto.ErrClearError();
-                    }
+                        // Similar to windows behavior, set SNI on openssl by default for client context, ignore errors.
+                        if (!Ssl.SslSetTlsExtHostName(sslHandle, sslAuthenticationOptions.TargetHost))
+                        {
+                            Crypto.ErrClearError();
+                        }
 
-                    if (cacheSslContext && !string.IsNullOrEmpty(punyCode))
-                    {
-                        sslCtxHandle.TrySetSession(sslHandle, punyCode);
-                        bool ignored = false;
-                        sslCtxHandle.DangerousAddRef(ref ignored);
-                        sslHandle.SslContextHandle = sslCtxHandle;
+
+                        if (cacheSslContext)
+                        {
+                            sslCtxHandle.TrySetSession(sslHandle, sslAuthenticationOptions.TargetHost);
+                            bool ignored = false;
+                            sslCtxHandle.DangerousAddRef(ref ignored);
+                            sslHandle.SslContextHandle = sslCtxHandle;
+                        }
                     }
 
                     // relevant to TLS 1.3 only: if user supplied a client cert or cert callback,
@@ -745,16 +744,18 @@ internal static partial class Interop
             Debug.Assert(session != IntPtr.Zero);
 
             IntPtr ptr = Ssl.SslGetData(ssl);
-            Debug.Assert(ptr != IntPtr.Zero);
-            GCHandle gch = GCHandle.FromIntPtr(ptr);
-
-            SafeSslContextHandle? ctxHandle = gch.Target as SafeSslContextHandle;
-            // There is no relation between SafeSslContextHandle and SafeSslHandle so the handle
-            // may be released while the ssl session is still active.
-            if (ctxHandle != null && ctxHandle.TryAddSession(Ssl.SslGetServerName(ssl), session))
+            if (ptr != IntPtr.Zero)
             {
-                // offered session was stored in our cache.
-                return 1;
+                GCHandle gch = GCHandle.FromIntPtr(ptr);
+
+                SafeSslContextHandle? ctxHandle = gch.Target as SafeSslContextHandle;
+                // There is no relation between SafeSslContextHandle and SafeSslHandle so the handle
+                // may be released while the ssl session is still active.
+                if (ctxHandle != null && ctxHandle.TryAddSession(Ssl.SslGetServerName(ssl), session))
+                {
+                    // offered session was stored in our cache.
+                    return 1;
+                }
             }
 
             // OpenSSL will destroy session.
index 10a2a7a..addbdd6 100644 (file)
@@ -2,7 +2,6 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System.Diagnostics;
-using System.Globalization;
 using System.Runtime.InteropServices;
 using System.Security.Authentication.ExtendedProtection;
 using System.Security.Cryptography.X509Certificates;
@@ -334,9 +333,6 @@ namespace System.Net.Security
     internal abstract partial class SafeDeleteContext : SafeHandle
     {
 #endif
-        private const string dummyStr = " ";
-        private static readonly IdnMapping s_idnMapping = new IdnMapping();
-
         protected SafeFreeCredentials? _EffectiveCredential;
 
         //-------------------------------------------------------------------
@@ -453,18 +449,12 @@ namespace System.Net.Security
                             }
                         }
 
-                        if (targetName == null || targetName.Length == 0)
-                        {
-                            targetName = dummyStr;
-                        }
-
-                        string punyCode = s_idnMapping.GetAscii(targetName);
-                        fixed (char* namePtr = punyCode)
+                        fixed (char* namePtr = targetName)
                         {
                             errorCode = MustRunInitializeSecurityContext(
                                             ref inCredentials,
                                             isContextAbsent,
-                                            (byte*)(((object)targetName == (object)dummyStr) ? null : namePtr),
+                                            (byte*)namePtr,
                                             inFlags,
                                             endianness,
                                             &inSecurityBufferDescriptor,
@@ -514,7 +504,7 @@ namespace System.Net.Security
                                 errorCode = MustRunInitializeSecurityContext(
                                              ref inCredentials,
                                              isContextAbsent,
-                                             (byte*)(((object)targetName == (object)dummyStr) ? null : namePtr),
+                                             (byte*)namePtr,
                                              inFlags,
                                              endianness,
                                              &inSecurityBufferDescriptor,
index 38149db..4db9d20 100644 (file)
@@ -1,8 +1,10 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System.Buffers;
 using System.Collections.Generic;
 using System.Diagnostics;
+using System.Globalization;
 using System.Runtime.InteropServices;
 using System.Security.Authentication;
 using System.Security.Cryptography.X509Certificates;
@@ -11,6 +13,8 @@ namespace System.Net.Security
 {
     internal sealed class SslAuthenticationOptions
     {
+        private static readonly IdnMapping s_idnMapping = new IdnMapping();
+
         // Simplified version of IPAddressParser.Parse to avoid allocations and dependencies.
         // It purposely ignores scopeId as we don't really use so we do not need to map it to actual interface id.
         private static unsafe bool IsValidAddress(ReadOnlySpan<char> ipSpan)
@@ -46,6 +50,12 @@ namespace System.Net.Security
             return false;
         }
 
+        private static readonly IndexOfAnyValues<char> s_safeDnsChars =
+            IndexOfAnyValues.Create("-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz");
+
+        private static bool IsSafeDnsString(ReadOnlySpan<char> name) =>
+            name.IndexOfAnyExcept(s_safeDnsChars) < 0;
+
         internal SslAuthenticationOptions()
         {
             TargetHost = string.Empty;
@@ -86,13 +96,25 @@ namespace System.Net.Security
             if (!string.IsNullOrEmpty(sslClientAuthenticationOptions.TargetHost))
             {
                 // RFC 6066 section 3 says to exclude trailing dot from fully qualified DNS hostname
-                TargetHost = sslClientAuthenticationOptions.TargetHost.TrimEnd('.');
+                string targetHost = sslClientAuthenticationOptions.TargetHost.TrimEnd('.');
 
                 // RFC 6066 forbids IP literals
-                if (IsValidAddress(TargetHost))
+                if (IsValidAddress(targetHost))
                 {
                     TargetHost = string.Empty;
                 }
+                else
+                {
+                    try
+                    {
+                        TargetHost = s_idnMapping.GetAscii(targetHost);
+                    }
+                    catch (ArgumentException) when (IsSafeDnsString(targetHost))
+                    {
+                        // Seems like name that does not confrom to IDN but apers somewhat valid according to orogional DNS rfc.
+                        TargetHost = targetHost;
+                    }
+                }
             }
 
             // Client specific options.
index 1a741bf..4a0cdf4 100644 (file)
@@ -202,6 +202,88 @@ namespace System.Net.Security.Tests
             Assert.Equal(string.Empty, server.TargetHostName);
         }
 
+        [Theory]
+        [InlineData("\u00E1b\u00E7d\u00EB.com")]
+        [InlineData("\u05D1\u05F1.com")]
+        [InlineData("\u30B6\u30C7\u30D8.com")]
+        public async Task SslStream_ValidIdn_Success(string name)
+        {
+            (SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams();
+            using (client)
+            using (server)
+            {
+                using X509Certificate2 serverCertificate = Configuration.Certificates.GetServerCertificate();
+                using X509Certificate2 clientCertificate = Configuration.Certificates.GetClientCertificate();
+
+                SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = serverCertificate };
+                SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions()
+                {
+                    TargetHost = name,
+                    CertificateChainPolicy = new X509ChainPolicy() { VerificationFlags = X509VerificationFlags.IgnoreInvalidName },
+                    RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true
+                };
+
+                await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
+                                client.AuthenticateAsClientAsync(clientOptions, default),
+                                server.AuthenticateAsServerAsync(serverOptions, default));
+
+                await TestHelper.PingPong(client, server, default);
+                Assert.Equal(name, server.TargetHostName);
+            }
+        }
+
+        [Theory]
+        [InlineData("www-.volal.cz")]
+        [InlineData("www-.colorhexa.com")]
+        [InlineData("xn--www-7m0a.thegratuit.com")]
+        [ActiveIssue("https://github.com/dotnet/runtime/issues/68206", TestPlatforms.Android)]
+        public async Task SslStream_SafeInvalidIdn_Success(string name)
+        {
+            (SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams();
+            using (client)
+            using (server)
+            {
+                using X509Certificate2 serverCertificate = Configuration.Certificates.GetServerCertificate();
+                using X509Certificate2 clientCertificate = Configuration.Certificates.GetClientCertificate();
+
+                SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = serverCertificate };
+                SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions()
+                {
+                    TargetHost = name,
+                    CertificateChainPolicy = new X509ChainPolicy() { VerificationFlags = X509VerificationFlags.IgnoreInvalidName },
+                    RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true
+                };
+
+                await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
+                                client.AuthenticateAsClientAsync(clientOptions, default),
+                                server.AuthenticateAsServerAsync(serverOptions, default));
+
+                await TestHelper.PingPong(client, server, default);
+                Assert.Equal(name, server.TargetHostName);
+            }
+        }
+
+        [Theory]
+        [InlineData("\u0000\u00E7d\u00EB.com")]
+        public async Task SslStream_UnsafeInvalidIdn_Throws(string name)
+        {
+            (SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams();
+            using (client)
+            using (server)
+            {
+                using X509Certificate2 serverCertificate = Configuration.Certificates.GetServerCertificate();
+
+                SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions()
+                {
+                    TargetHost = name,
+                    CertificateChainPolicy = new X509ChainPolicy() { VerificationFlags = X509VerificationFlags.IgnoreInvalidName },
+                    RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true
+                };
+
+                await Assert.ThrowsAsync<ArgumentException>(() => client.AuthenticateAsClientAsync(clientOptions, default));
+            }
+        }
+
         private static Func<Task> WithAggregateExceptionUnwrapping(Func<Task> a)
         {
             return async () => {