Guarded devirtualization foundations (#21270)
[platform/upstream/coreclr.git] / tests / src / Interop / COM / NativeServer / ComHelpers.h
1 // Licensed to the .NET Foundation under one or more agreements.
2 // The .NET Foundation licenses this file to you under the MIT license.
3 // See the LICENSE file in the project root for more information.
4
5 #pragma once
6
7 #include <Windows.h>
8 #include <comdef.h>
9 #include <cassert>
10 #include <exception>
11 #include <type_traits>
12 #include <atomic>
13
14 // Common macro for working in COM
15 #define RETURN_IF_FAILED(exp) { hr = exp; if (FAILED(hr)) { return hr; } }
16
17 namespace Internal
18 {
19     template<typename C, typename I>
20     HRESULT __QueryInterfaceImpl(
21         /* [in] */ C *obj,
22         /* [in] */ REFIID riid,
23         /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
24     {
25         if (riid == __uuidof(I))
26         {
27             *ppvObject = static_cast<I*>(obj);
28         }
29         else if (riid == __uuidof(IUnknown))
30         {
31             *ppvObject = static_cast<IUnknown*>(obj);
32         }
33         else
34         {
35             *ppvObject = nullptr;
36             return E_NOINTERFACE;
37         }
38
39         return S_OK;
40     }
41
42     template<typename C, typename I1, typename I2, typename ...R>
43     HRESULT __QueryInterfaceImpl(
44         /* [in] */ C *obj,
45         /* [in] */ REFIID riid,
46         /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
47     {
48         if (riid == __uuidof(I1))
49         {
50             *ppvObject = static_cast<I1*>(obj);
51             return S_OK;
52         }
53
54         return __QueryInterfaceImpl<C, I2, R...>(obj, riid, ppvObject);
55     }
56 }
57
58 // Implementation of IUnknown operations
59 class UnknownImpl
60 {
61 public:
62     UnknownImpl() = default;
63     virtual ~UnknownImpl() = default;
64
65     UnknownImpl(const UnknownImpl&) = delete;
66     UnknownImpl& operator=(const UnknownImpl&) = delete;
67
68     UnknownImpl(UnknownImpl&&) = default;
69     UnknownImpl& operator=(UnknownImpl&&) = default;
70
71     template<typename C, typename ...I>
72     HRESULT DoQueryInterface(
73         /* [in] */ C *derived,
74         /* [in] */ REFIID riid,
75         /* [iid_is][out] */ _COM_Outptr_ void **ppvObject)
76     {
77         assert(derived != nullptr);
78         if (ppvObject == nullptr)
79             return E_POINTER;
80
81         HRESULT hr = Internal::__QueryInterfaceImpl<C, I...>(derived, riid, ppvObject);
82         if (hr == S_OK)
83             DoAddRef();
84
85         return hr;
86     }
87
88     ULONG DoAddRef()
89     {
90         assert(_refCount > 0);
91         return (++_refCount);
92     }
93
94     ULONG DoRelease()
95     {
96         assert(_refCount > 0);
97         ULONG c = (--_refCount);
98         if (c == 0)
99             delete this;
100         return c;
101     }
102
103 private:
104     std::atomic<ULONG> _refCount = 1;
105 };
106
107 // Marco to use for defining ref counting impls
108 #define DEFINE_REF_COUNTING() \
109     STDMETHOD_(ULONG, AddRef)(void) { return UnknownImpl::DoAddRef(); } \
110     STDMETHOD_(ULONG, Release)(void) { return UnknownImpl::DoRelease(); }
111
112 // Templated class factory
113 template<typename T>
114 class ClassFactoryBasic : public UnknownImpl, public IClassFactory
115 {
116 public: // static
117     static HRESULT Create(_In_ REFIID riid, _Outptr_ LPVOID FAR* ppv)
118     {
119         try
120         {
121             auto cf = new ClassFactoryBasic();
122             HRESULT hr = cf->QueryInterface(riid, ppv);
123             cf->Release();
124             return hr;
125         }
126         catch (const std::bad_alloc&)
127         {
128             return E_OUTOFMEMORY;
129         }
130     }
131
132 public: // IClassFactory
133     STDMETHOD(CreateInstance)(
134         _In_opt_  IUnknown *pUnkOuter,
135         _In_  REFIID riid,
136         _COM_Outptr_  void **ppvObject)
137     {
138         if (pUnkOuter != nullptr)
139             return CLASS_E_NOAGGREGATION;
140
141         try
142         {
143             auto ti = new T();
144             HRESULT hr = ti->QueryInterface(riid, ppvObject);
145             ti->Release();
146             return hr;
147         }
148         catch (const std::bad_alloc&)
149         {
150             return E_OUTOFMEMORY;
151         }
152     }
153
154     STDMETHOD(LockServer)(/* [in] */ BOOL fLock)
155     {
156         assert(false && "Not impl");
157         return E_NOTIMPL;
158     }
159
160 public: // IUnknown
161     STDMETHOD(QueryInterface)(
162         /* [in] */ REFIID riid,
163         /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
164     {
165         return DoQueryInterface<ClassFactoryBasic, IClassFactory>(this, riid, ppvObject);
166     }
167
168     DEFINE_REF_COUNTING();
169 };
170
171 // Templated class factory for aggregation
172 template<typename T>
173 class ClassFactoryAggregate : public UnknownImpl, public IClassFactory
174 {
175 public: // static
176     static HRESULT Create(_In_ REFIID riid, _Outptr_ LPVOID FAR* ppv)
177     {
178         try
179         {
180             auto cf = new ClassFactoryAggregate();
181             HRESULT hr = cf->QueryInterface(riid, ppv);
182             cf->Release();
183             return hr;
184         }
185         catch (const std::bad_alloc&)
186         {
187             return E_OUTOFMEMORY;
188         }
189     }
190
191 public: // IClassFactory
192     STDMETHOD(CreateInstance)(
193         _In_opt_  IUnknown *pUnkOuter,
194         _In_  REFIID riid,
195         _COM_Outptr_  void **ppvObject)
196     {
197         if (pUnkOuter != nullptr && riid != IID_IUnknown)
198             return CLASS_E_NOAGGREGATION;
199
200         try
201         {
202             auto ti = new T(pUnkOuter);
203             HRESULT hr = ti->QueryInterface(riid, ppvObject);
204             ti->Release();
205             return hr;
206         }
207         catch (const std::bad_alloc&)
208         {
209             return E_OUTOFMEMORY;
210         }
211     }
212
213     STDMETHOD(LockServer)(/* [in] */ BOOL fLock)
214     {
215         assert(false && "Not impl");
216         return E_NOTIMPL;
217     }
218
219 public: // IUnknown
220     STDMETHOD(QueryInterface)(
221         /* [in] */ REFIID riid,
222         /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
223     {
224         return DoQueryInterface<ClassFactoryAggregate, IClassFactory>(this, riid, ppvObject);
225     }
226
227     DEFINE_REF_COUNTING();
228 };