ComWrappers: Add support for ICustomQueryInterface (#34733)
authorAaron Robinson <arobins@microsoft.com>
Fri, 10 Apr 2020 22:31:50 +0000 (15:31 -0700)
committerGitHub <noreply@github.com>
Fri, 10 Apr 2020 22:31:50 +0000 (15:31 -0700)
* Add support for falling back to ICustomQueryInterface if the managed object wrapper
 doesn't know about the IID but implements the interface.

* The QI on the managed object wrapper is now callable from within the runtime.
This means we can no longer assume we are in preemptive mode during the QI.
Update the ICustomQueryInterface dispatch to be okay with ANY mode.

src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs
src/coreclr/src/interop/comwrappers.cpp
src/coreclr/src/interop/comwrappers.hpp
src/coreclr/src/interop/inc/interoplibimports.h
src/coreclr/src/vm/interoplibinterface.cpp
src/coreclr/src/vm/metasig.h
src/coreclr/src/vm/mscorlib.h
src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs
src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs

index 59f3d47ac35856d8f4f947b7b0e4b20a64260a24..0229ac9c0dbc551f782cbfc9a68f6f6482be334c 100644 (file)
@@ -306,5 +306,17 @@ namespace System.Runtime.InteropServices
 
         [DllImport(RuntimeHelpers.QCall)]
         private static extern void GetIUnknownImplInternal(out IntPtr fpQueryInterface, out IntPtr fpAddRef, out IntPtr fpRelease);
+
+        internal static int CallICustomQueryInterface(object customQueryInterfaceMaybe, ref Guid iid, out IntPtr ppObject)
+        {
+            var customQueryInterface = customQueryInterfaceMaybe as ICustomQueryInterface;
+            if (customQueryInterface is null)
+            {
+                ppObject = IntPtr.Zero;
+                return -1; // See TryInvokeICustomQueryInterfaceResult
+            }
+
+            return (int)customQueryInterface.GetInterface(ref iid, out ppObject);
+        }
     }
 }
\ No newline at end of file
index d7fa94bab1337f405f600c10b8374441106ffc10..60d80910a5078b54fcbf2cba0e7823b9ed68d2dc 100644 (file)
@@ -9,6 +9,7 @@
 
 using OBJECTHANDLE = InteropLib::OBJECTHANDLE;
 using AllocScenario = InteropLibImports::AllocScenario;
+using TryInvokeICustomQueryInterfaceResult = InteropLibImports::TryInvokeICustomQueryInterfaceResult;
 
 namespace ABI
 {
@@ -488,9 +489,8 @@ ULONGLONG ManagedObjectWrapper::UniversalRelease(_In_ ULONGLONG dec)
     return refCount;
 }
 
-void* ManagedObjectWrapper::As(_In_ REFIID riid)
+void* ManagedObjectWrapper::AsRuntimeDefined(_In_ REFIID riid)
 {
-    // Find target interface and return dispatcher or null if not found.
     for (int32_t i = 0; i < _runtimeDefinedCount; ++i)
     {
         if (IsEqualGUID(_runtimeDefined[i].IID, riid))
@@ -499,6 +499,11 @@ void* ManagedObjectWrapper::As(_In_ REFIID riid)
         }
     }
 
+    return nullptr;
+}
+
+void* ManagedObjectWrapper::AsUserDefined(_In_ REFIID riid)
+{
     for (int32_t i = 0; i < _userDefinedCount; ++i)
     {
         if (IsEqualGUID(_userDefined[i].IID, riid))
@@ -510,6 +515,16 @@ void* ManagedObjectWrapper::As(_In_ REFIID riid)
     return nullptr;
 }
 
+void* ManagedObjectWrapper::As(_In_ REFIID riid)
+{
+    // Find target interface and return dispatcher or null if not found.
+    void* typeMaybe = AsRuntimeDefined(riid);
+    if (typeMaybe == nullptr)
+        typeMaybe = AsUserDefined(riid);
+
+    return typeMaybe;
+}
+
 bool ManagedObjectWrapper::TrySetObjectHandle(_In_ OBJECTHANDLE objectHandle, _In_ OBJECTHANDLE current)
 {
     return (::InterlockedCompareExchangePointer(&Target, objectHandle, current) == current);
@@ -582,9 +597,50 @@ HRESULT ManagedObjectWrapper::QueryInterface(
         return E_POINTER;
 
     // Find target interface
-    *ppvObject = As(riid);
+    *ppvObject = AsRuntimeDefined(riid);
     if (*ppvObject == nullptr)
-        return E_NOINTERFACE;
+    {
+        // Check if the managed object has implemented ICustomQueryInterface
+        if (!IsSet(CreateComInterfaceFlagsEx::LacksICustomQueryInterface))
+        {
+            TryInvokeICustomQueryInterfaceResult result = InteropLibImports::TryInvokeICustomQueryInterface(Target, riid, ppvObject);
+            switch (result)
+            {
+                case TryInvokeICustomQueryInterfaceResult::Handled:
+                    _ASSERTE(*ppvObject != nullptr);
+                    return S_OK;
+
+                case TryInvokeICustomQueryInterfaceResult::NotHandled:
+                    // Continue querying the static tables.
+                    break;
+
+                case TryInvokeICustomQueryInterfaceResult::Failed:
+                    _ASSERTE(*ppvObject == nullptr);
+                    return E_NOINTERFACE;
+
+                default:
+                    _ASSERTE(false && "Unknown result value");
+                case TryInvokeICustomQueryInterfaceResult::FailedToInvoke:
+                    // Set the 'lacks' flag since our attempt to use ICustomQueryInterface
+                    // indicated the object lacks an implementation.
+                    SetFlag(CreateComInterfaceFlagsEx::LacksICustomQueryInterface);
+                    break;
+
+                case TryInvokeICustomQueryInterfaceResult::OnGCThread:
+                    // We are going to assume the caller is attempting to
+                    // check if this wrapper has an interface that is supported
+                    // during a GC and not trying to do something bad.
+                    // Instead of returning immediately, we handle the case
+                    // the same way that would occur if the managed object lacked
+                    // an ICustomQueryInterface implementation.
+                    break;
+            }
+        }
+
+        *ppvObject = AsUserDefined(riid);
+        if (*ppvObject == nullptr)
+            return E_NOINTERFACE;
+    }
 
     (void)AddRef();
     return S_OK;
index e9838646b0189128e06eb16f2e55a52bec809c3e..02e3b562069ba9fb743cbe6770595107caaf5793 100644 (file)
@@ -15,11 +15,12 @@ enum class CreateComInterfaceFlagsEx : int32_t
     CallerDefinedIUnknown = InteropLib::Com::CreateComInterfaceFlags_CallerDefinedIUnknown,
     TrackerSupport = InteropLib::Com::CreateComInterfaceFlags_TrackerSupport,
 
-    // Highest bit is reserved for internal usage
+    // Highest bits are reserved for internal usage
+    LacksICustomQueryInterface = 1 << 29,
     IsComActivated = 1 << 30,
     IsPegged = 1 << 31,
 
-    InternalMask = IsPegged | IsComActivated,
+    InternalMask = IsPegged | IsComActivated | LacksICustomQueryInterface,
 };
 
 DEFINE_ENUM_FLAG_OPERATORS(CreateComInterfaceFlagsEx);
@@ -85,8 +86,16 @@ private:
     // the wrapper. Supplied with a decrementing value.
     ULONGLONG UniversalRelease(_In_ ULONGLONG dec);
 
+    // Query the runtime defined tables.
+    void* AsRuntimeDefined(_In_ REFIID riid);
+
+    // Query the user defined tables.
+    void* AsUserDefined(_In_ REFIID riid);
+
 public:
+    // N.B. Does not impact the reference count of the object.
     void* As(_In_ REFIID riid);
+
     // Attempt to set the target object handle based on an assumed current value.
     bool TrySetObjectHandle(_In_ InteropLib::OBJECTHANDLE objectHandle, _In_ InteropLib::OBJECTHANDLE current = nullptr);
     bool IsSet(_In_ CreateComInterfaceFlagsEx flag) const;
index 3217b78a35f3397d5b6bc75d79a138c19ba5152f..692f749df1a5e414d5e89a9e53c52d74adc5bfc3 100644 (file)
@@ -72,6 +72,28 @@ namespace InteropLibImports
         _In_ InteropLib::Com::CreateObjectFlags externalObjectFlags,
         _In_ InteropLib::Com::CreateComInterfaceFlags trackerTargetFlags,
         _Outptr_ void** trackerTarget) noexcept;
+
+    // The enum describes the value of System.Runtime.InteropServices.CustomQueryInterfaceResult
+    // and the case where the object doesn't support ICustomQueryInterface.
+    enum class TryInvokeICustomQueryInterfaceResult
+    {
+        OnGCThread = -2,
+        FailedToInvoke = -1,
+        Handled = 0,
+        NotHandled = 1,
+        Failed = 2,
+
+        // Range checks
+        Min = OnGCThread,
+        Max = Failed,
+    };
+
+    // Attempt to call the ICustomQueryInterface on the supplied object.
+    // Returns S_FALSE if the object doesn't support ICustomQueryInterface.
+    TryInvokeICustomQueryInterfaceResult TryInvokeICustomQueryInterface(
+        _In_ InteropLib::OBJECTHANDLE handle,
+        _In_ REFGUID iid,
+        _Outptr_result_maybenull_ void** obj) noexcept;
 }
 
 #endif // _INTEROP_INC_INTEROPLIBIMPORTS_H_
index 57345eafedc6be3d37dbf977a9f06cfb83e25cdf..8bb65c04cac2344cf96a1c01025f730ed9b6f43f 100644 (file)
@@ -481,6 +481,32 @@ namespace
         CALL_MANAGED_METHOD_NORET(args);
     }
 
+    int CallICustomQueryInterface(
+        _In_ OBJECTREF* implPROTECTED,
+        _In_ REFGUID iid,
+        _Outptr_result_maybenull_ void** ppObject)
+    {
+        CONTRACTL
+        {
+            THROWS;
+            MODE_COOPERATIVE;
+            PRECONDITION(implPROTECTED != NULL);
+            PRECONDITION(ppObject != NULL);
+        }
+        CONTRACTL_END;
+
+        int result;
+
+        PREPARE_NONVIRTUAL_CALLSITE(METHOD__COMWRAPPERS__CALL_ICUSTOMQUERYINTERFACE);
+        DECLARE_ARGHOLDER_ARRAY(args, 3);
+        args[ARGNUM_0]  = OBJECTREF_TO_ARGHOLDER(*implPROTECTED);
+        args[ARGNUM_1]  = PTR_TO_ARGHOLDER(&iid);
+        args[ARGNUM_2]  = PTR_TO_ARGHOLDER(ppObject);
+        CALL_MANAGED_METHOD(result, int, args);
+
+        return result;
+    }
+
     bool TryGetOrCreateComInterfaceForObjectInternal(
         _In_opt_ OBJECTREF impl,
         _In_ OBJECTREF instance,
@@ -1023,6 +1049,73 @@ namespace InteropLibImports
         return hr;
     }
 
+    TryInvokeICustomQueryInterfaceResult TryInvokeICustomQueryInterface(
+        _In_ InteropLib::OBJECTHANDLE handle,
+        _In_ REFGUID iid,
+        _Outptr_result_maybenull_ void** obj) noexcept
+    {
+        CONTRACTL
+        {
+            NOTHROW;
+            MODE_ANY;
+            PRECONDITION(handle != NULL);
+            PRECONDITION(obj != NULL);
+        }
+        CONTRACTL_END;
+
+        *obj = NULL;
+
+        // If this is a GC thread, then someone is trying to query for something
+        // at a time when we can't run managed code.
+        if (IsGCThread())
+            return TryInvokeICustomQueryInterfaceResult::OnGCThread;
+
+        // Ideally the BEGIN_EXTERNAL_ENTRYPOINT/END_EXTERNAL_ENTRYPOINT pairs
+        // would be used here. However, this code path can be entered from within
+        // and from outside the runtime.
+        MAKE_CURRENT_THREAD_AVAILABLE_EX(GetThreadNULLOk());
+        if (CURRENT_THREAD == NULL)
+        {
+            CURRENT_THREAD = SetupThreadNoThrow();
+
+            // If we failed to set up a new thread, we are going to indicate
+            // there was a general failure to invoke instead of failing fast.
+            if (CURRENT_THREAD == NULL)
+                return TryInvokeICustomQueryInterfaceResult::FailedToInvoke;
+        }
+
+        HRESULT hr;
+        auto result = TryInvokeICustomQueryInterfaceResult::FailedToInvoke;
+        EX_TRY_THREAD(CURRENT_THREAD)
+        {
+            // Switch to Cooperative mode since object references
+            // are being manipulated.
+            GCX_COOP();
+
+            struct
+            {
+                OBJECTREF objRef;
+            } gc;
+            ::ZeroMemory(&gc, sizeof(gc));
+            GCPROTECT_BEGIN(gc);
+
+            // Get the target of the external object's reference.
+            ::OBJECTHANDLE objectHandle = static_cast<::OBJECTHANDLE>(handle);
+            gc.objRef = ObjectFromHandle(objectHandle);
+
+            result = (TryInvokeICustomQueryInterfaceResult)CallICustomQueryInterface(&gc.objRef, iid, obj);
+
+            GCPROTECT_END();
+        }
+        EX_CATCH_HRESULT(hr);
+
+        // Assert valid value.
+        _ASSERTE(TryInvokeICustomQueryInterfaceResult::Min <= result
+            && result <= TryInvokeICustomQueryInterfaceResult::Max);
+
+        return result;
+    }
+
     struct RuntimeCallContext
     {
         // Iterators for all known external objects.
index e7830468cfa5898d411f1d592061602e134cb1cf..b35e6fd613852f2bddd2fcb43f157e2c943457c4 100644 (file)
@@ -199,6 +199,7 @@ DEFINE_METASIG_T(SM(Exception_IntPtr_RetException, C(EXCEPTION) I, C(EXCEPTION))
 DEFINE_METASIG_T(SM(ComWrappers_Obj_CreateFlags_RefInt_RetPtrVoid, C(COMWRAPPERS) j g(CREATECOMINTERFACEFLAGS) r(i), P(v)))
 DEFINE_METASIG_T(SM(ComWrappers_IntPtr_CreateFlags_RetObj, C(COMWRAPPERS) I g(CREATEOBJECTFLAGS), j))
 DEFINE_METASIG_T(SM(ComWrappers_IEnumerable_RetVoid, C(COMWRAPPERS) C(IENUMERABLE), v))
+DEFINE_METASIG_T(SM(Obj_RefGuid_RefIntPtr_RetInt, j r(g(GUID)) r(I), i))
 #endif // FEATURE_COMINTEROP
 DEFINE_METASIG(SM(Int_RetVoid, i, v))
 DEFINE_METASIG(SM(Int_Int_RetVoid, i i, v))
index d2057ffd5c74d2f6ac49b764c1c90dcce3d798c9..67f140a202dafde05b119b07f334980ace5e6a9f 100644 (file)
@@ -465,6 +465,7 @@ DEFINE_CLASS(CREATEOBJECTFLAGS,           Interop,          CreateObjectFlags)
 DEFINE_METHOD(COMWRAPPERS,                COMPUTE_VTABLES,  CallComputeVtables,         SM_ComWrappers_Obj_CreateFlags_RefInt_RetPtrVoid)
 DEFINE_METHOD(COMWRAPPERS,                CREATE_OBJECT,    CallCreateObject,           SM_ComWrappers_IntPtr_CreateFlags_RetObj)
 DEFINE_METHOD(COMWRAPPERS,                RELEASE_OBJECTS,  CallReleaseObjects,         SM_ComWrappers_IEnumerable_RetVoid)
+DEFINE_METHOD(COMWRAPPERS,     CALL_ICUSTOMQUERYINTERFACE,  CallICustomQueryInterface,  SM_Obj_RefGuid_RefIntPtr_RetInt)
 #endif //FEATURE_COMINTEROP
 
 DEFINE_CLASS(SERIALIZATION_INFO,        Serialization,      SerializationInfo)
index 2eaf5aafd53c24a193a3a6eab9a027f75b49ce0f..481eafdd00ae4848ee3e11c4a02ffdb12b59b5e6 100644 (file)
@@ -109,6 +109,36 @@ namespace ComWrappersTests
             Assert.AreEqual(count, 0);
         }
 
+        static void ValidateFallbackQueryInterface()
+        {
+            Console.WriteLine($"Running {nameof(ValidateFallbackQueryInterface)}...");
+
+            var testObj = new Test()
+                {
+                    EnableICustomQueryInterface = true
+                };
+
+            var wrappers = new TestComWrappers();
+
+            // Allocate a wrapper for the object
+            IntPtr comWrapper = wrappers.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.None);
+
+            testObj.ICustomQueryInterface_GetInterfaceResult = new IntPtr(0x2000000);
+
+            IntPtr result;
+            var anyGuid = new Guid("1E42439C-DCB5-4701-ACBD-87FE92E785DE");
+            testObj.ICustomQueryInterface_GetInterfaceIID = anyGuid;
+            int hr = Marshal.QueryInterface(comWrapper, ref anyGuid, out result);
+            Assert.AreEqual(hr, 0);
+            Assert.AreEqual(result, testObj.ICustomQueryInterface_GetInterfaceResult);
+
+            var anyGuid2 = new Guid("7996D0F9-C8DD-4544-B708-0F75C6FF076F");
+            hr = Marshal.QueryInterface(comWrapper, ref anyGuid2, out result);
+            const int E_NOINTERFACE = unchecked((int)0x80004002);
+            Assert.AreEqual(hr, E_NOINTERFACE);
+            Assert.AreEqual(result, IntPtr.Zero);
+        }
+
         static void ValidateCreateObjectCachingScenario()
         {
             Console.WriteLine($"Running {nameof(ValidateCreateObjectCachingScenario)}...");
@@ -325,6 +355,7 @@ namespace ComWrappersTests
             try
             {
                 ValidateComInterfaceCreation();
+                ValidateFallbackQueryInterface();
                 ValidateCreateObjectCachingScenario();
                 ValidatePrecreatedExternalWrapper();
                 ValidateIUnknownImpls();
index 06c052b0a237a6d13ccced454a9cdf706803bfd3..2f322018f0f8005ea24b975661d4f7e924871883 100644 (file)
@@ -16,7 +16,7 @@ namespace ComWrappersTests.Common
         void SetValue(int i);
     }
 
-    class Test : ITest
+    class Test : ITest, ICustomQueryInterface
     {
         public static int InstanceCount = 0;
 
@@ -26,6 +26,27 @@ namespace ComWrappersTests.Common
 
         public void SetValue(int i) => this.value = i;
         public int GetValue() => this.value;
+
+        public bool EnableICustomQueryInterface { get; set; } = false;
+        public Guid ICustomQueryInterface_GetInterfaceIID { get; set; }
+        public IntPtr ICustomQueryInterface_GetInterfaceResult { get; set; }
+
+        CustomQueryInterfaceResult ICustomQueryInterface.GetInterface(ref Guid iid, out IntPtr ppv)
+        {
+            ppv = IntPtr.Zero;
+            if (!EnableICustomQueryInterface)
+            {
+                return CustomQueryInterfaceResult.NotHandled;
+            }
+
+            if (iid != ICustomQueryInterface_GetInterfaceIID)
+            {
+                return CustomQueryInterfaceResult.Failed;
+            }
+
+            ppv = this.ICustomQueryInterface_GetInterfaceResult;
+            return CustomQueryInterfaceResult.Handled;
+        }
     }
 
     public struct IUnknownVtbl