IVGCVSW-3519 Refactor TransposeConvolution2dWorkload for CL backed
[platform/upstream/armnn.git] / src / backends / cl / workloads / ClTransposeConvolution2dWorkload.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "ClTransposeConvolution2dWorkload.hpp"
7
8 #include "ClWorkloadUtils.hpp"
9
10 #include <cl/ClLayerSupport.hpp>
11 #include <cl/ClTensorHandle.hpp>
12 #include <cl/ClLayerSupport.hpp>
13
14 #include <aclCommon/ArmComputeUtils.hpp>
15 #include <aclCommon/ArmComputeTensorUtils.hpp>
16
17 #include <backendsCommon/CpuTensorHandle.hpp>
18
19 #include <arm_compute/runtime/CL/functions/CLDeconvolutionLayer.h>
20
21 namespace armnn
22 {
23
24 using namespace armcomputetensorutils;
25
26 arm_compute::Status ClTransposeConvolution2dWorkloadValidate(const TensorInfo& input,
27                                                              const TensorInfo& output,
28                                                              const TransposeConvolution2dDescriptor& descriptor,
29                                                              const TensorInfo& weights,
30                                                              const Optional<TensorInfo>& biases)
31 {
32     arm_compute::TensorInfo aclInputInfo   = BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
33     arm_compute::TensorInfo aclOutputInfo  = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
34     arm_compute::TensorInfo aclWeightsInfo = BuildArmComputeTensorInfo(weights, descriptor.m_DataLayout);
35
36     arm_compute::TensorInfo aclBiasesInfo;
37     arm_compute::TensorInfo *optionalAclBiasesInfo = nullptr;
38
39     if (descriptor.m_BiasEnabled)
40     {
41         BOOST_ASSERT(biases.has_value());
42
43         aclBiasesInfo = BuildArmComputeTensorInfo(biases.value(), descriptor.m_DataLayout);
44         optionalAclBiasesInfo = &aclBiasesInfo;
45     }
46
47     arm_compute::PadStrideInfo padStrideInfo = BuildArmComputePadStrideInfo(descriptor);
48
49     return arm_compute::CLDeconvolutionLayer::validate(&aclInputInfo,
50                                                        &aclWeightsInfo,
51                                                        optionalAclBiasesInfo,
52                                                        &aclOutputInfo,
53                                                        padStrideInfo);
54 }
55
56 ClTransposeConvolution2dWorkload::ClTransposeConvolution2dWorkload(
57     const TransposeConvolution2dQueueDescriptor& descriptor,
58     const WorkloadInfo& info,
59     std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager) :
60     BaseWorkload<TransposeConvolution2dQueueDescriptor>(descriptor, info),
61     m_Layer(memoryManager)
62 {
63     const TensorInfo& weightInfo = m_Data.m_Weight->GetTensorInfo();
64
65     m_WeightsTensor = std::make_unique<arm_compute::CLTensor>();
66     BuildArmComputeTensor(*m_WeightsTensor, weightInfo, m_Data.m_Parameters.m_DataLayout);
67
68     arm_compute::PadStrideInfo padStrideInfo(m_Data.m_Parameters.m_StrideX,
69                                              m_Data.m_Parameters.m_StrideY,
70                                              m_Data.m_Parameters.m_PadLeft,
71                                              m_Data.m_Parameters.m_PadRight,
72                                              m_Data.m_Parameters.m_PadTop,
73                                              m_Data.m_Parameters.m_PadBottom,
74                                              arm_compute::DimensionRoundingType::FLOOR);
75
76     if (m_Data.m_Parameters.m_BiasEnabled)
77     {
78         m_BiasesTensor = std::make_unique<arm_compute::CLTensor>();
79         BuildArmComputeTensor(*m_BiasesTensor, m_Data.m_Bias->GetTensorInfo(), m_Data.m_Parameters.m_DataLayout);
80     }
81
82     m_Data.ValidateInputsOutputs("ClTransposeConvolution2dWorkload", 1, 1);
83
84     arm_compute::ICLTensor& input  = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
85     arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
86
87     arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout);
88
89     input.info()->set_data_layout(aclDataLayout);
90     output.info()->set_data_layout(aclDataLayout);
91
92     m_Layer.configure(&input, m_WeightsTensor.get(), m_BiasesTensor.get(), &output, padStrideInfo);
93
94     InitializeArmComputeClTensorData(*m_WeightsTensor, m_Data.m_Weight);
95     if (m_BiasesTensor)
96     {
97         InitializeArmComputeClTensorData(*m_BiasesTensor, m_Data.m_Bias);
98     }
99
100     m_Layer.prepare();
101
102     FreeUnusedTensors();
103 }
104
105 void ClTransposeConvolution2dWorkload::Execute() const
106 {
107     ARMNN_SCOPED_PROFILING_EVENT_CL("ClTransposeConvolution2dWorkload_Execute");
108     RunClFunction(m_Layer, CHECK_LOCATION());
109 }
110
111 void ClTransposeConvolution2dWorkload::FreeUnusedTensors()
112 {
113     FreeTensorIfUnused(m_WeightsTensor);
114     FreeTensorIfUnused(m_BiasesTensor);
115 }
116
117 } // namespace armnn