[cominterop] Handle NULL pointers when marshalling native-to-managed return values...
authorZebediah Figura <obsequiousnewt@gmail.com>
Wed, 31 Jul 2019 12:14:39 +0000 (07:14 -0500)
committerAlexander Köplinger <alex.koeplinger@outlook.com>
Wed, 31 Jul 2019 12:14:39 +0000 (14:14 +0200)
Fixes a bug running Rak24u with wine-mono: bugs.winehq.org/show_bug.cgi?id=47561

Commit migrated from https://github.com/mono/mono/commit/b51c696081a58515adf53a50e5746ed08bd1062e

src/mono/mono/metadata/cominterop.c
src/mono/mono/tests/cominterop.cs
src/mono/mono/tests/libtest.c

index 4ef0685..af9ca7c 100644 (file)
@@ -2298,8 +2298,8 @@ cominterop_get_managed_wrapper_adjusted (MonoMethod *method)
        MonoMarshalSpec **mspecs;
        MonoMethodSignature *sig, *sig_native;
        MonoExceptionClause *main_clause = NULL;
+       int hr = 0, retval = 0;
        int pos_leave;
-       int hr = 0;
        int i;
        gboolean const preserve_sig = (method->iflags & METHOD_IMPL_ATTRIBUTE_PRESERVE_SIG) != 0;
 
@@ -2334,8 +2334,11 @@ cominterop_get_managed_wrapper_adjusted (MonoMethod *method)
        mspecs [0] = NULL;
 
 #ifndef DISABLE_JIT
-       if (!preserve_sig)
+       if (!preserve_sig) {
+               if (!MONO_TYPE_IS_VOID (sig->ret))
+                       retval = mono_mb_add_local (mb, sig->ret);
                hr = mono_mb_add_local (mb, mono_get_int32_type ());
+       }
        else if (!MONO_TYPE_IS_VOID (sig->ret))
                hr = mono_mb_add_local (mb, sig->ret);
 
@@ -2343,10 +2346,6 @@ cominterop_get_managed_wrapper_adjusted (MonoMethod *method)
        main_clause = g_new0 (MonoExceptionClause, 1);
        main_clause->try_offset = mono_mb_get_label (mb);
 
-       /* load last param to store result if not preserve_sig and not void */
-       if (!preserve_sig && !MONO_TYPE_IS_VOID (sig->ret))
-               mono_mb_emit_ldarg (mb, sig_native->param_count-1);
-
        /* the CCW -> object conversion */
        mono_mb_emit_ldarg (mb, 0);
        mono_mb_emit_icon (mb, FALSE);
@@ -2359,12 +2358,21 @@ cominterop_get_managed_wrapper_adjusted (MonoMethod *method)
 
        if (!MONO_TYPE_IS_VOID (sig->ret)) {
                if (!preserve_sig) {
+                       mono_mb_emit_stloc (mb, retval);
+                       mono_mb_emit_ldarg (mb, sig_native->param_count - 1);
+                       const int pos_null = mono_mb_emit_branch (mb, CEE_BRFALSE);
+
+                       mono_mb_emit_ldarg (mb, sig_native->param_count - 1);
+                       mono_mb_emit_ldloc (mb, retval);
+
                        MonoClass *rclass = mono_class_from_mono_type_internal (sig->ret);
                        if (m_class_is_valuetype (rclass)) {
                                mono_mb_emit_op (mb, CEE_STOBJ, rclass);
                        } else {
                                mono_mb_emit_byte (mb, mono_type_to_stind (sig->ret));
                        }
+
+                       mono_mb_patch_branch (mb, pos_null);
                } else
                        mono_mb_emit_stloc (mb, hr);
        }
index b6cfe15..45527ca 100644 (file)
@@ -254,6 +254,12 @@ public class Tests
        [DllImport ("libtest")]
        public static extern int mono_test_marshal_array_ccw_itest (int count, [MarshalAs (UnmanagedType.LPArray, SizeParamIndex=0)] ITest[] ppUnk);
 
+       [DllImport ("libtest")]
+       public static extern int mono_test_marshal_retval_ccw_itest ([MarshalAs (UnmanagedType.Interface)]ITest itest, bool test_null);
+
+       [DllImport ("libtest")]
+       public static extern int mono_test_marshal_retval_ccw_itest ([MarshalAs (UnmanagedType.Interface)]ITestPresSig itest, bool test_null);
+
        [DllImport("libtest")]
        public static extern int mono_test_cominterop_ccw_queryinterface ([MarshalAs (UnmanagedType.Interface)] IOtherTest itest);
 
@@ -578,6 +584,13 @@ public class Tests
                        if (mono_test_cominterop_ccw_queryinterface (otherTest) != 0)
                                return 202;
 
+                       if (mono_test_marshal_retval_ccw_itest(test, true) != 0)
+                               return 203;
+
+                       /* Passing NULL to an out parameter will crash. */
+                       if (mono_test_marshal_retval_ccw_itest(test_pres_sig, false) != 0)
+                               return 204;
+
                        #endregion // COM Callable Wrapper Tests
 
                        #region SAFEARRAY tests
@@ -794,6 +807,8 @@ public class Tests
                [MethodImplAttribute (MethodImplOptions.InternalCall, MethodCodeType = MethodCodeType.Runtime)]
                void ITestOut ([MarshalAs (UnmanagedType.Interface)]out ITest val);
                int Return22NoICall();
+               [MethodImplAttribute (MethodImplOptions.InternalCall, MethodCodeType = MethodCodeType.Runtime)]
+               int IntOut();
        }
 
        [ComImport ()]
@@ -847,6 +862,9 @@ public class Tests
                int ITestOut ([MarshalAs (UnmanagedType.Interface)]out ITestPresSig val);
                [PreserveSig ()]
                int Return22NoICall();
+               [MethodImplAttribute (MethodImplOptions.InternalCall, MethodCodeType = MethodCodeType.Runtime)]
+               [PreserveSig ()]
+               int IntOut (out int val);
        }
 
        [System.Runtime.InteropServices.GuidAttribute ("00000000-0000-0000-0000-000000000002")]
@@ -886,9 +904,10 @@ public class Tests
                public virtual extern void ITestIn ([MarshalAs (UnmanagedType.Interface)]ITest val);
                [MethodImplAttribute (MethodImplOptions.InternalCall, MethodCodeType = MethodCodeType.Runtime)]
                public virtual extern void ITestOut ([MarshalAs (UnmanagedType.Interface)]out ITest val);
-
                [MethodImplAttribute (MethodImplOptions.InternalCall, MethodCodeType = MethodCodeType.Runtime)]
                public virtual extern int Return22NoICall();
+               [MethodImplAttribute (MethodImplOptions.InternalCall, MethodCodeType = MethodCodeType.Runtime)]
+               public virtual extern int IntOut();
        }
 
        [System.Runtime.InteropServices.GuidAttribute ("00000000-0000-0000-0000-000000000002")]
@@ -1033,6 +1052,12 @@ public class Tests
                {
                        return 88;
                }
+
+               public int IntOut(out int val)
+               {
+                       val = 33;
+                       return 0;
+               }
        }
 
        public class ManagedTest : ITest
@@ -1127,6 +1152,11 @@ public class Tests
                {
                        return 99;
                }
+
+               public int IntOut()
+               {
+                       return 33;
+               }
        }
 
        [ComVisible (true)]
index 3b04f2d..d4c5ade 100644 (file)
@@ -3389,6 +3389,7 @@ typedef struct
        int (STDCALL *ITestIn)(MonoComObject* pUnk, MonoComObject* pUnk2);
        int (STDCALL *ITestOut)(MonoComObject* pUnk, MonoComObject* *ppUnk);
        int (STDCALL *Return22NoICall)(MonoComObject* pUnk);
+       int (STDCALL *IntOut)(MonoComObject* pUnk, int *a);
 } MonoIUnknown;
 
 struct MonoComObject
@@ -3512,6 +3513,11 @@ Return22NoICall(MonoComObject* pUnk)
        return 22;
 }
 
+LIBTEST_API int STDCALL
+IntOut(MonoComObject* pUnk, int *a)
+{
+       return S_OK;
+}
 
 static void create_com_object (MonoComObject** pOut);
 
@@ -3545,6 +3551,7 @@ static void create_com_object (MonoComObject** pOut)
        (*pOut)->vtbl->ITestOut = ITestOut;
        (*pOut)->vtbl->get_ITest = get_ITest;
        (*pOut)->vtbl->Return22NoICall = Return22NoICall;
+       (*pOut)->vtbl->IntOut = IntOut;
 }
 
 static MonoComObject* same_object = NULL;
@@ -3655,6 +3662,29 @@ mono_test_marshal_array_ccw_itest (int count, MonoComObject ** ppUnk)
        return 0;
 }
 
+LIBTEST_API int STDCALL
+mono_test_marshal_retval_ccw_itest (MonoComObject *pUnk, int test_null)
+{
+       int hr = 0, i = 0;
+
+       if (!pUnk)
+               return 1;
+
+       hr = pUnk->vtbl->IntOut (pUnk, &i);
+       if (hr != 0)
+               return 2;
+       if (i != 33)
+               return 3;
+       if (test_null)
+       {
+               hr = pUnk->vtbl->IntOut (pUnk, NULL);
+               if (hr != 0)
+                       return 4;
+       }
+
+       return 0;
+}
+
 /*
  * mono_method_get_unmanaged_thunk tests
  */