IVGCVSW-4483 Introduce PolymorphicPointerDowncast
[platform/upstream/armnn.git] / include / armnn / utility / PolymorphicDowncast.hpp
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include "Assert.hpp"
9
10 #include <armnn/Exceptions.hpp>
11
12 #include <memory>
13 #include <type_traits>
14
15 namespace armnn
16 {
17
18 // If we are testing then throw an exception, otherwise regular assert
19 #if defined(ARMNN_POLYMORPHIC_CAST_TESTABLE)
20 #   define ARMNN_POLYMORPHIC_CAST_CHECK_METHOD(cond) ConditionalThrow<std::bad_cast>(cond)
21 #else
22 #   define ARMNN_POLYMORPHIC_CAST_CHECK_METHOD(cond) ARMNN_ASSERT(cond)
23 #endif
24
25 //Only check the condition if debug build or during testing
26 #if !defined(NDEBUG) || defined(ARMNN_POLYMORPHIC_CAST_TESTABLE)
27 #   define ARMNN_POLYMORPHIC_CAST_CHECK(cond)  ARMNN_POLYMORPHIC_CAST_CHECK_METHOD(cond)
28 #else
29 #   define ARMNN_POLYMORPHIC_CAST_CHECK(cond) // release builds dont check the cast
30 #endif
31
32
33 namespace utility
34 {
35 // static_pointer_cast overload for std::shared_ptr
36 template <class T1, class T2>
37 std::shared_ptr<T1> StaticPointerCast (const std::shared_ptr<T2>& sp)
38 {
39     return std::static_pointer_cast<T1>(sp);
40 }
41
42 // dynamic_pointer_cast overload for std::shared_ptr
43 template <class T1, class T2>
44 std::shared_ptr<T1> DynamicPointerCast (const std::shared_ptr<T2>& sp)
45 {
46     return std::dynamic_pointer_cast<T1>(sp);
47 }
48
49 // static_pointer_cast overload for raw pointers
50 template<class T1, class T2>
51 inline T1* StaticPointerCast(T2 *ptr)
52 {
53     return static_cast<T1*>(ptr);
54 }
55
56 // dynamic_pointer_cast overload for raw pointers
57 template<class T1, class T2>
58 inline T1* DynamicPointerCast(T2 *ptr)
59 {
60     return dynamic_cast<T1*>(ptr);
61 }
62
63 } // namespace utility
64
65 /// Polymorphic downcast for build in pointers only
66 ///
67 /// Usage: Child* pChild = PolymorphicDowncast<Child*>(pBase);
68 ///
69 /// \tparam DestType    Pointer type to the target object (Child pointer type)
70 /// \tparam SourceType  Pointer type to the source object (Base pointer type)
71 /// \param value        Pointer to the source object
72 /// \return             Pointer of type DestType (Pointer of type child)
73 template<typename DestType, typename SourceType>
74 DestType PolymorphicDowncast(SourceType value)
75 {
76     static_assert(std::is_pointer<SourceType>::value &&
77                   std::is_pointer<DestType>::value,
78                   "PolymorphicDowncast only works with pointer types.");
79
80     ARMNN_POLYMORPHIC_CAST_CHECK(dynamic_cast<DestType>(value) == static_cast<DestType>(value));
81     return static_cast<DestType>(value);
82 }
83
84
85 /// Polymorphic downcast for shared pointers and build in pointers
86 ///
87 /// Usage: auto pChild = PolymorphicPointerDowncast<Child>(pBase)
88 ///
89 /// \tparam DestType    Type of the target object (Child type)
90 /// \tparam SourceType  Pointer type to the source object (Base (shared) pointer type)
91 /// \param value        Pointer to the source object
92 /// \return             Pointer of type DestType ((Shared) pointer of type child)
93 template<typename DestType, typename SourceType>
94 auto PolymorphicPointerDowncast(const SourceType& value)
95 {
96     ARMNN_POLYMORPHIC_CAST_CHECK(utility::DynamicPointerCast<DestType>(value)
97                                  == utility::StaticPointerCast<DestType>(value));
98     return utility::StaticPointerCast<DestType>(value);
99 }
100
101 } //namespace armnn