2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include <armnn/Tensor.hpp>
8 #include <armnn/Types.hpp>
17 constexpr char const* GetStatusAsCString(Status status)
21 case armnn::Status::Success: return "Status::Success";
22 case armnn::Status::Failure: return "Status::Failure";
23 default: return "Unknown";
27 constexpr char const* GetActivationFunctionAsCString(ActivationFunction activation)
31 case ActivationFunction::Sigmoid: return "Sigmoid";
32 case ActivationFunction::TanH: return "TanH";
33 case ActivationFunction::Linear: return "Linear";
34 case ActivationFunction::ReLu: return "ReLu";
35 case ActivationFunction::BoundedReLu: return "BoundedReLu";
36 case ActivationFunction::SoftReLu: return "SoftReLu";
37 case ActivationFunction::LeakyReLu: return "LeakyReLu";
38 case ActivationFunction::Abs: return "Abs";
39 case ActivationFunction::Sqrt: return "Sqrt";
40 case ActivationFunction::Square: return "Square";
41 default: return "Unknown";
45 constexpr char const* GetArgMinMaxFunctionAsCString(ArgMinMaxFunction function)
49 case ArgMinMaxFunction::Max: return "Max";
50 case ArgMinMaxFunction::Min: return "Min";
51 default: return "Unknown";
55 constexpr char const* GetComparisonOperationAsCString(ComparisonOperation operation)
59 case ComparisonOperation::Equal: return "Equal";
60 case ComparisonOperation::Greater: return "Greater";
61 case ComparisonOperation::GreaterOrEqual: return "GreaterOrEqual";
62 case ComparisonOperation::Less: return "Less";
63 case ComparisonOperation::LessOrEqual: return "LessOrEqual";
64 case ComparisonOperation::NotEqual: return "NotEqual";
65 default: return "Unknown";
69 constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)
73 case PoolingAlgorithm::Average: return "Average";
74 case PoolingAlgorithm::Max: return "Max";
75 case PoolingAlgorithm::L2: return "L2";
76 default: return "Unknown";
80 constexpr char const* GetOutputShapeRoundingAsCString(OutputShapeRounding rounding)
84 case OutputShapeRounding::Ceiling: return "Ceiling";
85 case OutputShapeRounding::Floor: return "Floor";
86 default: return "Unknown";
90 constexpr char const* GetPaddingMethodAsCString(PaddingMethod method)
94 case PaddingMethod::Exclude: return "Exclude";
95 case PaddingMethod::IgnoreValue: return "IgnoreValue";
96 default: return "Unknown";
100 constexpr unsigned int GetDataTypeSize(DataType dataType)
104 case DataType::Float16: return 2U;
105 case DataType::Float32:
106 case DataType::Signed32: return 4U;
107 case DataType::QAsymmU8: return 1U;
108 case DataType::QSymmS8: return 1U;
109 case DataType::QuantizedSymm8PerAxis: return 1U;
110 case DataType::QSymmS16: return 2U;
111 case DataType::Boolean: return 1U;
116 template <unsigned N>
117 constexpr bool StrEqual(const char* strA, const char (&strB)[N])
120 for (unsigned i = 0; isEqual && (i < N); ++i)
122 isEqual = (strA[i] == strB[i]);
127 /// Deprecated function that will be removed together with
129 constexpr armnn::Compute ParseComputeDevice(const char* str)
131 if (armnn::StrEqual(str, "CpuAcc"))
133 return armnn::Compute::CpuAcc;
135 else if (armnn::StrEqual(str, "CpuRef"))
137 return armnn::Compute::CpuRef;
139 else if (armnn::StrEqual(str, "GpuAcc"))
141 return armnn::Compute::GpuAcc;
145 return armnn::Compute::Undefined;
149 constexpr const char* GetDataTypeName(DataType dataType)
153 case DataType::Float16: return "Float16";
154 case DataType::Float32: return "Float32";
155 case DataType::QAsymmU8: return "QAsymmU8";
156 case DataType::QSymmS8: return "QSymmS8";
157 case DataType::QuantizedSymm8PerAxis: return "QSymm8PerAxis";
158 case DataType::QSymmS16: return "QSymm16";
159 case DataType::Signed32: return "Signed32";
160 case DataType::Boolean: return "Boolean";
167 constexpr const char* GetDataLayoutName(DataLayout dataLayout)
171 case DataLayout::NCHW: return "NCHW";
172 case DataLayout::NHWC: return "NHWC";
173 default: return "Unknown";
177 constexpr const char* GetNormalizationAlgorithmChannelAsCString(NormalizationAlgorithmChannel channel)
181 case NormalizationAlgorithmChannel::Across: return "Across";
182 case NormalizationAlgorithmChannel::Within: return "Within";
183 default: return "Unknown";
187 constexpr const char* GetNormalizationAlgorithmMethodAsCString(NormalizationAlgorithmMethod method)
191 case NormalizationAlgorithmMethod::LocalBrightness: return "LocalBrightness";
192 case NormalizationAlgorithmMethod::LocalContrast: return "LocalContrast";
193 default: return "Unknown";
197 constexpr const char* GetResizeMethodAsCString(ResizeMethod method)
201 case ResizeMethod::Bilinear: return "Bilinear";
202 case ResizeMethod::NearestNeighbor: return "NearestNeighbour";
203 default: return "Unknown";
209 : std::integral_constant<bool, std::is_floating_point<T>::value && sizeof(T) == 2>
213 constexpr bool IsQuantizedType()
215 return std::is_integral<T>::value;
218 constexpr bool IsQuantizedType(DataType dataType)
220 return dataType == DataType::QAsymmU8 ||
221 dataType == DataType::QSymmS8 ||
222 dataType == DataType::QSymmS16 ||
223 dataType == DataType::QuantizedSymm8PerAxis;
226 inline std::ostream& operator<<(std::ostream& os, Status stat)
228 os << GetStatusAsCString(stat);
233 inline std::ostream & operator<<(std::ostream & os, const armnn::TensorShape & shape)
236 for (uint32_t i=0; i<shape.GetNumDimensions(); ++i)
248 /// Quantize a floating point data type into an 8-bit data type.
249 /// @param value - The value to quantize.
250 /// @param scale - The scale (must be non-zero).
251 /// @param offset - The offset.
252 /// @return - The quantized value calculated as round(value/scale)+offset.
254 template<typename QuantizedType>
255 QuantizedType Quantize(float value, float scale, int32_t offset);
257 /// Dequantize an 8-bit data type into a floating point data type.
258 /// @param value - The value to dequantize.
259 /// @param scale - The scale (must be non-zero).
260 /// @param offset - The offset.
261 /// @return - The dequantized value calculated as (value-offset)*scale.
263 template <typename QuantizedType>
264 float Dequantize(QuantizedType value, float scale, int32_t offset);
266 inline void VerifyTensorInfoDataType(const armnn::TensorInfo & info, armnn::DataType dataType)
268 if (info.GetDataType() != dataType)
270 std::stringstream ss;
271 ss << "Unexpected datatype:" << armnn::GetDataTypeName(info.GetDataType())
272 << " for tensor:" << info.GetShape()
273 << ". The type expected to be: " << armnn::GetDataTypeName(dataType);
274 throw armnn::Exception(ss.str());