IVGCVSW-1946: Remove armnn/src from the include paths
[platform/upstream/armnn.git] / src / backends / cl / workloads / ClLstmFloatWorkload.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <backendsCommon/Workload.hpp>
9 #include <backendsCommon/WorkloadData.hpp>
10
11 #include <arm_compute/runtime/CL/CLFunctions.h>
12
13 namespace armnn
14 {
15
16 class ClLstmFloatWorkload : public FloatWorkload<LstmQueueDescriptor>
17 {
18 public:
19     ClLstmFloatWorkload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info);
20     void Execute() const override;
21
22 private:
23     mutable arm_compute::CLLSTMLayer m_LstmLayer;
24
25     std::unique_ptr<arm_compute::CLTensor> m_InputToInputWeightsTensor;
26     std::unique_ptr<arm_compute::CLTensor> m_InputToForgetWeightsTensor;
27     std::unique_ptr<arm_compute::CLTensor> m_InputToCellWeightsTensor;
28     std::unique_ptr<arm_compute::CLTensor> m_InputToOutputWeightsTensor;
29     std::unique_ptr<arm_compute::CLTensor> m_RecurrentToInputWeightsTensor;
30     std::unique_ptr<arm_compute::CLTensor> m_RecurrentToForgetWeightsTensor;
31     std::unique_ptr<arm_compute::CLTensor> m_RecurrentToCellWeightsTensor;
32     std::unique_ptr<arm_compute::CLTensor> m_RecurrentToOutputWeightsTensor;
33     std::unique_ptr<arm_compute::CLTensor> m_CellToInputWeightsTensor;
34     std::unique_ptr<arm_compute::CLTensor> m_CellToForgetWeightsTensor;
35     std::unique_ptr<arm_compute::CLTensor> m_CellToOutputWeightsTensor;
36     std::unique_ptr<arm_compute::CLTensor> m_InputGateBiasTensor;
37     std::unique_ptr<arm_compute::CLTensor> m_ForgetGateBiasTensor;
38     std::unique_ptr<arm_compute::CLTensor> m_CellBiasTensor;
39     std::unique_ptr<arm_compute::CLTensor> m_OutputGateBiasTensor;
40     std::unique_ptr<arm_compute::CLTensor> m_ProjectionWeightsTensor;
41     std::unique_ptr<arm_compute::CLTensor> m_ProjectionBiasTensor;
42
43     std::unique_ptr<arm_compute::CLTensor> m_ScratchBuffer;
44
45     void FreeUnusedTensors();
46 };
47
48 arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const TensorInfo& outputStateIn,
49                                                 const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
50                                                 const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
51                                                 const TensorInfo& output, const LstmDescriptor &descriptor,
52                                                 const TensorInfo& inputToForgetWeights,
53                                                 const TensorInfo& inputToCellWeights,
54                                                 const TensorInfo& inputToOutputWeights,
55                                                 const TensorInfo& recurrentToForgetWeights,
56                                                 const TensorInfo& recurrentToCellWeights,
57                                                 const TensorInfo& recurrentToOutputWeights,
58                                                 const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
59                                                 const TensorInfo& outputGateBias,
60                                                 const TensorInfo* inputToInputWeights,
61                                                 const TensorInfo* recurrentToInputWeights,
62                                                 const TensorInfo* cellToInputWeights,
63                                                 const TensorInfo* inputGateBias,
64                                                 const TensorInfo* projectionWeights,
65                                                 const TensorInfo* projectionBias,
66                                                 const TensorInfo* cellToForgetWeights,
67                                                 const TensorInfo* cellToOutputWeights);
68 } //namespace armnn