Release 18.03
[platform/upstream/armnn.git] / include / armnn / TypesUtils.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include "Types.hpp"
8 #include <cmath>
9 #include <ostream>
10 #include <boost/assert.hpp>
11 #include <boost/numeric/conversion/cast.hpp>
12
13 namespace armnn
14 {
15
16 constexpr char const* GetStatusAsCString(Status compute)
17 {
18     switch (compute)
19     {
20         case armnn::Status::Success: return "Status::Success";
21         case armnn::Status::Failure: return "Status::Failure";
22         default:                     return "Unknown";
23     }
24 }
25
26 constexpr char const* GetComputeDeviceAsCString(Compute compute)
27 {
28     switch (compute)
29     {
30         case armnn::Compute::CpuRef: return "CpuRef";
31         case armnn::Compute::CpuAcc: return "CpuAcc";
32         case armnn::Compute::GpuAcc: return "GpuAcc";
33         default:                     return "Unknown";
34     }
35 }
36
37 constexpr char const* GetActivationFunctionAsCString(ActivationFunction activation)
38 {
39     switch (activation)
40     {
41         case ActivationFunction::Sigmoid:       return "Sigmoid";
42         case ActivationFunction::TanH:          return "TanH";
43         case ActivationFunction::Linear:        return "Linear";
44         case ActivationFunction::ReLu:          return "ReLu";
45         case ActivationFunction::BoundedReLu:   return "BoundedReLu";
46         case ActivationFunction::SoftReLu:      return "SoftReLu";
47         case ActivationFunction::LeakyReLu:     return "LeakyReLu";
48         case ActivationFunction::Abs:           return "Abs";
49         case ActivationFunction::Sqrt:          return "Sqrt";
50         case ActivationFunction::Square:        return "Square";
51         default:                                return "Unknown";
52     }
53 }
54
55 constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)
56 {
57     switch (pooling)
58     {
59         case PoolingAlgorithm::Average:  return "Average";
60         case PoolingAlgorithm::Max:      return "Max";
61         case PoolingAlgorithm::L2:       return "L2";
62         default:                         return "Unknown";
63     }
64 }
65
66 constexpr char const* GetOutputShapeRoundingAsCString(OutputShapeRounding rounding)
67 {
68     switch (rounding)
69     {
70         case OutputShapeRounding::Ceiling:  return "Ceiling";
71         case OutputShapeRounding::Floor:    return "Floor";
72         default:                            return "Unknown";
73     }
74 }
75
76
77 constexpr char const* GetPaddingMethodAsCString(PaddingMethod method)
78 {
79     switch (method)
80     {
81         case PaddingMethod::Exclude:       return "Exclude";
82         case PaddingMethod::IgnoreValue:   return "IgnoreValue";
83         default:                           return "Unknown";
84     }
85 }
86
87 constexpr unsigned int GetDataTypeSize(DataType dataType)
88 {
89     switch (dataType)
90     {
91         case DataType::Signed32:
92         case DataType::Float32:   return 4U;
93         case DataType::QuantisedAsymm8: return 1U;
94         default:                  return 0U;
95     }
96 }
97
98 template <int N>
99 constexpr bool StrEqual(const char* strA, const char (&strB)[N])
100 {
101     bool isEqual = true;
102     for (int i = 0; isEqual && (i < N); ++i)
103     {
104         isEqual = (strA[i] == strB[i]);
105     }
106     return isEqual;
107 }
108
109 constexpr Compute ParseComputeDevice(const char* str)
110 {
111     if (StrEqual(str, "CpuAcc"))
112     {
113         return armnn::Compute::CpuAcc;
114     }
115     else if (StrEqual(str, "CpuRef"))
116     {
117         return armnn::Compute::CpuRef;
118     }
119     else if (StrEqual(str, "GpuAcc"))
120     {
121         return armnn::Compute::GpuAcc;
122     }
123     else
124     {
125         return armnn::Compute::Undefined;
126     }
127 }
128
129 constexpr const char* GetDataTypeName(DataType dataType)
130 {
131     switch (dataType)
132     {
133         case DataType::Float32:   return "Float32";
134         case DataType::QuantisedAsymm8: return "Unsigned8";
135         case DataType::Signed32:  return "Signed32";
136         default:                  return "Unknown";
137     }
138 }
139
140 template <typename T>
141 constexpr DataType GetDataType();
142
143 template <>
144 constexpr DataType GetDataType<float>()
145 {
146     return DataType::Float32;
147 }
148
149 template <>
150 constexpr DataType GetDataType<uint8_t>()
151 {
152     return DataType::QuantisedAsymm8;
153 }
154
155 template <>
156 constexpr DataType GetDataType<int32_t>()
157 {
158     return DataType::Signed32;
159 }
160
161 template<typename T>
162 constexpr bool IsQuantizedType()
163 {
164     return std::is_integral<T>::value;
165 }
166
167
168 template<DataType DT>
169 struct ResolveTypeImpl;
170
171 template<>
172 struct ResolveTypeImpl<DataType::QuantisedAsymm8>
173 {
174     using Type = uint8_t;
175 };
176
177 template<>
178 struct ResolveTypeImpl<DataType::Float32>
179 {
180     using Type = float;
181 };
182
183 template<DataType DT>
184 using ResolveType = typename ResolveTypeImpl<DT>::Type;
185
186
187 inline std::ostream& operator<<(std::ostream& os, Status stat)
188 {
189     os << GetStatusAsCString(stat);
190     return os;
191 }
192
193 inline std::ostream& operator<<(std::ostream& os, Compute compute)
194 {
195     os << GetComputeDeviceAsCString(compute);
196     return os;
197 }
198
199 /// Quantize a floating point data type into an 8-bit data type
200 /// @param value The value to quantize
201 /// @param scale The scale (must be non-zero)
202 /// @param offset The offset
203 /// @return The quantized value calculated as round(value/scale)+offset
204 ///
205 template<typename QuantizedType>
206 inline QuantizedType Quantize(float value, float scale, int32_t offset)
207 {
208     static_assert(IsQuantizedType<QuantizedType>(), "Not an integer type.");
209     constexpr QuantizedType max = std::numeric_limits<QuantizedType>::max();
210     constexpr QuantizedType min = std::numeric_limits<QuantizedType>::lowest();
211     BOOST_ASSERT(scale != 0.f);
212     int quantized = boost::numeric_cast<int>(round(value / scale)) + offset;
213     QuantizedType quantizedBits = quantized < min ? min : quantized > max ? max : static_cast<QuantizedType>(quantized);
214     return quantizedBits;
215 }
216
217 /// Dequantize an 8-bit data type into a floating point data type
218 /// @param value The value to dequantize
219 /// @param scale The scale (must be non-zero)
220 /// @param offset The offset
221 /// @return The dequantized value calculated as (value-offset)*scale
222 ///
223 template <typename QuantizedType>
224 inline float Dequantize(QuantizedType value, float scale, int32_t offset)
225 {
226     static_assert(IsQuantizedType<QuantizedType>(), "Not an integer type.");
227     BOOST_ASSERT(scale != 0.f);
228     float dequantized = boost::numeric_cast<float>(value - offset) * scale;
229     return dequantized;
230 }
231
232 } //namespace armnn