ComWrappers isolation (#37861)
authorElinor Fung <47805090+elinor-fung@users.noreply.github.com>
Wed, 17 Jun 2020 03:10:20 +0000 (20:10 -0700)
committerGitHub <noreply@github.com>
Wed, 17 Jun 2020 03:10:20 +0000 (20:10 -0700)
- Assign/track ID of ComWrappers used to create object/wrapper.
- Only reuse object/wrapper created by same instance for GetOrCreate*
- Rehydrate RCW on WeakReference for global ComWrapper instances only

src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs
src/coreclr/src/vm/appdomain.hpp
src/coreclr/src/vm/gchandleutilities.h
src/coreclr/src/vm/interoplibinterface.cpp
src/coreclr/src/vm/interoplibinterface.h
src/coreclr/src/vm/interoputil.cpp
src/coreclr/src/vm/syncblk.h
src/coreclr/src/vm/weakreferencenative.cpp
src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs
src/coreclr/tests/src/Interop/COM/ComWrappers/WeakReference/WeakReferenceTest.cs

index 6bdd48d..b404dc6 100644 (file)
@@ -126,12 +126,20 @@ namespace System.Runtime.InteropServices
         /// </summary>
         private static ComWrappers? s_globalInstanceForMarshalling;
 
+        private static long s_instanceCounter;
+        private readonly long id = Interlocked.Increment(ref s_instanceCounter);
+
         /// <summary>
         /// Create a COM representation of the supplied object that can be passed to a non-managed environment.
         /// </summary>
         /// <param name="instance">The managed object to expose outside the .NET runtime.</param>
         /// <param name="flags">Flags used to configure the generated interface.</param>
         /// <returns>The generated COM interface that can be passed outside the .NET runtime.</returns>
+        /// <remarks>
+        /// If a COM representation was previously created for the specified <paramref name="instance" /> using
+        /// this <see cref="ComWrappers" /> instance, the previously created COM interface will be returned.
+        /// If not, a new one will be created.
+        /// </remarks>
         public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfaceFlags flags)
         {
             IntPtr ptr;
@@ -152,16 +160,16 @@ namespace System.Runtime.InteropServices
         /// <remarks>
         /// If <paramref name="impl" /> is <c>null</c>, the global instance (if registered) will be used.
         /// </remarks>
-        private static bool TryGetOrCreateComInterfaceForObjectInternal(ComWrappers? impl, object instance, CreateComInterfaceFlags flags, out IntPtr retValue)
+        private static bool TryGetOrCreateComInterfaceForObjectInternal(ComWrappers impl, object instance, CreateComInterfaceFlags flags, out IntPtr retValue)
         {
             if (instance == null)
                 throw new ArgumentNullException(nameof(instance));
 
-            return TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack.Create(ref impl), ObjectHandleOnStack.Create(ref instance), flags, out retValue);
+            return TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack.Create(ref impl), impl.id, ObjectHandleOnStack.Create(ref instance), flags, out retValue);
         }
 
         [DllImport(RuntimeHelpers.QCall)]
-        private static extern bool TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack comWrappersImpl, ObjectHandleOnStack instance, CreateComInterfaceFlags flags, out IntPtr retValue);
+        private static extern bool TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack comWrappersImpl, long wrapperId, ObjectHandleOnStack instance, CreateComInterfaceFlags flags, out IntPtr retValue);
 
         /// <summary>
         /// Compute the desired Vtable for <paramref name="obj"/> respecting the values of <paramref name="flags"/>.
@@ -211,6 +219,11 @@ namespace System.Runtime.InteropServices
         /// <param name="externalComObject">Object to import for usage into the .NET runtime.</param>
         /// <param name="flags">Flags used to describe the external object.</param>
         /// <returns>Returns a managed object associated with the supplied external COM object.</returns>
+        /// <remarks>
+        /// If a managed object was previously created for the specified <paramref name="externalComObject" />
+        /// using this <see cref="ComWrappers" /> instance, the previously created object will be returned.
+        /// If not, a new one will be created.
+        /// </remarks>
         public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateObjectFlags flags)
         {
             object? obj;
@@ -288,18 +301,18 @@ namespace System.Runtime.InteropServices
         /// <remarks>
         /// If <paramref name="impl" /> is <c>null</c>, the global instance (if registered) will be used.
         /// </remarks>
-        private static bool TryGetOrCreateObjectForComInstanceInternal(ComWrappers? impl, IntPtr externalComObject, CreateObjectFlags flags, object? wrapperMaybe, out object? retValue)
+        private static bool TryGetOrCreateObjectForComInstanceInternal(ComWrappers impl, IntPtr externalComObject, CreateObjectFlags flags, object? wrapperMaybe, out object? retValue)
         {
             if (externalComObject == IntPtr.Zero)
                 throw new ArgumentNullException(nameof(externalComObject));
 
             object? wrapperMaybeLocal = wrapperMaybe;
             retValue = null;
-            return TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack.Create(ref impl), externalComObject, flags, ObjectHandleOnStack.Create(ref wrapperMaybeLocal), ObjectHandleOnStack.Create(ref retValue));
+            return TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack.Create(ref impl), impl.id, externalComObject, flags, ObjectHandleOnStack.Create(ref wrapperMaybeLocal), ObjectHandleOnStack.Create(ref retValue));
         }
 
         [DllImport(RuntimeHelpers.QCall)]
-        private static extern bool TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack comWrappersImpl, IntPtr externalComObject, CreateObjectFlags flags, ObjectHandleOnStack wrapper, ObjectHandleOnStack retValue);
+        private static extern bool TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack comWrappersImpl, long wrapperId, IntPtr externalComObject, CreateObjectFlags flags, ObjectHandleOnStack wrapper, ObjectHandleOnStack retValue);
 
         /// <summary>
         /// Called when a request is made for a collection of objects to be released outside of normal object or COM interface lifetime.
@@ -332,13 +345,13 @@ namespace System.Runtime.InteropServices
                 throw new InvalidOperationException(SR.InvalidOperation_ResetGlobalComWrappersInstance);
             }
 
-            SetGlobalInstanceRegisteredForTrackerSupport();
+            SetGlobalInstanceRegisteredForTrackerSupport(instance.id);
         }
 
 
         [DllImport(RuntimeHelpers.QCall)]
         [SuppressGCTransition]
-        private static extern void SetGlobalInstanceRegisteredForTrackerSupport();
+        private static extern void SetGlobalInstanceRegisteredForTrackerSupport(long id);
 
         /// <summary>
         /// Register a <see cref="ComWrappers" /> instance to be used as the global instance for marshalling in the runtime.
@@ -366,12 +379,12 @@ namespace System.Runtime.InteropServices
             // Indicate to the runtime that a global instance has been registered for marshalling.
             // This allows the native runtime know to call into the managed ComWrappers only if a
             // global instance is registered for marshalling.
-            SetGlobalInstanceRegisteredForMarshalling();
+            SetGlobalInstanceRegisteredForMarshalling(instance.id);
         }
 
         [DllImport(RuntimeHelpers.QCall)]
         [SuppressGCTransition]
-        private static extern void SetGlobalInstanceRegisteredForMarshalling();
+        private static extern void SetGlobalInstanceRegisteredForMarshalling(long id);
 
         /// <summary>
         /// Get the runtime provided IUnknown implementation.
index afdee41..6b1393f 100644 (file)
@@ -1099,10 +1099,10 @@ public:
         return ::CreateRefcountedHandle(m_handleStore, object);
     }
 
-    OBJECTHANDLE CreateNativeComWeakHandle(OBJECTREF object, IWeakReference* pComWeakReference)
+    OBJECTHANDLE CreateNativeComWeakHandle(OBJECTREF object, NativeComWeakHandleInfo* pComWeakHandleInfo)
     {
         WRAPPER_NO_CONTRACT;
-        return ::CreateNativeComWeakHandle(m_handleStore, object, pComWeakReference);
+        return ::CreateNativeComWeakHandle(m_handleStore, object, pComWeakHandleInfo);
     }
 #endif // FEATURE_COMINTEROP
 
index d90a2f0..039df61 100644 (file)
@@ -201,9 +201,16 @@ inline OBJECTHANDLE CreateGlobalRefcountedHandle(OBJECTREF object)
 // Special handle creation convenience functions
 
 #ifdef FEATURE_COMINTEROP
-inline OBJECTHANDLE CreateNativeComWeakHandle(IGCHandleStore* store, OBJECTREF object, IWeakReference* pComWeakReference)
+
+struct NativeComWeakHandleInfo
+{
+    IWeakReference *WeakReference;
+    INT64 WrapperId;
+};
+
+inline OBJECTHANDLE CreateNativeComWeakHandle(IGCHandleStore* store, OBJECTREF object, NativeComWeakHandleInfo* pComWeakHandleInfo)
 {
-    OBJECTHANDLE hnd = store->CreateHandleWithExtraInfo(OBJECTREFToObject(object), HNDTYPE_WEAK_NATIVE_COM, (void*)pComWeakReference);
+    OBJECTHANDLE hnd = store->CreateHandleWithExtraInfo(OBJECTREFToObject(object), HNDTYPE_WEAK_NATIVE_COM, (void*)pComWeakHandleInfo);
     if (!hnd)
     {
         COMPlusThrowOM();
@@ -374,14 +381,16 @@ inline void DestroyNativeComWeakHandle(OBJECTHANDLE handle)
     }
     CONTRACTL_END;
 
-    // Release the WinRT weak reference if we have one. We're assuming that this will not reenter the
-    // runtime, since if we are pointing at a managed object, we should not be using HNDTYPE_WEAK_NATIVE_COM
-    // but rather HNDTYPE_WEAK_SHORT or HNDTYPE_WEAK_LONG.
+    // Delete the COM info and release the weak reference if we have one. We're assuming that
+    // this will not reenter the runtime, since if we are pointing at a managed object, we should
+    // not be using HNDTYPE_WEAK_NATIVE_COM but rather HNDTYPE_WEAK_SHORT or HNDTYPE_WEAK_LONG.
     void* pExtraInfo = GCHandleUtilities::GetGCHandleManager()->GetExtraInfoFromHandle(handle);
-    IWeakReference* pWinRTWeakReference = reinterpret_cast<IWeakReference*>(pExtraInfo);
-    if (pWinRTWeakReference != nullptr)
+    NativeComWeakHandleInfo* comWeakHandleInfo = reinterpret_cast<NativeComWeakHandleInfo*>(pExtraInfo);
+    if (comWeakHandleInfo != nullptr)
     {
-        pWinRTWeakReference->Release();
+        _ASSERTE(comWeakHandleInfo->WeakReference != nullptr);
+        comWeakHandleInfo->WeakReference->Release();
+        delete comWeakHandleInfo;
     }
 
     DiagHandleDestroyed(handle);
index a17bf75..ffd8d4f 100644 (file)
@@ -26,6 +26,7 @@ namespace
         void* Identity;
         void* ThreadContext;
         DWORD SyncBlockIndex;
+        INT64 WrapperId;
 
         enum
         {
@@ -41,6 +42,7 @@ namespace
             _In_ IUnknown* identity,
             _In_opt_ void* threadContext,
             _In_ DWORD syncBlockIndex,
+            _In_ INT64 wrapperId,
             _In_ DWORD flags)
         {
             CONTRACTL
@@ -57,6 +59,7 @@ namespace
             cxt->Identity = (void*)identity;
             cxt->ThreadContext = threadContext;
             cxt->SyncBlockIndex = syncBlockIndex;
+            cxt->WrapperId = wrapperId;
             cxt->Flags = flags;
         }
 
@@ -91,6 +94,40 @@ namespace
             _ASSERTE(IsActive());
             return ObjectToOBJECTREF(g_pSyncTable[SyncBlockIndex].m_Object);
         }
+
+        struct Key
+        {
+        public:
+            Key(void* identity, INT64 wrapperId)
+                : _identity { identity }
+                , _wrapperId { wrapperId }
+            {
+                _ASSERTE(identity != NULL);
+                _ASSERTE(wrapperId != ComWrappersNative::InvalidWrapperId);
+            }
+
+            DWORD Hash() const
+            {
+                DWORD hash = (_wrapperId >> 32) ^ (_wrapperId & 0xFFFFFFFF);
+#if POINTER_BITS == 32
+                return hash ^ (DWORD)_identity;
+#else
+                INT64 identityInt64 = (INT64)_identity;
+                return hash ^ (identityInt64 >> 32) ^ (identityInt64 & 0xFFFFFFFF);
+#endif
+            }
+
+            bool operator==(const Key & rhs) const { return _identity == rhs._identity && _wrapperId == rhs._wrapperId; }
+
+        private:
+            void* _identity;
+            INT64 _wrapperId;
+        };
+
+        Key GetKey() const
+        {
+            return Key(Identity, WrapperId);
+        }
     };
 
     const DWORD ExternalObjectContext::InvalidSyncBlockIndex = 0; // See syncblk.h
@@ -171,10 +208,10 @@ namespace
         class Traits : public DefaultSHashTraits<ExternalObjectContext *>
         {
         public:
-            using key_t = void*;
-            static const key_t GetKey(_In_ element_t e) { LIMITED_METHOD_CONTRACT; return (key_t)e->Identity; }
-            static count_t Hash(_In_ key_t key) { LIMITED_METHOD_CONTRACT; return (count_t)(size_t)key; }
-            static bool Equals(_In_ key_t lhs, _In_ key_t rhs) { LIMITED_METHOD_CONTRACT; return (lhs == rhs); }
+            using key_t = ExternalObjectContext::Key;
+            static const key_t GetKey(_In_ element_t e) { LIMITED_METHOD_CONTRACT; return e->GetKey(); }
+            static count_t Hash(_In_ key_t key) { LIMITED_METHOD_CONTRACT; return (count_t)key.Hash(); }
+            static bool Equals(_In_ key_t lhs, _In_ key_t rhs) { LIMITED_METHOD_CONTRACT; return lhs == rhs; }
         };
 
         // Alias some useful types
@@ -311,7 +348,7 @@ namespace
             RETURN gc.arrRef;
         }
 
-        ExternalObjectContext* Find(_In_ IUnknown* instance)
+        ExternalObjectContext* Find(_In_ const ExternalObjectContext::Key& key)
         {
             CONTRACT(ExternalObjectContext*)
             {
@@ -319,7 +356,6 @@ namespace
                 GC_NOTRIGGER;
                 MODE_COOPERATIVE;
                 PRECONDITION(IsLockHeld());
-                PRECONDITION(instance != NULL);
                 POSTCONDITION(CheckPointer(RETVAL, NULL_OK));
             }
             CONTRACT_END;
@@ -327,7 +363,7 @@ namespace
             // Forbid the GC from messing with the hash table.
             GCX_FORBID();
 
-            RETURN _hashMap.Lookup(instance);
+            RETURN _hashMap.Lookup(key);
         }
 
         ExternalObjectContext* Add(_In_ ExternalObjectContext* cxt)
@@ -340,7 +376,7 @@ namespace
                 PRECONDITION(IsLockHeld());
                 PRECONDITION(!Traits::IsNull(cxt) && !Traits::IsDeleted(cxt));
                 PRECONDITION(cxt->Identity != NULL);
-                PRECONDITION(Find(static_cast<IUnknown*>(cxt->Identity)) == NULL);
+                PRECONDITION(Find(cxt->GetKey()) == NULL);
                 POSTCONDITION(RETVAL == cxt);
             }
             CONTRACT_END;
@@ -349,7 +385,7 @@ namespace
             RETURN cxt;
         }
 
-        ExternalObjectContext* FindOrAdd(_In_ IUnknown* key, _In_ ExternalObjectContext* newCxt)
+        ExternalObjectContext* FindOrAdd(_In_ const ExternalObjectContext::Key& key, _In_ ExternalObjectContext* newCxt)
         {
             CONTRACT(ExternalObjectContext*)
             {
@@ -357,9 +393,8 @@ namespace
                 GC_NOTRIGGER;
                 MODE_COOPERATIVE;
                 PRECONDITION(IsLockHeld());
-                PRECONDITION(key != NULL);
                 PRECONDITION(!Traits::IsNull(newCxt) && !Traits::IsDeleted(newCxt));
-                PRECONDITION(key == newCxt->Identity);
+                PRECONDITION(key == newCxt->GetKey());
                 POSTCONDITION(CheckPointer(RETVAL));
             }
             CONTRACT_END;
@@ -392,7 +427,7 @@ namespace
             }
             CONTRACTL_END;
 
-            _hashMap.Remove(cxt->Identity);
+            _hashMap.Remove(cxt->GetKey());
         }
     };
 
@@ -400,8 +435,8 @@ namespace
     Volatile<ExtObjCxtCache*> ExtObjCxtCache::g_Instance;
 
     // Indicator for if a ComWrappers implementation is globally registered
-    bool g_IsGlobalComWrappersRegisteredForMarshalling;
-    bool g_IsGlobalComWrappersRegisteredForTrackerSupport;
+    INT64 g_marshallingGlobalInstanceId = ComWrappersNative::InvalidWrapperId;
+    INT64 g_trackerSupportGlobalInstanceId = ComWrappersNative::InvalidWrapperId;
 
     // Defined handle types for the specific object uses.
     const HandleType InstanceHandleType{ HNDTYPE_STRONG };
@@ -522,6 +557,7 @@ namespace
 
     bool TryGetOrCreateComInterfaceForObjectInternal(
         _In_opt_ OBJECTREF impl,
+        _In_ INT64 wrapperId,
         _In_ OBJECTREF instance,
         _In_ CreateComInterfaceFlags flags,
         _In_ ComWrappersScenario scenario,
@@ -533,7 +569,8 @@ namespace
             MODE_COOPERATIVE;
             PRECONDITION(instance != NULL);
             PRECONDITION(wrapperRaw != NULL);
-            PRECONDITION((impl != NULL && scenario == ComWrappersScenario::Instance)|| (impl == NULL && scenario != ComWrappersScenario::Instance));
+            PRECONDITION((impl != NULL && scenario == ComWrappersScenario::Instance) || (impl == NULL && scenario != ComWrappersScenario::Instance));
+            PRECONDITION(wrapperId != ComWrappersNative::InvalidWrapperId);
         }
         CONTRACT_END;
 
@@ -559,7 +596,7 @@ namespace
         _ASSERTE(syncBlock->IsPrecious());
 
         // Query the associated InteropSyncBlockInfo for an existing managed object wrapper.
-        if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRawMaybe))
+        if (!interopInfo->TryGetManagedObjectComWrapper(wrapperId, &wrapperRawMaybe))
         {
             // Compute VTables for the new existing COM object using the supplied COM Wrappers implementation.
             //
@@ -570,7 +607,7 @@ namespace
             void* vtables = CallComputeVTables(scenario, &gc.implRef, &gc.instRef, flags, &vtableCount);
 
             // Re-query the associated InteropSyncBlockInfo for an existing managed object wrapper.
-            if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRawMaybe)
+            if (!interopInfo->TryGetManagedObjectComWrapper(wrapperId, &wrapperRawMaybe)
                 && ((vtables != nullptr && vtableCount > 0) || (vtableCount == 0)))
             {
                 OBJECTHANDLE instHandle = GetAppDomain()->CreateTypedHandle(gc.instRef, InstanceHandleType);
@@ -590,14 +627,14 @@ namespace
                 _ASSERTE(!newWrapper.IsNull());
 
                 // Try setting the newly created managed object wrapper on the InteropSyncBlockInfo.
-                if (!interopInfo->TrySetManagedObjectComWrapper(newWrapper))
+                if (!interopInfo->TrySetManagedObjectComWrapper(wrapperId, newWrapper))
                 {
                     // The new wrapper couldn't be set which means a wrapper already exists.
                     newWrapper.Release();
 
                     // If the managed object wrapper couldn't be set, then
                     // it should be possible to get the current one.
-                    if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRawMaybe))
+                    if (!interopInfo->TryGetManagedObjectComWrapper(wrapperId, &wrapperRawMaybe))
                     {
                         UNREACHABLE();
                     }
@@ -638,6 +675,7 @@ namespace
 
     bool TryGetOrCreateObjectForComInstanceInternal(
         _In_opt_ OBJECTREF impl,
+        _In_ INT64 wrapperId,
         _In_ IUnknown* identity,
         _In_ CreateObjectFlags flags,
         _In_ ComWrappersScenario scenario,
@@ -651,6 +689,7 @@ namespace
             PRECONDITION(identity != NULL);
             PRECONDITION(objRef != NULL);
             PRECONDITION((impl != NULL && scenario == ComWrappersScenario::Instance) || (impl == NULL && scenario != ComWrappersScenario::Instance));
+            PRECONDITION(wrapperId != ComWrappersNative::InvalidWrapperId);
         }
         CONTRACT_END;
 
@@ -672,13 +711,15 @@ namespace
         ExtObjCxtCache* cache = ExtObjCxtCache::GetInstance();
         InteropLib::OBJECTHANDLE handle = NULL;
 
+        ExternalObjectContext::Key cacheKey(identity, wrapperId);
+
         // Check if the user requested a unique instance.
         bool uniqueInstance = !!(flags & CreateObjectFlags::CreateObjectFlags_UniqueInstance);
         if (!uniqueInstance)
         {
             // Query the external object cache
             ExtObjCxtCache::LockHolder lock(cache);
-            extObjCxt = cache->Find(identity);
+            extObjCxt = cache->Find(cacheKey);
 
             // If is no object found in the cache, check if the object COM instance is actually the CCW
             // representing a managed object. For the scenario of marshalling through a global instance,
@@ -750,6 +791,7 @@ namespace
                     identity,
                     GetCurrentCtxCookie(),
                     gc.objRefMaybe->GetSyncBlockIndex(),
+                    wrapperId,
                     flags);
 
                 if (uniqueInstance)
@@ -760,7 +802,7 @@ namespace
                 {
                     // Attempt to insert the new context into the cache.
                     ExtObjCxtCache::LockHolder lock(cache);
-                    extObjCxt = cache->FindOrAdd(identity, resultHolder.GetContext());
+                    extObjCxt = cache->FindOrAdd(cacheKey, resultHolder.GetContext());
                 }
 
                 // If the returned context matches the new context it means the
@@ -1039,6 +1081,7 @@ namespace InteropLibImports
             // Get wrapper for external object
             bool success = TryGetOrCreateObjectForComInstanceInternal(
                 gc.implRef,
+                g_trackerSupportGlobalInstanceId,
                 externalComObject,
                 externalObjectFlags,
                 ComWrappersScenario::TrackerSupportGlobalInstance,
@@ -1051,6 +1094,7 @@ namespace InteropLibImports
             // Get wrapper for managed object
             success = TryGetOrCreateComInterfaceForObjectInternal(
                 gc.implRef,
+                g_trackerSupportGlobalInstanceId,
                 gc.objRef,
                 trackerTargetFlags,
                 ComWrappersScenario::TrackerSupportGlobalInstance,
@@ -1220,6 +1264,7 @@ namespace InteropLibImports
 
 BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateComInterfaceForObject(
     _In_ QCall::ObjectHandleOnStack comWrappersImpl,
+    _In_ INT64 wrapperId,
     _In_ QCall::ObjectHandleOnStack instance,
     _In_ INT32 flags,
     _Outptr_ void** wrapper)
@@ -1236,6 +1281,7 @@ BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateComInterfaceForObject(
         GCX_COOP();
         success = TryGetOrCreateComInterfaceForObjectInternal(
             ObjectToOBJECTREF(*comWrappersImpl.m_ppObject),
+            wrapperId,
             ObjectToOBJECTREF(*instance.m_ppObject),
             (CreateComInterfaceFlags)flags,
             ComWrappersScenario::Instance,
@@ -1249,6 +1295,7 @@ BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateComInterfaceForObject(
 
 BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateObjectForComInstance(
     _In_ QCall::ObjectHandleOnStack comWrappersImpl,
+    _In_ INT64 wrapperId,
     _In_ void* ext,
     _In_ INT32 flags,
     _In_ QCall::ObjectHandleOnStack wrapperMaybe,
@@ -1278,6 +1325,7 @@ BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateObjectForComInstance(
         OBJECTREF newObj;
         success = TryGetOrCreateObjectForComInstanceInternal(
             ObjectToOBJECTREF(*comWrappersImpl.m_ppObject),
+            wrapperId,
             identity,
             (CreateObjectFlags)flags,
             ComWrappersScenario::Instance,
@@ -1387,19 +1435,25 @@ void ComWrappersNative::MarkWrapperAsComActivated(_In_ IUnknown* wrapperMaybe)
     _ASSERTE(SUCCEEDED(hr) || hr == E_INVALIDARG);
 }
 
-void QCALLTYPE GlobalComWrappersForMarshalling::SetGlobalInstanceRegisteredForMarshalling()
+void QCALLTYPE GlobalComWrappersForMarshalling::SetGlobalInstanceRegisteredForMarshalling(INT64 id)
 {
     QCALL_CONTRACT_NO_GC_TRANSITION;
 
-    _ASSERTE(!g_IsGlobalComWrappersRegisteredForMarshalling);
-    g_IsGlobalComWrappersRegisteredForMarshalling = true;
+    _ASSERTE(g_marshallingGlobalInstanceId == ComWrappersNative::InvalidWrapperId && id != ComWrappersNative::InvalidWrapperId);
+    g_marshallingGlobalInstanceId = id;
+}
+
+bool GlobalComWrappersForMarshalling::IsRegisteredInstance(INT64 id)
+{
+    return g_marshallingGlobalInstanceId != ComWrappersNative::InvalidWrapperId
+        && g_marshallingGlobalInstanceId == id;
 }
 
 bool GlobalComWrappersForMarshalling::TryGetOrCreateComInterfaceForObject(
     _In_ OBJECTREF instance,
     _Outptr_ void** wrapperRaw)
 {
-    if (!g_IsGlobalComWrappersRegisteredForMarshalling)
+    if (g_marshallingGlobalInstanceId == ComWrappersNative::InvalidWrapperId)
         return false;
 
     // Switch to Cooperative mode since object references
@@ -1412,6 +1466,7 @@ bool GlobalComWrappersForMarshalling::TryGetOrCreateComInterfaceForObject(
         // Passing NULL as the ComWrappers implementation indicates using the globally registered instance
         return TryGetOrCreateComInterfaceForObjectInternal(
             NULL,
+            g_marshallingGlobalInstanceId,
             instance,
             flags,
             ComWrappersScenario::MarshallingGlobalInstance,
@@ -1424,7 +1479,7 @@ bool GlobalComWrappersForMarshalling::TryGetOrCreateObjectForComInstance(
     _In_ INT32 objFromComIPFlags,
     _Out_ OBJECTREF* objRef)
 {
-    if (!g_IsGlobalComWrappersRegisteredForMarshalling)
+    if (g_marshallingGlobalInstanceId == ComWrappersNative::InvalidWrapperId)
         return false;
 
     // Determine the true identity of the object
@@ -1448,6 +1503,7 @@ bool GlobalComWrappersForMarshalling::TryGetOrCreateObjectForComInstance(
         // Passing NULL as the ComWrappers implementation indicates using the globally registered instance
         return TryGetOrCreateObjectForComInstanceInternal(
             NULL /*comWrappersImpl*/,
+            g_marshallingGlobalInstanceId,
             identity,
             (CreateObjectFlags)flags,
             ComWrappersScenario::MarshallingGlobalInstance,
@@ -1456,12 +1512,18 @@ bool GlobalComWrappersForMarshalling::TryGetOrCreateObjectForComInstance(
     }
 }
 
-void QCALLTYPE GlobalComWrappersForTrackerSupport::SetGlobalInstanceRegisteredForTrackerSupport()
+void QCALLTYPE GlobalComWrappersForTrackerSupport::SetGlobalInstanceRegisteredForTrackerSupport(INT64 id)
 {
     QCALL_CONTRACT_NO_GC_TRANSITION;
 
-    _ASSERTE(!g_IsGlobalComWrappersRegisteredForTrackerSupport);
-    g_IsGlobalComWrappersRegisteredForTrackerSupport = true;
+    _ASSERTE(g_trackerSupportGlobalInstanceId == ComWrappersNative::InvalidWrapperId && id != ComWrappersNative::InvalidWrapperId);
+    g_trackerSupportGlobalInstanceId = id;
+}
+
+bool GlobalComWrappersForTrackerSupport::IsRegisteredInstance(INT64 id)
+{
+    return g_trackerSupportGlobalInstanceId != ComWrappersNative::InvalidWrapperId
+        && g_trackerSupportGlobalInstanceId == id;
 }
 
 bool GlobalComWrappersForTrackerSupport::TryGetOrCreateComInterfaceForObject(
@@ -1475,12 +1537,13 @@ bool GlobalComWrappersForTrackerSupport::TryGetOrCreateComInterfaceForObject(
     }
     CONTRACTL_END;
 
-    if (!g_IsGlobalComWrappersRegisteredForTrackerSupport)
+    if (g_trackerSupportGlobalInstanceId == ComWrappersNative::InvalidWrapperId)
         return false;
 
     // Passing NULL as the ComWrappers implementation indicates using the globally registered instance
     return TryGetOrCreateComInterfaceForObjectInternal(
         NULL,
+        g_trackerSupportGlobalInstanceId,
         instance,
         CreateComInterfaceFlags::CreateComInterfaceFlags_TrackerSupport,
         ComWrappersScenario::TrackerSupportGlobalInstance,
@@ -1498,7 +1561,7 @@ bool GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance(
     }
     CONTRACTL_END;
 
-    if (!g_IsGlobalComWrappersRegisteredForTrackerSupport)
+    if (g_trackerSupportGlobalInstanceId == ComWrappersNative::InvalidWrapperId)
         return false;
 
     // Determine the true identity of the object
@@ -1513,6 +1576,7 @@ bool GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance(
     // Passing NULL as the ComWrappers implementation indicates using the globally registered instance
     return TryGetOrCreateObjectForComInstanceInternal(
         NULL /*comWrappersImpl*/,
+        g_trackerSupportGlobalInstanceId,
         identity,
         CreateObjectFlags::CreateObjectFlags_TrackerObject,
         ComWrappersScenario::TrackerSupportGlobalInstance,
@@ -1520,7 +1584,7 @@ bool GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance(
         objRef);
 }
 
-IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid)
+IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId)
 {
     CONTRACTL
     {
@@ -1528,11 +1592,14 @@ IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTE
         GC_TRIGGERS;
         MODE_COOPERATIVE;
         PRECONDITION(CheckPointer(objectPROTECTED));
+        PRECONDITION(CheckPointer(wrapperId));
     }
     CONTRACTL_END;
 
     ASSERT_PROTECTED(objectPROTECTED);
 
+    *wrapperId = ComWrappersNative::InvalidWrapperId;
+
     SyncBlock* syncBlock = (*objectPROTECTED)->PassiveGetSyncBlock();
     if (syncBlock == nullptr)
     {
@@ -1545,10 +1612,13 @@ IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTE
         return nullptr;
     }
 
-    void* context;
-    if (interopInfo->TryGetExternalComObjectContext(&context))
+    void* contextMaybe;
+    if (interopInfo->TryGetExternalComObjectContext(&contextMaybe))
     {
-        IUnknown* identity = reinterpret_cast<IUnknown*>(reinterpret_cast<ExternalObjectContext*>(context)->Identity);
+        ExternalObjectContext* context = reinterpret_cast<ExternalObjectContext*>(contextMaybe);
+        *wrapperId = context->WrapperId;
+
+        IUnknown* identity = reinterpret_cast<IUnknown*>(context->Identity);
         GCX_PREEMP();
         IUnknown* result;
         if (SUCCEEDED(identity->QueryInterface(riid, (void**)&result)))
index 3e5abda..bae00fb 100644 (file)
@@ -11,6 +11,9 @@
 // Native calls for the managed ComWrappers API
 class ComWrappersNative
 {
+public:
+    static const INT64 InvalidWrapperId = 0;
+
 public: // Native QCalls for the abstract ComWrappers managed type.
     static void QCALLTYPE GetIUnknownImpl(
         _Out_ void** fpQueryInterface,
@@ -19,12 +22,14 @@ public: // Native QCalls for the abstract ComWrappers managed type.
 
     static BOOL QCALLTYPE TryGetOrCreateComInterfaceForObject(
         _In_ QCall::ObjectHandleOnStack comWrappersImpl,
+        _In_ INT64 wrapperId,
         _In_ QCall::ObjectHandleOnStack instance,
         _In_ INT32 flags,
         _Outptr_ void** wrapperRaw);
 
     static BOOL QCALLTYPE TryGetOrCreateObjectForComInstance(
         _In_ QCall::ObjectHandleOnStack comWrappersImpl,
+        _In_ INT64 wrapperId,
         _In_ void* externalComObject,
         _In_ INT32 flags,
         _In_ QCall::ObjectHandleOnStack wrapperMaybe,
@@ -39,7 +44,7 @@ public: // COM activation
     static void MarkWrapperAsComActivated(_In_ IUnknown* wrapperMaybe);
 
 public: // Unwrapping support
-    static IUnknown* GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid);
+    static IUnknown* GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId);
 };
 
 class GlobalComWrappersForMarshalling
@@ -48,9 +53,11 @@ public:
     // Native QCall for the ComWrappers managed type to indicate a global instance
     // is registered for marshalling. This should be set if the private static member
     // representing the global instance for marshalling on ComWrappers is non-null.
-    static void QCALLTYPE SetGlobalInstanceRegisteredForMarshalling();
+    static void QCALLTYPE SetGlobalInstanceRegisteredForMarshalling(_In_ INT64 id);
 
 public: // Functions operating on a registered global instance for marshalling
+    static bool IsRegisteredInstance(_In_ INT64 id);
+
     static bool TryGetOrCreateComInterfaceForObject(
         _In_ OBJECTREF instance,
         _Outptr_ void** wrapperRaw);
@@ -68,9 +75,11 @@ public:
     // Native QCall for the ComWrappers managed type to indicate a global instance
     // is registered for tracker support. This should be set if the private static member
     // representing the global instance for tracker support on ComWrappers is non-null.
-    static void QCALLTYPE SetGlobalInstanceRegisteredForTrackerSupport();
+    static void QCALLTYPE SetGlobalInstanceRegisteredForTrackerSupport(_In_ INT64 id);
 
 public: // Functions operating on a registered global instance for tracker support
+    static bool IsRegisteredInstance(_In_ INT64 id);
+
     static bool TryGetOrCreateComInterfaceForObject(
         _In_ OBJECTREF instance,
         _Outptr_ void** wrapperRaw);
index 33cad06..a2f6846 100644 (file)
@@ -1321,12 +1321,7 @@ void CleanupSyncBlockComData(InteropSyncBlockInfo* pInteropInfo)
     }
 
 #ifdef FEATURE_COMWRAPPERS
-    void* mocw;
-    if (pInteropInfo->TryGetManagedObjectComWrapper(&mocw))
-    {
-        (void)pInteropInfo->TrySetManagedObjectComWrapper(NULL, mocw);
-        ComWrappersNative::DestroyManagedObjectComWrapper(mocw);
-    }
+    pInteropInfo->ClearManagedObjectComWrappers(&ComWrappersNative::DestroyManagedObjectComWrapper);
 
     void* eoc;
     if (pInteropInfo->TryGetExternalComObjectContext(&eoc))
index e0364df..08a4fee 100644 (file)
@@ -602,6 +602,8 @@ class ComClassFactory;
 struct RCW;
 class RCWHolder;
 typedef DPTR(class ComCallWrapper)        PTR_ComCallWrapper;
+
+#include "shash.h"
 #endif // FEATURE_COMINTEROP
 
 class InteropSyncBlockInfo
@@ -791,22 +793,65 @@ public:
 #endif
 
 public:
-    bool TryGetManagedObjectComWrapper(_Out_ void** mocw)
+    bool TryGetManagedObjectComWrapper(_In_ INT64 wrapperId, _Out_ void** mocw)
     {
         LIMITED_METHOD_DAC_CONTRACT;
-        *mocw = m_managedObjectComWrapper;
-        return (*mocw != NULL);
+
+        *mocw = NULL;
+        if (m_managedObjectComWrapperMap == NULL)
+            return false;
+
+        CrstHolder lock(&m_managedObjectComWrapperLock);
+        return m_managedObjectComWrapperMap->Lookup(wrapperId, mocw);
     }
 
 #ifndef DACCESS_COMPILE
-    bool TrySetManagedObjectComWrapper(_In_ void* mocw, _In_ void* curr = NULL)
+    bool TrySetManagedObjectComWrapper(_In_ INT64 wrapperId, _In_ void* mocw, _In_ void* curr = NULL)
     {
         LIMITED_METHOD_CONTRACT;
 
-        return (FastInterlockCompareExchangePointer(
-                        &m_managedObjectComWrapper,
-                        mocw,
-                        curr) == curr);
+        if (m_managedObjectComWrapperMap == NULL)
+        {
+            NewHolder<ManagedObjectComWrapperByIdMap> map = new ManagedObjectComWrapperByIdMap();
+            if (FastInterlockCompareExchangePointer((ManagedObjectComWrapperByIdMap**)&m_managedObjectComWrapperMap, (ManagedObjectComWrapperByIdMap *)map, NULL) == NULL)
+            {
+                map.SuppressRelease();
+                m_managedObjectComWrapperLock.Init(CrstLeafLock);
+            }
+
+            _ASSERTE(m_managedObjectComWrapperMap != NULL);
+        }
+
+        CrstHolder lock(&m_managedObjectComWrapperLock);
+
+        if (m_managedObjectComWrapperMap->LookupPtr(wrapperId) != curr)
+            return false;
+
+        m_managedObjectComWrapperMap->Add(wrapperId, mocw);
+        return true;
+    }
+
+    using EnumWrappersCallback = void(void* mocw);
+    void ClearManagedObjectComWrappers(EnumWrappersCallback* callback)
+    {
+        LIMITED_METHOD_CONTRACT;
+
+        if (m_managedObjectComWrapperMap == NULL)
+            return;
+
+        CrstHolder lock(&m_managedObjectComWrapperLock);
+
+        if (callback != NULL)
+        {
+            ManagedObjectComWrapperByIdMap::Iterator iter = m_managedObjectComWrapperMap->Begin();
+            while (iter != m_managedObjectComWrapperMap->End())
+            {
+                callback(iter->Value());
+                ++iter;
+            }
+        }
+
+        m_managedObjectComWrapperMap->RemoveAll();
     }
 #endif // !DACCESS_COMPILE
 
@@ -831,8 +876,11 @@ public:
 
 private:
     // See InteropLib API for usage.
-    void* m_managedObjectComWrapper;
     void* m_externalComObjectContext;
+
+    using ManagedObjectComWrapperByIdMap = MapSHash<INT64, void*>;
+    CrstExplicitInit m_managedObjectComWrapperLock;
+    NewHolder<ManagedObjectComWrapperByIdMap> m_managedObjectComWrapperMap;
 #endif // FEATURE_COMINTEROP
 
 };
index 2cdb83c..bc52c61 100644 (file)
@@ -103,9 +103,9 @@ private:
 
 #ifdef FEATURE_COMINTEROP
 
-// Get a native COM weak reference for the object underlying an RCW if applicable.  If the incoming object cannot
-// use a native COM weak reference, nullptr is returned.  Otherwise, an AddRef-ed IWeakReference* for the COM
-// object underlying the RCW is returned.
+// Get the native COM information for the object underlying an RCW if applicable. If the incoming object cannot
+// use a native COM weak reference, nullptr is returned. Otherwise, a new NativeComWeakHandleInfo containing an
+// AddRef-ed IWeakReference* for the COM object underlying the RCW is returned.
 //
 // In order to qualify to be used with a HNDTYPE_WEAK_NATIVE_COM, the incoming object must:
 //  * be an RCW
@@ -113,7 +113,7 @@ private:
 //  * succeed when asked for an IWeakReference*
 //
 // Note that *pObject should be GC protected on the way into this method
-IWeakReference* GetComWeakReference(OBJECTREF* pObject)
+NativeComWeakHandleInfo* GetComWeakReferenceInfo(OBJECTREF* pObject)
 {
     CONTRACTL
     {
@@ -134,6 +134,7 @@ IWeakReference* GetComWeakReference(OBJECTREF* pObject)
     MethodTable* pMT = (*pObject)->GetMethodTable();
 
     SafeComHolder<IWeakReferenceSource> pWeakReferenceSource(nullptr);
+    INT64 wrapperId = ComWrappersNative::InvalidWrapperId;
 
     // If the object is not an RCW, then we do not want to use a native COM weak reference to it
     // If the object is a managed type deriving from a COM type, then we also do not want to use a native COM
@@ -147,7 +148,7 @@ IWeakReference* GetComWeakReference(OBJECTREF* pObject)
 #ifdef FEATURE_COMWRAPPERS
     else
     {
-        pWeakReferenceSource = reinterpret_cast<IWeakReferenceSource*>(ComWrappersNative::GetIdentityForObject(pObject, IID_IWeakReferenceSource));
+        pWeakReferenceSource = reinterpret_cast<IWeakReferenceSource*>(ComWrappersNative::GetIdentityForObject(pObject, IID_IWeakReferenceSource, &wrapperId));
     }
 #endif
 
@@ -163,7 +164,9 @@ IWeakReference* GetComWeakReference(OBJECTREF* pObject)
         return nullptr;
     }
 
-    return pWeakReference.Extract();
+    NewHolder<NativeComWeakHandleInfo> info = new NativeComWeakHandleInfo { pWeakReference.GetValue(), wrapperId };
+    pWeakReference.SuppressRelease();
+    return info.Extract();
 }
 
 // Given an object handle that stores a native COM weak reference, attempt to create an RCW
@@ -205,6 +208,7 @@ NOINLINE Object* LoadComWeakReferenceTarget(WEAKREFERENCEREF weakReference, Type
     // Since we're acquiring and releasing the lock multiple times, we need to check the handle state each time we
     // reacquire the lock to make sure that another thread hasn't reassigned the target of the handle or finalized it
     SafeComHolder<IWeakReference> pComWeakReference = nullptr;
+    INT64 wrapperId = ComWrappersNative::InvalidWrapperId;
     {
         WeakHandleSpinLockHolder handle(AcquireWeakHandleSpinLock(gc.weakReference), &gc.weakReference);
         GCX_NOTRIGGER();
@@ -233,9 +237,11 @@ NOINLINE Object* LoadComWeakReferenceTarget(WEAKREFERENCEREF weakReference, Type
                 _ASSERTE(pComWeakReference.IsNull());
                 CONTRACT_VIOLATION(GCViolation);
                 IGCHandleManager *mgr = GCHandleUtilities::GetGCHandleManager();
-                pComWeakReference = reinterpret_cast<IWeakReference*>(mgr->GetExtraInfoFromHandle(handle.Handle));
-                if (!pComWeakReference.IsNull())
+                NativeComWeakHandleInfo* comWeakHandleInfo = reinterpret_cast<NativeComWeakHandleInfo*>(mgr->GetExtraInfoFromHandle(handle.Handle));
+                if (comWeakHandleInfo != nullptr)
                 {
+                    wrapperId = comWeakHandleInfo->WrapperId;
+                    pComWeakReference = comWeakHandleInfo->WeakReference;
                     pComWeakReference->AddRef();
                 }
             }
@@ -268,9 +274,21 @@ NOINLINE Object* LoadComWeakReferenceTarget(WEAKREFERENCEREF weakReference, Type
     // If we were able to get an IUnkown identity for the object, then we can find or create an associated RCW for it.
     if (!pTargetIdentity.IsNull())
     {
-        // Try the global COM wrappers first before falling back to the built-in system.
-        if (!GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance(pTargetIdentity, &gc.rcw))
+        if (wrapperId != ComWrappersNative::InvalidWrapperId)
         {
+            // Try the global COM wrappers
+            if (GlobalComWrappersForTrackerSupport::IsRegisteredInstance(wrapperId))
+            {
+                (void)GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance(pTargetIdentity, &gc.rcw);
+            }
+            else if (GlobalComWrappersForMarshalling::IsRegisteredInstance(wrapperId))
+            {
+                (void)GlobalComWrappersForMarshalling::TryGetOrCreateObjectForComInstance(pTargetIdentity, ObjFromComIP::NONE, &gc.rcw);
+            }
+        }
+        else
+        {
+            // If the original RCW was not created through ComWrappers, fall back to the built-in system.
             GetObjectRefFromComIP(&gc.rcw, pTargetIdentity);
         }
     }
@@ -428,21 +446,21 @@ FCIMPL3(void, WeakReferenceNative::Create, WeakReferenceObject * pThisUNSAFE, Ob
 
     // Create the handle.
 #ifdef FEATURE_COMINTEROP
-    IWeakReference* pRawComWeakReference = nullptr;
+    NativeComWeakHandleInfo *comWeakHandleInfo = nullptr;
     if (gc.pTarget != NULL)
     {
-        SyncBlockpSyncBlock = gc.pTarget->PassiveGetSyncBlock();
+        SyncBlock *pSyncBlock = gc.pTarget->PassiveGetSyncBlock();
         if (pSyncBlock != nullptr && pSyncBlock->GetInteropInfoNoCreate() != nullptr)
         {
-            pRawComWeakReference = GetComWeakReference(&gc.pTarget);
+            comWeakHandleInfo = GetComWeakReferenceInfo(&gc.pTarget);
         }
     }
 
-    if (pRawComWeakReference != nullptr)
+    if (comWeakHandleInfo != nullptr)
     {
-        SafeComHolder<IWeakReference> pComWeakReferenceHolder(pRawComWeakReference);
-        gc.pThis->m_Handle = SetNativeComWeakReferenceHandle(GetAppDomain()->CreateNativeComWeakHandle(gc.pTarget, pComWeakReferenceHolder));
-        pComWeakReferenceHolder.SuppressRelease();
+        NewHolder<NativeComWeakHandleInfo> infoHolder(comWeakHandleInfo);
+        gc.pThis->m_Handle = SetNativeComWeakReferenceHandle(GetAppDomain()->CreateNativeComWeakHandle(gc.pTarget, infoHolder));
+        infoHolder.SuppressRelease();
     }
     else
 #endif // FEATURE_COMINTEROP
@@ -478,23 +496,22 @@ FCIMPL3(void, WeakReferenceOfTNative::Create, WeakReferenceObject * pThisUNSAFE,
 
     _ASSERTE(gc.pThis->GetMethodTable()->GetCanonicalMethodTable() == pWeakReferenceOfTCanonMT);
 
-    // Create the handle.
 #ifdef FEATURE_COMINTEROP
-    IWeakReference* pRawComWeakReference = nullptr;
+    NativeComWeakHandleInfo *comWeakHandleInfo = nullptr;
     if (gc.pTarget != NULL)
     {
-        SyncBlockpSyncBlock = gc.pTarget->PassiveGetSyncBlock();
+        SyncBlock *pSyncBlock = gc.pTarget->PassiveGetSyncBlock();
         if (pSyncBlock != nullptr && pSyncBlock->GetInteropInfoNoCreate() != nullptr)
         {
-            pRawComWeakReference = GetComWeakReference(&gc.pTarget);
+            comWeakHandleInfo = GetComWeakReferenceInfo(&gc.pTarget);
         }
     }
 
-    if (pRawComWeakReference != nullptr)
+    if (comWeakHandleInfo != nullptr)
     {
-        SafeComHolder<IWeakReference> pComWeakReferenceHolder(pRawComWeakReference);
-        gc.pThis->m_Handle = SetNativeComWeakReferenceHandle(GetAppDomain()->CreateNativeComWeakHandle(gc.pTarget, pComWeakReferenceHolder));
-        pComWeakReferenceHolder.SuppressRelease();
+        NewHolder<NativeComWeakHandleInfo> infoHolder(comWeakHandleInfo);
+        gc.pThis->m_Handle = SetNativeComWeakReferenceHandle(GetAppDomain()->CreateNativeComWeakHandle(gc.pTarget, infoHolder));
+        infoHolder.SuppressRelease();
     }
     else
 #endif // FEATURE_COMINTEROP
@@ -747,7 +764,7 @@ NOINLINE void SetWeakReferenceTarget(WEAKREFERENCEREF weakReference, OBJECTREF t
     HELPER_METHOD_FRAME_BEGIN_ATTRIB_2(Frame::FRAME_ATTR_EXACT_DEPTH|Frame::FRAME_ATTR_CAPTURE_DEPTH_2, target, weakReference);
 
 #ifdef FEATURE_COMINTEROP
-    SafeComHolder<IWeakReference> pTargetWeakReference(GetComWeakReference(&target));
+    NewHolder<NativeComWeakHandleInfo> comWeakHandleInfo(GetComWeakReferenceInfo(&target));
 #endif // FEATURE_COMINTEROP
 
 
@@ -778,21 +795,23 @@ NOINLINE void SetWeakReferenceTarget(WEAKREFERENCEREF weakReference, OBJECTREF t
 
     if (IsNativeComWeakReferenceHandle(handle.RawHandle))
     {
-        // If the existing reference is a native COM weak reference, we need to release its IWeakReference pointer
-        // and update it with the new weak reference pointer.  If the incoming object is not an RCW that can
-        // use IWeakReference, then pTargetWeakReference will be null.  Therefore, no matter what the incoming
-        // object type is, we can unconditionally store pTargetWeakReference to the object handle's extra data.
+        // If the existing reference is a native COM weak reference, we need to delete its native COM info
+        // and update it with the new native COM info. If the incoming object is not an RCW that can use
+        // IWeakReference, then comWeakHandleInfo will be null. Therefore, no matter what the incoming
+        // object type is, we can unconditionally store comWeakHandleInfo to the object handle's extra data.
         IGCHandleManager *mgr = GCHandleUtilities::GetGCHandleManager();
-        IWeakReference* pExistingWeakReference = reinterpret_cast<IWeakReference*>(mgr->GetExtraInfoFromHandle(handle.Handle));
-        mgr->SetExtraInfoForHandle(handle.Handle, HNDTYPE_WEAK_NATIVE_COM, reinterpret_cast<void*>(pTargetWeakReference.GetValue()));
+        NativeComWeakHandleInfo* existingInfo = reinterpret_cast<NativeComWeakHandleInfo*>(mgr->GetExtraInfoFromHandle(handle.Handle));
+        mgr->SetExtraInfoForHandle(handle.Handle, HNDTYPE_WEAK_NATIVE_COM, reinterpret_cast<void*>(comWeakHandleInfo.GetValue()));
         StoreObjectInHandle(handle.Handle, target);
 
-        if (pExistingWeakReference != nullptr)
+        if (existingInfo != nullptr)
         {
-            pExistingWeakReference->Release();
+            _ASSERTE(existingInfo->WeakReference != nullptr);
+            existingInfo->WeakReference->Release();
+            delete existingInfo;
         }
     }
-    else if (pTargetWeakReference != nullptr)
+    else if (comWeakHandleInfo != nullptr)
     {
         // The existing handle is not a native COM weak reference, but we need to store the new object in
         // a native COM weak reference.  Therefore we need to destroy the old handle and create a new native COM
@@ -801,7 +820,7 @@ NOINLINE void SetWeakReferenceTarget(WEAKREFERENCEREF weakReference, OBJECTREF t
         _ASSERTE(!IsNativeComWeakReferenceHandle(handle.RawHandle));
         OBJECTHANDLE previousHandle = handle.RawHandle;
 
-        handle.Handle = GetAppDomain()->CreateNativeComWeakHandle(target, pTargetWeakReference);
+        handle.Handle = GetAppDomain()->CreateNativeComWeakHandle(target, comWeakHandleInfo);
         handle.RawHandle = SetNativeComWeakReferenceHandle(handle.Handle);
 
         DestroyTypedHandle(previousHandle);
@@ -813,7 +832,7 @@ NOINLINE void SetWeakReferenceTarget(WEAKREFERENCEREF weakReference, OBJECTREF t
     }
 
 #ifdef FEATURE_COMINTEROP
-    pTargetWeakReference.SuppressRelease();
+    comWeakHandleInfo.SuppressRelease();
 #endif // FEATURE_COMINTEROP
 
     HELPER_METHOD_FRAME_END();
index 481eafd..2e76a65 100644 (file)
@@ -159,6 +159,50 @@ namespace ComWrappersTests
             Assert.AreNotEqual(trackerObj1, trackerObj3);
         }
 
+        static void ValidateWrappersInstanceIsolation()
+        {
+            Console.WriteLine($"Running {nameof(ValidateWrappersInstanceIsolation)}...");   
+
+            var cw1 = new TestComWrappers();
+            var cw2 = new TestComWrappers();
+
+            var testObj = new Test();
+
+            // Allocate a wrapper for the object
+            IntPtr comWrapper1 = cw1.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.TrackerSupport);
+            IntPtr comWrapper2 = cw2.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.TrackerSupport);
+            Assert.AreNotEqual(comWrapper1, IntPtr.Zero);
+            Assert.AreNotEqual(comWrapper2, IntPtr.Zero);
+            Assert.AreNotEqual(comWrapper1, comWrapper2);
+
+            IntPtr comWrapper3 = cw1.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.TrackerSupport);
+            IntPtr comWrapper4 = cw2.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.TrackerSupport);
+            Assert.AreNotEqual(comWrapper3, comWrapper4);
+            Assert.AreEqual(comWrapper1, comWrapper3);
+            Assert.AreEqual(comWrapper2, comWrapper4);
+
+            Marshal.Release(comWrapper1);
+            Marshal.Release(comWrapper2);
+            Marshal.Release(comWrapper3);
+            Marshal.Release(comWrapper4);
+
+            // Get an object from a tracker runtime.
+            IntPtr trackerObjRaw = MockReferenceTrackerRuntime.CreateTrackerObject();
+
+            // Create objects for the COM instance
+            var trackerObj1 = (ITrackerObjectWrapper)cw1.GetOrCreateObjectForComInstance(trackerObjRaw, CreateObjectFlags.TrackerObject);
+            var trackerObj2 = (ITrackerObjectWrapper)cw2.GetOrCreateObjectForComInstance(trackerObjRaw, CreateObjectFlags.TrackerObject);
+            Assert.AreNotEqual(trackerObj1, trackerObj2);
+
+            var trackerObj3 = (ITrackerObjectWrapper)cw1.GetOrCreateObjectForComInstance(trackerObjRaw, CreateObjectFlags.TrackerObject);
+            var trackerObj4 = (ITrackerObjectWrapper)cw2.GetOrCreateObjectForComInstance(trackerObjRaw, CreateObjectFlags.TrackerObject);
+            Assert.AreNotEqual(trackerObj3, trackerObj4);
+            Assert.AreEqual(trackerObj1, trackerObj3);
+            Assert.AreEqual(trackerObj2, trackerObj4);
+
+            Marshal.Release(trackerObjRaw);
+        }
+
         static void ValidatePrecreatedExternalWrapper()
         {
             Console.WriteLine($"Running {nameof(ValidatePrecreatedExternalWrapper)}...");
@@ -357,6 +401,7 @@ namespace ComWrappersTests
                 ValidateComInterfaceCreation();
                 ValidateFallbackQueryInterface();
                 ValidateCreateObjectCachingScenario();
+                ValidateWrappersInstanceIsolation();
                 ValidatePrecreatedExternalWrapper();
                 ValidateIUnknownImpls();
                 ValidateBadComWrapperImpl();
index f85138f..64d32d6 100644 (file)
@@ -22,7 +22,14 @@ namespace ComWrappersTests
         public IntPtr Vtbl;
     }
 
-    public class WeakReferencableWrapper
+    public enum WrapperRegistration
+    {
+        Local,
+        TrackerSupport,
+        Marshalling,
+    }
+
+    public class WeakReferenceableWrapper
     {
         private struct Vtbl
         {
@@ -37,14 +44,17 @@ namespace ComWrappersTests
         private readonly IntPtr instance;
         private readonly Vtbl vtable;
 
-        public WeakReferencableWrapper(IntPtr instance)
+        public WrapperRegistration Registration { get; }
+
+        public WeakReferenceableWrapper(IntPtr instance, WrapperRegistration reg)
         {
             var inst = Marshal.PtrToStructure<VtblPtr>(instance);
             this.vtable = Marshal.PtrToStructure<Vtbl>(inst.Vtbl);
             this.instance = instance;
+            Registration = reg;
         }
 
-        ~WeakReferencableWrapper()
+        ~WeakReferenceableWrapper()
         {
             if (this.instance != IntPtr.Zero)
             {
@@ -57,6 +67,13 @@ namespace ComWrappersTests
     {
         class TestComWrappers : ComWrappers
         {
+            public WrapperRegistration Registration { get; }
+
+            public TestComWrappers(WrapperRegistration reg = WrapperRegistration.Local)
+            {
+                Registration = reg;
+            }
+
             protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
             {
                 count = 0;
@@ -66,59 +83,158 @@ namespace ComWrappersTests
             protected override object CreateObject(IntPtr externalComObject, CreateObjectFlags flag)
             {
                 Marshal.AddRef(externalComObject);
-                return new WeakReferencableWrapper(externalComObject);
+                return new WeakReferenceableWrapper(externalComObject, Registration);
             }
 
             protected override void ReleaseObjects(IEnumerable objects)
             {
             }
 
-            public static readonly ComWrappers Instance = new TestComWrappers();
+            public static readonly TestComWrappers TrackerSupportInstance = new TestComWrappers(WrapperRegistration.TrackerSupport);
+            public static readonly TestComWrappers MarshallingInstance = new TestComWrappers(WrapperRegistration.Marshalling);
         }
 
-        static void ValidateNativeWeakReference()
+        private static void ValidateWeakReferenceState(WeakReference<WeakReferenceableWrapper> wr, bool expectedIsAlive, TestComWrappers sourceWrappers = null)
         {
-            Console.WriteLine($"Running {nameof(ValidateNativeWeakReference)}...");
+            WeakReferenceableWrapper target;
+            bool isAlive = wr.TryGetTarget(out target);
+            Assert.AreEqual(expectedIsAlive, isAlive);
 
-            static (WeakReference<WeakReferencableWrapper>, IntPtr) GetWeakReference()
-            {
-                var cw = new TestComWrappers();
+            if (isAlive && sourceWrappers != null)
+                Assert.AreEqual(sourceWrappers.Registration, target.Registration);
+        }
 
-                IntPtr objRaw = WeakReferenceNative.CreateWeakReferencableObject();
+        private static (WeakReference<WeakReferenceableWrapper>, IntPtr) GetWeakReference(TestComWrappers cw)
+        {
+            IntPtr objRaw = WeakReferenceNative.CreateWeakReferencableObject();
+            var obj = (WeakReferenceableWrapper)cw.GetOrCreateObjectForComInstance(objRaw, CreateObjectFlags.None);
+            var wr = new WeakReference<WeakReferenceableWrapper>(obj);
+            ValidateWeakReferenceState(wr, expectedIsAlive: true, cw);
+            return (wr, objRaw);
+        }
+
+        private static IntPtr SetWeakReferenceTarget(WeakReference<WeakReferenceableWrapper> wr, TestComWrappers cw)
+        {
+            IntPtr objRaw = WeakReferenceNative.CreateWeakReferencableObject();
+            var obj = (WeakReferenceableWrapper)cw.GetOrCreateObjectForComInstance(objRaw, CreateObjectFlags.None);
+            wr.SetTarget(obj);
+            ValidateWeakReferenceState(wr, expectedIsAlive: true, cw);
+            return objRaw;
+        }
+
+        private static void ValidateNativeWeakReference(TestComWrappers cw)
+        {
+            Console.WriteLine($"  -- Validate weak reference creation");
+            var (weakRef, nativeRef) = GetWeakReference(cw);
+
+            // Make sure RCW is collected
+            GC.Collect();
+            GC.WaitForPendingFinalizers();
+
+            // Non-globally registered ComWrappers instances do not support rehydration.
+            // A weak reference to an RCW wrapping an IWeakReference can stay alive if the RCW was created through
+            // a global ComWrappers instance. If the RCW was created throug a local ComWrappers instance, the weak
+            // reference should be dead and stay dead once the RCW is collected.
+            bool supportsRehydration = cw.Registration != WrapperRegistration.Local;
+            
+            Console.WriteLine($"    -- Validate RCW recreation");
+            ValidateWeakReferenceState(weakRef, expectedIsAlive: supportsRehydration, cw);
+
+            // Release the last native reference.
+            Marshal.Release(nativeRef);
+            GC.Collect();
+            GC.WaitForPendingFinalizers();
+
+            // After all native references die and the RCW is collected, the weak reference should be dead and stay dead.
+            Console.WriteLine($"    -- Validate release");
+            ValidateWeakReferenceState(weakRef, expectedIsAlive: false);
+
+            // Reset the weak reference target
+            Console.WriteLine($"  -- Validate target reset");
+            nativeRef = SetWeakReferenceTarget(weakRef, cw);
 
-                var obj = (WeakReferencableWrapper)cw.GetOrCreateObjectForComInstance(objRaw, CreateObjectFlags.None);
+            // Make sure RCW is collected
+            GC.Collect();
+            GC.WaitForPendingFinalizers();
+
+            Console.WriteLine($"    -- Validate RCW recreation");
+            ValidateWeakReferenceState(weakRef, expectedIsAlive: supportsRehydration, cw);
+
+            // Release the last native reference.
+            Marshal.Release(nativeRef);
+            GC.Collect();
+            GC.WaitForPendingFinalizers();
+
+            // After all native references die and the RCW is collected, the weak reference should be dead and stay dead.
+            Console.WriteLine($"    -- Validate release");
+            ValidateWeakReferenceState(weakRef, expectedIsAlive: false);
+        }
+
+        static void ValidateGlobalInstanceTrackerSupport()
+        {
+            Console.WriteLine($"Running {nameof(ValidateGlobalInstanceTrackerSupport)}...");
+            ValidateNativeWeakReference(TestComWrappers.TrackerSupportInstance);
+        }
+
+        static void ValidateGlobalInstanceMarshalling()
+        {
+            Console.WriteLine($"Running {nameof(ValidateGlobalInstanceMarshalling)}...");
+            ValidateNativeWeakReference(TestComWrappers.MarshallingInstance);
+        }
+
+        static void ValidateLocalInstance()
+        {
+            Console.WriteLine($"Running {nameof(ValidateLocalInstance)}...");
+            ValidateNativeWeakReference(new TestComWrappers());
+        }
+
+        static void ValidateNonComWrappers()
+        {
+            Console.WriteLine($"Running {nameof(ValidateNonComWrappers)}...");
 
-                return (new WeakReference<WeakReferencableWrapper>(obj), objRaw);
+            (WeakReference, IntPtr) GetWeakReference()
+            {
+                IntPtr objRaw = WeakReferenceNative.CreateWeakReferencableObject();
+                var obj = Marshal.GetObjectForIUnknown(objRaw);
+                return (new WeakReference(obj), objRaw);
             }
 
-            static bool CheckIfWeakReferenceIsAlive(WeakReference<WeakReferencableWrapper> wr)
+            bool HasTarget(WeakReference wr)
             {
-                return wr.TryGetTarget(out _);
+                return wr.Target != null;
             }
 
             var (weakRef, nativeRef) = GetWeakReference();
             GC.Collect();
             GC.WaitForPendingFinalizers();
-            // A weak reference to an RCW wrapping an IWeakReference should stay alive even after the RCW dies
-            Assert.IsTrue(CheckIfWeakReferenceIsAlive(weakRef));
+
+            // A weak reference to an RCW wrapping an IWeakReference created throguh the built-in system
+            // should stay alive even after the RCW dies
+            Assert.IsFalse(weakRef.IsAlive);
+            Assert.IsTrue(HasTarget(weakRef));
 
             // Release the last native reference.
             Marshal.Release(nativeRef);
-
             GC.Collect();
             GC.WaitForPendingFinalizers();
 
             // After all native references die and the RCW is collected, the weak reference should be dead and stay dead.
-            Assert.IsFalse(CheckIfWeakReferenceIsAlive(weakRef));
-
+            Assert.IsNull(weakRef.Target);
         }
 
         static int Main(string[] doNotUse)
         {
             try
             {
-                ComWrappers.RegisterForTrackerSupport(TestComWrappers.Instance);
-                ValidateNativeWeakReference();
+                ValidateNonComWrappers();
+
+                ComWrappers.RegisterForTrackerSupport(TestComWrappers.TrackerSupportInstance);
+                ValidateGlobalInstanceTrackerSupport();
+
+                ComWrappers.RegisterForMarshalling(TestComWrappers.MarshallingInstance);
+                ValidateGlobalInstanceMarshalling();
+
+                ValidateLocalInstance();
             }
             catch (Exception e)
             {