ComWrappers: Add support for ICustomQueryInterface (#34733)
authorAaron Robinson <arobins@microsoft.com>
Fri, 10 Apr 2020 22:31:50 +0000 (15:31 -0700)
committerGitHub <noreply@github.com>
Fri, 10 Apr 2020 22:31:50 +0000 (15:31 -0700)
* Add support for falling back to ICustomQueryInterface if the managed object wrapper
 doesn't know about the IID but implements the interface.

* The QI on the managed object wrapper is now callable from within the runtime.
This means we can no longer assume we are in preemptive mode during the QI.
Update the ICustomQueryInterface dispatch to be okay with ANY mode.

src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs
src/coreclr/src/interop/comwrappers.cpp
src/coreclr/src/interop/comwrappers.hpp
src/coreclr/src/interop/inc/interoplibimports.h
src/coreclr/src/vm/interoplibinterface.cpp
src/coreclr/src/vm/metasig.h
src/coreclr/src/vm/mscorlib.h
src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs
src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs

index 59f3d47..0229ac9 100644 (file)
@@ -306,5 +306,17 @@ namespace System.Runtime.InteropServices
 
         [DllImport(RuntimeHelpers.QCall)]
         private static extern void GetIUnknownImplInternal(out IntPtr fpQueryInterface, out IntPtr fpAddRef, out IntPtr fpRelease);
+
+        internal static int CallICustomQueryInterface(object customQueryInterfaceMaybe, ref Guid iid, out IntPtr ppObject)
+        {
+            var customQueryInterface = customQueryInterfaceMaybe as ICustomQueryInterface;
+            if (customQueryInterface is null)
+            {
+                ppObject = IntPtr.Zero;
+                return -1; // See TryInvokeICustomQueryInterfaceResult
+            }
+
+            return (int)customQueryInterface.GetInterface(ref iid, out ppObject);
+        }
     }
 }
\ No newline at end of file
index d7fa94b..60d8091 100644 (file)
@@ -9,6 +9,7 @@
 
 using OBJECTHANDLE = InteropLib::OBJECTHANDLE;
 using AllocScenario = InteropLibImports::AllocScenario;
+using TryInvokeICustomQueryInterfaceResult = InteropLibImports::TryInvokeICustomQueryInterfaceResult;
 
 namespace ABI
 {
@@ -488,9 +489,8 @@ ULONGLONG ManagedObjectWrapper::UniversalRelease(_In_ ULONGLONG dec)
     return refCount;
 }
 
-void* ManagedObjectWrapper::As(_In_ REFIID riid)
+void* ManagedObjectWrapper::AsRuntimeDefined(_In_ REFIID riid)
 {
-    // Find target interface and return dispatcher or null if not found.
     for (int32_t i = 0; i < _runtimeDefinedCount; ++i)
     {
         if (IsEqualGUID(_runtimeDefined[i].IID, riid))
@@ -499,6 +499,11 @@ void* ManagedObjectWrapper::As(_In_ REFIID riid)
         }
     }
 
+    return nullptr;
+}
+
+void* ManagedObjectWrapper::AsUserDefined(_In_ REFIID riid)
+{
     for (int32_t i = 0; i < _userDefinedCount; ++i)
     {
         if (IsEqualGUID(_userDefined[i].IID, riid))
@@ -510,6 +515,16 @@ void* ManagedObjectWrapper::As(_In_ REFIID riid)
     return nullptr;
 }
 
+void* ManagedObjectWrapper::As(_In_ REFIID riid)
+{
+    // Find target interface and return dispatcher or null if not found.
+    void* typeMaybe = AsRuntimeDefined(riid);
+    if (typeMaybe == nullptr)
+        typeMaybe = AsUserDefined(riid);
+
+    return typeMaybe;
+}
+
 bool ManagedObjectWrapper::TrySetObjectHandle(_In_ OBJECTHANDLE objectHandle, _In_ OBJECTHANDLE current)
 {
     return (::InterlockedCompareExchangePointer(&Target, objectHandle, current) == current);
@@ -582,9 +597,50 @@ HRESULT ManagedObjectWrapper::QueryInterface(
         return E_POINTER;
 
     // Find target interface
-    *ppvObject = As(riid);
+    *ppvObject = AsRuntimeDefined(riid);
     if (*ppvObject == nullptr)
-        return E_NOINTERFACE;
+    {
+        // Check if the managed object has implemented ICustomQueryInterface
+        if (!IsSet(CreateComInterfaceFlagsEx::LacksICustomQueryInterface))
+        {
+            TryInvokeICustomQueryInterfaceResult result = InteropLibImports::TryInvokeICustomQueryInterface(Target, riid, ppvObject);
+            switch (result)
+            {
+                case TryInvokeICustomQueryInterfaceResult::Handled:
+                    _ASSERTE(*ppvObject != nullptr);
+                    return S_OK;
+
+                case TryInvokeICustomQueryInterfaceResult::NotHandled:
+                    // Continue querying the static tables.
+                    break;
+
+                case TryInvokeICustomQueryInterfaceResult::Failed:
+                    _ASSERTE(*ppvObject == nullptr);
+                    return E_NOINTERFACE;
+
+                default:
+                    _ASSERTE(false && "Unknown result value");
+                case TryInvokeICustomQueryInterfaceResult::FailedToInvoke:
+                    // Set the 'lacks' flag since our attempt to use ICustomQueryInterface
+                    // indicated the object lacks an implementation.
+                    SetFlag(CreateComInterfaceFlagsEx::LacksICustomQueryInterface);
+                    break;
+
+                case TryInvokeICustomQueryInterfaceResult::OnGCThread:
+                    // We are going to assume the caller is attempting to
+                    // check if this wrapper has an interface that is supported
+                    // during a GC and not trying to do something bad.
+                    // Instead of returning immediately, we handle the case
+                    // the same way that would occur if the managed object lacked
+                    // an ICustomQueryInterface implementation.
+                    break;
+            }
+        }
+
+        *ppvObject = AsUserDefined(riid);
+        if (*ppvObject == nullptr)
+            return E_NOINTERFACE;
+    }
 
     (void)AddRef();
     return S_OK;
index e983864..02e3b56 100644 (file)
@@ -15,11 +15,12 @@ enum class CreateComInterfaceFlagsEx : int32_t
     CallerDefinedIUnknown = InteropLib::Com::CreateComInterfaceFlags_CallerDefinedIUnknown,
     TrackerSupport = InteropLib::Com::CreateComInterfaceFlags_TrackerSupport,
 
-    // Highest bit is reserved for internal usage
+    // Highest bits are reserved for internal usage
+    LacksICustomQueryInterface = 1 << 29,
     IsComActivated = 1 << 30,
     IsPegged = 1 << 31,
 
-    InternalMask = IsPegged | IsComActivated,
+    InternalMask = IsPegged | IsComActivated | LacksICustomQueryInterface,
 };
 
 DEFINE_ENUM_FLAG_OPERATORS(CreateComInterfaceFlagsEx);
@@ -85,8 +86,16 @@ private:
     // the wrapper. Supplied with a decrementing value.
     ULONGLONG UniversalRelease(_In_ ULONGLONG dec);
 
+    // Query the runtime defined tables.
+    void* AsRuntimeDefined(_In_ REFIID riid);
+
+    // Query the user defined tables.
+    void* AsUserDefined(_In_ REFIID riid);
+
 public:
+    // N.B. Does not impact the reference count of the object.
     void* As(_In_ REFIID riid);
+
     // Attempt to set the target object handle based on an assumed current value.
     bool TrySetObjectHandle(_In_ InteropLib::OBJECTHANDLE objectHandle, _In_ InteropLib::OBJECTHANDLE current = nullptr);
     bool IsSet(_In_ CreateComInterfaceFlagsEx flag) const;
index 3217b78..692f749 100644 (file)
@@ -72,6 +72,28 @@ namespace InteropLibImports
         _In_ InteropLib::Com::CreateObjectFlags externalObjectFlags,
         _In_ InteropLib::Com::CreateComInterfaceFlags trackerTargetFlags,
         _Outptr_ void** trackerTarget) noexcept;
+
+    // The enum describes the value of System.Runtime.InteropServices.CustomQueryInterfaceResult
+    // and the case where the object doesn't support ICustomQueryInterface.
+    enum class TryInvokeICustomQueryInterfaceResult
+    {
+        OnGCThread = -2,
+        FailedToInvoke = -1,
+        Handled = 0,
+        NotHandled = 1,
+        Failed = 2,
+
+        // Range checks
+        Min = OnGCThread,
+        Max = Failed,
+    };
+
+    // Attempt to call the ICustomQueryInterface on the supplied object.
+    // Returns S_FALSE if the object doesn't support ICustomQueryInterface.
+    TryInvokeICustomQueryInterfaceResult TryInvokeICustomQueryInterface(
+        _In_ InteropLib::OBJECTHANDLE handle,
+        _In_ REFGUID iid,
+        _Outptr_result_maybenull_ void** obj) noexcept;
 }
 
 #endif // _INTEROP_INC_INTEROPLIBIMPORTS_H_
index 57345ea..8bb65c0 100644 (file)
@@ -481,6 +481,32 @@ namespace
         CALL_MANAGED_METHOD_NORET(args);
     }
 
+    int CallICustomQueryInterface(
+        _In_ OBJECTREF* implPROTECTED,
+        _In_ REFGUID iid,
+        _Outptr_result_maybenull_ void** ppObject)
+    {
+        CONTRACTL
+        {
+            THROWS;
+            MODE_COOPERATIVE;
+            PRECONDITION(implPROTECTED != NULL);
+            PRECONDITION(ppObject != NULL);
+        }
+        CONTRACTL_END;
+
+        int result;
+
+        PREPARE_NONVIRTUAL_CALLSITE(METHOD__COMWRAPPERS__CALL_ICUSTOMQUERYINTERFACE);
+        DECLARE_ARGHOLDER_ARRAY(args, 3);
+        args[ARGNUM_0]  = OBJECTREF_TO_ARGHOLDER(*implPROTECTED);
+        args[ARGNUM_1]  = PTR_TO_ARGHOLDER(&iid);
+        args[ARGNUM_2]  = PTR_TO_ARGHOLDER(ppObject);
+        CALL_MANAGED_METHOD(result, int, args);
+
+        return result;
+    }
+
     bool TryGetOrCreateComInterfaceForObjectInternal(
         _In_opt_ OBJECTREF impl,
         _In_ OBJECTREF instance,
@@ -1023,6 +1049,73 @@ namespace InteropLibImports
         return hr;
     }
 
+    TryInvokeICustomQueryInterfaceResult TryInvokeICustomQueryInterface(
+        _In_ InteropLib::OBJECTHANDLE handle,
+        _In_ REFGUID iid,
+        _Outptr_result_maybenull_ void** obj) noexcept
+    {
+        CONTRACTL
+        {
+            NOTHROW;
+            MODE_ANY;
+            PRECONDITION(handle != NULL);
+            PRECONDITION(obj != NULL);
+        }
+        CONTRACTL_END;
+
+        *obj = NULL;
+
+        // If this is a GC thread, then someone is trying to query for something
+        // at a time when we can't run managed code.
+        if (IsGCThread())
+            return TryInvokeICustomQueryInterfaceResult::OnGCThread;
+
+        // Ideally the BEGIN_EXTERNAL_ENTRYPOINT/END_EXTERNAL_ENTRYPOINT pairs
+        // would be used here. However, this code path can be entered from within
+        // and from outside the runtime.
+        MAKE_CURRENT_THREAD_AVAILABLE_EX(GetThreadNULLOk());
+        if (CURRENT_THREAD == NULL)
+        {
+            CURRENT_THREAD = SetupThreadNoThrow();
+
+            // If we failed to set up a new thread, we are going to indicate
+            // there was a general failure to invoke instead of failing fast.
+            if (CURRENT_THREAD == NULL)
+                return TryInvokeICustomQueryInterfaceResult::FailedToInvoke;
+        }
+
+        HRESULT hr;
+        auto result = TryInvokeICustomQueryInterfaceResult::FailedToInvoke;
+        EX_TRY_THREAD(CURRENT_THREAD)
+        {
+            // Switch to Cooperative mode since object references
+            // are being manipulated.
+            GCX_COOP();
+
+            struct
+            {
+                OBJECTREF objRef;
+            } gc;
+            ::ZeroMemory(&gc, sizeof(gc));
+            GCPROTECT_BEGIN(gc);
+
+            // Get the target of the external object's reference.
+            ::OBJECTHANDLE objectHandle = static_cast<::OBJECTHANDLE>(handle);
+            gc.objRef = ObjectFromHandle(objectHandle);
+
+            result = (TryInvokeICustomQueryInterfaceResult)CallICustomQueryInterface(&gc.objRef, iid, obj);
+
+            GCPROTECT_END();
+        }
+        EX_CATCH_HRESULT(hr);
+
+        // Assert valid value.
+        _ASSERTE(TryInvokeICustomQueryInterfaceResult::Min <= result
+            && result <= TryInvokeICustomQueryInterfaceResult::Max);
+
+        return result;
+    }
+
     struct RuntimeCallContext
     {
         // Iterators for all known external objects.
index e783046..b35e6fd 100644 (file)
@@ -199,6 +199,7 @@ DEFINE_METASIG_T(SM(Exception_IntPtr_RetException, C(EXCEPTION) I, C(EXCEPTION))
 DEFINE_METASIG_T(SM(ComWrappers_Obj_CreateFlags_RefInt_RetPtrVoid, C(COMWRAPPERS) j g(CREATECOMINTERFACEFLAGS) r(i), P(v)))
 DEFINE_METASIG_T(SM(ComWrappers_IntPtr_CreateFlags_RetObj, C(COMWRAPPERS) I g(CREATEOBJECTFLAGS), j))
 DEFINE_METASIG_T(SM(ComWrappers_IEnumerable_RetVoid, C(COMWRAPPERS) C(IENUMERABLE), v))
+DEFINE_METASIG_T(SM(Obj_RefGuid_RefIntPtr_RetInt, j r(g(GUID)) r(I), i))
 #endif // FEATURE_COMINTEROP
 DEFINE_METASIG(SM(Int_RetVoid, i, v))
 DEFINE_METASIG(SM(Int_Int_RetVoid, i i, v))
index d2057ff..67f140a 100644 (file)
@@ -465,6 +465,7 @@ DEFINE_CLASS(CREATEOBJECTFLAGS,           Interop,          CreateObjectFlags)
 DEFINE_METHOD(COMWRAPPERS,                COMPUTE_VTABLES,  CallComputeVtables,         SM_ComWrappers_Obj_CreateFlags_RefInt_RetPtrVoid)
 DEFINE_METHOD(COMWRAPPERS,                CREATE_OBJECT,    CallCreateObject,           SM_ComWrappers_IntPtr_CreateFlags_RetObj)
 DEFINE_METHOD(COMWRAPPERS,                RELEASE_OBJECTS,  CallReleaseObjects,         SM_ComWrappers_IEnumerable_RetVoid)
+DEFINE_METHOD(COMWRAPPERS,     CALL_ICUSTOMQUERYINTERFACE,  CallICustomQueryInterface,  SM_Obj_RefGuid_RefIntPtr_RetInt)
 #endif //FEATURE_COMINTEROP
 
 DEFINE_CLASS(SERIALIZATION_INFO,        Serialization,      SerializationInfo)
index 2eaf5aa..481eafd 100644 (file)
@@ -109,6 +109,36 @@ namespace ComWrappersTests
             Assert.AreEqual(count, 0);
         }
 
+        static void ValidateFallbackQueryInterface()
+        {
+            Console.WriteLine($"Running {nameof(ValidateFallbackQueryInterface)}...");
+
+            var testObj = new Test()
+                {
+                    EnableICustomQueryInterface = true
+                };
+
+            var wrappers = new TestComWrappers();
+
+            // Allocate a wrapper for the object
+            IntPtr comWrapper = wrappers.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.None);
+
+            testObj.ICustomQueryInterface_GetInterfaceResult = new IntPtr(0x2000000);
+
+            IntPtr result;
+            var anyGuid = new Guid("1E42439C-DCB5-4701-ACBD-87FE92E785DE");
+            testObj.ICustomQueryInterface_GetInterfaceIID = anyGuid;
+            int hr = Marshal.QueryInterface(comWrapper, ref anyGuid, out result);
+            Assert.AreEqual(hr, 0);
+            Assert.AreEqual(result, testObj.ICustomQueryInterface_GetInterfaceResult);
+
+            var anyGuid2 = new Guid("7996D0F9-C8DD-4544-B708-0F75C6FF076F");
+            hr = Marshal.QueryInterface(comWrapper, ref anyGuid2, out result);
+            const int E_NOINTERFACE = unchecked((int)0x80004002);
+            Assert.AreEqual(hr, E_NOINTERFACE);
+            Assert.AreEqual(result, IntPtr.Zero);
+        }
+
         static void ValidateCreateObjectCachingScenario()
         {
             Console.WriteLine($"Running {nameof(ValidateCreateObjectCachingScenario)}...");
@@ -325,6 +355,7 @@ namespace ComWrappersTests
             try
             {
                 ValidateComInterfaceCreation();
+                ValidateFallbackQueryInterface();
                 ValidateCreateObjectCachingScenario();
                 ValidatePrecreatedExternalWrapper();
                 ValidateIUnknownImpls();
index 06c052b..2f32201 100644 (file)
@@ -16,7 +16,7 @@ namespace ComWrappersTests.Common
         void SetValue(int i);
     }
 
-    class Test : ITest
+    class Test : ITest, ICustomQueryInterface
     {
         public static int InstanceCount = 0;
 
@@ -26,6 +26,27 @@ namespace ComWrappersTests.Common
 
         public void SetValue(int i) => this.value = i;
         public int GetValue() => this.value;
+
+        public bool EnableICustomQueryInterface { get; set; } = false;
+        public Guid ICustomQueryInterface_GetInterfaceIID { get; set; }
+        public IntPtr ICustomQueryInterface_GetInterfaceResult { get; set; }
+
+        CustomQueryInterfaceResult ICustomQueryInterface.GetInterface(ref Guid iid, out IntPtr ppv)
+        {
+            ppv = IntPtr.Zero;
+            if (!EnableICustomQueryInterface)
+            {
+                return CustomQueryInterfaceResult.NotHandled;
+            }
+
+            if (iid != ICustomQueryInterface_GetInterfaceIID)
+            {
+                return CustomQueryInterfaceResult.Failed;
+            }
+
+            ppv = this.ICustomQueryInterface_GetInterfaceResult;
+            return CustomQueryInterfaceResult.Handled;
+        }
     }
 
     public struct IUnknownVtbl