1 // Licensed to the .NET Foundation under one or more agreements.
2 // The .NET Foundation licenses this file to you under the MIT license.
3 // See the LICENSE file in the project root for more information.
11 #include <type_traits>
14 // Common macro for working in COM
15 #define RETURN_IF_FAILED(exp) { hr = exp; if (FAILED(hr)) { return hr; } }
19 template<typename C, typename I>
20 HRESULT __QueryInterfaceImpl(
22 /* [in] */ REFIID riid,
23 /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
25 if (riid == __uuidof(I))
27 *ppvObject = static_cast<I*>(obj);
29 else if (riid == __uuidof(IUnknown))
31 *ppvObject = static_cast<IUnknown*>(obj);
42 template<typename C, typename I1, typename I2, typename ...R>
43 HRESULT __QueryInterfaceImpl(
45 /* [in] */ REFIID riid,
46 /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
48 if (riid == __uuidof(I1))
50 *ppvObject = static_cast<I1*>(obj);
54 return __QueryInterfaceImpl<C, I2, R...>(obj, riid, ppvObject);
58 // Implementation of IUnknown operations
62 UnknownImpl() = default;
63 virtual ~UnknownImpl() = default;
65 UnknownImpl(const UnknownImpl&) = delete;
66 UnknownImpl& operator=(const UnknownImpl&) = delete;
68 UnknownImpl(UnknownImpl&&) = default;
69 UnknownImpl& operator=(UnknownImpl&&) = default;
71 template<typename C, typename ...I>
72 HRESULT DoQueryInterface(
73 /* [in] */ C *derived,
74 /* [in] */ REFIID riid,
75 /* [iid_is][out] */ _COM_Outptr_ void **ppvObject)
77 assert(derived != nullptr);
78 if (ppvObject == nullptr)
81 HRESULT hr = Internal::__QueryInterfaceImpl<C, I...>(derived, riid, ppvObject);
90 assert(_refCount > 0);
96 assert(_refCount > 0);
97 ULONG c = (--_refCount);
104 std::atomic<ULONG> _refCount = 1;
107 // Marco to use for defining ref counting impls
108 #define DEFINE_REF_COUNTING() \
109 STDMETHOD_(ULONG, AddRef)(void) { return UnknownImpl::DoAddRef(); } \
110 STDMETHOD_(ULONG, Release)(void) { return UnknownImpl::DoRelease(); }
112 // Templated class factory
114 class ClassFactoryBasic : public UnknownImpl, public IClassFactory
117 static HRESULT Create(_In_ REFIID riid, _Outptr_ LPVOID FAR* ppv)
121 auto cf = new ClassFactoryBasic();
122 HRESULT hr = cf->QueryInterface(riid, ppv);
126 catch (const std::bad_alloc&)
128 return E_OUTOFMEMORY;
132 public: // IClassFactory
133 STDMETHOD(CreateInstance)(
134 _In_opt_ IUnknown *pUnkOuter,
136 _COM_Outptr_ void **ppvObject)
138 if (pUnkOuter != nullptr)
139 return CLASS_E_NOAGGREGATION;
144 HRESULT hr = ti->QueryInterface(riid, ppvObject);
148 catch (const std::bad_alloc&)
150 return E_OUTOFMEMORY;
154 STDMETHOD(LockServer)(/* [in] */ BOOL fLock)
156 assert(false && "Not impl");
161 STDMETHOD(QueryInterface)(
162 /* [in] */ REFIID riid,
163 /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
165 return DoQueryInterface<ClassFactoryBasic, IClassFactory>(this, riid, ppvObject);
168 DEFINE_REF_COUNTING();
171 // Templated class factory for aggregation
173 class ClassFactoryAggregate : public UnknownImpl, public IClassFactory
176 static HRESULT Create(_In_ REFIID riid, _Outptr_ LPVOID FAR* ppv)
180 auto cf = new ClassFactoryAggregate();
181 HRESULT hr = cf->QueryInterface(riid, ppv);
185 catch (const std::bad_alloc&)
187 return E_OUTOFMEMORY;
191 public: // IClassFactory
192 STDMETHOD(CreateInstance)(
193 _In_opt_ IUnknown *pUnkOuter,
195 _COM_Outptr_ void **ppvObject)
197 if (pUnkOuter != nullptr && riid != IID_IUnknown)
198 return CLASS_E_NOAGGREGATION;
202 auto ti = new T(pUnkOuter);
203 HRESULT hr = ti->QueryInterface(riid, ppvObject);
207 catch (const std::bad_alloc&)
209 return E_OUTOFMEMORY;
213 STDMETHOD(LockServer)(/* [in] */ BOOL fLock)
215 assert(false && "Not impl");
220 STDMETHOD(QueryInterface)(
221 /* [in] */ REFIID riid,
222 /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
224 return DoQueryInterface<ClassFactoryAggregate, IClassFactory>(this, riid, ppvObject);
227 DEFINE_REF_COUNTING();