1// Licensed to the .NET Foundation under one or more agreements.
2// The .NET Foundation licenses this file to you under the MIT license.
3// See the LICENSE file in the project root for more information.
4//*****************************************************************************
5//
6// InternalUnknownImpl.h
7//
8// Defines utility class ComUtil::IUnknownCommon, which provides default
9// implementations for IUnknown's AddRef, Release, and QueryInterface methods.
10//
11// Use: a class that implements one or more interfaces should derive from
12// ComUtil::IUnknownCommon with a template parameter list consisting of the
13// list of implemented interfaces.
14//
15// Example:
16// class MyInterfacesImpl :
17// public IUnknownCommon<MyInterface1, MyInterface2>
18// { ... };
19//
20// IUnknownCommon will provide base AddRef and Release semantics, and will
21// also provide an implementation of QueryInterface that will evaluate the
22// arguments against the set of supported interfaces and return the
23// appropriate result.
24//
25// If you need to specify multiple interfaces where one is a base interface
26// of another and implementing all of them would result in a compiler error,
27// you can use the NoDerive wrapper to tell IUnknownCommon to not derive from
28// this interface but just use it for QueryInterface calls.
29//
30// Example:
31// interface A
32// { ... };
33// interface B : public A
34// { ... };
35// class MyInterfacesImpl : public IUnknownCommon<B, NoDerive<A> >
36// { ... };
37//
38// If a base type also implements IUnknownCommon, then you must override
39// QueryInterface with a method that delegates to your type's
40// IUnknownCommon::QueryInterface and then to BaseType::QueryInterface.
41//
42
43
44//
45//*****************************************************************************
46
47#ifndef __InternalUnknownImpl_h__
48#define __InternalUnknownImpl_h__
49
50#include <winnt.h>
51#include "winwrap.h"
52#include "contract.h"
53#include "ex.h"
54#include "volatile.h"
55#include "mpl/type_list"
56#include "debugmacros.h"
57
58#define COMUTIL_IIDOF(x) __uuidof(x)
59
60namespace ComUtil
61{
62 //---------------------------------------------------------------------------------------------
63 template <typename T>
64 struct TypeWrapper
65 { typedef T wrapped_type; };
66
67 namespace detail
68 {
69 typedef char (&_Yes)[1];
70 typedef char (&_No)[2];
71
72 static inline _No _IsTypeWrapper(...);
73
74 template <typename T>
75 static _Yes _IsTypeWrapper(T *, typename T::wrapped_type * = nullptr);
76 }
77
78 //---------------------------------------------------------------------------------------------
79 template <typename T>
80 struct IsTypeWrapper
81 {
82 static const bool value = std::integral_constant<
83 bool,
84 sizeof(detail::_IsTypeWrapper((T*)0)) == sizeof(detail::_Yes)>::value;
85 };
86
87 //-----------------------------------------------------------------------------------------
88 // Utility to remove marker type wrappers.
89 template <typename T, bool IsWrapper = IsTypeWrapper<T>::value>
90 struct UnwrapOne
91 { typedef T type; };
92
93 template <typename T>
94 struct UnwrapOne<T, true>
95 { typedef typename T::wrapped_type type; };
96
97 template <typename T, bool IsWrapper = IsTypeWrapper<T>::value>
98 struct Unwrap
99 { typedef T type; };
100
101 template <typename T>
102 struct Unwrap<T, true>
103 { typedef typename Unwrap< typename UnwrapOne<T>::type >::type type; };
104
105 //---------------------------------------------------------------------------------------------
106 // Used as a flag to indicate that an interface should not be used as a base class.
107 // See DeriveTypeList below.
108 template <typename T>
109 struct NoDerive : public TypeWrapper<T>
110 { };
111
112 //---------------------------------------------------------------------------------------------
113 // Used to indicate that a base class contributes implemented interfaces.
114 template <typename T>
115 struct ItfBase : public TypeWrapper<T>
116 { };
117
118 namespace detail
119 {
120 using namespace mpl;
121
122 //-----------------------------------------------------------------------------------------
123 // Exposes a type that derives every type in the given type list, except for those marked
124 // with NoDerive.
125 template <typename ListT>
126 struct DeriveTypeList;
127
128 // Common case. Derive from list head and recursively on list tail.
129 template <typename HeadT, typename TailT>
130 struct DeriveTypeList< type_list<HeadT, TailT> > :
131 public Unwrap<HeadT>::type,
132 public DeriveTypeList<TailT>
133 {};
134
135 // Non-derived case. Skip this type, continue with tail.
136 template <typename HeadT, typename TailT>
137 struct DeriveTypeList< type_list< NoDerive< HeadT >, TailT> > :
138 public DeriveTypeList<TailT>
139 {};
140
141 // Termination case.
142 template <>
143 struct DeriveTypeList<null_type>
144 {};
145
146 //-----------------------------------------------------------------------------------------
147 template <typename ItfTypeListT>
148 struct GetFirstInterface;
149
150 template <typename HeadT, typename TailT>
151 struct GetFirstInterface< type_list<HeadT, TailT> >
152 { typedef HeadT type; };
153
154 template <typename HeadT, typename TailT>
155 struct GetFirstInterface< type_list< ItfBase< HeadT >, TailT> >
156 { typedef typename GetFirstInterface<TailT>::type type; };
157
158 template <>
159 struct GetFirstInterface< null_type >
160 { typedef IUnknown type; };
161
162 //-----------------------------------------------------------------------------------------
163 // Uses type lists to implement the helper. Type lists are implemented
164 // through templates, and can be best understood if compared to Scheme
165 // cdr and cons: each list type has a head type and a tail type. The
166 // head type is typically a concrete type, and the tail type is
167 // typically another list containing the remainder of the list. Type
168 // lists are terminated with a head type of null_type.
169 //
170 // QueryInterface is implemented using QIHelper, which uses type_lists
171 // and partial specialization to recursively walk the type list and
172 // look to see if the requested interface is supported. If not, then
173 // the termination case is reached and a final test against IUknown
174 // is made before returning a failure.
175 //-----------------------------------------------------------------------------------------
176 template <typename InterfaceTypeList>
177 struct QIHelper;
178
179 template <typename HeadT, typename TailT>
180 struct QIHelper< type_list< HeadT, TailT > >
181 {
182 template <typename IUnknownCommonT>
183 static inline HRESULT QI(
184 REFIID riid,
185 void **ppvObject,
186 IUnknownCommonT *pThis)
187 {
188 STATIC_CONTRACT_NOTHROW;
189 STATIC_CONTRACT_GC_NOTRIGGER;
190 STATIC_CONTRACT_ENTRY_POINT;
191
192 HRESULT hr = S_OK;
193
194 typedef typename Unwrap<HeadT>::type ItfT;
195
196 // If the interface type matches that of the head of the list,
197 // then cast to it and return success.
198 if (riid == COMUTIL_IIDOF(ItfT))
199 {
200 ItfT *pItf = static_cast<ItfT *>(pThis);
201 pItf->AddRef();
202 *ppvObject = pItf;
203 }
204 // If not, recurse on the tail of the list.
205 else
206 hr = QIHelper<TailT>::QI(riid, ppvObject, pThis);
207
208 return hr;
209 }
210 };
211
212 template <typename HeadT, typename TailT>
213 struct QIHelper< type_list< ItfBase< HeadT >, TailT> >
214 {
215 template <typename IUnknownCommonT>
216 static inline HRESULT QI(
217 REFIID riid,
218 void **ppvObject,
219 IUnknownCommonT *pThis)
220 {
221 STATIC_CONTRACT_NOTHROW;
222 STATIC_CONTRACT_GC_NOTRIGGER;
223 STATIC_CONTRACT_ENTRY_POINT;
224
225 HRESULT hr = S_OK;
226
227 hr = pThis->HeadT::QueryInterface(riid, ppvObject);
228
229 if (hr == E_NOINTERFACE)
230 hr = QIHelper<TailT>::QI(riid, ppvObject, pThis);
231
232 return hr;
233 }
234 };
235
236 // This is the termination case. In this case, we check if the
237 // requested interface is IUnknown (which is common to all interfaces).
238 template <>
239 struct QIHelper< null_type >
240 {
241 template <typename IUnknownCommonT>
242 static inline HRESULT QI(
243 REFIID riid,
244 void **ppvObject,
245 IUnknownCommonT *pThis)
246 {
247 STATIC_CONTRACT_NOTHROW;
248 STATIC_CONTRACT_GC_NOTRIGGER;
249 STATIC_CONTRACT_ENTRY_POINT;
250
251 HRESULT hr = S_OK;
252
253 // If the request was for IUnknown, cast and return success.
254 if (riid == COMUTIL_IIDOF(IUnknown))
255 {
256 typedef typename detail::GetFirstInterface<
257 typename IUnknownCommonT::InterfaceListT>::type IUnknownCastHelper;
258
259 // Cast to first interface type to then cast to IUnknown unambiguously.
260 IUnknown *pItf = static_cast<IUnknown *>(
261 static_cast<IUnknownCastHelper *>(pThis));
262 pItf->AddRef();
263 *ppvObject = pItf;
264 }
265 // Otherwise none of the interfaces match the requested IID,
266 // so return E_NOINTERFACE.
267 else
268 {
269 *ppvObject = nullptr;
270 hr = E_NOINTERFACE;
271 }
272
273 return hr;
274 }
275 };
276
277 //-----------------------------------------------------------------------------------------
278 // Is used as a virtual base to ensure that there is a single reference count field.
279 struct IUnknownCommonRef
280 {
281 inline
282 IUnknownCommonRef()
283 : m_cRef(0)
284 {}
285
286 Volatile<LONG> m_cRef;
287 };
288 }
289
290 //---------------------------------------------------------------------------------------------
291 // IUnknownCommon
292 //
293 // T0-T9 - the list of interfaces to implement.
294 template
295 <
296 typename T0 = mpl::null_type,
297 typename T1 = mpl::null_type,
298 typename T2 = mpl::null_type,
299 typename T3 = mpl::null_type,
300 typename T4 = mpl::null_type,
301 typename T5 = mpl::null_type,
302 typename T6 = mpl::null_type,
303 typename T7 = mpl::null_type,
304 typename T8 = mpl::null_type,
305 typename T9 = mpl::null_type
306 >
307 class IUnknownCommon :
308 virtual protected detail::IUnknownCommonRef,
309 public detail::DeriveTypeList< typename mpl::make_type_list<
310 T0, T1, T2, T3, T4, T5, T6, T7, T8, T9>::type >
311 {
312 public:
313 typedef typename mpl::make_type_list<
314 T0, T1, T2, T3, T4, T5, T6, T7, T8, T9>::type InterfaceListT;
315
316 // Add a virtual destructor to force derived types to also have virtual destructors.
317 virtual ~IUnknownCommon()
318 {
319 WRAPPER_NO_CONTRACT;
320 clr::dbg::PoisonMem(*this);
321 }
322
323 // Standard AddRef implementation
324 STDMETHOD_(ULONG, AddRef())
325 {
326 STATIC_CONTRACT_LIMITED_METHOD;
327 STATIC_CONTRACT_ENTRY_POINT;
328
329 return InterlockedIncrement(&m_cRef);
330 }
331
332 // Standard Release implementation.
333 STDMETHOD_(ULONG, Release())
334 {
335 STATIC_CONTRACT_LIMITED_METHOD;
336 STATIC_CONTRACT_ENTRY_POINT;
337
338 _ASSERTE(m_cRef > 0);
339
340 ULONG cRef = InterlockedDecrement(&m_cRef);
341
342 if (cRef == 0)
343 delete this; // Relies on virtual dtor to work properly.
344
345 return cRef;
346 }
347
348 // Uses detail::QIHelper for implementation.
349 STDMETHOD(QueryInterface(REFIID riid, void **ppvObject))
350 {
351 STATIC_CONTRACT_LIMITED_METHOD;
352 STATIC_CONTRACT_ENTRY_POINT;
353
354 if (ppvObject == nullptr)
355 return E_INVALIDARG;
356
357 *ppvObject = nullptr;
358
359 return detail::QIHelper<InterfaceListT>::QI(
360 riid, ppvObject, this);
361 }
362
363 template <typename ItfT>
364 HRESULT QueryInterface(ItfT **ppItf)
365 {
366 return QueryInterface(__uuidof(ItfT), reinterpret_cast<void**>(ppItf));
367 }
368
369 protected:
370 // May only be constructed as a base type.
371 inline IUnknownCommon() :
372 IUnknownCommonRef()
373 { WRAPPER_NO_CONTRACT; }
374 };
375
376 //---------------------------------------------------------------------------------------------
377 // IUnknownCommonExternal
378 //
379 // T0-T9 - the list of interfaces to implement.
380 template
381 <
382 typename T0 = mpl::null_type,
383 typename T1 = mpl::null_type,
384 typename T2 = mpl::null_type,
385 typename T3 = mpl::null_type,
386 typename T4 = mpl::null_type,
387 typename T5 = mpl::null_type,
388 typename T6 = mpl::null_type,
389 typename T7 = mpl::null_type,
390 typename T8 = mpl::null_type,
391 typename T9 = mpl::null_type
392 >
393 class IUnknownCommonExternal :
394 virtual protected detail::IUnknownCommonRef,
395 public detail::DeriveTypeList< typename mpl::make_type_list<
396 T0, T1, T2, T3, T4, T5, T6, T7, T8, T9>::type >
397 {
398 public:
399 typedef typename mpl::make_type_list<
400 T0, T1, T2, T3, T4, T5, T6, T7, T8, T9>::type InterfaceListT;
401
402 // Standard AddRef implementation
403 STDMETHOD_(ULONG, AddRef())
404 {
405 STATIC_CONTRACT_LIMITED_METHOD;
406 STATIC_CONTRACT_ENTRY_POINT;
407
408 return InterlockedIncrement(&m_cRef);
409 }
410
411 // Standard Release implementation.
412 // Should be called outside VM only
413 STDMETHOD_(ULONG, Release())
414 {
415 STATIC_CONTRACT_LIMITED_METHOD;
416 STATIC_CONTRACT_ENTRY_POINT;
417
418 _ASSERTE(m_cRef > 0);
419
420 ULONG cRef = InterlockedDecrement(&m_cRef);
421
422 if (cRef == 0)
423 {
424 Cleanup(); // Cleans up the object
425 delete this;
426 }
427
428 return cRef;
429 }
430
431 // Internal release
432 // Should be called inside VM only
433 STDMETHOD_(ULONG, InternalRelease())
434 {
435 LIMITED_METHOD_CONTRACT;
436
437 _ASSERTE(m_cRef > 0);
438
439 ULONG cRef = InterlockedDecrement(&m_cRef);
440
441 if (cRef == 0)
442 {
443 InternalCleanup(); // Cleans up the object, internal version
444 delete this;
445 }
446
447 return cRef;
448 }
449
450 // Uses detail::QIHelper for implementation.
451 STDMETHOD(QueryInterface(REFIID riid, void **ppvObject))
452 {
453 STATIC_CONTRACT_LIMITED_METHOD;
454 STATIC_CONTRACT_ENTRY_POINT;
455
456 if (ppvObject == nullptr)
457 return E_INVALIDARG;
458
459 *ppvObject = nullptr;
460
461 return detail::QIHelper<InterfaceListT>::QI(
462 riid, ppvObject, this);
463 }
464
465 template <typename ItfT>
466 HRESULT QueryInterface(ItfT **ppItf)
467 {
468 return QueryInterface(__uuidof(ItfT), reinterpret_cast<void**>(ppItf));
469 }
470
471 protected:
472 // May only be constructed as a base type.
473 inline IUnknownCommonExternal() :
474 IUnknownCommonRef()
475 { WRAPPER_NO_CONTRACT; }
476
477 // Internal version of cleanup
478 virtual void InternalCleanup() = 0;
479
480 // External version of cleanup
481 // Not surprisingly, this should call InternalCleanup to avoid duplicate code
482 // Not implemented here to avoid bringing too much into this header file
483 virtual void Cleanup() = 0;
484 };
485}
486
487#undef COMUTIL_IIDOF
488
489using ComUtil::NoDerive;
490using ComUtil::ItfBase;
491using ComUtil::IUnknownCommon;
492using ComUtil::IUnknownCommonExternal;
493
494#endif // __InternalUnknownImpl_h__
495