Enable Marshal.SecureStringToBSTR and Marshal.ZeroFreeBSTR on Unix (#11234)
authorJohn Bottenberg <jobotten@microsoft.com>
Tue, 2 May 2017 20:54:52 +0000 (13:54 -0700)
committerJan Kotas <jkotas@microsoft.com>
Tue, 2 May 2017 20:54:52 +0000 (13:54 -0700)
13 files changed:
src/mscorlib/System.Private.CoreLib.csproj
src/mscorlib/shared/Interop/Windows/Interop.Libraries.cs
src/mscorlib/shared/Interop/Windows/NtDll/Interop.ZeroMemory.cs [deleted file]
src/mscorlib/shared/System.Private.CoreLib.Shared.projitems
src/mscorlib/shared/System/Security/SafeBSTRHandle.cs
src/mscorlib/shared/System/Security/SecureString.Unix.cs
src/mscorlib/shared/System/Security/SecureString.Windows.cs
src/mscorlib/src/Microsoft/Win32/Win32Native.cs
src/mscorlib/src/System/Runtime/InteropServices/Marshal.cs
src/mscorlib/src/System/Runtime/InteropServices/NonPortable.cs
src/mscorlib/src/System/Runtime/InteropServices/PInvokeMarshal.cs [new file with mode: 0644]
src/mscorlib/src/System/Runtime/RuntimeImports.cs
tests/src/Interop/MarshalAPI/String/StringMarshalingTest.cs

index 161134d..bb905eb 100644 (file)
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\Marshal.cs" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\MarshalDirectiveException.cs" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\PInvokeMap.cs" />
+    <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\PInvokeMarshal.cs" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\RuntimeEnvironment.cs" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\SEHException.cs" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\SafeBuffer.cs" />
index 58bb12d..bf07a68 100644 (file)
@@ -9,7 +9,6 @@ internal static partial class Interop
         internal const string BCrypt = "BCrypt.dll";
         internal const string Crypt32 = "crypt32.dll";
         internal const string Kernel32 = "kernel32.dll";
-        internal const string NtDll = "ntdll.dll";
         internal const string OleAut32 = "oleaut32.dll";
     }
 }
diff --git a/src/mscorlib/shared/Interop/Windows/NtDll/Interop.ZeroMemory.cs b/src/mscorlib/shared/Interop/Windows/NtDll/Interop.ZeroMemory.cs
deleted file mode 100644 (file)
index 9bf7321..0000000
+++ /dev/null
@@ -1,16 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System;
-using System.Runtime.InteropServices;
-using System.Security;
-
-internal partial class Interop
-{
-    internal partial class NtDll
-    {
-        [DllImport(Libraries.NtDll, CharSet = CharSet.Unicode, EntryPoint = "RtlZeroMemory")]
-        internal static extern void ZeroMemory(IntPtr address, UIntPtr length);
-    }
-}
index df66f1b..75d31c0 100644 (file)
     <Compile Include="$(MSBuildThisFileDirectory)Interop\Windows\Kernel32\Interop.WideCharToMultiByte.cs"/>
     <Compile Include="$(MSBuildThisFileDirectory)Interop\Windows\Kernel32\Interop.WriteFile_SafeHandle_IntPtr.cs"/>
     <Compile Include="$(MSBuildThisFileDirectory)Interop\Windows\Kernel32\Interop.WriteFile_SafeHandle_NativeOverlapped.cs"/>
-    <Compile Include="$(MSBuildThisFileDirectory)Interop\Windows\NtDll\Interop.ZeroMemory.cs"/>
     <Compile Include="$(MSBuildThisFileDirectory)Interop\Windows\OleAut32\Interop.SysAllocStringLen.cs"/>
     <Compile Include="$(MSBuildThisFileDirectory)Interop\Windows\OleAut32\Interop.SysFreeString.cs"/>
     <Compile Include="$(MSBuildThisFileDirectory)Interop\Windows\OleAut32\Interop.SysStringLen.cs"/>
index a1164dc..227fed3 100644 (file)
@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
+using System.Runtime;
 using System.Diagnostics;
 using System.Runtime.InteropServices;
 
@@ -25,7 +26,7 @@ namespace System.Security
 
         override protected bool ReleaseHandle()
         {
-            Interop.NtDll.ZeroMemory(handle, (UIntPtr)(Interop.OleAut32.SysStringLen(handle) * sizeof(char)));
+            RuntimeImports.RhZeroMemory(handle, (UIntPtr)(Interop.OleAut32.SysStringLen(handle) * sizeof(char)));
             Interop.OleAut32.SysFreeString(handle);
             return true;
         }
@@ -36,7 +37,7 @@ namespace System.Security
             try
             {
                 AcquirePointer(ref bufferPtr);
-                Interop.NtDll.ZeroMemory((IntPtr)bufferPtr, (UIntPtr)(Interop.OleAut32.SysStringLen((IntPtr)bufferPtr) * sizeof(char)));
+                RuntimeImports.RhZeroMemory((IntPtr)bufferPtr, (UIntPtr)(Interop.OleAut32.SysStringLen((IntPtr)bufferPtr) * sizeof(char)));
             }
             finally
             {
index 0ef38e4..cfeebc1 100644 (file)
@@ -3,6 +3,7 @@
 // See the LICENSE file in the project root for more information.
 
 using System.Diagnostics;
+using System.Runtime;
 using System.Runtime.InteropServices;
 using System.Text;
 
@@ -142,6 +143,41 @@ namespace System.Security
             _buffer.Write((ulong)(index * sizeof(char)), c);
         }
 
+        internal unsafe IntPtr MarshalToBSTR()
+        {
+            int length = _decryptedLength;
+            IntPtr ptr = IntPtr.Zero;
+            IntPtr result = IntPtr.Zero;
+            byte* bufferPtr = null;
+            
+            try
+            {
+                _buffer.AcquirePointer(ref bufferPtr);
+                int resultByteLength = (length + 1) * sizeof(char);
+
+                ptr = PInvokeMarshal.AllocBSTR(length);
+
+                Buffer.MemoryCopy(bufferPtr, (byte*)ptr, resultByteLength, length * sizeof(char));
+
+                result = ptr;
+            }
+            finally
+            {
+                // If we failed for any reason, free the new buffer
+                if (result == IntPtr.Zero && ptr != IntPtr.Zero)
+                {
+                    RuntimeImports.RhZeroMemory(ptr, (UIntPtr)(length * sizeof(char)));
+                    PInvokeMarshal.FreeBSTR(ptr);
+                }
+
+                if (bufferPtr != null)
+                {
+                    _buffer.ReleasePointer();
+                }
+            }
+            return result;
+        }
+
         internal unsafe IntPtr MarshalToStringCore(bool globalAlloc, bool unicode)
         {
             int length = _decryptedLength;
@@ -179,7 +215,7 @@ namespace System.Security
                 // release the string if we had one.
                 if (stringPtr != IntPtr.Zero && result == IntPtr.Zero)
                 {
-                    UnmanagedBuffer.ZeroMemory((byte*)stringPtr, (ulong)(length * sizeof(char)));
+                    RuntimeImports.RhZeroMemory(stringPtr, (UIntPtr)(length * sizeof(char)));
                     MarshalFree(stringPtr, globalAlloc);
                 }
 
@@ -241,7 +277,7 @@ namespace System.Security
                 try
                 {
                     AcquirePointer(ref ptr);
-                    ZeroMemory(ptr, ByteLength);
+                    RuntimeImports.RhZeroMemory((IntPtr)ptr, (UIntPtr)ByteLength);
                 }
                 finally
                 {
@@ -284,12 +320,6 @@ namespace System.Security
                 Marshal.FreeHGlobal(handle);
                 return true;
             }
-
-            internal static unsafe void ZeroMemory(byte* ptr, ulong len)
-            {
-                for (ulong i = 0; i < len; i++) *ptr++ = 0;
-            }
         }
-
     }
 }
index 13f75a3..2a80081 100644 (file)
@@ -3,6 +3,7 @@
 // See the LICENSE file in the project root for more information.
 
 using System.Diagnostics;
+using System.Runtime;
 using System.Runtime.InteropServices;
 using System.Security.Cryptography;
 using Microsoft.Win32;
@@ -157,11 +158,7 @@ namespace System.Security
                 _buffer.AcquirePointer(ref bufferPtr);
                 int resultByteLength = (length + 1) * sizeof(char);
 
-                ptr = Interop.OleAut32.SysAllocStringLen(null, length);
-                if (ptr == IntPtr.Zero)
-                {
-                    throw new OutOfMemoryException();
-                }
+                ptr = PInvokeMarshal.AllocBSTR(length);
 
                 Buffer.MemoryCopy(bufferPtr, (byte*)ptr, resultByteLength, length * sizeof(char));
 
@@ -174,8 +171,8 @@ namespace System.Security
                 // If we failed for any reason, free the new buffer
                 if (result == IntPtr.Zero && ptr != IntPtr.Zero)
                 {
-                    Interop.NtDll.ZeroMemory(ptr, (UIntPtr)(length * sizeof(char)));
-                    Interop.OleAut32.SysFreeString(ptr);
+                    RuntimeImports.RhZeroMemory(ptr, (UIntPtr)(length * sizeof(char)));
+                    PInvokeMarshal.FreeBSTR(ptr);
                 }
 
                 if (bufferPtr != null)
@@ -223,7 +220,7 @@ namespace System.Security
                 // If we failed for any reason, free the new buffer
                 if (result == IntPtr.Zero && ptr != IntPtr.Zero)
                 {
-                    Interop.NtDll.ZeroMemory(ptr, (UIntPtr)(length * sizeof(char)));
+                    RuntimeImports.RhZeroMemory(ptr, (UIntPtr)(length * sizeof(char)));
                     MarshalFree(ptr, globalAlloc);
                 }
 
index 8543bc8..8145a95 100644 (file)
@@ -461,13 +461,11 @@ namespace Microsoft.Win32
         internal const String USER32 = "user32.dll";
         internal const String OLE32 = "ole32.dll";
         internal const String OLEAUT32 = "oleaut32.dll";
-        internal const String NTDLL = "ntdll.dll";
 #else //FEATURE_PAL
         internal const String KERNEL32 = "libcoreclr";
         internal const String USER32   = "libcoreclr";
         internal const String OLE32    = "libcoreclr";
         internal const String OLEAUT32 = "libcoreclr";
-        internal const String NTDLL    = "libcoreclr";
 #endif //FEATURE_PAL         
         internal const String ADVAPI32 = "advapi32.dll";
         internal const String SHELL32 = "shell32.dll";
@@ -509,10 +507,6 @@ namespace Microsoft.Win32
         [DllImport(KERNEL32, SetLastError = true)]
         internal static extern IntPtr LocalFree(IntPtr handle);
 
-        // MSDN says the length is a SIZE_T.
-        [DllImport(NTDLL, EntryPoint = "RtlZeroMemory")]
-        internal static extern void ZeroMemory(IntPtr address, UIntPtr length);
-
         internal static bool GlobalMemoryStatusEx(ref MEMORYSTATUSEX buffer)
         {
             buffer.length = Marshal.SizeOf(typeof(MEMORYSTATUSEX));
index 9d9a57b..97f1bc5 100644 (file)
@@ -15,6 +15,7 @@
 namespace System.Runtime.InteropServices
 {
     using System;
+    using System.Runtime;
     using System.Collections.Generic;
     using System.Reflection;
     using System.Reflection.Emit;
@@ -1843,11 +1844,7 @@ namespace System.Runtime.InteropServices
             }
             Contract.EndContractBlock();
 
-#if FEATURE_COMINTEROP
             return s.MarshalToBSTR();
-#else
-            throw new PlatformNotSupportedException(); // https://github.com/dotnet/coreclr/issues/10443
-#endif
         }
 
         public static IntPtr SecureStringToCoTaskMemAnsi(SecureString s)
@@ -1871,30 +1868,28 @@ namespace System.Runtime.InteropServices
 
             return s.MarshalToString(globalAlloc: false, unicode: true);
         }
-
-#if FEATURE_COMINTEROP
+        
         public static void ZeroFreeBSTR(IntPtr s)
         {
-            Win32Native.ZeroMemory(s, (UIntPtr)(Win32Native.SysStringLen(s) * 2));
+            RuntimeImports.RhZeroMemory(s, (UIntPtr)(Win32Native.SysStringLen(s) * 2));
             FreeBSTR(s);
         }
-#endif
 
         public static void ZeroFreeCoTaskMemAnsi(IntPtr s)
         {
-            Win32Native.ZeroMemory(s, (UIntPtr)(Win32Native.lstrlenA(s)));
+            RuntimeImports.RhZeroMemory(s, (UIntPtr)(Win32Native.lstrlenA(s)));
             FreeCoTaskMem(s);
         }
 
         public static void ZeroFreeCoTaskMemUnicode(IntPtr s)
         {
-            Win32Native.ZeroMemory(s, (UIntPtr)(Win32Native.lstrlenW(s) * 2));
+            RuntimeImports.RhZeroMemory(s, (UIntPtr)(Win32Native.lstrlenW(s) * 2));
             FreeCoTaskMem(s);
         }
 
         unsafe public static void ZeroFreeCoTaskMemUTF8(IntPtr s)
         {
-            Win32Native.ZeroMemory(s, (UIntPtr)System.StubHelpers.StubHelpers.strlen((sbyte*)s));
+            RuntimeImports.RhZeroMemory(s, (UIntPtr)System.StubHelpers.StubHelpers.strlen((sbyte*)s));
             FreeCoTaskMem(s);
         }
 
@@ -1922,13 +1917,13 @@ namespace System.Runtime.InteropServices
 
         public static void ZeroFreeGlobalAllocAnsi(IntPtr s)
         {
-            Win32Native.ZeroMemory(s, (UIntPtr)(Win32Native.lstrlenA(s)));
+            RuntimeImports.RhZeroMemory(s, (UIntPtr)(Win32Native.lstrlenA(s)));
             FreeHGlobal(s);
         }
 
         public static void ZeroFreeGlobalAllocUnicode(IntPtr s)
         {
-            Win32Native.ZeroMemory(s, (UIntPtr)(Win32Native.lstrlenW(s) * 2));
+            RuntimeImports.RhZeroMemory(s, (UIntPtr)(Win32Native.lstrlenW(s) * 2));
             FreeHGlobal(s);
         }
     }
index 7b7c5ef..c79af8b 100644 (file)
@@ -171,11 +171,6 @@ namespace System.Runtime.InteropServices
         {
             throw new PlatformNotSupportedException(SR.PlatformNotSupported_ComInterop);
         }
-
-        public static void ZeroFreeBSTR(System.IntPtr s)
-        {
-            throw new PlatformNotSupportedException(SR.PlatformNotSupported_ComInterop);
-        }
     }
 
     public class DispatchWrapper
diff --git a/src/mscorlib/src/System/Runtime/InteropServices/PInvokeMarshal.cs b/src/mscorlib/src/System/Runtime/InteropServices/PInvokeMarshal.cs
new file mode 100644 (file)
index 0000000..9eb60bd
--- /dev/null
@@ -0,0 +1,24 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Win32;
+
+namespace System.Runtime.InteropServices
+{
+    internal static class PInvokeMarshal
+    {
+        public static IntPtr AllocBSTR(int length)
+        {
+            IntPtr bstr = Win32Native.SysAllocStringLen(null, length);
+            if (bstr == IntPtr.Zero)
+                throw new OutOfMemoryException();
+            return bstr;
+        }
+
+        public static void FreeBSTR(IntPtr ptr)
+        {
+            Win32Native.SysFreeString(ptr);
+        }
+    }
+}
index 16d41d3..ed0c556 100644 (file)
@@ -8,7 +8,7 @@ using System.Runtime.InteropServices;
 #if BIT64
 using nuint = System.UInt64;
 #else
-    using nuint = System.UInt32;
+using nuint = System.UInt32;
 #endif
 
 namespace System.Runtime
@@ -26,8 +26,13 @@ namespace System.Runtime
             }
         }
 
+        internal unsafe static void RhZeroMemory(IntPtr p, UIntPtr byteLength)
+        {
+            RhZeroMemory((void*)p, (nuint)byteLength);
+        }
+
         [DllImport(JitHelpers.QCall, CharSet = CharSet.Unicode)]
-        extern private unsafe static void RhZeroMemory(byte* b, nuint byteLength);
+        extern private unsafe static void RhZeroMemory(void* b, nuint byteLength);
 
         [MethodImpl(MethodImplOptions.InternalCall)]
         internal extern unsafe static void RhBulkMoveWithWriteBarrier(ref byte destination, ref byte source, nuint byteCount);
index 714dac8..29ee83a 100644 (file)
@@ -52,6 +52,39 @@ public class StringMarshalingTest
         }
     }
 
+    private unsafe void SecureStringToBSTRToString()
+    {
+        foreach (String ts in TestStrings)
+        {
+            SecureString secureString = new SecureString();
+            foreach (char character in ts)
+            {
+                secureString.AppendChar(character);
+            }
+
+            IntPtr BStr = IntPtr.Zero;
+            String str;
+
+            try
+            {
+                BStr = Marshal.SecureStringToBSTR(secureString);
+                str = Marshal.PtrToStringBSTR(BStr);
+            }
+            finally
+            {
+                if (BStr != IntPtr.Zero)
+                {
+                    Marshal.ZeroFreeBSTR(BStr);
+                }
+            }
+
+            if (!str.Equals(ts))
+            {
+                throw new Exception();
+            }
+        }
+    }
+
     private void StringToCoTaskMemAnsiToString()
     {
         foreach (String ts in TestStrings)
@@ -201,6 +234,7 @@ public class StringMarshalingTest
     public  bool RunTests()
     {
         StringToBStrToString();
+        SecureStringToBSTRToString();
         StringToCoTaskMemAnsiToString();
         StringToCoTaskMemUniToString();
         StringToHGlobalAnsiToString();