From: Jeremy Koritzinsky Date: Thu, 6 Dec 2018 17:56:21 +0000 (-0800) Subject: Port CustomMarshalers.dll to C# and .NET Core (#21343) X-Git-Tag: accepted/tizen/unified/20190422.045933~432 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5955967d808d1d88ef927235276d3ca82b5d9363;p=platform%2Fupstream%2Fcoreclr.git Port CustomMarshalers.dll to C# and .NET Core (#21343) * Import CustomMarshalers from corefx PR. * Add tests for marshalling IEnumerator and IEnumerable using the System.Runtime.InteropServices.CustomMarshalers. * Verify that the marshaller is used in the managed -> native direction in the tests. * Remove #if's * PR Feedback. * Move ComHelpers.h into interop common headers. * Add some template machinery to ensure that we don't accidentally pass in an incorrect type into DoQueryInterface. * Add VariantClear * Make marshalers internal and root them in ILLinkTrim.xml * Revert enable-if machinery. * Add tests for the pass-through unwrapping done by the EnumeratorToEnumVariantMarshaler. * Remove some of my helper code. * Added todo about exposing EnumerableToDispatchMarshaler. * PR Feedback. * Removed the last placement new. * Use the runtime hooks to enable us to not have to publically expose the custom marshalers. * Remove out of date comments. --- diff --git a/src/System.Private.CoreLib/ILLinkTrim.xml b/src/System.Private.CoreLib/ILLinkTrim.xml index 5445899..53eae5d 100644 --- a/src/System.Private.CoreLib/ILLinkTrim.xml +++ b/src/System.Private.CoreLib/ILLinkTrim.xml @@ -39,5 +39,8 @@ + + + diff --git a/src/System.Private.CoreLib/Resources/Strings.resx b/src/System.Private.CoreLib/Resources/Strings.resx index 982d552..d48b228 100644 --- a/src/System.Private.CoreLib/Resources/Strings.resx +++ b/src/System.Private.CoreLib/Resources/Strings.resx @@ -3652,4 +3652,10 @@ Precision cannot be larger than {0}. + + The supplied object does not implement ICloneable. + + + The returned enumerator does not implement IEnumVARIANT. + diff --git a/src/System.Private.CoreLib/System.Private.CoreLib.csproj b/src/System.Private.CoreLib/System.Private.CoreLib.csproj index 040cd88..b9842af 100644 --- a/src/System.Private.CoreLib/System.Private.CoreLib.csproj +++ b/src/System.Private.CoreLib/System.Private.CoreLib.csproj @@ -269,6 +269,7 @@ + @@ -366,8 +367,14 @@ + + + + + + - + diff --git a/src/System.Private.CoreLib/shared/System/Collections/IEnumerable.cs b/src/System.Private.CoreLib/shared/System/Collections/IEnumerable.cs index e5edeff..3bc1547 100644 --- a/src/System.Private.CoreLib/shared/System/Collections/IEnumerable.cs +++ b/src/System.Private.CoreLib/shared/System/Collections/IEnumerable.cs @@ -7,6 +7,8 @@ using System.Runtime.InteropServices; namespace System.Collections { + [Guid("496B0ABE-CDEE-11d3-88E8-00902754C43A")] + [ComVisible(true)] public interface IEnumerable { // Returns an IEnumerator for this enumerable Object. The enumerator provides diff --git a/src/System.Private.CoreLib/shared/System/HResults.cs b/src/System.Private.CoreLib/shared/System/HResults.cs index 4a5ec0d..a3d87ac 100644 --- a/src/System.Private.CoreLib/shared/System/HResults.cs +++ b/src/System.Private.CoreLib/shared/System/HResults.cs @@ -26,6 +26,7 @@ namespace System internal static partial class HResults { internal const int S_OK = unchecked((int)0x00000000); + internal const int S_FALSE = unchecked((int)0x1); internal const int COR_E_ABANDONEDMUTEX = unchecked((int)0x8013152D); internal const int COR_E_AMBIGUOUSMATCH = unchecked((int)0x8000211D); internal const int COR_E_APPDOMAINUNLOADED = unchecked((int)0x80131014); diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsSink.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsSink.cs index b9f10bc..0ddde5d 100644 --- a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsSink.cs +++ b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsSink.cs @@ -127,9 +127,9 @@ namespace System.Runtime.InteropServices CustomQueryInterfaceResult ICustomQueryInterface.GetInterface(ref Guid iid, out IntPtr ppv) { ppv = IntPtr.Zero; - if (iid == _iidSourceItf || iid == typeof(NativeMethods.IDispatch).GUID) + if (iid == _iidSourceItf || iid == typeof(IDispatch).GUID) { - ppv = Marshal.GetComInterfaceForObject(this, typeof(NativeMethods.IDispatch), CustomQueryInterfaceMode.Ignore); + ppv = Marshal.GetComInterfaceForObject(this, typeof(IDispatch), CustomQueryInterfaceMode.Ignore); return CustomQueryInterfaceResult.Handled; } else if (iid == IID_IManagedObject) diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComTypes/IEnumerable.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComTypes/IEnumerable.cs new file mode 100644 index 0000000..e541960 --- /dev/null +++ b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComTypes/IEnumerable.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information + +using System; +namespace System.Runtime.InteropServices.ComTypes +{ + /*========================================================================== + ** Interface: IEnumerable + ** Purpose: + ** This interface is redefined here since the original IEnumerable interface + ** has all its methods marked as ecall's since it is a managed standard + ** interface. This interface is used from within the runtime to make a call + ** on the COM server directly when it implements the IEnumerable interface. + ==========================================================================*/ + [Guid("496B0ABE-CDEE-11d3-88E8-00902754C43A")] + internal interface IEnumerable + { + [DispId(-4)] + System.Collections.IEnumerator GetEnumerator(); + } +} diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/ComDataHelpers.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/ComDataHelpers.cs new file mode 100644 index 0000000..c398987 --- /dev/null +++ b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/ComDataHelpers.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace System.Runtime.InteropServices.CustomMarshalers +{ + internal static class ComDataHelpers + { + public static TView GetOrCreateManagedViewFromComData(object comObject, Func createCallback) + { + object key = typeof(TView); + + if (Marshal.GetComObjectData(comObject, key) is TView managedView) + { + return managedView; + } + + managedView = createCallback((T)comObject); + if (!Marshal.SetComObjectData(comObject, key, managedView)) + { + managedView = (TView)Marshal.GetComObjectData(comObject, key); + } + return managedView; + } + } +} diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumVariantViewOfEnumerator.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumVariantViewOfEnumerator.cs new file mode 100644 index 0000000..d43b682 --- /dev/null +++ b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumVariantViewOfEnumerator.cs @@ -0,0 +1,102 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Runtime.InteropServices.ComTypes; + +namespace System.Runtime.InteropServices.CustomMarshalers +{ + internal class EnumVariantViewOfEnumerator : IEnumVARIANT, ICustomAdapter + { + public EnumVariantViewOfEnumerator(IEnumerator enumerator) + { + if (enumerator is null) + { + throw new ArgumentNullException(nameof(enumerator)); + } + + Enumerator = enumerator; + } + + public IEnumerator Enumerator { get; } + + public IEnumVARIANT Clone() + { + if (Enumerator is ICloneable clonable) + { + return new EnumVariantViewOfEnumerator((IEnumerator)clonable.Clone()); + } + else + { + throw new COMException(SR.Arg_EnumNotCloneable, HResults.E_FAIL); + } + } + + public int Next(int celt, object[] rgVar, IntPtr pceltFetched) + { + int numElements = 0; + + try + { + if (celt > 0 && rgVar == null) + { + return HResults.E_INVALIDARG; + } + + while ((numElements < celt) && Enumerator.MoveNext()) + { + rgVar[numElements++] = Enumerator.Current; + } + + if (pceltFetched != IntPtr.Zero) + { + Marshal.WriteInt32(pceltFetched, numElements); + } + } + catch (Exception e) + { + return e.HResult; + } + + return numElements == celt ? HResults.S_OK : HResults.S_FALSE; + } + + public int Reset() + { + try + { + Enumerator.Reset(); + } + catch (Exception e) + { + return e.HResult; + } + + return HResults.S_OK; + } + + public int Skip(int celt) + { + try + { + while (celt > 0 && Enumerator.MoveNext()) + { + celt--; + } + } + catch (Exception e) + { + return e.HResult; + } + + return celt == 0 ? HResults.S_OK : HResults.S_FALSE; + } + + public object GetUnderlyingObject() + { + return Enumerator; + } + } +} diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumerableToDispatchMarshaler.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumerableToDispatchMarshaler.cs new file mode 100644 index 0000000..b3d319a --- /dev/null +++ b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumerableToDispatchMarshaler.cs @@ -0,0 +1,56 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections; + +namespace System.Runtime.InteropServices.CustomMarshalers +{ + internal class EnumerableToDispatchMarshaler : ICustomMarshaler + { + private static readonly EnumerableToDispatchMarshaler s_enumerableToDispatchMarshaler = new EnumerableToDispatchMarshaler(); + + public static ICustomMarshaler GetInstance(string cookie) => s_enumerableToDispatchMarshaler; + + private EnumerableToDispatchMarshaler() + { + } + + public void CleanUpManagedData(object ManagedObj) + { + } + + public void CleanUpNativeData(IntPtr pNativeData) + { + Marshal.Release(pNativeData); + } + + public int GetNativeDataSize() + { + // Return -1 to indicate the managed type this marshaler handles is not a value type. + return -1; + } + + public IntPtr MarshalManagedToNative(object ManagedObj) + { + if (ManagedObj == null) + { + throw new ArgumentNullException(nameof(ManagedObj)); + } + + return Marshal.GetComInterfaceForObject(ManagedObj); + } + + public object MarshalNativeToManaged(IntPtr pNativeData) + { + if (pNativeData == IntPtr.Zero) + { + throw new ArgumentNullException(nameof(pNativeData)); + } + + object comObject = Marshal.GetObjectForIUnknown(pNativeData); + + return ComDataHelpers.GetOrCreateManagedViewFromComData(comObject, obj => new EnumerableViewOfDispatch(obj)); + } + } +} diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumerableViewOfDispatch.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumerableViewOfDispatch.cs new file mode 100644 index 0000000..8de914e --- /dev/null +++ b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumerableViewOfDispatch.cs @@ -0,0 +1,51 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections; +using System.Runtime.InteropServices.ComTypes; + +namespace System.Runtime.InteropServices.CustomMarshalers +{ + internal class EnumerableViewOfDispatch : ICustomAdapter, System.Collections.IEnumerable + { + // Reserved DISPID slot for getting an enumerator from an IDispatch-implementing COM interface. + private const int DISPID_NEWENUM = -4; + private const int LCID_DEFAULT = 1; + private readonly object _dispatch; + + public EnumerableViewOfDispatch(object dispatch) + { + _dispatch = dispatch; + } + + private IDispatch Dispatch => (IDispatch)_dispatch; + + public IEnumerator GetEnumerator() + { + DISPPARAMS dispParams = new DISPPARAMS(); + Guid guid = Guid.Empty; + Dispatch.Invoke( + DISPID_NEWENUM, + ref guid, + LCID_DEFAULT, + InvokeFlags.DISPATCH_METHOD | InvokeFlags.DISPATCH_PROPERTYGET, + ref dispParams, + out object result, + IntPtr.Zero, + IntPtr.Zero); + + if (!(result is IEnumVARIANT enumVariant)) + { + throw new InvalidOperationException(SR.InvalidOp_InvalidNewEnumVariant); + } + + return (IEnumerator)EnumeratorToEnumVariantMarshaler.GetInstance(null).MarshalNativeToManaged(Marshal.GetIUnknownForObject(enumVariant)); + } + + public object GetUnderlyingObject() + { + return _dispatch; + } + } +} diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumeratorToEnumVariantMarshaler.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumeratorToEnumVariantMarshaler.cs new file mode 100644 index 0000000..fddc0c7 --- /dev/null +++ b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumeratorToEnumVariantMarshaler.cs @@ -0,0 +1,77 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Runtime.InteropServices.ComTypes; +using System.Text; + +namespace System.Runtime.InteropServices.CustomMarshalers +{ + internal class EnumeratorToEnumVariantMarshaler : ICustomMarshaler + { + private static readonly EnumeratorToEnumVariantMarshaler s_enumeratorToEnumVariantMarshaler = new EnumeratorToEnumVariantMarshaler(); + + public static ICustomMarshaler GetInstance(string cookie) => s_enumeratorToEnumVariantMarshaler; + + private EnumeratorToEnumVariantMarshaler() + { + } + + public void CleanUpManagedData(object ManagedObj) + { + } + + public void CleanUpNativeData(IntPtr pNativeData) + { + Marshal.Release(pNativeData); + } + + public int GetNativeDataSize() + { + // Return -1 to indicate the managed type this marshaler handles is not a value type. + return -1; + } + + public IntPtr MarshalManagedToNative(object ManagedObj) + { + if (ManagedObj == null) + { + throw new ArgumentNullException(nameof(ManagedObj)); + } + + if (ManagedObj is EnumeratorViewOfEnumVariant view) + { + return Marshal.GetComInterfaceForObject(view.GetUnderlyingObject()); + } + + EnumVariantViewOfEnumerator nativeView = new EnumVariantViewOfEnumerator((IEnumerator)ManagedObj); + + return Marshal.GetComInterfaceForObject(nativeView); + } + + public object MarshalNativeToManaged(IntPtr pNativeData) + { + if (pNativeData == IntPtr.Zero) + { + throw new ArgumentNullException(nameof(pNativeData)); + } + + object comObject = Marshal.GetObjectForIUnknown(pNativeData); + + if (!comObject.GetType().IsCOMObject) + { + if (comObject is EnumVariantViewOfEnumerator enumVariantView) + { + return enumVariantView.Enumerator; + } + + return comObject as IEnumerator; + } + + return ComDataHelpers.GetOrCreateManagedViewFromComData(comObject, var => new EnumeratorViewOfEnumVariant(var)); + } + } +} diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumeratorViewOfEnumVariant.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumeratorViewOfEnumVariant.cs new file mode 100644 index 0000000..b5ef969 --- /dev/null +++ b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumeratorViewOfEnumVariant.cs @@ -0,0 +1,69 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Runtime.InteropServices.ComTypes; +using System.Text; + +namespace System.Runtime.InteropServices.CustomMarshalers +{ + internal class EnumeratorViewOfEnumVariant : ICustomAdapter, IEnumerator + { + private readonly IEnumVARIANT _enumVariantObject; + private bool _fetchedLastObject; + private object[] _nextArray = new object[1]; + private object _current; + + public EnumeratorViewOfEnumVariant(IEnumVARIANT enumVariantObject) + { + _enumVariantObject = enumVariantObject; + _fetchedLastObject = false; + _current = null; + } + + public object Current => _current; + + public unsafe bool MoveNext() + { + if (_fetchedLastObject) + { + _current = null; + return false; + } + + int numFetched = 0; + + if (_enumVariantObject.Next(1, _nextArray, (IntPtr)(&numFetched)) == HResults.S_FALSE) + { + _fetchedLastObject = true; + + if (numFetched == 0) + { + _current = null; + return false; + } + } + + _current = _nextArray[0]; + + return true; + } + + public void Reset() + { + int hr = _enumVariantObject.Reset(); + if (hr < 0) + { + Marshal.ThrowExceptionForHR(hr); + } + } + + public object GetUnderlyingObject() + { + return _enumVariantObject; + } + } +} diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/IDispatch.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/IDispatch.cs new file mode 100644 index 0000000..183efa5 --- /dev/null +++ b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/IDispatch.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices.ComTypes; +using System.Text; + +namespace System.Runtime.InteropServices +{ + [ComImport] + [Guid("00020400-0000-0000-C000-000000000046")] + [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] + internal interface IDispatch + { + void GetIDsOfNames( + ref Guid riid, + [MarshalAs(UnmanagedType.LPArray, ArraySubType = UnmanagedType.LPWStr, SizeParamIndex = 2), In] + string[] rgszNames, + int cNames, + int lcid, + [Out] int[] rgDispId); + + ITypeInfo GetTypeInfo( + int iTInfo, + int lcid); + + int GetTypeInfoCount(); + + void Invoke( + int dispIdMember, + ref Guid riid, + int lcid, + InvokeFlags wFlags, + ref DISPPARAMS pDispParams, + out object pVarResult, + IntPtr pExcepInfo, + IntPtr puArgErr); + } + + [Flags] + internal enum InvokeFlags : short + { + DISPATCH_METHOD = 1, + DISPATCH_PROPERTYGET = 2, + DISPATCH_PROPERTYPUT = 4, + DISPATCH_PROPERTYPUTREF = 8 + } +} diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/NativeMethods.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/NativeMethods.cs deleted file mode 100644 index bd7e9e5..0000000 --- a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/NativeMethods.cs +++ /dev/null @@ -1,24 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -#if FEATURE_COMINTEROP - -namespace System.Runtime.InteropServices -{ - /// - /// Part of ComEventHelpers APIs which allow binding managed delegates - /// to COM's connection point based events. - /// - internal static class NativeMethods - { - [ComImport] - [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] - [Guid("00020400-0000-0000-C000-000000000046")] - internal interface IDispatch - { - } - } -} - -#endif diff --git a/src/vm/classnames.h b/src/vm/classnames.h index f7b0ce7..cb71df3 100644 --- a/src/vm/classnames.h +++ b/src/vm/classnames.h @@ -27,6 +27,7 @@ #ifdef FEATURE_COMINTEROP #define g_ECMAKeyToken "B77A5C561934E089" // The ECMA key used by some framework assemblies: mscorlib, system, etc. #define g_FXKeyToken "b03f5f7f11d50a3a" // The FX key used by other framework assemblies: System.Web, System.Drawing, etc. +#define g_CorelibAsmName "System.Private.CoreLib" #define g_SystemAsmName "System" #define g_SystemRuntimeAsmName "System.Runtime" #define g_DrawingAsmName "System.Drawing" diff --git a/src/vm/mlinfo.cpp b/src/vm/mlinfo.cpp index f9439e7..24af47c 100644 --- a/src/vm/mlinfo.cpp +++ b/src/vm/mlinfo.cpp @@ -40,7 +40,7 @@ #ifdef FEATURE_COMINTEROP - DEFINE_ASM_QUAL_TYPE_NAME(ENUMERATOR_TO_ENUM_VARIANT_CM_NAME, g_EnumeratorToEnumClassName, "CustomMarshalers", VER_ASSEMBLYVERSION_STR, g_FXKeyToken); + DEFINE_ASM_QUAL_TYPE_NAME(ENUMERATOR_TO_ENUM_VARIANT_CM_NAME, g_EnumeratorToEnumClassName, g_CorelibAsmName, VER_ASSEMBLYVERSION_STR, g_ECMAKeyToken); static const int ENUMERATOR_TO_ENUM_VARIANT_CM_NAME_LEN = lengthof(ENUMERATOR_TO_ENUM_VARIANT_CM_NAME); static const char ENUMERATOR_TO_ENUM_VARIANT_CM_COOKIE[] = {""}; diff --git a/src/vm/mngstdinterfaces.cpp b/src/vm/mngstdinterfaces.cpp index 7e93877..11bc5bf 100644 --- a/src/vm/mngstdinterfaces.cpp +++ b/src/vm/mngstdinterfaces.cpp @@ -990,9 +990,7 @@ FCIMPL1(Object*, StdMngIEnumerable::GetEnumerator, Object* refThisUNSAFE) if (retVal == NULL) { - // In desktop CLR we'll attempt to call through IDispatch(DISPID_NEWENUM) - // This is not supported in CoreCLR - COMPlusThrow(kPlatformNotSupportedException, IDS_EE_ERROR_IDISPATCH); + retVal = ObjectToOBJECTREF((Object*)GetEnumeratorWorker(args)); } GCPROTECT_END(); diff --git a/src/vm/mngstditflist.h b/src/vm/mngstditflist.h index f0d5fb3..6ee8b56 100644 --- a/src/vm/mngstditflist.h +++ b/src/vm/mngstditflist.h @@ -30,7 +30,7 @@ #define MNGSTDITF_DEFINE_METH3(FriendlyName, MethName, MethSig, FcallDecl) \ MNGSTDITF_DEFINE_METH_IMPL(FriendlyName, MethName##_3, MethName, MethSig, FcallDecl) -#define CUSTOM_MARSHALER_ASM ", CustomMarshalers, Version=" VER_ASSEMBLYVERSION_STR ", Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a" +#define CUSTOM_MARSHALER_ASM ", System.Private.CoreLib, Version=" VER_ASSEMBLYVERSION_STR ", Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a" diff --git a/tests/src/Interop/CMakeLists.txt b/tests/src/Interop/CMakeLists.txt index d9f3200..ab5d94f 100644 --- a/tests/src/Interop/CMakeLists.txt +++ b/tests/src/Interop/CMakeLists.txt @@ -72,6 +72,7 @@ if(WIN32) add_subdirectory(PInvoke/NativeCallManagedComVisible) # This test doesn't necessarily need to be Windows-only, but the implementation is very tied to Windows APIs add_subdirectory(PInvoke/DateTime) + add_subdirectory(PInvoke/IEnumerator) add_subdirectory(ArrayMarshalling/SafeArray) add_subdirectory(COM/NativeServer) add_subdirectory(COM/NativeClients/Primitives) diff --git a/tests/src/Interop/COM/NativeServer/Servers.h b/tests/src/Interop/COM/NativeServer/Servers.h index 475aede..f00edbb 100644 --- a/tests/src/Interop/COM/NativeServer/Servers.h +++ b/tests/src/Interop/COM/NativeServer/Servers.h @@ -55,7 +55,7 @@ private: }; #ifndef COM_CLIENT - #include "ComHelpers.h" + #include #define DEF_FUNC(n) virtual COM_DECLSPEC_NOTHROW HRESULT STDMETHODCALLTYPE n diff --git a/tests/src/Interop/PInvoke/IEnumerator/CMakeLists.txt b/tests/src/Interop/PInvoke/IEnumerator/CMakeLists.txt new file mode 100644 index 0000000..c040698 --- /dev/null +++ b/tests/src/Interop/PInvoke/IEnumerator/CMakeLists.txt @@ -0,0 +1,11 @@ +cmake_minimum_required (VERSION 2.6) +project (IEnumeratorNative) +include_directories(${INC_PLATFORM_DIR}) +set(SOURCES IEnumeratorNative.cpp ) + +# add the executable +add_library (IEnumeratorNative SHARED ${SOURCES}) +target_link_libraries(IEnumeratorNative ${LINK_LIBRARIES_ADDITIONAL}) + +# add the install targets +install (TARGETS IEnumeratorNative DESTINATION bin) diff --git a/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorNative.cpp b/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorNative.cpp new file mode 100644 index 0000000..43f786a --- /dev/null +++ b/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorNative.cpp @@ -0,0 +1,112 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#include "IEnumeratorNative.h" +#include + +extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE GetIntegerEnumerator(int start, int count, IEnumVARIANT** ppEnum) +{ + if (count < 0) + { + return E_INVALIDARG; + } + + *ppEnum = new IntegerEnumerator(start, count); + + return S_OK; +} + +extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE VerifyIntegerEnumerator(IEnumVARIANT* pEnum, int start, int count) +{ + if (count < 0) + { + return E_INVALIDARG; + } + + HRESULT hr = S_OK; + + VARIANT element; + ULONG numFetched; + + for(int i = start; i < start + count; ++i) + { + VariantClear(&element); + hr = pEnum->Next(1, &element, &numFetched); + if(FAILED(hr) || numFetched != 1) + { + return hr; + } + + if (V_I4(&element) != i) + { + return E_UNEXPECTED; + } + } + + hr = pEnum->Next(1, &element, &numFetched); + if (hr != S_FALSE || numFetched != 0) + { + return E_UNEXPECTED; + } + + return S_OK; +} + +extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE GetIntegerEnumeration(int start, int count, IDispatch** ppDisp) +{ + if (count < 0) + { + return E_INVALIDARG; + } + + *ppDisp = new IntegerEnumerable(start, count); + + return S_OK; +} + +extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE VerifyIntegerEnumeration(IDispatch* pDisp, int start, int count) +{ + DISPPARAMS params{}; + VARIANT result; + HRESULT hr = pDisp->Invoke( + DISPID_NEWENUM, + IID_NULL, + LOCALE_USER_DEFAULT, + DISPATCH_METHOD | DISPATCH_PROPERTYGET, + ¶ms, + &result, + NULL, + NULL + ); + + if (FAILED(hr)) + { + return hr; + } + + if(!((V_VT(&result) == VT_UNKNOWN) || (V_VT(&result) == VT_DISPATCH))) + { + return E_UNEXPECTED; + } + + IEnumVARIANT* pEnum; + + hr = V_UNKNOWN(&result)->QueryInterface(&pEnum); + + if (FAILED(hr)) + { + return hr; + } + + hr = VerifyIntegerEnumerator(pEnum, start, count); + + pEnum->Release(); + + return hr; +} + +extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE PassThroughEnumerator(IEnumVARIANT* in, IEnumVARIANT** out) +{ + return in->QueryInterface(out); +} diff --git a/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorNative.h b/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorNative.h new file mode 100644 index 0000000..c512901 --- /dev/null +++ b/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorNative.h @@ -0,0 +1,153 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#include +#include +#include +#include + +class IntegerEnumerator : public UnknownImpl, public IEnumVARIANT +{ + int start; + int count; + int current; + +public: + IntegerEnumerator(int start, int count) + :UnknownImpl(), + start(start), + count(count), + current(start) + { + } + + HRESULT STDMETHODCALLTYPE Next( + ULONG celt, + VARIANT *rgVar, + ULONG *pCeltFetched) override + { + for(*pCeltFetched = 0; *pCeltFetched < celt && current < start + count; ++*pCeltFetched, ++current) + { + VariantClear(&(rgVar[*pCeltFetched])); + V_VT(&rgVar[*pCeltFetched]) = VT_I4; + V_I4(&(rgVar[*pCeltFetched])) = current; + } + + return celt == *pCeltFetched ? S_OK : S_FALSE; + } + + HRESULT STDMETHODCALLTYPE Skip(ULONG celt) override + { + int original = current; + current = std::min(current + (int)celt, start + count); + return original + (int)celt <= start + count ? S_OK : S_FALSE; + } + + HRESULT STDMETHODCALLTYPE Reset(void) override + { + current = start; + return S_OK; + } + + HRESULT STDMETHODCALLTYPE Clone(IEnumVARIANT **ppEnum) override + { + IntegerEnumerator* clone = new IntegerEnumerator(start, count); + clone->current = current; + *ppEnum = clone; + + return S_OK; + } + + HRESULT STDMETHODCALLTYPE QueryInterface( + REFIID riid, + void** ppvObject) + { + return DoQueryInterface(this, riid, ppvObject); + } + + DEFINE_REF_COUNTING(); +}; + +class IntegerEnumerable : public UnknownImpl, public IDispatch +{ +private: + int start; + int count; +public: + IntegerEnumerable(int start, int count) + :UnknownImpl(), + start(start), + count(count) + { + } + + HRESULT STDMETHODCALLTYPE GetTypeInfoCount( + UINT *pctinfo) override + { + *pctinfo = 0; + return S_OK; + } + + HRESULT STDMETHODCALLTYPE GetTypeInfo( + UINT iTInfo, + LCID lcid, + ITypeInfo **ppTInfo) override + { + return E_NOTIMPL; + } + + HRESULT STDMETHODCALLTYPE GetIDsOfNames( + REFIID riid, + LPOLESTR *rgszNames, + UINT cNames, + LCID lcid, + DISPID *rgDispId) override + { + bool containsUnknown = false; + DISPID *curr = rgDispId; + for (UINT i = 0; i < cNames; ++i) + { + *curr = DISPID_UNKNOWN; + LPOLESTR name = rgszNames[i]; + if(wcscmp(name, W("GetEnumerator")) == 0) + { + *curr = DISPID_NEWENUM; + } + + containsUnknown &= (*curr == DISPID_UNKNOWN); + curr++; + } + + return (containsUnknown) ? DISP_E_UNKNOWNNAME : S_OK; + } + + HRESULT STDMETHODCALLTYPE Invoke( + DISPID dispIdMember, + REFIID riid, + LCID lcid, + WORD wFlags, + DISPPARAMS *pDispParams, + VARIANT *pVarResult, + EXCEPINFO *pExcepInfo, + UINT *puArgErr) override + { + if (dispIdMember == DISPID_NEWENUM && (wFlags & INVOKE_PROPERTYGET) == INVOKE_PROPERTYGET) + { + V_VT(pVarResult) = VT_UNKNOWN; + V_UNKNOWN(pVarResult) = new IntegerEnumerator(start, count); + return S_OK; + } + + return E_NOTIMPL; + } + + HRESULT STDMETHODCALLTYPE QueryInterface( + REFIID riid, + void** ppvObject) + { + return DoQueryInterface(this, riid, ppvObject); + } + + DEFINE_REF_COUNTING(); +}; diff --git a/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorTest.cs b/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorTest.cs new file mode 100644 index 0000000..0c4e4d8 --- /dev/null +++ b/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorTest.cs @@ -0,0 +1,94 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using TestLibrary; + +namespace PInvokeTests +{ + static class IEnumeratorNative + { + [DllImport(nameof(IEnumeratorNative), PreserveSig = false)] + public static extern IEnumerator GetIntegerEnumerator( + int start, + int count); + + [DllImport(nameof(IEnumeratorNative), PreserveSig = false)] + public static extern IEnumerable GetIntegerEnumeration( + int start, + int count); + + [DllImport(nameof(IEnumeratorNative), PreserveSig = false)] + public static extern void VerifyIntegerEnumerator( + IEnumerator enumerator, + int start, + int count); + + [DllImport(nameof(IEnumeratorNative), PreserveSig = false)] + public static extern void VerifyIntegerEnumeration( + IEnumerable enumerable, + int start, + int count); + + [DllImport(nameof(IEnumeratorNative), PreserveSig = false)] + public static extern IEnumerator PassThroughEnumerator(IEnumerator enumerator); + } + + public static class IEnumeratorTests + { + private static void TestNativeToManaged() + { + Assert.AreAllEqual(Enumerable.Range(1, 10), EnumeratorAsEnumerable(IEnumeratorNative.GetIntegerEnumerator(1, 10))); + Assert.AreAllEqual(Enumerable.Range(1, 10), IEnumeratorNative.GetIntegerEnumeration(1, 10).OfType()); + } + + private static void TestManagedToNative() + { + IEnumeratorNative.VerifyIntegerEnumerator(Enumerable.Range(1, 10).GetEnumerator(), 1, 10); + IEnumeratorNative.VerifyIntegerEnumeration(Enumerable.Range(1, 10), 1, 10); + } + + private static void TestNativeRoundTrip() + { + IEnumerator nativeEnumerator = IEnumeratorNative.GetIntegerEnumerator(1, 10); + Assert.AreEqual(nativeEnumerator, IEnumeratorNative.PassThroughEnumerator(nativeEnumerator)); + } + + private static void TestManagedRoundTrip() + { + IEnumerator managedEnumerator = Enumerable.Range(1, 10).GetEnumerator(); + Assert.AreEqual(managedEnumerator, IEnumeratorNative.PassThroughEnumerator(managedEnumerator)); + } + + public static int Main() + { + try + { + TestNativeToManaged(); + TestManagedToNative(); + TestNativeRoundTrip(); + TestManagedRoundTrip(); + } + catch (System.Exception e) + { + Console.WriteLine(e.ToString()); + return 101; + } + + return 100; + } + + private static IEnumerable EnumeratorAsEnumerable(IEnumerator enumerator) + { + while (enumerator.MoveNext()) + { + yield return (int)enumerator.Current; + } + } + } +} diff --git a/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorTest.csproj b/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorTest.csproj new file mode 100644 index 0000000..f35118a --- /dev/null +++ b/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorTest.csproj @@ -0,0 +1,36 @@ + + + + + + Debug + AnyCPU + IEnumeratorTest + 2.0 + {F1E66554-8C8E-4141-85CF-D0CD6A0CD0B0} + exe + {786C830F-07A1-408B-BD7F-6EE04809D6DB};{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC} + ..\..\ + true + $(DefineConstants);STATIC + + + true + true + + + + + + + False + + + + + + + + + + diff --git a/tests/src/Interop/COM/NativeServer/ComHelpers.h b/tests/src/Interop/common/ComHelpers.h similarity index 100% rename from tests/src/Interop/COM/NativeServer/ComHelpers.h rename to tests/src/Interop/common/ComHelpers.h