2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
10 #include <armnn/Exceptions.hpp>
13 #include <type_traits>
18 // If we are testing then throw an exception, otherwise regular assert
19 #if defined(ARMNN_POLYMORPHIC_CAST_TESTABLE)
20 # define ARMNN_POLYMORPHIC_CAST_CHECK_METHOD(cond) ConditionalThrow<std::bad_cast>(cond)
22 # define ARMNN_POLYMORPHIC_CAST_CHECK_METHOD(cond) ARMNN_ASSERT(cond)
25 //Only check the condition if debug build or during testing
26 #if !defined(NDEBUG) || defined(ARMNN_POLYMORPHIC_CAST_TESTABLE)
27 # define ARMNN_POLYMORPHIC_CAST_CHECK(cond) ARMNN_POLYMORPHIC_CAST_CHECK_METHOD(cond)
29 # define ARMNN_POLYMORPHIC_CAST_CHECK(cond) // release builds dont check the cast
35 // static_pointer_cast overload for std::shared_ptr
36 template <class T1, class T2>
37 std::shared_ptr<T1> StaticPointerCast (const std::shared_ptr<T2>& sp)
39 return std::static_pointer_cast<T1>(sp);
42 // dynamic_pointer_cast overload for std::shared_ptr
43 template <class T1, class T2>
44 std::shared_ptr<T1> DynamicPointerCast (const std::shared_ptr<T2>& sp)
46 return std::dynamic_pointer_cast<T1>(sp);
49 // static_pointer_cast overload for raw pointers
50 template<class T1, class T2>
51 inline T1* StaticPointerCast(T2 *ptr)
53 return static_cast<T1*>(ptr);
56 // dynamic_pointer_cast overload for raw pointers
57 template<class T1, class T2>
58 inline T1* DynamicPointerCast(T2 *ptr)
60 return dynamic_cast<T1*>(ptr);
63 } // namespace utility
65 /// Polymorphic downcast for build in pointers only
67 /// Usage: Child* pChild = PolymorphicDowncast<Child*>(pBase);
69 /// \tparam DestType Pointer type to the target object (Child pointer type)
70 /// \tparam SourceType Pointer type to the source object (Base pointer type)
71 /// \param value Pointer to the source object
72 /// \return Pointer of type DestType (Pointer of type child)
73 template<typename DestType, typename SourceType>
74 DestType PolymorphicDowncast(SourceType value)
76 static_assert(std::is_pointer<SourceType>::value &&
77 std::is_pointer<DestType>::value,
78 "PolymorphicDowncast only works with pointer types.");
80 ARMNN_POLYMORPHIC_CAST_CHECK(dynamic_cast<DestType>(value) == static_cast<DestType>(value));
81 return static_cast<DestType>(value);
85 /// Polymorphic downcast for shared pointers and build in pointers
87 /// Usage: auto pChild = PolymorphicPointerDowncast<Child>(pBase)
89 /// \tparam DestType Type of the target object (Child type)
90 /// \tparam SourceType Pointer type to the source object (Base (shared) pointer type)
91 /// \param value Pointer to the source object
92 /// \return Pointer of type DestType ((Shared) pointer of type child)
93 template<typename DestType, typename SourceType>
94 auto PolymorphicPointerDowncast(const SourceType& value)
96 ARMNN_POLYMORPHIC_CAST_CHECK(utility::DynamicPointerCast<DestType>(value)
97 == utility::StaticPointerCast<DestType>(value));
98 return utility::StaticPointerCast<DestType>(value);