b5f2ae92235cb6145bb96499877b569dd3e9afb3
[platform/upstream/armnn.git] / src / backends / neon / workloads / NeonDepthwiseConvolutionWorkload.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <neon/workloads/NeonWorkloadUtils.hpp>
9
10 namespace armnn
11 {
12
13 arm_compute::Status NeonDepthwiseConvolutionWorkloadValidate(const TensorInfo& input,
14                                                              const TensorInfo& output,
15                                                              const DepthwiseConvolution2dDescriptor& descriptor,
16                                                              const TensorInfo& weights,
17                                                              const Optional<TensorInfo>& biases);
18
19 class NeonDepthwiseConvolutionWorkload : public BaseWorkload<DepthwiseConvolution2dQueueDescriptor>
20 {
21 public:
22     NeonDepthwiseConvolutionWorkload(const DepthwiseConvolution2dQueueDescriptor& descriptor,
23                                      const WorkloadInfo& info);
24
25     virtual void Execute() const override;
26
27 private:
28     mutable std::unique_ptr<arm_compute::IFunction> m_pDepthwiseConvolutionLayer;
29
30     std::unique_ptr<arm_compute::Tensor> m_KernelTensor;
31     std::unique_ptr<arm_compute::Tensor> m_BiasTensor;
32
33     void FreeUnusedTensors();
34 };
35
36 } // namespace armnn