UnmanagedCallersOnlyAttribute in load_assembly_and_get_function_pointer (#35763)
authorAaron Robinson <arobins@microsoft.com>
Tue, 5 May 2020 18:03:02 +0000 (11:03 -0700)
committerGitHub <noreply@github.com>
Tue, 5 May 2020 18:03:02 +0000 (11:03 -0700)
* Add support for UnmanagedCallersOnlyAttribute in the load_assembly_and_get_function_pointer API.

* Handle x86 of UnmanagedCallersOnly in managed GetFunctionPointer() API.

15 files changed:
src/coreclr/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComponentActivator.cs
src/coreclr/src/tools/Common/JitInterface/CorInfoImpl.cs
src/coreclr/src/tools/crossgen2/ILCompiler.ReadyToRun/JitInterface/CorInfoImpl.ReadyToRun.cs
src/coreclr/src/vm/comdelegate.cpp
src/coreclr/src/vm/comdelegate.h
src/coreclr/src/vm/corhost.cpp
src/coreclr/src/vm/dllimportcallback.cpp
src/coreclr/src/vm/dllimportcallback.h
src/coreclr/src/vm/jitinterface.cpp
src/coreclr/src/vm/runtimehandles.cpp
src/installer/corehost/cli/coreclr_delegates.h
src/installer/corehost/cli/test/nativehost/host_context_test.cpp
src/installer/test/Assets/TestProjects/ComponentWithNoDependencies/Component.cs
src/installer/test/HostActivation.Tests/NativeHosting/ComponentActivation.cs
src/libraries/System.Private.CoreLib/src/Resources/Strings.resx

index 1c25db0..c571fcf 100644 (file)
@@ -6,6 +6,7 @@
 
 using System;
 using System.Collections.Generic;
+using System.Diagnostics;
 using System.Reflection;
 using System.Runtime.InteropServices;
 
@@ -48,7 +49,7 @@ namespace Internal.Runtime.InteropServices
         /// <param name="reserved">Extensibility parameter (currently unused)</param>
         /// <param name="functionHandle">Pointer where to store the function pointer result</param>
         [UnmanagedCallersOnly]
-        public static int LoadAssemblyAndGetFunctionPointer(IntPtr assemblyPathNative,
+        public static unsafe int LoadAssemblyAndGetFunctionPointer(IntPtr assemblyPathNative,
                                                             IntPtr typeNameNative,
                                                             IntPtr methodNameNative,
                                                             IntPtr delegateTypeNative,
@@ -57,18 +58,36 @@ namespace Internal.Runtime.InteropServices
         {
             try
             {
+                // Load the assembly and create a resolver callback for types.
                 string assemblyPath = MarshalToString(assemblyPathNative, nameof(assemblyPathNative));
+                IsolatedComponentLoadContext alc = GetIsolatedComponentLoadContext(assemblyPath);
+                Func<AssemblyName, Assembly> resolver = name => alc.LoadFromAssemblyName(name);
+
+                // Get the requested type.
                 string typeName = MarshalToString(typeNameNative, nameof(typeNameNative));
+                Type type = Type.GetType(typeName, resolver, null, throwOnError: true)!;
+
+                // Get the method name on the type.
                 string methodName = MarshalToString(methodNameNative, nameof(methodNameNative));
 
-                string delegateType;
+                // Determine the signature of the type. There are 3 possibilities:
+                //  * No delegate type was supplied - use the default (i.e. ComponentEntryPoint).
+                //  * A sentinel value was supplied - the function is marked UnmanagedCallersOnly. This means
+                //      a function pointer can be returned without creating a delegate.
+                //  * A delegate type was supplied - Load the type and create a delegate for that method.
+                Type? delegateType;
                 if (delegateTypeNative == IntPtr.Zero)
                 {
-                    delegateType = typeof(ComponentEntryPoint).AssemblyQualifiedName!;
+                    delegateType = typeof(ComponentEntryPoint);
+                }
+                else if (delegateTypeNative == (IntPtr)(-1))
+                {
+                    delegateType = null;
                 }
                 else
                 {
-                    delegateType = MarshalToString(delegateTypeNative, nameof(delegateTypeNative));
+                    string delegateTypeName = MarshalToString(delegateTypeNative, nameof(delegateTypeNative));
+                    delegateType = Type.GetType(delegateTypeName, resolver, null, throwOnError: true)!;
                 }
 
                 if (reserved != IntPtr.Zero)
@@ -81,17 +100,35 @@ namespace Internal.Runtime.InteropServices
                     throw new ArgumentNullException(nameof(functionHandle));
                 }
 
-                Delegate d = CreateDelegate(assemblyPath, typeName, methodName, delegateType);
+                IntPtr functionPtr;
+                if (delegateType == null)
+                {
+                    // Match search semantics of the CreateDelegate() function below.
+                    BindingFlags bindingFlags = BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic;
+                    MethodInfo? methodInfo = type.GetMethod(methodName, bindingFlags);
+                    if (methodInfo == null)
+                        throw new MissingMethodException(typeName, methodName);
 
-                IntPtr functionPtr = Marshal.GetFunctionPointerForDelegate(d);
+                    // Verify the function is properly marked.
+                    if (null == methodInfo.GetCustomAttribute<UnmanagedCallersOnlyAttribute>())
+                        throw new InvalidOperationException(SR.InvalidOperation_FunctionMissingUnmanagedCallersOnly);
 
-                lock (s_delegates)
+                    functionPtr = methodInfo.MethodHandle.GetFunctionPointer();
+                }
+                else
                 {
-                    // Keep a reference to the delegate to prevent it from being garbage collected
-                    s_delegates[functionPtr] = d;
+                    Delegate d = Delegate.CreateDelegate(delegateType, type, methodName)!;
+
+                    functionPtr = Marshal.GetFunctionPointerForDelegate(d);
+
+                    lock (s_delegates)
+                    {
+                        // Keep a reference to the delegate to prevent it from being garbage collected
+                        s_delegates[functionPtr] = d;
+                    }
                 }
 
-                Marshal.WriteIntPtr(functionHandle, functionPtr);
+                *(IntPtr*)functionHandle = functionPtr;
             }
             catch (Exception e)
             {
@@ -101,23 +138,6 @@ namespace Internal.Runtime.InteropServices
             return 0;
         }
 
-        private static Delegate CreateDelegate(string assemblyPath, string typeName, string methodName, string delegateTypeName)
-        {
-            // Throws
-            IsolatedComponentLoadContext alc = GetIsolatedComponentLoadContext(assemblyPath);
-
-            Func<AssemblyName, Assembly> resolver = name => alc.LoadFromAssemblyName(name);
-
-            // Throws
-            Type type = Type.GetType(typeName, resolver, null, throwOnError: true)!;
-
-            // Throws
-            Type delegateType = Type.GetType(delegateTypeName, resolver, null, throwOnError: true)!;
-
-            // Throws
-            return Delegate.CreateDelegate(delegateType, type, methodName)!;
-        }
-
         private static IsolatedComponentLoadContext GetIsolatedComponentLoadContext(string assemblyPath)
         {
             IsolatedComponentLoadContext? alc;
index aeef11e..753fc0a 100644 (file)
@@ -18,6 +18,7 @@ using Internal.Runtime.CompilerServices;
 using Internal.IL;
 using Internal.TypeSystem;
 using Internal.TypeSystem.Ecma;
+using Internal.TypeSystem.Interop;
 using Internal.CorConstants;
 
 using ILCompiler;
@@ -2934,6 +2935,22 @@ namespace Internal.JitInterface
                 }
 #endif
 
+                // Validate UnmanagedCallersOnlyAttribute usage
+                if (!this.MethodBeingCompiled.Signature.IsStatic) // Must be a static method
+                {
+                    ThrowHelper.ThrowInvalidProgramException(ExceptionStringID.InvalidProgramNonStaticMethod, this.MethodBeingCompiled);
+                }
+
+                if (this.MethodBeingCompiled.HasInstantiation || this.MethodBeingCompiled.OwningType.HasInstantiation) // No generics involved
+                {
+                    ThrowHelper.ThrowInvalidProgramException(ExceptionStringID.InvalidProgramGenericMethod, this.MethodBeingCompiled);
+                }
+
+                if (Marshaller.IsMarshallingRequired(this.MethodBeingCompiled.Signature, Array.Empty<ParameterMetadata>())) // Only blittable arguments
+                {
+                    ThrowHelper.ThrowInvalidProgramException(ExceptionStringID.InvalidProgramNonBlittableTypes, this.MethodBeingCompiled);
+                }
+
                 flags.Set(CorJitFlag.CORJIT_FLAG_REVERSE_PINVOKE);
             }
 
index 97fbe08..a0691e4 100644 (file)
@@ -1153,25 +1153,6 @@ namespace Internal.JitInterface
                 ThrowHelper.ThrowInvalidProgramException(ExceptionStringID.InvalidProgramCallVirtStatic, originalMethod);
             }
 
-            if ((flags & CORINFO_CALLINFO_FLAGS.CORINFO_CALLINFO_LDFTN) != 0
-                && originalMethod.IsUnmanagedCallersOnly)
-            {
-                if (!originalMethod.Signature.IsStatic) // Must be a static method
-                {
-                    ThrowHelper.ThrowInvalidProgramException(ExceptionStringID.InvalidProgramNonStaticMethod, originalMethod);
-                }
-
-                if (originalMethod.HasInstantiation || originalMethod.OwningType.HasInstantiation) // No generics involved
-                {
-                    ThrowHelper.ThrowInvalidProgramException(ExceptionStringID.InvalidProgramGenericMethod, originalMethod);
-                }
-
-                if (Marshaller.IsMarshallingRequired(originalMethod.Signature, Array.Empty<ParameterMetadata>())) // Only blittable arguments
-                {
-                    ThrowHelper.ThrowInvalidProgramException(ExceptionStringID.InvalidProgramNonBlittableTypes, originalMethod);
-                }
-            }
-
             exactType = type;
 
             constrainedType = null;
index 791421c..20a724b 100644 (file)
 #include "cgensys.h"
 #include "asmconstants.h"
 #include "virtualcallstub.h"
-#include "callingconvention.h"
-#include "customattribute.h"
 #include "typestring.h"
-#include "../md/compiler/custattr.h"
 #ifdef FEATURE_COMINTEROP
 #include "comcallablewrapper.h"
 #endif // FEATURE_COMINTEROP
@@ -1133,59 +1130,23 @@ void COMDelegate::BindToMethod(DELEGATEREF   *pRefThis,
 }
 
 #if defined(TARGET_X86)
-// Marshals a managed method to an unmanaged callback provided the
-// managed method is static and it's parameters require no marshalling.
-PCODE COMDelegate::ConvertToCallback(MethodDesc* pMD)
+// Marshals a managed method to an unmanaged callback.
+PCODE COMDelegate::ConvertToUnmanagedCallback(MethodDesc* pMD)
 {
     CONTRACTL
     {
         THROWS;
         GC_TRIGGERS;
         PRECONDITION(pMD != NULL);
+        PRECONDITION(pMD->HasUnmanagedCallersOnlyAttribute());
         INJECT_FAULT(COMPlusThrowOM());
     }
     CONTRACTL_END;
 
-    PCODE pCode = NULL;
-
     // Get UMEntryThunk from the thunk cache.
     UMEntryThunk *pUMEntryThunk = pMD->GetLoaderAllocator()->GetUMEntryThunkCache()->GetUMEntryThunk(pMD);
 
-#if !defined(FEATURE_STUBS_AS_IL)
-
-    // System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute
-    BYTE* pData = NULL;
-    LONG cData = 0;
-    CorPinvokeMap callConv = (CorPinvokeMap)0;
-
-    HRESULT hr = pMD->GetCustomAttribute(WellKnownAttribute::UnmanagedCallersOnly, (const VOID **)(&pData), (ULONG *)&cData);
-    IfFailThrow(hr);
-
-    if (cData > 0)
-    {
-        CustomAttributeParser ca(pData, cData);
-        // UnmanagedCallersOnly has two optional named arguments CallingConvention and EntryPoint.
-        CaNamedArg namedArgs[2];
-        CaTypeCtor caType(SERIALIZATION_TYPE_STRING);
-        // First, the void constructor.
-        IfFailThrow(ParseKnownCaArgs(ca, NULL, 0));
-
-        // Now the optional named properties
-        namedArgs[0].InitI4FieldEnum("CallingConvention", "System.Runtime.InteropServices.CallingConvention", (ULONG)callConv);
-        namedArgs[1].Init("EntryPoint", SERIALIZATION_TYPE_STRING, caType);
-        IfFailThrow(ParseKnownCaNamedArgs(ca, namedArgs, lengthof(namedArgs)));
-
-        callConv = (CorPinvokeMap)(namedArgs[0].val.u4 << 8);
-        // Let UMThunkMarshalInfo choose the default if calling convension not definied.
-        if (namedArgs[0].val.type.tag != SERIALIZATION_TYPE_UNDEFINED)
-        {
-            UMThunkMarshInfo* pUMThunkMarshalInfo = pUMEntryThunk->GetUMThunkMarshInfo();
-            pUMThunkMarshalInfo->SetCallingConvention(callConv);
-        }
-}
-#endif  // !FEATURE_STUBS_AS_IL
-
-    pCode = (PCODE)pUMEntryThunk->GetCode();
+    PCODE pCode = (PCODE)pUMEntryThunk->GetCode();
     _ASSERTE(pCode != NULL);
     return pCode;
 }
@@ -2140,6 +2101,29 @@ FCIMPLEND
 
 #endif // CROSSGEN_COMPILE
 
+void COMDelegate::ThrowIfInvalidUnmanagedCallersOnlyUsage(MethodDesc* pMD)
+{
+    CONTRACTL
+    {
+        THROWS;
+        GC_TRIGGERS;
+        PRECONDITION(pMD != NULL);
+        PRECONDITION(pMD->HasUnmanagedCallersOnlyAttribute());
+    }
+    CONTRACTL_END;
+
+    if (!pMD->IsStatic())
+        EX_THROW(EEResourceException, (kInvalidProgramException, W("InvalidProgram_NonStaticMethod")));
+
+    // No generic methods
+    if (pMD->HasClassOrMethodInstantiation())
+        EX_THROW(EEResourceException, (kInvalidProgramException, W("InvalidProgram_GenericMethod")));
+
+    // Arguments
+    if (NDirect::MarshalingRequired(pMD, pMD->GetSig(), pMD->GetModule()))
+        EX_THROW(EEResourceException, (kInvalidProgramException, W("InvalidProgram_NonBlittableTypes")));
+}
+
 BOOL COMDelegate::NeedsWrapperDelegate(MethodDesc* pTargetMD)
 {
     LIMITED_METHOD_CONTRACT;
index d11d79f..39db2e7 100644 (file)
@@ -87,7 +87,7 @@ public:
 #if defined(TARGET_X86)
     // Marshals a managed method to an unmanaged callback.
     // This is only used on x86. See usage for further details.
-    static PCODE ConvertToCallback(MethodDesc* pMD);
+    static PCODE ConvertToUnmanagedCallback(MethodDesc* pMD);
 #endif // defined(TARGET_X86)
 
     // Marshals an unmanaged callback to Delegate
@@ -127,6 +127,10 @@ public:
 
     static BOOL IsTrueMulticastDelegate(OBJECTREF delegate);
 
+    // Throw if the method violates any usage restrictions
+    // for UnmanagedCallersOnlyAttribute.
+    static void ThrowIfInvalidUnmanagedCallersOnlyUsage(MethodDesc* pMD);
+
 private:
     static Stub* SetupShuffleThunk(MethodTable * pDelMT, MethodDesc *pTargetMeth);
 
index e52bef8..15c2bd1 100644 (file)
@@ -801,11 +801,8 @@ HRESULT CorHost2::CreateDelegate(
 
         if (pMD->HasUnmanagedCallersOnlyAttribute())
         {
-            if (NDirect::MarshalingRequired(pMD, pMD->GetSig(), pMD->GetModule()))
-                ThrowHR(COR_E_INVALIDPROGRAM);
-
 #ifdef TARGET_X86
-            *fnPtr = (INT_PTR)COMDelegate::ConvertToCallback(pMD);
+            *fnPtr = (INT_PTR)COMDelegate::ConvertToUnmanagedCallback(pMD);
 #else
             *fnPtr = pMD->GetMultiCallableAddrOfCode();
 #endif
index 4ecd77d..9ec4153 100644 (file)
@@ -21,6 +21,8 @@
 #include "dbginterface.h"
 #include "stubgen.h"
 #include "appdomain.inl"
+#include "callingconvention.h"
+#include "customattribute.h"
 
 #ifndef CROSSGEN_COMPILE
 
@@ -612,6 +614,43 @@ VOID UMEntryThunk::CompileUMThunkWorker(UMThunkStubInfo *pInfo,
     pcpusl->X86EmitNearJump(pEnableRejoin);
 }
 
+VOID UMThunkMarshInfo::SetUpForUnmanagedCallersOnly()
+{
+    STANDARD_VM_CONTRACT;
+
+    MethodDesc* pMD = GetMethod();
+    _ASSERTE(pMD != NULL && pMD->HasUnmanagedCallersOnlyAttribute());
+
+    // Validate UnmanagedCallersOnlyAttribute usage
+    COMDelegate::ThrowIfInvalidUnmanagedCallersOnlyUsage(pMD);
+
+    BYTE* pData = NULL;
+    LONG cData = 0;
+    CorPinvokeMap callConv = (CorPinvokeMap)0;
+
+    HRESULT hr = pMD->GetCustomAttribute(WellKnownAttribute::UnmanagedCallersOnly, (const VOID **)(&pData), (ULONG *)&cData);
+    IfFailThrow(hr);
+
+    _ASSERTE(cData > 0);
+
+    CustomAttributeParser ca(pData, cData);
+    // UnmanagedCallersOnly has two optional named arguments CallingConvention and EntryPoint.
+    CaNamedArg namedArgs[2];
+    CaTypeCtor caType(SERIALIZATION_TYPE_STRING);
+    // First, the void constructor.
+    IfFailThrow(ParseKnownCaArgs(ca, NULL, 0));
+
+    // Now the optional named properties
+    namedArgs[0].InitI4FieldEnum("CallingConvention", "System.Runtime.InteropServices.CallingConvention", (ULONG)callConv);
+    namedArgs[1].Init("EntryPoint", SERIALIZATION_TYPE_STRING, caType);
+    IfFailThrow(ParseKnownCaNamedArgs(ca, namedArgs, lengthof(namedArgs)));
+
+    callConv = (CorPinvokeMap)(namedArgs[0].val.u4 << 8);
+    // Let UMThunkMarshalInfo choose the default if calling convension not definied.
+    if (namedArgs[0].val.type.tag != SERIALIZATION_TYPE_UNDEFINED)
+        m_callConv = (UINT16)callConv;
+}
+
 // Compiles an unmanaged to managed thunk for the given signature.
 Stub *UMThunkMarshInfo::CompileNExportThunk(LoaderHeap *pLoaderHeap, PInvokeStaticSigInfo* pSigInfo, MetaSig *pMetaSig, BOOL fNoStub)
 {
@@ -721,7 +760,9 @@ Stub *UMThunkMarshInfo::CompileNExportThunk(LoaderHeap *pLoaderHeap, PInvokeStat
 
     m_cbActualArgSize = cbActualArgSize;
 
-    m_callConv = static_cast<UINT16>(pSigInfo->GetCallConv());
+    // This could have been set in the UnmanagedCallersOnly scenario.
+    if (m_callConv == UINT16_MAX)
+        m_callConv = static_cast<UINT16>(pSigInfo->GetCallConv());
 
     UMThunkStubInfo stubInfo;
     memset(&stubInfo, 0, sizeof(stubInfo));
@@ -1117,6 +1158,7 @@ VOID UMThunkMarshInfo::LoadTimeInit(Signature sig, Module * pModule, MethodDesc
     m_sig = sig;
 
 #if defined(TARGET_X86) && !defined(FEATURE_STUBS_AS_IL)
+    m_callConv = UINT16_MAX;
     INDEBUG(m_cbRetPop = 0xcccc;)
 #endif
 }
@@ -1142,6 +1184,14 @@ VOID UMThunkMarshInfo::RunTimeInit()
 
     MethodDesc * pMD = GetMethod();
 
+#if defined(TARGET_X86) && !defined(FEATURE_STUBS_AS_IL)
+    if (pMD != NULL
+        && pMD->HasUnmanagedCallersOnlyAttribute())
+    {
+        SetUpForUnmanagedCallersOnly();
+    }
+#endif // TARGET_X86 && !FEATURE_STUBS_AS_IL
+
     // Lookup NGened stub - currently we only support ngening of reverse delegate invoke interop stubs
     if (pMD != NULL && pMD->IsEEImpl())
     {
index 12bc89a..b92b8e8 100644 (file)
@@ -146,26 +146,6 @@ public:
         return m_cbRetPop;
     }
 
-    CorPinvokeMap GetCallingConvention()
-    {
-        CONTRACTL
-        {
-            NOTHROW;
-            GC_NOTRIGGER;
-            MODE_ANY;
-            SUPPORTS_DAC;
-            PRECONDITION(IsCompletelyInited());
-        }
-        CONTRACTL_END;
-
-        return (CorPinvokeMap)m_callConv;
-    }
-
-    VOID SetCallingConvention(const CorPinvokeMap callConv)
-    {
-        m_callConv = (UINT16)callConv;
-    }
-
 #else
     PCODE GetExecStubEntryPoint();
 #endif
@@ -195,6 +175,9 @@ public:
 
     VOID SetupArguments(char *pSrc, ArgumentRegisters *pArgRegs, char *pDst);
 #else
+private:
+    VOID SetUpForUnmanagedCallersOnly();
+
     // Compiles an unmanaged to managed thunk for the given signature. The thunk
     // will call the stub or, if fNoStub == TRUE, directly the managed target.
     Stub *CompileNExportThunk(LoaderHeap *pLoaderHeap, PInvokeStaticSigInfo* pSigInfo, MetaSig *pMetaSig, BOOL fNoStub);
index 552cffc..8e29a10 100644 (file)
@@ -5095,22 +5095,6 @@ void CEEInfo::getCallInfo(
         EX_THROW(EEMessageException, (kMissingMethodException, IDS_EE_MISSING_METHOD, W("?")));
     }
 
-    // If this call is for a LDFTN and the target method has the UnmanagedCallersOnlyAttribute,
-    // then validate it adheres to the limitations.
-    if ((flags & CORINFO_CALLINFO_LDFTN) && pMD->HasUnmanagedCallersOnlyAttribute())
-    {
-        if (!pMD->IsStatic())
-            EX_THROW(EEResourceException, (kInvalidProgramException, W("InvalidProgram_NonStaticMethod")));
-
-        // No generic methods
-        if (pMD->HasClassOrMethodInstantiation())
-            EX_THROW(EEResourceException, (kInvalidProgramException, W("InvalidProgram_GenericMethod")));
-
-        // Arguments
-        if (NDirect::MarshalingRequired(pMD, pMD->GetSig(), pMD->GetModule()))
-            EX_THROW(EEResourceException, (kInvalidProgramException, W("InvalidProgram_NonBlittableTypes")));
-    }
-
     TypeHandle exactType = TypeHandle(pResolvedToken->hClass);
 
     TypeHandle constrainedType;
@@ -9225,7 +9209,7 @@ void CEEInfo::getFunctionFixedEntryPoint(CORINFO_METHOD_HANDLE   ftn,
     // https://github.com/dotnet/runtime/issues/33582
     if (pMD->HasUnmanagedCallersOnlyAttribute())
     {
-        pResult->addr = (void*)COMDelegate::ConvertToCallback(pMD);
+        pResult->addr = (void*)COMDelegate::ConvertToUnmanagedCallback(pMD);
     }
     else
     {
@@ -12440,7 +12424,10 @@ CorJitResult CallCompileMethodWithSEHWrapper(EEJitManager *jitMgr,
 
 #if !defined(TARGET_X86)
     if (ftn->HasUnmanagedCallersOnlyAttribute())
+    {
+        COMDelegate::ThrowIfInvalidUnmanagedCallersOnlyUsage(ftn);
         flags.Set(CORJIT_FLAGS::CORJIT_FLAG_REVERSE_PINVOKE);
+    }
 #endif // !TARGET_X86
 
     return flags;
index 2a80972..a16bc1a 100644 (file)
@@ -22,6 +22,7 @@
 #include "eeconfig.h"
 #include "eehash.h"
 #include "interoputil.h"
+#include "comdelegate.h"
 #include "typedesc.h"
 #include "virtualcallstub.h"
 #include "contractimpl.h"
@@ -1759,7 +1760,22 @@ void * QCALLTYPE RuntimeMethodHandle::GetFunctionPointer(MethodDesc * pMethod)
     // Ensure the method is active so
     // the function pointer can be used.
     pMethod->EnsureActive();
+
+#if defined(TARGET_X86)
+    // Deferring X86 support until a need is observed or
+    // time permits investigation into all the potential issues.
+    // https://github.com/dotnet/runtime/issues/33582
+    if (pMethod->HasUnmanagedCallersOnlyAttribute())
+    {
+        funcPtr = (void*)COMDelegate::ConvertToUnmanagedCallback(pMethod);
+    }
+    else
+    {
+        funcPtr = (void*)pMethod->GetMultiCallableAddrOfCode();
+    }
+#else
     funcPtr = (void*)pMethod->GetMultiCallableAddrOfCode();
+#endif
 
     END_QCALL;
 
index ca2800b..7e0dd1a 100644 (file)
     typedef char char_t;
 #endif
 
+#define UNMANAGEDCALLERSONLY_METHOD ((const char_t*)-1)
+
 // Signature of delegate returned by coreclr_delegate_type::load_assembly_and_get_function_pointer
 typedef int (CORECLR_DELEGATE_CALLTYPE *load_assembly_and_get_function_pointer_fn)(
     const char_t *assembly_path      /* Fully qualified path to assembly */,
     const char_t *type_name          /* Assembly qualified type name */,
     const char_t *method_name        /* Public static method name compatible with delegateType */,
-    const char_t *delegate_type_name /* Assembly qualified delegate type name or null */,
+    const char_t *delegate_type_name /* Assembly qualified delegate type name or null
+                                        or UNMANAGEDCALLERSONLY_METHOD if the method is marked with
+                                        the UnmanagedCallersOnlyAttribute. */,
     void         *reserved           /* Extensibility parameter (currently unused and must be 0) */,
     /*out*/ void **delegate          /* Pointer where to store the function pointer result */);
 
index 708fa2b..f221543 100644 (file)
@@ -6,6 +6,7 @@
 #include <pal.h>
 #include <error_codes.h>
 #include <future>
+#include <array>
 #include <hostfxr.h>
 #include <coreclr_delegates.h>
 #include <corehost_context_contract.h>
@@ -225,6 +226,74 @@ namespace
         return -1;
     }
 
+    struct _printable_delegate_name_t
+    {
+        const pal::char_t* name;
+    };
+
+    std::basic_ostream<pal::char_t>& operator<<(std::basic_ostream<pal::char_t>& stream, const _printable_delegate_name_t &p)
+    {
+        if (p.name == nullptr)
+        {
+            return stream << _X("nullptr");
+        }
+        else if (p.name == UNMANAGEDCALLERSONLY_METHOD)
+        {
+            return stream << _X("UNMANAGEDCALLERSONLY_METHOD");
+        }
+        else
+        {
+            return stream << _X("\"") << p.name << _X("\"");
+        }
+    }
+
+    const _printable_delegate_name_t to_printable_delegate_name(const pal::char_t *delegate_name)
+    {
+        return _printable_delegate_name_t{ delegate_name };
+    }
+
+    int call_delegate_flavour(
+        load_assembly_and_get_function_pointer_fn delegate,
+        const pal::char_t *assembly_path,
+        const pal::char_t *type_name,
+        const pal::char_t *method_name,
+        const pal::char_t *log_prefix,
+        pal::stringstream_t &test_output)
+    {
+        const pal::char_t *delegate_name = nullptr;
+        pal::string_t method_name_local{ method_name };
+        if (pal::string_t::npos != method_name_local.find(_X("Unmanaged")))
+            delegate_name = UNMANAGEDCALLERSONLY_METHOD;
+
+        test_output << log_prefix << _X("calling load_assembly_and_get_function_pointer(\"")
+            << assembly_path << _X("\", \"")
+            << type_name << _X("\", \"")
+            << method_name << _X("\", ")
+            << to_printable_delegate_name(delegate_name) << _X(", ")
+            << _X("nullptr, &componentEntryPointDelegate)")
+            << std::endl;
+
+        component_entry_point_fn componentEntryPointDelegate = nullptr;
+        int rc = delegate(assembly_path,
+                        type_name,
+                        method_name,
+                        delegate_name,
+                        nullptr /* reserved */,
+                        (void **)&componentEntryPointDelegate);
+
+        if (rc != StatusCode::Success)
+        {
+            test_output << log_prefix << _X("load_assembly_and_get_function_pointer failed: ") << std::hex << std::showbase << rc << std::endl;
+        }
+        else
+        {
+            test_output << log_prefix << _X("load_assembly_and_get_function_pointer succeeded: ") << std::hex << std::showbase << rc << std::endl;
+            rc = call_delegate_with_try_except(componentEntryPointDelegate, method_name, log_prefix, test_output);
+        }
+
+        return rc;
+    }
+
     bool load_assembly_and_get_function_pointer_test(
         const hostfxr_exports &hostfxr,
         const pal::char_t *config_path,
@@ -258,31 +327,7 @@ namespace
             else
             {
                 test_output << log_prefix << _X("hostfxr_get_runtime_delegate succeeded: ") << std::hex << std::showbase << rc << std::endl;
-
-                test_output << log_prefix << _X("calling load_assembly_and_get_function_pointer(\"")
-                    << assembly_path << _X("\", \"")
-                    << type_name << _X("\", \"")
-                    << method_name << _X("\", \"")
-                    << _X("nullptr, nullptr, &componentEntryPointDelegate)")
-                    << std::endl;
-
-                component_entry_point_fn componentEntryPointDelegate = nullptr;
-                rc = delegate(assembly_path,
-                              type_name,
-                              method_name,
-                              nullptr /* delegateTypeNative */,
-                              nullptr /* reserved */,
-                              (void **)&componentEntryPointDelegate);
-
-                if (rc != StatusCode::Success)
-                {
-                    test_output << log_prefix << _X("load_assembly_and_get_function_pointer failed: ") << std::hex << std::showbase << rc << std::endl;
-                }
-                else
-                {
-                    test_output << log_prefix << _X("load_assembly_and_get_function_pointer succeeded: ") << std::hex << std::showbase << rc << std::endl;
-                    rc = call_delegate_with_try_except(componentEntryPointDelegate, method_name, log_prefix, test_output);
-                }
+                rc = call_delegate_flavour(delegate, assembly_path, type_name, method_name, log_prefix, test_output);
             }
         }
 
@@ -525,4 +570,4 @@ bool host_context_test::load_assembly_and_get_function_pointer(
     hostfxr_exports hostfxr{ hostfxr_path };
 
     return load_assembly_and_get_function_pointer_test(hostfxr, config_path, argc, argv, config_log_prefix, test_output);
-}
\ No newline at end of file
+}
index bd8b496..19a1606 100644 (file)
@@ -3,6 +3,7 @@
 // See the LICENSE file in the project root for more information.
 
 using System;
+using System.Runtime.InteropServices;
 
 namespace Component
 {
@@ -11,12 +12,18 @@ namespace Component
         private static int componentCallCount = 0;
         private static int entryPoint1CallCount = 0;
         private static int entryPoint2CallCount = 0;
+        private static int unmanagedEntryPoint1CallCount = 0;
+
+        private static void PrintComponentCallLog(string name, IntPtr arg, int size)
+        {
+            Console.WriteLine($"Called {name}(0x{arg.ToString("x")}, {size}) - component call count: {componentCallCount}");
+        }
 
         public static int ComponentEntryPoint1(IntPtr arg, int size)
         {
             componentCallCount++;
             entryPoint1CallCount++;
-            Console.WriteLine($"Called {nameof(ComponentEntryPoint1)}(0x{arg.ToString("x")}, {size}) - component call count: {componentCallCount}");
+            PrintComponentCallLog(nameof(ComponentEntryPoint1), arg, size);
             return entryPoint1CallCount;
         }
 
@@ -24,15 +31,24 @@ namespace Component
         {
             componentCallCount++;
             entryPoint2CallCount++;
-            Console.WriteLine($"Called {nameof(ComponentEntryPoint2)}(0x{arg.ToString("x")}, {size}) - component call count: {componentCallCount}");
+            PrintComponentCallLog(nameof(ComponentEntryPoint2), arg, size);
             return entryPoint2CallCount;
         }
 
         public static int ThrowException(IntPtr arg, int size)
         {
             componentCallCount++;
-            Console.WriteLine($"Called {nameof(ThrowException)}(0x{arg.ToString("x")}, {size}) - component call count: {componentCallCount}");
+            PrintComponentCallLog(nameof(ThrowException), arg, size);
             throw new InvalidOperationException(nameof(ThrowException));
         }
+
+        [UnmanagedCallersOnly]
+        public static int UnmanagedComponentEntryPoint1(IntPtr arg, int size)
+        {
+            componentCallCount++;
+            unmanagedEntryPoint1CallCount++;
+            PrintComponentCallLog(nameof(UnmanagedComponentEntryPoint1), arg, size);
+            return unmanagedEntryPoint1CallCount;
+        }
     }
 }
\ No newline at end of file
index e123787..3e9c6b5 100644 (file)
@@ -57,9 +57,11 @@ namespace Microsoft.DotNet.CoreSetup.Test.HostActivation.NativeHosting
         }
 
         [Theory]
-        [InlineData(1)]
-        [InlineData(10)]
-        public void CallDelegate_MultipleEntryPoints(int callCount)
+        [InlineData(1, false)]
+        [InlineData(1, true)]
+        [InlineData(10, false)]
+        [InlineData(10, true)]
+        public void CallDelegate_MultipleEntryPoints(int callCount, bool callUnmanaged)
         {
             var componentProject = sharedState.ComponentWithNoDependenciesFixture.TestProject;
             string[] baseArgs =
@@ -68,12 +70,14 @@ namespace Microsoft.DotNet.CoreSetup.Test.HostActivation.NativeHosting
                 sharedState.HostFxrPath,
                 componentProject.RuntimeConfigJson,
             };
+
+            string comp1Name = callUnmanaged ? sharedState.UnmanagedComponentEntryPoint1 : sharedState.ComponentEntryPoint1;
             string[] componentInfo =
             {
-                // ComponentEntryPoint1
+                // [Unmanaged]ComponentEntryPoint1
                 componentProject.AppDll,
                 sharedState.ComponentTypeName,
-                sharedState.ComponentEntryPoint1,
+                comp1Name,
                 // ComponentEntryPoint2
                 componentProject.AppDll,
                 sharedState.ComponentTypeName,
@@ -95,7 +99,7 @@ namespace Microsoft.DotNet.CoreSetup.Test.HostActivation.NativeHosting
             for (int i = 1; i <= callCount; ++i)
             {
                 result.Should()
-                    .ExecuteComponentEntryPoint(sharedState.ComponentEntryPoint1, i * 2 - 1, i)
+                    .ExecuteComponentEntryPoint(comp1Name, i * 2 - 1, i)
                     .And.ExecuteComponentEntryPoint(sharedState.ComponentEntryPoint2, i * 2, i);
             }
         }
@@ -175,6 +179,7 @@ namespace Microsoft.DotNet.CoreSetup.Test.HostActivation.NativeHosting
             public string ComponentTypeName { get; }
             public string ComponentEntryPoint1 => "ComponentEntryPoint1";
             public string ComponentEntryPoint2 => "ComponentEntryPoint2";
+            public string UnmanagedComponentEntryPoint1 => "UnmanagedComponentEntryPoint1";
 
             public SharedTestState()
             {
index 9f004bd..5957640 100644 (file)
   <data name="NotSupported_NoCodepageData" xml:space="preserve">
     <value>No data is available for encoding {0}. For information on defining a custom encoding, see the documentation for the Encoding.RegisterProvider method.</value>
   </data>
+  <data name="InvalidOperation_FunctionMissingUnmanagedCallersOnly" xml:space="preserve">
+    <value>Function not marked with UnmanagedCallersOnlyAttribute.</value>
+  </data>
   <data name="InvalidProgram_NonBlittableTypes" xml:space="preserve">
     <value>Non-blittable parameter types are invalid for UnmanagedCallersOnly methods.</value>
   </data>