Support COM events (#22721)
authorAaron Robinson <arobins@microsoft.com>
Thu, 21 Feb 2019 17:51:10 +0000 (09:51 -0800)
committerGitHub <noreply@github.com>
Thu, 21 Feb 2019 17:51:10 +0000 (09:51 -0800)
* Add support for COM events
* Add COM Event tests

29 files changed:
src/System.Private.CoreLib/System.Private.CoreLib.csproj
src/System.Private.CoreLib/src/Interop/Windows/OleAut32/Interop.VariantClear.cs [new file with mode: 0644]
src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsHelper.cs
src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsInfo.cs
src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsMethod.cs
src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsSink.cs
src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumerableViewOfDispatch.cs
src/System.Private.CoreLib/src/System/Runtime/InteropServices/IDispatch.cs
src/System.Private.CoreLib/src/System/Runtime/InteropServices/Variant.cs [new file with mode: 0644]
tests/src/Interop/COM/NETClients/Events/App.manifest [new file with mode: 0644]
tests/src/Interop/COM/NETClients/Events/NETClientEvents.csproj [new file with mode: 0644]
tests/src/Interop/COM/NETClients/Events/Program.cs [new file with mode: 0644]
tests/src/Interop/COM/NativeServer/ArrayTesting.h
tests/src/Interop/COM/NativeServer/COMNativeServer.X.manifest
tests/src/Interop/COM/NativeServer/ColorTesting.h
tests/src/Interop/COM/NativeServer/DispatchTesting.h
tests/src/Interop/COM/NativeServer/ErrorMarshalTesting.h
tests/src/Interop/COM/NativeServer/EventTesting.h [new file with mode: 0644]
tests/src/Interop/COM/NativeServer/NumericTesting.h
tests/src/Interop/COM/NativeServer/Servers.cpp
tests/src/Interop/COM/NativeServer/Servers.h
tests/src/Interop/COM/NativeServer/StringTesting.h
tests/src/Interop/COM/ServerContracts/NativeServers.cs
tests/src/Interop/COM/ServerContracts/Server.Contracts.cs
tests/src/Interop/COM/ServerContracts/Server.Contracts.h
tests/src/Interop/COM/ServerContracts/Server.Events.cs [new file with mode: 0644]
tests/src/Interop/COM/ServerContracts/ServerGuids.cs
tests/src/Interop/PInvoke/IEnumerator/IEnumeratorNative.h
tests/src/Interop/common/ComHelpers.h

index 5930aa0..5e98274 100644 (file)
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComEventsInfo.cs" Condition="'$(FeatureClassicCominterop)' == 'true'" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComEventsMethod.cs" Condition="'$(FeatureClassicCominterop)' == 'true'" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComEventsSink.cs" Condition="'$(FeatureClassicCominterop)' == 'true'" />
+    <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\Variant.cs" Condition="'$(FeatureClassicCominterop)' == 'true'" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\CustomMarshalers\ComDataHelpers.cs" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\CustomMarshalers\EnumVariantViewOfEnumerator.cs" />
     <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\CustomMarshalers\EnumerableToDispatchMarshaler.cs" />
   </ItemGroup>
   <ItemGroup Condition="'$(TargetsWindows)' == 'true'">
     <Compile Include="$(BclSourcesRoot)\System\DateTime.Windows.cs" />
+    <Compile Include="$(BclSourcesRoot)\Interop\Windows\OleAut32\Interop.VariantClear.cs" />
     <Compile Include="$(BclSourcesRoot)\System\ApplicationModel.Windows.cs" />
     <Compile Include="$(BclSourcesRoot)\System\Globalization\GlobalizationMode.Windows.cs" />
     <Compile Include="$(BclSourcesRoot)\System\Threading\ClrThreadPoolBoundHandle.Windows.cs" />
diff --git a/src/System.Private.CoreLib/src/Interop/Windows/OleAut32/Interop.VariantClear.cs b/src/System.Private.CoreLib/src/Interop/Windows/OleAut32/Interop.VariantClear.cs
new file mode 100644 (file)
index 0000000..ae95fe5
--- /dev/null
@@ -0,0 +1,15 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Runtime.InteropServices;
+
+internal partial class Interop
+{
+    internal partial class OleAut32
+    {
+        [DllImport(Libraries.OleAut32)]
+        internal static extern void VariantClear(IntPtr variant);
+    }
+}
index 50e9ea6..a5d431d 100644 (file)
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
-
-/*============================================================
-**
-**
-** Purpose: ComEventHelpers APIs allow binding 
-** managed delegates to COM's connection point based events.
-**
-**/
 //
-// #ComEventsFeature
-// 
-// code:#ComEventsFeature defines two public methods allowing to add/remove .NET delegates handling
-// events from COM objects. Those methods are defined as part of code:ComEventsHelper static class
-// * code:ComEventsHelper.Combine - will create/reuse-an-existing COM event sink and register the
+// ComEventsFeature
+//
+// ComEventsFeature defines two public methods allowing to add/remove .NET delegates handling
+// events from COM objects. Those methods are defined as part of ComEventsHelper static class
+// * ComEventsHelper.Combine - will create/reuse-an-existing COM event sink and register the
 //     specified delegate to be raised when corresponding COM event is raised
-// * code:ComEventsHelper.Remove
-// 
-// 
-// To bind an event handler to the COM object you need to provide the following data:
-//  * rcw - the instance of the COM object you want to bind to
-//  * iid - Guid of the source interface you want the sink to implement
-//  * dispid - dispatch identifier of the event on the source interface you are interested in
-//  * d - delegate to invoked when corresponding COM event is raised.
-// 
-// #ComEventsArchitecture:
-// In COM world, events are handled by so-called event sinks. What these are? COM-based Object Models
-// (OMs) define "source" interfaces that need to be implemented by the COM clients to receive events. So,
-// event sinks are COM objects implementing a source interfaces. Once an event sink is passed to the COM
+// * ComEventsHelper.Remove
+//
+// ComEventsArchitecture:
+// In COM world, events are handled by so-called event sinks. These are COM-based Object Models
+// (OMs) that define "source" interfaces that need to be implemented by COM clients to receive events. So,
+// event sinks are COM objects implementing source interfaces. Once an event sink is passed to the COM
 // server (through a mechanism known as 'binding/advising to connection point'), COM server will be
-// calling source interface methods to "fire events" (advising, connection points, firing events etc. -
-// is all COM jargon).
-// 
+// calling source interface methods to "fire events".
+// See https://docs.microsoft.com/cpp/mfc/connection-points
+//
 // There are few interesting obervations about source interfaces. Usually source interfaces are defined
 // as 'dispinterface' - meaning that only late-bound invocations on this interface are allowed. Even
 // though it is not illegal to use early bound invocations on source interfaces - the practice is
 // discouraged because of versioning concerns.
-// 
+//
 // Notice also that each COM server object might define multiple source interfaces and hence have
 // multiple connection points (each CP handles exactly one source interface). COM objects that want to
-// fire events are required to implement IConnectionPointContainer interface which is used by the COM
-// clients to discovery connection poitns - objects implementing IConnectionPoint interface. Once
+// fire events are required to implement the IConnectionPointContainer interface which is used by COM
+// clients to discovery connection points - objects implementing IConnectionPoint interface. Once a
 // connection point is found - clients can bind to it using IConnectionPoint::Advise (see
-// code:ComEventsSink.Advise).
-// 
-// The idea behind code:#ComEventsFeature is to write a "universal event sink" COM component that is
+// ComEventsSink.Advise).
+//
+// The idea behind ComEventsFeature is to write a "universal event sink" COM component that is
 // generic enough to handle all late-bound event firings and invoke corresponding COM delegates (through
 // reflection).
-// 
-// When delegate is registered (using code:ComEventsHelper.Combine) we will verify we have corresponding
+//
+// When delegate is registered (using ComEventsHelper.Combine) we will verify we have corresponding
 // event sink created and bound.
-// 
-// But what happens when COM events are fired? code:ComEventsSink.Invoke implements IDispatch::Invoke method
-// and this is the entry point that is called. Once our event sink is invoked, we need to find the
-// corresponding delegate to invoke . We need to match the dispid of the call that is coming in to a
-// dispid of .NET delegate that has been registered for this object. Once this is found we do call the
-// delegates using reflection (code:ComEventsMethod.Invoke).
-// 
-// #ComEventsArgsMarshalling
+//
+// When COM events are fired, ComEventsSink.Invoke implements IDispatch and the Invoke method
+// is the entry point that is called. Once our event sink is invoked, we need to find the
+// corresponding delegate to invoke. We need to match the dispid of the call that is coming in to a
+// dispid of .NET delegate that has been registered for this object. Once this is found we call the
+// delegates using reflection (see ComEventsMethod.Invoke).
+//
+// ComEventsArgsMarshalling
 // Notice, that we may not have a delegate registered against every method on the source interface. If we
 // were to marshal all the input parameters for methods that do not reach user code - we would end up
 // generatic RCWs that are not reachable for user code (the inconvenience it might create is there will
 // be RCWs that users can not call Marshal.ReleaseComObject on to explicitly manage the lifetime of these
-// COM objects). The above behavior was one of the shortcoimings of legacy TLBIMP's implementation of COM
+// COM objects). The above behavior was one of the shortcomings of legacy TLBIMP's implementation of COM
 // event sinking. In our code we will not marshal any data if there is no delegate registered to handle
-// the event. (code:ComEventsMethod.Invoke)
-// 
-// #ComEventsFinalization:
+// the event. (see ComEventsMethod.Invoke)
+//
+// ComEventsFinalization:
 // Additional area of interest is when COM sink should be unadvised from the connection point. Legacy
 // TLBIMP's implementation of COM event sinks will unadvises the sink when corresponding RCW is GCed.
 // This is achieved by rooting the event sinks in a finalizable object stored in RCW's property bag
 // (using Marshal.SetComObjectData). Hence, once RCW is no longer reachable - the finalizer is called and
 // it would unadvise all the event sinks. We are employing the same strategy here. See storing an
-// instance in the RCW at code:ComEventsInfo.FromObject and undadvsing the sinks at
-// code:ComEventsInfo.~ComEventsInfo
-// 
+// instance in the RCW at ComEventsInfo.FromObject and unadvising the sinks in ComEventsInfo.~ComEventsInfo
+//
 // Classes of interest:
-// * code:ComEventsHelpers - defines public methods but there are also a number of internal classes that
-//     implement the actual COM event sink:
-// * code:ComEventsInfo - represents a finalizable container for all event sinks for a particular RCW.
+// * ComEventsHelpers - defines public methods but there are also a number of internal classes that
+//     implement the actual COM event sink
+// * ComEventsInfo - represents a finalizable container for all event sinks for a particular RCW.
 //     Lifetime of this instance corresponds to the lifetime of the RCW object
-// * code:ComEventsSink - represents a single event sink. Maintains an internal pointer to the next
-//     instance (in a singly linked list). A collection of code:ComEventsSink is stored at
-//     code:ComEventsInfo._sinks
-// * code:ComEventsMethod - represents a single method from the source interface which has .NET delegates
+// * ComEventsSink - represents a single event sink. Maintains an internal pointer to the next
+//     instance (in a singly linked list). A collection of ComEventsSink is stored at
+//     ComEventsInfo._sinks
+// * ComEventsMethod - represents a single method from the source interface which has .NET delegates
 //     attached to it. Maintains an internal pointer to the next instance (in a singly linked list). A
-//     collection of code:ComEventMethod is stored at code:ComEventsSink._methods
-//     
-// #ComEventsRetValIssue:
+//     collection of ComEventMethod is stored at ComEventsSink._methods
+//
+// ComEventsRetValIssue:
 // Issue: normally, COM events would not return any value. However, it may happen as described in
 // http://support.microsoft.com/kb/810228. Such design might represent a problem for us - e.g. what is
 // the return value of a chain of delegates - is it the value of the last call in the chain or the the
 // first one? As the above KB article indicates, in cases where OM has events returning values, it is
 // suggested that people implement their event sink by explicitly implementing the source interface. This
 // means that the problem is already quite complex and we should not be dealing with it - see
-// code:ComEventsMethod.Invoke
+// ComEventsMethod.Invoke
 
 using System;
 
@@ -115,10 +99,8 @@ namespace System.Runtime.InteropServices
         /// <param name="iid">identifier of the source interface used by COM object to fire events</param>
         /// <param name="dispid">dispatch identifier of the method on the source interface</param>
         /// <param name="d">delegate to invoke when specified COM event is fired</param>
-        public static void Combine(object rcw, Guid iid, int dispid, System.Delegate d)
+        public static void Combine(object rcw, Guid iid, int dispid, Delegate d)
         {
-            rcw = UnwrapIfTransparentProxy(rcw);
-
             lock (rcw)
             {
                 ComEventsInfo eventsInfo = ComEventsInfo.FromObject(rcw);
@@ -129,7 +111,6 @@ namespace System.Runtime.InteropServices
                     sink = eventsInfo.AddSink(ref iid);
                 }
 
-
                 ComEventsMethod method = sink.FindMethod(dispid);
                 if (method == null)
                 {
@@ -147,22 +128,27 @@ namespace System.Runtime.InteropServices
         /// <param name="iid">identifier of the source interface used by COM object to fire events</param>
         /// <param name="dispid">dispatch identifier of the method on the source interface</param>
         /// <param name="d">delegate to remove from the invocation list</param>
-        /// <returns></returns>
-        public static Delegate Remove(object rcw, Guid iid, int dispid, System.Delegate d)
+        public static Delegate Remove(object rcw, Guid iid, int dispid, Delegate d)
         {
-            rcw = UnwrapIfTransparentProxy(rcw);
-
             lock (rcw)
             {
                 ComEventsInfo eventsInfo = ComEventsInfo.Find(rcw);
                 if (eventsInfo == null)
+                {
                     return null;
+                }
+
                 ComEventsSink sink = eventsInfo.FindSink(ref iid);
                 if (sink == null)
+                {
                     return null;
+                }
+
                 ComEventsMethod method = sink.FindMethod(dispid);
                 if (method == null)
+                {
                     return null;
+                }
 
                 method.RemoveDelegate(d);
 
@@ -171,11 +157,13 @@ namespace System.Runtime.InteropServices
                     // removed the last event handler for this dispid - need to remove dispid handler
                     method = sink.RemoveMethod(method);
                 }
+
                 if (method == null)
                 {
                     // removed last dispid handler for this sink - need to remove the sink
                     sink = eventsInfo.RemoveSink(sink);
                 }
+
                 if (sink == null)
                 {
                     // removed last sink for this rcw - need to remove all traces of event info
@@ -186,10 +174,5 @@ namespace System.Runtime.InteropServices
                 return d;
             }
         }
-
-        internal static object UnwrapIfTransparentProxy(object rcw)
-        {
-            return rcw;
-        }
     }
 }
index 0fbe34d..8b47683 100644 (file)
@@ -2,33 +2,16 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
-
-/*============================================================
-**
-**
-** Purpose: part of ComEventHelpers APIs which allow binding 
-** managed delegates to COM's connection point based events.
-**
-**/
+using System;
+using ComTypes = System.Runtime.InteropServices.ComTypes;
 
 namespace System.Runtime.InteropServices
 {
-    using System;
-    using ComTypes = System.Runtime.InteropServices.ComTypes;
-
-    // see code:ComEventsHelper#ComEventsArchitecture
     internal class ComEventsInfo
     {
-        #region fields
-
         private ComEventsSink _sinks;
         private object _rcw;
 
-        #endregion
-
-
-        #region ctor/dtor
-
         private ComEventsInfo(object rcw)
         {
             _rcw = rcw;
@@ -36,22 +19,17 @@ namespace System.Runtime.InteropServices
 
         ~ComEventsInfo()
         {
-            // see code:ComEventsHelper#ComEventsFinalization
+            // see notes in ComEventsHelper.cs regarding ComEventsFinalization
             _sinks = ComEventsSink.RemoveAll(_sinks);
         }
 
-        #endregion
-
-
-        #region static methods
-
-        internal static ComEventsInfo Find(object rcw)
+        public static ComEventsInfo Find(object rcw)
         {
             return (ComEventsInfo)Marshal.GetComObjectData(rcw, typeof(ComEventsInfo));
         }
 
         // it is caller's responsibility to call this method under lock(rcw)
-        internal static ComEventsInfo FromObject(object rcw)
+        public static ComEventsInfo FromObject(object rcw)
         {
             ComEventsInfo eventsInfo = Find(rcw);
             if (eventsInfo == null)
@@ -62,18 +40,13 @@ namespace System.Runtime.InteropServices
             return eventsInfo;
         }
 
-        #endregion
-
-
-        #region internal methods
-
-        internal ComEventsSink FindSink(ref Guid iid)
+        public ComEventsSink FindSink(ref Guid iid)
         {
             return ComEventsSink.Find(_sinks, ref iid);
         }
 
         // it is caller's responsibility to call this method under lock(rcw)
-        internal ComEventsSink AddSink(ref Guid iid)
+        public ComEventsSink AddSink(ref Guid iid)
         {
             ComEventsSink sink = new ComEventsSink(_rcw, iid);
             _sinks = ComEventsSink.Add(_sinks, sink);
@@ -87,7 +60,5 @@ namespace System.Runtime.InteropServices
             _sinks = ComEventsSink.Remove(_sinks, sink);
             return _sinks;
         }
-
-        #endregion
     }
 }
index 8b1bcdc..ce36100 100644 (file)
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
-
-/*============================================================
-**
-**
-** Purpose: part of ComEventHelpers APIs which allow binding 
-** managed delegates to COM's connection point based events.
-**
-**/
-
 using System;
 using System.Collections.Generic;
-using System.Text;
 using System.Diagnostics;
+using System.Text;
 using System.Runtime.InteropServices;
 using System.Reflection;
 
-
 namespace System.Runtime.InteropServices
 {
-    // see code:ComEventsHelper#ComEventsArchitecture
+    /// <summary>
+    /// Part of ComEventHelpers APIs which allow binding
+    /// managed delegates to COM's connection point based events.
+    /// </summary>
     internal class ComEventsMethod
     {
-        // This delegate wrapper class handles dynamic invocation of delegates. The reason for the wrapper's
-        // existence is that under certain circumstances we need to coerce arguments to types expected by the
-        // delegates signature. Normally, reflection (Delegate.DynamicInvoke) handles types coercion
-        // correctly but one known case is when the expected signature is 'ref Enum' - in this case
-        // reflection by design does not do the coercion. Since we need to be compatible with COM interop
-        // handling of this scenario - we are pre-processing delegate's signature by looking for 'ref enums'
-        // and cache the types required for such coercion.
-        internal class DelegateWrapper
+        /// <summary>
+        /// This delegate wrapper class handles dynamic invocation of delegates. The reason for the wrapper's
+        /// existence is that under certain circumstances we need to coerce arguments to types expected by the
+        /// delegates signature. Normally, reflection (Delegate.DynamicInvoke) handles type coercion
+        /// correctly but one known case is when the expected signature is 'ref Enum' - in this case
+        /// reflection by design does not do the coercion. Since we need to be compatible with COM interop
+        /// handling of this scenario - we are pre-processing delegate's signature by looking for 'ref enums'
+        /// and cache the types required for such coercion.
+        /// </summary>
+        public class DelegateWrapper
         {
-            private Delegate _d;
+            private bool _once = false;
+            private int _expectedParamsCount;
+            private Type[] _cachedTargetTypes;
 
             public DelegateWrapper(Delegate d)
             {
-                _d = d;
+                Delegate = d;
             }
 
-            public Delegate Delegate
+            public Delegate Delegate { get; set; }
+
+            public object Invoke(object[] args)
             {
-                get { return _d; }
-                set { _d = value; }
+                if (Delegate == null)
+                {
+                    return null;
+                }
+
+                if (_once == false)
+                {
+                    PreProcessSignature();
+                    _once = true;
+                }
+
+                if (_cachedTargetTypes != null && _expectedParamsCount == args.Length)
+                {
+                    for (int i = 0; i < _expectedParamsCount; i++)
+                    {
+                        if (_cachedTargetTypes[i] != null)
+                        {
+                            args[i] = Enum.ToObject(_cachedTargetTypes[i], args[i]);
+                        }
+                    }
+                }
+
+                return Delegate.DynamicInvoke(args);
             }
-        }
 
-        #region private fields
+            private void PreProcessSignature()
+            {
+                ParameterInfo[] parameters = Delegate.Method.GetParameters();
+                _expectedParamsCount = parameters.Length;
+
+                bool needToHandleCoercion = false;
+
+                var targetTypes = new List<Type>();
+                foreach (ParameterInfo pi in parameters)
+                {
+                    Type targetType = null;
+
+                    // recognize only 'ref Enum' signatures and cache
+                    // both enum type and the underlying type.
+                    if (pi.ParameterType.IsByRef
+                        && pi.ParameterType.HasElementType
+                        && pi.ParameterType.GetElementType().IsEnum)
+                    {
+                        needToHandleCoercion = true;
+                        targetType = pi.ParameterType.GetElementType();
+                    }
+
+                    targetTypes.Add(targetType);
+                }
+
+                if (needToHandleCoercion)
+                {
+                    _cachedTargetTypes = targetTypes.ToArray();
+                }
+            }
+        }
 
         /// <summary>
         /// Invoking ComEventsMethod means invoking a multi-cast delegate attached to it.
         /// Since multicast delegate's built-in chaining supports only chaining instances of the same type,
         /// we need to complement this design by using an explicit linked list data structure.
         /// </summary>
-        private DelegateWrapper[] _delegateWrappers;
+        private List<DelegateWrapper> _delegateWrappers = new List<DelegateWrapper>();
 
-        private int _dispid;
+        private readonly int _dispid;
         private ComEventsMethod _next;
 
-        #endregion
-
-
-        #region ctor
-
-        internal ComEventsMethod(int dispid)
+        public ComEventsMethod(int dispid)
         {
-            _delegateWrappers = null;
             _dispid = dispid;
         }
 
-        #endregion
-
-
-        #region internal static methods
-
-        internal static ComEventsMethod Find(ComEventsMethod methods, int dispid)
+        public static ComEventsMethod Find(ComEventsMethod methods, int dispid)
         {
             while (methods != null && methods._dispid != dispid)
             {
                 methods = methods._next;
             }
+
             return methods;
         }
 
-        internal static ComEventsMethod Add(ComEventsMethod methods, ComEventsMethod method)
+        public static ComEventsMethod Add(ComEventsMethod methods, ComEventsMethod method)
         {
             method._next = methods;
             return method;
         }
 
-        internal static ComEventsMethod Remove(ComEventsMethod methods, ComEventsMethod method)
+        public static ComEventsMethod Remove(ComEventsMethod methods, ComEventsMethod method)
         {
+            Debug.Assert(methods != null, "removing method from empty methods collection");
+            Debug.Assert(method != null, "specify method is null");
+
             if (methods == method)
             {
                 methods = methods._next;
@@ -100,99 +142,100 @@ namespace System.Runtime.InteropServices
             {
                 ComEventsMethod current = methods;
                 while (current != null && current._next != method)
+                {
                     current = current._next;
+                }
+
                 if (current != null)
+                {
                     current._next = method._next;
+                }
             }
 
             return methods;
         }
 
-        #endregion
-        #region public properties / methods
-
-        internal bool Empty
-        {
-            get { return _delegateWrappers == null || _delegateWrappers.Length == 0; }
-        }
-
-        internal void AddDelegate(Delegate d)
+        public bool Empty
         {
-            int count = 0;
-            if (_delegateWrappers != null)
+            get
             {
-                count = _delegateWrappers.Length;
+                lock (_delegateWrappers)
+                {
+                    return _delegateWrappers.Count == 0;
+                }
             }
+        }
 
-            for (int i = 0; i < count; i++)
+        public void AddDelegate(Delegate d)
+        {
+            lock (_delegateWrappers)
             {
-                if (_delegateWrappers[i].Delegate.GetType() == d.GetType())
+                // Update an existing delegate wrapper
+                foreach (DelegateWrapper wrapper in _delegateWrappers)
                 {
-                    _delegateWrappers[i].Delegate = Delegate.Combine(_delegateWrappers[i].Delegate, d);
-                    return;
+                    if (wrapper.Delegate.GetType() == d.GetType())
+                    {
+                        wrapper.Delegate = Delegate.Combine(wrapper.Delegate, d);
+                        return;
+                    }
                 }
-            }
 
-            DelegateWrapper[] newDelegateWrappers = new DelegateWrapper[count + 1];
-            if (count > 0)
-            {
-                _delegateWrappers.CopyTo(newDelegateWrappers, 0);
+                var newWrapper = new DelegateWrapper(d);
+                _delegateWrappers.Add(newWrapper);
             }
-
-            DelegateWrapper wrapper = new DelegateWrapper(d);
-            newDelegateWrappers[count] = wrapper;
-
-            _delegateWrappers = newDelegateWrappers;
         }
 
-        internal void RemoveDelegate(Delegate d)
+        public void RemoveDelegate(Delegate d)
         {
-            int count = _delegateWrappers.Length;
-            int removeIdx = -1;
-
-            for (int i = 0; i < count; i++)
+            lock (_delegateWrappers)
             {
-                if (_delegateWrappers[i].Delegate.GetType() == d.GetType())
+                // Find delegate wrapper index
+                int removeIdx = -1;
+                DelegateWrapper wrapper = null;
+                for (int i = 0; i < _delegateWrappers.Count; i++)
                 {
-                    removeIdx = i;
-                    break;
+                    DelegateWrapper wrapperMaybe = _delegateWrappers[i];
+                    if (wrapperMaybe.Delegate.GetType() == d.GetType())
+                    {
+                        removeIdx = i;
+                        wrapper = wrapperMaybe;
+                        break;
+                    }
                 }
-            }
 
-            if (removeIdx < 0)
-                return;
+                if (removeIdx < 0)
+                {
+                    // Not present in collection
+                    return;
+                }
 
-            Delegate newDelegate = Delegate.Remove(_delegateWrappers[removeIdx].Delegate, d);
-            if (newDelegate != null)
-            {
-                _delegateWrappers[removeIdx].Delegate = newDelegate;
-                return;
+                // Update wrapper or remove from collection
+                Delegate newDelegate = Delegate.Remove(wrapper.Delegate, d);
+                if (newDelegate != null)
+                {
+                    wrapper.Delegate = newDelegate;
+                }
+                else
+                {
+                    _delegateWrappers.RemoveAt(removeIdx);
+                }
             }
+        }
 
-            // now remove the found entry from the _delegates array
-
-            if (count == 1)
-            {
-                _delegateWrappers = null;
-                return;
-            }
+        public object Invoke(object[] args)
+        {
+            Debug.Assert(!Empty);
+            object result = null;
 
-            DelegateWrapper[] newDelegateWrappers = new DelegateWrapper[count - 1];
-            int j = 0;
-            while (j < removeIdx)
+            lock (_delegateWrappers)
             {
-                newDelegateWrappers[j] = _delegateWrappers[j];
-                j++;
-            }
-            while (j < count - 1)
-            {
-                newDelegateWrappers[j] = _delegateWrappers[j + 1];
-                j++;
+                foreach (DelegateWrapper wrapper in _delegateWrappers)
+                {
+                    result = wrapper.Invoke(args);
+                }
             }
 
-            _delegateWrappers = newDelegateWrappers;
+            return result;
         }
-
-        #endregion
     }
 }
index 9281d24..c5262a6 100644 (file)
@@ -2,48 +2,32 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
-
-/*============================================================
-**
-**
-** Purpose: part of ComEventHelpers APIs which allow binding 
-** managed delegates to COM's connection point based events.
-**
-**/
-
 using System;
 using System.Diagnostics;
 
+using Variant = System.Runtime.InteropServices.Variant;
+
 namespace System.Runtime.InteropServices
 {
-    // see code:ComEventsHelper#ComEventsArchitecture
-    internal class ComEventsSink : ICustomQueryInterface
+    /// <summary>
+    /// Part of ComEventHelpers APIs which allow binding
+    /// managed delegates to COM's connection point based events.
+    /// </summary>
+    internal class ComEventsSink : IDispatch, ICustomQueryInterface
     {
-        #region private fields
-
         private Guid _iidSourceItf;
         private ComTypes.IConnectionPoint _connectionPoint;
         private int _cookie;
         private ComEventsMethod _methods;
         private ComEventsSink _next;
 
-        #endregion
-
-
-        #region ctor
-
-        internal ComEventsSink(object rcw, Guid iid)
+        public ComEventsSink(object rcw, Guid iid)
         {
             _iidSourceItf = iid;
             this.Advise(rcw);
         }
 
-        #endregion
-
-
-        #region static members
-
-        internal static ComEventsSink Find(ComEventsSink sinks, ref Guid iid)
+        public static ComEventsSink Find(ComEventsSink sinks, ref Guid iid)
         {
             ComEventsSink sink = sinks;
             while (sink != null && sink._iidSourceItf != iid)
@@ -54,13 +38,13 @@ namespace System.Runtime.InteropServices
             return sink;
         }
 
-        internal static ComEventsSink Add(ComEventsSink sinks, ComEventsSink sink)
+        public static ComEventsSink Add(ComEventsSink sinks, ComEventsSink sink)
         {
             sink._next = sinks;
             return sink;
         }
 
-        internal static ComEventsSink RemoveAll(ComEventsSink sinks)
+        public static ComEventsSink RemoveAll(ComEventsSink sinks)
         {
             while (sinks != null)
             {
@@ -71,7 +55,7 @@ namespace System.Runtime.InteropServices
             return null;
         }
 
-        internal static ComEventsSink Remove(ComEventsSink sinks, ComEventsSink sink)
+        public static ComEventsSink Remove(ComEventsSink sinks, ComEventsSink sink)
         {
             Debug.Assert(sinks != null, "removing event sink from empty sinks collection");
             Debug.Assert(sink != null, "specify event sink is null");
@@ -84,7 +68,9 @@ namespace System.Runtime.InteropServices
             {
                 ComEventsSink current = sinks;
                 while (current != null && current._next != sink)
+                {
                     current = current._next;
+                }
 
                 if (current != null)
                 {
@@ -97,11 +83,6 @@ namespace System.Runtime.InteropServices
             return sinks;
         }
 
-        #endregion
-
-
-        #region public methods
-
         public ComEventsMethod RemoveMethod(ComEventsMethod method)
         {
             _methods = ComEventsMethod.Remove(_methods, method);
@@ -120,7 +101,136 @@ namespace System.Runtime.InteropServices
             return method;
         }
 
-        #endregion
+        int IDispatch.GetTypeInfoCount()
+        {
+            return 0;
+        }
+
+        ComTypes.ITypeInfo IDispatch.GetTypeInfo(int iTInfo, int lcid)
+        {
+            throw new NotImplementedException();
+        }
+
+        void IDispatch.GetIDsOfNames(ref Guid iid, string[] names, int cNames, int lcid, int[] rgDispId)
+        {
+            throw new NotImplementedException();
+        }
+
+        private const VarEnum VT_BYREF_VARIANT = VarEnum.VT_BYREF | VarEnum.VT_VARIANT;
+        private const VarEnum VT_TYPEMASK = (VarEnum) 0x0fff;
+        private const VarEnum VT_BYREF_TYPEMASK = VT_TYPEMASK | VarEnum.VT_BYREF;
+
+        private static unsafe ref Variant GetVariant(ref Variant pSrc)
+        {
+            if (pSrc.VariantType == VT_BYREF_VARIANT)
+            {
+                // For VB6 compatibility reasons, if the VARIANT is a VT_BYREF | VT_VARIANT that
+                // contains another VARIANT with VT_BYREF | VT_VARIANT, then we need to extract the
+                // inner VARIANT and use it instead of the outer one. Note that if the inner VARIANT
+                // is VT_BYREF | VT_VARIANT | VT_ARRAY, it will pass the below test too.
+                Span<Variant> pByRefVariant = new Span<Variant>(pSrc.AsByRefVariant.ToPointer(), 1);
+                if ((pByRefVariant[0].VariantType & VT_BYREF_TYPEMASK) == VT_BYREF_VARIANT)
+                {
+                   return ref pByRefVariant[0];
+                }
+            }
+
+            return ref pSrc;
+        }
+
+        unsafe void IDispatch.Invoke(
+            int dispid,
+            ref Guid riid,
+            int lcid,
+            InvokeFlags wFlags,
+            ref ComTypes.DISPPARAMS pDispParams,
+            IntPtr pVarResult,
+            IntPtr pExcepInfo,
+            IntPtr puArgErr)
+        {
+            ComEventsMethod method = FindMethod(dispid);
+            if (method == null)
+            {
+                return;
+            }
+
+            // notice the unsafe pointers we are using. This is to avoid unnecessary
+            // arguments marshalling. see code:ComEventsHelper#ComEventsArgsMarshalling
+
+            const int InvalidIdx = -1;
+            object [] args = new object[pDispParams.cArgs];
+            int [] byrefsMap = new int[pDispParams.cArgs];
+            bool [] usedArgs = new bool[pDispParams.cArgs];
+
+            int totalCount = pDispParams.cNamedArgs + pDispParams.cArgs;
+            var vars = new Span<Variant>(pDispParams.rgvarg.ToPointer(), totalCount);
+            var namedArgs = new Span<int>(pDispParams.rgdispidNamedArgs.ToPointer(), totalCount);
+
+            // copy the named args (positional) as specified
+            int i;
+            int pos;
+            for (i = 0; i < pDispParams.cNamedArgs; i++)
+            {
+                pos = namedArgs[i];
+                ref Variant pvar = ref GetVariant(ref vars[i]);
+                args[pos] = pvar.ToObject();
+                usedArgs[pos] = true;
+
+                int byrefIdx = InvalidIdx;
+                if (pvar.IsByRef)
+                {
+                    byrefIdx = i;
+                }
+
+                byrefsMap[pos] = byrefIdx;
+            }
+
+            // copy the rest of the arguments in the reverse order
+            pos = 0;
+            for (; i < pDispParams.cArgs; i++)
+            {
+                // find the next unassigned argument
+                while (usedArgs[pos])
+                {
+                    pos++;
+                }
+
+                ref Variant pvar = ref GetVariant(ref vars[pDispParams.cArgs - 1 - i]);
+                args[pos] = pvar.ToObject();
+
+                int byrefIdx = InvalidIdx;
+                if (pvar.IsByRef)
+                {
+                    byrefIdx = pDispParams.cArgs - 1 - i;
+                }
+
+                byrefsMap[pos] = byrefIdx;
+
+                pos++;
+            }
+
+            // Do the actual delegate invocation
+            object result = method.Invoke(args);
+
+            // convert result to VARIANT
+            if (pVarResult != IntPtr.Zero)
+            {
+                Marshal.GetNativeVariantForObject(result, pVarResult);
+            }
+
+            // Now we need to marshal all the byrefs back
+            for (i = 0; i < pDispParams.cArgs; i++)
+            {
+                int idxToPos = byrefsMap[i];
+                if (idxToPos == InvalidIdx)
+                {
+                    continue;
+                }
+
+                ref Variant pvar = ref GetVariant(ref vars[idxToPos]);
+                pvar.CopyFromIndirect(args[i]);
+            }
+        }
 
         CustomQueryInterfaceResult ICustomQueryInterface.GetInterface(ref Guid iid, out IntPtr ppv)
         {
@@ -134,9 +244,6 @@ namespace System.Runtime.InteropServices
             return CustomQueryInterfaceResult.NotHandled;
         }
 
-        #region private methods
-
-
         private void Advise(object rcw)
         {
             Debug.Assert(_connectionPoint == null, "comevent sink is already advised");
@@ -161,7 +268,7 @@ namespace System.Runtime.InteropServices
                 _connectionPoint.Unadvise(_cookie);
                 Marshal.ReleaseComObject(_connectionPoint);
             }
-            catch (System.Exception)
+            catch (Exception)
             {
                 // swallow all exceptions on unadvise
                 // the host may not be available at this point
@@ -171,7 +278,5 @@ namespace System.Runtime.InteropServices
                 _connectionPoint = null;
             }
         }
-
-        #endregion
     };
 }
index 8de914e..67bb393 100644 (file)
@@ -5,6 +5,8 @@
 using System.Collections;
 using System.Runtime.InteropServices.ComTypes;
 
+using Variant = System.Runtime.InteropServices.Variant;
+
 namespace System.Runtime.InteropServices.CustomMarshalers
 {
     internal class EnumerableViewOfDispatch : ICustomAdapter, System.Collections.IEnumerable
@@ -23,19 +25,25 @@ namespace System.Runtime.InteropServices.CustomMarshalers
 
         public IEnumerator GetEnumerator()
         {
-            DISPPARAMS dispParams = new DISPPARAMS();
-            Guid guid = Guid.Empty;
-            Dispatch.Invoke(
-                DISPID_NEWENUM,
-                ref guid,
-                LCID_DEFAULT,
-                InvokeFlags.DISPATCH_METHOD | InvokeFlags.DISPATCH_PROPERTYGET,
-                ref dispParams,
-                out object result,
-                IntPtr.Zero,
-                IntPtr.Zero);
-
-            if (!(result is IEnumVARIANT enumVariant))
+            Variant result;
+            unsafe
+            {
+                void *resultLocal = &result;
+                DISPPARAMS dispParams = new DISPPARAMS();
+                Guid guid = Guid.Empty;
+                Dispatch.Invoke(
+                    DISPID_NEWENUM,
+                    ref guid,
+                    LCID_DEFAULT,
+                    InvokeFlags.DISPATCH_METHOD | InvokeFlags.DISPATCH_PROPERTYGET,
+                    ref dispParams,
+                    new IntPtr(resultLocal),
+                    IntPtr.Zero,
+                    IntPtr.Zero);
+            }
+
+            object resultAsObject = result.ToObject();
+            if (!(resultAsObject is IEnumVARIANT enumVariant))
             {
                 throw new InvalidOperationException(SR.InvalidOp_InvalidNewEnumVariant);
             }
index 183efa5..fbe70fe 100644 (file)
@@ -4,7 +4,6 @@
 
 using System;
 using System.Collections.Generic;
-using System.Runtime.InteropServices.ComTypes;
 using System.Text;
 
 namespace System.Runtime.InteropServices
@@ -14,6 +13,12 @@ namespace System.Runtime.InteropServices
     [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
     internal interface IDispatch
     {
+        int GetTypeInfoCount();
+
+        ComTypes.ITypeInfo GetTypeInfo(
+            int iTInfo,
+            int lcid);
+
         void GetIDsOfNames(
             ref Guid riid,
             [MarshalAs(UnmanagedType.LPArray, ArraySubType = UnmanagedType.LPWStr, SizeParamIndex = 2), In]
@@ -22,21 +27,17 @@ namespace System.Runtime.InteropServices
             int lcid,
             [Out] int[] rgDispId);
 
-        ITypeInfo GetTypeInfo(
-            int iTInfo,
-            int lcid);
-
-        int GetTypeInfoCount();
-
+        // The last 3 parameters of Invoke() are optional and must be defined
+        // as IntPtr in C#, since there is no language feature for optional ref/out.
         void Invoke(
             int dispIdMember,
             ref Guid riid,
             int lcid,
             InvokeFlags wFlags,
-            ref DISPPARAMS pDispParams,
-            out object pVarResult,
-            IntPtr pExcepInfo,
-            IntPtr puArgErr);
+            ref ComTypes.DISPPARAMS pDispParams,
+            /* out/optional */ IntPtr pVarResult,
+            /* out/optional */ IntPtr pExcepInfo,
+            /* out/optional */ IntPtr puArgErr);
     }
 
     [Flags]
diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Variant.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Variant.cs
new file mode 100644 (file)
index 0000000..c726255
--- /dev/null
@@ -0,0 +1,710 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Diagnostics;
+
+namespace System.Runtime.InteropServices
+{
+    /// <summary>
+    /// Variant is the basic COM type for late-binding. It can contain any other COM data type.
+    /// This type definition precisely matches the unmanaged data layout so that the struct can be passed
+    /// to and from COM calls.
+    /// </summary>
+    [StructLayout(LayoutKind.Explicit)]
+    internal struct Variant
+    {
+#if DEBUG
+        static Variant()
+        {
+            // Variant size is the size of 4 pointers (16 bytes) on a 32-bit processor,
+            // and 3 pointers (24 bytes) on a 64-bit processor.
+            int variantSize = Marshal.SizeOf(typeof(Variant));
+            if (IntPtr.Size == 4)
+            {
+                Debug.Assert(variantSize == (4 * IntPtr.Size));
+            }
+            else
+            {
+                Debug.Assert(IntPtr.Size == 8);
+                Debug.Assert(variantSize == (3 * IntPtr.Size));
+            }
+        }
+#endif
+
+        // Most of the data types in the Variant are carried in _typeUnion
+        [FieldOffset(0)] private TypeUnion _typeUnion;
+
+        // Decimal is the largest data type and it needs to use the space that is normally unused in TypeUnion._wReserved1, etc.
+        // Hence, it is declared to completely overlap with TypeUnion. A Decimal does not use the first two bytes, and so
+        // TypeUnion._vt can still be used to encode the type.
+        [FieldOffset(0)] private Decimal _decimal;
+
+        [StructLayout(LayoutKind.Sequential)]
+        private struct TypeUnion
+        {
+            public ushort _vt;
+            public ushort _wReserved1;
+            public ushort _wReserved2;
+            public ushort _wReserved3;
+
+            public UnionTypes _unionTypes;
+        }
+
+        [StructLayout(LayoutKind.Sequential)]
+        private struct Record
+        {
+            public IntPtr _record;
+            public IntPtr _recordInfo;
+        }
+
+        [StructLayout(LayoutKind.Explicit)]
+        private struct UnionTypes
+        {
+            [FieldOffset(0)] public sbyte _i1;
+            [FieldOffset(0)] public short _i2;
+            [FieldOffset(0)] public int _i4;
+            [FieldOffset(0)] public long _i8;
+            [FieldOffset(0)] public byte _ui1;
+            [FieldOffset(0)] public ushort _ui2;
+            [FieldOffset(0)] public uint _ui4;
+            [FieldOffset(0)] public ulong _ui8;
+            [FieldOffset(0)] public int _int;
+            [FieldOffset(0)] public uint _uint;
+            [FieldOffset(0)] public short _bool;
+            [FieldOffset(0)] public int _error;
+            [FieldOffset(0)] public float _r4;
+            [FieldOffset(0)] public double _r8;
+            [FieldOffset(0)] public long _cy;
+            [FieldOffset(0)] public double _date;
+            [FieldOffset(0)] public IntPtr _bstr;
+            [FieldOffset(0)] public IntPtr _unknown;
+            [FieldOffset(0)] public IntPtr _dispatch;
+            [FieldOffset(0)] public IntPtr _pvarVal;
+            [FieldOffset(0)] public IntPtr _byref;
+            [FieldOffset(0)] public Record _record;
+        }
+
+        /// <summary>
+        /// Primitive types are the basic COM types. It includes valuetypes like ints, but also reference types
+        /// like BStrs. It does not include composite types like arrays and user-defined COM types (IUnknown/IDispatch).
+        /// </summary>
+        public static bool IsPrimitiveType(VarEnum varEnum)
+        {
+            switch(varEnum)
+            {
+                case VarEnum.VT_I1:
+                case VarEnum.VT_I2:
+                case VarEnum.VT_I4:
+                case VarEnum.VT_I8:
+                case VarEnum.VT_UI1:
+                case VarEnum.VT_UI2:
+                case VarEnum.VT_UI4:
+                case VarEnum.VT_UI8:
+                case VarEnum.VT_INT:
+                case VarEnum.VT_UINT:
+                case VarEnum.VT_BOOL:
+                case VarEnum.VT_R4:
+                case VarEnum.VT_R8:
+                case VarEnum.VT_DECIMAL:
+                case VarEnum.VT_DATE:
+                case VarEnum.VT_BSTR:
+                    return true;
+            }
+
+            return false;
+        }
+
+        unsafe public void CopyFromIndirect(object value)
+        {
+            VarEnum vt = (VarEnum)(((int)this.VariantType) & ~((int)VarEnum.VT_BYREF));
+
+            if (value == null)
+            {
+                if (vt == VarEnum.VT_DISPATCH || vt == VarEnum.VT_UNKNOWN || vt == VarEnum.VT_BSTR)
+                {
+                    *(IntPtr*)this._typeUnion._unionTypes._byref = IntPtr.Zero;
+                }
+                return;
+            }
+
+            if ((vt & VarEnum.VT_ARRAY) != 0)
+            {
+                Variant vArray;
+                Marshal.GetNativeVariantForObject(value, (IntPtr)(void*)&vArray);
+                *(IntPtr*)this._typeUnion._unionTypes._byref = vArray._typeUnion._unionTypes._byref;
+                return;
+            }
+
+            switch (vt)
+            {
+                case VarEnum.VT_I1:
+                    *(sbyte*)this._typeUnion._unionTypes._byref = (sbyte)value;
+                    break;
+
+                case VarEnum.VT_UI1:
+                    *(byte*)this._typeUnion._unionTypes._byref = (byte)value;
+                    break;
+
+                case VarEnum.VT_I2:
+                    *(short*)this._typeUnion._unionTypes._byref = (short)value;
+                    break;
+
+                case VarEnum.VT_UI2:
+                    *(ushort*)this._typeUnion._unionTypes._byref = (ushort)value;
+                    break;
+
+                case VarEnum.VT_BOOL:
+                    // VARIANT_TRUE  = -1
+                    // VARIANT_FALSE = 0
+                    *(short*)this._typeUnion._unionTypes._byref = (bool)value ? (short)-1 : (short)0;
+                    break;
+
+                case VarEnum.VT_I4:
+                case VarEnum.VT_INT:
+                    *(int*)this._typeUnion._unionTypes._byref = (int)value;
+                    break;
+
+                case VarEnum.VT_UI4:
+                case VarEnum.VT_UINT:
+                    *(uint*)this._typeUnion._unionTypes._byref = (uint)value;
+                    break;
+
+                case VarEnum.VT_ERROR:
+                    *(int*)this._typeUnion._unionTypes._byref = ((ErrorWrapper)value).ErrorCode;
+                    break;
+
+                case VarEnum.VT_I8:
+                    *(Int64*)this._typeUnion._unionTypes._byref = (Int64)value;
+                    break;
+
+                case VarEnum.VT_UI8:
+                    *(UInt64*)this._typeUnion._unionTypes._byref = (UInt64)value;
+                    break;
+
+                case VarEnum.VT_R4:
+                    *(float*)this._typeUnion._unionTypes._byref = (float)value;
+                    break;
+
+                case VarEnum.VT_R8:
+                    *(double*)this._typeUnion._unionTypes._byref = (double)value;
+                    break;
+
+                case VarEnum.VT_DATE:
+                    *(double*)this._typeUnion._unionTypes._byref = ((DateTime)value).ToOADate();
+                    break;
+
+                case VarEnum.VT_UNKNOWN:
+                    *(IntPtr*)this._typeUnion._unionTypes._byref = Marshal.GetIUnknownForObject(value);
+                    break;
+
+                case VarEnum.VT_DISPATCH:
+                    *(IntPtr*)this._typeUnion._unionTypes._byref = Marshal.GetIDispatchForObject(value);
+                    break;
+
+                case VarEnum.VT_BSTR:
+                    *(IntPtr*)this._typeUnion._unionTypes._byref = Marshal.StringToBSTR((string)value);
+                    break;
+
+                case VarEnum.VT_CY:
+                    *(long*)this._typeUnion._unionTypes._byref = decimal.ToOACurrency((decimal)value);
+                    break;
+
+                case VarEnum.VT_DECIMAL:
+                    *(decimal*)this._typeUnion._unionTypes._byref = (decimal)value;
+                    break;
+
+                case VarEnum.VT_VARIANT:
+                    Marshal.GetNativeVariantForObject(value, this._typeUnion._unionTypes._byref);
+                    break;
+
+                default:
+                    throw new ArgumentException();
+            }
+        }
+
+        /// <summary>
+        /// Get the managed object representing the Variant.
+        /// </summary>
+        /// <returns></returns>
+        public object ToObject()
+        {
+            // Check the simple case upfront
+            if (IsEmpty)
+            {
+                return null;
+            }
+
+            switch (VariantType)
+            {
+                case VarEnum.VT_NULL:
+                    return DBNull.Value;
+
+                case VarEnum.VT_I1: return AsI1;
+                case VarEnum.VT_I2: return AsI2;
+                case VarEnum.VT_I4: return AsI4;
+                case VarEnum.VT_I8: return AsI8;
+                case VarEnum.VT_UI1: return AsUi1;
+                case VarEnum.VT_UI2: return AsUi2;
+                case VarEnum.VT_UI4: return AsUi4;
+                case VarEnum.VT_UI8: return AsUi8;
+                case VarEnum.VT_INT: return AsInt;
+                case VarEnum.VT_UINT: return AsUint;
+                case VarEnum.VT_BOOL: return AsBool;
+                case VarEnum.VT_ERROR: return AsError;
+                case VarEnum.VT_R4: return AsR4;
+                case VarEnum.VT_R8: return AsR8;
+                case VarEnum.VT_DECIMAL: return AsDecimal;
+                case VarEnum.VT_CY: return AsCy;
+                case VarEnum.VT_DATE: return AsDate;
+                case VarEnum.VT_BSTR: return AsBstr;
+                case VarEnum.VT_UNKNOWN: return AsUnknown;
+                case VarEnum.VT_DISPATCH: return AsDispatch;
+
+                default:
+                    unsafe
+                    {
+                        fixed (void* pThis = &this)
+                        {
+                            return Marshal.GetObjectForNativeVariant((System.IntPtr)pThis);
+                        }
+                    }
+            }
+        }
+
+        /// <summary>
+        /// Release any unmanaged memory associated with the Variant
+        /// </summary>
+        /// <returns></returns>
+        public void Clear()
+        {
+            // We do not need to call OLE32's VariantClear for primitive types or ByRefs
+            // to save ourselves the cost of interop transition.
+            // ByRef indicates the memory is not owned by the VARIANT itself while
+            // primitive types do not have any resources to free up.
+            // Hence, only safearrays, BSTRs, interfaces and user types are
+            // handled differently.
+            VarEnum vt = VariantType;
+            if ((vt & VarEnum.VT_BYREF) != 0)
+            {
+                VariantType = VarEnum.VT_EMPTY;
+            }
+            else if (((vt & VarEnum.VT_ARRAY) != 0)
+                    || ((vt) == VarEnum.VT_BSTR)
+                    || ((vt) == VarEnum.VT_UNKNOWN)
+                    || ((vt) == VarEnum.VT_DISPATCH)
+                    || ((vt) == VarEnum.VT_VARIANT)
+                    || ((vt) == VarEnum.VT_RECORD)
+                    || ((vt) == VarEnum.VT_VARIANT))
+            {
+                unsafe
+                {
+                    fixed (void* pThis = &this)
+                    {
+                        Interop.OleAut32.VariantClear((IntPtr)pThis);
+                    }
+                }
+
+                Debug.Assert(IsEmpty);
+            }
+            else
+            {
+                VariantType = VarEnum.VT_EMPTY;
+            }
+        }
+
+        public VarEnum VariantType
+        {
+            get => (VarEnum)_typeUnion._vt;
+            set => _typeUnion._vt = (ushort)value;
+        }
+
+        public bool IsEmpty => _typeUnion._vt == ((ushort)VarEnum.VT_EMPTY);
+
+        public bool IsByRef => (_typeUnion._vt & ((ushort)VarEnum.VT_BYREF)) != 0;
+
+        public void SetAsNULL()
+        {
+            Debug.Assert(IsEmpty);
+            VariantType = VarEnum.VT_NULL;
+        }
+
+        // VT_I1
+
+        public sbyte AsI1
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_I1);
+                return _typeUnion._unionTypes._i1;
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                VariantType = VarEnum.VT_I1;
+                _typeUnion._unionTypes._i1 = value;
+            }
+        }
+
+        // VT_I2
+
+        public short AsI2
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_I2);
+                return _typeUnion._unionTypes._i2;
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                VariantType = VarEnum.VT_I2;
+                _typeUnion._unionTypes._i2 = value;
+            }
+        }
+
+        // VT_I4
+
+        public int AsI4
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_I4);
+                return _typeUnion._unionTypes._i4;
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                VariantType = VarEnum.VT_I4;
+                _typeUnion._unionTypes._i4 = value;
+            }
+        }
+
+        // VT_I8
+
+        public long AsI8
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_I8);
+                return _typeUnion._unionTypes._i8;
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                VariantType = VarEnum.VT_I8;
+                _typeUnion._unionTypes._i8 = value;
+            }
+        }
+
+        // VT_UI1
+
+        public byte AsUi1
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_UI1);
+                return _typeUnion._unionTypes._ui1;
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                VariantType = VarEnum.VT_UI1;
+                _typeUnion._unionTypes._ui1 = value;
+            }
+        }
+
+        // VT_UI2
+
+        public ushort AsUi2
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_UI2);
+                return _typeUnion._unionTypes._ui2;
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                VariantType = VarEnum.VT_UI2;
+                _typeUnion._unionTypes._ui2 = value;
+            }
+        }
+
+        // VT_UI4
+
+        public uint AsUi4
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_UI4);
+                return _typeUnion._unionTypes._ui4;
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                VariantType = VarEnum.VT_UI4;
+                _typeUnion._unionTypes._ui4 = value;
+            }
+        }
+
+        // VT_UI8
+
+        public ulong AsUi8
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_UI8);
+                return _typeUnion._unionTypes._ui8;
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                VariantType = VarEnum.VT_UI8;
+                _typeUnion._unionTypes._ui8 = value;
+            }
+        }
+
+        // VT_INT
+
+        public int AsInt
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_INT);
+                return _typeUnion._unionTypes._int;
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                VariantType = VarEnum.VT_INT;
+                _typeUnion._unionTypes._int = value;
+            }
+        }
+
+        // VT_UINT
+
+        public uint AsUint
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_UINT);
+                return _typeUnion._unionTypes._uint;
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                VariantType = VarEnum.VT_UINT;
+                _typeUnion._unionTypes._uint = value;
+            }
+        }
+
+        // VT_BOOL
+
+        public bool AsBool
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_BOOL);
+                return _typeUnion._unionTypes._bool != 0;
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                // VARIANT_TRUE  = -1
+                // VARIANT_FALSE = 0
+                VariantType = VarEnum.VT_BOOL;
+                _typeUnion._unionTypes._bool = value ? (short)-1 : (short)0;
+            }
+        }
+
+        // VT_ERROR
+
+        public int AsError
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_ERROR);
+                return _typeUnion._unionTypes._error;
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                VariantType = VarEnum.VT_ERROR;
+                _typeUnion._unionTypes._error = value;
+            }
+        }
+
+        // VT_R4
+
+        public float AsR4
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_R4);
+                return _typeUnion._unionTypes._r4;
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                VariantType = VarEnum.VT_R4;
+                _typeUnion._unionTypes._r4 = value;
+            }
+        }
+
+        // VT_R8
+
+        public double AsR8
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_R8);
+                return _typeUnion._unionTypes._r8;
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                VariantType = VarEnum.VT_R8;
+                _typeUnion._unionTypes._r8 = value;
+            }
+        }
+
+        // VT_DECIMAL
+
+        public Decimal AsDecimal
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_DECIMAL);
+                // The first byte of Decimal is unused, but usually set to 0
+                Variant v = this;
+                v._typeUnion._vt = 0;
+                return v._decimal;
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                VariantType = VarEnum.VT_DECIMAL;
+                _decimal = value;
+                // _vt overlaps with _decimal, and should be set after setting _decimal
+                _typeUnion._vt = (ushort)VarEnum.VT_DECIMAL;
+            }
+        }
+
+        // VT_CY
+
+        public Decimal AsCy
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_CY);
+                return Decimal.FromOACurrency(_typeUnion._unionTypes._cy);
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                VariantType = VarEnum.VT_CY;
+                _typeUnion._unionTypes._cy = Decimal.ToOACurrency(value);
+            }
+        }
+
+        // VT_DATE
+
+        public DateTime AsDate
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_DATE);
+                return DateTime.FromOADate(_typeUnion._unionTypes._date);
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                VariantType = VarEnum.VT_DATE;
+                _typeUnion._unionTypes._date = value.ToOADate();
+            }
+        }
+
+        // VT_BSTR
+
+        public string AsBstr
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_BSTR);
+                return (string)Marshal.PtrToStringBSTR(this._typeUnion._unionTypes._bstr);
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                VariantType = VarEnum.VT_BSTR;
+                this._typeUnion._unionTypes._bstr = Marshal.StringToBSTR(value);
+            }
+        }
+
+        // VT_UNKNOWN
+
+        public object AsUnknown
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_UNKNOWN);
+                if (_typeUnion._unionTypes._unknown == IntPtr.Zero)
+                {
+                    return null;
+                }
+                return Marshal.GetObjectForIUnknown(_typeUnion._unionTypes._unknown);
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                VariantType = VarEnum.VT_UNKNOWN;
+                if (value == null)
+                {
+                    _typeUnion._unionTypes._unknown = IntPtr.Zero;
+                }
+                else
+                {
+                    _typeUnion._unionTypes._unknown = Marshal.GetIUnknownForObject(value);
+                }
+            }
+        }
+
+        // VT_DISPATCH
+
+        public object AsDispatch
+        {
+            get
+            {
+                Debug.Assert(VariantType == VarEnum.VT_DISPATCH);
+                if (_typeUnion._unionTypes._dispatch == IntPtr.Zero)
+                {
+                    return null;
+                }
+                return Marshal.GetObjectForIUnknown(_typeUnion._unionTypes._dispatch);
+            }
+            set
+            {
+                Debug.Assert(IsEmpty);
+                VariantType = VarEnum.VT_DISPATCH;
+                if (value == null)
+                {
+                    _typeUnion._unionTypes._dispatch = IntPtr.Zero;
+                }
+                else
+                {
+                    _typeUnion._unionTypes._dispatch = Marshal.GetIDispatchForObject(value);
+                }
+            }
+        }
+
+        public IntPtr AsByRefVariant
+        {
+            get
+            {
+                Debug.Assert(VariantType == (VarEnum.VT_BYREF | VarEnum.VT_VARIANT));
+                return _typeUnion._unionTypes._pvarVal;
+            }
+        }
+    }
+}
diff --git a/tests/src/Interop/COM/NETClients/Events/App.manifest b/tests/src/Interop/COM/NETClients/Events/App.manifest
new file mode 100644 (file)
index 0000000..833743f
--- /dev/null
@@ -0,0 +1,18 @@
+<?xml version="1.0" encoding="utf-8"?>
+<assembly manifestVersion="1.0" xmlns="urn:schemas-microsoft-com:asm.v1">
+  <assemblyIdentity
+    type="win32" 
+    name="NetPrimitivesEvents"
+    version="1.0.0.0" />
+
+  <dependency>
+    <dependentAssembly>
+      <!-- RegFree COM -->
+      <assemblyIdentity
+          type="win32"
+          name="COMNativeServer.X"
+          version="1.0.0.0"/>
+    </dependentAssembly>
+  </dependency>
+
+</assembly>
diff --git a/tests/src/Interop/COM/NETClients/Events/NETClientEvents.csproj b/tests/src/Interop/COM/NETClients/Events/NETClientEvents.csproj
new file mode 100644 (file)
index 0000000..fccbee9
--- /dev/null
@@ -0,0 +1,43 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project ToolsVersion="12.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+  <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.props))\dir.props" />
+  <PropertyGroup>
+    <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
+    <Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform>
+    <AssemblyName>NETClientEvents</AssemblyName>
+    <SchemaVersion>2.0</SchemaVersion>
+    <ProjectGuid>{85C57688-DA98-4DE3-AC9B-526E4747434C}</ProjectGuid>
+    <OutputType>Exe</OutputType>
+    <ProjectTypeGuids>{209912F9-0DA1-4184-9CC1-8D583BAF4A28};{87799F5D-CEBD-499D-BDBA-B2C6105CD766}</ProjectTypeGuids>
+    <ApplicationManifest>App.manifest</ApplicationManifest>
+
+    <!-- Blocked on ILAsm supporting embedding resources. See https://github.com/dotnet/coreclr/issues/20819 -->
+    <IlrtTestKind>BuildOnly</IlrtTestKind>
+
+    <!-- Blocked on CrossGen.exe supporting embedding resources. See https://github.com/dotnet/coreclr/issues/21006 -->
+    <CrossGenTest>false</CrossGenTest>
+
+    <!-- Test unsupported outside of windows -->
+    <TestUnsupportedOutsideWindows>true</TestUnsupportedOutsideWindows>
+    <DisableProjectBuild Condition="'$(TargetsUnix)' == 'true'">true</DisableProjectBuild>
+    <!-- This test would require the runincontext.exe to include App.manifest describing the COM interfaces -->
+    <UnloadabilityIncompatible>true</UnloadabilityIncompatible>
+  </PropertyGroup>
+  <!-- Default configurations to help VS understand the configurations -->
+  <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|x64'">
+  </PropertyGroup>
+  <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Release|x64'">
+  </PropertyGroup>
+  <ItemGroup>
+    <Compile Include="Program.cs" />
+    <Compile Include="../../ServerContracts/NativeServers.cs" />
+    <Compile Include="../../ServerContracts/Server.Contracts.cs" />
+    <Compile Include="../../ServerContracts/Server.Events.cs" />
+    <Compile Include="../../ServerContracts/ServerGuids.cs" />
+  </ItemGroup>
+  <ItemGroup>
+    <ProjectReference Include="../../NativeServer/CMakeLists.txt" />
+    <ProjectReference Include="../../../../Common/CoreCLRTestLibrary/CoreCLRTestLibrary.csproj" />
+  </ItemGroup>
+  <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.targets))\dir.targets" />
+</Project>
diff --git a/tests/src/Interop/COM/NETClients/Events/Program.cs b/tests/src/Interop/COM/NETClients/Events/Program.cs
new file mode 100644 (file)
index 0000000..4935924
--- /dev/null
@@ -0,0 +1,117 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace NetClient
+{
+    using System;
+    using System.Reflection;
+    using System.Runtime.InteropServices;
+
+    using TestLibrary;
+    using Server.Contract;
+    using Server.Contract.Servers;
+    using Server.Contract.Events;
+
+    class Program
+    {
+        static void Validate_BasicCOMEvent()
+        {
+            Console.WriteLine($"{nameof(Validate_BasicCOMEvent)}...");
+
+            var eventTesting = (EventTesting)new EventTestingClass();
+
+            // Verify event handler subscription
+
+            // Add event
+            eventTesting.OnEvent += OnEventEventHandler;
+
+            bool eventFired = false;
+            string message = string.Empty;
+            eventTesting.FireEvent();
+
+            Assert.IsTrue(eventFired, "Event didn't fire");
+            Assert.AreEqual(nameof(EventTesting.FireEvent), message, "Event message is incorrect");
+
+            // Remove event
+            eventTesting.OnEvent -= OnEventEventHandler;
+
+            // Verify event handler removed
+
+            eventFired = false;
+            eventTesting.FireEvent();
+
+            Assert.IsFalse(eventFired, "Event shouldn't fire");
+
+            void OnEventEventHandler(string msg)
+            {
+                eventFired = true;
+                message = msg;
+            }
+        }
+
+#pragma warning disable 618 // Must test deprecated features
+
+        // The ComAwareEventInfo is used by the compiler when PIAs
+        // containing COM Events are embedded.
+        static void Validate_COMEventViaComAwareEventInfo()
+        {
+            Console.WriteLine($"{nameof(Validate_COMEventViaComAwareEventInfo)}...");
+
+            var eventTesting = (EventTesting)new EventTestingClass();
+
+            // Verify event handler subscription
+
+            // Add event
+            var comAwareEventInfo = new ComAwareEventInfo(typeof(TestingEvents_Event), nameof(TestingEvents_Event.OnEvent));
+            var handler = new TestingEvents_OnEventEventHandler(OnEventEventHandler);
+            comAwareEventInfo.AddEventHandler(eventTesting, handler);
+
+            bool eventFired = false;
+            string message = string.Empty;
+            eventTesting.FireEvent();
+
+            Assert.IsTrue(eventFired, "Event didn't fire");
+            Assert.AreEqual(nameof(EventTesting.FireEvent), message, "Event message is incorrect");
+
+            comAwareEventInfo.RemoveEventHandler(eventTesting, handler);
+
+            // Verify event handler removed
+
+            eventFired = false;
+            eventTesting.FireEvent();
+
+            Assert.IsFalse(eventFired, "Event shouldn't fire");
+
+            void OnEventEventHandler(string msg)
+            {
+                eventFired = true;
+                message = msg;
+            }
+        }
+
+#pragma warning restore 618 // Must test deprecated features
+
+        static int Main(string[] doNotUse)
+        {
+            // RegFree COM is not supported on Windows Nano
+            if (Utilities.IsWindowsNanoServer)
+            {
+                return 100;
+            }
+
+            try
+            {
+                Validate_BasicCOMEvent();
+                Validate_COMEventViaComAwareEventInfo();
+            }
+            catch (Exception e)
+            {
+                Console.WriteLine($"Test Failure: {e}");
+                return 101;
+            }
+
+            return 100;
+        }
+    }
+}
index 8366a1c..b653a53 100644 (file)
@@ -340,7 +340,7 @@ public: // IUnknown
         /* [in] */ REFIID riid,
         /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
     {
-        return DoQueryInterface<ArrayTesting, IArrayTesting>(this, riid, ppvObject);
+        return DoQueryInterface(riid, ppvObject, static_cast<IArrayTesting *>(this));
     }
 
     DEFINE_REF_COUNTING();
index 1569d52..4509ee9 100644 (file)
     clsid="{0F8ACD0C-ECE0-4F2A-BD1B-6BFCA93A0726}"
     threadingModel="Both" />
 
+  <!-- EventTesting -->
+  <comClass
+    clsid="{4DBD9B61-E372-499F-84DE-EFC70AA8A009}"
+    threadingModel="Both" />
+
   <!-- AggregationTesting -->
   <comClass
     clsid="{4CEFE36D-F377-4B6E-8C34-819A8BB9CB04}"
index 5d6e174..6357ce2 100644 (file)
@@ -34,7 +34,7 @@ public: // IUnknown
         /* [in] */ REFIID riid,
         /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
     {
-        return DoQueryInterface<ColorTesting, IColorTesting>(this, riid, ppvObject);
+        return DoQueryInterface(riid, ppvObject, static_cast<IColorTesting *>(this));
     }
 
     DEFINE_REF_COUNTING();
index af9cd6c..166d6fa 100644 (file)
@@ -415,7 +415,7 @@ public: // IUnknown
         /* [in] */ REFIID riid,
         /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
     {
-        return DoQueryInterface<DispatchTesting, IDispatch, IDispatchTesting>(this, riid, ppvObject);
+        return DoQueryInterface(riid, ppvObject, static_cast<IDispatch *>(this), static_cast<IDispatchTesting *>(this));
     }
 
     DEFINE_REF_COUNTING();
index 28c0011..c28fa0f 100644 (file)
@@ -26,7 +26,7 @@ public: // IUnknown
         /* [in] */ REFIID riid,
         /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
     {
-        return DoQueryInterface<ErrorMarshalTesting, IErrorMarshalTesting>(this, riid, ppvObject);
+        return DoQueryInterface(riid, ppvObject, static_cast<IErrorMarshalTesting *>(this));
     }
 
     DEFINE_REF_COUNTING();
diff --git a/tests/src/Interop/COM/NativeServer/EventTesting.h b/tests/src/Interop/COM/NativeServer/EventTesting.h
new file mode 100644 (file)
index 0000000..f700db8
--- /dev/null
@@ -0,0 +1,237 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+#pragma once
+
+#include "Servers.h"
+
+class EventTesting :
+    public UnknownImpl,
+    public IEventTesting,
+    public IConnectionPointContainer,
+    public IConnectionPoint
+{
+private: // static
+    static const WCHAR * const Names[];
+    static const int NamesCount;
+
+private:
+    IDispatch *_eventConnections[32];
+
+public:
+    EventTesting()
+    {
+        // Ensure connections array is null
+        ::memset(_eventConnections, 0, sizeof(_eventConnections));
+    }
+
+public: // IDispatch
+        virtual HRESULT STDMETHODCALLTYPE GetTypeInfoCount( 
+            /* [out] */ __RPC__out UINT *pctinfo)
+        {
+            *pctinfo = 0;
+            return S_OK;
+        }
+        
+        virtual HRESULT STDMETHODCALLTYPE GetTypeInfo( 
+            /* [in] */ UINT iTInfo,
+            /* [in] */ LCID lcid,
+            /* [out] */ __RPC__deref_out_opt ITypeInfo **ppTInfo)
+        {
+            return E_NOTIMPL;
+        }
+        
+        virtual HRESULT STDMETHODCALLTYPE GetIDsOfNames( 
+            /* [in] */ __RPC__in REFIID,
+            /* [size_is][in] */ __RPC__in_ecount_full(cNames) LPOLESTR *rgszNames,
+            /* [range][in] */ __RPC__in_range(0,16384) UINT cNames,
+            /* [in] */ LCID,
+            /* [size_is][out] */ __RPC__out_ecount_full(cNames) DISPID *rgDispId)
+        {
+            bool containsUnknown = false;
+            DISPID *curr = rgDispId;
+            for (UINT i = 0; i < cNames; ++i)
+            {
+                *curr = DISPID_UNKNOWN;
+                LPOLESTR name = rgszNames[i];
+                for (int j = 1; j < NamesCount; ++j)
+                {
+                    const WCHAR *nameMaybe = Names[j];
+                    if (::TP_wcmp_s(name, nameMaybe) == 0)
+                    {
+                        *curr = DISPID{ j };
+                        break;
+                    }
+                }
+
+                containsUnknown &= (*curr == DISPID_UNKNOWN);
+                curr++;
+            }
+
+            return (containsUnknown) ? DISP_E_UNKNOWNNAME : S_OK;
+        }
+
+        virtual /* [local] */ HRESULT STDMETHODCALLTYPE Invoke( 
+            /* [annotation][in] */ _In_  DISPID dispIdMember,
+            /* [annotation][in] */ _In_  REFIID riid,
+            /* [annotation][in] */ _In_  LCID lcid,
+            /* [annotation][in] */ _In_  WORD wFlags,
+            /* [annotation][out][in] */ _In_  DISPPARAMS *pDispParams,
+            /* [annotation][out] */ _Out_opt_  VARIANT *pVarResult,
+            /* [annotation][out] */ _Out_opt_  EXCEPINFO *pExcepInfo,
+            /* [annotation][out] */ _Out_opt_  UINT *puArgErr)
+        {
+            //
+            // Note that arguments are received in reverse order for IDispatch::Invoke()
+            //
+
+            switch (dispIdMember)
+            {
+            case 1:
+            {
+                return FireEvent();
+            }
+            }
+
+            return E_NOTIMPL;
+        }
+
+public: // IEventTesting
+    virtual HRESULT STDMETHODCALLTYPE FireEvent()
+    {
+        return FireEvent_Impl(1 /* DISPID for the FireEvent function */);
+    }
+
+public: // IConnectionPointContainer
+    virtual HRESULT STDMETHODCALLTYPE EnumConnectionPoints( 
+        /* [out] */ __RPC__deref_out_opt IEnumConnectionPoints **ppEnum)
+    {
+        return E_NOTIMPL;
+    }
+    virtual HRESULT STDMETHODCALLTYPE FindConnectionPoint( 
+        /* [in] */ __RPC__in REFIID riid,
+        /* [out] */ __RPC__deref_out_opt IConnectionPoint **ppCP)
+    {
+        if (riid != IID_TestingEvents)
+            return CONNECT_E_NOCONNECTION;
+
+        return QueryInterface(__uuidof(*ppCP), (void**)ppCP);
+    }
+
+public: // IConnectionPoint
+    virtual HRESULT STDMETHODCALLTYPE GetConnectionInterface( 
+        /* [out] */ __RPC__out IID *pIID)
+    {
+        return E_NOTIMPL;
+    }
+    virtual HRESULT STDMETHODCALLTYPE GetConnectionPointContainer( 
+        /* [out] */ __RPC__deref_out_opt IConnectionPointContainer **ppCPC)
+    {
+        return E_NOTIMPL;
+    }
+    virtual HRESULT STDMETHODCALLTYPE Advise( 
+        /* [in] */ __RPC__in_opt IUnknown *pUnkSink,
+        /* [out] */ __RPC__out DWORD *pdwCookie)
+    {
+        if (pUnkSink == nullptr || pdwCookie == nullptr)
+            return E_POINTER;
+
+        for (DWORD i = 0; i < ARRAYSIZE(_eventConnections); ++i)
+        {
+            if (_eventConnections[i] == nullptr)
+            {
+                IDispatch *handler;
+                HRESULT hr = pUnkSink->QueryInterface(IID_IDispatch, (void**)&handler);
+                if (hr != S_OK)
+                    return CONNECT_E_CANNOTCONNECT;
+
+                _eventConnections[i] = handler;
+                *pdwCookie = i;
+                return S_OK;
+            }
+        }
+
+        return CONNECT_E_ADVISELIMIT;
+    }
+    virtual HRESULT STDMETHODCALLTYPE Unadvise( 
+        /* [in] */ DWORD dwCookie)
+    {
+        if (0 <= dwCookie && dwCookie < ARRAYSIZE(_eventConnections))
+        {
+            IDispatch *handler = _eventConnections[dwCookie];
+            if (handler != nullptr)
+            {
+                _eventConnections[dwCookie] = nullptr;
+                handler->Release();
+                return S_OK;
+            }
+        }
+
+        return E_POINTER;
+    }
+    virtual HRESULT STDMETHODCALLTYPE EnumConnections( 
+        /* [out] */ __RPC__deref_out_opt IEnumConnections **ppEnum)
+    {
+        return E_NOTIMPL;
+    }
+
+private:
+    HRESULT FireEvent_Impl(_In_ int dispId)
+    {
+        HRESULT hr = S_OK;
+
+        VARIANTARG arg;
+        ::VariantInit(&arg);
+
+        arg.vt = VT_BSTR;
+        arg.bstrVal = TP_SysAllocString(Names[dispId]);
+
+        for (DWORD i = 0; i < ARRAYSIZE(_eventConnections); ++i)
+        {
+            IDispatch *handler = _eventConnections[i];
+            if (handler != nullptr)
+            {
+                DISPPARAMS params{};
+                params.rgvarg = &arg;
+                params.cArgs = 1;
+                hr = handler->Invoke(
+                    DISPATCHTESTINGEVENTS_DISPID_ONEVENT,
+                    IID_NULL,
+                    0,
+                    DISPATCH_METHOD,
+                    &params,
+                    nullptr,
+                    nullptr,
+                    nullptr);
+
+                if (FAILED(hr))
+                    break;
+            }
+        }
+
+        return ::VariantClear(&arg);
+    }
+
+public: // IUnknown
+    STDMETHOD(QueryInterface)(
+        /* [in] */ REFIID riid,
+        /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
+    {
+        return DoQueryInterface(riid, ppvObject,
+            static_cast<IDispatch *>(this),
+            static_cast<IEventTesting *>(this),
+            static_cast<IConnectionPointContainer *>(this),
+            static_cast<IConnectionPoint *>(this));
+    }
+
+    DEFINE_REF_COUNTING();
+};
+
+const WCHAR * const EventTesting::Names[] =
+{
+    W("__RESERVED__"),
+    W("FireEvent"),
+};
+
+const int EventTesting::NamesCount = ARRAYSIZE(EventTesting::Names);
index aa703be..d30427a 100644 (file)
@@ -288,7 +288,7 @@ public: // IUnknown
         /* [in] */ REFIID riid,
         /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
     {
-        return DoQueryInterface<NumericTesting, INumericTesting>(this, riid, ppvObject);
+        return DoQueryInterface(riid, ppvObject, static_cast<INumericTesting *>(this));
     }
 
     DEFINE_REF_COUNTING();
index 80da1b7..3bf7072 100644 (file)
@@ -165,6 +165,7 @@ STDAPI DllRegisterServer(void)
     RETURN_IF_FAILED(RegisterClsid(__uuidof(StringTesting), L"Both"));
     RETURN_IF_FAILED(RegisterClsid(__uuidof(ErrorMarshalTesting), L"Both"));
     RETURN_IF_FAILED(RegisterClsid(__uuidof(DispatchTesting), L"Both"));
+    RETURN_IF_FAILED(RegisterClsid(__uuidof(EventTesting), L"Both"));
     RETURN_IF_FAILED(RegisterClsid(__uuidof(AggregationTesting), L"Both"));
     RETURN_IF_FAILED(RegisterClsid(__uuidof(ColorTesting), L"Both"));
 
@@ -180,6 +181,7 @@ STDAPI DllUnregisterServer(void)
     RETURN_IF_FAILED(RemoveClsid(__uuidof(StringTesting)));
     RETURN_IF_FAILED(RemoveClsid(__uuidof(ErrorMarshalTesting)));
     RETURN_IF_FAILED(RemoveClsid(__uuidof(DispatchTesting)));
+    RETURN_IF_FAILED(RemoveClsid(__uuidof(EventTesting)));
     RETURN_IF_FAILED(RemoveClsid(__uuidof(AggregationTesting)));
     RETURN_IF_FAILED(RemoveClsid(__uuidof(ColorTesting)));
 
@@ -203,6 +205,9 @@ STDAPI DllGetClassObject(_In_ REFCLSID rclsid, _In_ REFIID riid, _Out_ LPVOID FA
     if (rclsid == __uuidof(DispatchTesting))
         return ClassFactoryBasic<DispatchTesting>::Create(riid, ppv);
 
+    if (rclsid == __uuidof(EventTesting))
+        return ClassFactoryBasic<EventTesting>::Create(riid, ppv);
+
     if (rclsid == __uuidof(AggregationTesting))
         return ClassFactoryAggregate<AggregationTesting>::Create(riid, ppv);
 
index 38983e3..7a2a1ff 100644 (file)
@@ -15,6 +15,7 @@ class DECLSPEC_UUID("B99ABE6A-DFF6-440F-BFB6-55179B8FE18E") ArrayTesting;
 class DECLSPEC_UUID("C73C83E8-51A2-47F8-9B5C-4284458E47A6") StringTesting;
 class DECLSPEC_UUID("71CF5C45-106C-4B32-B418-43A463C6041F") ErrorMarshalTesting;
 class DECLSPEC_UUID("0F8ACD0C-ECE0-4F2A-BD1B-6BFCA93A0726") DispatchTesting;
+class DECLSPEC_UUID("4DBD9B61-E372-499F-84DE-EFC70AA8A009") EventTesting;
 class DECLSPEC_UUID("4CEFE36D-F377-4B6E-8C34-819A8BB9CB04") AggregationTesting;
 class DECLSPEC_UUID("C222F472-DA5A-4FC6-9321-92F4F7053A65") ColorTesting;
 
@@ -23,6 +24,7 @@ class DECLSPEC_UUID("C222F472-DA5A-4FC6-9321-92F4F7053A65") ColorTesting;
 #define CLSID_StringTesting __uuidof(StringTesting)
 #define CLSID_ErrorMarshalTesting __uuidof(ErrorMarshalTesting)
 #define CLSID_DispatchTesting __uuidof(DispatchTesting)
+#define CLSID_EventTesting __uuidof(EventTesting)
 #define CLSID_AggregationTesting __uuidof(AggregationTesting)
 #define CLSID_ColorTesting __uuidof(ColorTesting)
 
@@ -31,6 +33,8 @@ class DECLSPEC_UUID("C222F472-DA5A-4FC6-9321-92F4F7053A65") ColorTesting;
 #define IID_IStringTesting __uuidof(IStringTesting)
 #define IID_IErrorMarshalTesting __uuidof(IErrorMarshalTesting)
 #define IID_IDispatchTesting __uuidof(IDispatchTesting)
+#define IID_TestingEvents __uuidof(TestingEvents)
+#define IID_IEventTesting __uuidof(IEventTesting)
 #define IID_IAggregationTesting __uuidof(IAggregationTesting)
 #define IID_IColorTesting __uuidof(IColorTesting)
 
@@ -67,6 +71,7 @@ private:
     #include "StringTesting.h"
     #include "ErrorMarshalTesting.h"
     #include "DispatchTesting.h"
+    #include "EventTesting.h"
     #include "AggregationTesting.h"
     #include "ColorTesting.h"
 #endif
index 06f13db..cd01fbe 100644 (file)
@@ -317,7 +317,7 @@ public: // IUnknown
         /* [in] */ REFIID riid,
         /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
     {
-        return DoQueryInterface<StringTesting, IStringTesting>(this, riid, ppvObject);
+        return DoQueryInterface(riid, ppvObject, static_cast<IStringTesting *>(this));
     }
 
     DEFINE_REF_COUNTING();
index 6f70bd1..e868345 100644 (file)
@@ -3,6 +3,7 @@
 // See the LICENSE file in the project root for more information.
 
 #pragma warning disable IDE1006 // Naming Styles
+#pragma warning disable 618 // Must test deprecated features
 
 namespace Server.Contract.Servers
 {
@@ -91,7 +92,7 @@ namespace Server.Contract.Servers
     [ComImport]
     [CoClass(typeof(DispatchTestingClass))]
     [Guid("a5e04c1c-474e-46d2-bbc0-769d04e12b54")]
-    internal interface DispatchTesting : Server.Contract.IDispatchTesting 
+    internal interface DispatchTesting : Server.Contract.IDispatchTesting
     {
     }
 
@@ -143,4 +144,5 @@ namespace Server.Contract.Servers
     }
 }
 
+#pragma warning restore 618 // Must test deprecated features
 #pragma warning restore IDE1006 // Naming Styles
index 3d2fb96..f401c48 100644 (file)
@@ -234,6 +234,24 @@ namespace Server.Contract
     }
 
     [ComVisible(true)]
+    [Guid("83AFF8E4-C46A-45DB-9D91-2ADB5164545E")]
+    [InterfaceType(ComInterfaceType.InterfaceIsIDispatch)]
+    public interface IEventTesting
+    {
+        [DispId(1)]
+        void FireEvent();
+    }
+
+    [ComImport]
+    [Guid("28ea6635-42ab-4f5b-b458-4152e78b8e86")]
+    [InterfaceType(ComInterfaceType.InterfaceIsIDispatch)]
+    public interface TestingEvents
+    {
+        [DispId(100)]
+        void OnEvent([MarshalAs(UnmanagedType.BStr)] string msg);
+    };
+
+    [ComVisible(true)]
     [Guid("98cc27f0-d521-4f79-8b63-e980e3a92974")]
     [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
     public interface IAggregationTesting
index c9be5b0..2502567 100644 (file)
@@ -19,6 +19,8 @@ struct __declspec(uuid("592386a5-6837-444d-9de3-250815d18556"))
 /* interface */ IErrorMarshalTesting;
 struct __declspec(uuid("a5e04c1c-474e-46d2-bbc0-769d04e12b54"))
 /* interface */ IDispatchTesting;
+struct __declspec(uuid("83AFF8E4-C46A-45DB-9D91-2ADB5164545E"))
+/* interface */ IEventTesting;
 struct __declspec(uuid("98cc27f0-d521-4f79-8b63-e980e3a92974"))
 /* interface */ IAggregationTesting;
 struct __declspec(uuid("E6D72BA7-0936-4396-8A69-3B76DA1108DA"))
@@ -33,6 +35,7 @@ _COM_SMARTPTR_TYPEDEF(IArrayTesting, __uuidof(IArrayTesting));
 _COM_SMARTPTR_TYPEDEF(IStringTesting, __uuidof(IStringTesting));
 _COM_SMARTPTR_TYPEDEF(IErrorMarshalTesting, __uuidof(IErrorMarshalTesting));
 _COM_SMARTPTR_TYPEDEF(IDispatchTesting, __uuidof(IDispatchTesting));
+_COM_SMARTPTR_TYPEDEF(IEventTesting, __uuidof(IEventTesting));
 _COM_SMARTPTR_TYPEDEF(IAggregationTesting, __uuidof(IAggregationTesting));
 _COM_SMARTPTR_TYPEDEF(IColorTesting, __uuidof(IColorTesting));
 
@@ -444,6 +447,19 @@ IDispatchTesting : IDispatch
         /*[out,retval]*/ HFA_4 *pRetVal) = 0;
 };
 
+struct __declspec(uuid("83AFF8E4-C46A-45DB-9D91-2ADB5164545E"))
+IEventTesting : IDispatch
+{
+    virtual HRESULT STDMETHODCALLTYPE FireEvent() = 0;
+};
+
+struct __declspec(uuid("28ea6635-42ab-4f5b-b458-4152e78b8e86"))
+TestingEvents : IDispatch
+{
+#define DISPATCHTESTINGEVENTS_DISPID_ONEVENT 100
+    // void OnEvent(_In_z_ BSTR t);
+};
+
 struct __declspec(uuid("98cc27f0-d521-4f79-8b63-e980e3a92974"))
 IAggregationTesting : IUnknown
 {
diff --git a/tests/src/Interop/COM/ServerContracts/Server.Events.cs b/tests/src/Interop/COM/ServerContracts/Server.Events.cs
new file mode 100644 (file)
index 0000000..16ed578
--- /dev/null
@@ -0,0 +1,193 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+#pragma warning disable 618 // Must test deprecated features
+
+namespace Server.Contract
+{
+    using System;
+    using System.Collections.Generic;
+    using System.Runtime.InteropServices;
+
+    using IConnectionPoint = System.Runtime.InteropServices.ComTypes.IConnectionPoint;
+    using IConnectionPointContainer = System.Runtime.InteropServices.ComTypes.IConnectionPointContainer;
+
+    namespace Servers
+    {
+        /// <summary>
+        /// Managed definition of CoClass
+        /// </summary>
+        [ComImport]
+        [CoClass(typeof(EventTestingClass))]
+        [Guid("83AFF8E4-C46A-45DB-9D91-2ADB5164545E")]
+        internal interface EventTesting : IEventTesting, Events.TestingEvents_Event
+        {
+        }
+
+        /// <summary>
+        /// Managed activation for CoClass
+        /// </summary>
+        [ComImport]
+        [ComSourceInterfaces("Server.Contract.Events.TestingEvents\0")]
+        [Guid(Server.Contract.Guids.EventTesting)]
+        internal class EventTestingClass
+        {
+        }
+    }
+
+    /// <summary>
+    /// Classes in the Events namespace are traditionally generated by the TlbImp tool.
+    /// </summary>
+    namespace Events
+    {
+        /// <summary>
+        /// Delegate used for event handler
+        /// </summary>
+        [ComVisible(false)]
+        public delegate void TestingEvents_OnEventEventHandler(string msg);
+
+        /// <summary>
+        /// Event source interface
+        /// </summary>
+        /// <remarks>
+        /// Observe usage of the <see cref="ComEventInterfaceAttribute"/> attribute.
+        /// </remarks>
+        [ComVisible(false)]
+        [ComEventInterface(typeof(Contract.TestingEvents), typeof(TestingEvents_EventProvider))]
+        public interface TestingEvents_Event
+        {
+            event TestingEvents_OnEventEventHandler OnEvent;
+        }
+
+        /// <summary>
+        /// Managed proxy for event subscription via IConnectionPointContainer and IConnectionPoint.
+        /// </summary>
+        public sealed class TestingEvents_EventProvider : TestingEvents_Event, IDisposable
+        {
+            private readonly WeakReference ConnectionPointContainer;
+            private readonly List<TestingEvents_SinkHelper> eventSinkHelpers = new List<TestingEvents_SinkHelper>();
+
+            private IConnectionPoint connectionPoint;
+            private bool isDisposed = false;
+
+            public TestingEvents_EventProvider(object container)
+            {
+                this.ConnectionPointContainer = new WeakReference((IConnectionPointContainer)container, false);
+            }
+
+            event TestingEvents_OnEventEventHandler TestingEvents_Event.OnEvent
+            {
+                add
+                {
+                    lock (this.eventSinkHelpers)
+                    {
+                        if (this.connectionPoint == null)
+                        {
+                            this.Init();
+                        }
+
+                        var sinkHelper = new TestingEvents_SinkHelper();
+
+                        int cookie;
+                        this.connectionPoint.Advise(sinkHelper, out cookie);
+
+                        sinkHelper.Cookie = cookie;
+                        sinkHelper.OnEventDelegate = value;
+                        this.eventSinkHelpers.Add(sinkHelper);
+                    }
+                }
+                remove
+                {
+                    lock (this.eventSinkHelpers)
+                    {
+                        TestingEvents_SinkHelper sinkHelper = null;
+                        int removeIdx = -1;
+                        for (int i = 0; i < this.eventSinkHelpers.Count; ++i)
+                        {
+                            TestingEvents_SinkHelper sinkHelperMaybe = this.eventSinkHelpers[i];
+                            if (sinkHelperMaybe.OnEventDelegate.Equals(value))
+                            {
+                                removeIdx = i;
+                                sinkHelper = sinkHelperMaybe;
+                                break;
+                            }
+                        }
+
+                        if (removeIdx < 0)
+                        {
+                            return;
+                        }
+
+                        this.connectionPoint.Unadvise(sinkHelper.Cookie);
+                        this.eventSinkHelpers.RemoveAt(removeIdx);
+
+                        if (this.eventSinkHelpers.Count == 0)
+                        {
+                            Marshal.ReleaseComObject(this.connectionPoint);
+                            this.connectionPoint = null;
+                        }
+                    }
+                }
+            }
+
+            void IDisposable.Dispose()
+            {
+                if (this.isDisposed)
+                {
+                    return;
+                }
+
+                lock (this.eventSinkHelpers)
+                {
+                    foreach (TestingEvents_SinkHelper sinkHelper in this.eventSinkHelpers)
+                    {
+                        this.connectionPoint.Unadvise(sinkHelper.Cookie);
+                    }
+
+                    this.eventSinkHelpers.Clear();
+                }
+
+                Marshal.ReleaseComObject(this.connectionPoint);
+                this.connectionPoint = null;
+
+                this.isDisposed = true;
+                System.GC.SuppressFinalize(this);
+            }
+
+            private void Init()
+            {
+                var container = (IConnectionPointContainer)this.ConnectionPointContainer.Target;
+
+                Guid iid = typeof(Contract.TestingEvents).GUID;
+                IConnectionPoint connectionPoint;
+                container.FindConnectionPoint(ref iid, out connectionPoint);
+
+                this.connectionPoint = connectionPoint;
+            }
+        }
+
+        /// <summary>
+        /// Wrapper for event delegate.
+        /// </summary>
+        /// <remarks>
+        /// Observe usage of the <see cref="ClassInterfaceAttribute"/> attribute.
+        /// </remarks>
+        [ClassInterface(ClassInterfaceType.None)]
+        public class TestingEvents_SinkHelper : Contract.TestingEvents
+        {
+            public int Cookie { get; set; }
+            public TestingEvents_OnEventEventHandler OnEventDelegate { get; set; }
+
+            public void OnEvent(string msg)
+            {
+                if (this.OnEventDelegate != null)
+                {
+                    this.OnEventDelegate(msg);
+                }
+            }
+        }
+    }
+}
+
+#pragma warning restore 618 // Must test deprecated features
index 199f618..1269e6a 100644 (file)
@@ -14,6 +14,7 @@ namespace Server.Contract
         public const string StringTesting = "C73C83E8-51A2-47F8-9B5C-4284458E47A6";
         public const string ErrorMarshalTesting = "71CF5C45-106C-4B32-B418-43A463C6041F";
         public const string DispatchTesting = "0F8ACD0C-ECE0-4F2A-BD1B-6BFCA93A0726";
+        public const string EventTesting = "4DBD9B61-E372-499F-84DE-EFC70AA8A009";
         public const string AggregationTesting = "4CEFE36D-F377-4B6E-8C34-819A8BB9CB04";
         public const string ColorTesting = "C222F472-DA5A-4FC6-9321-92F4F7053A65";
     }
index 47456bf..9e3f1f1 100644 (file)
@@ -63,7 +63,7 @@ public:
         REFIID riid,
         void** ppvObject)
     {
-        return DoQueryInterface<IntegerEnumerator, IEnumVARIANT>(this, riid, ppvObject);
+        return DoQueryInterface(riid, ppvObject, static_cast<IEnumVARIANT *>(this));
     }
 
     DEFINE_REF_COUNTING();
@@ -146,7 +146,7 @@ public:
         REFIID riid,
         void** ppvObject)
     {
-        return DoQueryInterface<IntegerEnumerable, IDispatch>(this, riid, ppvObject);
+        return DoQueryInterface(riid, ppvObject, static_cast<IDispatch*>(this));
     }
 
     DEFINE_REF_COUNTING();
index fd89639..c4d9e6c 100644 (file)
 
 namespace Internal
 {
-    template<typename C, typename I>
+    template<typename I>
     HRESULT __QueryInterfaceImpl(
-        /* [in] */ C *obj,
         /* [in] */ REFIID riid,
-        /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
+        /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject,
+        /* [in] */ I obj)
     {
         if (riid == __uuidof(I))
         {
-            *ppvObject = static_cast<I*>(obj);
-        }
-        else if (riid == __uuidof(IUnknown))
-        {
-            *ppvObject = static_cast<IUnknown*>(obj);
+            *ppvObject = static_cast<I>(obj);
         }
         else
         {
@@ -39,19 +35,20 @@ namespace Internal
         return S_OK;
     }
 
-    template<typename C, typename I1, typename I2, typename ...R>
+    template<typename I1, typename ...IR>
     HRESULT __QueryInterfaceImpl(
-        /* [in] */ C *obj,
         /* [in] */ REFIID riid,
-        /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
+        /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject,
+        /* [in] */ I1 i1,
+        /* [in] */ IR... remain)
     {
         if (riid == __uuidof(I1))
         {
-            *ppvObject = static_cast<I1*>(obj);
+            *ppvObject = static_cast<I1>(i1);
             return S_OK;
         }
 
-        return __QueryInterfaceImpl<C, I2, R...>(obj, riid, ppvObject);
+        return __QueryInterfaceImpl(riid, ppvObject, remain...);
     }
 }
 
@@ -68,21 +65,29 @@ public:
     UnknownImpl(UnknownImpl&&) = default;
     UnknownImpl& operator=(UnknownImpl&&) = default;
 
-    template<typename C, typename ...I>
+    template<typename I1, typename ...IR>
     HRESULT DoQueryInterface(
-        /* [in] */ C *derived,
         /* [in] */ REFIID riid,
-        /* [iid_is][out] */ _COM_Outptr_ void **ppvObject)
+        /* [iid_is][out] */ _COM_Outptr_ void **ppvObject,
+        /* [in] */ I1 i1,
+        /* [in] */ IR... remain)
     {
-        assert(derived != nullptr);
         if (ppvObject == nullptr)
             return E_POINTER;
 
-        HRESULT hr = Internal::__QueryInterfaceImpl<C, I...>(derived, riid, ppvObject);
-        if (hr == S_OK)
-            DoAddRef();
+        if (riid == __uuidof(IUnknown))
+        {
+            *ppvObject = static_cast<IUnknown *>(i1);
+        }
+        else
+        {
+            HRESULT hr = Internal::__QueryInterfaceImpl(riid, ppvObject, i1, remain...);
+            if (hr != S_OK)
+                return hr;
+        }
 
-        return hr;
+        DoAddRef();
+        return S_OK;
     }
 
     ULONG DoAddRef()
@@ -162,7 +167,7 @@ public: // IUnknown
         /* [in] */ REFIID riid,
         /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
     {
-        return DoQueryInterface<ClassFactoryBasic, IClassFactory>(this, riid, ppvObject);
+        return DoQueryInterface(riid, ppvObject, static_cast<IClassFactory *>(this));
     }
 
     DEFINE_REF_COUNTING();
@@ -221,7 +226,7 @@ public: // IUnknown
         /* [in] */ REFIID riid,
         /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
     {
-        return DoQueryInterface<ClassFactoryAggregate, IClassFactory>(this, riid, ppvObject);
+        return DoQueryInterface(riid, ppvObject, static_cast<IClassFactory *>(this));
     }
 
     DEFINE_REF_COUNTING();