IVGCVSW-2467 Remove GetDataType<T> function
[platform/upstream/armnn.git] / include / armnn / TypesUtils.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "Types.hpp"
8 #include "Tensor.hpp"
9 #include <cmath>
10 #include <ostream>
11 #include <boost/assert.hpp>
12 #include <boost/numeric/conversion/cast.hpp>
13 #include <set>
14
15 namespace armnn
16 {
17
18 constexpr char const* GetStatusAsCString(Status status)
19 {
20     switch (status)
21     {
22         case armnn::Status::Success: return "Status::Success";
23         case armnn::Status::Failure: return "Status::Failure";
24         default:                     return "Unknown";
25     }
26 }
27
28 constexpr char const* GetActivationFunctionAsCString(ActivationFunction activation)
29 {
30     switch (activation)
31     {
32         case ActivationFunction::Sigmoid:       return "Sigmoid";
33         case ActivationFunction::TanH:          return "TanH";
34         case ActivationFunction::Linear:        return "Linear";
35         case ActivationFunction::ReLu:          return "ReLu";
36         case ActivationFunction::BoundedReLu:   return "BoundedReLu";
37         case ActivationFunction::SoftReLu:      return "SoftReLu";
38         case ActivationFunction::LeakyReLu:     return "LeakyReLu";
39         case ActivationFunction::Abs:           return "Abs";
40         case ActivationFunction::Sqrt:          return "Sqrt";
41         case ActivationFunction::Square:        return "Square";
42         default:                                return "Unknown";
43     }
44 }
45
46 constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)
47 {
48     switch (pooling)
49     {
50         case PoolingAlgorithm::Average:  return "Average";
51         case PoolingAlgorithm::Max:      return "Max";
52         case PoolingAlgorithm::L2:       return "L2";
53         default:                         return "Unknown";
54     }
55 }
56
57 constexpr char const* GetOutputShapeRoundingAsCString(OutputShapeRounding rounding)
58 {
59     switch (rounding)
60     {
61         case OutputShapeRounding::Ceiling:  return "Ceiling";
62         case OutputShapeRounding::Floor:    return "Floor";
63         default:                            return "Unknown";
64     }
65 }
66
67
68 constexpr char const* GetPaddingMethodAsCString(PaddingMethod method)
69 {
70     switch (method)
71     {
72         case PaddingMethod::Exclude:       return "Exclude";
73         case PaddingMethod::IgnoreValue:   return "IgnoreValue";
74         default:                           return "Unknown";
75     }
76 }
77
78 constexpr unsigned int GetDataTypeSize(DataType dataType)
79 {
80     switch (dataType)
81     {
82         case DataType::Float16:          return 2U;
83         case DataType::Float32:
84         case DataType::Signed32:         return 4U;
85         case DataType::QuantisedAsymm8:  return 1U;
86         case DataType::Boolean:          return 1U;
87         default:                         return 0U;
88     }
89 }
90
91 template <unsigned N>
92 constexpr bool StrEqual(const char* strA, const char (&strB)[N])
93 {
94     bool isEqual = true;
95     for (unsigned i = 0; isEqual && (i < N); ++i)
96     {
97         isEqual = (strA[i] == strB[i]);
98     }
99     return isEqual;
100 }
101
102 /// Deprecated function that will be removed together with
103 /// the Compute enum
104 constexpr armnn::Compute ParseComputeDevice(const char* str)
105 {
106     if (armnn::StrEqual(str, "CpuAcc"))
107     {
108         return armnn::Compute::CpuAcc;
109     }
110     else if (armnn::StrEqual(str, "CpuRef"))
111     {
112         return armnn::Compute::CpuRef;
113     }
114     else if (armnn::StrEqual(str, "GpuAcc"))
115     {
116         return armnn::Compute::GpuAcc;
117     }
118     else
119     {
120         return armnn::Compute::Undefined;
121     }
122 }
123
124 constexpr const char* GetDataTypeName(DataType dataType)
125 {
126     switch (dataType)
127     {
128         case DataType::Float16:         return "Float16";
129         case DataType::Float32:         return "Float32";
130         case DataType::QuantisedAsymm8: return "Unsigned8";
131         case DataType::Signed32:        return "Signed32";
132
133         default:
134             return "Unknown";
135     }
136 }
137
138 constexpr const char* GetDataLayoutName(DataLayout dataLayout)
139 {
140     switch (dataLayout)
141     {
142         case DataLayout::NCHW: return "NCHW";
143         case DataLayout::NHWC: return "NHWC";
144         default:               return "Unknown";
145     }
146 }
147
148
149 template<typename T>
150 struct IsHalfType
151     : std::integral_constant<bool, std::is_floating_point<T>::value && sizeof(T) == 2>
152 {};
153
154 template<typename T>
155 constexpr bool IsQuantizedType()
156 {
157     return std::is_integral<T>::value;
158 }
159
160 inline std::ostream& operator<<(std::ostream& os, Status stat)
161 {
162     os << GetStatusAsCString(stat);
163     return os;
164 }
165
166
167 inline std::ostream & operator<<(std::ostream & os, const armnn::TensorShape & shape)
168 {
169     os << "[";
170     for (uint32_t i=0; i<shape.GetNumDimensions(); ++i)
171     {
172         if (i!=0)
173         {
174             os << ",";
175         }
176         os << shape[i];
177     }
178     os << "]";
179     return os;
180 }
181
182 /// Quantize a floating point data type into an 8-bit data type.
183 /// @param value - The value to quantize.
184 /// @param scale - The scale (must be non-zero).
185 /// @param offset - The offset.
186 /// @return - The quantized value calculated as round(value/scale)+offset.
187 ///
188 template<typename QuantizedType>
189 inline QuantizedType Quantize(float value, float scale, int32_t offset)
190 {
191     // TODO : check we act sensibly for Inf, NaN and -Inf
192     //        see IVGCVSW-1849
193     static_assert(IsQuantizedType<QuantizedType>(), "Not an integer type.");
194     constexpr QuantizedType max = std::numeric_limits<QuantizedType>::max();
195     constexpr QuantizedType min = std::numeric_limits<QuantizedType>::lowest();
196     BOOST_ASSERT(scale != 0.f);
197     int quantized = boost::numeric_cast<int>(round(value / scale)) + offset;
198     QuantizedType quantizedBits = quantized <= min
199                                   ? min
200                                   : quantized >= max
201                                     ? max
202                                     : static_cast<QuantizedType>(quantized);
203     return quantizedBits;
204 }
205
206 /// Dequantize an 8-bit data type into a floating point data type.
207 /// @param value - The value to dequantize.
208 /// @param scale - The scale (must be non-zero).
209 /// @param offset - The offset.
210 /// @return - The dequantized value calculated as (value-offset)*scale.
211 ///
212 template <typename QuantizedType>
213 inline float Dequantize(QuantizedType value, float scale, int32_t offset)
214 {
215     static_assert(IsQuantizedType<QuantizedType>(), "Not an integer type.");
216     BOOST_ASSERT(scale != 0.f);
217     float dequantized = boost::numeric_cast<float>(value - offset) * scale;
218     return dequantized;
219 }
220
221 template <armnn::DataType DataType>
222 void VerifyTensorInfoDataType(const armnn::TensorInfo & info)
223 {
224     if (info.GetDataType() != DataType)
225     {
226         std::stringstream ss;
227         ss << "Unexpected datatype:" << armnn::GetDataTypeName(info.GetDataType())
228             << " for tensor:" << info.GetShape()
229             << ". The type expected to be: " << armnn::GetDataTypeName(DataType);
230         throw armnn::Exception(ss.str());
231     }
232 }
233
234 } //namespace armnn