Release 18.05.01
[platform/upstream/armnn.git] / src / armnn / backends / NeonWorkloads / NeonConvolution2dBaseWorkload.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #pragma once
7
8 #include <backends/Workload.hpp>
9 #include <backends/NeonWorkloadUtils.hpp>
10
11 #include "backends/CpuTensorHandle.hpp"
12 #include "backends/ArmComputeTensorUtils.hpp"
13 #include "backends/NeonLayerSupport.hpp"
14
15 #include "arm_compute/runtime/MemoryManagerOnDemand.h"
16
17 #include <memory>
18
19 namespace armnn
20 {
21
22 arm_compute::Status NeonConvolution2dWorkloadValidate(const TensorInfo& input,
23     const TensorInfo& output,
24     const Convolution2dDescriptor& descriptor,
25     const TensorInfo& weights,
26     const TensorInfo& biases);
27
28 template<armnn::DataType dataType>
29 class NeonConvolution2dBaseWorkload : public TypedWorkload<Convolution2dQueueDescriptor, dataType>
30 {
31 public:
32     using TypedWorkload<Convolution2dQueueDescriptor, dataType>::m_Data;
33
34     NeonConvolution2dBaseWorkload(const Convolution2dQueueDescriptor& descriptor, const WorkloadInfo& info,
35                                   std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager);
36
37     virtual void ValidateData() const {};
38
39 protected:
40     std::unique_ptr<arm_compute::IFunction> m_ConvolutionLayer;
41     arm_compute::Tensor m_KernelTensor;
42     arm_compute::Tensor m_BiasTensor;
43 };
44
45 } //namespace armnn