Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / ClWorkloads / ClLstmFloat32Workload.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #pragma once
7
8 #include "backends/ClWorkloadUtils.hpp"
9 #include "backends/Workload.hpp"
10 #include "backends/WorkloadData.hpp"
11
12 namespace armnn
13 {
14
15 class ClLstmFloat32Workload : public FloatWorkload<LstmQueueDescriptor>
16 {
17 public:
18     ClLstmFloat32Workload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info);
19     void Execute() const override;
20
21 private:
22     mutable arm_compute::CLLSTMLayer m_LstmLayer;
23
24     std::unique_ptr<arm_compute::CLTensor> m_InputToInputWeightsTensor;
25     std::unique_ptr<arm_compute::CLTensor> m_InputToForgetWeightsTensor;
26     std::unique_ptr<arm_compute::CLTensor> m_InputToCellWeightsTensor;
27     std::unique_ptr<arm_compute::CLTensor> m_InputToOutputWeightsTensor;
28     std::unique_ptr<arm_compute::CLTensor> m_RecurrentToInputWeightsTensor;
29     std::unique_ptr<arm_compute::CLTensor> m_RecurrentToForgetWeightsTensor;
30     std::unique_ptr<arm_compute::CLTensor> m_RecurrentToCellWeightsTensor;
31     std::unique_ptr<arm_compute::CLTensor> m_RecurrentToOutputWeightsTensor;
32     std::unique_ptr<arm_compute::CLTensor> m_CellToInputWeightsTensor;
33     std::unique_ptr<arm_compute::CLTensor> m_CellToForgetWeightsTensor;
34     std::unique_ptr<arm_compute::CLTensor> m_CellToOutputWeightsTensor;
35     std::unique_ptr<arm_compute::CLTensor> m_InputGateBiasTensor;
36     std::unique_ptr<arm_compute::CLTensor> m_ForgetGateBiasTensor;
37     std::unique_ptr<arm_compute::CLTensor> m_CellBiasTensor;
38     std::unique_ptr<arm_compute::CLTensor> m_OutputGateBiasTensor;
39     std::unique_ptr<arm_compute::CLTensor> m_ProjectionWeightsTensor;
40     std::unique_ptr<arm_compute::CLTensor> m_ProjectionBiasTensor;
41
42     std::unique_ptr<arm_compute::CLTensor> m_ScratchBuffer;
43
44     void FreeUnusedTensors();
45 };
46
47 arm_compute::Status ClLstmFloat32WorkloadValidate(const TensorInfo& input, const TensorInfo& outputStateIn,
48                                                   const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
49                                                   const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
50                                                   const TensorInfo& output, const LstmDescriptor &descriptor,
51                                                   const TensorInfo& inputToForgetWeights,
52                                                   const TensorInfo& inputToCellWeights,
53                                                   const TensorInfo& inputToOutputWeights,
54                                                   const TensorInfo& recurrentToForgetWeights,
55                                                   const TensorInfo& recurrentToCellWeights,
56                                                   const TensorInfo& recurrentToOutputWeights,
57                                                   const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
58                                                   const TensorInfo& outputGateBias,
59                                                   const TensorInfo* inputToInputWeights,
60                                                   const TensorInfo* recurrentToInputWeights,
61                                                   const TensorInfo* cellToInputWeights,
62                                                   const TensorInfo* inputGateBias,
63                                                   const TensorInfo* projectionWeights,
64                                                   const TensorInfo* projectionBias,
65                                                   const TensorInfo* cellToForgetWeights,
66                                                   const TensorInfo* cellToOutputWeights);
67 } //namespace armnn