From d10f6f11debe5308d21f9e9802ba155b6624e2fd Mon Sep 17 00:00:00 2001 From: Aaron Robinson Date: Mon, 3 Aug 2020 15:30:09 -0700 Subject: [PATCH] Don't rely on the built-in interface marshaller during COM activation. (#40228) * Don't rely on the built-in marshaller during activation. Relying on the built-in marshaller leverages the Class interface approach which doesn't work for some interface types (e.g. interfaces inheriting from IDispatch). This approach is wrong regardless of why given that COM dictates the returned value must be properly cast the specific interface vtable. --- .../Runtime/InteropServices/ComActivator.cs | 46 ++-- src/tests/Interop/COM/Activator/Program.cs | 28 ++- .../Interop/COM/NativeClients/Dispatch/Client.cpp | 258 ++++++++++++--------- 3 files changed, 196 insertions(+), 136 deletions(-) diff --git a/src/coreclr/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs b/src/coreclr/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs index 45f4d48..c41d4dd 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs @@ -25,7 +25,7 @@ namespace Internal.Runtime.InteropServices void CreateInstance( [MarshalAs(UnmanagedType.Interface)] object? pUnkOuter, ref Guid riid, - [MarshalAs(UnmanagedType.Interface)] out object? ppvObject); + out IntPtr ppvObject); void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock); } @@ -51,7 +51,7 @@ namespace Internal.Runtime.InteropServices new void CreateInstance( [MarshalAs(UnmanagedType.Interface)] object? pUnkOuter, ref Guid riid, - [MarshalAs(UnmanagedType.Interface)] out object? ppvObject); + out IntPtr ppvObject); new void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock); @@ -66,7 +66,7 @@ namespace Internal.Runtime.InteropServices [MarshalAs(UnmanagedType.Interface)] object? pUnkReserved, ref Guid riid, [MarshalAs(UnmanagedType.BStr)] string bstrKey, - [MarshalAs(UnmanagedType.Interface)] out object ppvObject); + out IntPtr ppvObject); } [StructLayout(LayoutKind.Sequential)] @@ -493,28 +493,32 @@ $@"{nameof(UnregisterClassForTypeInternal)} arguments: #endif } - public static void ValidateObjectIsMarshallableAsInterface(object obj, Type interfaceType) + public static IntPtr GetObjectAsInterface(object obj, Type interfaceType) { #if FEATURE_COMINTEROP_UNMANAGED_ACTIVATION - // If the requested "interface type" is type object then return - // because type object is always marshallable. + // If the requested "interface type" is type object then return as IUnknown if (interfaceType == typeof(object)) { - return; + return Marshal.GetIUnknownForObject(obj); } Debug.Assert(interfaceType.IsInterface); - // The intent of this call is to validate the interface can be + // The intent of this call is to get AND validate the interface can be // marshalled to native code. An exception will be thrown if the // type is unable to be marshalled to native code. // Scenarios where this is relevant: // - Interfaces that use Generics // - Interfaces that define implementation - IntPtr ptr = Marshal.GetComInterfaceForObject(obj, interfaceType, CustomQueryInterfaceMode.Ignore); + IntPtr interfaceMaybe = Marshal.GetComInterfaceForObject(obj, interfaceType, CustomQueryInterfaceMode.Ignore); - // Decrement the above 'Marshal.GetComInterfaceForObject()' - Marshal.Release(ptr); + if (interfaceMaybe == IntPtr.Zero) + { + // E_NOINTERFACE + throw new InvalidCastException(); + } + + return interfaceMaybe; #else throw new PlatformNotSupportedException(); #endif @@ -544,18 +548,18 @@ $@"{nameof(UnregisterClassForTypeInternal)} arguments: public void CreateInstance( [MarshalAs(UnmanagedType.Interface)] object? pUnkOuter, ref Guid riid, - [MarshalAs(UnmanagedType.Interface)] out object? ppvObject) + out IntPtr ppvObject) { #if FEATURE_COMINTEROP_UNMANAGED_ACTIVATION Type interfaceType = BasicClassFactory.GetValidatedInterfaceType(_classType, ref riid, pUnkOuter); - ppvObject = Activator.CreateInstance(_classType)!; + object obj = Activator.CreateInstance(_classType)!; if (pUnkOuter != null) { - ppvObject = BasicClassFactory.CreateAggregatedObject(pUnkOuter, ppvObject); + obj = BasicClassFactory.CreateAggregatedObject(pUnkOuter, obj); } - BasicClassFactory.ValidateObjectIsMarshallableAsInterface(ppvObject, interfaceType); + ppvObject = BasicClassFactory.GetObjectAsInterface(obj, interfaceType); #else throw new PlatformNotSupportedException(); #endif @@ -593,7 +597,7 @@ $@"{nameof(UnregisterClassForTypeInternal)} arguments: public void CreateInstance( [MarshalAs(UnmanagedType.Interface)] object? pUnkOuter, ref Guid riid, - [MarshalAs(UnmanagedType.Interface)] out object? ppvObject) + out IntPtr ppvObject) { #if FEATURE_COMINTEROP_UNMANAGED_ACTIVATION CreateInstanceInner(pUnkOuter, ref riid, key: null, isDesignTime: true, out ppvObject); @@ -640,7 +644,7 @@ $@"{nameof(UnregisterClassForTypeInternal)} arguments: [MarshalAs(UnmanagedType.Interface)] object? pUnkReserved, ref Guid riid, [MarshalAs(UnmanagedType.BStr)] string bstrKey, - [MarshalAs(UnmanagedType.Interface)] out object ppvObject) + out IntPtr ppvObject) { #if FEATURE_COMINTEROP_UNMANAGED_ACTIVATION Debug.Assert(pUnkReserved == null); @@ -655,18 +659,18 @@ $@"{nameof(UnregisterClassForTypeInternal)} arguments: ref Guid riid, string? key, bool isDesignTime, - out object ppvObject) + out IntPtr ppvObject) { #if FEATURE_COMINTEROP_UNMANAGED_ACTIVATION Type interfaceType = BasicClassFactory.GetValidatedInterfaceType(_classType, ref riid, pUnkOuter); - ppvObject = _licenseProxy.AllocateAndValidateLicense(_classType, key, isDesignTime); + object obj = _licenseProxy.AllocateAndValidateLicense(_classType, key, isDesignTime); if (pUnkOuter != null) { - ppvObject = BasicClassFactory.CreateAggregatedObject(pUnkOuter, ppvObject); + obj = BasicClassFactory.CreateAggregatedObject(pUnkOuter, obj); } - BasicClassFactory.ValidateObjectIsMarshallableAsInterface(ppvObject, interfaceType); + ppvObject = BasicClassFactory.GetObjectAsInterface(obj, interfaceType); #else throw new PlatformNotSupportedException(); #endif diff --git a/src/tests/Interop/COM/Activator/Program.cs b/src/tests/Interop/COM/Activator/Program.cs index ac7fe8d..fd525aa 100644 --- a/src/tests/Interop/COM/Activator/Program.cs +++ b/src/tests/Interop/COM/Activator/Program.cs @@ -105,9 +105,11 @@ namespace Activator var factory = (IClassFactory)ComActivator.GetClassFactoryForType(cxt); - object svr; - factory.CreateInstance(null, ref iid, out svr); - typeCFromAssemblyA = (Type)((IGetTypeFromC)svr).GetTypeFromC(); + IntPtr svrRaw; + factory.CreateInstance(null, ref iid, out svrRaw); + var svr = (IGetTypeFromC)Marshal.GetObjectForIUnknown(svrRaw); + Marshal.Release(svrRaw); + typeCFromAssemblyA = (Type)svr.GetTypeFromC(); } using (HostPolicyMock.Mock_corehost_resolve_component_dependencies( @@ -127,9 +129,11 @@ namespace Activator var factory = (IClassFactory)ComActivator.GetClassFactoryForType(cxt); - object svr; - factory.CreateInstance(null, ref iid, out svr); - typeCFromAssemblyB = (Type)((IGetTypeFromC)svr).GetTypeFromC(); + IntPtr svrRaw; + factory.CreateInstance(null, ref iid, out svrRaw); + var svr = (IGetTypeFromC)Marshal.GetObjectForIUnknown(svrRaw); + Marshal.Release(svrRaw); + typeCFromAssemblyB = (Type)svr.GetTypeFromC(); } Assert.AreNotEqual(typeCFromAssemblyA, typeCFromAssemblyB, "Types should be from different AssemblyLoadContexts"); @@ -178,8 +182,10 @@ namespace Activator var factory = (IClassFactory)ComActivator.GetClassFactoryForType(cxt); - object svr; - factory.CreateInstance(null, ref iid, out svr); + IntPtr svrRaw; + factory.CreateInstance(null, ref iid, out svrRaw); + var svr = Marshal.GetObjectForIUnknown(svrRaw); + Marshal.Release(svrRaw); var inst = (IValidateRegistrationCallbacks)svr; Assert.IsFalse(inst.DidRegister()); @@ -215,8 +221,10 @@ namespace Activator var factory = (IClassFactory)ComActivator.GetClassFactoryForType(cxt); - object svr; - factory.CreateInstance(null, ref iid, out svr); + IntPtr svrRaw; + factory.CreateInstance(null, ref iid, out svrRaw); + var svr = Marshal.GetObjectForIUnknown(svrRaw); + Marshal.Release(svrRaw); var inst = (IValidateRegistrationCallbacks)svr; cxt.InterfaceId = Guid.Empty; diff --git a/src/tests/Interop/COM/NativeClients/Dispatch/Client.cpp b/src/tests/Interop/COM/NativeClients/Dispatch/Client.cpp index 36b4806..0aa509c 100644 --- a/src/tests/Interop/COM/NativeClients/Dispatch/Client.cpp +++ b/src/tests/Interop/COM/NativeClients/Dispatch/Client.cpp @@ -88,59 +88,81 @@ void Validate_Numeric_In_ReturnByRef() ULONGLONG ul1 = 4168; ULONGLONG ul2; - DISPPARAMS params; - params.cArgs = 14; - params.rgvarg = new VARIANTARG[params.cArgs]; - params.cNamedArgs = 0; - params.rgdispidNamedArgs = nullptr; + { + DISPPARAMS params; + params.cArgs = 14; + params.rgvarg = new VARIANTARG[params.cArgs]; + params.cNamedArgs = 0; + params.rgdispidNamedArgs = nullptr; - V_VT(¶ms.rgvarg[13]) = VT_UI1; - V_UI1(¶ms.rgvarg[13]) = b1; - V_VT(¶ms.rgvarg[12]) = VT_BYREF | VT_UI1; - V_UI1REF(¶ms.rgvarg[12]) = &b2; - V_VT(¶ms.rgvarg[11]) = VT_I2; - V_I2(¶ms.rgvarg[11]) = s1; - V_VT(¶ms.rgvarg[10]) = VT_BYREF | VT_I2; - V_I2REF(¶ms.rgvarg[10]) = &s2; - V_VT(¶ms.rgvarg[9]) = VT_UI2; - V_UI2(¶ms.rgvarg[9]) = us1; - V_VT(¶ms.rgvarg[8]) = VT_BYREF | VT_UI2; - V_UI2REF(¶ms.rgvarg[8]) = &us2; - V_VT(¶ms.rgvarg[7]) = VT_I4; - V_I4(¶ms.rgvarg[7]) = i1; - V_VT(¶ms.rgvarg[6]) = VT_BYREF | VT_I4; - V_I4REF(¶ms.rgvarg[6]) = &i2; - V_VT(¶ms.rgvarg[5]) = VT_UI4; - V_UI4(¶ms.rgvarg[5]) = ui1; - V_VT(¶ms.rgvarg[4]) = VT_BYREF | VT_UI4; - V_UI4REF(¶ms.rgvarg[4]) = &ui2; - V_VT(¶ms.rgvarg[3]) = VT_I8; - V_I8(¶ms.rgvarg[3]) = l1; - V_VT(¶ms.rgvarg[2]) = VT_BYREF | VT_I8; - V_I8REF(¶ms.rgvarg[2]) = &l2; - V_VT(¶ms.rgvarg[1]) = VT_UI8; - V_UI8(¶ms.rgvarg[1]) = ul1; - V_VT(¶ms.rgvarg[0]) = VT_BYREF | VT_UI8; - V_UI8REF(¶ms.rgvarg[0]) = &ul2; - - THROW_IF_FAILED(dispatchTesting->Invoke( - methodId, - IID_NULL, - lcid, - DISPATCH_METHOD, - ¶ms, - nullptr, - nullptr, - nullptr - )); + V_VT(¶ms.rgvarg[13]) = VT_UI1; + V_UI1(¶ms.rgvarg[13]) = b1; + V_VT(¶ms.rgvarg[12]) = VT_BYREF | VT_UI1; + V_UI1REF(¶ms.rgvarg[12]) = &b2; + V_VT(¶ms.rgvarg[11]) = VT_I2; + V_I2(¶ms.rgvarg[11]) = s1; + V_VT(¶ms.rgvarg[10]) = VT_BYREF | VT_I2; + V_I2REF(¶ms.rgvarg[10]) = &s2; + V_VT(¶ms.rgvarg[9]) = VT_UI2; + V_UI2(¶ms.rgvarg[9]) = us1; + V_VT(¶ms.rgvarg[8]) = VT_BYREF | VT_UI2; + V_UI2REF(¶ms.rgvarg[8]) = &us2; + V_VT(¶ms.rgvarg[7]) = VT_I4; + V_I4(¶ms.rgvarg[7]) = i1; + V_VT(¶ms.rgvarg[6]) = VT_BYREF | VT_I4; + V_I4REF(¶ms.rgvarg[6]) = &i2; + V_VT(¶ms.rgvarg[5]) = VT_UI4; + V_UI4(¶ms.rgvarg[5]) = ui1; + V_VT(¶ms.rgvarg[4]) = VT_BYREF | VT_UI4; + V_UI4REF(¶ms.rgvarg[4]) = &ui2; + V_VT(¶ms.rgvarg[3]) = VT_I8; + V_I8(¶ms.rgvarg[3]) = l1; + V_VT(¶ms.rgvarg[2]) = VT_BYREF | VT_I8; + V_I8REF(¶ms.rgvarg[2]) = &l2; + V_VT(¶ms.rgvarg[1]) = VT_UI8; + V_UI8(¶ms.rgvarg[1]) = ul1; + V_VT(¶ms.rgvarg[0]) = VT_BYREF | VT_UI8; + V_UI8REF(¶ms.rgvarg[0]) = &ul2; + + THROW_IF_FAILED(dispatchTesting->Invoke( + methodId, + IID_NULL, + lcid, + DISPATCH_METHOD, + ¶ms, + nullptr, + nullptr, + nullptr + )); + + THROW_FAIL_IF_FALSE(b2 == b1 * 2); + THROW_FAIL_IF_FALSE(s2 == s1 * 2); + THROW_FAIL_IF_FALSE(us2 == us1 * 2); + THROW_FAIL_IF_FALSE(i2 == i1 * 2); + THROW_FAIL_IF_FALSE(ui2 == ui1 * 2); + THROW_FAIL_IF_FALSE(l2 == l1 * 2); + THROW_FAIL_IF_FALSE(ul2 == ul1 * 2); + } - THROW_FAIL_IF_FALSE(b2 == b1 * 2); - THROW_FAIL_IF_FALSE(s2 == s1 * 2); - THROW_FAIL_IF_FALSE(us2 == us1 * 2); - THROW_FAIL_IF_FALSE(i2 == i1 * 2); - THROW_FAIL_IF_FALSE(ui2 == ui1 * 2); - THROW_FAIL_IF_FALSE(l2 == l1 * 2); - THROW_FAIL_IF_FALSE(ul2 == ul1 * 2); + { + b2 = 0; + s2 = 0; + us2 = 0; + i2 = 0; + ui2 = 0; + l2 = 0; + ul2 = 0; + + THROW_IF_FAILED(dispatchTesting->DoubleNumeric_ReturnByRef(b1, &b2, s1, &s2, us1, &us2, i1, (INT*)&i2, ui1, (UINT*)&ui2, l1, &l2, ul1, &ul2)); + + THROW_FAIL_IF_FALSE(b2 == b1 * 2); + THROW_FAIL_IF_FALSE(s2 == s1 * 2); + THROW_FAIL_IF_FALSE(us2 == us1 * 2); + THROW_FAIL_IF_FALSE(i2 == i1 * 2); + THROW_FAIL_IF_FALSE(ui2 == ui1 * 2); + THROW_FAIL_IF_FALSE(l2 == l1 * 2); + THROW_FAIL_IF_FALSE(ul2 == ul1 * 2); + } } namespace @@ -183,37 +205,50 @@ void Validate_Float_In_ReturnAndUpdateByRef() lcid, &methodId)); - float a = 12.34f; - float b = 1.234f; - float expected = b + a; - - DISPPARAMS params; - params.cArgs = 2; - params.rgvarg = new VARIANTARG[params.cArgs]; - params.cNamedArgs = 0; - params.rgdispidNamedArgs = nullptr; + const float a = 12.34f; + const float b_orig = 1.234f; + const float expected = b_orig + a; - VARIANT result; - - V_VT(¶ms.rgvarg[1]) = VT_R4; - V_R4(¶ms.rgvarg[1]) = a; - V_VT(¶ms.rgvarg[0]) = VT_BYREF | VT_R4; - V_R4REF(¶ms.rgvarg[0]) = &b; + float b = b_orig; + { + DISPPARAMS params; + params.cArgs = 2; + params.rgvarg = new VARIANTARG[params.cArgs]; + params.cNamedArgs = 0; + params.rgdispidNamedArgs = nullptr; + + VARIANT result; + + V_VT(¶ms.rgvarg[1]) = VT_R4; + V_R4(¶ms.rgvarg[1]) = a; + V_VT(¶ms.rgvarg[0]) = VT_BYREF | VT_R4; + V_R4REF(¶ms.rgvarg[0]) = &b; + + + THROW_IF_FAILED(dispatchTesting->Invoke( + methodId, + IID_NULL, + lcid, + DISPATCH_METHOD, + ¶ms, + &result, + nullptr, + nullptr + )); + + THROW_FAIL_IF_FALSE(EqualByBound(expected, V_R4(&result))); + THROW_FAIL_IF_FALSE(EqualByBound(expected, b)); + } - THROW_IF_FAILED(dispatchTesting->Invoke( - methodId, - IID_NULL, - lcid, - DISPATCH_METHOD, - ¶ms, - &result, - nullptr, - nullptr - )); + { + b = b_orig; + float result; + THROW_IF_FAILED(dispatchTesting->Add_Float_ReturnAndUpdateByRef(a, &b, &result)); - THROW_FAIL_IF_FALSE(EqualByBound(expected, V_R4(&result))); - THROW_FAIL_IF_FALSE(EqualByBound(expected, b)); + THROW_FAIL_IF_FALSE(EqualByBound(expected, result)); + THROW_FAIL_IF_FALSE(EqualByBound(expected, b)); + } } void Validate_Double_In_ReturnAndUpdateByRef() @@ -237,37 +272,50 @@ void Validate_Double_In_ReturnAndUpdateByRef() lcid, &methodId)); - double a = 1856.5634; - double b = 587867.757; - double expected = a + b; - - DISPPARAMS params; - params.cArgs = 2; - params.rgvarg = new VARIANTARG[params.cArgs]; - params.cNamedArgs = 0; - params.rgdispidNamedArgs = nullptr; - - VARIANT result; + const double a = 1856.5634; + const double b_orig = 587867.757; + const double expected = a + b_orig; - V_VT(¶ms.rgvarg[1]) = VT_R8; - V_R8(¶ms.rgvarg[1]) = a; - V_VT(¶ms.rgvarg[0]) = VT_BYREF | VT_R8; - V_R8REF(¶ms.rgvarg[0]) = &b; + double b = b_orig; + { + DISPPARAMS params; + params.cArgs = 2; + params.rgvarg = new VARIANTARG[params.cArgs]; + params.cNamedArgs = 0; + params.rgdispidNamedArgs = nullptr; + + VARIANT result; + + V_VT(¶ms.rgvarg[1]) = VT_R8; + V_R8(¶ms.rgvarg[1]) = a; + V_VT(¶ms.rgvarg[0]) = VT_BYREF | VT_R8; + V_R8REF(¶ms.rgvarg[0]) = &b; + + + THROW_IF_FAILED(dispatchTesting->Invoke( + methodId, + IID_NULL, + lcid, + DISPATCH_METHOD, + ¶ms, + &result, + nullptr, + nullptr + )); + + THROW_FAIL_IF_FALSE(EqualByBound(expected, V_R8(&result))); + THROW_FAIL_IF_FALSE(EqualByBound(expected, b)); + } - THROW_IF_FAILED(dispatchTesting->Invoke( - methodId, - IID_NULL, - lcid, - DISPATCH_METHOD, - ¶ms, - &result, - nullptr, - nullptr - )); + { + b = b_orig; + double result; + THROW_IF_FAILED(dispatchTesting->Add_Double_ReturnAndUpdateByRef(a, &b, &result)); - THROW_FAIL_IF_FALSE(EqualByBound(expected, V_R8(&result))); - THROW_FAIL_IF_FALSE(EqualByBound(expected, b)); + THROW_FAIL_IF_FALSE(EqualByBound(expected, result)); + THROW_FAIL_IF_FALSE(EqualByBound(expected, b)); + } } void Validate_LCID_Marshaled() -- 2.7.4