Port CustomMarshalers.dll to C# and .NET Core (#21343)
authorJeremy Koritzinsky <jkoritzinsky@gmail.com>
Thu, 6 Dec 2018 17:56:21 +0000 (09:56 -0800)
committerGitHub <noreply@github.com>
Thu, 6 Dec 2018 17:56:21 +0000 (09:56 -0800)
* Import CustomMarshalers from corefx PR.

* Add tests for marshalling IEnumerator and IEnumerable using the System.Runtime.InteropServices.CustomMarshalers.

* Verify that the marshaller is used in the managed -> native direction in the tests.

* Remove #if's

* PR Feedback.

* Move ComHelpers.h into interop common headers.

* Add some template machinery to ensure that we don't accidentally pass in an incorrect type into DoQueryInterface.

* Add VariantClear

* Make marshalers internal and root them in ILLinkTrim.xml

* Revert enable-if machinery.

* Add tests for the pass-through unwrapping done by the EnumeratorToEnumVariantMarshaler.

* Remove some of my helper code.

* Added todo about exposing EnumerableToDispatchMarshaler.

* PR Feedback.

* Removed the last placement new.

* Use the runtime hooks to enable us to not have to publically expose the custom marshalers.

* Remove out of date comments.

27 files changed:
src/System.Private.CoreLib/ILLinkTrim.xml
src/System.Private.CoreLib/Resources/Strings.resx
src/System.Private.CoreLib/System.Private.CoreLib.csproj
src/System.Private.CoreLib/shared/System/Collections/IEnumerable.cs
src/System.Private.CoreLib/shared/System/HResults.cs
src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsSink.cs
src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComTypes/IEnumerable.cs [new file with mode: 0644]
src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/ComDataHelpers.cs [new file with mode: 0644]
src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumVariantViewOfEnumerator.cs [new file with mode: 0644]
src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumerableToDispatchMarshaler.cs [new file with mode: 0644]
src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumerableViewOfDispatch.cs [new file with mode: 0644]
src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumeratorToEnumVariantMarshaler.cs [new file with mode: 0644]
src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumeratorViewOfEnumVariant.cs [new file with mode: 0644]
src/System.Private.CoreLib/src/System/Runtime/InteropServices/IDispatch.cs [new file with mode: 0644]
src/System.Private.CoreLib/src/System/Runtime/InteropServices/NativeMethods.cs [deleted file]
src/vm/classnames.h
src/vm/mlinfo.cpp
src/vm/mngstdinterfaces.cpp
src/vm/mngstditflist.h
tests/src/Interop/CMakeLists.txt
tests/src/Interop/COM/NativeServer/Servers.h
tests/src/Interop/PInvoke/IEnumerator/CMakeLists.txt [new file with mode: 0644]
tests/src/Interop/PInvoke/IEnumerator/IEnumeratorNative.cpp [new file with mode: 0644]
tests/src/Interop/PInvoke/IEnumerator/IEnumeratorNative.h [new file with mode: 0644]
tests/src/Interop/PInvoke/IEnumerator/IEnumeratorTest.cs [new file with mode: 0644]
tests/src/Interop/PInvoke/IEnumerator/IEnumeratorTest.csproj [new file with mode: 0644]
tests/src/Interop/common/ComHelpers.h [moved from tests/src/Interop/COM/NativeServer/ComHelpers.h with 100% similarity]

index 5445899..53eae5d 100644 (file)
@@ -39,5 +39,8 @@
     <type fullname="System.Diagnostics.Tracing.RuntimeEventSource" />
     <!-- Accessed via reflection in TraceLogging-style EventSource events. -->
     <type fullname="*f__AnonymousType*" />
+    <!-- Accessed via native code. -->
+    <type fullname="System.Runtime.InteropServices.ComTypes.IEnumerable" />
+    <type fullname="System.Runtime.InteropServices.CustomMarshalers.*" />
   </assembly>
 </linker>
index 982d552..d48b228 100644 (file)
   <data name="Argument_PrecisionTooLarge" xml:space="preserve">
     <value>Precision cannot be larger than {0}.</value>
   </data>
+  <data name="Arg_EnumNotCloneable" xml:space="preserve">
+    <value>The supplied object does not implement ICloneable.</value>
+  </data>
+  <data name="InvalidOp_InvalidNewEnumVariant" xml:space="preserve">
+    <value>The returned enumerator does not implement IEnumVARIANT.</value>
+  </data>
 </root>
index 040cd88..b9842af 100644 (file)
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComTypes\IConnectionPointContainer.cs" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComTypes\IEnumConnectionPoints.cs" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComTypes\IEnumConnections.cs" />
+    <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComTypes\IEnumerable.cs" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComTypes\IEnumMoniker.cs" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComTypes\IEnumString.cs" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComTypes\IEnumVARIANT.cs" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComEventsInfo.cs" Condition="'$(FeatureClassicCominterop)' == 'true'" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComEventsMethod.cs" Condition="'$(FeatureClassicCominterop)' == 'true'" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComEventsSink.cs" Condition="'$(FeatureClassicCominterop)' == 'true'" />
+    <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\CustomMarshalers\ComDataHelpers.cs" />
+    <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\CustomMarshalers\EnumVariantViewOfEnumerator.cs" />
+    <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\CustomMarshalers\EnumerableToDispatchMarshaler.cs" />
+    <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\CustomMarshalers\EnumerableViewOfDispatch.cs" />
+    <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\CustomMarshalers\EnumeratorToEnumVariantMarshaler.cs" />
+    <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\CustomMarshalers\EnumeratorViewOfEnumVariant.cs" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\DispatchWrapper.cs" />
-    <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\NativeMethods.cs" Condition="'$(FeatureClassicCominterop)' == 'true'" />
+    <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\IDispatch.cs" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\WindowsRuntime\Attributes.cs" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\WindowsRuntime\BindableVectorToCollectionAdapter.cs" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\WindowsRuntime\BindableVectorToListAdapter.cs" />
index e5edeff..3bc1547 100644 (file)
@@ -7,6 +7,8 @@ using System.Runtime.InteropServices;
 
 namespace System.Collections
 {
+    [Guid("496B0ABE-CDEE-11d3-88E8-00902754C43A")]
+    [ComVisible(true)]
     public interface IEnumerable
     {
         // Returns an IEnumerator for this enumerable Object.  The enumerator provides
index 4a5ec0d..a3d87ac 100644 (file)
@@ -26,6 +26,7 @@ namespace System
     internal static partial class HResults
     {
         internal const int S_OK = unchecked((int)0x00000000);
+        internal const int S_FALSE = unchecked((int)0x1);
         internal const int COR_E_ABANDONEDMUTEX = unchecked((int)0x8013152D);
         internal const int COR_E_AMBIGUOUSMATCH = unchecked((int)0x8000211D);
         internal const int COR_E_APPDOMAINUNLOADED = unchecked((int)0x80131014);
index b9f10bc..0ddde5d 100644 (file)
@@ -127,9 +127,9 @@ namespace System.Runtime.InteropServices
         CustomQueryInterfaceResult ICustomQueryInterface.GetInterface(ref Guid iid, out IntPtr ppv)
         {
             ppv = IntPtr.Zero;
-            if (iid == _iidSourceItf || iid == typeof(NativeMethods.IDispatch).GUID)
+            if (iid == _iidSourceItf || iid == typeof(IDispatch).GUID)
             {
-                ppv = Marshal.GetComInterfaceForObject(this, typeof(NativeMethods.IDispatch), CustomQueryInterfaceMode.Ignore);
+                ppv = Marshal.GetComInterfaceForObject(this, typeof(IDispatch), CustomQueryInterfaceMode.Ignore);
                 return CustomQueryInterfaceResult.Handled;
             }
             else if (iid == IID_IManagedObject)
diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComTypes/IEnumerable.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComTypes/IEnumerable.cs
new file mode 100644 (file)
index 0000000..e541960
--- /dev/null
@@ -0,0 +1,22 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information
+
+using System;
+namespace System.Runtime.InteropServices.ComTypes
+{
+    /*==========================================================================
+    ** Interface: IEnumerable
+    ** Purpose:
+    ** This interface is redefined here since the original IEnumerable interface
+    ** has all its methods marked as ecall's since it is a managed standard
+    ** interface. This interface is used from within the runtime to make a call
+    ** on the COM server directly when it implements the IEnumerable interface.
+    ==========================================================================*/
+    [Guid("496B0ABE-CDEE-11d3-88E8-00902754C43A")]
+    internal interface IEnumerable
+    {
+        [DispId(-4)]
+        System.Collections.IEnumerator GetEnumerator();
+    }
+}
diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/ComDataHelpers.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/ComDataHelpers.cs
new file mode 100644 (file)
index 0000000..c398987
--- /dev/null
@@ -0,0 +1,30 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace System.Runtime.InteropServices.CustomMarshalers
+{
+    internal static class ComDataHelpers
+    {
+        public static TView GetOrCreateManagedViewFromComData<T, TView>(object comObject, Func<T, TView> createCallback)
+        {
+            object key = typeof(TView);
+
+            if (Marshal.GetComObjectData(comObject, key) is TView managedView)
+            {
+                return managedView;
+            }
+            
+            managedView = createCallback((T)comObject);
+            if (!Marshal.SetComObjectData(comObject, key, managedView))
+            {
+                managedView = (TView)Marshal.GetComObjectData(comObject, key);
+            }
+            return managedView;
+        }
+    }
+}
diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumVariantViewOfEnumerator.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumVariantViewOfEnumerator.cs
new file mode 100644 (file)
index 0000000..d43b682
--- /dev/null
@@ -0,0 +1,102 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections;
+using System.Runtime.InteropServices.ComTypes;
+
+namespace System.Runtime.InteropServices.CustomMarshalers
+{
+    internal class EnumVariantViewOfEnumerator : IEnumVARIANT, ICustomAdapter
+    {
+        public EnumVariantViewOfEnumerator(IEnumerator enumerator)
+        {
+            if (enumerator is null)
+            {
+                throw new ArgumentNullException(nameof(enumerator));
+            }
+
+            Enumerator = enumerator;
+        }
+
+        public IEnumerator Enumerator { get; }
+
+        public IEnumVARIANT Clone()
+        {
+            if (Enumerator is ICloneable clonable)
+            {
+                return new EnumVariantViewOfEnumerator((IEnumerator)clonable.Clone());
+            }
+            else
+            {
+                throw new COMException(SR.Arg_EnumNotCloneable, HResults.E_FAIL);
+            }
+        }
+
+        public int Next(int celt, object[] rgVar, IntPtr pceltFetched)
+        {
+            int numElements = 0;
+
+            try
+            {
+                if (celt > 0 && rgVar == null)
+                {
+                    return HResults.E_INVALIDARG;
+                }
+
+                while ((numElements < celt) && Enumerator.MoveNext())
+                {
+                    rgVar[numElements++] = Enumerator.Current;
+                }
+
+                if (pceltFetched != IntPtr.Zero)
+                {
+                    Marshal.WriteInt32(pceltFetched, numElements);
+                }
+            }
+            catch (Exception e)
+            {
+                return e.HResult;
+            }
+
+            return numElements == celt ? HResults.S_OK : HResults.S_FALSE;
+        }
+
+        public int Reset()
+        {
+            try
+            {
+                Enumerator.Reset();
+            }
+            catch (Exception e)
+            {
+                return e.HResult;
+            }
+
+            return HResults.S_OK;
+        }
+
+        public int Skip(int celt)
+        {
+            try
+            {
+                while (celt > 0 && Enumerator.MoveNext())
+                {
+                    celt--;
+                }
+            }
+            catch (Exception e)
+            {
+                return e.HResult;
+            }
+
+            return celt == 0 ? HResults.S_OK : HResults.S_FALSE;
+        }
+
+        public object GetUnderlyingObject()
+        {
+            return Enumerator;
+        }
+    }
+}
diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumerableToDispatchMarshaler.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumerableToDispatchMarshaler.cs
new file mode 100644 (file)
index 0000000..b3d319a
--- /dev/null
@@ -0,0 +1,56 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections;
+
+namespace System.Runtime.InteropServices.CustomMarshalers
+{
+    internal class EnumerableToDispatchMarshaler : ICustomMarshaler
+    {
+        private static readonly EnumerableToDispatchMarshaler s_enumerableToDispatchMarshaler = new EnumerableToDispatchMarshaler();
+
+        public static ICustomMarshaler GetInstance(string cookie) => s_enumerableToDispatchMarshaler;
+
+        private EnumerableToDispatchMarshaler()
+        {
+        }
+
+        public void CleanUpManagedData(object ManagedObj)
+        {
+        }
+
+        public void CleanUpNativeData(IntPtr pNativeData)
+        {
+            Marshal.Release(pNativeData);
+        }
+
+        public int GetNativeDataSize()
+        {
+            // Return -1 to indicate the managed type this marshaler handles is not a value type.
+            return -1;
+        }
+
+        public IntPtr MarshalManagedToNative(object ManagedObj)
+        {
+            if (ManagedObj == null)
+            {
+                throw new ArgumentNullException(nameof(ManagedObj));
+            }
+
+            return Marshal.GetComInterfaceForObject<object, IEnumerable>(ManagedObj);
+        }
+
+        public object MarshalNativeToManaged(IntPtr pNativeData)
+        {
+            if (pNativeData == IntPtr.Zero)
+            {
+                throw new ArgumentNullException(nameof(pNativeData));
+            }
+
+            object comObject = Marshal.GetObjectForIUnknown(pNativeData);
+
+            return ComDataHelpers.GetOrCreateManagedViewFromComData<object, EnumerableViewOfDispatch>(comObject, obj => new EnumerableViewOfDispatch(obj));
+        }
+    }
+}
diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumerableViewOfDispatch.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumerableViewOfDispatch.cs
new file mode 100644 (file)
index 0000000..8de914e
--- /dev/null
@@ -0,0 +1,51 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections;
+using System.Runtime.InteropServices.ComTypes;
+
+namespace System.Runtime.InteropServices.CustomMarshalers
+{
+    internal class EnumerableViewOfDispatch : ICustomAdapter, System.Collections.IEnumerable
+    {
+        // Reserved DISPID slot for getting an enumerator from an IDispatch-implementing COM interface.
+        private const int DISPID_NEWENUM = -4;
+        private const int LCID_DEFAULT = 1;
+        private readonly object _dispatch;
+
+        public EnumerableViewOfDispatch(object dispatch)
+        {
+            _dispatch = dispatch;
+        }
+
+        private IDispatch Dispatch => (IDispatch)_dispatch;
+
+        public IEnumerator GetEnumerator()
+        {
+            DISPPARAMS dispParams = new DISPPARAMS();
+            Guid guid = Guid.Empty;
+            Dispatch.Invoke(
+                DISPID_NEWENUM,
+                ref guid,
+                LCID_DEFAULT,
+                InvokeFlags.DISPATCH_METHOD | InvokeFlags.DISPATCH_PROPERTYGET,
+                ref dispParams,
+                out object result,
+                IntPtr.Zero,
+                IntPtr.Zero);
+
+            if (!(result is IEnumVARIANT enumVariant))
+            {
+                throw new InvalidOperationException(SR.InvalidOp_InvalidNewEnumVariant);
+            }
+
+            return (IEnumerator)EnumeratorToEnumVariantMarshaler.GetInstance(null).MarshalNativeToManaged(Marshal.GetIUnknownForObject(enumVariant));
+        }
+
+        public object GetUnderlyingObject()
+        {
+            return _dispatch;
+        }
+    }
+}
diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumeratorToEnumVariantMarshaler.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumeratorToEnumVariantMarshaler.cs
new file mode 100644 (file)
index 0000000..fddc0c7
--- /dev/null
@@ -0,0 +1,77 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Runtime.InteropServices.ComTypes;
+using System.Text;
+
+namespace System.Runtime.InteropServices.CustomMarshalers
+{
+    internal class EnumeratorToEnumVariantMarshaler : ICustomMarshaler
+    {
+        private static readonly EnumeratorToEnumVariantMarshaler s_enumeratorToEnumVariantMarshaler = new EnumeratorToEnumVariantMarshaler();
+
+        public static ICustomMarshaler GetInstance(string cookie) => s_enumeratorToEnumVariantMarshaler;
+
+        private EnumeratorToEnumVariantMarshaler()
+        {
+        }
+
+        public void CleanUpManagedData(object ManagedObj)
+        {
+        }
+
+        public void CleanUpNativeData(IntPtr pNativeData)
+        {
+            Marshal.Release(pNativeData);
+        }
+
+        public int GetNativeDataSize()
+        {
+            // Return -1 to indicate the managed type this marshaler handles is not a value type.
+            return -1;
+        }
+
+        public IntPtr MarshalManagedToNative(object ManagedObj)
+        {
+            if (ManagedObj == null)
+            {
+                throw new ArgumentNullException(nameof(ManagedObj));
+            }
+
+            if (ManagedObj is EnumeratorViewOfEnumVariant view)
+            {
+                return Marshal.GetComInterfaceForObject<object, IEnumVARIANT>(view.GetUnderlyingObject());
+            }
+
+            EnumVariantViewOfEnumerator nativeView = new EnumVariantViewOfEnumerator((IEnumerator)ManagedObj);
+
+            return Marshal.GetComInterfaceForObject<EnumVariantViewOfEnumerator, IEnumVARIANT>(nativeView);
+        }
+
+        public object MarshalNativeToManaged(IntPtr pNativeData)
+        {
+            if (pNativeData == IntPtr.Zero)
+            {
+                throw new ArgumentNullException(nameof(pNativeData));
+            }
+
+            object comObject = Marshal.GetObjectForIUnknown(pNativeData);
+
+            if (!comObject.GetType().IsCOMObject)
+            {
+                if (comObject is EnumVariantViewOfEnumerator enumVariantView)
+                {
+                    return enumVariantView.Enumerator;
+                }
+
+                return comObject as IEnumerator;
+            }
+
+            return ComDataHelpers.GetOrCreateManagedViewFromComData<IEnumVARIANT, EnumeratorViewOfEnumVariant>(comObject, var => new EnumeratorViewOfEnumVariant(var));
+        }
+    }
+}
diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumeratorViewOfEnumVariant.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumeratorViewOfEnumVariant.cs
new file mode 100644 (file)
index 0000000..b5ef969
--- /dev/null
@@ -0,0 +1,69 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Runtime.InteropServices.ComTypes;
+using System.Text;
+
+namespace System.Runtime.InteropServices.CustomMarshalers
+{
+    internal class EnumeratorViewOfEnumVariant : ICustomAdapter, IEnumerator
+    {
+        private readonly IEnumVARIANT _enumVariantObject;
+        private bool _fetchedLastObject;
+        private object[] _nextArray = new object[1];
+        private object _current;
+
+        public EnumeratorViewOfEnumVariant(IEnumVARIANT enumVariantObject)
+        {
+            _enumVariantObject = enumVariantObject;
+            _fetchedLastObject = false;
+            _current = null;
+        }
+
+        public object Current => _current;
+
+        public unsafe bool MoveNext()
+        {
+            if (_fetchedLastObject)
+            {
+                _current = null;
+                return false;
+            }
+
+            int numFetched = 0;
+
+            if (_enumVariantObject.Next(1, _nextArray, (IntPtr)(&numFetched)) == HResults.S_FALSE)
+            {
+                _fetchedLastObject = true;
+
+                if (numFetched == 0)
+                {
+                    _current = null;
+                    return false;
+                }
+            }
+
+            _current = _nextArray[0];
+
+            return true;
+        }
+
+        public void Reset()
+        {
+            int hr = _enumVariantObject.Reset();
+            if (hr < 0)
+            {
+                Marshal.ThrowExceptionForHR(hr);
+            }
+        }
+
+        public object GetUnderlyingObject()
+        {
+            return _enumVariantObject;
+        }
+    }
+}
diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/IDispatch.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/IDispatch.cs
new file mode 100644 (file)
index 0000000..183efa5
--- /dev/null
@@ -0,0 +1,50 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Runtime.InteropServices.ComTypes;
+using System.Text;
+
+namespace System.Runtime.InteropServices
+{
+    [ComImport]
+    [Guid("00020400-0000-0000-C000-000000000046")]
+    [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
+    internal interface IDispatch
+    {
+        void GetIDsOfNames(
+            ref Guid riid,
+            [MarshalAs(UnmanagedType.LPArray, ArraySubType = UnmanagedType.LPWStr, SizeParamIndex = 2), In]
+            string[] rgszNames,
+            int cNames,
+            int lcid,
+            [Out] int[] rgDispId);
+
+        ITypeInfo GetTypeInfo(
+            int iTInfo,
+            int lcid);
+
+        int GetTypeInfoCount();
+
+        void Invoke(
+            int dispIdMember,
+            ref Guid riid,
+            int lcid,
+            InvokeFlags wFlags,
+            ref DISPPARAMS pDispParams,
+            out object pVarResult,
+            IntPtr pExcepInfo,
+            IntPtr puArgErr);
+    }
+
+    [Flags]
+    internal enum InvokeFlags : short
+    {
+        DISPATCH_METHOD = 1,
+        DISPATCH_PROPERTYGET = 2,
+        DISPATCH_PROPERTYPUT = 4,
+        DISPATCH_PROPERTYPUTREF = 8
+    }
+}
diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/NativeMethods.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/NativeMethods.cs
deleted file mode 100644 (file)
index bd7e9e5..0000000
+++ /dev/null
@@ -1,24 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-#if FEATURE_COMINTEROP
-
-namespace System.Runtime.InteropServices
-{
-    /// <summary>
-    /// Part of ComEventHelpers APIs which allow binding managed delegates
-    /// to COM's connection point based events.
-    /// </summary>
-    internal static class NativeMethods
-    {
-        [ComImport]
-        [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
-        [Guid("00020400-0000-0000-C000-000000000046")]
-        internal interface IDispatch
-        {
-        }
-    }
-}
-
-#endif
index f7b0ce7..cb71df3 100644 (file)
@@ -27,6 +27,7 @@
 #ifdef FEATURE_COMINTEROP
 #define g_ECMAKeyToken "B77A5C561934E089"       // The ECMA key used by some framework assemblies: mscorlib, system, etc.
 #define g_FXKeyToken "b03f5f7f11d50a3a"         // The FX key used by other framework assemblies: System.Web, System.Drawing, etc.
+#define g_CorelibAsmName "System.Private.CoreLib"
 #define g_SystemAsmName "System"
 #define g_SystemRuntimeAsmName "System.Runtime"
 #define g_DrawingAsmName "System.Drawing"
index f9439e7..24af47c 100644 (file)
@@ -40,7 +40,7 @@
 
 
 #ifdef FEATURE_COMINTEROP
-    DEFINE_ASM_QUAL_TYPE_NAME(ENUMERATOR_TO_ENUM_VARIANT_CM_NAME, g_EnumeratorToEnumClassName, "CustomMarshalers", VER_ASSEMBLYVERSION_STR, g_FXKeyToken);
+    DEFINE_ASM_QUAL_TYPE_NAME(ENUMERATOR_TO_ENUM_VARIANT_CM_NAME, g_EnumeratorToEnumClassName, g_CorelibAsmName, VER_ASSEMBLYVERSION_STR, g_ECMAKeyToken);
 
     static const int        ENUMERATOR_TO_ENUM_VARIANT_CM_NAME_LEN    = lengthof(ENUMERATOR_TO_ENUM_VARIANT_CM_NAME);
     static const char       ENUMERATOR_TO_ENUM_VARIANT_CM_COOKIE[]    = {""};
index 7e93877..11bc5bf 100644 (file)
@@ -990,9 +990,7 @@ FCIMPL1(Object*, StdMngIEnumerable::GetEnumerator, Object* refThisUNSAFE)
 
     if (retVal == NULL)
     {
-        // In desktop CLR we'll attempt to call through IDispatch(DISPID_NEWENUM)
-        // This is not supported in CoreCLR
-        COMPlusThrow(kPlatformNotSupportedException, IDS_EE_ERROR_IDISPATCH);
+        retVal = ObjectToOBJECTREF((Object*)GetEnumeratorWorker(args));
     }
 
     GCPROTECT_END();
index f0d5fb3..6ee8b56 100644 (file)
@@ -30,7 +30,7 @@
 #define MNGSTDITF_DEFINE_METH3(FriendlyName, MethName, MethSig, FcallDecl) \
     MNGSTDITF_DEFINE_METH_IMPL(FriendlyName, MethName##_3, MethName, MethSig, FcallDecl)
         
-#define CUSTOM_MARSHALER_ASM ", CustomMarshalers, Version=" VER_ASSEMBLYVERSION_STR ", Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a"
+#define CUSTOM_MARSHALER_ASM ", System.Private.CoreLib, Version=" VER_ASSEMBLYVERSION_STR ", Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a"
 
 
 
index d9f3200..ab5d94f 100644 (file)
@@ -72,6 +72,7 @@ if(WIN32)
     add_subdirectory(PInvoke/NativeCallManagedComVisible)
     # This test doesn't necessarily need to be Windows-only, but the implementation is very tied to Windows APIs
     add_subdirectory(PInvoke/DateTime)
+    add_subdirectory(PInvoke/IEnumerator)
     add_subdirectory(ArrayMarshalling/SafeArray)
     add_subdirectory(COM/NativeServer)
     add_subdirectory(COM/NativeClients/Primitives)
index 475aede..f00edbb 100644 (file)
@@ -55,7 +55,7 @@ private:
 };
 
 #ifndef COM_CLIENT
-    #include "ComHelpers.h"
+    #include <ComHelpers.h>
 
     #define DEF_FUNC(n) virtual COM_DECLSPEC_NOTHROW HRESULT STDMETHODCALLTYPE n
 
diff --git a/tests/src/Interop/PInvoke/IEnumerator/CMakeLists.txt b/tests/src/Interop/PInvoke/IEnumerator/CMakeLists.txt
new file mode 100644 (file)
index 0000000..c040698
--- /dev/null
@@ -0,0 +1,11 @@
+cmake_minimum_required (VERSION 2.6)
+project (IEnumeratorNative)
+include_directories(${INC_PLATFORM_DIR})
+set(SOURCES IEnumeratorNative.cpp )
+
+# add the executable
+add_library (IEnumeratorNative SHARED ${SOURCES})
+target_link_libraries(IEnumeratorNative ${LINK_LIBRARIES_ADDITIONAL}) 
+
+# add the install targets
+install (TARGETS IEnumeratorNative DESTINATION bin)
diff --git a/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorNative.cpp b/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorNative.cpp
new file mode 100644 (file)
index 0000000..43f786a
--- /dev/null
@@ -0,0 +1,112 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+#include "IEnumeratorNative.h"
+#include <xplatform.h>
+
+extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE GetIntegerEnumerator(int start, int count, IEnumVARIANT** ppEnum)
+{
+    if (count < 0)
+    {
+        return E_INVALIDARG;
+    }
+
+    *ppEnum = new IntegerEnumerator(start, count);
+
+    return S_OK;
+}
+
+extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE VerifyIntegerEnumerator(IEnumVARIANT* pEnum, int start, int count)
+{
+    if (count < 0)
+    {
+        return E_INVALIDARG;
+    }
+
+    HRESULT hr = S_OK;
+
+    VARIANT element;
+    ULONG numFetched;
+    
+    for(int i = start; i < start + count; ++i)
+    {
+        VariantClear(&element);
+        hr = pEnum->Next(1, &element, &numFetched);
+        if(FAILED(hr) || numFetched != 1)
+        {
+            return hr;
+        }
+
+        if (V_I4(&element) != i)
+        {
+            return E_UNEXPECTED;
+        }
+    }
+    
+    hr = pEnum->Next(1, &element, &numFetched);
+    if (hr != S_FALSE || numFetched != 0)
+    {
+        return E_UNEXPECTED;
+    }
+
+    return S_OK;
+}
+
+extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE GetIntegerEnumeration(int start, int count, IDispatch** ppDisp)
+{
+    if (count < 0)
+    {
+        return E_INVALIDARG;
+    }
+
+    *ppDisp = new IntegerEnumerable(start, count);
+
+    return S_OK;
+}
+
+extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE VerifyIntegerEnumeration(IDispatch* pDisp, int start, int count)
+{
+    DISPPARAMS params{};
+    VARIANT result;
+    HRESULT hr = pDisp->Invoke(
+        DISPID_NEWENUM,
+        IID_NULL,
+        LOCALE_USER_DEFAULT,
+        DISPATCH_METHOD | DISPATCH_PROPERTYGET,
+        &params,
+        &result,
+        NULL,
+        NULL
+    );
+
+    if (FAILED(hr))
+    {
+        return hr;
+    }
+
+    if(!((V_VT(&result) == VT_UNKNOWN) || (V_VT(&result) == VT_DISPATCH)))
+    {
+        return E_UNEXPECTED;
+    }
+
+    IEnumVARIANT* pEnum;
+
+    hr = V_UNKNOWN(&result)->QueryInterface<IEnumVARIANT>(&pEnum);
+
+    if (FAILED(hr))
+    {
+        return hr;
+    }
+
+    hr = VerifyIntegerEnumerator(pEnum, start, count);
+
+    pEnum->Release();
+
+    return hr;
+}
+
+extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE PassThroughEnumerator(IEnumVARIANT* in, IEnumVARIANT** out)
+{
+    return in->QueryInterface(out);
+}
diff --git a/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorNative.h b/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorNative.h
new file mode 100644 (file)
index 0000000..c512901
--- /dev/null
@@ -0,0 +1,153 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+#include <xplatform.h>
+#include <platformdefines.h>
+#include <ComHelpers.h>
+#include <algorithm>
+
+class IntegerEnumerator : public UnknownImpl, public IEnumVARIANT
+{
+    int start;
+    int count;
+    int current;
+
+public:
+    IntegerEnumerator(int start, int count)
+        :UnknownImpl(),
+        start(start),
+        count(count),
+        current(start)
+    {
+    }
+
+    HRESULT STDMETHODCALLTYPE Next( 
+        ULONG celt,
+        VARIANT *rgVar,
+        ULONG *pCeltFetched) override
+    {
+        for(*pCeltFetched = 0; *pCeltFetched < celt && current < start + count; ++*pCeltFetched, ++current)
+        {
+            VariantClear(&(rgVar[*pCeltFetched]));
+            V_VT(&rgVar[*pCeltFetched]) = VT_I4;
+            V_I4(&(rgVar[*pCeltFetched])) = current;
+        }
+        
+        return celt == *pCeltFetched ? S_OK : S_FALSE;
+    }
+
+    HRESULT STDMETHODCALLTYPE Skip(ULONG celt) override
+    {
+        int original = current;
+        current = std::min(current + (int)celt, start + count);
+        return original + (int)celt <= start + count ? S_OK : S_FALSE;
+    }
+
+    HRESULT STDMETHODCALLTYPE Reset(void) override
+    {
+        current = start;
+        return S_OK;
+    }
+
+    HRESULT STDMETHODCALLTYPE Clone(IEnumVARIANT **ppEnum) override
+    {
+        IntegerEnumerator* clone = new IntegerEnumerator(start, count);
+        clone->current = current;
+        *ppEnum = clone;
+
+        return S_OK;
+    }
+
+    HRESULT STDMETHODCALLTYPE QueryInterface(
+        REFIID riid,
+        void** ppvObject)
+    {
+        return DoQueryInterface<IntegerEnumerator, IEnumVARIANT>(this, riid, ppvObject);
+    }
+
+    DEFINE_REF_COUNTING();
+};
+
+class IntegerEnumerable : public UnknownImpl, public IDispatch
+{
+private:
+    int start;
+    int count;
+public:
+    IntegerEnumerable(int start, int count)
+        :UnknownImpl(),
+        start(start),
+        count(count)
+    {
+    }
+
+    HRESULT STDMETHODCALLTYPE GetTypeInfoCount( 
+        UINT *pctinfo) override
+    {
+        *pctinfo = 0;
+        return S_OK;
+    }
+    
+    HRESULT STDMETHODCALLTYPE GetTypeInfo( 
+        UINT iTInfo,
+        LCID lcid,
+        ITypeInfo **ppTInfo) override
+    {
+        return E_NOTIMPL;
+    }
+    
+    HRESULT STDMETHODCALLTYPE GetIDsOfNames( 
+        REFIID riid,
+        LPOLESTR *rgszNames,
+        UINT cNames,
+        LCID lcid,
+        DISPID *rgDispId) override
+    {
+        bool containsUnknown = false;
+        DISPID *curr = rgDispId;
+        for (UINT i = 0; i < cNames; ++i)
+        {
+            *curr = DISPID_UNKNOWN;
+            LPOLESTR name = rgszNames[i];
+            if(wcscmp(name, W("GetEnumerator")) == 0)
+            {
+                *curr = DISPID_NEWENUM;
+            }
+
+            containsUnknown &= (*curr == DISPID_UNKNOWN);
+            curr++;
+        }
+
+        return (containsUnknown) ? DISP_E_UNKNOWNNAME : S_OK;
+    }
+    
+    HRESULT STDMETHODCALLTYPE Invoke(
+        DISPID dispIdMember,
+        REFIID riid,
+        LCID lcid,
+        WORD wFlags,
+        DISPPARAMS *pDispParams,
+        VARIANT *pVarResult,
+        EXCEPINFO *pExcepInfo,
+        UINT *puArgErr) override
+    {
+        if (dispIdMember == DISPID_NEWENUM && (wFlags & INVOKE_PROPERTYGET) == INVOKE_PROPERTYGET)
+        {
+            V_VT(pVarResult) = VT_UNKNOWN;
+            V_UNKNOWN(pVarResult) = new IntegerEnumerator(start, count);
+            return S_OK;
+        }
+
+        return E_NOTIMPL;
+    }
+
+    HRESULT STDMETHODCALLTYPE QueryInterface(
+        REFIID riid,
+        void** ppvObject)
+    {
+        return DoQueryInterface<IntegerEnumerable, IDispatch>(this, riid, ppvObject);
+    }
+
+    DEFINE_REF_COUNTING();
+};
diff --git a/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorTest.cs b/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorTest.cs
new file mode 100644 (file)
index 0000000..0c4e4d8
--- /dev/null
@@ -0,0 +1,94 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.InteropServices;
+using TestLibrary;
+
+namespace PInvokeTests
+{
+    static class IEnumeratorNative
+    {
+        [DllImport(nameof(IEnumeratorNative), PreserveSig = false)]
+        public static extern IEnumerator GetIntegerEnumerator(
+            int start,
+            int count);
+
+        [DllImport(nameof(IEnumeratorNative), PreserveSig = false)]
+        public static extern IEnumerable GetIntegerEnumeration(
+            int start,
+            int count);
+
+        [DllImport(nameof(IEnumeratorNative), PreserveSig = false)]
+        public static extern void VerifyIntegerEnumerator(
+            IEnumerator enumerator,
+            int start,
+            int count);
+
+        [DllImport(nameof(IEnumeratorNative), PreserveSig = false)]
+        public static extern void VerifyIntegerEnumeration(
+            IEnumerable enumerable,
+            int start,
+            int count);
+
+        [DllImport(nameof(IEnumeratorNative), PreserveSig = false)]
+        public static extern IEnumerator PassThroughEnumerator(IEnumerator enumerator);
+    }
+
+    public static class IEnumeratorTests
+    {
+        private static void TestNativeToManaged()
+        {
+            Assert.AreAllEqual(Enumerable.Range(1, 10), EnumeratorAsEnumerable(IEnumeratorNative.GetIntegerEnumerator(1, 10)));
+            Assert.AreAllEqual(Enumerable.Range(1, 10), IEnumeratorNative.GetIntegerEnumeration(1, 10).OfType<int>());
+        }
+
+        private static void TestManagedToNative()
+        {
+            IEnumeratorNative.VerifyIntegerEnumerator(Enumerable.Range(1, 10).GetEnumerator(), 1, 10);
+            IEnumeratorNative.VerifyIntegerEnumeration(Enumerable.Range(1, 10), 1, 10);
+        }
+
+        private static void TestNativeRoundTrip()
+        {
+            IEnumerator nativeEnumerator = IEnumeratorNative.GetIntegerEnumerator(1, 10);
+            Assert.AreEqual(nativeEnumerator, IEnumeratorNative.PassThroughEnumerator(nativeEnumerator));
+        }
+
+        private static void TestManagedRoundTrip()
+        {
+            IEnumerator managedEnumerator = Enumerable.Range(1, 10).GetEnumerator();
+            Assert.AreEqual(managedEnumerator, IEnumeratorNative.PassThroughEnumerator(managedEnumerator));
+        }
+
+        public static int Main()
+        {
+            try
+            {
+                TestNativeToManaged();
+                TestManagedToNative();
+                TestNativeRoundTrip();
+                TestManagedRoundTrip();
+            }
+            catch (System.Exception e)
+            {
+                Console.WriteLine(e.ToString());
+                return 101;
+            }
+
+            return 100;
+        }
+
+        private static IEnumerable<int> EnumeratorAsEnumerable(IEnumerator enumerator)
+        {
+            while (enumerator.MoveNext())
+            {
+                yield return (int)enumerator.Current;
+            }
+        }
+    }
+}
diff --git a/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorTest.csproj b/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorTest.csproj
new file mode 100644 (file)
index 0000000..f35118a
--- /dev/null
@@ -0,0 +1,36 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project ToolsVersion="12.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+  <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.props))\dir.props" />
+  <Import Project="../../Interop.settings.targets" />
+  <PropertyGroup>
+    <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
+    <Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform>
+    <AssemblyName>IEnumeratorTest</AssemblyName>
+    <SchemaVersion>2.0</SchemaVersion>
+    <ProjectGuid>{F1E66554-8C8E-4141-85CF-D0CD6A0CD0B0}</ProjectGuid>
+    <OutputType>exe</OutputType>
+    <ProjectTypeGuids>{786C830F-07A1-408B-BD7F-6EE04809D6DB};{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}</ProjectTypeGuids>
+    <SolutionDir Condition="$(SolutionDir) == '' Or $(SolutionDir) == '*Undefined*'">..\..\</SolutionDir>
+    <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
+    <DefineConstants>$(DefineConstants);STATIC</DefineConstants>
+
+    <!-- IEnumerator/IEnumerable marshalling unsupported outside of windows -->
+    <TestUnsupportedOutsideWindows>true</TestUnsupportedOutsideWindows>
+    <DisableProjectBuild Condition="'$(TargetsUnix)' == 'true'">true</DisableProjectBuild>
+  </PropertyGroup>
+  <!-- Default configurations to help VS understand the configurations -->
+  <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|x64'"></PropertyGroup>
+  <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Release|x64'"></PropertyGroup>
+  <ItemGroup>
+    <CodeAnalysisDependentAssemblyPaths Condition=" '$(VS100COMNTOOLS)' != '' " Include="$(VS100COMNTOOLS)..\IDE\PrivateAssemblies">
+      <Visible>False</Visible>
+    </CodeAnalysisDependentAssemblyPaths>
+  </ItemGroup>
+  <ItemGroup>
+    <Compile Include="*.cs" />
+  </ItemGroup>
+  <ItemGroup>
+    <ProjectReference Include="CMakeLists.txt"/>
+  </ItemGroup>
+  <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.targets))\dir.targets" />
+</Project>