Make sure we consider buffer length when marshalling back Unicode ByValTStr fields...
authorJeremy Koritzinsky <jekoritz@microsoft.com>
Mon, 28 Jun 2021 19:41:39 +0000 (12:41 -0700)
committerGitHub <noreply@github.com>
Mon, 28 Jun 2021 19:41:39 +0000 (12:41 -0700)
* Use string constructor that takes length instead of the one that searches for a null terminator.

Fixes #54662

* Marshal back buffer size or string to first null terminator, whichever is shorter

* Add tests.

* Add unicode test.

* Use the same implementation style for the wstr case case as the cstr case

* Fix accidental deletion from test.

src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs
src/coreclr/vm/corelib.h
src/coreclr/vm/ilmarshalers.cpp
src/tests/Interop/StringMarshalling/LPTSTR/LPTSTRTest.cs
src/tests/Interop/StringMarshalling/LPTSTR/LPTStrTestNative.cpp
src/tests/Interop/StringMarshalling/LPTSTR/LPTStrTestNative.cs

index 3343568..54f5764 100644 (file)
@@ -508,6 +508,17 @@ namespace System.StubHelpers
             managed.Slice(0, numChars).CopyTo(native);
             native[numChars] = '\0';
         }
+
+        internal static unsafe string ConvertToManaged(IntPtr nativeHome, int length)
+        {
+            int end = SpanHelpers.IndexOf(ref *(char*)nativeHome, '\0', length);
+            if (end != -1)
+            {
+                length = end;
+            }
+
+            return new string((char*)nativeHome, 0, length);
+        }
     }  // class WSTRBufferMarshaler
 #if FEATURE_COMINTEROP
 
index 73ce65f..8f6afcc 100644 (file)
@@ -1069,6 +1069,7 @@ DEFINE_METHOD(CSTRMARSHALER,        CLEAR_NATIVE,           ClearNative,
 
 DEFINE_CLASS(FIXEDWSTRMARSHALER,   StubHelpers,            FixedWSTRMarshaler)
 DEFINE_METHOD(FIXEDWSTRMARSHALER,  CONVERT_TO_NATIVE,      ConvertToNative,            SM_Str_IntPtr_Int_RetVoid)
+DEFINE_METHOD(FIXEDWSTRMARSHALER,  CONVERT_TO_MANAGED,     ConvertToManaged,           SM_IntPtr_Int_RetStr)
 
 DEFINE_CLASS(BSTRMARSHALER,         StubHelpers,            BSTRMarshaler)
 DEFINE_METHOD(BSTRMARSHALER,        CONVERT_TO_NATIVE,      ConvertToNative,            SM_Str_IntPtr_RetIntPtr)
index 163222d..8f34846 100644 (file)
@@ -1906,11 +1906,8 @@ void ILFixedWSTRMarshaler::EmitConvertContentsNativeToCLR(ILCodeStream* pslILEmi
     STANDARD_VM_CONTRACT;
 
     EmitLoadNativeHomeAddr(pslILEmit);
-    pslILEmit->EmitDUP();
-    pslILEmit->EmitCALL(METHOD__STRING__WCSLEN, 1, 1);
-    pslILEmit->EmitCALL(METHOD__STUBHELPERS__CHECK_STRING_LENGTH, 1, 0);
-
-    pslILEmit->EmitNEWOBJ(METHOD__STRING__CTOR_CHARPTR, 1);
+    pslILEmit->EmitLDC(m_pargs->fs.fixedStringLength);
+    pslILEmit->EmitCALL(METHOD__FIXEDWSTRMARSHALER__CONVERT_TO_MANAGED, 2, 1);
     EmitStoreManagedValue(pslILEmit);
 }
 
index 8c789d1..5405443 100644 (file)
@@ -12,6 +12,8 @@ using static LPTStrTestNative;
 class LPTStrTest
 {
     private static readonly string InitialString = "Hello World";
+    private static readonly string LongString = "0123456789abcdefghi";
+    private static readonly string LongUnicodeString = "๐Ÿ‘จโ€๐Ÿ‘จโ€๐Ÿ‘งโ€๐Ÿ‘ง๐Ÿฑโ€๐Ÿ‘ค";
 
     public static int Main()
     {
@@ -58,7 +60,21 @@ class LPTStrTest
         };
 
         ReverseByValStringUni(ref uniStr);
-
         Assert.AreEqual(Helpers.Reverse(InitialString), uniStr.str);
+
+        ReverseCopyByValStringAnsi(new ByValStringInStructAnsi { str = LongString }, out ByValStringInStructSplitAnsi ansiStrSplit);
+
+        Assert.AreEqual(Helpers.Reverse(LongString[^10..]), ansiStrSplit.str1);
+        Assert.AreEqual(Helpers.Reverse(LongString[..^10]), ansiStrSplit.str2);
+
+        ReverseCopyByValStringUni(new ByValStringInStructUnicode { str = LongString }, out ByValStringInStructSplitUnicode uniStrSplit);
+
+        Assert.AreEqual(Helpers.Reverse(LongString[^10..]), uniStrSplit.str1);
+        Assert.AreEqual(Helpers.Reverse(LongString[..^10]), uniStrSplit.str2);
+
+        ReverseCopyByValStringUni(new ByValStringInStructUnicode { str = LongUnicodeString }, out ByValStringInStructSplitUnicode uniStrSplit2);
+
+        Assert.AreEqual(Helpers.Reverse(LongUnicodeString[^10..]), uniStrSplit2.str1);
+        Assert.AreEqual(Helpers.Reverse(LongUnicodeString[..^10]), uniStrSplit2.str2);
     }
 }
index 428ce48..c618a5d 100644 (file)
@@ -47,3 +47,15 @@ extern "C" DLL_EXPORT void STDMETHODCALLTYPE ReverseByValStringUni(ByValStringIn
 {
     StringMarshalingTests<LPWSTR, TP_slen>::ReverseInplace(str->str);
 }
+
+extern "C" DLL_EXPORT void STDMETHODCALLTYPE ReverseCopyByValStringAnsi(ByValStringInStructAnsi str, ByValStringInStructAnsi* out)
+{
+    *out = str;
+    StringMarshalingTests<char*, default_callconv_strlen>::ReverseInplace(out->str);
+}
+
+extern "C" DLL_EXPORT void STDMETHODCALLTYPE ReverseCopyByValStringUni(ByValStringInStructUnicode str, ByValStringInStructUnicode* out)
+{
+    *out = str;
+    StringMarshalingTests<LPWSTR, TP_slen>::ReverseInplace(out->str);
+}
index 022c3dc..443bc76 100644 (file)
@@ -14,6 +14,15 @@ class LPTStrTestNative
         public string str;
     }
 
+    [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)]
+    public struct ByValStringInStructSplitAnsi
+    {
+        [MarshalAs(UnmanagedType.ByValTStr, SizeConst = 10)]
+        public string str1;
+        [MarshalAs(UnmanagedType.ByValTStr, SizeConst = 10)]
+        public string str2;
+    }
+
     [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
     public struct ByValStringInStructUnicode
     {
@@ -21,6 +30,15 @@ class LPTStrTestNative
         public string str;
     }
 
+    [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
+    public struct ByValStringInStructSplitUnicode
+    {
+        [MarshalAs(UnmanagedType.ByValTStr, SizeConst = 10)]
+        public string str1;
+        [MarshalAs(UnmanagedType.ByValTStr, SizeConst = 10)]
+        public string str2;
+    }
+
     [DllImport(nameof(LPTStrTestNative), CharSet = CharSet.Unicode)]
     public static extern bool Verify_NullTerminators_PastEnd(StringBuilder builder, int length);
 
@@ -36,4 +54,9 @@ class LPTStrTestNative
     public static extern void ReverseByValStringAnsi(ref ByValStringInStructAnsi str);
     [DllImport(nameof(LPTStrTestNative))]
     public static extern void ReverseByValStringUni(ref ByValStringInStructUnicode str);
+
+    [DllImport(nameof(LPTStrTestNative))]
+    public static extern void ReverseCopyByValStringAnsi(ByValStringInStructAnsi str, out ByValStringInStructSplitAnsi strOut);
+    [DllImport(nameof(LPTStrTestNative))]
+    public static extern void ReverseCopyByValStringUni(ByValStringInStructUnicode str, out ByValStringInStructSplitUnicode strOut);
 }