2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include <armnn/ArmNN.hpp>
8 #include <armnn/TypesUtils.hpp>
10 #include <initializer_list>
13 #include <boost/core/ignore_unused.hpp>
15 template<typename T, bool DoQuantize=true>
16 struct SelectiveQuantizer
18 static T Quantize(float value, float scale, int32_t offset)
20 return armnn::Quantize<T>(value, scale, offset);
23 static float Dequantize(T value, float scale, int32_t offset)
25 return armnn::Dequantize(value, scale, offset);
30 struct SelectiveQuantizer<T, false>
32 static T Quantize(float value, float scale, int32_t offset)
34 boost::ignore_unused(scale, offset);
38 static float Dequantize(T value, float scale, int32_t offset)
40 boost::ignore_unused(scale, offset);
46 T SelectiveQuantize(float value, float scale, int32_t offset)
48 return SelectiveQuantizer<T, armnn::IsQuantizedType<T>()>::Quantize(value, scale, offset);
52 float SelectiveDequantize(T value, float scale, int32_t offset)
54 return SelectiveQuantizer<T, armnn::IsQuantizedType<T>()>::Dequantize(value, scale, offset);
57 template<typename ItType>
58 struct IsFloatingPointIterator
60 static constexpr bool value=std::is_floating_point<typename std::iterator_traits<ItType>::value_type>::value;
63 template <typename T, typename FloatIt,
64 typename std::enable_if<IsFloatingPointIterator<FloatIt>::value, int>::type=0 // Makes sure fp iterator is valid.
66 std::vector<T> QuantizedVector(float qScale, int32_t qOffset, FloatIt first, FloatIt last)
68 std::vector<T> quantized;
69 quantized.reserve(boost::numeric_cast<size_t>(std::distance(first, last)));
71 for (auto it = first; it != last; ++it)
74 T q =SelectiveQuantize<T>(f, qScale, qOffset);
75 quantized.push_back(q);
82 std::vector<T> QuantizedVector(float qScale, int32_t qOffset, const std::vector<float>& array)
84 return QuantizedVector<T>(qScale, qOffset, array.begin(), array.end());
88 std::vector<T> QuantizedVector(float qScale, int32_t qOffset, std::initializer_list<float> array)
90 return QuantizedVector<T>(qScale, qOffset, array.begin(), array.end());