065b6839fc13ff3006434a58db578ba7efc6c701
[platform/upstream/armnn.git] / include / armnn / TypesUtils.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <armnn/Tensor.hpp>
8 #include <armnn/Types.hpp>
9
10 #include <cmath>
11 #include <ostream>
12 #include <set>
13
14 namespace armnn
15 {
16
17 constexpr char const* GetStatusAsCString(Status status)
18 {
19     switch (status)
20     {
21         case armnn::Status::Success: return "Status::Success";
22         case armnn::Status::Failure: return "Status::Failure";
23         default:                     return "Unknown";
24     }
25 }
26
27 constexpr char const* GetActivationFunctionAsCString(ActivationFunction activation)
28 {
29     switch (activation)
30     {
31         case ActivationFunction::Sigmoid:       return "Sigmoid";
32         case ActivationFunction::TanH:          return "TanH";
33         case ActivationFunction::Linear:        return "Linear";
34         case ActivationFunction::ReLu:          return "ReLu";
35         case ActivationFunction::BoundedReLu:   return "BoundedReLu";
36         case ActivationFunction::SoftReLu:      return "SoftReLu";
37         case ActivationFunction::LeakyReLu:     return "LeakyReLu";
38         case ActivationFunction::Abs:           return "Abs";
39         case ActivationFunction::Sqrt:          return "Sqrt";
40         case ActivationFunction::Square:        return "Square";
41         default:                                return "Unknown";
42     }
43 }
44
45 constexpr char const* GetArgMinMaxFunctionAsCString(ArgMinMaxFunction function)
46 {
47     switch (function)
48     {
49         case ArgMinMaxFunction::Max:    return "Max";
50         case ArgMinMaxFunction::Min:    return "Min";
51         default:                        return "Unknown";
52     }
53 }
54
55 constexpr char const* GetComparisonOperationAsCString(ComparisonOperation operation)
56 {
57     switch (operation)
58     {
59         case ComparisonOperation::Equal:          return "Equal";
60         case ComparisonOperation::Greater:        return "Greater";
61         case ComparisonOperation::GreaterOrEqual: return "GreaterOrEqual";
62         case ComparisonOperation::Less:           return "Less";
63         case ComparisonOperation::LessOrEqual:    return "LessOrEqual";
64         case ComparisonOperation::NotEqual:       return "NotEqual";
65         default:                                  return "Unknown";
66     }
67 }
68
69 constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)
70 {
71     switch (pooling)
72     {
73         case PoolingAlgorithm::Average:  return "Average";
74         case PoolingAlgorithm::Max:      return "Max";
75         case PoolingAlgorithm::L2:       return "L2";
76         default:                         return "Unknown";
77     }
78 }
79
80 constexpr char const* GetResizeMethodAsCString(ResizeMethod resizeMethod)
81 {
82     switch (resizeMethod)
83     {
84         case ResizeMethod::Bilinear:        return "Bilinear";
85         case ResizeMethod::NearestNeighbor: return "NearestNeighbor";
86         default:                            return "Unknown";
87     }
88 }
89
90 constexpr char const* GetOutputShapeRoundingAsCString(OutputShapeRounding rounding)
91 {
92     switch (rounding)
93     {
94         case OutputShapeRounding::Ceiling:  return "Ceiling";
95         case OutputShapeRounding::Floor:    return "Floor";
96         default:                            return "Unknown";
97     }
98 }
99
100
101 constexpr char const* GetPaddingMethodAsCString(PaddingMethod method)
102 {
103     switch (method)
104     {
105         case PaddingMethod::Exclude:       return "Exclude";
106         case PaddingMethod::IgnoreValue:   return "IgnoreValue";
107         default:                           return "Unknown";
108     }
109 }
110
111 constexpr unsigned int GetDataTypeSize(DataType dataType)
112 {
113     switch (dataType)
114     {
115         case DataType::Float16:               return 2U;
116         case DataType::Float32:
117         case DataType::Signed32:              return 4U;
118         case DataType::QAsymmU8:              return 1U;
119         case DataType::QSymmS8:               return 1U;
120         case DataType::QuantizedSymm8PerAxis: return 1U;
121         case DataType::QSymmS16:       return 2U;
122         case DataType::Boolean:               return 1U;
123         default:                              return 0U;
124     }
125 }
126
127 template <unsigned N>
128 constexpr bool StrEqual(const char* strA, const char (&strB)[N])
129 {
130     bool isEqual = true;
131     for (unsigned i = 0; isEqual && (i < N); ++i)
132     {
133         isEqual = (strA[i] == strB[i]);
134     }
135     return isEqual;
136 }
137
138 /// Deprecated function that will be removed together with
139 /// the Compute enum
140 constexpr armnn::Compute ParseComputeDevice(const char* str)
141 {
142     if (armnn::StrEqual(str, "CpuAcc"))
143     {
144         return armnn::Compute::CpuAcc;
145     }
146     else if (armnn::StrEqual(str, "CpuRef"))
147     {
148         return armnn::Compute::CpuRef;
149     }
150     else if (armnn::StrEqual(str, "GpuAcc"))
151     {
152         return armnn::Compute::GpuAcc;
153     }
154     else
155     {
156         return armnn::Compute::Undefined;
157     }
158 }
159
160 constexpr const char* GetDataTypeName(DataType dataType)
161 {
162     switch (dataType)
163     {
164         case DataType::Float16:               return "Float16";
165         case DataType::Float32:               return "Float32";
166         case DataType::QAsymmU8:              return "QAsymmU8";
167         case DataType::QSymmS8:               return "QSymmS8";
168         case DataType::QuantizedSymm8PerAxis: return "QSymm8PerAxis";
169         case DataType::QSymmS16:       return "QSymm16";
170         case DataType::Signed32:              return "Signed32";
171         case DataType::Boolean:               return "Boolean";
172
173         default:
174             return "Unknown";
175     }
176 }
177
178 constexpr const char* GetDataLayoutName(DataLayout dataLayout)
179 {
180     switch (dataLayout)
181     {
182         case DataLayout::NCHW: return "NCHW";
183         case DataLayout::NHWC: return "NHWC";
184         default:               return "Unknown";
185     }
186 }
187
188
189 template<typename T>
190 struct IsHalfType
191     : std::integral_constant<bool, std::is_floating_point<T>::value && sizeof(T) == 2>
192 {};
193
194 template<typename T>
195 constexpr bool IsQuantizedType()
196 {
197     return std::is_integral<T>::value;
198 }
199
200 constexpr bool IsQuantizedType(DataType dataType)
201 {
202     return dataType == DataType::QAsymmU8        ||
203            dataType == DataType::QSymmS8         ||
204            dataType == DataType::QSymmS16 ||
205            dataType == DataType::QuantizedSymm8PerAxis;
206 }
207
208 inline std::ostream& operator<<(std::ostream& os, Status stat)
209 {
210     os << GetStatusAsCString(stat);
211     return os;
212 }
213
214
215 inline std::ostream & operator<<(std::ostream & os, const armnn::TensorShape & shape)
216 {
217     os << "[";
218     for (uint32_t i=0; i<shape.GetNumDimensions(); ++i)
219     {
220         if (i!=0)
221         {
222             os << ",";
223         }
224         os << shape[i];
225     }
226     os << "]";
227     return os;
228 }
229
230 /// Quantize a floating point data type into an 8-bit data type.
231 /// @param value - The value to quantize.
232 /// @param scale - The scale (must be non-zero).
233 /// @param offset - The offset.
234 /// @return - The quantized value calculated as round(value/scale)+offset.
235 ///
236 template<typename QuantizedType>
237 QuantizedType Quantize(float value, float scale, int32_t offset);
238
239 /// Dequantize an 8-bit data type into a floating point data type.
240 /// @param value - The value to dequantize.
241 /// @param scale - The scale (must be non-zero).
242 /// @param offset - The offset.
243 /// @return - The dequantized value calculated as (value-offset)*scale.
244 ///
245 template <typename QuantizedType>
246 float Dequantize(QuantizedType value, float scale, int32_t offset);
247
248 inline void VerifyTensorInfoDataType(const armnn::TensorInfo & info, armnn::DataType dataType)
249 {
250     if (info.GetDataType() != dataType)
251     {
252         std::stringstream ss;
253         ss << "Unexpected datatype:" << armnn::GetDataTypeName(info.GetDataType())
254            << " for tensor:" << info.GetShape()
255            << ". The type expected to be: " << armnn::GetDataTypeName(dataType);
256         throw armnn::Exception(ss.str());
257     }
258 }
259
260 } //namespace armnn