2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
8 #include "backends/ClWorkloadUtils.hpp"
9 #include "backends/Workload.hpp"
10 #include "backends/WorkloadData.hpp"
15 class ClLstmFloat32Workload : public FloatWorkload<LstmQueueDescriptor>
18 ClLstmFloat32Workload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info);
19 void Execute() const override;
22 mutable arm_compute::CLLSTMLayer m_LstmLayer;
24 std::unique_ptr<arm_compute::CLTensor> m_InputToInputWeightsTensor;
25 std::unique_ptr<arm_compute::CLTensor> m_InputToForgetWeightsTensor;
26 std::unique_ptr<arm_compute::CLTensor> m_InputToCellWeightsTensor;
27 std::unique_ptr<arm_compute::CLTensor> m_InputToOutputWeightsTensor;
28 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToInputWeightsTensor;
29 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToForgetWeightsTensor;
30 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToCellWeightsTensor;
31 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToOutputWeightsTensor;
32 std::unique_ptr<arm_compute::CLTensor> m_CellToInputWeightsTensor;
33 std::unique_ptr<arm_compute::CLTensor> m_CellToForgetWeightsTensor;
34 std::unique_ptr<arm_compute::CLTensor> m_CellToOutputWeightsTensor;
35 std::unique_ptr<arm_compute::CLTensor> m_InputGateBiasTensor;
36 std::unique_ptr<arm_compute::CLTensor> m_ForgetGateBiasTensor;
37 std::unique_ptr<arm_compute::CLTensor> m_CellBiasTensor;
38 std::unique_ptr<arm_compute::CLTensor> m_OutputGateBiasTensor;
39 std::unique_ptr<arm_compute::CLTensor> m_ProjectionWeightsTensor;
40 std::unique_ptr<arm_compute::CLTensor> m_ProjectionBiasTensor;
42 std::unique_ptr<arm_compute::CLTensor> m_ScratchBuffer;
44 void FreeUnusedTensors();
47 arm_compute::Status ClLstmFloat32WorkloadValidate(const TensorInfo& input, const TensorInfo& outputStateIn,
48 const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
49 const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
50 const TensorInfo& output, const LstmDescriptor &descriptor,
51 const TensorInfo& inputToForgetWeights,
52 const TensorInfo& inputToCellWeights,
53 const TensorInfo& inputToOutputWeights,
54 const TensorInfo& recurrentToForgetWeights,
55 const TensorInfo& recurrentToCellWeights,
56 const TensorInfo& recurrentToOutputWeights,
57 const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
58 const TensorInfo& outputGateBias,
59 const TensorInfo* inputToInputWeights,
60 const TensorInfo* recurrentToInputWeights,
61 const TensorInfo* cellToInputWeights,
62 const TensorInfo* inputGateBias,
63 const TensorInfo* projectionWeights,
64 const TensorInfo* projectionBias,
65 const TensorInfo* cellToForgetWeights,
66 const TensorInfo* cellToOutputWeights);