IVGCVSW-4268 Print all Descriptors on dot graph
[platform/upstream/armnn.git] / include / armnn / TypesUtils.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <armnn/Tensor.hpp>
8 #include <armnn/Types.hpp>
9
10 #include <cmath>
11 #include <ostream>
12 #include <set>
13
14 namespace armnn
15 {
16
17 constexpr char const* GetStatusAsCString(Status status)
18 {
19     switch (status)
20     {
21         case armnn::Status::Success: return "Status::Success";
22         case armnn::Status::Failure: return "Status::Failure";
23         default:                     return "Unknown";
24     }
25 }
26
27 constexpr char const* GetActivationFunctionAsCString(ActivationFunction activation)
28 {
29     switch (activation)
30     {
31         case ActivationFunction::Sigmoid:       return "Sigmoid";
32         case ActivationFunction::TanH:          return "TanH";
33         case ActivationFunction::Linear:        return "Linear";
34         case ActivationFunction::ReLu:          return "ReLu";
35         case ActivationFunction::BoundedReLu:   return "BoundedReLu";
36         case ActivationFunction::SoftReLu:      return "SoftReLu";
37         case ActivationFunction::LeakyReLu:     return "LeakyReLu";
38         case ActivationFunction::Abs:           return "Abs";
39         case ActivationFunction::Sqrt:          return "Sqrt";
40         case ActivationFunction::Square:        return "Square";
41         default:                                return "Unknown";
42     }
43 }
44
45 constexpr char const* GetArgMinMaxFunctionAsCString(ArgMinMaxFunction function)
46 {
47     switch (function)
48     {
49         case ArgMinMaxFunction::Max:    return "Max";
50         case ArgMinMaxFunction::Min:    return "Min";
51         default:                        return "Unknown";
52     }
53 }
54
55 constexpr char const* GetComparisonOperationAsCString(ComparisonOperation operation)
56 {
57     switch (operation)
58     {
59         case ComparisonOperation::Equal:          return "Equal";
60         case ComparisonOperation::Greater:        return "Greater";
61         case ComparisonOperation::GreaterOrEqual: return "GreaterOrEqual";
62         case ComparisonOperation::Less:           return "Less";
63         case ComparisonOperation::LessOrEqual:    return "LessOrEqual";
64         case ComparisonOperation::NotEqual:       return "NotEqual";
65         default:                                  return "Unknown";
66     }
67 }
68
69 constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)
70 {
71     switch (pooling)
72     {
73         case PoolingAlgorithm::Average:  return "Average";
74         case PoolingAlgorithm::Max:      return "Max";
75         case PoolingAlgorithm::L2:       return "L2";
76         default:                         return "Unknown";
77     }
78 }
79
80 constexpr char const* GetOutputShapeRoundingAsCString(OutputShapeRounding rounding)
81 {
82     switch (rounding)
83     {
84         case OutputShapeRounding::Ceiling:  return "Ceiling";
85         case OutputShapeRounding::Floor:    return "Floor";
86         default:                            return "Unknown";
87     }
88 }
89
90 constexpr char const* GetPaddingMethodAsCString(PaddingMethod method)
91 {
92     switch (method)
93     {
94         case PaddingMethod::Exclude:       return "Exclude";
95         case PaddingMethod::IgnoreValue:   return "IgnoreValue";
96         default:                           return "Unknown";
97     }
98 }
99
100 constexpr unsigned int GetDataTypeSize(DataType dataType)
101 {
102     switch (dataType)
103     {
104         case DataType::Float16:               return 2U;
105         case DataType::Float32:
106         case DataType::Signed32:              return 4U;
107         case DataType::QAsymmU8:              return 1U;
108         case DataType::QSymmS8:               return 1U;
109         case DataType::QuantizedSymm8PerAxis: return 1U;
110         case DataType::QSymmS16:       return 2U;
111         case DataType::Boolean:               return 1U;
112         default:                              return 0U;
113     }
114 }
115
116 template <unsigned N>
117 constexpr bool StrEqual(const char* strA, const char (&strB)[N])
118 {
119     bool isEqual = true;
120     for (unsigned i = 0; isEqual && (i < N); ++i)
121     {
122         isEqual = (strA[i] == strB[i]);
123     }
124     return isEqual;
125 }
126
127 /// Deprecated function that will be removed together with
128 /// the Compute enum
129 constexpr armnn::Compute ParseComputeDevice(const char* str)
130 {
131     if (armnn::StrEqual(str, "CpuAcc"))
132     {
133         return armnn::Compute::CpuAcc;
134     }
135     else if (armnn::StrEqual(str, "CpuRef"))
136     {
137         return armnn::Compute::CpuRef;
138     }
139     else if (armnn::StrEqual(str, "GpuAcc"))
140     {
141         return armnn::Compute::GpuAcc;
142     }
143     else
144     {
145         return armnn::Compute::Undefined;
146     }
147 }
148
149 constexpr const char* GetDataTypeName(DataType dataType)
150 {
151     switch (dataType)
152     {
153         case DataType::Float16:               return "Float16";
154         case DataType::Float32:               return "Float32";
155         case DataType::QAsymmU8:              return "QAsymmU8";
156         case DataType::QSymmS8:               return "QSymmS8";
157         case DataType::QuantizedSymm8PerAxis: return "QSymm8PerAxis";
158         case DataType::QSymmS16:       return "QSymm16";
159         case DataType::Signed32:              return "Signed32";
160         case DataType::Boolean:               return "Boolean";
161
162         default:
163             return "Unknown";
164     }
165 }
166
167 constexpr const char* GetDataLayoutName(DataLayout dataLayout)
168 {
169     switch (dataLayout)
170     {
171         case DataLayout::NCHW: return "NCHW";
172         case DataLayout::NHWC: return "NHWC";
173         default:               return "Unknown";
174     }
175 }
176
177 constexpr const char* GetNormalizationAlgorithmChannelAsCString(NormalizationAlgorithmChannel channel)
178 {
179     switch (channel)
180     {
181         case NormalizationAlgorithmChannel::Across: return "Across";
182         case NormalizationAlgorithmChannel::Within: return "Within";
183         default:                                    return "Unknown";
184     }
185 }
186
187 constexpr const char* GetNormalizationAlgorithmMethodAsCString(NormalizationAlgorithmMethod method)
188 {
189     switch (method)
190     {
191         case NormalizationAlgorithmMethod::LocalBrightness: return "LocalBrightness";
192         case NormalizationAlgorithmMethod::LocalContrast:   return "LocalContrast";
193         default:                                            return "Unknown";
194     }
195 }
196
197 constexpr const char* GetResizeMethodAsCString(ResizeMethod method)
198 {
199     switch (method)
200     {
201         case ResizeMethod::Bilinear:        return "Bilinear";
202         case ResizeMethod::NearestNeighbor: return "NearestNeighbour";
203         default:                            return "Unknown";
204     }
205 }
206
207 template<typename T>
208 struct IsHalfType
209     : std::integral_constant<bool, std::is_floating_point<T>::value && sizeof(T) == 2>
210 {};
211
212 template<typename T>
213 constexpr bool IsQuantizedType()
214 {
215     return std::is_integral<T>::value;
216 }
217
218 constexpr bool IsQuantizedType(DataType dataType)
219 {
220     return dataType == DataType::QAsymmU8        ||
221            dataType == DataType::QSymmS8         ||
222            dataType == DataType::QSymmS16 ||
223            dataType == DataType::QuantizedSymm8PerAxis;
224 }
225
226 inline std::ostream& operator<<(std::ostream& os, Status stat)
227 {
228     os << GetStatusAsCString(stat);
229     return os;
230 }
231
232
233 inline std::ostream & operator<<(std::ostream & os, const armnn::TensorShape & shape)
234 {
235     os << "[";
236     for (uint32_t i=0; i<shape.GetNumDimensions(); ++i)
237     {
238         if (i!=0)
239         {
240             os << ",";
241         }
242         os << shape[i];
243     }
244     os << "]";
245     return os;
246 }
247
248 /// Quantize a floating point data type into an 8-bit data type.
249 /// @param value - The value to quantize.
250 /// @param scale - The scale (must be non-zero).
251 /// @param offset - The offset.
252 /// @return - The quantized value calculated as round(value/scale)+offset.
253 ///
254 template<typename QuantizedType>
255 QuantizedType Quantize(float value, float scale, int32_t offset);
256
257 /// Dequantize an 8-bit data type into a floating point data type.
258 /// @param value - The value to dequantize.
259 /// @param scale - The scale (must be non-zero).
260 /// @param offset - The offset.
261 /// @return - The dequantized value calculated as (value-offset)*scale.
262 ///
263 template <typename QuantizedType>
264 float Dequantize(QuantizedType value, float scale, int32_t offset);
265
266 inline void VerifyTensorInfoDataType(const armnn::TensorInfo & info, armnn::DataType dataType)
267 {
268     if (info.GetDataType() != dataType)
269     {
270         std::stringstream ss;
271         ss << "Unexpected datatype:" << armnn::GetDataTypeName(info.GetDataType())
272            << " for tensor:" << info.GetShape()
273            << ". The type expected to be: " << armnn::GetDataTypeName(dataType);
274         throw armnn::Exception(ss.str());
275     }
276 }
277
278 } //namespace armnn