Enable loading COM component in default ALC via runtime config setting (#79026)
authorElinor Fung <elfung@microsoft.com>
Fri, 2 Dec 2022 17:47:06 +0000 (09:47 -0800)
committerGitHub <noreply@github.com>
Fri, 2 Dec 2022 17:47:06 +0000 (09:47 -0800)
Allow COM components to opt-in to being loaded in the default ALC:

- `comhost` checks for the `System.Runtime.InteropServices.COM.LoadComponentInDefaultContext` property and uses new functions on `ComActivator`, loading into the default context if the property is set to true
- Default behaviour remains loading into an isolated context
- Fall back to using `ComActivator` functions that always load into an isolated context if new functions aren't found and loading in an isolated context

src/coreclr/System.Private.CoreLib/src/ILLink/ILLink.Descriptors.Windows.xml
src/coreclr/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivationContextInternal.cs
src/coreclr/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs
src/installer/tests/Assets/TestProjects/ComLibrary/ComLibrary.cs
src/installer/tests/HostActivation.Tests/NativeHosting/Comhost.cs
src/libraries/System.Private.CoreLib/src/ILLink/ILLink.Descriptors.LibraryBuild.xml
src/native/corehost/comhost/comhost.cpp
src/native/corehost/fxr_resolver.h
src/native/corehost/ijwhost/ijwhost.cpp
src/tests/Interop/COM/Activator/ComActivationContext.cs
src/tests/Interop/COM/Activator/Program.cs

index 7ddddc6f48af8c3c14e778f84a301cb86e1bc791..4addc7fd09a87b903d44a1d3ec92f9c4bcf3f900 100644 (file)
     </type>
   </assembly>
 
+  <assembly fullname="System.Private.CoreLib" feature="System.Runtime.InteropServices.BuiltInComInterop.IsSupported" featurevalue="true">
+    <!-- Enables the .NET COM host (.NET 8.0+) to load a COM component. -->
+    <type fullname="Internal.Runtime.InteropServices.ComActivator" >
+      <method name="GetClassFactoryForTypeInContext" />
+      <method name="RegisterClassForTypeInContext" />
+      <method name="UnregisterClassForTypeInContext" />
+    </type>
+  </assembly>
+
   <assembly fullname="System.Private.CoreLib" feature="System.Runtime.InteropServices.EnableCppCLIHostActivation" featurevalue="true">
     <!-- Enables the .NET IJW host (.NET 7.0+) to load an in-memory module as a .NET assembly. -->
     <type fullname="Internal.Runtime.InteropServices.InMemoryAssemblyLoader">
index fd88d8edad3ce0a23a27fb237990d943bed54b42..c3153f47dec5125afb0fc4769ac0fd8cda83064d 100644 (file)
@@ -17,34 +17,4 @@ namespace Internal.Runtime.InteropServices
         public char* TypeNameBuffer;
         public IntPtr ClassFactoryDest;
     }
-
-    //
-    // Types below are 'public' only to aid in testing of functionality.
-    // They should not be considered publicly consumable.
-    //
-
-    [StructLayout(LayoutKind.Sequential)]
-    internal partial struct ComActivationContext
-    {
-        public Guid ClassId;
-        public Guid InterfaceId;
-        public string AssemblyPath;
-        public string AssemblyName;
-        public string TypeName;
-    }
-
-    [ComImport]
-    [ComVisible(false)]
-    [Guid("00000001-0000-0000-C000-000000000046")]
-    [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
-    internal interface IClassFactory
-    {
-        [RequiresUnreferencedCode("Built-in COM support is not trim compatible", Url = "https://aka.ms/dotnet-illink/com")]
-        void CreateInstance(
-            [MarshalAs(UnmanagedType.Interface)] object? pUnkOuter,
-            ref Guid riid,
-            out IntPtr ppvObject);
-
-        void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock);
-    }
 }
index 476ff22862ec03eafb5becd4a8d6b59829e5c180..0ee478fa67545842bdf8b9c6d8f13217c6e03ac8 100644 (file)
@@ -29,6 +29,21 @@ namespace Internal.Runtime.InteropServices
         public bool fLicVerified;
     }
 
+    [ComImport]
+    [ComVisible(false)]
+    [Guid("00000001-0000-0000-C000-000000000046")]
+    [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
+    internal interface IClassFactory
+    {
+        [RequiresUnreferencedCode("Built-in COM support is not trim compatible", Url = "https://aka.ms/dotnet-illink/com")]
+        void CreateInstance(
+            [MarshalAs(UnmanagedType.Interface)] object? pUnkOuter,
+            ref Guid riid,
+            out IntPtr ppvObject);
+
+        void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock);
+    }
+
     [ComImport]
     [ComVisible(false)]
     [Guid("B196B28F-BAB4-101A-B69C-00AA00341D07")]
@@ -57,9 +72,17 @@ namespace Internal.Runtime.InteropServices
             out IntPtr ppvObject);
     }
 
-    internal partial struct ComActivationContext
+    [StructLayout(LayoutKind.Sequential)]
+    internal struct ComActivationContext
     {
-        public static unsafe ComActivationContext Create(ref ComActivationContextInternal cxtInt)
+        public Guid ClassId;
+        public Guid InterfaceId;
+        public string AssemblyPath;
+        public string AssemblyName;
+        public string TypeName;
+        public bool IsolatedContext;
+
+        public static unsafe ComActivationContext Create(ref ComActivationContextInternal cxtInt, bool isolatedContext)
         {
             if (!Marshal.IsBuiltInComSupported)
             {
@@ -72,7 +95,8 @@ namespace Internal.Runtime.InteropServices
                 InterfaceId = cxtInt.InterfaceId,
                 AssemblyPath = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyPathBuffer))!,
                 AssemblyName = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyNameBuffer))!,
-                TypeName = Marshal.PtrToStringUni(new IntPtr(cxtInt.TypeNameBuffer))!
+                TypeName = Marshal.PtrToStringUni(new IntPtr(cxtInt.TypeNameBuffer))!,
+                IsolatedContext = isolatedContext
             };
         }
     }
@@ -84,6 +108,9 @@ namespace Internal.Runtime.InteropServices
         // unloadable COM server ALCs, this will need to be changed.
         private static readonly Dictionary<string, AssemblyLoadContext> s_assemblyLoadContexts = new Dictionary<string, AssemblyLoadContext>(StringComparer.InvariantCultureIgnoreCase);
 
+        // COM component assembly paths loaded in the default ALC
+        private static readonly HashSet<string> s_loadedInDefaultContext = new HashSet<string>(StringComparer.InvariantCultureIgnoreCase);
+
         /// <summary>
         /// Entry point for unmanaged COM activation API from managed code
         /// </summary>
@@ -107,7 +134,7 @@ namespace Internal.Runtime.InteropServices
                 throw new ArgumentException(null, nameof(cxt));
             }
 
-            Type classType = FindClassType(cxt.ClassId, cxt.AssemblyPath, cxt.AssemblyName, cxt.TypeName);
+            Type classType = FindClassType(cxt);
 
             if (LicenseInteropProxy.HasLicense(classType))
             {
@@ -145,7 +172,7 @@ namespace Internal.Runtime.InteropServices
                 throw new ArgumentException(null, nameof(cxt));
             }
 
-            Type classType = FindClassType(cxt.ClassId, cxt.AssemblyPath, cxt.AssemblyName, cxt.TypeName);
+            Type classType = FindClassType(cxt);
 
             Type? currentType = classType;
             bool calledFunction = false;
@@ -213,17 +240,45 @@ namespace Internal.Runtime.InteropServices
         }
 
         /// <summary>
-        /// Internal entry point for unmanaged COM activation API from native code
+        /// Gets a class factory for COM activation in an isolated load context
         /// </summary>
         /// <param name="pCxtInt">Pointer to a <see cref="ComActivationContextInternal"/> instance</param>
         [UnmanagedCallersOnly]
         private static unsafe int GetClassFactoryForTypeInternal(ComActivationContextInternal* pCxtInt)
         {
             if (!Marshal.IsBuiltInComSupported)
-            {
                 throw new NotSupportedException(SR.NotSupported_COM);
-            }
 
+#pragma warning disable IL2026 // suppressed in ILLink.Suppressions.LibraryBuild.xml
+            return GetClassFactoryForTypeImpl(pCxtInt, isolatedContext: true);
+#pragma warning restore IL2026
+        }
+
+        /// <summary>
+        /// Gets a class factory for COM activation in the specified load context
+        /// </summary>
+        /// <param name="pCxtInt">Pointer to a <see cref="ComActivationContextInternal"/> instance</param>
+        /// <param name="loadContext">Load context - currently must be IntPtr.Zero (default context) or -1 (isolated context)</param>
+        [UnmanagedCallersOnly]
+        private static unsafe int GetClassFactoryForTypeInContext(ComActivationContextInternal* pCxtInt, IntPtr loadContext)
+        {
+            if (!Marshal.IsBuiltInComSupported)
+                throw new NotSupportedException(SR.NotSupported_COM);
+
+            if (loadContext != IntPtr.Zero && loadContext != (IntPtr)(-1))
+                throw new ArgumentOutOfRangeException(nameof(loadContext));
+
+            return GetClassFactoryForTypeLocal(pCxtInt, isolatedContext: loadContext != IntPtr.Zero);
+
+            // Use a local function for a targeted suppression of the requires unreferenced code warning
+            [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026:RequiresUnreferencedCode",
+                Justification = "The same feature switch applies to GetClassFactoryForTypeInternal and this function. We rely on the warning from GetClassFactoryForTypeInternal.")]
+            static int GetClassFactoryForTypeLocal(ComActivationContextInternal* pCxtInt, bool isolatedContext) => GetClassFactoryForTypeImpl(pCxtInt, isolatedContext);
+        }
+
+        [RequiresUnreferencedCode("Built-in COM support is not trim compatible", Url = "https://aka.ms/dotnet-illink/com")]
+        private static unsafe int GetClassFactoryForTypeImpl(ComActivationContextInternal* pCxtInt, bool isolatedContext)
+        {
             ref ComActivationContextInternal cxtInt = ref *pCxtInt;
 
             if (IsLoggingEnabled())
@@ -240,10 +295,8 @@ $@"{nameof(GetClassFactoryForTypeInternal)} arguments:
 
             try
             {
-                var cxt = ComActivationContext.Create(ref cxtInt);
-#pragma warning disable IL2026 // suppressed in ILLink.Suppressions.LibraryBuild.xml
+                var cxt = ComActivationContext.Create(ref cxtInt, isolatedContext);
                 object cf = GetClassFactoryForType(cxt);
-#pragma warning restore IL2026
                 IntPtr nativeIUnknown = Marshal.GetIUnknownForObject(cf);
                 Marshal.WriteIntPtr(cxtInt.ClassFactoryDest, nativeIUnknown);
             }
@@ -256,17 +309,37 @@ $@"{nameof(GetClassFactoryForTypeInternal)} arguments:
         }
 
         /// <summary>
-        /// Internal entry point for registering a managed COM server API from native code
+        /// Registers a managed COM server in an isolated load context
         /// </summary>
         /// <param name="pCxtInt">Pointer to a <see cref="ComActivationContextInternal"/> instance</param>
         [UnmanagedCallersOnly]
         private static unsafe int RegisterClassForTypeInternal(ComActivationContextInternal* pCxtInt)
         {
             if (!Marshal.IsBuiltInComSupported)
-            {
                 throw new NotSupportedException(SR.NotSupported_COM);
-            }
 
+            return RegisterClassForTypeImpl(pCxtInt, isolatedContext: true);
+        }
+
+        /// <summary>
+        /// Registers a managed COM server in the specified load context
+        /// </summary>
+        /// <param name="pCxtInt">Pointer to a <see cref="ComActivationContextInternal"/> instance</param>
+        /// <param name="loadContext">Load context - currently must be IntPtr.Zero (default context) or -1 (isolated context)</param>
+        [UnmanagedCallersOnly]
+        private static unsafe int RegisterClassForTypeInContext(ComActivationContextInternal* pCxtInt, IntPtr loadContext)
+        {
+            if (!Marshal.IsBuiltInComSupported)
+                throw new NotSupportedException(SR.NotSupported_COM);
+
+            if (loadContext != IntPtr.Zero && loadContext != (IntPtr)(-1))
+                throw new ArgumentOutOfRangeException(nameof(loadContext));
+
+            return RegisterClassForTypeImpl(pCxtInt, isolatedContext: loadContext != IntPtr.Zero);
+        }
+
+        private static unsafe int RegisterClassForTypeImpl(ComActivationContextInternal* pCxtInt, bool isolatedContext)
+        {
             ref ComActivationContextInternal cxtInt = ref *pCxtInt;
 
             if (IsLoggingEnabled())
@@ -289,7 +362,7 @@ $@"{nameof(RegisterClassForTypeInternal)} arguments:
 
             try
             {
-                var cxt = ComActivationContext.Create(ref cxtInt);
+                var cxt = ComActivationContext.Create(ref cxtInt, isolatedContext);
                 ClassRegistrationScenarioForTypeLocal(cxt, register: true);
             }
             catch (Exception e)
@@ -306,16 +379,36 @@ $@"{nameof(RegisterClassForTypeInternal)} arguments:
         }
 
         /// <summary>
-        /// Internal entry point for unregistering a managed COM server API from native code
+        /// Unregisters a managed COM server in an isolated load context
         /// </summary>
         [UnmanagedCallersOnly]
         private static unsafe int UnregisterClassForTypeInternal(ComActivationContextInternal* pCxtInt)
         {
             if (!Marshal.IsBuiltInComSupported)
-            {
                 throw new NotSupportedException(SR.NotSupported_COM);
-            }
 
+            return UnregisterClassForTypeImpl(pCxtInt, isolatedContext: true);
+        }
+
+        /// <summary>
+        /// Unregisters a managed COM server in the specified load context
+        /// </summary>
+        /// <param name="pCxtInt">Pointer to a <see cref="ComActivationContextInternal"/> instance</param>
+        /// <param name="loadContext">Load context - currently must be IntPtr.Zero (default context) or -1 (isolated context)</param>
+        [UnmanagedCallersOnly]
+        private static unsafe int UnregisterClassForTypeInContext(ComActivationContextInternal* pCxtInt, IntPtr loadContext)
+        {
+            if (!Marshal.IsBuiltInComSupported)
+                throw new NotSupportedException(SR.NotSupported_COM);
+
+            if (loadContext != IntPtr.Zero && loadContext != (IntPtr)(-1))
+                throw new ArgumentOutOfRangeException(nameof(loadContext));
+
+            return UnregisterClassForTypeImpl(pCxtInt, isolatedContext: loadContext != IntPtr.Zero);
+        }
+
+        private static unsafe int UnregisterClassForTypeImpl(ComActivationContextInternal* pCxtInt, bool isolatedContext)
+        {
             ref ComActivationContextInternal cxtInt = ref *pCxtInt;
 
             if (IsLoggingEnabled())
@@ -338,7 +431,7 @@ $@"{nameof(UnregisterClassForTypeInternal)} arguments:
 
             try
             {
-                var cxt = ComActivationContext.Create(ref cxtInt);
+                var cxt = ComActivationContext.Create(ref cxtInt, isolatedContext);
                 ClassRegistrationScenarioForTypeLocal(cxt, register: false);
             }
             catch (Exception e)
@@ -370,14 +463,14 @@ $@"{nameof(UnregisterClassForTypeInternal)} arguments:
         }
 
         [RequiresUnreferencedCode("Built-in COM support is not trim compatible", Url = "https://aka.ms/dotnet-illink/com")]
-        private static Type FindClassType(Guid clsid, string assemblyPath, string assemblyName, string typeName)
+        private static Type FindClassType(ComActivationContext cxt)
         {
             try
             {
-                AssemblyLoadContext alc = GetALC(assemblyPath);
-                var assemblyNameLocal = new AssemblyName(assemblyName);
+                AssemblyLoadContext alc = GetALC(cxt.AssemblyPath, cxt.IsolatedContext);
+                var assemblyNameLocal = new AssemblyName(cxt.AssemblyName);
                 Assembly assem = alc.LoadFromAssemblyName(assemblyNameLocal);
-                Type? t = assem.GetType(typeName);
+                Type? t = assem.GetType(cxt.TypeName);
                 if (t != null)
                 {
                     return t;
@@ -387,7 +480,7 @@ $@"{nameof(UnregisterClassForTypeInternal)} arguments:
             {
                 if (IsLoggingEnabled())
                 {
-                    Log($"COM Activation of {clsid} failed. {e}");
+                    Log($"COM Activation of {cxt.ClassId} failed. {e}");
                 }
             }
 
@@ -396,16 +489,39 @@ $@"{nameof(UnregisterClassForTypeInternal)} arguments:
         }
 
         [RequiresUnreferencedCode("The trimmer might remove types which are needed by the assemblies loaded in this method.")]
-        private static AssemblyLoadContext GetALC(string assemblyPath)
+        private static AssemblyLoadContext GetALC(string assemblyPath, bool isolatedContext)
         {
             AssemblyLoadContext? alc;
-
-            lock (s_assemblyLoadContexts)
+            if (isolatedContext)
+            {
+                lock (s_assemblyLoadContexts)
+                {
+                    if (!s_assemblyLoadContexts.TryGetValue(assemblyPath, out alc))
+                    {
+                        alc = new IsolatedComponentLoadContext(assemblyPath);
+                        s_assemblyLoadContexts.Add(assemblyPath, alc);
+                    }
+                }
+            }
+            else
             {
-                if (!s_assemblyLoadContexts.TryGetValue(assemblyPath, out alc))
+                alc = AssemblyLoadContext.Default;
+                lock (s_loadedInDefaultContext)
                 {
-                    alc = new IsolatedComponentLoadContext(assemblyPath);
-                    s_assemblyLoadContexts.Add(assemblyPath, alc);
+                    if (!s_loadedInDefaultContext.Contains(assemblyPath))
+                    {
+                        var resolver = new AssemblyDependencyResolver(assemblyPath);
+                        AssemblyLoadContext.Default.Resolving +=
+                            (context, assemblyName) =>
+                            {
+                                string? assemblyPath = resolver.ResolveAssemblyToPath(assemblyName);
+                                return assemblyPath != null
+                                    ? context.LoadFromAssemblyPath(assemblyPath)
+                                    : null;
+                            };
+
+                        s_loadedInDefaultContext.Add(assemblyPath);
+                    }
                 }
             }
 
index cf4418acd2744828ef2f0600bf7e08037c5cbe28..ae8947fb6a1e5cb65964009c21a7c9af917433e8 100644 (file)
@@ -2,7 +2,9 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
+using System.Reflection;
 using System.Runtime.InteropServices;
+using System.Runtime.Loader;
 
 namespace ComLibrary
 {
@@ -25,6 +27,8 @@ namespace ComLibrary
     {
         public Server()
         {
+            Assembly asm = Assembly.GetExecutingAssembly();
+            Console.WriteLine($"{asm.GetName().Name}: AssemblyLoadContext = {AssemblyLoadContext.GetLoadContext(asm)}");
             Console.WriteLine($"New instance of {nameof(Server)} created");
         }
     }
index 373dea32c722694c33b3b9fc0388a373a7b71759..9f6018b41b9229a6d2f86e7f5e41409e13f7307f 100644 (file)
@@ -40,7 +40,8 @@ namespace Microsoft.DotNet.CoreSetup.Test.HostActivation.NativeHosting
                 .Execute();
 
             result.Should().Pass()
-                .And.HaveStdOutContaining("New instance of Server created");
+                .And.HaveStdOutContaining("New instance of Server created")
+                .And.ExecuteInIsolatedContext(sharedState.ComLibraryFixture.TestProject.AssemblyName);
 
             for (var i = 1; i <= count; ++i)
             {
@@ -48,6 +49,46 @@ namespace Microsoft.DotNet.CoreSetup.Test.HostActivation.NativeHosting
             }
         }
 
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public void ActivateClass_ContextConfig(bool inDefaultContext)
+        {
+            using (var fixture = sharedState.ComLibraryFixture.Copy())
+            {
+                var comHost = Path.Combine(
+                    fixture.TestProject.BuiltApp.Location,
+                    $"{fixture.TestProject.AssemblyName}.comhost.dll");
+
+                RuntimeConfig.FromFile(fixture.TestProject.RuntimeConfigJson)
+                    .WithProperty("System.Runtime.InteropServices.COM.LoadComponentInDefaultContext", inDefaultContext.ToString())
+                    .Save();
+
+                string[] args = {
+                    "comhost",
+                    "synchronous",
+                    "1",
+                    comHost,
+                    sharedState.ClsidString
+                    };
+                CommandResult result = sharedState.CreateNativeHostCommand(args, fixture.BuiltDotnet.BinPath)
+                    .Execute();
+
+                result.Should().Pass()
+                    .And.HaveStdOutContaining("New instance of Server created")
+                    .And.HaveStdOutContaining($"Activation of {sharedState.ClsidString} succeeded.");
+
+                if (inDefaultContext)
+                {
+                    result.Should().ExecuteInDefaultContext(sharedState.ComLibraryFixture.TestProject.AssemblyName);
+                }
+                else
+                {
+                    result.Should().ExecuteInIsolatedContext(sharedState.ComLibraryFixture.TestProject.AssemblyName);
+                }
+            }
+        }
+
         [Fact]
         public void ActivateClass_IgnoreAppLocalHostFxr()
         {
index d7380437c86f8d02dd757a01977a362f688ed41f..7c201b271a54d16505cfe778cdb04cd82a5b4419 100644 (file)
     </type>
     <type fullname="Internal.Runtime.InteropServices.ComActivator">
       <!-- Used by hostpolicy.cpp -->
+      <method name="GetClassFactoryForTypeInContext" />
       <method name="GetClassFactoryForTypeInternal" />
+      <method name="RegisterClassForTypeInContext" />
       <method name="RegisterClassForTypeInternal" />
+      <method name="UnregisterClassForTypeInContext" />
       <method name="UnregisterClassForTypeInternal" />
     </type>
     <type fullname="Internal.Runtime.InteropServices.InMemoryAssemblyLoader">
index ebc68253312f0861dcbeca63bb57f632dff6e0cc..c5317bf31503fa7af220c1b58b3fecf39ddb45c1 100644 (file)
@@ -11,6 +11,7 @@
 #include "utils.h"
 #include <type_traits>
 #include <minipal/utils.h>
+#include <coreclr_delegates.h>
 
 using comhost::clsid_map_entry;
 using comhost::clsid_map;
@@ -42,14 +43,59 @@ struct com_activation_context
     void **class_factory_dest;
 };
 
-using com_delegate_fn = int(STDMETHODCALLTYPE*)(com_activation_context*);
+#define ISOLATED_CONTEXT (void*)-1
+using com_delegate_fn = int(STDMETHODCALLTYPE*)(com_activation_context*, void*);
+using com_delegate_no_load_context_fn = int(STDMETHODCALLTYPE*)(com_activation_context*);
 
 namespace
 {
-    int get_com_delegate(hostfxr_delegate_type del_type, pal::string_t *app_path, com_delegate_fn *delegate)
+    struct com_delegates
     {
-        return load_fxr_and_get_delegate(
-            del_type,
+        com_delegate_fn delegate;
+
+        // Delegate that does not take a load context. This version has existed since COM support was
+        // added in .NET Core 3.0. It is used as a fallback when loading in an isolated load context
+        // in versions of .NET without the functions that take a load context.
+        com_delegate_no_load_context_fn delegate_no_load_cxt;
+    };
+
+    // Fallback for loading a COM server in an isolated context in versions of .NET that don't have the
+    // functions that take a load context.
+    int get_com_delegate_no_load_context(hostfxr_delegate_type del_type, get_function_pointer_fn get_function_pointer, com_delegate_no_load_context_fn *delegate)
+    {
+        const pal::char_t* method_name;
+        switch (del_type)
+        {
+        case hostfxr_delegate_type::hdt_com_activation:
+            method_name = _X("GetClassFactoryForTypeInternal");
+            break;
+        case hostfxr_delegate_type::hdt_com_register:
+            method_name = _X("RegisterClassForTypeInternal");
+            break;
+        case hostfxr_delegate_type::hdt_com_unregister:
+            method_name = _X("UnregisterClassForTypeInternal");
+            break;
+        default:
+            return StatusCode::InvalidArgFailure;
+        }
+
+        return get_function_pointer(
+            _X("Internal.Runtime.InteropServices.ComActivator, System.Private.CoreLib"),
+            method_name,
+            UNMANAGEDCALLERSONLY_METHOD,
+            nullptr, // load context
+            nullptr, // reserved
+            reinterpret_cast<void**>(delegate));
+    }
+
+    int get_com_delegate(hostfxr_delegate_type del_type, pal::string_t *app_path, com_delegates &delegates, void **load_context)
+    {
+        delegates.delegate = nullptr;
+        delegates.delegate_no_load_cxt = nullptr;
+
+        get_function_pointer_fn get_function_pointer;
+        int status = load_fxr_and_get_delegate(
+            hostfxr_delegate_type::hdt_get_function_pointer,
             [app_path](const pal::string_t& host_path, pal::string_t* config_path_out)
             {
                 // Strip the comhost suffix to get the 'app' and config
@@ -66,8 +112,61 @@ namespace
 
                 return StatusCode::Success;
             },
-            delegate
+            [load_context](pal::dll_t fxr, hostfxr_handle context)
+            {
+                *load_context = ISOLATED_CONTEXT;
+                auto get_runtime_property_value = reinterpret_cast<hostfxr_get_runtime_property_value_fn>(pal::get_symbol(fxr, "hostfxr_get_runtime_property_value"));
+                const pal::char_t* value;
+                if (get_runtime_property_value(context, _X("System.Runtime.InteropServices.COM.LoadComponentInDefaultContext"), &value) == StatusCode::Success
+                    && pal::strcasecmp(value, _X("true")) == 0)
+                {
+                    *load_context = nullptr; // Default context
+                }
+            },
+            reinterpret_cast<void**>(&get_function_pointer)
         );
+        if (status != StatusCode::Success)
+            return status;
+
+        const pal::char_t* method_name;
+        switch (del_type)
+        {
+        case hostfxr_delegate_type::hdt_com_activation:
+            method_name = _X("GetClassFactoryForTypeInContext");
+            break;
+        case hostfxr_delegate_type::hdt_com_register:
+            method_name = _X("RegisterClassForTypeInContext");
+            break;
+        case hostfxr_delegate_type::hdt_com_unregister:
+            method_name = _X("UnregisterClassForTypeInContext");
+            break;
+        default:
+            return StatusCode::InvalidArgFailure;
+        }
+
+        status = get_function_pointer(
+            _X("Internal.Runtime.InteropServices.ComActivator, System.Private.CoreLib"),
+            method_name,
+            UNMANAGEDCALLERSONLY_METHOD,
+            nullptr, // load context
+            nullptr, // reserved
+            (void**)&delegates.delegate);
+
+        if (status == StatusCode::Success)
+            return status;
+
+        // Newer methods with context not found and using isolated context.
+        // Fall back to methods without context.
+        // The runtime will throw MissingMethodException, so we check for the corresponding COR_E_MISSINGMETHOD HRESULT.
+        // We also need to check for COR_E_MISSINGMEMBER due to a pre-7.0 bug where the HRESULT was not correctly set on
+        // MissingMethodException and it ended up with the HRESULT for MissingMemberException
+        if ((status == 0x80131513 /*COR_E_MISSINGMETHOD*/ || status == 0x80131512 /*COR_E_MISSINGMEMBER*/)
+            && *load_context == ISOLATED_CONTEXT)
+        {
+            status = get_com_delegate_no_load_context(del_type, get_function_pointer, &delegates.delegate_no_load_cxt);
+        }
+
+        return status;
     }
 
     void report_com_error_info(const GUID& guid, pal::string_t errs)
@@ -108,14 +207,15 @@ COM_API HRESULT STDMETHODCALLTYPE DllGetClassObject(
 
     HRESULT hr;
     pal::string_t app_path;
-    com_delegate_fn act;
+    com_delegates act;
+    void* load_context;
     {
         trace::setup();
         reset_redirected_error_writer();
 
         error_writer_scope_t writer_scope(redirected_error_writer);
 
-        int ec = get_com_delegate(hostfxr_delegate_type::hdt_com_activation, &app_path, &act);
+        int ec = get_com_delegate(hostfxr_delegate_type::hdt_com_activation, &app_path, act, &load_context);
         if (ec != StatusCode::Success)
         {
             report_com_error_info(rclsid, std::move(get_redirected_error_string()));
@@ -123,6 +223,8 @@ COM_API HRESULT STDMETHODCALLTYPE DllGetClassObject(
         }
     }
 
+    assert(act.delegate != nullptr || load_context == ISOLATED_CONTEXT);
+
     // Query the CLR for the type
 
     IUnknown *classFactory = nullptr;
@@ -135,7 +237,14 @@ COM_API HRESULT STDMETHODCALLTYPE DllGetClassObject(
         iter->second.type.c_str(),
         (void**)&classFactory
     };
-    RETURN_IF_FAILED(act(&cxt));
+    if (act.delegate != nullptr)
+    {
+        RETURN_IF_FAILED(act.delegate(&cxt, load_context));
+    }
+    else
+    {
+        RETURN_IF_FAILED(act.delegate_no_load_cxt(&cxt));
+    }
     assert(classFactory != nullptr);
 
     hr = classFactory->QueryInterface(riid, ppv);
@@ -463,8 +572,10 @@ COM_API HRESULT STDMETHODCALLTYPE DllRegisterServer(void)
 
     HRESULT hr;
     pal::string_t app_path;
-    com_delegate_fn reg;
-    RETURN_IF_FAILED(get_com_delegate(hostfxr_delegate_type::hdt_com_register, &app_path, &reg));
+    com_delegates reg;
+    void* load_context;
+    RETURN_IF_FAILED(get_com_delegate(hostfxr_delegate_type::hdt_com_register, &app_path, reg, &load_context));
+    assert(reg.delegate != nullptr || load_context == ISOLATED_CONTEXT);
 
     com_activation_context cxt
     {
@@ -486,7 +597,14 @@ COM_API HRESULT STDMETHODCALLTYPE DllRegisterServer(void)
         cxt.class_id = p.first;
         cxt.assembly_name = p.second.assembly.c_str();
         cxt.type_name = p.second.type.c_str();
-        RETURN_IF_FAILED(reg(&cxt));
+        if (reg.delegate != nullptr)
+        {
+            RETURN_IF_FAILED(reg.delegate(&cxt, load_context));
+        }
+        else
+        {
+            RETURN_IF_FAILED(reg.delegate_no_load_cxt(&cxt));
+        }
     }
 
     return S_OK;
@@ -507,8 +625,10 @@ COM_API HRESULT STDMETHODCALLTYPE DllUnregisterServer(void)
 
     HRESULT hr;
     pal::string_t app_path;
-    com_delegate_fn unreg;
-    RETURN_IF_FAILED(get_com_delegate(hostfxr_delegate_type::hdt_com_unregister, &app_path, &unreg));
+    com_delegates unreg;
+    void* load_context;
+    RETURN_IF_FAILED(get_com_delegate(hostfxr_delegate_type::hdt_com_unregister, &app_path, unreg, &load_context));
+    assert(unreg.delegate != nullptr || load_context == ISOLATED_CONTEXT);
 
     com_activation_context cxt
     {
@@ -527,7 +647,14 @@ COM_API HRESULT STDMETHODCALLTYPE DllUnregisterServer(void)
         cxt.class_id = p.first;
         cxt.assembly_name = p.second.assembly.c_str();
         cxt.type_name = p.second.type.c_str();
-        RETURN_IF_FAILED(unreg(&cxt));
+        if (unreg.delegate != nullptr)
+        {
+            RETURN_IF_FAILED(unreg.delegate(&cxt, load_context));
+        }
+        else
+        {
+            RETURN_IF_FAILED(unreg.delegate_no_load_cxt(&cxt));
+        }
 
         // Unregister the CLSID from registry
         RETURN_IF_FAILED(RemoveClsid(p.second));
index 152a2ae92e1d6afe575f3d2f7682d96cfb4031c1..3df6c2de3053d3488851f9e9e0cc9f1f728389c1 100644 (file)
@@ -17,8 +17,8 @@ namespace fxr_resolver
     bool try_get_existing_fxr(pal::dll_t *out_fxr, pal::string_t *out_fxr_path);
 }
 
-template<typename THostPathToConfigCallback, typename TDelegate>
-int load_fxr_and_get_delegate(hostfxr_delegate_type type, THostPathToConfigCallback host_path_to_config_path, TDelegate* delegate)
+template<typename THostPathToConfigCallback, typename TBeforeRunCallback>
+int load_fxr_and_get_delegate(hostfxr_delegate_type type, THostPathToConfigCallback host_path_to_config_path, TBeforeRunCallback on_before_run, void** delegate)
 {
     pal::dll_t fxr;
 
@@ -64,11 +64,8 @@ int load_fxr_and_get_delegate(hostfxr_delegate_type type, THostPathToConfigCallb
 
     pal::string_t config_path;
     pal::hresult_t status = host_path_to_config_path(host_path, &config_path);
-
     if (status != StatusCode::Success)
-    {
         return status;
-    }
 
     hostfxr_initialize_parameters parameters {
         sizeof(hostfxr_initialize_parameters),
@@ -86,7 +83,9 @@ int load_fxr_and_get_delegate(hostfxr_delegate_type type, THostPathToConfigCallb
         if (!STATUS_CODE_SUCCEEDED(rc))
             return rc;
 
-        rc = hostfxr_get_runtime_delegate(context, type, reinterpret_cast<void**>(delegate));
+        on_before_run(fxr, context);
+
+        rc = hostfxr_get_runtime_delegate(context, type, delegate);
 
         int rcClose = hostfxr_close(context);
         if (rcClose != StatusCode::Success)
index ab4f047ee8123c8fd730cea51cc6a9dc7dac4c42..7f1684cca6263049f804ac8fc447b9fb2f921e65 100644 (file)
@@ -40,7 +40,8 @@ pal::hresult_t get_load_in_memory_assembly_delegate(pal::dll_t handle, load_in_m
 
             return StatusCode::Success;
         },
-        &get_function_pointer
+        [](pal::dll_t fxr, hostfxr_handle context){ },
+        reinterpret_cast<void**>(&get_function_pointer)
     );
     if (status != StatusCode::Success)
         return status;
index 9864107f694f5fe6acf72191761594b43809152f..9f3ba0c1de767994a959d2a5577f86e88c799b2b 100644 (file)
@@ -15,5 +15,6 @@ namespace Activator
         public string AssemblyPath;
         public string AssemblyName;
         public string TypeName;
+        public bool IsolatedContext;
     }
 }
index b45ec9da6b1ca879c6e0501e184653c7e71530d4..0bb5763d73f6d665ded4dca2baba63328b8f96a8 100644 (file)
@@ -128,9 +128,9 @@ namespace Activator
             }
         }
 
-        static void ValidateAssemblyIsolation(bool builtInComDisabled)
+        static void ValidateAssemblyIsolation(bool builtInComDisabled, bool useIsolatedContext)
         {
-            Console.WriteLine($"Running {nameof(ValidateAssemblyIsolation)}...");
+            Console.WriteLine($"Running {nameof(ValidateAssemblyIsolation)}({nameof(ComActivationContext.IsolatedContext)}={useIsolatedContext})...");
 
             string assemblySubPath = Path.Combine(Environment.CurrentDirectory, "Servers");
             string assemblyAPath = Path.Combine(assemblySubPath, "AssemblyA.dll");
@@ -157,7 +157,8 @@ namespace Activator
                     InterfaceId = typeof(IClassFactory).GUID,
                     AssemblyPath = assemblyAPath,
                     AssemblyName = "AssemblyA",
-                    TypeName = "ClassFromA"
+                    TypeName = "ClassFromA",
+                    IsolatedContext = useIsolatedContext,
                 };
 
                 if (builtInComDisabled)
@@ -188,7 +189,8 @@ namespace Activator
                     InterfaceId = typeof(IClassFactory).GUID,
                     AssemblyPath = assemblyBPath,
                     AssemblyName = "AssemblyB",
-                    TypeName = "ClassFromB"
+                    TypeName = "ClassFromB",
+                    IsolatedContext = useIsolatedContext
                 };
 
                 var factory = GetClassFactoryForType(cxt);
@@ -200,7 +202,14 @@ namespace Activator
                 typeCFromAssemblyB = (Type)svr.GetTypeFromC();
             }
 
-            Assert.NotEqual(typeCFromAssemblyA, typeCFromAssemblyB);
+            if (useIsolatedContext)
+            {
+                Assert.NotEqual(typeCFromAssemblyA, typeCFromAssemblyB);
+            }
+            else
+            {
+                Assert.Equal(typeCFromAssemblyA, typeCFromAssemblyB);
+            }
         }
 
         static void ValidateUserDefinedRegistrationCallbacks()
@@ -334,10 +343,11 @@ namespace Activator
                 InvalidInterfaceRequest();
                 ClassNotRegistered(builtInComDisabled);
                 NonrootedAssemblyPath(builtInComDisabled);
-                ValidateAssemblyIsolation(builtInComDisabled);
+                ValidateAssemblyIsolation(builtInComDisabled, useIsolatedContext: true);
                 if (!builtInComDisabled)
                 {
                     // We don't test this scenario with builtInComDisabled since it is covered by ValidateAssemblyIsolation() above
+                    ValidateAssemblyIsolation(builtInComDisabled, useIsolatedContext: false);
                     ValidateUserDefinedRegistrationCallbacks();
                 }
             }