Spanify some Linux SslStreamPal internals and refactor EncryptDecryptHelper (#53512)
authorGeoff Kizer <geoffrek@microsoft.com>
Fri, 4 Jun 2021 19:59:03 +0000 (12:59 -0700)
committerGitHub <noreply@github.com>
Fri, 4 Jun 2021 19:59:03 +0000 (12:59 -0700)
* Spanify some SslStreamPal internals and refactor EncryptDecryptHelper

Co-authored-by: Geoffrey Kizer <geoffrek@windows.microsoft.com>
src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs
src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs

index 71ee76a..42276ab 100644 (file)
@@ -348,7 +348,7 @@ internal static partial class Interop
             return retVal;
         }
 
-        internal static int Decrypt(SafeSslHandle context, byte[] outBuffer, int offset, int count, out Ssl.SslErrorCode errorCode)
+        internal static int Decrypt(SafeSslHandle context, Span<byte> buffer, out Ssl.SslErrorCode errorCode)
         {
 #if DEBUG
             ulong assertNoError = Crypto.ErrPeekError();
@@ -356,53 +356,33 @@ internal static partial class Interop
 #endif
             errorCode = Ssl.SslErrorCode.SSL_ERROR_NONE;
 
-            int retVal = BioWrite(context.InputBio!, outBuffer, offset, count);
-            Exception? innerError = null;
+            BioWrite(context.InputBio!, buffer);
 
-            if (retVal == count)
+            int retVal = Ssl.SslRead(context, ref MemoryMarshal.GetReference(buffer), buffer.Length);
+            if (retVal > 0)
             {
-                unsafe
-                {
-                    fixed (byte* fixedBuffer = outBuffer)
-                    {
-                        retVal = Ssl.SslRead(context, fixedBuffer + offset, outBuffer.Length);
-                    }
-                }
-
-                if (retVal > 0)
-                {
-                        count = retVal;
-                }
+                return retVal;
             }
 
-            if (retVal != count)
+            errorCode = GetSslError(context, retVal, out Exception? innerError);
+            switch (errorCode)
             {
-                errorCode = GetSslError(context, retVal, out innerError);
-            }
-
-            if (retVal != count)
-            {
-                retVal = 0;
-
-                switch (errorCode)
-                {
-                    // indicate end-of-file
-                    case Ssl.SslErrorCode.SSL_ERROR_ZERO_RETURN:
-                        break;
+                // indicate end-of-file
+                case Ssl.SslErrorCode.SSL_ERROR_ZERO_RETURN:
+                    break;
 
-                    case Ssl.SslErrorCode.SSL_ERROR_WANT_READ:
-                        // update error code to renegotiate if renegotiate is pending, otherwise make it SSL_ERROR_WANT_READ
-                        errorCode = Ssl.IsSslRenegotiatePending(context) ?
-                                    Ssl.SslErrorCode.SSL_ERROR_RENEGOTIATE :
-                                    Ssl.SslErrorCode.SSL_ERROR_WANT_READ;
-                        break;
+                case Ssl.SslErrorCode.SSL_ERROR_WANT_READ:
+                    // update error code to renegotiate if renegotiate is pending, otherwise make it SSL_ERROR_WANT_READ
+                    errorCode = Ssl.IsSslRenegotiatePending(context) ?
+                                Ssl.SslErrorCode.SSL_ERROR_RENEGOTIATE :
+                                Ssl.SslErrorCode.SSL_ERROR_WANT_READ;
+                    break;
 
-                    default:
-                        throw new SslException(SR.Format(SR.net_ssl_decrypt_failed, errorCode), innerError);
-                }
+                default:
+                    throw new SslException(SR.Format(SR.net_ssl_decrypt_failed, errorCode), innerError);
             }
 
-            return retVal;
+            return 0;
         }
 
         internal static SafeX509Handle GetPeerCertificate(SafeSslHandle context)
@@ -507,27 +487,13 @@ internal static partial class Interop
             return bytes;
         }
 
-        private static int BioWrite(SafeBioHandle bio, byte[] buffer, int offset, int count)
+        private static void BioWrite(SafeBioHandle bio, ReadOnlySpan<byte> buffer)
         {
-            Debug.Assert(buffer != null);
-            Debug.Assert(offset >= 0);
-            Debug.Assert(count >= 0);
-            Debug.Assert(buffer.Length >= offset + count);
-
-            int bytes;
-            unsafe
-            {
-                fixed (byte* bufPtr = buffer)
-                {
-                    bytes = Ssl.BioWrite(bio, bufPtr + offset, count);
-                }
-            }
-
-            if (bytes != count)
+            int bytes = Ssl.BioWrite(bio, ref MemoryMarshal.GetReference(buffer), buffer.Length);
+            if (bytes != buffer.Length)
             {
                 throw CreateSslException(SR.net_ssl_write_bio_failed_error);
             }
-            return bytes;
         }
 
         private static Ssl.SslErrorCode GetSslError(SafeSslHandle context, int result, out Exception? innerError)
index 057e4dc..d080cf2 100644 (file)
@@ -72,7 +72,7 @@ internal static partial class Interop
         internal static extern int SslWrite(SafeSslHandle ssl, ref byte buf, int num);
 
         [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslRead", SetLastError = true)]
-        internal static extern unsafe int SslRead(SafeSslHandle ssl, byte* buf, int num);
+        internal static extern int SslRead(SafeSslHandle ssl, ref byte buf, int num);
 
         [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_IsSslRenegotiatePending")]
         [return: MarshalAs(UnmanagedType.Bool)]
index 1416e39..98a21e1 100644 (file)
@@ -49,20 +49,51 @@ namespace System.Net.Security
 
         public static SecurityStatusPal EncryptMessage(SafeDeleteSslContext securityContext, ReadOnlyMemory<byte> input, int headerSize, int trailerSize, ref byte[] output, out int resultSize)
         {
-            return EncryptDecryptHelper(securityContext, input, offset: 0, size: 0, encrypt: true, output: ref output, resultSize: out resultSize);
+            try
+            {
+                resultSize = Interop.OpenSsl.Encrypt(securityContext.SslContext, input.Span, ref output, out Interop.Ssl.SslErrorCode errorCode);
+
+                return MapNativeErrorCode(errorCode);
+            }
+            catch (Exception ex)
+            {
+                resultSize = 0;
+                return new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError, ex);
+            }
         }
 
         public static SecurityStatusPal DecryptMessage(SafeDeleteSslContext securityContext, byte[] buffer, ref int offset, ref int count)
         {
-            SecurityStatusPal retVal = EncryptDecryptHelper(securityContext, buffer, offset, count, false, ref buffer, out int resultSize);
-            if (retVal.ErrorCode == SecurityStatusPalErrorCode.OK ||
-                retVal.ErrorCode == SecurityStatusPalErrorCode.Renegotiate)
+            try
             {
-                count = resultSize;
+                int resultSize = Interop.OpenSsl.Decrypt(securityContext.SslContext, new Span<byte>(buffer, offset, count), out Interop.Ssl.SslErrorCode errorCode);
+
+                SecurityStatusPal retVal = MapNativeErrorCode(errorCode);
+
+                if (retVal.ErrorCode == SecurityStatusPalErrorCode.OK ||
+                    retVal.ErrorCode == SecurityStatusPalErrorCode.Renegotiate)
+                {
+                    count = resultSize;
+                }
+
+                return retVal;
+            }
+            catch (Exception ex)
+            {
+                return new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError, ex);
             }
-            return retVal;
         }
 
+        private static SecurityStatusPal MapNativeErrorCode(Interop.Ssl.SslErrorCode errorCode) =>
+            errorCode switch
+            {
+                Interop.Ssl.SslErrorCode.SSL_ERROR_RENEGOTIATE => new SecurityStatusPal(SecurityStatusPalErrorCode.Renegotiate),
+                Interop.Ssl.SslErrorCode.SSL_ERROR_ZERO_RETURN => new SecurityStatusPal(SecurityStatusPalErrorCode.ContextExpired),
+                Interop.Ssl.SslErrorCode.SSL_ERROR_NONE or
+                Interop.Ssl.SslErrorCode.SSL_ERROR_WANT_READ => new SecurityStatusPal(SecurityStatusPalErrorCode.OK),
+                _ => new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError, new Interop.OpenSsl.SslException((int)errorCode))
+            };
+
         public static ChannelBinding? QueryContextChannelBinding(SafeDeleteSslContext securityContext, ChannelBindingKind attribute)
         {
             ChannelBinding? bindingHandle;
@@ -155,42 +186,6 @@ namespace System.Net.Security
             return Interop.Ssl.SslGetAlpnSelected(context.SslContext);
         }
 
-        private static SecurityStatusPal EncryptDecryptHelper(SafeDeleteSslContext securityContext, ReadOnlyMemory<byte> input, int offset, int size, bool encrypt, ref byte[] output, out int resultSize)
-        {
-            resultSize = 0;
-            try
-            {
-                Interop.Ssl.SslErrorCode errorCode = Interop.Ssl.SslErrorCode.SSL_ERROR_NONE;
-                SafeSslHandle scHandle = securityContext.SslContext;
-
-                if (encrypt)
-                {
-                    resultSize = Interop.OpenSsl.Encrypt(scHandle, input.Span, ref output, out errorCode);
-                }
-                else
-                {
-                    resultSize = Interop.OpenSsl.Decrypt(scHandle, output, offset, size, out errorCode);
-                }
-
-                switch (errorCode)
-                {
-                    case Interop.Ssl.SslErrorCode.SSL_ERROR_RENEGOTIATE:
-                        return new SecurityStatusPal(SecurityStatusPalErrorCode.Renegotiate);
-                    case Interop.Ssl.SslErrorCode.SSL_ERROR_ZERO_RETURN:
-                        return new SecurityStatusPal(SecurityStatusPalErrorCode.ContextExpired);
-                    case Interop.Ssl.SslErrorCode.SSL_ERROR_NONE:
-                    case Interop.Ssl.SslErrorCode.SSL_ERROR_WANT_READ:
-                        return new SecurityStatusPal(SecurityStatusPalErrorCode.OK);
-                    default:
-                        return new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError, new Interop.OpenSsl.SslException((int)errorCode));
-                }
-            }
-            catch (Exception ex)
-            {
-                return new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError, ex);
-            }
-        }
-
         public static SecurityStatusPal ApplyAlertToken(ref SafeFreeCredentials? credentialsHandle, SafeDeleteContext? securityContext, TlsAlertType alertType, TlsAlertMessage alertMessage)
         {
             // There doesn't seem to be an exposed API for writing an alert,