2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
6 #include "backends/CpuTensorHandle.hpp"
7 #include "backends/ArmComputeTensorUtils.hpp"
8 #include "backends/NeonLayerSupport.hpp"
10 #include "NeonConvolution2dBaseWorkload.hpp"
15 template<armnn::DataType dataType>
16 NeonConvolution2dBaseWorkload<dataType>::NeonConvolution2dBaseWorkload(const Convolution2dQueueDescriptor& descriptor,
17 const WorkloadInfo& info)
18 : TypedWorkload<Convolution2dQueueDescriptor, dataType>(descriptor, info)
20 using arm_compute::NEDirectConvolutionLayer;
21 using namespace armcomputetensorutils;
25 // todo: check tensor shapes match
27 arm_compute::ITensor& input = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
28 arm_compute::ITensor& output = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
30 BuildArmComputeTensor(m_KernelTensor, m_Data.m_Weight->GetTensorInfo());
32 arm_compute::Tensor* optionalBiasTensor = nullptr;
33 if (m_Data.m_Parameters.m_BiasEnabled)
35 BuildArmComputeTensor(m_BiasTensor, m_Data.m_Bias->GetTensorInfo());
36 optionalBiasTensor = &m_BiasTensor;
39 arm_compute::PadStrideInfo padStrideInfo(m_Data.m_Parameters.m_StrideX,
40 m_Data.m_Parameters.m_StrideY,
41 m_Data.m_Parameters.m_PadLeft,
42 m_Data.m_Parameters.m_PadRight,
43 m_Data.m_Parameters.m_PadTop,
44 m_Data.m_Parameters.m_PadBottom,
45 arm_compute::DimensionRoundingType::FLOOR);
47 const bool preferDirectConvolution =
48 IsNeonDirectConvolutionPreferred(m_Data.m_Weight->GetTensorInfo(),
51 if (preferDirectConvolution)
53 auto directConvolutionLayer = std::make_unique<arm_compute::NEDirectConvolutionLayer>();
54 directConvolutionLayer->configure(&input,
59 m_ConvolutionLayer.reset(directConvolutionLayer.release());
63 auto convolutionLayer = std::make_unique<arm_compute::NEConvolutionLayer>();
64 convolutionLayer->configure(&input,
69 m_ConvolutionLayer.reset(convolutionLayer.release());
71 BOOST_ASSERT(m_ConvolutionLayer);
73 using Type = ResolveType<dataType>;
75 InitialiseArmComputeTensorData(m_KernelTensor, m_Data.m_Weight->template GetConstTensor<Type>());
78 // Generate known implementations for linker
79 template class NeonConvolution2dBaseWorkload<DataType::Float32>;
80 template class NeonConvolution2dBaseWorkload<DataType::QuantisedAsymm8>;