Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / ClWorkloads / ClLstmFloat32Workload.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include "ClLstmFloat32Workload.hpp"
7 #include "backends/ClTensorHandle.hpp"
8 #include "backends/CpuTensorHandle.hpp"
9 #include "backends/ArmComputeTensorUtils.hpp"
10 #include "backends/ClLayerSupport.hpp"
11 #include "arm_compute/runtime/CL/functions/CLLSTMLayer.h"
12
13 namespace armnn
14 {
15 using namespace armcomputetensorutils;
16
17 ClLstmFloat32Workload::ClLstmFloat32Workload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
18         : FloatWorkload<LstmQueueDescriptor>(descriptor, info)
19 {
20     arm_compute::LSTMParams<arm_compute::ICLTensor> lstm_param;
21
22     // Basic parameters
23     m_InputToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
24     BuildArmComputeTensor(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights->GetTensorInfo());
25
26     m_InputToCellWeightsTensor = std::make_unique<arm_compute::CLTensor>();
27     BuildArmComputeTensor(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights->GetTensorInfo());
28
29     m_InputToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
30     BuildArmComputeTensor(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights->GetTensorInfo());
31
32     m_RecurrentToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
33     BuildArmComputeTensor(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights->GetTensorInfo());
34
35     m_RecurrentToCellWeightsTensor = std::make_unique<arm_compute::CLTensor>();
36     BuildArmComputeTensor(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights->GetTensorInfo());
37
38     m_RecurrentToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
39     BuildArmComputeTensor(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights->GetTensorInfo());
40
41     m_ForgetGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
42     BuildArmComputeTensor(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias->GetTensorInfo());
43
44     m_CellBiasTensor = std::make_unique<arm_compute::CLTensor>();
45     BuildArmComputeTensor(*m_CellBiasTensor, m_Data.m_CellBias->GetTensorInfo());
46
47     m_OutputGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
48     BuildArmComputeTensor(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias->GetTensorInfo());
49
50     // for future reference: check the AndroidNN API for the logic here
51     if (!m_Data.m_Parameters.m_CifgEnabled)
52     {
53         m_InputToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
54         BuildArmComputeTensor(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights->GetTensorInfo());
55
56         m_RecurrentToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
57         BuildArmComputeTensor(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights->GetTensorInfo());
58
59         m_CellToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
60         if (m_Data.m_CellToInputWeights != nullptr)
61         {
62             BuildArmComputeTensor(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights->GetTensorInfo());
63         }
64
65         m_InputGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
66         BuildArmComputeTensor(*m_InputGateBiasTensor, m_Data.m_InputGateBias->GetTensorInfo());
67
68         lstm_param.set_cifg_params(m_InputToInputWeightsTensor.get(),
69                                    m_RecurrentToInputWeightsTensor.get(),
70                                    m_Data.m_CellToInputWeights != nullptr ? m_CellToInputWeightsTensor.get() : nullptr,
71                                    m_InputGateBiasTensor.get());
72     }
73
74     if (m_Data.m_Parameters.m_ProjectionEnabled)
75     {
76         m_ProjectionWeightsTensor = std::make_unique<arm_compute::CLTensor>();
77         BuildArmComputeTensor(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights->GetTensorInfo());
78
79         m_ProjectionBiasTensor = std::make_unique<arm_compute::CLTensor>();
80         if (m_Data.m_ProjectionBias != nullptr)
81         {
82             BuildArmComputeTensor(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias->GetTensorInfo());
83         }
84
85         lstm_param.set_projection_params(m_ProjectionWeightsTensor.get(),
86                                          m_Data.m_ProjectionBias != nullptr ? m_ProjectionBiasTensor.get() : nullptr);
87     }
88
89     if (m_Data.m_Parameters.m_PeepholeEnabled)
90     {
91         m_CellToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
92         BuildArmComputeTensor(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights->GetTensorInfo());
93
94         m_CellToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
95         BuildArmComputeTensor(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights->GetTensorInfo());
96
97         lstm_param.set_peephole_params(m_CellToForgetWeightsTensor.get(), m_CellToOutputWeightsTensor.get());
98     }
99
100     const arm_compute::ICLTensor& input           = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
101     const arm_compute::ICLTensor& output_state_in = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
102     const arm_compute::ICLTensor& cell_state_in   = static_cast<IClTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
103
104     arm_compute::ICLTensor& output_state_out      = static_cast<IClTensorHandle*>(m_Data.m_Outputs[1])->GetTensor();
105     arm_compute::ICLTensor& cell_state_out        = static_cast<IClTensorHandle*>(m_Data.m_Outputs[2])->GetTensor();
106     arm_compute::ICLTensor& output                = static_cast<IClTensorHandle*>(m_Data.m_Outputs[3])->GetTensor();
107
108     // Get the batch_size and the num_units from the cellStateIn dimensions
109     const TensorInfo& inputTensorInfo = info.m_InputTensorInfos[2];
110     const unsigned int batch_size = boost::numeric_cast<unsigned int>(inputTensorInfo.GetShape()[0]);
111     const unsigned int num_units  = boost::numeric_cast<unsigned int>(inputTensorInfo.GetShape()[1]);
112
113     m_ScratchBuffer = std::make_unique<arm_compute::CLTensor>();
114     if (m_Data.m_Parameters.m_CifgEnabled)
115     {
116         // 2D tensor with dimensions [num_units * 4, batch_size] with CIFG
117         armnn::TensorInfo scratchBuffer1({ batch_size, num_units * 4 }, DataType::Float32);
118         BuildArmComputeTensor(*m_ScratchBuffer, scratchBuffer1);
119     }
120     else
121     {
122         // scratch_buffer [num_units * 3, batch_size] without CIFG
123         armnn::TensorInfo scratchBuffer2({ batch_size, num_units * 3 }, DataType::Float32);
124         BuildArmComputeTensor(*m_ScratchBuffer, scratchBuffer2);
125     }
126
127     float cell_threshold = m_Data.m_Parameters.m_ClippingThresCell;
128     float projection_threshold = m_Data.m_Parameters.m_ClippingThresProj;
129
130     // for preparing the object for the class ActivationLayerInfo, we need to consider 5 situations
131     arm_compute::ActivationLayerInfo activationLayerInfo;
132     if (m_Data.m_Parameters.m_ActivationFunc == 0)
133     {
134         // no activation, do nothing
135     }
136     else if (m_Data.m_Parameters.m_ActivationFunc == 1)
137     {
138         activationLayerInfo = arm_compute::ActivationLayerInfo(
139                 arm_compute::ActivationLayerInfo::ActivationFunction::RELU);
140     }
141     else if (m_Data.m_Parameters.m_ActivationFunc == 3)
142     {
143         activationLayerInfo = arm_compute::ActivationLayerInfo(
144                 arm_compute::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.0);
145     }
146     else if (m_Data.m_Parameters.m_ActivationFunc == 4)
147     {
148         activationLayerInfo =  arm_compute::ActivationLayerInfo(
149                 arm_compute::ActivationLayerInfo::ActivationFunction::TANH, 1.0, 1.0);
150     }
151     else if (m_Data.m_Parameters.m_ActivationFunc == 6)
152     {
153         activationLayerInfo =  arm_compute::ActivationLayerInfo(
154                 arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC);
155     }
156     else
157     {
158         throw armnn::Exception("Wrong Type of Activation Function!");
159     }
160
161
162     m_LstmLayer.configure(&input, m_InputToForgetWeightsTensor.get(), m_InputToCellWeightsTensor.get(),
163                           m_InputToOutputWeightsTensor.get(), m_RecurrentToForgetWeightsTensor.get(),
164                           m_RecurrentToCellWeightsTensor.get(), m_RecurrentToOutputWeightsTensor.get(),
165                           m_ForgetGateBiasTensor.get(), m_CellBiasTensor.get(), m_OutputGateBiasTensor.get(),
166                           &output_state_in, &cell_state_in, m_ScratchBuffer.get(), &output_state_out,
167                           &cell_state_out, &output, lstm_param, activationLayerInfo,
168                           cell_threshold, projection_threshold);
169
170     armcomputetensorutils::InitialiseArmComputeTensorEmpty(*m_ScratchBuffer);
171
172     InitialiseArmComputeClTensorData(*m_InputToForgetWeightsTensor,
173                                      m_Data.m_InputToForgetWeights->GetConstTensor<float>());
174     InitialiseArmComputeClTensorData(*m_InputToCellWeightsTensor,
175                                      m_Data.m_InputToCellWeights->GetConstTensor<float>());
176     InitialiseArmComputeClTensorData(*m_InputToOutputWeightsTensor,
177                                      m_Data.m_InputToOutputWeights->GetConstTensor<float>());
178     InitialiseArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor,
179                                      m_Data.m_RecurrentToForgetWeights->GetConstTensor<float>());
180     InitialiseArmComputeClTensorData(*m_RecurrentToCellWeightsTensor,
181                                      m_Data.m_RecurrentToCellWeights->GetConstTensor<float>());
182     InitialiseArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor,
183                                      m_Data.m_RecurrentToOutputWeights->GetConstTensor<float>());
184     InitialiseArmComputeClTensorData(*m_ForgetGateBiasTensor,
185                                      m_Data.m_ForgetGateBias->GetConstTensor<float>());
186     InitialiseArmComputeClTensorData(*m_CellBiasTensor,
187                                      m_Data.m_CellBias->GetConstTensor<float>());
188     InitialiseArmComputeClTensorData(*m_OutputGateBiasTensor,
189                                      m_Data.m_OutputGateBias->GetConstTensor<float>());
190
191     if (!m_Data.m_Parameters.m_CifgEnabled)
192     {
193         InitialiseArmComputeClTensorData(*m_InputToInputWeightsTensor,
194                                          m_Data.m_InputToInputWeights->GetConstTensor<float>());
195         InitialiseArmComputeClTensorData(*m_RecurrentToInputWeightsTensor,
196                                          m_Data.m_RecurrentToInputWeights->GetConstTensor<float>());
197         if (m_Data.m_CellToInputWeights != nullptr)
198         {
199             InitialiseArmComputeClTensorData(*m_CellToInputWeightsTensor,
200                                              m_Data.m_CellToInputWeights->GetConstTensor<float>());
201         }
202         InitialiseArmComputeClTensorData(*m_InputGateBiasTensor,
203                                          m_Data.m_InputGateBias->GetConstTensor<float>());
204     }
205
206     if (m_Data.m_Parameters.m_ProjectionEnabled)
207     {
208         InitialiseArmComputeClTensorData(*m_ProjectionWeightsTensor,
209                                          m_Data.m_ProjectionWeights->GetConstTensor<float>());
210         if (m_Data.m_ProjectionBias != nullptr)
211         {
212             InitialiseArmComputeClTensorData(*m_ProjectionBiasTensor,
213                                              m_Data.m_ProjectionBias->GetConstTensor<float>());
214         }
215     }
216
217     if (m_Data.m_Parameters.m_PeepholeEnabled)
218     {
219         InitialiseArmComputeClTensorData(*m_CellToForgetWeightsTensor,
220                                          m_Data.m_CellToForgetWeights->GetConstTensor<float>());
221         InitialiseArmComputeClTensorData(*m_CellToOutputWeightsTensor,
222                                          m_Data.m_CellToOutputWeights->GetConstTensor<float>());
223     }
224
225     // Force Compute Library to perform the necessary copying and reshaping, after which
226     // delete all the input tensors that will no longer be needed
227     m_LstmLayer.prepare();
228     FreeUnusedTensors();
229 }
230
231 void ClLstmFloat32Workload::Execute() const
232 {
233     m_LstmLayer.run();
234 }
235
236 arm_compute::Status ClLstmFloat32WorkloadValidate(const TensorInfo& input, const TensorInfo& outputStateIn,
237                                                   const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
238                                                   const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
239                                                   const TensorInfo& output, const LstmDescriptor& descriptor,
240                                                   const TensorInfo& inputToForgetWeights,
241                                                   const TensorInfo& inputToCellWeights,
242                                                   const TensorInfo& inputToOutputWeights,
243                                                   const TensorInfo& recurrentToForgetWeights,
244                                                   const TensorInfo& recurrentToCellWeights,
245                                                   const TensorInfo& recurrentToOutputWeights,
246                                                   const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
247                                                   const TensorInfo& outputGateBias,
248                                                   const TensorInfo* inputToInputWeights,
249                                                   const TensorInfo* recurrentToInputWeights,
250                                                   const TensorInfo* cellToInputWeights,
251                                                   const TensorInfo* inputGateBias,
252                                                   const TensorInfo* projectionWeights,
253                                                   const TensorInfo* projectionBias,
254                                                   const TensorInfo* cellToForgetWeights,
255                                                   const TensorInfo* cellToOutputWeights)
256 {
257     arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info;
258
259     // The inputs and the outputs
260     const arm_compute::TensorInfo aclInputInfo  = BuildArmComputeTensorInfo(input);
261     const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);
262     const arm_compute::TensorInfo aclCellStateInInfo = BuildArmComputeTensorInfo(cellStateIn);
263     const arm_compute::TensorInfo aclScratchBufferInfo = BuildArmComputeTensorInfo(scratchBuffer);
264     const arm_compute::TensorInfo aclOutputStateOutInfo = BuildArmComputeTensorInfo(outputStateOut);
265     const arm_compute::TensorInfo aclCellStateOutInfo = BuildArmComputeTensorInfo(cellStateOut);
266     const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
267
268     // Basic parameters
269     const arm_compute::TensorInfo aclInputToForgetWeightsInfo = BuildArmComputeTensorInfo(inputToForgetWeights);
270     const arm_compute::TensorInfo aclInputToCellWeightsInfo = BuildArmComputeTensorInfo(inputToCellWeights);
271     const arm_compute::TensorInfo aclInputToOutputWeightsInfo = BuildArmComputeTensorInfo(inputToOutputWeights);
272     const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
273                                   = BuildArmComputeTensorInfo(recurrentToForgetWeights);
274     const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
275                                   = BuildArmComputeTensorInfo(recurrentToCellWeights);
276     const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
277                                   = BuildArmComputeTensorInfo(recurrentToOutputWeights);
278     const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(forgetGateBias);
279     const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(cellBias);
280     const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(outputGateBias);
281
282     arm_compute::TensorInfo aclInputToInputWeightsInfo;
283     arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
284     arm_compute::TensorInfo aclCellToInputWeightsInfo;
285     arm_compute::TensorInfo aclInputGateBiasInfo;
286     arm_compute::TensorInfo aclProjectionWeightsInfo;
287     arm_compute::TensorInfo aclProjectionBiasInfo;
288     arm_compute::TensorInfo aclCellToForgetWeightsInfo;
289     arm_compute::TensorInfo aclCellToOutputWeightsInfo;
290
291     if (!descriptor.m_CifgEnabled)
292     {
293         armnn::TensorInfo inputToInputWInfo = *inputToInputWeights;
294         aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(inputToInputWInfo);
295         armnn::TensorInfo recurrentToInputWInfo = *recurrentToInputWeights;
296         aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(recurrentToInputWInfo);
297
298         if (cellToInputWeights != nullptr)
299         {
300             armnn::TensorInfo cellToInputWInfo = *cellToInputWeights;
301             aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(cellToInputWInfo);
302         }
303         armnn::TensorInfo inputGateBiasInfo = *inputGateBias;
304         aclInputGateBiasInfo = BuildArmComputeTensorInfo(inputGateBiasInfo);
305         lstm_params_info.set_cifg_params(&aclInputToInputWeightsInfo, &aclRecurrentToInputWeightsInfo,
306                                          cellToInputWeights != nullptr ? &aclCellToInputWeightsInfo: nullptr,
307                                          &aclInputGateBiasInfo);
308     }
309
310     if (descriptor.m_ProjectionEnabled)
311     {
312         const armnn::TensorInfo& projectionWInfo = *projectionWeights;
313         aclProjectionWeightsInfo = BuildArmComputeTensorInfo(projectionWInfo);
314
315         if (projectionBias != nullptr)
316         {
317             const armnn::TensorInfo& projectionBiasInfo = *projectionBias;
318             aclProjectionBiasInfo = BuildArmComputeTensorInfo(projectionBiasInfo);
319         }
320         lstm_params_info.set_projection_params(&aclProjectionWeightsInfo,
321                                                projectionBias != nullptr ? &aclProjectionBiasInfo: nullptr);
322     }
323
324     if (descriptor.m_PeepholeEnabled)
325     {
326         const armnn::TensorInfo& cellToForgetWInfo = *cellToForgetWeights;
327         aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(cellToForgetWInfo);
328         const armnn::TensorInfo& cellToOutputWInfo = *cellToOutputWeights;
329         aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(cellToOutputWInfo);
330         lstm_params_info.set_peephole_params(&aclCellToForgetWeightsInfo, &aclCellToOutputWeightsInfo);
331     }
332
333     float cell_threshold = descriptor.m_ClippingThresCell;
334     float projection_threshold = descriptor.m_ClippingThresProj;
335
336     // for preparing the object for the class ActivationLayerInfo, we need to consider 5 situations
337     arm_compute::ActivationLayerInfo activationLayerInfo;
338     if (descriptor.m_ActivationFunc == 0)
339     {
340         // no activation, do nothing
341     }
342     else if (descriptor.m_ActivationFunc == 1)
343     {
344         activationLayerInfo = arm_compute::ActivationLayerInfo(
345                 arm_compute::ActivationLayerInfo::ActivationFunction::RELU);
346     }
347     else if (descriptor.m_ActivationFunc == 3)
348     {
349         activationLayerInfo = arm_compute::ActivationLayerInfo(
350                 arm_compute::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.0);
351     }
352     else if (descriptor.m_ActivationFunc == 4)
353     {
354         activationLayerInfo =  arm_compute::ActivationLayerInfo(
355                 arm_compute::ActivationLayerInfo::ActivationFunction::TANH, 1.0, 1.0);
356     }
357     else if (descriptor.m_ActivationFunc == 6)
358     {
359         activationLayerInfo =  arm_compute::ActivationLayerInfo(
360                 arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC);
361     }
362     else
363     {
364         throw armnn::Exception("Wrong Type of Activation Function!");
365     }
366
367     return arm_compute::CLLSTMLayer::validate(&aclInputInfo, &aclInputToForgetWeightsInfo,
368                                               &aclInputToCellWeightsInfo,
369                                               &aclInputToOutputWeightsInfo,
370                                               &aclRecurrentToForgetWeightsInfo,
371                                               &aclRecurrentToCellWeightsInfo,
372                                               &aclRecurrentToOutputWeightsInfo,
373                                               &aclForgetGateBiasInfo,
374                                               &aclCellBiasInfo,
375                                               &aclOutputGateBiasInfo,
376                                               &aclOutputStateInInfo, &aclCellStateInInfo,
377                                               &aclScratchBufferInfo, &aclOutputStateOutInfo,
378                                               &aclCellStateOutInfo, &aclOutputInfo,
379                                               lstm_params_info, activationLayerInfo,
380                                               cell_threshold, projection_threshold);
381 }
382
383 void ClLstmFloat32Workload::FreeUnusedTensors()
384 {
385     FreeTensorIfUnused(m_InputToInputWeightsTensor);
386     FreeTensorIfUnused(m_InputToForgetWeightsTensor);
387     FreeTensorIfUnused(m_InputToCellWeightsTensor);
388     FreeTensorIfUnused(m_InputToOutputWeightsTensor);
389     FreeTensorIfUnused(m_RecurrentToInputWeightsTensor);
390     FreeTensorIfUnused(m_RecurrentToForgetWeightsTensor);
391     FreeTensorIfUnused(m_RecurrentToCellWeightsTensor);
392     FreeTensorIfUnused(m_RecurrentToOutputWeightsTensor);
393     FreeTensorIfUnused(m_CellToInputWeightsTensor);
394     FreeTensorIfUnused(m_CellToForgetWeightsTensor);
395     FreeTensorIfUnused(m_CellToOutputWeightsTensor);
396     FreeTensorIfUnused(m_InputGateBiasTensor);
397     FreeTensorIfUnused(m_ForgetGateBiasTensor);
398     FreeTensorIfUnused(m_CellBiasTensor);
399     FreeTensorIfUnused(m_OutputGateBiasTensor);
400     FreeTensorIfUnused(m_ProjectionWeightsTensor);
401     FreeTensorIfUnused(m_ProjectionBiasTensor);
402     FreeTensorIfUnused(m_ScratchBuffer);
403 }
404
405 } //namespace armnn