Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / RefWorkloads / RefDepthwiseConvolution2dUint8Workload.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include "RefDepthwiseConvolution2dUint8Workload.hpp"
7
8 #include "ConvImpl.hpp"
9 #include "RefWorkloadUtils.hpp"
10
11 #include "Profiling.hpp"
12
13 namespace armnn
14 {
15
16 RefDepthwiseConvolution2dUint8Workload::RefDepthwiseConvolution2dUint8Workload(
17         const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info)
18         : Uint8Workload<DepthwiseConvolution2dQueueDescriptor>(descriptor, info),
19           m_Weight(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Weight))),
20           m_Bias(descriptor.m_Parameters.m_BiasEnabled
21                  ? std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Bias)) : nullptr) {}
22
23 void RefDepthwiseConvolution2dUint8Workload::Execute() const
24 {
25     ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefDepthwiseConvolution2dUint8Workload_Execute");
26
27     const uint8_t* inputData = GetInputTensorDataU8(0, m_Data);
28     const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
29     const uint8_t* weightsData = m_Weight->template GetConstTensor<uint8_t>();
30     const TensorInfo& weightsInfo = GetTensorInfo(m_Weight.get());
31     const int32_t* biasData = m_Data.m_Parameters.m_BiasEnabled ?
32         m_Bias->template GetConstTensor<int32_t>() :
33         nullptr;
34     uint8_t* outputData = GetOutputTensorDataU8(0, m_Data);
35     const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
36     const TensorInfo& filterInfo = m_Weight->GetTensorInfo();
37
38     ConvImpl<armnn::DepthwiseConvolution2dQueueDescriptor, uint8_t, int32_t, int32_t>(
39         m_Data,
40         inputData, inputInfo.GetQuantizationScale(),  inputInfo.GetQuantizationOffset(),
41         weightsData, weightsInfo.GetQuantizationScale(), weightsInfo.GetQuantizationOffset(),
42         biasData,
43         outputData, outputInfo.GetQuantizationScale(), outputInfo.GetQuantizationOffset(), filterInfo, true);
44 }
45
46 } //namespace armnn