Don't probe for A/W variants of entry points on non-Windows. (#33250)
authorJeremy Koritzinsky <jekoritz@microsoft.com>
Wed, 11 Mar 2020 04:16:07 +0000 (21:16 -0700)
committerGitHub <noreply@github.com>
Wed, 11 Mar 2020 04:16:07 +0000 (21:16 -0700)
* Don't probe for A/W variants of entry points on non-Windows.

* Cal PAL_GetProcAddressDirect directly on Unix.

* Remove dependence on A/W probing on non-Windows.

src/coreclr/src/vm/method.cpp
src/coreclr/src/vm/method.hpp
src/coreclr/tests/src/Interop/DllImportAttribute/ExactSpelling/ExactSpellingTest.cs
src/coreclr/tests/src/Interop/DllImportAttribute/ExactSpelling/ExactSpellingTest.csproj
src/coreclr/tests/src/JIT/Directed/pinvoke/pinvokeexamplenative.cpp

index 9af4bb3..39e1e7d 100644 (file)
@@ -5254,6 +5254,8 @@ void NDirectMethodDesc::InterlockedSetNDirectFlags(WORD wFlags)
 }
 
 #ifndef CROSSGEN_COMPILE
+
+#ifdef TARGET_WINDOWS
 FARPROC NDirectMethodDesc::FindEntryPointWithMangling(NATIVE_LIBRARY_HANDLE hMod, PTR_CUTF8 entryPointName) const
 {
     CONTRACTL
@@ -5264,11 +5266,7 @@ FARPROC NDirectMethodDesc::FindEntryPointWithMangling(NATIVE_LIBRARY_HANDLE hMod
     }
     CONTRACTL_END;
 
-#ifndef TARGET_UNIX
     FARPROC pFunc = GetProcAddress(hMod, entryPointName);
-#else
-    FARPROC pFunc = PAL_GetProcAddressDirect(hMod, entryPointName);
-#endif
 
 #if defined(TARGET_X86)
 
@@ -5309,6 +5307,7 @@ FARPROC NDirectMethodDesc::FindEntryPointWithMangling(NATIVE_LIBRARY_HANDLE hMod
 
     return pFunc;
 }
+#endif
 
 //*******************************************************************************
 LPVOID NDirectMethodDesc::FindEntryPoint(NATIVE_LIBRARY_HANDLE hMod) const
@@ -5323,21 +5322,21 @@ LPVOID NDirectMethodDesc::FindEntryPoint(NATIVE_LIBRARY_HANDLE hMod) const
 
     char const * funcName = GetEntrypointName();
 
-    FARPROC pFunc = NULL;
-
-#ifndef TARGET_UNIX
+#ifndef TARGET_WINDOWS
+    return reinterpret_cast<LPVOID>(PAL_GetProcAddressDirect(hMod, funcName));
+#else
     // Handle ordinals.
     if (funcName[0] == '#')
     {
         long ordinal = atol(funcName + 1);
         return reinterpret_cast<LPVOID>(GetProcAddress(hMod, (LPCSTR)(size_t)((UINT16)ordinal)));
     }
-#endif
 
-    // Just look for the user-provided name without charset suffixes.  If it is unicode fcn, we are going
+    // Just look for the user-provided name without charset suffixes.
+    // If  it is unicode fcn, we are going
     // to need to check for the 'W' API because it takes precedence over the
     // unmangled one (on NT some APIs have unmangled ANSI exports).
-    pFunc = FindEntryPointWithMangling(hMod, funcName);
+    FARPROC pFunc = FindEntryPointWithMangling(hMod, funcName);
     if ((pFunc != NULL && IsNativeAnsi()) || IsNativeNoMangled())
     {
         return reinterpret_cast<LPVOID>(pFunc);
@@ -5369,6 +5368,7 @@ LPVOID NDirectMethodDesc::FindEntryPoint(NATIVE_LIBRARY_HANDLE hMod) const
     }
 
     return reinterpret_cast<LPVOID>(pFunc);
+#endif
 }
 #endif // CROSSGEN_COMPILE
 
index 7b150db..64936b4 100644 (file)
@@ -3140,9 +3140,10 @@ public:
     //
     LPVOID FindEntryPoint(NATIVE_LIBRARY_HANDLE hMod) const;
 
+#ifdef TARGET_WINDOWS
 private:
     FARPROC FindEntryPointWithMangling(NATIVE_LIBRARY_HANDLE mod, PTR_CUTF8 entryPointName) const;
-
+#endif
 public:
 
     void SetStackArgumentSize(WORD cbDstBuffer, CorPinvokeMap unmgdCallConv)
index c5e2ebf..d564818 100644 (file)
@@ -4,6 +4,7 @@
 
 using System;
 using System.Runtime.InteropServices;
+using TestLibrary;
 
 class ExactSpellingTest
 {
@@ -43,83 +44,113 @@ class ExactSpellingTest
         public static extern int Marshal_Int_InOut2([In, Out] int intValue);
     }
 
-    public static int Main(string[] args)
+    private static void ExactSpellingTrue()
     {
-        int failures = 0;
         int intManaged = 1000;
         int intNative = 2000;
         int intReturn = 3000;
-        
+
         Console.WriteLine("Method Unicode.Marshal_Int_InOut: ExactSpelling = true");
         int int1 = intManaged;
         int intRet1 = Unicode.Marshal_Int_InOut(int1);
-        failures += Verify(intReturn, intManaged, intRet1, int1);
-        
+        Verify(intReturn, intManaged, intRet1, int1);
+
         Console.WriteLine("Method Unicode.MarshalPointer_Int_InOut: ExactSpelling = true");
         int int2 = intManaged;
         int intRet2 = Unicode.MarshalPointer_Int_InOut(ref int2);
-        
-        failures += Verify(intReturn, intNative, intRet2, int2);
+
+        Verify(intReturn, intNative, intRet2, int2);
 
         Console.WriteLine("Method Ansi.Marshal_Int_InOut: ExactSpelling = true");
         int int3 = intManaged;
         int intRet3 = Ansi.Marshal_Int_InOut(int3);
-        failures += Verify(intReturn, intManaged, intRet3, int3);
+        Verify(intReturn, intManaged, intRet3, int3);
 
         Console.WriteLine("Method Ansi.MarshalPointer_Int_InOut: ExactSpelling = true");
         int int4 = intManaged;
         int intRet4 = Ansi.MarshalPointer_Int_InOut(ref int4);
-        failures += Verify(intReturn, intNative, intRet4, int4);
+        Verify(intReturn, intNative, intRet4, int4);
+    }
 
+    private static void ExactSpellingFalse_Windows()
+    {
+        int intManaged = 1000;
+        int intNative = 2000;
         int intReturnAnsi = 4000;
         int intReturnUnicode = 5000;
 
         Console.WriteLine("Method Unicode.Marshal_Int_InOut2: ExactSpelling = false");
         int int5 = intManaged;
         int intRet5 = Unicode.Marshal_Int_InOut2(int5);
-        failures += Verify(intReturnUnicode, intManaged, intRet5, int5);
-        
+        Verify(intReturnUnicode, intManaged, intRet5, int5);
+
         Console.WriteLine("Method Unicode.MarshalPointer_Int_InOut2: ExactSpelling = false");
         int int6 = intManaged;
         int intRet6 = Unicode.MarshalPointer_Int_InOut2(ref int6);
-        failures += Verify(intReturnUnicode, intNative, intRet6, int6);
+        Verify(intReturnUnicode, intNative, intRet6, int6);
 
         Console.WriteLine("Method Ansi.Marshal_Int_InOut2: ExactSpelling = false");
         int int7 = intManaged;
         int intRet7 = Ansi.Marshal_Int_InOut2(int7);
-        failures += Verify(intReturnAnsi, intManaged, intRet7, int7);
+        Verify(intReturnAnsi, intManaged, intRet7, int7);
 
         Console.WriteLine("Method Ansi.MarshalPointer_Int_InOut2: ExactSpelling = false");
         int int8 = intManaged;
         int intRet8 = Ansi.MarshalPointer_Int_InOut2(ref int8);
-        failures += Verify(intReturnAnsi, intNative, intRet8, int8);
+        Verify(intReturnAnsi, intNative, intRet8, int8);
 
         Console.WriteLine("Method Auto.Marshal_Int_InOut: ExactSpelling = false. Verify CharSet.Auto behavior per-platform.");
         int int9 = intManaged;
         int intRet9 = Auto.Marshal_Int_InOut2(int9);
-#if TARGET_WINDOWS
-        failures += Verify(intReturnUnicode, intManaged, intRet9, int9);
-#else
-        failures += Verify(intReturnAnsi, intManaged, intRet9, int9);
-#endif
-        
-        return 100 + failures;
+        Verify(intReturnUnicode, intManaged, intRet9, int9);
+    }
+
+    private static void ExactSpellingFalse_NonWindows()
+    {
+        int intManaged = 1000;
+        int intNative = 2000;
+        Console.WriteLine("Method Unicode.Marshal_Int_InOut2: ExactSpelling = false");
+        int int5 = intManaged;
+        Assert.Throws<EntryPointNotFoundException>(() => Unicode.Marshal_Int_InOut2(int5));
+
+        Console.WriteLine("Method Unicode.MarshalPointer_Int_InOut2: ExactSpelling = false");
+        int int6 = intManaged;
+        Assert.Throws<EntryPointNotFoundException>(() => Unicode.MarshalPointer_Int_InOut2(ref int6));
+
+        Console.WriteLine("Method Ansi.Marshal_Int_InOut2: ExactSpelling = false");
+        int int7 = intManaged;
+        Assert.Throws<EntryPointNotFoundException>(() => Ansi.Marshal_Int_InOut2(int7));
+
+        Console.WriteLine("Method Ansi.MarshalPointer_Int_InOut2: ExactSpelling = false");
+        int int8 = intManaged;
+        Assert.Throws<EntryPointNotFoundException>(() => Ansi.MarshalPointer_Int_InOut2(ref int8));
     }
 
-    private static int Verify(int expectedReturnValue, int expectedParameterValue, int actualReturnValue, int actualParameterValue)
+    public static int Main(string[] args)
     {
-        int failures = 0;
-        if (expectedReturnValue != actualReturnValue)
+        try
         {
-            failures++;
-            Console.WriteLine($"The return value is wrong. Expected {expectedReturnValue}, got {actualReturnValue}");
+            ExactSpellingTrue();
+            if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
+            {
+                ExactSpellingFalse_Windows();
+            }
+            else
+            {
+                ExactSpellingFalse_NonWindows();
+            }
         }
-        if (expectedParameterValue != actualParameterValue)
+        catch (System.Exception ex)
         {
-            failures++;
-            Console.WriteLine($"The parameter value is changed. Expected {expectedParameterValue}, got {actualParameterValue}");
+            Console.WriteLine(ex.ToString());
+            return 101;
         }
+        return 100;
+    }
 
-        return failures;
+    private static void Verify(int expectedReturnValue, int expectedParameterValue, int actualReturnValue, int actualParameterValue)
+    {
+        Assert.AreEqual(expectedReturnValue, actualReturnValue);
+        Assert.AreEqual(expectedParameterValue, actualParameterValue);
     }
 }
index 3f5d395..c04abb8 100644 (file)
@@ -7,6 +7,7 @@
   </ItemGroup>
   <ItemGroup>
     <ProjectReference Include="CMakeLists.txt" />
+    <ProjectReference Include="$(TestSourceDir)Common/CoreCLRTestLibrary/CoreCLRTestLibrary.csproj" />
   </ItemGroup>
   <ItemGroup>
     <TraitTags Include="OsSpecific" />
index d2245ad..c8e16ea 100644 (file)
@@ -23,7 +23,7 @@
 #define __int16     short int
 #define __int8      char        // assumes char is signed
 
-#endif 
+#endif
 
 #include <cstddef>
 
@@ -50,7 +50,7 @@ DestroyMenu(
 
 EXPORT_API
 unsigned __int32
-AppendMenuA(
+AppendMenu(
     HMENU hMenu,
     unsigned __int32 uFlags,
     unsigned __int32 uID,
@@ -69,7 +69,7 @@ AppendMenuA(
 
 EXPORT_API
 __int32
-GetMenuStringA(
+GetMenuString(
     HMENU hMenu,
     unsigned __int32 uIDItem,
     char * lpString,
@@ -79,16 +79,16 @@ GetMenuStringA(
 {
     if (flags != 0x400)
     {
-        throw "GetMenuStringA: only MF_BYPOSITION (0x400) supported for flags";
+        throw "GetMenuString: only MF_BYPOSITION (0x400) supported for flags";
     }
 
     if (cchMax < 0)
     {
-        throw "GetMenuStringA: invalid argument (cchMax)";
+        throw "GetMenuString: invalid argument (cchMax)";
     }
 
     if (uIDItem >= hMenu->size())
-    { 
+    {
         return 0;
     }
 
@@ -105,7 +105,7 @@ GetMenuStringA(
     {
         cch = cchMax - 1;
     }
-   
+
     memcpy(lpString, str.c_str(), cch);
     lpString[cch] = '\0';