IVGCVSW-1946: Remove armnn/src from the include paths
[platform/upstream/armnn.git] / src / backends / neon / workloads / NeonLstmFloatWorkload.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <backendsCommon/Workload.hpp>
9 #include <backendsCommon/WorkloadData.hpp>
10
11 #include "arm_compute/graph/Tensor.h"
12 #include "arm_compute/runtime/NEON/functions/NELSTMLayer.h"
13
14 namespace armnn
15 {
16
17 class NeonLstmFloatWorkload : public FloatWorkload<LstmQueueDescriptor>
18 {
19 public:
20     NeonLstmFloatWorkload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info);
21     virtual void Execute() const override;
22
23 private:
24     mutable arm_compute::NELSTMLayer m_LstmLayer;
25
26     std::unique_ptr<arm_compute::Tensor> m_InputToInputWeightsTensor;
27     std::unique_ptr<arm_compute::Tensor> m_InputToForgetWeightsTensor;
28     std::unique_ptr<arm_compute::Tensor> m_InputToCellWeightsTensor;
29     std::unique_ptr<arm_compute::Tensor> m_InputToOutputWeightsTensor;
30     std::unique_ptr<arm_compute::Tensor> m_RecurrentToInputWeightsTensor;
31     std::unique_ptr<arm_compute::Tensor> m_RecurrentToForgetWeightsTensor;
32     std::unique_ptr<arm_compute::Tensor> m_RecurrentToCellWeightsTensor;
33     std::unique_ptr<arm_compute::Tensor> m_RecurrentToOutputWeightsTensor;
34     std::unique_ptr<arm_compute::Tensor> m_CellToInputWeightsTensor;
35     std::unique_ptr<arm_compute::Tensor> m_CellToForgetWeightsTensor;
36     std::unique_ptr<arm_compute::Tensor> m_CellToOutputWeightsTensor;
37     std::unique_ptr<arm_compute::Tensor> m_InputGateBiasTensor;
38     std::unique_ptr<arm_compute::Tensor> m_ForgetGateBiasTensor;
39     std::unique_ptr<arm_compute::Tensor> m_CellBiasTensor;
40     std::unique_ptr<arm_compute::Tensor> m_OutputGateBiasTensor;
41     std::unique_ptr<arm_compute::Tensor> m_ProjectionWeightsTensor;
42     std::unique_ptr<arm_compute::Tensor> m_ProjectionBiasTensor;
43
44     std::unique_ptr<arm_compute::Tensor> m_ScratchBuffer;
45
46     void FreeUnusedTensors();
47 };
48
49 arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input, const TensorInfo& outputStateIn,
50                                                   const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
51                                                   const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
52                                                   const TensorInfo& output, const LstmDescriptor &descriptor,
53                                                   const TensorInfo& inputToForgetWeights,
54                                                   const TensorInfo& inputToCellWeights,
55                                                   const TensorInfo& inputToOutputWeights,
56                                                   const TensorInfo& recurrentToForgetWeights,
57                                                   const TensorInfo& recurrentToCellWeights,
58                                                   const TensorInfo& recurrentToOutputWeights,
59                                                   const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
60                                                   const TensorInfo& outputGateBias,
61                                                   const TensorInfo* inputToInputWeights,
62                                                   const TensorInfo* recurrentToInputWeights,
63                                                   const TensorInfo* cellToInputWeights,
64                                                   const TensorInfo* inputGateBias,
65                                                   const TensorInfo* projectionWeights,
66                                                   const TensorInfo* projectionBias,
67                                                   const TensorInfo* cellToForgetWeights,
68                                                   const TensorInfo* cellToOutputWeights);
69
70 } //namespace armnn