[DllImportGenerator] Fix stub generation for char array marshalling (#61188)
authorElinor Fung <elfung@microsoft.com>
Thu, 4 Nov 2021 05:49:01 +0000 (22:49 -0700)
committerGitHub <noreply@github.com>
Thu, 4 Nov 2021 05:49:01 +0000 (22:49 -0700)
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CharMarshaller.cs
src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.Tests/ArrayTests.cs
src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/Arrays.cs

index 4852974..293b979 100644 (file)
@@ -96,7 +96,7 @@ namespace Microsoft.Interop
                 case StubCodeContext.Stage.Setup:
                     break;
                 case StubCodeContext.Stage.Marshal:
-                    if (info.IsByRef && info.RefKind != RefKind.Out)
+                    if ((info.IsByRef && info.RefKind != RefKind.Out) || !context.SingleFrameSpansNativeContext)
                     {
                         yield return ExpressionStatement(
                             AssignmentExpression(
index 82039c4..b52876d 100644 (file)
@@ -35,6 +35,12 @@ namespace DllImportGenerator.IntegrationTests
             [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "create_range_array_out")]
             public static partial void CreateRange_Out(int start, int end, out int numValues, [MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 2)] out int[] res);
 
+            [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "sum_char_array", CharSet = CharSet.Unicode)]
+            public static partial int SumChars(char[] chars, int numElements);
+
+            [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "reverse_char_array", CharSet = CharSet.Unicode)]
+            public static partial void ReverseChars([MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 1)] ref char[] chars, int numElements);
+
             [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "sum_string_lengths")]
             public static partial int SumStringLengths([MarshalAs(UnmanagedType.LPArray, ArraySubType = UnmanagedType.LPWStr)] string[] strArray);
 
@@ -119,6 +125,22 @@ namespace DllImportGenerator.IntegrationTests
         }
 
         [Fact]
+        public void CharArrayMarshalledToNativeAsExpected()
+        {
+            char[] array = CharacterTests.CharacterMappings().Select(o => (char)o[0]).ToArray();
+            Assert.Equal(array.Sum(c => c), NativeExportsNE.Arrays.SumChars(array, array.Length));
+        }
+
+        [Fact]
+        public void CharArrayRefParameter()
+        {
+            char[] array = CharacterTests.CharacterMappings().Select(o => (char)o[0]).ToArray();
+            var newArray = array;
+            NativeExportsNE.Arrays.ReverseChars(ref newArray, array.Length);
+            Assert.Equal(array.Reverse(), newArray);
+        }
+
+        [Fact]
         public void ArraysReturnedFromNative()
         {
             int start = 5;
index 3e11f6c..a220162 100644 (file)
@@ -90,6 +90,34 @@ namespace NativeExports
             }
         }
 
+        [UnmanagedCallersOnly(EntryPoint = "sum_char_array")]
+        public static int SumChars(ushort* values, int numValues)
+        {
+            if (values == null)
+            {
+                return -1;
+            }
+
+            int sum = 0;
+            for (int i = 0; i < numValues; i++)
+            {
+                sum += values[i];
+            }
+            return sum;
+        }
+
+        [UnmanagedCallersOnly(EntryPoint = "reverse_char_array")]
+        public static void ReverseChars(ushort** values, int numValues)
+        {
+            if (*values == null)
+            {
+                return;
+            }
+
+            var span = new Span<ushort>(*values, numValues);
+            span.Reverse();
+        }
+
         [UnmanagedCallersOnly(EntryPoint = "sum_string_lengths")]
         public static int SumStringLengths(ushort** strArray)
         {