[marshal-ilgen] Fix NULL check for blittable out lparray. (#46548)
authormonojenkins <jo.shields+jenkins@xamarin.com>
Tue, 5 Jan 2021 08:48:12 +0000 (03:48 -0500)
committerGitHub <noreply@github.com>
Tue, 5 Jan 2021 08:48:12 +0000 (09:48 +0100)
Found while working on WPF clipboard code in Wine Mono. The test case causes a NullReferenceException in Mono but works in .NET Framework.

Co-authored-by: madewokherd <madewokherd@users.noreply.github.com>
src/mono/mono/metadata/marshal-ilgen.c
src/mono/mono/tests/cominterop.cs
src/mono/mono/tests/libtest.c

index ea580d3..785556b 100644 (file)
@@ -3006,6 +3006,7 @@ emit_marshal_array_ilgen (EmitMarshalContext *m, int argnum, MonoType *t,
                        mono_mb_emit_byte (mb, CEE_MUL);
                        mono_mb_emit_byte (mb, CEE_PREFIX1);
                        mono_mb_emit_byte (mb, CEE_CPBLK);                      
+                       mono_mb_patch_branch (mb, label1);
                        break;
                }
 
index e83cafd..79ae06a 100644 (file)
@@ -331,6 +331,9 @@ public class Tests
        public static extern int mono_test_marshal_safearray_in_ccw([MarshalAs (UnmanagedType.Interface)] ITest itest);
 
        [DllImport("libtest")]
+       public static extern int mono_test_marshal_lparray_out_ccw([MarshalAs (UnmanagedType.Interface)] ITest itest);
+
+       [DllImport("libtest")]
        public static extern int mono_test_default_interface_ccw([MarshalAs (UnmanagedType.Interface)] ITest itest);
 
        [DllImport("libtest")]
@@ -792,6 +795,8 @@ public class Tests
                                }
                                if (mono_test_marshal_safearray_in_ccw(test) != 0)
                                        return 97;
+                               if (mono_test_marshal_lparray_out_ccw(test) != 0)
+                                       return 98;
                        }
                        #endregion // SafeArray Tests
 
@@ -904,6 +909,8 @@ public class Tests
                [MethodImplAttribute (MethodImplOptions.InternalCall, MethodCodeType = MethodCodeType.Runtime)]
                void ArrayIn3 (object[] array);
                [MethodImplAttribute (MethodImplOptions.InternalCall, MethodCodeType = MethodCodeType.Runtime)]
+               int ArrayOut ([Out, MarshalAs (UnmanagedType.LPArray, SizeConst=1)] int[] array);
+               [MethodImplAttribute (MethodImplOptions.InternalCall, MethodCodeType = MethodCodeType.Runtime)]
                [return: MarshalAs (UnmanagedType.Interface)]
                TestDefaultInterfaceClass1 GetDefInterface1();
                [MethodImplAttribute (MethodImplOptions.InternalCall, MethodCodeType = MethodCodeType.Runtime)]
@@ -971,6 +978,9 @@ public class Tests
                int ArrayIn2 ([In] object[] array);
                [MethodImplAttribute (MethodImplOptions.InternalCall, MethodCodeType = MethodCodeType.Runtime)]
                int ArrayIn3 (object[] array);
+               [MethodImplAttribute (MethodImplOptions.InternalCall, MethodCodeType = MethodCodeType.Runtime)]
+               [PreserveSig]
+               int ArrayOut ([Out, MarshalAs (UnmanagedType.LPArray, SizeConst=1)] int[] array, out int result);
        }
 
        [System.Runtime.InteropServices.GuidAttribute ("00000000-0000-0000-0000-000000000002")]
@@ -1021,6 +1031,8 @@ public class Tests
                [MethodImplAttribute (MethodImplOptions.InternalCall, MethodCodeType = MethodCodeType.Runtime)]
                public virtual extern void ArrayIn3 (object[] array);
                [MethodImplAttribute (MethodImplOptions.InternalCall, MethodCodeType = MethodCodeType.Runtime)]
+               public virtual extern int ArrayOut ([Out, MarshalAs (UnmanagedType.LPArray, SizeConst=1)] int[] array);
+               [MethodImplAttribute (MethodImplOptions.InternalCall, MethodCodeType = MethodCodeType.Runtime)]
                public virtual extern TestDefaultInterfaceClass1 GetDefInterface1();
                [MethodImplAttribute (MethodImplOptions.InternalCall, MethodCodeType = MethodCodeType.Runtime)]
                public virtual extern TestDefaultInterfaceClass2 GetDefInterface2();
@@ -1212,6 +1224,18 @@ public class Tests
                {
                        return ArrayIn(array);
                }
+
+               public int ArrayOut (int[] array, out int result)
+               {
+                       if (array == null)
+                               result = 0;
+                       else
+                       {
+                               array[0] = 55;
+                               result = 1;
+                       }
+                       return 0;
+               }
        }
 
        public class ManagedTest : ITest
@@ -1346,6 +1370,14 @@ public class Tests
                        ArrayIn(array);
                }
 
+               public int ArrayOut (int[] array)
+               {
+                       if (array == null)
+                               return 0;
+                       array[0] = 55;
+                       return 1;
+               }
+
                public TestDefaultInterfaceClass1 GetDefInterface1()
                {
                        return new TestDefaultInterfaceClass1();
index 175ed4a..0f078a3 100644 (file)
@@ -3436,6 +3436,7 @@ typedef struct
        int (STDCALL *ArrayIn)(MonoComObject* pUnk, void *array);
        int (STDCALL *ArrayIn2)(MonoComObject* pUnk, void *array);
        int (STDCALL *ArrayIn3)(MonoComObject* pUnk, void *array);
+       int (STDCALL *ArrayOut)(MonoComObject* pUnk, guint32 *array, guint32 *result);
        int (STDCALL *GetDefInterface1)(MonoComObject* pUnk, MonoDefItfObject **iface);
        int (STDCALL *GetDefInterface2)(MonoComObject* pUnk, MonoDefItfObject **iface);
 } MonoIUnknown;
@@ -3591,6 +3592,12 @@ ArrayIn3(MonoComObject* pUnk, void *array)
 }
 
 LIBTEST_API int STDCALL
+ArrayOut(MonoComObject* pUnk, guint32 *array, guint32 *result)
+{
+       return S_OK;
+}
+
+LIBTEST_API int STDCALL
 GetDefInterface1(MonoComObject* pUnk, MonoDefItfObject **obj)
 {
        return S_OK;
@@ -3638,6 +3645,7 @@ static void create_com_object (MonoComObject** pOut)
        (*pOut)->vtbl->ArrayIn = ArrayIn;
        (*pOut)->vtbl->ArrayIn2 = ArrayIn2;
        (*pOut)->vtbl->ArrayIn3 = ArrayIn3;
+       (*pOut)->vtbl->ArrayOut = ArrayOut;
        (*pOut)->vtbl->GetDefInterface1 = GetDefInterface1;
        (*pOut)->vtbl->GetDefInterface2 = GetDefInterface2;
 }
@@ -5714,6 +5722,29 @@ mono_test_marshal_safearray_in_ccw(MonoComObject *pUnk)
        return ret;
 }
 
+LIBTEST_API int STDCALL
+mono_test_marshal_lparray_out_ccw(MonoComObject *pUnk)
+{
+       guint32 array, result;
+       int ret;
+
+       ret = pUnk->vtbl->ArrayOut (pUnk, &array, &result);
+       if (ret)
+               return ret;
+       if (array != 55)
+               return 1;
+       if (result != 1)
+               return 2;
+
+       ret = pUnk->vtbl->ArrayOut (pUnk, NULL, &result);
+       if (ret)
+               return ret;
+       if (result != 0)
+               return 3;
+
+       return 0;
+}
+
 #endif
 
 static int call_managed_res;