From 7fa7d854adba15c6214c90e555f9bc52aee5891f Mon Sep 17 00:00:00 2001 From: Elinor Fung <47805090+elinor-fung@users.noreply.github.com> Date: Wed, 8 Apr 2020 14:07:13 -0700 Subject: [PATCH] Make NativeLibrary Load/TryLoad use ALC extension points for the specified assembly (#34519) --- src/coreclr/src/pal/inc/pal.h | 1 - src/coreclr/src/vm/dllimport.cpp | 860 ++++++++++----------- src/coreclr/src/vm/dllimport.h | 10 +- .../ResolveUnmanagedDllTests.cs | 106 ++- .../NativeLibrary/Callback/CallbackStressTest.cs | 43 +- .../NativeLibrary/Callback/CallbackTests.cs | 130 ++-- .../NativeLibrary/Callback/CallbackTests.csproj | 1 + .../NativeLibraryToLoad/NativeLibraryToLoad.cs | 2 +- .../Runtime/InteropServices/NativeLibrary.cs | 23 +- 9 files changed, 648 insertions(+), 528 deletions(-) diff --git a/src/coreclr/src/pal/inc/pal.h b/src/coreclr/src/pal/inc/pal.h index a7751ee..7926b29 100644 --- a/src/coreclr/src/pal/inc/pal.h +++ b/src/coreclr/src/pal/inc/pal.h @@ -72,7 +72,6 @@ extern "C" { // Native system libray handle. // On Unix systems, NATIVE_LIBRARY_HANDLE type represents a library handle not registered with the PAL. -// To get a HMODULE on Unix, call PAL_RegisterLibraryDirect() on a NATIVE_LIBRARY_HANDLE. typedef PVOID NATIVE_LIBRARY_HANDLE; /******************* Processor-specific glue *****************************/ diff --git a/src/coreclr/src/vm/dllimport.cpp b/src/coreclr/src/vm/dllimport.cpp index 4875c57..0046ed9 100644 --- a/src/coreclr/src/vm/dllimport.cpp +++ b/src/coreclr/src/vm/dllimport.cpp @@ -6306,9 +6306,7 @@ private: SString m_message; }; // class LoadLibErrorTracker -// Load the library directly. On Unix systems, don't register it yet with PAL. -// * External callers like System.Runtime.InteropServices.NativeLibrary.Load() need the raw system handle -// * Internal callers like LoadLibraryModule() can convert this handle to a HMODULE via PAL APIs on Unix +// Load the library directly and return the raw system handle static NATIVE_LIBRARY_HANDLE LocalLoadLibraryHelper( LPCWSTR name, DWORD flags, LoadLibErrorTracker *pErrorTracker ) { STANDARD_VM_CONTRACT; @@ -6432,65 +6430,6 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryFromPath(LPCWSTR libraryPath, BOOL thr } // static -NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryByName(LPCWSTR libraryName, Assembly *callingAssembly, - BOOL hasDllImportSearchFlags, DWORD dllImportSearchFlags, - BOOL throwOnError) -{ - CONTRACTL - { - STANDARD_VM_CHECK; - PRECONDITION(CheckPointer(libraryName)); - PRECONDITION(CheckPointer(callingAssembly)); - } - CONTRACTL_END; - - LoadLibErrorTracker errorTracker; - - // First checks if a default dllImportSearchPathFlags was passed in, if so, use that value. - // Otherwise checks if the assembly has the DefaultDllImportSearchPathsAttribute attribute. - // If so, use that value. - BOOL searchAssemblyDirectory; - DWORD dllImportSearchPathFlags; - - if (hasDllImportSearchFlags) - { - dllImportSearchPathFlags = dllImportSearchFlags & ~DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY; - searchAssemblyDirectory = dllImportSearchFlags & DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY; - - } - else - { - GetDllImportSearchPathFlags(callingAssembly->GetManifestModule(), - &dllImportSearchPathFlags, &searchAssemblyDirectory); - } - - NATIVE_LIBRARY_HANDLE hmod = - LoadLibraryModuleBySearch(callingAssembly, searchAssemblyDirectory, dllImportSearchPathFlags, &errorTracker, libraryName); - - if (throwOnError && (hmod == nullptr)) - { - SString libraryPathSString(libraryName); - errorTracker.Throw(libraryPathSString); - } - - return hmod; -} - -// static -NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleBySearch(NDirectMethodDesc * pMD, LoadLibErrorTracker * pErrorTracker, PCWSTR wszLibName) -{ - STANDARD_VM_CONTRACT; - - BOOL searchAssemblyDirectory; - DWORD dllImportSearchPathFlags; - - GetDllImportSearchPathFlags(pMD, &dllImportSearchPathFlags, &searchAssemblyDirectory); - - Assembly* pAssembly = pMD->GetMethodTable()->GetAssembly(); - return LoadLibraryModuleBySearch(pAssembly, searchAssemblyDirectory, dllImportSearchPathFlags, pErrorTracker, wszLibName); -} - -// static void NDirect::FreeNativeLibrary(NATIVE_LIBRARY_HANDLE handle) { STANDARD_VM_CONTRACT; @@ -6532,483 +6471,547 @@ INT_PTR NDirect::GetNativeLibraryExport(NATIVE_LIBRARY_HANDLE handle, LPCWSTR sy return address; } -#ifndef TARGET_UNIX -BOOL IsWindowsAPISet(PCWSTR wszLibName) +namespace { - STANDARD_VM_CONTRACT; +#ifndef TARGET_UNIX + BOOL IsWindowsAPISet(PCWSTR wszLibName) + { + STANDARD_VM_CONTRACT; - // This is replicating quick check from the OS implementation of api sets. - return SString::_wcsnicmp(wszLibName, W("api-"), 4) == 0 || - SString::_wcsnicmp(wszLibName, W("ext-"), 4) == 0; -} + // This is replicating quick check from the OS implementation of api sets. + return SString::_wcsnicmp(wszLibName, W("api-"), 4) == 0 || + SString::_wcsnicmp(wszLibName, W("ext-"), 4) == 0; + } #endif // !TARGET_UNIX -// static -NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleViaHost(NDirectMethodDesc * pMD, PCWSTR wszLibName) -{ - STANDARD_VM_CONTRACT; - //Dynamic Pinvoke Support: - //Check if we need to provide the host a chance to provide the unmanaged dll + NATIVE_LIBRARY_HANDLE LoadNativeLibraryViaAssemblyLoadContext(Assembly * pAssembly, PCWSTR wszLibName) + { + STANDARD_VM_CONTRACT; #ifndef TARGET_UNIX - if (IsWindowsAPISet(wszLibName)) - { - // Prevent Overriding of Windows API sets. - return NULL; - } + if (IsWindowsAPISet(wszLibName)) + { + // Prevent Overriding of Windows API sets. + return NULL; + } #endif // !TARGET_UNIX - NATIVE_LIBRARY_HANDLE hmod = NULL; - AppDomain* pDomain = GetAppDomain(); - CLRPrivBinderCoreCLR *pTPABinder = pDomain->GetTPABinderContext(); - Assembly* pAssembly = pMD->GetMethodTable()->GetAssembly(); + NATIVE_LIBRARY_HANDLE hmod = NULL; + AppDomain* pDomain = GetAppDomain(); + CLRPrivBinderCoreCLR *pTPABinder = pDomain->GetTPABinderContext(); - PEFile *pManifestFile = pAssembly->GetManifestFile(); - PTR_ICLRPrivBinder pBindingContext = pManifestFile->GetBindingContext(); + PEFile *pManifestFile = pAssembly->GetManifestFile(); + PTR_ICLRPrivBinder pBindingContext = pManifestFile->GetBindingContext(); - //Step 0: Check if the assembly was bound using TPA. - // The Binding Context can be null or an overridden TPA context - if (pBindingContext == NULL) - { - // If we do not have any binder associated, then return to the default resolution mechanism. - return NULL; - } + //Step 0: Check if the assembly was bound using TPA. + // The Binding Context can be null or an overridden TPA context + if (pBindingContext == NULL) + { + // If we do not have any binder associated, then return to the default resolution mechanism. + return NULL; + } - UINT_PTR assemblyBinderID = 0; - IfFailThrow(pBindingContext->GetBinderID(&assemblyBinderID)); + UINT_PTR assemblyBinderID = 0; + IfFailThrow(pBindingContext->GetBinderID(&assemblyBinderID)); - ICLRPrivBinder *pCurrentBinder = reinterpret_cast(assemblyBinderID); + ICLRPrivBinder *pCurrentBinder = reinterpret_cast(assemblyBinderID); - // For assemblies bound via TPA binder, we should use the standard mechanism to make the pinvoke call. - if (AreSameBinderInstance(pCurrentBinder, pTPABinder)) - { - return NULL; - } + // For assemblies bound via TPA binder, we should use the standard mechanism to make the pinvoke call. + if (AreSameBinderInstance(pCurrentBinder, pTPABinder)) + { + return NULL; + } #ifdef FEATURE_COMINTEROP - CLRPrivBinderWinRT *pWinRTBinder = pDomain->GetWinRtBinder(); - if (AreSameBinderInstance(pCurrentBinder, pWinRTBinder)) - { - // We could be here when a non-WinRT assembly load is triggerred by a winmd (e.g. System.Runtime being loaded due to - // types being referenced from Windows.Foundation.Winmd) or when dealing with a winmd (which is bound using WinRT binder). - // - // For this, we should use the standard mechanism to make pinvoke call as well. - return NULL; - } + CLRPrivBinderWinRT *pWinRTBinder = pDomain->GetWinRtBinder(); + if (AreSameBinderInstance(pCurrentBinder, pWinRTBinder)) + { + // We could be here when a non-WinRT assembly load is triggerred by a winmd (e.g. System.Runtime being loaded due to + // types being referenced from Windows.Foundation.Winmd) or when dealing with a winmd (which is bound using WinRT binder). + // + // For this, we should use the standard mechanism to make pinvoke call as well. + return NULL; + } #endif // FEATURE_COMINTEROP - //Step 1: If the assembly was not bound using TPA, - // Call System.Runtime.Loader.AssemblyLoadContext.ResolveUnamanagedDll to give - // The custom assembly context a chance to load the unmanaged dll. + //Step 1: If the assembly was not bound using TPA, + // Call System.Runtime.Loader.AssemblyLoadContext.ResolveUnmanagedDll to give + // The custom assembly context a chance to load the unmanaged dll. - GCX_COOP(); + GCX_COOP(); - STRINGREF pUnmanagedDllName; - pUnmanagedDllName = StringObject::NewString(wszLibName); + STRINGREF pUnmanagedDllName; + pUnmanagedDllName = StringObject::NewString(wszLibName); - GCPROTECT_BEGIN(pUnmanagedDllName); + GCPROTECT_BEGIN(pUnmanagedDllName); - // Get the pointer to the managed assembly load context - INT_PTR ptrManagedAssemblyLoadContext = ((CLRPrivBinderAssemblyLoadContext *)pCurrentBinder)->GetManagedAssemblyLoadContext(); + // Get the pointer to the managed assembly load context + INT_PTR ptrManagedAssemblyLoadContext = ((CLRPrivBinderAssemblyLoadContext *)pCurrentBinder)->GetManagedAssemblyLoadContext(); - // Prepare to invoke System.Runtime.Loader.AssemblyLoadContext.ResolveUnamanagedDll method. - PREPARE_NONVIRTUAL_CALLSITE(METHOD__ASSEMBLYLOADCONTEXT__RESOLVEUNMANAGEDDLL); - DECLARE_ARGHOLDER_ARRAY(args, 2); - args[ARGNUM_0] = STRINGREF_TO_ARGHOLDER(pUnmanagedDllName); - args[ARGNUM_1] = PTR_TO_ARGHOLDER(ptrManagedAssemblyLoadContext); + // Prepare to invoke System.Runtime.Loader.AssemblyLoadContext.ResolveUnmanagedDll method. + PREPARE_NONVIRTUAL_CALLSITE(METHOD__ASSEMBLYLOADCONTEXT__RESOLVEUNMANAGEDDLL); + DECLARE_ARGHOLDER_ARRAY(args, 2); + args[ARGNUM_0] = STRINGREF_TO_ARGHOLDER(pUnmanagedDllName); + args[ARGNUM_1] = PTR_TO_ARGHOLDER(ptrManagedAssemblyLoadContext); - // Make the call - CALL_MANAGED_METHOD(hmod, NATIVE_LIBRARY_HANDLE, args); + // Make the call + CALL_MANAGED_METHOD(hmod, NATIVE_LIBRARY_HANDLE, args); - GCPROTECT_END(); + GCPROTECT_END(); - return hmod; -} - -// Return the AssemblyLoadContext for an assembly -INT_PTR GetManagedAssemblyLoadContext(Assembly* pAssembly) -{ - STANDARD_VM_CONTRACT; + return hmod; + } - PTR_ICLRPrivBinder pBindingContext = pAssembly->GetManifestFile()->GetBindingContext(); - if (pBindingContext == NULL) + // Return the AssemblyLoadContext for an assembly + INT_PTR GetManagedAssemblyLoadContext(Assembly* pAssembly) { - // GetBindingContext() returns NULL for System.Private.CoreLib - return NULL; - } + STANDARD_VM_CONTRACT; + + PTR_ICLRPrivBinder pBindingContext = pAssembly->GetManifestFile()->GetBindingContext(); + if (pBindingContext == NULL) + { + // GetBindingContext() returns NULL for System.Private.CoreLib + return NULL; + } - UINT_PTR assemblyBinderID = 0; - IfFailThrow(pBindingContext->GetBinderID(&assemblyBinderID)); + UINT_PTR assemblyBinderID = 0; + IfFailThrow(pBindingContext->GetBinderID(&assemblyBinderID)); - AppDomain *pDomain = GetAppDomain(); - ICLRPrivBinder *pCurrentBinder = reinterpret_cast(assemblyBinderID); + AppDomain *pDomain = GetAppDomain(); + ICLRPrivBinder *pCurrentBinder = reinterpret_cast(assemblyBinderID); #ifdef FEATURE_COMINTEROP - if (AreSameBinderInstance(pCurrentBinder, pDomain->GetWinRtBinder())) - { - // No ALC associated handle with WinRT Binders. - return NULL; - } + if (AreSameBinderInstance(pCurrentBinder, pDomain->GetWinRtBinder())) + { + // No ALC associated handle with WinRT Binders. + return NULL; + } #endif // FEATURE_COMINTEROP - // The code here deals with two implementations of ICLRPrivBinder interface: - // - CLRPrivBinderCoreCLR for the TPA binder in the default ALC, and - // - CLRPrivBinderAssemblyLoadContext for custom ALCs. - // in order obtain the associated ALC handle. - INT_PTR ptrManagedAssemblyLoadContext = AreSameBinderInstance(pCurrentBinder, pDomain->GetTPABinderContext()) - ? ((CLRPrivBinderCoreCLR *)pCurrentBinder)->GetManagedAssemblyLoadContext() - : ((CLRPrivBinderAssemblyLoadContext *)pCurrentBinder)->GetManagedAssemblyLoadContext(); - - return ptrManagedAssemblyLoadContext; -} - -// static -NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleViaEvent(NDirectMethodDesc * pMD, PCWSTR wszLibName) -{ - STANDARD_VM_CONTRACT; + // The code here deals with two implementations of ICLRPrivBinder interface: + // - CLRPrivBinderCoreCLR for the TPA binder in the default ALC, and + // - CLRPrivBinderAssemblyLoadContext for custom ALCs. + // in order obtain the associated ALC handle. + INT_PTR ptrManagedAssemblyLoadContext = AreSameBinderInstance(pCurrentBinder, pDomain->GetTPABinderContext()) + ? ((CLRPrivBinderCoreCLR *)pCurrentBinder)->GetManagedAssemblyLoadContext() + : ((CLRPrivBinderAssemblyLoadContext *)pCurrentBinder)->GetManagedAssemblyLoadContext(); - NATIVE_LIBRARY_HANDLE hmod = NULL; - Assembly* pAssembly = pMD->GetMethodTable()->GetAssembly(); - INT_PTR ptrManagedAssemblyLoadContext = GetManagedAssemblyLoadContext(pAssembly); + return ptrManagedAssemblyLoadContext; + } - if (ptrManagedAssemblyLoadContext == NULL) + NATIVE_LIBRARY_HANDLE LoadNativeLibraryViaAssemblyLoadContextEvent(Assembly * pAssembly, PCWSTR wszLibName) { - return NULL; - } + STANDARD_VM_CONTRACT; - GCX_COOP(); + INT_PTR ptrManagedAssemblyLoadContext = GetManagedAssemblyLoadContext(pAssembly); + if (ptrManagedAssemblyLoadContext == NULL) + { + return NULL; + } - struct { - STRINGREF DllName; - OBJECTREF AssemblyRef; - } gc = { NULL, NULL }; + NATIVE_LIBRARY_HANDLE hmod = NULL; - GCPROTECT_BEGIN(gc); + GCX_COOP(); - gc.DllName = StringObject::NewString(wszLibName); - gc.AssemblyRef = pAssembly->GetExposedObject(); + struct { + STRINGREF DllName; + OBJECTREF AssemblyRef; + } gc = { NULL, NULL }; - // Prepare to invoke System.Runtime.Loader.AssemblyLoadContext.ResolveUnmanagedDllUsingEvent method - // While ResolveUnmanagedDllUsingEvent() could compute the AssemblyLoadContext using the AssemblyRef - // argument, it will involve another pInvoke to the runtime. So AssemblyLoadContext is passed in - // as an additional argument. - PREPARE_NONVIRTUAL_CALLSITE(METHOD__ASSEMBLYLOADCONTEXT__RESOLVEUNMANAGEDDLLUSINGEVENT); - DECLARE_ARGHOLDER_ARRAY(args, 3); - args[ARGNUM_0] = STRINGREF_TO_ARGHOLDER(gc.DllName); - args[ARGNUM_1] = OBJECTREF_TO_ARGHOLDER(gc.AssemblyRef); - args[ARGNUM_2] = PTR_TO_ARGHOLDER(ptrManagedAssemblyLoadContext); + GCPROTECT_BEGIN(gc); - // Make the call - CALL_MANAGED_METHOD(hmod, NATIVE_LIBRARY_HANDLE, args); + gc.DllName = StringObject::NewString(wszLibName); + gc.AssemblyRef = pAssembly->GetExposedObject(); - GCPROTECT_END(); + // Prepare to invoke System.Runtime.Loader.AssemblyLoadContext.ResolveUnmanagedDllUsingEvent method + // While ResolveUnmanagedDllUsingEvent() could compute the AssemblyLoadContext using the AssemblyRef + // argument, it will involve another pInvoke to the runtime. So AssemblyLoadContext is passed in + // as an additional argument. + PREPARE_NONVIRTUAL_CALLSITE(METHOD__ASSEMBLYLOADCONTEXT__RESOLVEUNMANAGEDDLLUSINGEVENT); + DECLARE_ARGHOLDER_ARRAY(args, 3); + args[ARGNUM_0] = STRINGREF_TO_ARGHOLDER(gc.DllName); + args[ARGNUM_1] = OBJECTREF_TO_ARGHOLDER(gc.AssemblyRef); + args[ARGNUM_2] = PTR_TO_ARGHOLDER(ptrManagedAssemblyLoadContext); - return hmod; -} + // Make the call + CALL_MANAGED_METHOD(hmod, NATIVE_LIBRARY_HANDLE, args); -NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleViaCallback(NDirectMethodDesc * pMD, LPCWSTR wszLibName) -{ - STANDARD_VM_CONTRACT; + GCPROTECT_END(); - if (pMD->GetModule()->IsSystem()) - { - // Don't attempt to callback on Corelib itself. - // The LoadLibrary callback stub is managed code that requires CoreLib - return NULL; + return hmod; } - DWORD dllImportSearchPathFlags; - BOOL searchAssemblyDirectory; - BOOL hasDllImportSearchPathFlags = GetDllImportSearchPathFlags(pMD, &dllImportSearchPathFlags, &searchAssemblyDirectory); - dllImportSearchPathFlags |= searchAssemblyDirectory ? DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY : 0; - - Assembly* pAssembly = pMD->GetMethodTable()->GetAssembly(); - NATIVE_LIBRARY_HANDLE handle = NULL; + NATIVE_LIBRARY_HANDLE LoadNativeLibraryViaDllImportResolver(NDirectMethodDesc * pMD, LPCWSTR wszLibName) + { + STANDARD_VM_CONTRACT; - GCX_COOP(); + if (pMD->GetModule()->IsSystem()) + { + // Don't attempt to callback on Corelib itself. + // The LoadLibrary callback stub is managed code that requires CoreLib + return NULL; + } - struct { - STRINGREF libNameRef; - OBJECTREF assemblyRef; - } gc = { NULL, NULL }; + DWORD dllImportSearchPathFlags; + BOOL searchAssemblyDirectory; + BOOL hasDllImportSearchPathFlags = GetDllImportSearchPathFlags(pMD, &dllImportSearchPathFlags, &searchAssemblyDirectory); + dllImportSearchPathFlags |= searchAssemblyDirectory ? DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY : 0; - GCPROTECT_BEGIN(gc); + Assembly* pAssembly = pMD->GetMethodTable()->GetAssembly(); + NATIVE_LIBRARY_HANDLE handle = NULL; - gc.libNameRef = StringObject::NewString(wszLibName); - gc.assemblyRef = pAssembly->GetExposedObject(); + GCX_COOP(); - PREPARE_NONVIRTUAL_CALLSITE(METHOD__NATIVELIBRARY__LOADLIBRARYCALLBACKSTUB); - DECLARE_ARGHOLDER_ARRAY(args, 4); - args[ARGNUM_0] = STRINGREF_TO_ARGHOLDER(gc.libNameRef); - args[ARGNUM_1] = OBJECTREF_TO_ARGHOLDER(gc.assemblyRef); - args[ARGNUM_2] = BOOL_TO_ARGHOLDER(hasDllImportSearchPathFlags); - args[ARGNUM_3] = DWORD_TO_ARGHOLDER(dllImportSearchPathFlags); + struct { + STRINGREF libNameRef; + OBJECTREF assemblyRef; + } gc = { NULL, NULL }; - // Make the call - CALL_MANAGED_METHOD(handle, NATIVE_LIBRARY_HANDLE, args); - GCPROTECT_END(); + GCPROTECT_BEGIN(gc); - return handle; -} + gc.libNameRef = StringObject::NewString(wszLibName); + gc.assemblyRef = pAssembly->GetExposedObject(); -// Try to load the module alongside the assembly where the PInvoke was declared. -NATIVE_LIBRARY_HANDLE NDirect::LoadFromPInvokeAssemblyDirectory(Assembly *pAssembly, LPCWSTR libName, DWORD flags, LoadLibErrorTracker *pErrorTracker) -{ - STANDARD_VM_CONTRACT; + PREPARE_NONVIRTUAL_CALLSITE(METHOD__NATIVELIBRARY__LOADLIBRARYCALLBACKSTUB); + DECLARE_ARGHOLDER_ARRAY(args, 4); + args[ARGNUM_0] = STRINGREF_TO_ARGHOLDER(gc.libNameRef); + args[ARGNUM_1] = OBJECTREF_TO_ARGHOLDER(gc.assemblyRef); + args[ARGNUM_2] = BOOL_TO_ARGHOLDER(hasDllImportSearchPathFlags); + args[ARGNUM_3] = DWORD_TO_ARGHOLDER(dllImportSearchPathFlags); - NATIVE_LIBRARY_HANDLE hmod = NULL; + // Make the call + CALL_MANAGED_METHOD(handle, NATIVE_LIBRARY_HANDLE, args); + GCPROTECT_END(); - SString path = pAssembly->GetManifestFile()->GetPath(); + return handle; + } - SString::Iterator lastPathSeparatorIter = path.End(); - if (PEAssembly::FindLastPathSeparator(path, lastPathSeparatorIter)) + // Try to load the module alongside the assembly where the PInvoke was declared. + NATIVE_LIBRARY_HANDLE LoadFromPInvokeAssemblyDirectory(Assembly *pAssembly, LPCWSTR libName, DWORD flags, LoadLibErrorTracker *pErrorTracker) { - lastPathSeparatorIter++; - path.Truncate(lastPathSeparatorIter); + STANDARD_VM_CONTRACT; - path.Append(libName); - hmod = LocalLoadLibraryHelper(path, flags, pErrorTracker); - } + NATIVE_LIBRARY_HANDLE hmod = NULL; - return hmod; -} + SString path = pAssembly->GetManifestFile()->GetPath(); -// Try to load the module from the native DLL search directories -NATIVE_LIBRARY_HANDLE NDirect::LoadFromNativeDllSearchDirectories(LPCWSTR libName, DWORD flags, LoadLibErrorTracker *pErrorTracker) -{ - STANDARD_VM_CONTRACT; + SString::Iterator lastPathSeparatorIter = path.End(); + if (PEAssembly::FindLastPathSeparator(path, lastPathSeparatorIter)) + { + lastPathSeparatorIter++; + path.Truncate(lastPathSeparatorIter); - NATIVE_LIBRARY_HANDLE hmod = NULL; - AppDomain* pDomain = GetAppDomain(); + path.Append(libName); + hmod = LocalLoadLibraryHelper(path, flags, pErrorTracker); + } + + return hmod; + } - if (pDomain->HasNativeDllSearchDirectories()) + // Try to load the module from the native DLL search directories + NATIVE_LIBRARY_HANDLE LoadFromNativeDllSearchDirectories(LPCWSTR libName, DWORD flags, LoadLibErrorTracker *pErrorTracker) { - AppDomain::PathIterator pathIter = pDomain->IterateNativeDllSearchDirectories(); - while (hmod == NULL && pathIter.Next()) + STANDARD_VM_CONTRACT; + + NATIVE_LIBRARY_HANDLE hmod = NULL; + AppDomain* pDomain = GetAppDomain(); + + if (pDomain->HasNativeDllSearchDirectories()) { - SString qualifiedPath(*(pathIter.GetPath())); - qualifiedPath.Append(libName); - if (!Path::IsRelative(qualifiedPath)) + AppDomain::PathIterator pathIter = pDomain->IterateNativeDllSearchDirectories(); + while (hmod == NULL && pathIter.Next()) { - hmod = LocalLoadLibraryHelper(qualifiedPath, flags, pErrorTracker); + SString qualifiedPath(*(pathIter.GetPath())); + qualifiedPath.Append(libName); + if (!Path::IsRelative(qualifiedPath)) + { + hmod = LocalLoadLibraryHelper(qualifiedPath, flags, pErrorTracker); + } } } - } - return hmod; -} + return hmod; + } #ifdef TARGET_UNIX -static const int MaxVariationCount = 4; -static void DetermineLibNameVariations(const WCHAR** libNameVariations, int* numberOfVariations, const SString& libName, bool libNameIsRelativePath) -{ - // Supported lib name variations - static auto NameFmt = W("%.0s%s%.0s"); - static auto PrefixNameFmt = W("%s%s%.0s"); - static auto NameSuffixFmt = W("%.0s%s%s"); - static auto PrefixNameSuffixFmt = W("%s%s%s"); + const int MaxVariationCount = 4; + void DetermineLibNameVariations(const WCHAR** libNameVariations, int* numberOfVariations, const SString& libName, bool libNameIsRelativePath) + { + // Supported lib name variations + static auto NameFmt = W("%.0s%s%.0s"); + static auto PrefixNameFmt = W("%s%s%.0s"); + static auto NameSuffixFmt = W("%.0s%s%s"); + static auto PrefixNameSuffixFmt = W("%s%s%s"); - _ASSERTE(*numberOfVariations >= MaxVariationCount); + _ASSERTE(*numberOfVariations >= MaxVariationCount); - int varCount = 0; - if (!libNameIsRelativePath) - { - libNameVariations[varCount++] = NameFmt; - } - else - { - // We check if the suffix is contained in the name, because on Linux it is common to append - // a version number to the library name (e.g. 'libicuuc.so.57'). - bool containsSuffix = false; - SString::CIterator it = libName.Begin(); - if (libName.Find(it, PLATFORM_SHARED_LIB_SUFFIX_W)) + int varCount = 0; + if (!libNameIsRelativePath) { - it += COUNTOF(PLATFORM_SHARED_LIB_SUFFIX_W); - containsSuffix = it == libName.End() || *it == (WCHAR)'.'; + libNameVariations[varCount++] = NameFmt; } + else + { + // We check if the suffix is contained in the name, because on Linux it is common to append + // a version number to the library name (e.g. 'libicuuc.so.57'). + bool containsSuffix = false; + SString::CIterator it = libName.Begin(); + if (libName.Find(it, PLATFORM_SHARED_LIB_SUFFIX_W)) + { + it += COUNTOF(PLATFORM_SHARED_LIB_SUFFIX_W); + containsSuffix = it == libName.End() || *it == (WCHAR)'.'; + } - // If the path contains a path delimiter, we don't add a prefix - it = libName.Begin(); - bool containsDelim = libName.Find(it, DIRECTORY_SEPARATOR_STR_W); + // If the path contains a path delimiter, we don't add a prefix + it = libName.Begin(); + bool containsDelim = libName.Find(it, DIRECTORY_SEPARATOR_STR_W); - if (containsSuffix) - { - libNameVariations[varCount++] = NameFmt; + if (containsSuffix) + { + libNameVariations[varCount++] = NameFmt; - if (!containsDelim) - libNameVariations[varCount++] = PrefixNameFmt; + if (!containsDelim) + libNameVariations[varCount++] = PrefixNameFmt; - libNameVariations[varCount++] = NameSuffixFmt; + libNameVariations[varCount++] = NameSuffixFmt; - if (!containsDelim) - libNameVariations[varCount++] = PrefixNameSuffixFmt; - } - else - { - libNameVariations[varCount++] = NameSuffixFmt; + if (!containsDelim) + libNameVariations[varCount++] = PrefixNameSuffixFmt; + } + else + { + libNameVariations[varCount++] = NameSuffixFmt; - if (!containsDelim) - libNameVariations[varCount++] = PrefixNameSuffixFmt; + if (!containsDelim) + libNameVariations[varCount++] = PrefixNameSuffixFmt; - libNameVariations[varCount++] = NameFmt; + libNameVariations[varCount++] = NameFmt; - if (!containsDelim) - libNameVariations[varCount++] = PrefixNameFmt; + if (!containsDelim) + libNameVariations[varCount++] = PrefixNameFmt; + } } - } - *numberOfVariations = varCount; -} -#else // TARGET_UNIX -static const int MaxVariationCount = 2; -static void DetermineLibNameVariations(const WCHAR** libNameVariations, int* numberOfVariations, const SString& libName, bool libNameIsRelativePath) -{ - // Supported lib name variations - static auto NameFmt = W("%.0s%s%.0s"); - static auto NameSuffixFmt = W("%.0s%s%s"); - - _ASSERTE(*numberOfVariations >= MaxVariationCount); - - int varCount = 0; - - // The purpose of following code is to workaround LoadLibrary limitation: - // LoadLibrary won't append extension if filename itself contains '.'. Thus it will break the following scenario: - // [DllImport("A.B")] // The full name for file is "A.B.dll". This is common code pattern for cross-platform PInvoke - // The workaround for above scenario is to call LoadLibrary with "A.B" first, if it fails, then call LoadLibrary with "A.B.dll" - auto it = libName.Begin(); - if (!libNameIsRelativePath || - !libName.Find(it, W('.')) || - libName.EndsWith(W(".")) || - libName.EndsWithCaseInsensitive(W(".dll")) || - libName.EndsWithCaseInsensitive(W(".exe"))) - { - // Follow LoadLibrary rules in MSDN doc: https://msdn.microsoft.com/en-us/library/windows/desktop/ms684175(v=vs.85).aspx - // If the string specifies a full path, the function searches only that path for the module. - // If the string specifies a module name without a path and the file name extension is omitted, the function appends the default library extension .dll to the module name. - // To prevent the function from appending .dll to the module name, include a trailing point character (.) in the module name string. - libNameVariations[varCount++] = NameFmt; - } - else - { - libNameVariations[varCount++] = NameFmt; - libNameVariations[varCount++] = NameSuffixFmt; + *numberOfVariations = varCount; } +#else // TARGET_UNIX + const int MaxVariationCount = 2; + void DetermineLibNameVariations(const WCHAR** libNameVariations, int* numberOfVariations, const SString& libName, bool libNameIsRelativePath) + { + // Supported lib name variations + static auto NameFmt = W("%.0s%s%.0s"); + static auto NameSuffixFmt = W("%.0s%s%s"); + + _ASSERTE(*numberOfVariations >= MaxVariationCount); + + int varCount = 0; + + // The purpose of following code is to workaround LoadLibrary limitation: + // LoadLibrary won't append extension if filename itself contains '.'. Thus it will break the following scenario: + // [DllImport("A.B")] // The full name for file is "A.B.dll". This is common code pattern for cross-platform PInvoke + // The workaround for above scenario is to call LoadLibrary with "A.B" first, if it fails, then call LoadLibrary with "A.B.dll" + auto it = libName.Begin(); + if (!libNameIsRelativePath || + !libName.Find(it, W('.')) || + libName.EndsWith(W(".")) || + libName.EndsWithCaseInsensitive(W(".dll")) || + libName.EndsWithCaseInsensitive(W(".exe"))) + { + // Follow LoadLibrary rules in MSDN doc: https://msdn.microsoft.com/en-us/library/windows/desktop/ms684175(v=vs.85).aspx + // If the string specifies a full path, the function searches only that path for the module. + // If the string specifies a module name without a path and the file name extension is omitted, the function appends the default library extension .dll to the module name. + // To prevent the function from appending .dll to the module name, include a trailing point character (.) in the module name string. + libNameVariations[varCount++] = NameFmt; + } + else + { + libNameVariations[varCount++] = NameFmt; + libNameVariations[varCount++] = NameSuffixFmt; + } - *numberOfVariations = varCount; -} + *numberOfVariations = varCount; + } #endif // TARGET_UNIX -// Search for the library and variants of its name in probing directories. -//static -NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleBySearch(Assembly *callingAssembly, - BOOL searchAssemblyDirectory, DWORD dllImportSearchPathFlags, - LoadLibErrorTracker * pErrorTracker, LPCWSTR wszLibName) -{ - STANDARD_VM_CONTRACT; + // Search for the library and variants of its name in probing directories. + NATIVE_LIBRARY_HANDLE LoadNativeLibraryBySearch(Assembly *callingAssembly, + BOOL searchAssemblyDirectory, DWORD dllImportSearchPathFlags, + LoadLibErrorTracker * pErrorTracker, LPCWSTR wszLibName) + { + STANDARD_VM_CONTRACT; - NATIVE_LIBRARY_HANDLE hmod = NULL; + NATIVE_LIBRARY_HANDLE hmod = NULL; #if defined(FEATURE_CORESYSTEM) && !defined(TARGET_UNIX) - // Try to go straight to System32 for Windows API sets. This is replicating quick check from - // the OS implementation of api sets. - if (IsWindowsAPISet(wszLibName)) - { - hmod = LocalLoadLibraryHelper(wszLibName, LOAD_LIBRARY_SEARCH_SYSTEM32, pErrorTracker); - if (hmod != NULL) + // Try to go straight to System32 for Windows API sets. This is replicating quick check from + // the OS implementation of api sets. + if (IsWindowsAPISet(wszLibName)) { - return hmod; + hmod = LocalLoadLibraryHelper(wszLibName, LOAD_LIBRARY_SEARCH_SYSTEM32, pErrorTracker); + if (hmod != NULL) + { + return hmod; + } } - } #endif // FEATURE_CORESYSTEM && !TARGET_UNIX - AppDomain* pDomain = GetAppDomain(); - DWORD loadWithAlteredPathFlags = GetLoadWithAlteredSearchPathFlag(); - bool libNameIsRelativePath = Path::IsRelative(wszLibName); - - // P/Invokes are often declared with variations on the actual library name. - // For example, it's common to leave off the extension/suffix of the library - // even if it has one, or to leave off a prefix like "lib" even if it has one - // (both of these are typically done to smooth over cross-platform differences). - // We try to dlopen with such variations on the original. - const WCHAR* prefixSuffixCombinations[MaxVariationCount] = {}; - int numberOfVariations = COUNTOF(prefixSuffixCombinations); - DetermineLibNameVariations(prefixSuffixCombinations, &numberOfVariations, wszLibName, libNameIsRelativePath); - for (int i = 0; i < numberOfVariations; i++) - { - SString currLibNameVariation; - currLibNameVariation.Printf(prefixSuffixCombinations[i], PLATFORM_SHARED_LIB_PREFIX_W, wszLibName, PLATFORM_SHARED_LIB_SUFFIX_W); - - // NATIVE_DLL_SEARCH_DIRECTORIES set by host is considered well known path - hmod = LoadFromNativeDllSearchDirectories(currLibNameVariation, loadWithAlteredPathFlags, pErrorTracker); - if (hmod != NULL) - { - return hmod; - } + AppDomain* pDomain = GetAppDomain(); + DWORD loadWithAlteredPathFlags = GetLoadWithAlteredSearchPathFlag(); + bool libNameIsRelativePath = Path::IsRelative(wszLibName); + + // P/Invokes are often declared with variations on the actual library name. + // For example, it's common to leave off the extension/suffix of the library + // even if it has one, or to leave off a prefix like "lib" even if it has one + // (both of these are typically done to smooth over cross-platform differences). + // We try to dlopen with such variations on the original. + const WCHAR* prefixSuffixCombinations[MaxVariationCount] = {}; + int numberOfVariations = COUNTOF(prefixSuffixCombinations); + DetermineLibNameVariations(prefixSuffixCombinations, &numberOfVariations, wszLibName, libNameIsRelativePath); + for (int i = 0; i < numberOfVariations; i++) + { + SString currLibNameVariation; + currLibNameVariation.Printf(prefixSuffixCombinations[i], PLATFORM_SHARED_LIB_PREFIX_W, wszLibName, PLATFORM_SHARED_LIB_SUFFIX_W); + + // NATIVE_DLL_SEARCH_DIRECTORIES set by host is considered well known path + hmod = LoadFromNativeDllSearchDirectories(currLibNameVariation, loadWithAlteredPathFlags, pErrorTracker); + if (hmod != NULL) + { + return hmod; + } - if (!libNameIsRelativePath) - { - DWORD flags = loadWithAlteredPathFlags; - if ((dllImportSearchPathFlags & LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) != 0) + if (!libNameIsRelativePath) { - // LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR is the only flag affecting absolute path. Don't OR the flags - // unconditionally as all absolute path P/Invokes could then lose LOAD_WITH_ALTERED_SEARCH_PATH. - flags |= dllImportSearchPathFlags; + DWORD flags = loadWithAlteredPathFlags; + if ((dllImportSearchPathFlags & LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) != 0) + { + // LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR is the only flag affecting absolute path. Don't OR the flags + // unconditionally as all absolute path P/Invokes could then lose LOAD_WITH_ALTERED_SEARCH_PATH. + flags |= dllImportSearchPathFlags; + } + + hmod = LocalLoadLibraryHelper(currLibNameVariation, flags, pErrorTracker); + if (hmod != NULL) + { + return hmod; + } + } + else if ((callingAssembly != nullptr) && searchAssemblyDirectory) + { + hmod = LoadFromPInvokeAssemblyDirectory(callingAssembly, currLibNameVariation, loadWithAlteredPathFlags | dllImportSearchPathFlags, pErrorTracker); + if (hmod != NULL) + { + return hmod; + } } - hmod = LocalLoadLibraryHelper(currLibNameVariation, flags, pErrorTracker); + hmod = LocalLoadLibraryHelper(currLibNameVariation, dllImportSearchPathFlags, pErrorTracker); if (hmod != NULL) { return hmod; } } - else if ((callingAssembly != nullptr) && searchAssemblyDirectory) + + // This may be an assembly name + // Format is "fileName, assemblyDisplayName" + MAKE_UTF8PTR_FROMWIDE(szLibName, wszLibName); + char *szComma = strchr(szLibName, ','); + if (szComma) { - hmod = LoadFromPInvokeAssemblyDirectory(callingAssembly, currLibNameVariation, loadWithAlteredPathFlags | dllImportSearchPathFlags, pErrorTracker); - if (hmod != NULL) + *szComma = '\0'; + // Trim white spaces + while (COMCharacter::nativeIsWhiteSpace(*(++szComma))); + + AssemblySpec spec; + if (SUCCEEDED(spec.Init(szComma))) { - return hmod; + // Need to perform case insensitive hashing. + SString moduleName(SString::Utf8, szLibName); + moduleName.LowerCase(); + + StackScratchBuffer buffer; + szLibName = (LPSTR)moduleName.GetUTF8(buffer); + + Assembly *pAssembly = spec.LoadAssembly(FILE_LOADED); + Module *pModule = pAssembly->FindModuleByName(szLibName); + + hmod = LocalLoadLibraryHelper(pModule->GetPath(), loadWithAlteredPathFlags | dllImportSearchPathFlags, pErrorTracker); } } - hmod = LocalLoadLibraryHelper(currLibNameVariation, dllImportSearchPathFlags, pErrorTracker); - if (hmod != NULL) - { - return hmod; - } + return hmod; } - // This may be an assembly name - // Format is "fileName, assemblyDisplayName" - MAKE_UTF8PTR_FROMWIDE(szLibName, wszLibName); - char *szComma = strchr(szLibName, ','); - if (szComma) + NATIVE_LIBRARY_HANDLE LoadNativeLibraryBySearch(NDirectMethodDesc *pMD, LoadLibErrorTracker *pErrorTracker, PCWSTR wszLibName) { - *szComma = '\0'; - // Trim white spaces - while (COMCharacter::nativeIsWhiteSpace(*(++szComma))); + STANDARD_VM_CONTRACT; - AssemblySpec spec; - if (SUCCEEDED(spec.Init(szComma))) - { - // Need to perform case insensitive hashing. - SString moduleName(SString::Utf8, szLibName); - moduleName.LowerCase(); + BOOL searchAssemblyDirectory; + DWORD dllImportSearchPathFlags; - StackScratchBuffer buffer; - szLibName = (LPSTR)moduleName.GetUTF8(buffer); + GetDllImportSearchPathFlags(pMD, &dllImportSearchPathFlags, &searchAssemblyDirectory); - Assembly *pAssembly = spec.LoadAssembly(FILE_LOADED); - Module *pModule = pAssembly->FindModuleByName(szLibName); + Assembly *pAssembly = pMD->GetMethodTable()->GetAssembly(); + return LoadNativeLibraryBySearch(pAssembly, searchAssemblyDirectory, dllImportSearchPathFlags, pErrorTracker, wszLibName); + } +} - hmod = LocalLoadLibraryHelper(pModule->GetPath(), loadWithAlteredPathFlags | dllImportSearchPathFlags, pErrorTracker); - } +// static +NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryByName(LPCWSTR libraryName, Assembly *callingAssembly, + BOOL hasDllImportSearchFlags, DWORD dllImportSearchFlags, + BOOL throwOnError) +{ + CONTRACTL + { + STANDARD_VM_CHECK; + PRECONDITION(CheckPointer(libraryName)); + PRECONDITION(CheckPointer(callingAssembly)); + } + CONTRACTL_END; + + NATIVE_LIBRARY_HANDLE hmod = nullptr; + + // Resolve using the AssemblyLoadContext.LoadUnmanagedDll implementation + hmod = LoadNativeLibraryViaAssemblyLoadContext(callingAssembly, libraryName); + if (hmod != nullptr) + return hmod; + + // Check if a default dllImportSearchPathFlags was passed in. If so, use that value. + // Otherwise, check if the assembly has the DefaultDllImportSearchPathsAttribute attribute. + // If so, use that value. + BOOL searchAssemblyDirectory; + DWORD dllImportSearchPathFlags; + if (hasDllImportSearchFlags) + { + dllImportSearchPathFlags = dllImportSearchFlags & ~DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY; + searchAssemblyDirectory = dllImportSearchFlags & DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY; + + } + else + { + GetDllImportSearchPathFlags(callingAssembly->GetManifestModule(), + &dllImportSearchPathFlags, &searchAssemblyDirectory); + } + + LoadLibErrorTracker errorTracker; + hmod = LoadNativeLibraryBySearch(callingAssembly, searchAssemblyDirectory, dllImportSearchPathFlags, &errorTracker, libraryName); + if (hmod != nullptr) + return hmod; + + // Resolve using the AssemblyLoadContext.ResolvingUnmanagedDll event + hmod = LoadNativeLibraryViaAssemblyLoadContextEvent(callingAssembly, libraryName); + if (hmod != nullptr) + return hmod; + + if (throwOnError) + { + SString libraryPathSString(libraryName); + errorTracker.Throw(libraryPathSString); } return hmod; } -// This Method returns an instance of the PAL-Registered handle -NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModule(NDirectMethodDesc * pMD, LoadLibErrorTracker * pErrorTracker) +NATIVE_LIBRARY_HANDLE NDirect::LoadNativeLibrary(NDirectMethodDesc * pMD, LoadLibErrorTracker * pErrorTracker) { CONTRACTL { @@ -7024,23 +7027,19 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModule(NDirectMethodDesc * pMD, LoadLi PREFIX_ASSUME( name != NULL ); MAKE_WIDEPTR_FROMUTF8( wszLibName, name ); - NativeLibraryHandleHolder hmod = LoadLibraryModuleViaCallback(pMD, wszLibName); + NativeLibraryHandleHolder hmod = LoadNativeLibraryViaDllImportResolver(pMD, wszLibName); if (hmod != NULL) { return hmod.Extract(); } AppDomain* pDomain = GetAppDomain(); + Assembly* pAssembly = pMD->GetMethodTable()->GetAssembly(); - // AssemblyLoadContext is not supported in AppX mode and thus, - // we should not perform PInvoke resolution via it when operating in AppX mode. - if (!AppX::IsAppXProcess()) + hmod = LoadNativeLibraryViaAssemblyLoadContext(pAssembly, wszLibName); + if (hmod != NULL) { - hmod = LoadLibraryModuleViaHost(pMD, wszLibName); - if (hmod != NULL) - { - return hmod.Extract(); - } + return hmod.Extract(); } hmod = pDomain->FindUnmanagedImageInCache(wszLibName); @@ -7049,7 +7048,7 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModule(NDirectMethodDesc * pMD, LoadLi return hmod.Extract(); } - hmod = LoadLibraryModuleBySearch(pMD, pErrorTracker, wszLibName); + hmod = LoadNativeLibraryBySearch(pMD, pErrorTracker, wszLibName); if (hmod != NULL) { // If we have a handle add it to the cache. @@ -7057,13 +7056,10 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModule(NDirectMethodDesc * pMD, LoadLi return hmod.Extract(); } - if (!AppX::IsAppXProcess()) + hmod = LoadNativeLibraryViaAssemblyLoadContextEvent(pAssembly, wszLibName); + if (hmod != NULL) { - hmod = LoadLibraryModuleViaEvent(pMD, wszLibName); - if (hmod != NULL) - { - return hmod.Extract(); - } + return hmod.Extract(); } return hmod.Extract(); @@ -7116,7 +7112,7 @@ VOID NDirect::NDirectLink(NDirectMethodDesc *pMD) LoadLibErrorTracker errorTracker; BOOL fSuccess = FALSE; - NATIVE_LIBRARY_HANDLE hmod = LoadLibraryModule( pMD, &errorTracker ); + NATIVE_LIBRARY_HANDLE hmod = LoadNativeLibrary( pMD, &errorTracker ); if ( hmod ) { LPVOID pvTarget = NDirectGetEntryPoint(pMD, hmod); diff --git a/src/coreclr/src/vm/dllimport.h b/src/coreclr/src/vm/dllimport.h index cdc67c3..4b7f011 100644 --- a/src/coreclr/src/vm/dllimport.h +++ b/src/coreclr/src/vm/dllimport.h @@ -76,7 +76,7 @@ public: static NATIVE_LIBRARY_HANDLE LoadLibraryByName(LPCWSTR name, Assembly *callingAssembly, BOOL hasDllImportSearchPathFlags, DWORD dllImportSearchPathFlags, BOOL throwOnError); - static NATIVE_LIBRARY_HANDLE LoadLibraryModule(NDirectMethodDesc * pMD, LoadLibErrorTracker *pErrorTracker); + static NATIVE_LIBRARY_HANDLE LoadNativeLibrary(NDirectMethodDesc * pMD, LoadLibErrorTracker *pErrorTracker); static void FreeNativeLibrary(NATIVE_LIBRARY_HANDLE handle); static INT_PTR GetNativeLibraryExport(NATIVE_LIBRARY_HANDLE handle, LPCWSTR symbolName, BOOL throwOnError); @@ -122,14 +122,6 @@ public: private: NDirect() {LIMITED_METHOD_CONTRACT;}; // prevent "new"'s on this class - - static NATIVE_LIBRARY_HANDLE LoadFromNativeDllSearchDirectories(LPCWSTR libName, DWORD flags, LoadLibErrorTracker *pErrorTracker); - static NATIVE_LIBRARY_HANDLE LoadFromPInvokeAssemblyDirectory(Assembly *pAssembly, LPCWSTR libName, DWORD flags, LoadLibErrorTracker *pErrorTracker); - static NATIVE_LIBRARY_HANDLE LoadLibraryModuleViaHost(NDirectMethodDesc * pMD, LPCWSTR wszLibName); - static NATIVE_LIBRARY_HANDLE LoadLibraryModuleViaEvent(NDirectMethodDesc * pMD, LPCWSTR wszLibName); - static NATIVE_LIBRARY_HANDLE LoadLibraryModuleViaCallback(NDirectMethodDesc * pMD, LPCWSTR wszLibName); - static NATIVE_LIBRARY_HANDLE LoadLibraryModuleBySearch(NDirectMethodDesc * pMD, LoadLibErrorTracker * pErrorTracker, LPCWSTR wszLibName); - static NATIVE_LIBRARY_HANDLE LoadLibraryModuleBySearch(Assembly *callingAssembly, BOOL searchAssemblyDirectory, DWORD dllImportSearchPathFlags, LoadLibErrorTracker * pErrorTracker, LPCWSTR wszLibName); }; //---------------------------------------------------------------- diff --git a/src/coreclr/tests/src/Interop/NativeLibrary/AssemblyLoadContext/ResolveUnmanagedDllTests.cs b/src/coreclr/tests/src/Interop/NativeLibrary/AssemblyLoadContext/ResolveUnmanagedDllTests.cs index a6460e3..3b575f9 100644 --- a/src/coreclr/tests/src/Interop/NativeLibrary/AssemblyLoadContext/ResolveUnmanagedDllTests.cs +++ b/src/coreclr/tests/src/Interop/NativeLibrary/AssemblyLoadContext/ResolveUnmanagedDllTests.cs @@ -3,24 +3,54 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using System.IO; +using System.Linq; using System.Runtime.Loader; using System.Reflection; using System.Runtime.InteropServices; using TestLibrary; +public class FakeNativeLibrary +{ + public const string Name = "FakeNativeLibrary"; + public const string RedirectName = "FakeNativeLibraryRedirect"; + + public static readonly IntPtr Handle = new IntPtr(ResolveUnmanagedDllTests.rand.Next()); +} + public class ALC : AssemblyLoadContext { - public bool LoadUnmanagedDllCalled { get; private set; } + private List invocations = new List(); + + public void Reset() + { + invocations.Clear(); + } + + public void Validate(params string[] expectedNames) + { + Assert.AreAllEqual(expectedNames, invocations, $"Unexpected invocations for {nameof(LoadUnmanagedDll)}."); + } protected override IntPtr LoadUnmanagedDll(string unmanagedDllName) { - LoadUnmanagedDllCalled = true; + invocations.Add(unmanagedDllName); if (string.Equals(unmanagedDllName, NativeLibraryToLoad.InvalidName)) return LoadUnmanagedDllFromPath(NativeLibraryToLoad.GetFullPath()); + if (string.Equals(unmanagedDllName, FakeNativeLibrary.Name)) + return FakeNativeLibrary.Handle; + + if (string.Equals(unmanagedDllName, FakeNativeLibrary.RedirectName)) + { + IntPtr ptr; + if (NativeLibrary.TryLoad(FakeNativeLibrary.Name, Assemblies.First(), null, out ptr)) + return ptr; + } + return IntPtr.Zero; } } @@ -28,7 +58,7 @@ public class ALC : AssemblyLoadContext public class ResolveUnmanagedDllTests { private static readonly int seed = 123; - private static readonly Random rand = new Random(seed); + internal static readonly Random rand = new Random(seed); public static int Main() { @@ -50,14 +80,48 @@ public class ResolveUnmanagedDllTests { Console.WriteLine($"Running {nameof(ValidateLoadUnmanagedDll)}..."); + ALC alc = new ALC(); + var asm = alc.LoadFromAssemblyPath(Assembly.GetExecutingAssembly().Location); + + Console.WriteLine(" -- Validate explicit load..."); + + // ALC implementation returns a fake handle value + IntPtr ptr = NativeLibrary.Load(FakeNativeLibrary.Name, asm, null); + alc.Validate(FakeNativeLibrary.Name); + Assert.AreEqual(FakeNativeLibrary.Handle, ptr, $"Unexpected return value for {nameof(NativeLibrary.Load)}"); + + alc.Reset(); + ptr = IntPtr.Zero; + + bool success = NativeLibrary.TryLoad(FakeNativeLibrary.Name, asm, null, out ptr); + Assert.IsTrue(success, $"NativeLibrary.TryLoad should have succeeded"); + alc.Validate(FakeNativeLibrary.Name); + Assert.AreEqual(FakeNativeLibrary.Handle, ptr, $"Unexpected return value for {nameof(NativeLibrary.Load)}"); + + alc.Reset(); + + // ALC implementation calls NativeLibrary.TryLoad with a different name + ptr = NativeLibrary.Load(FakeNativeLibrary.RedirectName, asm, null); + alc.Validate(FakeNativeLibrary.RedirectName, FakeNativeLibrary.Name); + Assert.AreEqual(FakeNativeLibrary.Handle, ptr, $"Unexpected return value for {nameof(NativeLibrary.Load)}"); + + alc.Reset(); + ptr = IntPtr.Zero; + + success = NativeLibrary.TryLoad(FakeNativeLibrary.RedirectName, asm, null, out ptr); + Assert.IsTrue(success, $"NativeLibrary.TryLoad should have succeeded"); + alc.Validate(FakeNativeLibrary.RedirectName, FakeNativeLibrary.Name); + Assert.AreEqual(FakeNativeLibrary.Handle, ptr, $"Unexpected return value for {nameof(NativeLibrary.Load)}"); + + alc.Reset(); + Console.WriteLine(" -- Validate p/invoke..."); int addend1 = rand.Next(int.MaxValue / 2); int addend2 = rand.Next(int.MaxValue / 2); int expected = addend1 + addend2; - ALC alc = new ALC(); int value = NativeSumInAssemblyLoadContext(alc, addend1, addend2); - Assert.IsTrue(alc.LoadUnmanagedDllCalled, "AssemblyLoadContext.LoadUnmanagedDll should have been called."); + alc.Validate(NativeLibraryToLoad.InvalidName); Assert.AreEqual(expected, value, $"Unexpected return value for {nameof(NativeSum)}"); } @@ -65,14 +129,39 @@ public class ResolveUnmanagedDllTests { Console.WriteLine($"Running {nameof(ValidateResolvingUnmanagedDllEvent)}..."); + Console.WriteLine(" -- Validate explicit load: custom ALC..."); + AssemblyLoadContext alcExplicitLoad = new AssemblyLoadContext(nameof(ValidateResolvingUnmanagedDllEvent)); + var asm = alcExplicitLoad.LoadFromAssemblyPath(Assembly.GetExecutingAssembly().Location); + ValidateResolvingUnmanagedDllEvent_ExplicitLoad(asm); + + Console.WriteLine(" -- Validate explicit load: default ALC..."); + ValidateResolvingUnmanagedDllEvent_ExplicitLoad(Assembly.GetExecutingAssembly()); + Console.WriteLine(" -- Validate p/invoke: custom ALC..."); - AssemblyLoadContext alc = new AssemblyLoadContext(nameof(ValidateResolvingUnmanagedDllEvent)); - ValidateResolvingUnmanagedDllEvent_PInvoke(alc); + AssemblyLoadContext alcPInvoke = new AssemblyLoadContext(nameof(ValidateResolvingUnmanagedDllEvent)); + ValidateResolvingUnmanagedDllEvent_PInvoke(alcPInvoke); Console.WriteLine(" -- Validate p/invoke: default ALC..."); ValidateResolvingUnmanagedDllEvent_PInvoke(AssemblyLoadContext.Default); } + private static void ValidateResolvingUnmanagedDllEvent_ExplicitLoad(Assembly assembly) + { + AssemblyLoadContext alc = AssemblyLoadContext.GetLoadContext(assembly); + using (var handler = new Handlers(alc, returnValid: false)) + { + Assert.Throws(() => NativeLibrary.Load(FakeNativeLibrary.Name, assembly, null)); + Assert.IsTrue(handler.EventHandlerInvoked, "Event handler should have been invoked"); + } + + using (var handler = new Handlers(alc, returnValid: true)) + { + IntPtr ptr = NativeLibrary.Load(FakeNativeLibrary.Name, assembly, null); + Assert.IsTrue(handler.EventHandlerInvoked, "Event handler should have been invoked"); + Assert.AreEqual(FakeNativeLibrary.Handle, ptr, $"Unexpected return value for {nameof(NativeLibrary.Load)}"); + } + } + private static void ValidateResolvingUnmanagedDllEvent_PInvoke(AssemblyLoadContext alc) { int addend1 = rand.Next(int.MaxValue / 2); @@ -155,6 +244,9 @@ public class ResolveUnmanagedDllTests if (string.Equals(libraryName, NativeLibraryToLoad.InvalidName)) return NativeLibrary.Load(NativeLibraryToLoad.Name, assembly, null); + if (string.Equals(libraryName, FakeNativeLibrary.Name)) + return FakeNativeLibrary.Handle; + return IntPtr.Zero; } } diff --git a/src/coreclr/tests/src/Interop/NativeLibrary/Callback/CallbackStressTest.cs b/src/coreclr/tests/src/Interop/NativeLibrary/Callback/CallbackStressTest.cs index ceb4e56..49afbb3 100644 --- a/src/coreclr/tests/src/Interop/NativeLibrary/Callback/CallbackStressTest.cs +++ b/src/coreclr/tests/src/Interop/NativeLibrary/Callback/CallbackStressTest.cs @@ -24,21 +24,26 @@ public class CallbackStressTest public static void SetResolve() { Console.WriteLine("Setting PInvoke Resolver"); - + DllImportResolver resolver = (string libraryName, Assembly asm, DllImportSearchPath? dllImportSearchPath) => { - if (dllImportSearchPath != DllImportSearchPath.System32) + if (string.Equals(libraryName, NativeLibraryToLoad.InvalidName)) { - Console.WriteLine($"Unexpected dllImportSearchPath: {dllImportSearchPath.ToString()}"); - throw new ArgumentException(); + if (dllImportSearchPath != DllImportSearchPath.System32) + { + Console.WriteLine($"Unexpected dllImportSearchPath: {dllImportSearchPath.ToString()}"); + throw new ArgumentException(); + } + + return NativeLibrary.Load(NativeLibraryToLoad.Name, asm, null); } - return NativeLibrary.Load(NativeLibraryToLoad.Name, asm, null); + return IntPtr.Zero; }; NativeLibrary.SetDllImportResolver( - Assembly.GetExecutingAssembly(), + Assembly.GetExecutingAssembly(), resolver); } @@ -61,7 +66,7 @@ public class CallbackStressTest s_PInvokesExecuted += (a == 20 ? 1 : 0); } catch (DllNotFoundException) { s_CatchCalled++; } - + throw new ArgumentException(); } @@ -97,7 +102,7 @@ public class CallbackStressTest } finally { s_FinallyCalled++; } } - + [MethodImpl(MethodImplOptions.NoInlining)] public static void ManualRaiseException() { @@ -111,7 +116,7 @@ public class CallbackStressTest // TODO: test on Unix when implementing pinvoke inlining s_SEHExceptionCatchCalled++; #endif - } + } public static int Main() { @@ -123,13 +128,13 @@ public class CallbackStressTest s_WrongPInvokesExecuted++; } catch (DllNotFoundException) { GC.Collect(); s_CatchCalled++; } - + try { DoCall(); } catch (DllNotFoundException) { GC.Collect(); s_CatchCalled++; } - + try { DoCallTryFinally(); } catch (DllNotFoundException) { GC.Collect(); s_CatchCalled++; } - + try { DoCallTryCatch(true); } catch (ArgumentException) { GC.Collect(); s_OtherExceptionCatchCalled++; } @@ -138,10 +143,10 @@ public class CallbackStressTest try { DoCallTryRethrowDifferentExceptionInCatch(); } catch (InvalidOperationException) { GC.Collect(); s_OtherExceptionCatchCalled++; } - + ManualRaiseException(); } - + SetResolve(); for(int i = 0; i < s_LoopCounter; i++) @@ -152,11 +157,11 @@ public class CallbackStressTest try { DoCallTryCatch(false); } catch (ArgumentException) { GC.Collect(); s_OtherExceptionCatchCalled++; } - + ManualRaiseException(); } - - if (s_FinallyCalled == s_LoopCounter && + + if (s_FinallyCalled == s_LoopCounter && s_CatchCalled == (s_LoopCounter * 7) && s_OtherExceptionCatchCalled == (s_LoopCounter * 3) && s_WrongPInvokesExecuted == 0 && @@ -166,7 +171,7 @@ public class CallbackStressTest Console.WriteLine("PASS"); return 100; } - + Console.WriteLine("s_FinallyCalled = " + s_FinallyCalled); Console.WriteLine("s_CatchCalled = " + s_CatchCalled); Console.WriteLine("s_OtherExceptionCatchCalled = " + s_OtherExceptionCatchCalled); @@ -179,7 +184,7 @@ public class CallbackStressTest [DllImport(NativeLibraryToLoad.InvalidName)] [DefaultDllImportSearchPaths(DllImportSearchPath.System32)] static extern int NativeSum(int arg1, int arg2); - + #if WINDOWS [DllImport("kernel32")] static extern void RaiseException(uint dwExceptionCode, uint dwExceptionFlags, uint nNumberOfArguments, IntPtr lpArguments); diff --git a/src/coreclr/tests/src/Interop/NativeLibrary/Callback/CallbackTests.cs b/src/coreclr/tests/src/Interop/NativeLibrary/Callback/CallbackTests.cs index 81215a0..ec5c1cf 100644 --- a/src/coreclr/tests/src/Interop/NativeLibrary/Callback/CallbackTests.cs +++ b/src/coreclr/tests/src/Interop/NativeLibrary/Callback/CallbackTests.cs @@ -2,87 +2,109 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using System.IO; using System.Reflection; using System.Runtime.InteropServices; +using TestLibrary; + [assembly: DefaultDllImportSearchPaths(DllImportSearchPath.SafeDirectories)] public class CallbackTests { + private static readonly int seed = 123; + private static readonly Random rand = new Random(seed); + public static int Main() { try { - Assembly assembly = Assembly.GetExecutingAssembly(); + // The first test sets the resolver for the executing assembly + // Subsequents tests assume the resolver has already been set. + ValidateSetDllImportResolver(); - DllImportResolver resolver = - (string libraryName, Assembly asm, DllImportSearchPath? dllImportSearchPath) => - { - if (dllImportSearchPath != DllImportSearchPath.System32) - { - Console.WriteLine($"Unexpected dllImportSearchPath: {dllImportSearchPath.ToString()}"); - throw new ArgumentException(); - } + ValidatePInvoke(); + } + catch (Exception e) + { + Console.WriteLine($"Test Failure: {e}"); + return 101; + } - return NativeLibrary.Load(NativeLibraryToLoad.Name, asm, null); - }; + return 100; + } - DllImportResolver anotherResolver = - (string libraryName, Assembly asm, DllImportSearchPath? dllImportSearchPath) => - IntPtr.Zero; + public static void ValidateSetDllImportResolver() + { + Console.WriteLine($"Running {nameof(ValidateSetDllImportResolver)}..."); + Assembly assembly = Assembly.GetExecutingAssembly(); + DllImportResolver resolver = Resolver.Instance.Callback; - try - { - NativeSum(10, 10); - Console.WriteLine("Exception expected: no callback registered yet"); - return 101; - } - catch (DllNotFoundException) {} + // Invalid arguments + Assert.Throws(() => NativeLibrary.SetDllImportResolver(null, resolver), "Exception expected for null assembly parameter"); + Assert.Throws(() => NativeLibrary.SetDllImportResolver(assembly, null), "Exception expected for null resolver parameter"); - try - { - NativeLibrary.SetDllImportResolver(null, resolver); + // No callback registered yet + Assert.Throws(() => NativeSum(10, 10)); - Console.WriteLine("Exception expected: assembly parameter null"); - return 102; - } - catch (ArgumentNullException) { } + // Set a resolver callback + NativeLibrary.SetDllImportResolver(assembly, resolver); - try - { - NativeLibrary.SetDllImportResolver(assembly, null); + // Try to set the resolver again on the same assembly + Assert.Throws(() => NativeLibrary.SetDllImportResolver(assembly, resolver), "Should not be able to re-register resolver"); - Console.WriteLine("Exception expected: resolver parameter null"); - return 103; - } - catch (ArgumentNullException) { } + // Try to set another resolver on the same assembly + DllImportResolver anotherResolver = + (string libraryName, Assembly asm, DllImportSearchPath? dllImportSearchPath) => + IntPtr.Zero; + Assert.Throws(() => NativeLibrary.SetDllImportResolver(assembly, anotherResolver), "Should not be able to register another resolver"); + } - // Set a resolver callback - NativeLibrary.SetDllImportResolver(assembly, resolver); + public static void ValidatePInvoke() + { + Console.WriteLine($"Running {nameof(ValidatePInvoke)}..."); + int addend1 = rand.Next(int.MaxValue / 2); + int addend2 = rand.Next(int.MaxValue / 2); + int expected = addend1 + addend2; - try - { - // Try to set another resolver on the same assembly. - NativeLibrary.SetDllImportResolver(assembly, anotherResolver); + Resolver.Instance.Reset(); + int value = NativeSum(addend1, addend2); + Resolver.Instance.Validate(NativeLibraryToLoad.InvalidName); + Assert.AreEqual(expected, value, $"Unexpected return value from {nameof(NativeSum)}"); + } - Console.WriteLine("Exception expected: Trying to register second resolver"); - return 104; - } - catch (InvalidOperationException) { } + private class Resolver + { + public static Resolver Instance = new Resolver(); - if (NativeSum(10, 10) != 20) - { - Console.WriteLine("Unexpected ReturnValue from NativeSum()"); - return 105; - } + public DllImportResolver Callback => ResolveDllImport; + + private List invocations = new List(); + + public void Reset() + { + invocations.Clear(); } - catch (Exception e) + + public void Validate(params string[] expectedNames) { - Console.WriteLine($"Unexpected exception: {e.ToString()} {e.Message}"); - return 106; + Assert.AreEqual(expectedNames.Length, invocations.Count, $"Unexpected invocation count for registered {nameof(DllImportResolver)}."); + for (int i = 0; i < expectedNames.Length; i++) + Assert.AreEqual(expectedNames[i], invocations[i], $"Unexpected library name received by registered resolver."); } - return 100; + private IntPtr ResolveDllImport(string libraryName, Assembly asm, DllImportSearchPath? dllImportSearchPath) + { + invocations.Add(libraryName); + + if (string.Equals(libraryName, NativeLibraryToLoad.InvalidName)) + { + Assert.AreEqual(DllImportSearchPath.System32, dllImportSearchPath, $"Unexpected {nameof(dllImportSearchPath)}: {dllImportSearchPath.ToString()}"); + return NativeLibrary.Load(NativeLibraryToLoad.Name, asm, null); + } + + return IntPtr.Zero; + } } [DllImport(NativeLibraryToLoad.InvalidName)] diff --git a/src/coreclr/tests/src/Interop/NativeLibrary/Callback/CallbackTests.csproj b/src/coreclr/tests/src/Interop/NativeLibrary/Callback/CallbackTests.csproj index c1532ef..f3b8093 100644 --- a/src/coreclr/tests/src/Interop/NativeLibrary/Callback/CallbackTests.csproj +++ b/src/coreclr/tests/src/Interop/NativeLibrary/Callback/CallbackTests.csproj @@ -9,5 +9,6 @@ + diff --git a/src/coreclr/tests/src/Interop/NativeLibrary/NativeLibraryToLoad/NativeLibraryToLoad.cs b/src/coreclr/tests/src/Interop/NativeLibrary/NativeLibraryToLoad/NativeLibraryToLoad.cs index 55c4581..251c79c 100644 --- a/src/coreclr/tests/src/Interop/NativeLibrary/NativeLibraryToLoad/NativeLibraryToLoad.cs +++ b/src/coreclr/tests/src/Interop/NativeLibrary/NativeLibraryToLoad/NativeLibraryToLoad.cs @@ -32,4 +32,4 @@ public class NativeLibraryToLoad string directory = Path.GetDirectoryName(assembly.Location); return Path.Combine(directory, GetFileName()); } -} \ No newline at end of file +} diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/NativeLibrary.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/NativeLibrary.cs index 9bd65fe..bcff553 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/NativeLibrary.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/NativeLibrary.cs @@ -68,10 +68,12 @@ namespace System.Runtime.InteropServices /// If DllImportSearchPath parameter is non-null, the flags in this enumeration are used. /// Otherwise, the flags specified by the DefaultDllImportSearchPaths attribute on the /// calling assembly (if any) are used. - /// This LoadLibrary() method does not invoke the managed call-backs for native library resolution: - /// * The per-assembly registered callback + /// This method follows the native library resolution for the AssemblyLoadContext of the + /// specified assembly. It will invoke the managed extension points: /// * AssemblyLoadContext.LoadUnmanagedDll() /// * AssemblyLoadContext.ResolvingUnmanagedDllEvent + /// It does not invoke extension points that are not tied to the AssemblyLoadContext: + /// * The per-assembly registered DllImportResolver callback /// /// The name of the native library to be loaded /// The assembly loading the native library @@ -79,7 +81,7 @@ namespace System.Runtime.InteropServices /// The handle for the loaded library /// If libraryPath or assembly is null /// If assembly is not a RuntimeAssembly - /// If the library can't be found. + /// If the library can't be found. /// If the library is not valid. public static IntPtr Load(string libraryName, Assembly assembly, DllImportSearchPath? searchPath) { @@ -98,10 +100,21 @@ namespace System.Runtime.InteropServices /// /// NativeLibrary Loader: High-level API that doesn't throw. + /// Given a library name, this function searches specific paths based on the + /// runtime configuration, input parameters, and attributes of the calling assembly. + /// If DllImportSearchPath parameter is non-null, the flags in this enumeration are used. + /// Otherwise, the flags specified by the DefaultDllImportSearchPaths attribute on the + /// calling assembly (if any) are used. + /// This method follows the native library resolution for the AssemblyLoadContext of the + /// specified assembly. It will invoke the managed extension points: + /// * AssemblyLoadContext.LoadUnmanagedDll() + /// * AssemblyLoadContext.ResolvingUnmanagedDllEvent + /// It does not invoke extension points that are not tied to the AssemblyLoadContext: + /// * The per-assembly registered DllImportResolver callback /// /// The name of the native library to be loaded - /// The search path /// The assembly loading the native library + /// The search path /// The out-parameter for the loaded native library handle /// True on successful load, false otherwise /// If libraryPath or assembly is null @@ -214,7 +227,7 @@ namespace System.Runtime.InteropServices } catch (ArgumentException) { - // ConditionalWealTable throws ArgumentException if the Key already exists + // ConditionalWeakTable throws ArgumentException if the Key already exists throw new InvalidOperationException(SR.InvalidOperation_CannotRegisterSecondResolver); } } -- 2.7.4