IVGCVSW-1946: Remove armnn/src from the include paths
[platform/upstream/armnn.git] / src / backends / neon / workloads / NeonFullyConnectedWorkload.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <neon/workloads/NeonWorkloadUtils.hpp>
9
10 #include <arm_compute/runtime/MemoryManagerOnDemand.h>
11
12 #include <memory>
13
14 namespace armnn
15 {
16
17 arm_compute::Status NeonFullyConnectedWorkloadValidate(const TensorInfo& input,
18                                                        const TensorInfo& output,
19                                                        const TensorInfo& weights,
20                                                        const TensorInfo& biases,
21                                                        const FullyConnectedDescriptor& descriptor);
22
23 class NeonFullyConnectedWorkload : public BaseWorkload<FullyConnectedQueueDescriptor>
24 {
25 public:
26     NeonFullyConnectedWorkload(const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info,
27                                std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager);
28     virtual void Execute() const override;
29
30 private:
31     mutable arm_compute::NEFullyConnectedLayer m_FullyConnectedLayer;
32
33     std::unique_ptr<arm_compute::Tensor> m_WeightsTensor;
34     std::unique_ptr<arm_compute::Tensor> m_BiasesTensor;
35
36     void FreeUnusedTensors();
37 };
38
39 } //namespace armnn
40