Add UnmanagedCallersOnly attribute to SafeDeleteSslContext.ReadFromConnection/WriteTo...
authorMaxim Lipnin <v-maxlip@microsoft.com>
Tue, 17 Aug 2021 16:33:06 +0000 (19:33 +0300)
committerGitHub <noreply@github.com>
Tue, 17 Aug 2021 16:33:06 +0000 (18:33 +0200)
* Add UnmanagedCallersOnly attribute to  SafeDeleteSslContext.ReadFromConnection/WriteToConnection methods

* Fix the build

Co-authored-by: Jan Kotas <jkotas@microsoft.com>
* Fix the build

* Add SslSetConnection interop method to make sure the right SafeDeleteSslContext instance is associated to an ssl session

* Update entrypoints.c with new DllImport

Co-authored-by: Jan Kotas <jkotas@microsoft.com>
Co-authored-by: Steve Pfister <steve.pfister@microsoft.com>
Co-authored-by: Alexander Köplinger <alex.koeplinger@outlook.com>
src/libraries/Common/src/Interop/OSX/System.Security.Cryptography.Native.Apple/Interop.Ssl.cs
src/libraries/Native/Unix/System.Security.Cryptography.Native.Apple/entrypoints.c
src/libraries/Native/Unix/System.Security.Cryptography.Native.Apple/pal_ssl.c
src/libraries/Native/Unix/System.Security.Cryptography.Native.Apple/pal_ssl.h
src/libraries/System.Net.Security/src/System/Net/Security/Pal.OSX/SafeDeleteSslContext.cs

index 26bf0a3..f38e947 100644 (file)
@@ -60,6 +60,11 @@ internal static partial class Interop
         [DllImport(Interop.Libraries.AppleCryptoNative, EntryPoint = "AppleCryptoNative_SslCreateContext")]
         internal static extern System.Net.SafeSslHandle SslCreateContext(int isServer);
 
+        [DllImport(Interop.Libraries.AppleCryptoNative, EntryPoint = "AppleCryptoNative_SslSetConnection")]
+        internal static extern int SslSetConnection(
+            SafeSslHandle sslHandle,
+            IntPtr sslConnection);
+
         [DllImport(Interop.Libraries.AppleCryptoNative)]
         private static extern int AppleCryptoNative_SslSetMinProtocolVersion(
             SafeSslHandle sslHandle,
@@ -119,10 +124,10 @@ internal static partial class Interop
         private static extern int AppleCryptoNative_SslSetAcceptClientCert(SafeSslHandle sslHandle);
 
         [DllImport(Interop.Libraries.AppleCryptoNative, EntryPoint = "AppleCryptoNative_SslSetIoCallbacks")]
-        internal static extern int SslSetIoCallbacks(
+        internal static extern unsafe int SslSetIoCallbacks(
             SafeSslHandle sslHandle,
-            SSLReadFunc readCallback,
-            SSLWriteFunc writeCallback);
+            delegate* unmanaged<IntPtr, byte*, void**, int> readCallback,
+            delegate* unmanaged<IntPtr, byte*, void**, int> writeCallback);
 
         [DllImport(Interop.Libraries.AppleCryptoNative, EntryPoint = "AppleCryptoNative_SslWrite")]
         internal static extern unsafe PAL_TlsIo SslWrite(SafeSslHandle sslHandle, byte* writeFrom, int count, out int bytesWritten);
index 5162b7a..db75e90 100644 (file)
@@ -78,6 +78,7 @@ static const Entry s_cryptoAppleNative[] =
     DllImportEntry(AppleCryptoNative_SecKeyCopyExternalRepresentation)
     DllImportEntry(AppleCryptoNative_SecKeyCopyPublicKey)
     DllImportEntry(AppleCryptoNative_SslCreateContext)
+    DllImportEntry(AppleCryptoNative_SslSetConnection)
     DllImportEntry(AppleCryptoNative_SslSetAcceptClientCert)
     DllImportEntry(AppleCryptoNative_SslSetMinProtocolVersion)
     DllImportEntry(AppleCryptoNative_SslSetMaxProtocolVersion)
index 2b5d3a3..2d66847 100644 (file)
@@ -23,6 +23,14 @@ SSLContextRef AppleCryptoNative_SslCreateContext(int32_t isServer)
 #pragma clang diagnostic pop
 }
 
+int32_t AppleCryptoNative_SslSetConnection(SSLContextRef sslContext, SSLConnectionRef sslConnection)
+{
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wdeprecated-declarations"
+    return SSLSetConnection(sslContext, sslConnection);
+#pragma clang diagnostic pop
+}
+
 int32_t AppleCryptoNative_SslSetAcceptClientCert(SSLContextRef sslContext)
 {
 #pragma clang diagnostic push
index 8c3b61d..8a992bb 100644 (file)
@@ -36,6 +36,13 @@ Returns NULL if an invalid boolean is given for isServer, an SSLContextRef other
 PALEXPORT SSLContextRef AppleCryptoNative_SslCreateContext(int32_t isServer);
 
 /*
+Data that is used to uniquely identify an SSL session.
+
+Returns the result of SSLSetConnection
+*/
+PALEXPORT int32_t AppleCryptoNative_SslSetConnection(SSLContextRef sslContext, SSLConnectionRef sslConnection);
+
+/*
 Indicate that an SSL Context (in server mode) should allow a client to present a mutual auth cert.
 
 Returns The result of SSLSetClientSideAuthenticate
index 18585ee..3cd2b92 100644 (file)
@@ -5,6 +5,7 @@ using System.Collections.Generic;
 using System.Diagnostics;
 using System.Net.Http;
 using System.Net.Security;
+using System.Runtime.InteropServices;
 using System.Security.Authentication;
 using System.Security.Cryptography.X509Certificates;
 using Microsoft.Win32.SafeHandles;
@@ -22,8 +23,6 @@ namespace System.Net
         private const int OSStatus_errSSLWouldBlock = -9803;
         private const int InitialBufferSize = 2048;
         private SafeSslHandle _sslContext;
-        private Interop.AppleCrypto.SSLReadFunc _readCallback;
-        private Interop.AppleCrypto.SSLWriteFunc _writeCallback;
         private ArrayBuffer _inputBuffer = new ArrayBuffer(InitialBufferSize);
         private ArrayBuffer _outputBuffer = new ArrayBuffer(InitialBufferSize);
 
@@ -38,19 +37,20 @@ namespace System.Net
             {
                 int osStatus;
 
+                _sslContext = CreateSslContext(credential, sslAuthenticationOptions.IsServer);
+
+                // Make sure the class instance is associated to the session and is provided
+                // in the Read/Write callback connection parameter
+                SslSetConnection(_sslContext);
+
                 unsafe
                 {
-                    _readCallback = ReadFromConnection;
-                    _writeCallback = WriteToConnection;
+                    osStatus = Interop.AppleCrypto.SslSetIoCallbacks(
+                        _sslContext,
+                        &ReadFromConnection,
+                        &WriteToConnection);
                 }
 
-                _sslContext = CreateSslContext(credential, sslAuthenticationOptions.IsServer);
-
-                osStatus = Interop.AppleCrypto.SslSetIoCallbacks(
-                    _sslContext,
-                    _readCallback,
-                    _writeCallback);
-
                 if (osStatus != 0)
                 {
                     throw Interop.AppleCrypto.CreateExceptionForOSStatus(osStatus);
@@ -142,6 +142,13 @@ namespace System.Net
             return sslContext;
         }
 
+        private void SslSetConnection(SafeSslHandle sslContext)
+        {
+            GCHandle handle = GCHandle.Alloc(this, GCHandleType.Weak);
+
+            Interop.AppleCrypto.SslSetConnection(sslContext, GCHandle.ToIntPtr(handle));
+        }
+
         public override bool IsInvalid => _sslContext?.IsInvalid ?? true;
 
         protected override void Dispose(bool disposing)
@@ -160,8 +167,12 @@ namespace System.Net
             base.Dispose(disposing);
         }
 
-        private unsafe int WriteToConnection(void* connection, byte* data, void** dataLength)
+        [UnmanagedCallersOnly]
+        private static unsafe int WriteToConnection(IntPtr connection, byte* data, void** dataLength)
         {
+            SafeDeleteSslContext? context = (SafeDeleteSslContext?)GCHandle.FromIntPtr(connection).Target;
+            Debug.Assert(context != null);
+
             // We don't pool these buffers and we can't because there's a race between their us in the native
             // read/write callbacks and being disposed when the SafeHandle is disposed. This race is benign currently,
             // but if we were to pool the buffers we would have a potential use-after-free issue.
@@ -173,9 +184,9 @@ namespace System.Net
                 int toWrite = (int)length;
                 var inputBuffer = new ReadOnlySpan<byte>(data, toWrite);
 
-                _outputBuffer.EnsureAvailableSpace(toWrite);
-                inputBuffer.CopyTo(_outputBuffer.AvailableSpan);
-                _outputBuffer.Commit(toWrite);
+                context._outputBuffer.EnsureAvailableSpace(toWrite);
+                inputBuffer.CopyTo(context._outputBuffer.AvailableSpan);
+                context._outputBuffer.Commit(toWrite);
                 // Since we can enqueue everything, no need to re-assign *dataLength.
 
                 return OSStatus_noErr;
@@ -183,13 +194,17 @@ namespace System.Net
             catch (Exception e)
             {
                 if (NetEventSource.Log.IsEnabled())
-                    NetEventSource.Error(this, $"WritingToConnection failed: {e.Message}");
+                    NetEventSource.Error(context, $"WritingToConnection failed: {e.Message}");
                 return OSStatus_writErr;
             }
         }
 
-        private unsafe int ReadFromConnection(void* connection, byte* data, void** dataLength)
+        [UnmanagedCallersOnly]
+        private static unsafe int ReadFromConnection(IntPtr connection, byte* data, void** dataLength)
         {
+            SafeDeleteSslContext? context = (SafeDeleteSslContext?)GCHandle.FromIntPtr(connection).Target;
+            Debug.Assert(context != null);
+
             try
             {
                 ulong toRead = (ulong)*dataLength;
@@ -201,16 +216,16 @@ namespace System.Net
 
                 uint transferred = 0;
 
-                if (_inputBuffer.ActiveLength == 0)
+                if (context._inputBuffer.ActiveLength == 0)
                 {
                     *dataLength = (void*)0;
                     return OSStatus_errSSLWouldBlock;
                 }
 
-                int limit = Math.Min((int)toRead, _inputBuffer.ActiveLength);
+                int limit = Math.Min((int)toRead, context._inputBuffer.ActiveLength);
 
-                _inputBuffer.ActiveSpan.Slice(0, limit).CopyTo(new Span<byte>(data, limit));
-                _inputBuffer.Discard(limit);
+                context._inputBuffer.ActiveSpan.Slice(0, limit).CopyTo(new Span<byte>(data, limit));
+                context._inputBuffer.Discard(limit);
                 transferred = (uint)limit;
 
                 *dataLength = (void*)transferred;
@@ -219,7 +234,7 @@ namespace System.Net
             catch (Exception e)
             {
                 if (NetEventSource.Log.IsEnabled())
-                    NetEventSource.Error(this, $"ReadFromConnectionfailed: {e.Message}");
+                    NetEventSource.Error(context, $"ReadFromConnectionfailed: {e.Message}");
                 return OSStatus_readErr;
             }
         }