a78804b709791db461f2dfcb60b1d7f4dbaecb87
[platform/upstream/armnn.git] / src / backends / reference / workloads / RefQuantizeWorkload.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "RefQuantizeWorkload.hpp"
7
8 #include <armnn/TypesUtils.hpp>
9
10
11 namespace armnn
12 {
13
14 namespace
15 {
16
17 template<typename T>
18 void QuantizeImpl(const void *input, void *output, size_t numValues, float scale, int offset)
19 {
20     auto in = static_cast<const float *>(input);
21     auto out = static_cast<T *>(output);
22     for (size_t i = 0; i < numValues; i++, in++, out++)
23     {
24         *out = armnn::Quantize<T>(*in, scale, offset);
25     }
26 }
27
28 } //namespace
29
30 RefQuantizeWorkload::RefQuantizeWorkload(const QuantizeQueueDescriptor& descriptor, const WorkloadInfo &info)
31     : BaseWorkload(descriptor, info)
32     , m_NumElements(info.m_InputTensorInfos[0].GetNumElements())
33     , m_TargetType(info.m_OutputTensorInfos[0].GetDataType())
34     , m_Scale(info.m_OutputTensorInfos[0].GetQuantizationScale())
35     , m_Offset(info.m_OutputTensorInfos[0].GetQuantizationOffset())
36 {
37 }
38
39 void RefQuantizeWorkload::Execute() const
40 {
41     const void* input = m_Data.m_Inputs[0]->Map(true);
42     void* output =  m_Data.m_Outputs[0]->Map(true);
43
44     switch(m_TargetType)
45     {
46         case DataType::QuantisedAsymm8:
47         {
48             QuantizeImpl<uint8_t>(input, output, m_NumElements, m_Scale, m_Offset);
49             break;
50         }
51         case DataType::QSymmS8:
52         {
53             QuantizeImpl<int8_t>(input, output, m_NumElements, m_Scale, m_Offset);
54             break;
55         }
56         case DataType::QuantisedSymm16:
57         {
58             QuantizeImpl<int16_t>(input, output, m_NumElements, m_Scale, 0);
59             break;
60         }
61         default:
62         {
63             BOOST_ASSERT_MSG(false, "RefQuantizeWorkload: Non quantized output type encountered");
64         }
65     }
66
67     m_Data.m_Inputs[0]->Unmap();
68     m_Data.m_Outputs[0]->Unmap();
69 }
70
71 } //namespace armnn