2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
10 #include <boost/assert.hpp>
11 #include <boost/numeric/conversion/cast.hpp>
16 constexpr char const* GetStatusAsCString(Status compute)
20 case armnn::Status::Success: return "Status::Success";
21 case armnn::Status::Failure: return "Status::Failure";
22 default: return "Unknown";
26 constexpr char const* GetComputeDeviceAsCString(Compute compute)
30 case armnn::Compute::CpuRef: return "CpuRef";
31 case armnn::Compute::CpuAcc: return "CpuAcc";
32 case armnn::Compute::GpuAcc: return "GpuAcc";
33 default: return "Unknown";
37 constexpr char const* GetActivationFunctionAsCString(ActivationFunction activation)
41 case ActivationFunction::Sigmoid: return "Sigmoid";
42 case ActivationFunction::TanH: return "TanH";
43 case ActivationFunction::Linear: return "Linear";
44 case ActivationFunction::ReLu: return "ReLu";
45 case ActivationFunction::BoundedReLu: return "BoundedReLu";
46 case ActivationFunction::SoftReLu: return "SoftReLu";
47 case ActivationFunction::LeakyReLu: return "LeakyReLu";
48 case ActivationFunction::Abs: return "Abs";
49 case ActivationFunction::Sqrt: return "Sqrt";
50 case ActivationFunction::Square: return "Square";
51 default: return "Unknown";
55 constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)
59 case PoolingAlgorithm::Average: return "Average";
60 case PoolingAlgorithm::Max: return "Max";
61 case PoolingAlgorithm::L2: return "L2";
62 default: return "Unknown";
66 constexpr char const* GetOutputShapeRoundingAsCString(OutputShapeRounding rounding)
70 case OutputShapeRounding::Ceiling: return "Ceiling";
71 case OutputShapeRounding::Floor: return "Floor";
72 default: return "Unknown";
77 constexpr char const* GetPaddingMethodAsCString(PaddingMethod method)
81 case PaddingMethod::Exclude: return "Exclude";
82 case PaddingMethod::IgnoreValue: return "IgnoreValue";
83 default: return "Unknown";
87 constexpr unsigned int GetDataTypeSize(DataType dataType)
91 case DataType::Signed32:
92 case DataType::Float32: return 4U;
93 case DataType::QuantisedAsymm8: return 1U;
99 constexpr bool StrEqual(const char* strA, const char (&strB)[N])
102 for (int i = 0; isEqual && (i < N); ++i)
104 isEqual = (strA[i] == strB[i]);
109 constexpr Compute ParseComputeDevice(const char* str)
111 if (StrEqual(str, "CpuAcc"))
113 return armnn::Compute::CpuAcc;
115 else if (StrEqual(str, "CpuRef"))
117 return armnn::Compute::CpuRef;
119 else if (StrEqual(str, "GpuAcc"))
121 return armnn::Compute::GpuAcc;
125 return armnn::Compute::Undefined;
129 constexpr const char* GetDataTypeName(DataType dataType)
133 case DataType::Float32: return "Float32";
134 case DataType::QuantisedAsymm8: return "Unsigned8";
135 case DataType::Signed32: return "Signed32";
136 default: return "Unknown";
140 template <typename T>
141 constexpr DataType GetDataType();
144 constexpr DataType GetDataType<float>()
146 return DataType::Float32;
150 constexpr DataType GetDataType<uint8_t>()
152 return DataType::QuantisedAsymm8;
156 constexpr DataType GetDataType<int32_t>()
158 return DataType::Signed32;
162 constexpr bool IsQuantizedType()
164 return std::is_integral<T>::value;
168 template<DataType DT>
169 struct ResolveTypeImpl;
172 struct ResolveTypeImpl<DataType::QuantisedAsymm8>
174 using Type = uint8_t;
178 struct ResolveTypeImpl<DataType::Float32>
183 template<DataType DT>
184 using ResolveType = typename ResolveTypeImpl<DT>::Type;
187 inline std::ostream& operator<<(std::ostream& os, Status stat)
189 os << GetStatusAsCString(stat);
193 inline std::ostream& operator<<(std::ostream& os, Compute compute)
195 os << GetComputeDeviceAsCString(compute);
199 /// Quantize a floating point data type into an 8-bit data type
200 /// @param value The value to quantize
201 /// @param scale The scale (must be non-zero)
202 /// @param offset The offset
203 /// @return The quantized value calculated as round(value/scale)+offset
205 template<typename QuantizedType>
206 inline QuantizedType Quantize(float value, float scale, int32_t offset)
208 static_assert(IsQuantizedType<QuantizedType>(), "Not an integer type.");
209 constexpr QuantizedType max = std::numeric_limits<QuantizedType>::max();
210 constexpr QuantizedType min = std::numeric_limits<QuantizedType>::lowest();
211 BOOST_ASSERT(scale != 0.f);
212 int quantized = boost::numeric_cast<int>(round(value / scale)) + offset;
213 QuantizedType quantizedBits = quantized < min ? min : quantized > max ? max : static_cast<QuantizedType>(quantized);
214 return quantizedBits;
217 /// Dequantize an 8-bit data type into a floating point data type
218 /// @param value The value to dequantize
219 /// @param scale The scale (must be non-zero)
220 /// @param offset The offset
221 /// @return The dequantized value calculated as (value-offset)*scale
223 template <typename QuantizedType>
224 inline float Dequantize(QuantizedType value, float scale, int32_t offset)
226 static_assert(IsQuantizedType<QuantizedType>(), "Not an integer type.");
227 BOOST_ASSERT(scale != 0.f);
228 float dequantized = boost::numeric_cast<float>(value - offset) * scale;