* Import CustomMarshalers from corefx PR.
* Add tests for marshalling IEnumerator and IEnumerable using the System.Runtime.InteropServices.CustomMarshalers.
* Verify that the marshaller is used in the managed -> native direction in the tests.
* Remove #if's
* PR Feedback.
* Move ComHelpers.h into interop common headers.
* Add some template machinery to ensure that we don't accidentally pass in an incorrect type into DoQueryInterface.
* Add VariantClear
* Make marshalers internal and root them in ILLinkTrim.xml
* Revert enable-if machinery.
* Add tests for the pass-through unwrapping done by the EnumeratorToEnumVariantMarshaler.
* Remove some of my helper code.
* Added todo about exposing EnumerableToDispatchMarshaler.
* PR Feedback.
* Removed the last placement new.
* Use the runtime hooks to enable us to not have to publically expose the custom marshalers.
* Remove out of date comments.
<type fullname="System.Diagnostics.Tracing.RuntimeEventSource" />
<!-- Accessed via reflection in TraceLogging-style EventSource events. -->
<type fullname="*f__AnonymousType*" />
+ <!-- Accessed via native code. -->
+ <type fullname="System.Runtime.InteropServices.ComTypes.IEnumerable" />
+ <type fullname="System.Runtime.InteropServices.CustomMarshalers.*" />
</assembly>
</linker>
<data name="Argument_PrecisionTooLarge" xml:space="preserve">
<value>Precision cannot be larger than {0}.</value>
</data>
+ <data name="Arg_EnumNotCloneable" xml:space="preserve">
+ <value>The supplied object does not implement ICloneable.</value>
+ </data>
+ <data name="InvalidOp_InvalidNewEnumVariant" xml:space="preserve">
+ <value>The returned enumerator does not implement IEnumVARIANT.</value>
+ </data>
</root>
<Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComTypes\IConnectionPointContainer.cs" />
<Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComTypes\IEnumConnectionPoints.cs" />
<Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComTypes\IEnumConnections.cs" />
+ <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComTypes\IEnumerable.cs" />
<Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComTypes\IEnumMoniker.cs" />
<Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComTypes\IEnumString.cs" />
<Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComTypes\IEnumVARIANT.cs" />
<Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComEventsInfo.cs" Condition="'$(FeatureClassicCominterop)' == 'true'" />
<Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComEventsMethod.cs" Condition="'$(FeatureClassicCominterop)' == 'true'" />
<Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComEventsSink.cs" Condition="'$(FeatureClassicCominterop)' == 'true'" />
+ <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\CustomMarshalers\ComDataHelpers.cs" />
+ <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\CustomMarshalers\EnumVariantViewOfEnumerator.cs" />
+ <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\CustomMarshalers\EnumerableToDispatchMarshaler.cs" />
+ <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\CustomMarshalers\EnumerableViewOfDispatch.cs" />
+ <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\CustomMarshalers\EnumeratorToEnumVariantMarshaler.cs" />
+ <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\CustomMarshalers\EnumeratorViewOfEnumVariant.cs" />
<Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\DispatchWrapper.cs" />
- <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\NativeMethods.cs" Condition="'$(FeatureClassicCominterop)' == 'true'" />
+ <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\IDispatch.cs" />
<Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\WindowsRuntime\Attributes.cs" />
<Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\WindowsRuntime\BindableVectorToCollectionAdapter.cs" />
<Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\WindowsRuntime\BindableVectorToListAdapter.cs" />
namespace System.Collections
{
+ [Guid("496B0ABE-CDEE-11d3-88E8-00902754C43A")]
+ [ComVisible(true)]
public interface IEnumerable
{
// Returns an IEnumerator for this enumerable Object. The enumerator provides
internal static partial class HResults
{
internal const int S_OK = unchecked((int)0x00000000);
+ internal const int S_FALSE = unchecked((int)0x1);
internal const int COR_E_ABANDONEDMUTEX = unchecked((int)0x8013152D);
internal const int COR_E_AMBIGUOUSMATCH = unchecked((int)0x8000211D);
internal const int COR_E_APPDOMAINUNLOADED = unchecked((int)0x80131014);
CustomQueryInterfaceResult ICustomQueryInterface.GetInterface(ref Guid iid, out IntPtr ppv)
{
ppv = IntPtr.Zero;
- if (iid == _iidSourceItf || iid == typeof(NativeMethods.IDispatch).GUID)
+ if (iid == _iidSourceItf || iid == typeof(IDispatch).GUID)
{
- ppv = Marshal.GetComInterfaceForObject(this, typeof(NativeMethods.IDispatch), CustomQueryInterfaceMode.Ignore);
+ ppv = Marshal.GetComInterfaceForObject(this, typeof(IDispatch), CustomQueryInterfaceMode.Ignore);
return CustomQueryInterfaceResult.Handled;
}
else if (iid == IID_IManagedObject)
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information
+
+using System;
+namespace System.Runtime.InteropServices.ComTypes
+{
+ /*==========================================================================
+ ** Interface: IEnumerable
+ ** Purpose:
+ ** This interface is redefined here since the original IEnumerable interface
+ ** has all its methods marked as ecall's since it is a managed standard
+ ** interface. This interface is used from within the runtime to make a call
+ ** on the COM server directly when it implements the IEnumerable interface.
+ ==========================================================================*/
+ [Guid("496B0ABE-CDEE-11d3-88E8-00902754C43A")]
+ internal interface IEnumerable
+ {
+ [DispId(-4)]
+ System.Collections.IEnumerator GetEnumerator();
+ }
+}
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace System.Runtime.InteropServices.CustomMarshalers
+{
+ internal static class ComDataHelpers
+ {
+ public static TView GetOrCreateManagedViewFromComData<T, TView>(object comObject, Func<T, TView> createCallback)
+ {
+ object key = typeof(TView);
+
+ if (Marshal.GetComObjectData(comObject, key) is TView managedView)
+ {
+ return managedView;
+ }
+
+ managedView = createCallback((T)comObject);
+ if (!Marshal.SetComObjectData(comObject, key, managedView))
+ {
+ managedView = (TView)Marshal.GetComObjectData(comObject, key);
+ }
+ return managedView;
+ }
+ }
+}
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections;
+using System.Runtime.InteropServices.ComTypes;
+
+namespace System.Runtime.InteropServices.CustomMarshalers
+{
+ internal class EnumVariantViewOfEnumerator : IEnumVARIANT, ICustomAdapter
+ {
+ public EnumVariantViewOfEnumerator(IEnumerator enumerator)
+ {
+ if (enumerator is null)
+ {
+ throw new ArgumentNullException(nameof(enumerator));
+ }
+
+ Enumerator = enumerator;
+ }
+
+ public IEnumerator Enumerator { get; }
+
+ public IEnumVARIANT Clone()
+ {
+ if (Enumerator is ICloneable clonable)
+ {
+ return new EnumVariantViewOfEnumerator((IEnumerator)clonable.Clone());
+ }
+ else
+ {
+ throw new COMException(SR.Arg_EnumNotCloneable, HResults.E_FAIL);
+ }
+ }
+
+ public int Next(int celt, object[] rgVar, IntPtr pceltFetched)
+ {
+ int numElements = 0;
+
+ try
+ {
+ if (celt > 0 && rgVar == null)
+ {
+ return HResults.E_INVALIDARG;
+ }
+
+ while ((numElements < celt) && Enumerator.MoveNext())
+ {
+ rgVar[numElements++] = Enumerator.Current;
+ }
+
+ if (pceltFetched != IntPtr.Zero)
+ {
+ Marshal.WriteInt32(pceltFetched, numElements);
+ }
+ }
+ catch (Exception e)
+ {
+ return e.HResult;
+ }
+
+ return numElements == celt ? HResults.S_OK : HResults.S_FALSE;
+ }
+
+ public int Reset()
+ {
+ try
+ {
+ Enumerator.Reset();
+ }
+ catch (Exception e)
+ {
+ return e.HResult;
+ }
+
+ return HResults.S_OK;
+ }
+
+ public int Skip(int celt)
+ {
+ try
+ {
+ while (celt > 0 && Enumerator.MoveNext())
+ {
+ celt--;
+ }
+ }
+ catch (Exception e)
+ {
+ return e.HResult;
+ }
+
+ return celt == 0 ? HResults.S_OK : HResults.S_FALSE;
+ }
+
+ public object GetUnderlyingObject()
+ {
+ return Enumerator;
+ }
+ }
+}
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections;
+
+namespace System.Runtime.InteropServices.CustomMarshalers
+{
+ internal class EnumerableToDispatchMarshaler : ICustomMarshaler
+ {
+ private static readonly EnumerableToDispatchMarshaler s_enumerableToDispatchMarshaler = new EnumerableToDispatchMarshaler();
+
+ public static ICustomMarshaler GetInstance(string cookie) => s_enumerableToDispatchMarshaler;
+
+ private EnumerableToDispatchMarshaler()
+ {
+ }
+
+ public void CleanUpManagedData(object ManagedObj)
+ {
+ }
+
+ public void CleanUpNativeData(IntPtr pNativeData)
+ {
+ Marshal.Release(pNativeData);
+ }
+
+ public int GetNativeDataSize()
+ {
+ // Return -1 to indicate the managed type this marshaler handles is not a value type.
+ return -1;
+ }
+
+ public IntPtr MarshalManagedToNative(object ManagedObj)
+ {
+ if (ManagedObj == null)
+ {
+ throw new ArgumentNullException(nameof(ManagedObj));
+ }
+
+ return Marshal.GetComInterfaceForObject<object, IEnumerable>(ManagedObj);
+ }
+
+ public object MarshalNativeToManaged(IntPtr pNativeData)
+ {
+ if (pNativeData == IntPtr.Zero)
+ {
+ throw new ArgumentNullException(nameof(pNativeData));
+ }
+
+ object comObject = Marshal.GetObjectForIUnknown(pNativeData);
+
+ return ComDataHelpers.GetOrCreateManagedViewFromComData<object, EnumerableViewOfDispatch>(comObject, obj => new EnumerableViewOfDispatch(obj));
+ }
+ }
+}
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections;
+using System.Runtime.InteropServices.ComTypes;
+
+namespace System.Runtime.InteropServices.CustomMarshalers
+{
+ internal class EnumerableViewOfDispatch : ICustomAdapter, System.Collections.IEnumerable
+ {
+ // Reserved DISPID slot for getting an enumerator from an IDispatch-implementing COM interface.
+ private const int DISPID_NEWENUM = -4;
+ private const int LCID_DEFAULT = 1;
+ private readonly object _dispatch;
+
+ public EnumerableViewOfDispatch(object dispatch)
+ {
+ _dispatch = dispatch;
+ }
+
+ private IDispatch Dispatch => (IDispatch)_dispatch;
+
+ public IEnumerator GetEnumerator()
+ {
+ DISPPARAMS dispParams = new DISPPARAMS();
+ Guid guid = Guid.Empty;
+ Dispatch.Invoke(
+ DISPID_NEWENUM,
+ ref guid,
+ LCID_DEFAULT,
+ InvokeFlags.DISPATCH_METHOD | InvokeFlags.DISPATCH_PROPERTYGET,
+ ref dispParams,
+ out object result,
+ IntPtr.Zero,
+ IntPtr.Zero);
+
+ if (!(result is IEnumVARIANT enumVariant))
+ {
+ throw new InvalidOperationException(SR.InvalidOp_InvalidNewEnumVariant);
+ }
+
+ return (IEnumerator)EnumeratorToEnumVariantMarshaler.GetInstance(null).MarshalNativeToManaged(Marshal.GetIUnknownForObject(enumVariant));
+ }
+
+ public object GetUnderlyingObject()
+ {
+ return _dispatch;
+ }
+ }
+}
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Runtime.InteropServices.ComTypes;
+using System.Text;
+
+namespace System.Runtime.InteropServices.CustomMarshalers
+{
+ internal class EnumeratorToEnumVariantMarshaler : ICustomMarshaler
+ {
+ private static readonly EnumeratorToEnumVariantMarshaler s_enumeratorToEnumVariantMarshaler = new EnumeratorToEnumVariantMarshaler();
+
+ public static ICustomMarshaler GetInstance(string cookie) => s_enumeratorToEnumVariantMarshaler;
+
+ private EnumeratorToEnumVariantMarshaler()
+ {
+ }
+
+ public void CleanUpManagedData(object ManagedObj)
+ {
+ }
+
+ public void CleanUpNativeData(IntPtr pNativeData)
+ {
+ Marshal.Release(pNativeData);
+ }
+
+ public int GetNativeDataSize()
+ {
+ // Return -1 to indicate the managed type this marshaler handles is not a value type.
+ return -1;
+ }
+
+ public IntPtr MarshalManagedToNative(object ManagedObj)
+ {
+ if (ManagedObj == null)
+ {
+ throw new ArgumentNullException(nameof(ManagedObj));
+ }
+
+ if (ManagedObj is EnumeratorViewOfEnumVariant view)
+ {
+ return Marshal.GetComInterfaceForObject<object, IEnumVARIANT>(view.GetUnderlyingObject());
+ }
+
+ EnumVariantViewOfEnumerator nativeView = new EnumVariantViewOfEnumerator((IEnumerator)ManagedObj);
+
+ return Marshal.GetComInterfaceForObject<EnumVariantViewOfEnumerator, IEnumVARIANT>(nativeView);
+ }
+
+ public object MarshalNativeToManaged(IntPtr pNativeData)
+ {
+ if (pNativeData == IntPtr.Zero)
+ {
+ throw new ArgumentNullException(nameof(pNativeData));
+ }
+
+ object comObject = Marshal.GetObjectForIUnknown(pNativeData);
+
+ if (!comObject.GetType().IsCOMObject)
+ {
+ if (comObject is EnumVariantViewOfEnumerator enumVariantView)
+ {
+ return enumVariantView.Enumerator;
+ }
+
+ return comObject as IEnumerator;
+ }
+
+ return ComDataHelpers.GetOrCreateManagedViewFromComData<IEnumVARIANT, EnumeratorViewOfEnumVariant>(comObject, var => new EnumeratorViewOfEnumVariant(var));
+ }
+ }
+}
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Runtime.InteropServices.ComTypes;
+using System.Text;
+
+namespace System.Runtime.InteropServices.CustomMarshalers
+{
+ internal class EnumeratorViewOfEnumVariant : ICustomAdapter, IEnumerator
+ {
+ private readonly IEnumVARIANT _enumVariantObject;
+ private bool _fetchedLastObject;
+ private object[] _nextArray = new object[1];
+ private object _current;
+
+ public EnumeratorViewOfEnumVariant(IEnumVARIANT enumVariantObject)
+ {
+ _enumVariantObject = enumVariantObject;
+ _fetchedLastObject = false;
+ _current = null;
+ }
+
+ public object Current => _current;
+
+ public unsafe bool MoveNext()
+ {
+ if (_fetchedLastObject)
+ {
+ _current = null;
+ return false;
+ }
+
+ int numFetched = 0;
+
+ if (_enumVariantObject.Next(1, _nextArray, (IntPtr)(&numFetched)) == HResults.S_FALSE)
+ {
+ _fetchedLastObject = true;
+
+ if (numFetched == 0)
+ {
+ _current = null;
+ return false;
+ }
+ }
+
+ _current = _nextArray[0];
+
+ return true;
+ }
+
+ public void Reset()
+ {
+ int hr = _enumVariantObject.Reset();
+ if (hr < 0)
+ {
+ Marshal.ThrowExceptionForHR(hr);
+ }
+ }
+
+ public object GetUnderlyingObject()
+ {
+ return _enumVariantObject;
+ }
+ }
+}
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Runtime.InteropServices.ComTypes;
+using System.Text;
+
+namespace System.Runtime.InteropServices
+{
+ [ComImport]
+ [Guid("00020400-0000-0000-C000-000000000046")]
+ [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
+ internal interface IDispatch
+ {
+ void GetIDsOfNames(
+ ref Guid riid,
+ [MarshalAs(UnmanagedType.LPArray, ArraySubType = UnmanagedType.LPWStr, SizeParamIndex = 2), In]
+ string[] rgszNames,
+ int cNames,
+ int lcid,
+ [Out] int[] rgDispId);
+
+ ITypeInfo GetTypeInfo(
+ int iTInfo,
+ int lcid);
+
+ int GetTypeInfoCount();
+
+ void Invoke(
+ int dispIdMember,
+ ref Guid riid,
+ int lcid,
+ InvokeFlags wFlags,
+ ref DISPPARAMS pDispParams,
+ out object pVarResult,
+ IntPtr pExcepInfo,
+ IntPtr puArgErr);
+ }
+
+ [Flags]
+ internal enum InvokeFlags : short
+ {
+ DISPATCH_METHOD = 1,
+ DISPATCH_PROPERTYGET = 2,
+ DISPATCH_PROPERTYPUT = 4,
+ DISPATCH_PROPERTYPUTREF = 8
+ }
+}
+++ /dev/null
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-#if FEATURE_COMINTEROP
-
-namespace System.Runtime.InteropServices
-{
- /// <summary>
- /// Part of ComEventHelpers APIs which allow binding managed delegates
- /// to COM's connection point based events.
- /// </summary>
- internal static class NativeMethods
- {
- [ComImport]
- [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
- [Guid("00020400-0000-0000-C000-000000000046")]
- internal interface IDispatch
- {
- }
- }
-}
-
-#endif
#ifdef FEATURE_COMINTEROP
#define g_ECMAKeyToken "B77A5C561934E089" // The ECMA key used by some framework assemblies: mscorlib, system, etc.
#define g_FXKeyToken "b03f5f7f11d50a3a" // The FX key used by other framework assemblies: System.Web, System.Drawing, etc.
+#define g_CorelibAsmName "System.Private.CoreLib"
#define g_SystemAsmName "System"
#define g_SystemRuntimeAsmName "System.Runtime"
#define g_DrawingAsmName "System.Drawing"
#ifdef FEATURE_COMINTEROP
- DEFINE_ASM_QUAL_TYPE_NAME(ENUMERATOR_TO_ENUM_VARIANT_CM_NAME, g_EnumeratorToEnumClassName, "CustomMarshalers", VER_ASSEMBLYVERSION_STR, g_FXKeyToken);
+ DEFINE_ASM_QUAL_TYPE_NAME(ENUMERATOR_TO_ENUM_VARIANT_CM_NAME, g_EnumeratorToEnumClassName, g_CorelibAsmName, VER_ASSEMBLYVERSION_STR, g_ECMAKeyToken);
static const int ENUMERATOR_TO_ENUM_VARIANT_CM_NAME_LEN = lengthof(ENUMERATOR_TO_ENUM_VARIANT_CM_NAME);
static const char ENUMERATOR_TO_ENUM_VARIANT_CM_COOKIE[] = {""};
if (retVal == NULL)
{
- // In desktop CLR we'll attempt to call through IDispatch(DISPID_NEWENUM)
- // This is not supported in CoreCLR
- COMPlusThrow(kPlatformNotSupportedException, IDS_EE_ERROR_IDISPATCH);
+ retVal = ObjectToOBJECTREF((Object*)GetEnumeratorWorker(args));
}
GCPROTECT_END();
#define MNGSTDITF_DEFINE_METH3(FriendlyName, MethName, MethSig, FcallDecl) \
MNGSTDITF_DEFINE_METH_IMPL(FriendlyName, MethName##_3, MethName, MethSig, FcallDecl)
-#define CUSTOM_MARSHALER_ASM ", CustomMarshalers, Version=" VER_ASSEMBLYVERSION_STR ", Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a"
+#define CUSTOM_MARSHALER_ASM ", System.Private.CoreLib, Version=" VER_ASSEMBLYVERSION_STR ", Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a"
add_subdirectory(PInvoke/NativeCallManagedComVisible)
# This test doesn't necessarily need to be Windows-only, but the implementation is very tied to Windows APIs
add_subdirectory(PInvoke/DateTime)
+ add_subdirectory(PInvoke/IEnumerator)
add_subdirectory(ArrayMarshalling/SafeArray)
add_subdirectory(COM/NativeServer)
add_subdirectory(COM/NativeClients/Primitives)
};
#ifndef COM_CLIENT
- #include "ComHelpers.h"
+ #include <ComHelpers.h>
#define DEF_FUNC(n) virtual COM_DECLSPEC_NOTHROW HRESULT STDMETHODCALLTYPE n
--- /dev/null
+cmake_minimum_required (VERSION 2.6)
+project (IEnumeratorNative)
+include_directories(${INC_PLATFORM_DIR})
+set(SOURCES IEnumeratorNative.cpp )
+
+# add the executable
+add_library (IEnumeratorNative SHARED ${SOURCES})
+target_link_libraries(IEnumeratorNative ${LINK_LIBRARIES_ADDITIONAL})
+
+# add the install targets
+install (TARGETS IEnumeratorNative DESTINATION bin)
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+#include "IEnumeratorNative.h"
+#include <xplatform.h>
+
+extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE GetIntegerEnumerator(int start, int count, IEnumVARIANT** ppEnum)
+{
+ if (count < 0)
+ {
+ return E_INVALIDARG;
+ }
+
+ *ppEnum = new IntegerEnumerator(start, count);
+
+ return S_OK;
+}
+
+extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE VerifyIntegerEnumerator(IEnumVARIANT* pEnum, int start, int count)
+{
+ if (count < 0)
+ {
+ return E_INVALIDARG;
+ }
+
+ HRESULT hr = S_OK;
+
+ VARIANT element;
+ ULONG numFetched;
+
+ for(int i = start; i < start + count; ++i)
+ {
+ VariantClear(&element);
+ hr = pEnum->Next(1, &element, &numFetched);
+ if(FAILED(hr) || numFetched != 1)
+ {
+ return hr;
+ }
+
+ if (V_I4(&element) != i)
+ {
+ return E_UNEXPECTED;
+ }
+ }
+
+ hr = pEnum->Next(1, &element, &numFetched);
+ if (hr != S_FALSE || numFetched != 0)
+ {
+ return E_UNEXPECTED;
+ }
+
+ return S_OK;
+}
+
+extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE GetIntegerEnumeration(int start, int count, IDispatch** ppDisp)
+{
+ if (count < 0)
+ {
+ return E_INVALIDARG;
+ }
+
+ *ppDisp = new IntegerEnumerable(start, count);
+
+ return S_OK;
+}
+
+extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE VerifyIntegerEnumeration(IDispatch* pDisp, int start, int count)
+{
+ DISPPARAMS params{};
+ VARIANT result;
+ HRESULT hr = pDisp->Invoke(
+ DISPID_NEWENUM,
+ IID_NULL,
+ LOCALE_USER_DEFAULT,
+ DISPATCH_METHOD | DISPATCH_PROPERTYGET,
+ ¶ms,
+ &result,
+ NULL,
+ NULL
+ );
+
+ if (FAILED(hr))
+ {
+ return hr;
+ }
+
+ if(!((V_VT(&result) == VT_UNKNOWN) || (V_VT(&result) == VT_DISPATCH)))
+ {
+ return E_UNEXPECTED;
+ }
+
+ IEnumVARIANT* pEnum;
+
+ hr = V_UNKNOWN(&result)->QueryInterface<IEnumVARIANT>(&pEnum);
+
+ if (FAILED(hr))
+ {
+ return hr;
+ }
+
+ hr = VerifyIntegerEnumerator(pEnum, start, count);
+
+ pEnum->Release();
+
+ return hr;
+}
+
+extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE PassThroughEnumerator(IEnumVARIANT* in, IEnumVARIANT** out)
+{
+ return in->QueryInterface(out);
+}
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+#include <xplatform.h>
+#include <platformdefines.h>
+#include <ComHelpers.h>
+#include <algorithm>
+
+class IntegerEnumerator : public UnknownImpl, public IEnumVARIANT
+{
+ int start;
+ int count;
+ int current;
+
+public:
+ IntegerEnumerator(int start, int count)
+ :UnknownImpl(),
+ start(start),
+ count(count),
+ current(start)
+ {
+ }
+
+ HRESULT STDMETHODCALLTYPE Next(
+ ULONG celt,
+ VARIANT *rgVar,
+ ULONG *pCeltFetched) override
+ {
+ for(*pCeltFetched = 0; *pCeltFetched < celt && current < start + count; ++*pCeltFetched, ++current)
+ {
+ VariantClear(&(rgVar[*pCeltFetched]));
+ V_VT(&rgVar[*pCeltFetched]) = VT_I4;
+ V_I4(&(rgVar[*pCeltFetched])) = current;
+ }
+
+ return celt == *pCeltFetched ? S_OK : S_FALSE;
+ }
+
+ HRESULT STDMETHODCALLTYPE Skip(ULONG celt) override
+ {
+ int original = current;
+ current = std::min(current + (int)celt, start + count);
+ return original + (int)celt <= start + count ? S_OK : S_FALSE;
+ }
+
+ HRESULT STDMETHODCALLTYPE Reset(void) override
+ {
+ current = start;
+ return S_OK;
+ }
+
+ HRESULT STDMETHODCALLTYPE Clone(IEnumVARIANT **ppEnum) override
+ {
+ IntegerEnumerator* clone = new IntegerEnumerator(start, count);
+ clone->current = current;
+ *ppEnum = clone;
+
+ return S_OK;
+ }
+
+ HRESULT STDMETHODCALLTYPE QueryInterface(
+ REFIID riid,
+ void** ppvObject)
+ {
+ return DoQueryInterface<IntegerEnumerator, IEnumVARIANT>(this, riid, ppvObject);
+ }
+
+ DEFINE_REF_COUNTING();
+};
+
+class IntegerEnumerable : public UnknownImpl, public IDispatch
+{
+private:
+ int start;
+ int count;
+public:
+ IntegerEnumerable(int start, int count)
+ :UnknownImpl(),
+ start(start),
+ count(count)
+ {
+ }
+
+ HRESULT STDMETHODCALLTYPE GetTypeInfoCount(
+ UINT *pctinfo) override
+ {
+ *pctinfo = 0;
+ return S_OK;
+ }
+
+ HRESULT STDMETHODCALLTYPE GetTypeInfo(
+ UINT iTInfo,
+ LCID lcid,
+ ITypeInfo **ppTInfo) override
+ {
+ return E_NOTIMPL;
+ }
+
+ HRESULT STDMETHODCALLTYPE GetIDsOfNames(
+ REFIID riid,
+ LPOLESTR *rgszNames,
+ UINT cNames,
+ LCID lcid,
+ DISPID *rgDispId) override
+ {
+ bool containsUnknown = false;
+ DISPID *curr = rgDispId;
+ for (UINT i = 0; i < cNames; ++i)
+ {
+ *curr = DISPID_UNKNOWN;
+ LPOLESTR name = rgszNames[i];
+ if(wcscmp(name, W("GetEnumerator")) == 0)
+ {
+ *curr = DISPID_NEWENUM;
+ }
+
+ containsUnknown &= (*curr == DISPID_UNKNOWN);
+ curr++;
+ }
+
+ return (containsUnknown) ? DISP_E_UNKNOWNNAME : S_OK;
+ }
+
+ HRESULT STDMETHODCALLTYPE Invoke(
+ DISPID dispIdMember,
+ REFIID riid,
+ LCID lcid,
+ WORD wFlags,
+ DISPPARAMS *pDispParams,
+ VARIANT *pVarResult,
+ EXCEPINFO *pExcepInfo,
+ UINT *puArgErr) override
+ {
+ if (dispIdMember == DISPID_NEWENUM && (wFlags & INVOKE_PROPERTYGET) == INVOKE_PROPERTYGET)
+ {
+ V_VT(pVarResult) = VT_UNKNOWN;
+ V_UNKNOWN(pVarResult) = new IntegerEnumerator(start, count);
+ return S_OK;
+ }
+
+ return E_NOTIMPL;
+ }
+
+ HRESULT STDMETHODCALLTYPE QueryInterface(
+ REFIID riid,
+ void** ppvObject)
+ {
+ return DoQueryInterface<IntegerEnumerable, IDispatch>(this, riid, ppvObject);
+ }
+
+ DEFINE_REF_COUNTING();
+};
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.InteropServices;
+using TestLibrary;
+
+namespace PInvokeTests
+{
+ static class IEnumeratorNative
+ {
+ [DllImport(nameof(IEnumeratorNative), PreserveSig = false)]
+ public static extern IEnumerator GetIntegerEnumerator(
+ int start,
+ int count);
+
+ [DllImport(nameof(IEnumeratorNative), PreserveSig = false)]
+ public static extern IEnumerable GetIntegerEnumeration(
+ int start,
+ int count);
+
+ [DllImport(nameof(IEnumeratorNative), PreserveSig = false)]
+ public static extern void VerifyIntegerEnumerator(
+ IEnumerator enumerator,
+ int start,
+ int count);
+
+ [DllImport(nameof(IEnumeratorNative), PreserveSig = false)]
+ public static extern void VerifyIntegerEnumeration(
+ IEnumerable enumerable,
+ int start,
+ int count);
+
+ [DllImport(nameof(IEnumeratorNative), PreserveSig = false)]
+ public static extern IEnumerator PassThroughEnumerator(IEnumerator enumerator);
+ }
+
+ public static class IEnumeratorTests
+ {
+ private static void TestNativeToManaged()
+ {
+ Assert.AreAllEqual(Enumerable.Range(1, 10), EnumeratorAsEnumerable(IEnumeratorNative.GetIntegerEnumerator(1, 10)));
+ Assert.AreAllEqual(Enumerable.Range(1, 10), IEnumeratorNative.GetIntegerEnumeration(1, 10).OfType<int>());
+ }
+
+ private static void TestManagedToNative()
+ {
+ IEnumeratorNative.VerifyIntegerEnumerator(Enumerable.Range(1, 10).GetEnumerator(), 1, 10);
+ IEnumeratorNative.VerifyIntegerEnumeration(Enumerable.Range(1, 10), 1, 10);
+ }
+
+ private static void TestNativeRoundTrip()
+ {
+ IEnumerator nativeEnumerator = IEnumeratorNative.GetIntegerEnumerator(1, 10);
+ Assert.AreEqual(nativeEnumerator, IEnumeratorNative.PassThroughEnumerator(nativeEnumerator));
+ }
+
+ private static void TestManagedRoundTrip()
+ {
+ IEnumerator managedEnumerator = Enumerable.Range(1, 10).GetEnumerator();
+ Assert.AreEqual(managedEnumerator, IEnumeratorNative.PassThroughEnumerator(managedEnumerator));
+ }
+
+ public static int Main()
+ {
+ try
+ {
+ TestNativeToManaged();
+ TestManagedToNative();
+ TestNativeRoundTrip();
+ TestManagedRoundTrip();
+ }
+ catch (System.Exception e)
+ {
+ Console.WriteLine(e.ToString());
+ return 101;
+ }
+
+ return 100;
+ }
+
+ private static IEnumerable<int> EnumeratorAsEnumerable(IEnumerator enumerator)
+ {
+ while (enumerator.MoveNext())
+ {
+ yield return (int)enumerator.Current;
+ }
+ }
+ }
+}
--- /dev/null
+<?xml version="1.0" encoding="utf-8"?>
+<Project ToolsVersion="12.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.props))\dir.props" />
+ <Import Project="../../Interop.settings.targets" />
+ <PropertyGroup>
+ <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
+ <Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform>
+ <AssemblyName>IEnumeratorTest</AssemblyName>
+ <SchemaVersion>2.0</SchemaVersion>
+ <ProjectGuid>{F1E66554-8C8E-4141-85CF-D0CD6A0CD0B0}</ProjectGuid>
+ <OutputType>exe</OutputType>
+ <ProjectTypeGuids>{786C830F-07A1-408B-BD7F-6EE04809D6DB};{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}</ProjectTypeGuids>
+ <SolutionDir Condition="$(SolutionDir) == '' Or $(SolutionDir) == '*Undefined*'">..\..\</SolutionDir>
+ <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
+ <DefineConstants>$(DefineConstants);STATIC</DefineConstants>
+
+ <!-- IEnumerator/IEnumerable marshalling unsupported outside of windows -->
+ <TestUnsupportedOutsideWindows>true</TestUnsupportedOutsideWindows>
+ <DisableProjectBuild Condition="'$(TargetsUnix)' == 'true'">true</DisableProjectBuild>
+ </PropertyGroup>
+ <!-- Default configurations to help VS understand the configurations -->
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|x64'"></PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Release|x64'"></PropertyGroup>
+ <ItemGroup>
+ <CodeAnalysisDependentAssemblyPaths Condition=" '$(VS100COMNTOOLS)' != '' " Include="$(VS100COMNTOOLS)..\IDE\PrivateAssemblies">
+ <Visible>False</Visible>
+ </CodeAnalysisDependentAssemblyPaths>
+ </ItemGroup>
+ <ItemGroup>
+ <Compile Include="*.cs" />
+ </ItemGroup>
+ <ItemGroup>
+ <ProjectReference Include="CMakeLists.txt"/>
+ </ItemGroup>
+ <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.targets))\dir.targets" />
+</Project>