IVGCVSW-1946: Remove armnn/src from the include paths
[platform/upstream/armnn.git] / src / backends / neon / workloads / NeonLstmFloatWorkload.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "NeonLstmFloatWorkload.hpp"
7 #include "NeonWorkloadUtils.hpp"
8
9 #include "backendsCommon/CpuTensorHandle.hpp"
10 #include "aclCommon/ArmComputeTensorUtils.hpp"
11 #include "neon/NeonTensorHandle.hpp"
12
13 namespace armnn
14 {
15 using namespace armcomputetensorutils;
16
17 NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
18         : FloatWorkload<LstmQueueDescriptor>(descriptor, info)
19 {
20     arm_compute::LSTMParams<arm_compute::ITensor> lstm_param;
21
22     // Basic parameters
23     m_InputToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
24     BuildArmComputeTensor(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights->GetTensorInfo());
25
26     m_InputToCellWeightsTensor = std::make_unique<arm_compute::Tensor>();
27     BuildArmComputeTensor(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights->GetTensorInfo());
28
29     m_InputToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
30     BuildArmComputeTensor(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights->GetTensorInfo());
31
32     m_RecurrentToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
33     BuildArmComputeTensor(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights->GetTensorInfo());
34
35     m_RecurrentToCellWeightsTensor = std::make_unique<arm_compute::Tensor>();
36     BuildArmComputeTensor(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights->GetTensorInfo());
37
38     m_RecurrentToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
39     BuildArmComputeTensor(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights->GetTensorInfo());
40
41     m_ForgetGateBiasTensor = std::make_unique<arm_compute::Tensor>();
42     BuildArmComputeTensor(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias->GetTensorInfo());
43
44     m_CellBiasTensor = std::make_unique<arm_compute::Tensor>();
45     BuildArmComputeTensor(*m_CellBiasTensor, m_Data.m_CellBias->GetTensorInfo());
46
47     m_OutputGateBiasTensor = std::make_unique<arm_compute::Tensor>();
48     BuildArmComputeTensor(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias->GetTensorInfo());
49
50     // for future reference: check the AndroidNN API for the logic here
51     if (!m_Data.m_Parameters.m_CifgEnabled)
52     {
53         m_InputToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
54         BuildArmComputeTensor(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights->GetTensorInfo());
55
56         m_RecurrentToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
57         BuildArmComputeTensor(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights->GetTensorInfo());
58
59         m_CellToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
60         if (m_Data.m_CellToInputWeights != nullptr)
61         {
62             BuildArmComputeTensor(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights->GetTensorInfo());
63         }
64
65         m_InputGateBiasTensor = std::make_unique<arm_compute::Tensor>();
66         BuildArmComputeTensor(*m_InputGateBiasTensor, m_Data.m_InputGateBias->GetTensorInfo());
67
68         lstm_param.set_cifg_params(m_InputToInputWeightsTensor.get(),
69                                    m_RecurrentToInputWeightsTensor.get(),
70                                    m_Data.m_CellToInputWeights != nullptr ? m_CellToInputWeightsTensor.get() : nullptr,
71                                    m_InputGateBiasTensor.get());
72     }
73
74     if (m_Data.m_Parameters.m_ProjectionEnabled)
75     {
76         m_ProjectionWeightsTensor = std::make_unique<arm_compute::Tensor>();
77         BuildArmComputeTensor(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights->GetTensorInfo());
78
79         m_ProjectionBiasTensor = std::make_unique<arm_compute::Tensor>();
80         if (m_Data.m_ProjectionBias != nullptr)
81         {
82             BuildArmComputeTensor(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias->GetTensorInfo());
83         }
84
85         lstm_param.set_projection_params(m_ProjectionWeightsTensor.get(),
86                                          m_Data.m_ProjectionBias != nullptr ? m_ProjectionBiasTensor.get() : nullptr);
87     }
88
89     if (m_Data.m_Parameters.m_PeepholeEnabled)
90     {
91         m_CellToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
92         BuildArmComputeTensor(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights->GetTensorInfo());
93
94         m_CellToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
95         BuildArmComputeTensor(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights->GetTensorInfo());
96
97         lstm_param.set_peephole_params(m_CellToForgetWeightsTensor.get(), m_CellToOutputWeightsTensor.get());
98     }
99
100     const arm_compute::ITensor& input           = static_cast<INeonTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
101     const arm_compute::ITensor& output_state_in = static_cast<INeonTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
102     const arm_compute::ITensor& cell_state_in   = static_cast<INeonTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
103
104     arm_compute::ITensor& output_state_out      = static_cast<INeonTensorHandle*>(m_Data.m_Outputs[1])->GetTensor();
105     arm_compute::ITensor& cell_state_out        = static_cast<INeonTensorHandle*>(m_Data.m_Outputs[2])->GetTensor();
106     arm_compute::ITensor& output                = static_cast<INeonTensorHandle*>(m_Data.m_Outputs[3])->GetTensor();
107
108     // Get the batch_size and the num_units from the cellStateIn dimensions
109     const TensorInfo& inputTensorInfo = info.m_InputTensorInfos[2];
110     const unsigned int batch_size = boost::numeric_cast<unsigned int>(inputTensorInfo.GetShape()[0]);
111     const unsigned int num_units  = boost::numeric_cast<unsigned int>(inputTensorInfo.GetShape()[1]);
112
113     m_ScratchBuffer = std::make_unique<arm_compute::Tensor>();
114     if (m_Data.m_Parameters.m_CifgEnabled)
115     {
116         // 2D tensor with dimensions [num_units * 4, batch_size] with CIFG
117         armnn::TensorInfo scratchBuffer1({ batch_size, num_units * 4 }, DataType::Float32);
118         BuildArmComputeTensor(*m_ScratchBuffer, scratchBuffer1);
119     }
120     else
121     {
122         // scratch_buffer [num_units * 3, batch_size] without CIFG
123         armnn::TensorInfo scratchBuffer2({ batch_size, num_units * 3 }, DataType::Float32);
124         BuildArmComputeTensor(*m_ScratchBuffer, scratchBuffer2);
125     }
126
127     float cell_threshold = m_Data.m_Parameters.m_ClippingThresCell;
128     float projection_threshold = m_Data.m_Parameters.m_ClippingThresProj;
129
130     // for preparing the object for the class ActivationLayerInfo, we need to consider 5 situations
131     arm_compute::ActivationLayerInfo activationLayerInfo;
132     if (m_Data.m_Parameters.m_ActivationFunc == 0)
133     {
134         // no activation, do nothing
135     }
136     else if (m_Data.m_Parameters.m_ActivationFunc == 1)
137     {
138         activationLayerInfo = arm_compute::ActivationLayerInfo(
139                 arm_compute::ActivationLayerInfo::ActivationFunction::RELU);
140     }
141     else if (m_Data.m_Parameters.m_ActivationFunc == 3)
142     {
143         activationLayerInfo = arm_compute::ActivationLayerInfo(
144                 arm_compute::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.0);
145     }
146     else if (m_Data.m_Parameters.m_ActivationFunc == 4)
147     {
148         activationLayerInfo = arm_compute::ActivationLayerInfo(
149                 arm_compute::ActivationLayerInfo::ActivationFunction::TANH, 1.0, 1.0);
150     }
151     else if (m_Data.m_Parameters.m_ActivationFunc == 6)
152     {
153         activationLayerInfo = arm_compute::ActivationLayerInfo(
154                 arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC);
155     }
156     else
157     {
158         throw armnn::Exception("Wrong Type of Activation Function!");
159     }
160
161
162     m_LstmLayer.configure(&input, m_InputToForgetWeightsTensor.get(), m_InputToCellWeightsTensor.get(),
163                           m_InputToOutputWeightsTensor.get(), m_RecurrentToForgetWeightsTensor.get(),
164                           m_RecurrentToCellWeightsTensor.get(), m_RecurrentToOutputWeightsTensor.get(),
165                           m_ForgetGateBiasTensor.get(), m_CellBiasTensor.get(), m_OutputGateBiasTensor.get(),
166                           &output_state_in, &cell_state_in, m_ScratchBuffer.get(), &output_state_out,
167                           &cell_state_out, &output, lstm_param, activationLayerInfo,
168                           cell_threshold, projection_threshold);
169
170     armcomputetensorutils::InitialiseArmComputeTensorEmpty(*m_ScratchBuffer);
171
172     InitializeArmComputeTensorData(*m_InputToForgetWeightsTensor,
173                                    m_Data.m_InputToForgetWeights);
174     InitializeArmComputeTensorData(*m_InputToCellWeightsTensor,
175                                    m_Data.m_InputToCellWeights);
176     InitializeArmComputeTensorData(*m_InputToOutputWeightsTensor,
177                                    m_Data.m_InputToOutputWeights);
178     InitializeArmComputeTensorData(*m_RecurrentToForgetWeightsTensor,
179                                    m_Data.m_RecurrentToForgetWeights);
180     InitializeArmComputeTensorData(*m_RecurrentToCellWeightsTensor,
181                                    m_Data.m_RecurrentToCellWeights);
182     InitializeArmComputeTensorData(*m_RecurrentToOutputWeightsTensor,
183                                    m_Data.m_RecurrentToOutputWeights);
184     InitializeArmComputeTensorData(*m_ForgetGateBiasTensor,
185                                    m_Data.m_ForgetGateBias);
186     InitializeArmComputeTensorData(*m_CellBiasTensor,
187                                    m_Data.m_CellBias);
188     InitializeArmComputeTensorData(*m_OutputGateBiasTensor,
189                                    m_Data.m_OutputGateBias);
190
191     if (!m_Data.m_Parameters.m_CifgEnabled)
192     {
193         InitializeArmComputeTensorData(*m_InputToInputWeightsTensor,
194                                        m_Data.m_InputToInputWeights);
195         InitializeArmComputeTensorData(*m_RecurrentToInputWeightsTensor,
196                                        m_Data.m_RecurrentToInputWeights);
197         if (m_Data.m_CellToInputWeights != nullptr)
198         {
199             InitializeArmComputeTensorData(*m_CellToInputWeightsTensor,
200                                            m_Data.m_CellToInputWeights);
201         }
202         InitializeArmComputeTensorData(*m_InputGateBiasTensor,
203                                        m_Data.m_InputGateBias);
204     }
205
206     if (m_Data.m_Parameters.m_ProjectionEnabled)
207     {
208         InitializeArmComputeTensorData(*m_ProjectionWeightsTensor,
209                                        m_Data.m_ProjectionWeights);
210         if (m_Data.m_ProjectionBias != nullptr)
211         {
212             InitializeArmComputeTensorData(*m_ProjectionBiasTensor,
213                                            m_Data.m_ProjectionBias);
214         }
215     }
216
217     if (m_Data.m_Parameters.m_PeepholeEnabled)
218     {
219         InitializeArmComputeTensorData(*m_CellToForgetWeightsTensor,
220                                        m_Data.m_CellToForgetWeights);
221         InitializeArmComputeTensorData(*m_CellToOutputWeightsTensor,
222                                        m_Data.m_CellToOutputWeights);
223     }
224
225     // Force Compute Library to perform the necessary copying and reshaping, after which
226     // delete all the input tensors that will no longer be needed
227     m_LstmLayer.prepare();
228     FreeUnusedTensors();
229 }
230
231 void NeonLstmFloatWorkload::Execute() const
232 {
233     m_LstmLayer.run();
234 }
235
236 arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input,
237                                                   const TensorInfo& outputStateIn,
238                                                   const TensorInfo& cellStateIn,
239                                                   const TensorInfo& scratchBuffer,
240                                                   const TensorInfo& outputStateOut,
241                                                   const TensorInfo& cellStateOut,
242                                                   const TensorInfo& output,
243                                                   const LstmDescriptor& descriptor,
244                                                   const TensorInfo& inputToForgetWeights,
245                                                   const TensorInfo& inputToCellWeights,
246                                                   const TensorInfo& inputToOutputWeights,
247                                                   const TensorInfo& recurrentToForgetWeights,
248                                                   const TensorInfo& recurrentToCellWeights,
249                                                   const TensorInfo& recurrentToOutputWeights,
250                                                   const TensorInfo& forgetGateBias,
251                                                   const TensorInfo& cellBias,
252                                                   const TensorInfo& outputGateBias,
253                                                   const TensorInfo* inputToInputWeights,
254                                                   const TensorInfo* recurrentToInputWeights,
255                                                   const TensorInfo* cellToInputWeights,
256                                                   const TensorInfo* inputGateBias,
257                                                   const TensorInfo* projectionWeights,
258                                                   const TensorInfo* projectionBias,
259                                                   const TensorInfo* cellToForgetWeights,
260                                                   const TensorInfo* cellToOutputWeights)
261 {
262     arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info;
263
264     // The inputs and the outputs
265     const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input);
266     const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);
267     const arm_compute::TensorInfo aclCellStateInInfo = BuildArmComputeTensorInfo(cellStateIn);
268     const arm_compute::TensorInfo aclScratchBufferInfo = BuildArmComputeTensorInfo(scratchBuffer);
269     const arm_compute::TensorInfo aclOutputStateOutInfo = BuildArmComputeTensorInfo(outputStateOut);
270     const arm_compute::TensorInfo aclCellStateOutInfo = BuildArmComputeTensorInfo(cellStateOut);
271     const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
272
273     // Basic parameters
274     const arm_compute::TensorInfo aclInputToForgetWeightsInfo = BuildArmComputeTensorInfo(inputToForgetWeights);
275     const arm_compute::TensorInfo aclInputToCellWeightsInfo = BuildArmComputeTensorInfo(inputToCellWeights);
276     const arm_compute::TensorInfo aclInputToOutputWeightsInfo = BuildArmComputeTensorInfo(inputToOutputWeights);
277     const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
278                                   = BuildArmComputeTensorInfo(recurrentToForgetWeights);
279     const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
280                                   = BuildArmComputeTensorInfo(recurrentToCellWeights);
281     const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
282                                   = BuildArmComputeTensorInfo(recurrentToOutputWeights);
283     const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(forgetGateBias);
284     const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(cellBias);
285     const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(outputGateBias);
286
287     arm_compute::TensorInfo aclInputToInputWeightsInfo;
288     arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
289     arm_compute::TensorInfo aclCellToInputWeightsInfo;
290     arm_compute::TensorInfo aclInputGateBiasInfo;
291     arm_compute::TensorInfo aclProjectionWeightsInfo;
292     arm_compute::TensorInfo aclProjectionBiasInfo;
293     arm_compute::TensorInfo aclCellToForgetWeightsInfo;
294     arm_compute::TensorInfo aclCellToOutputWeightsInfo;
295
296     if (!descriptor.m_CifgEnabled)
297     {
298         armnn::TensorInfo inputToInputWInfo = *inputToInputWeights;
299         aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(inputToInputWInfo);
300         armnn::TensorInfo recurrentToInputWInfo = *recurrentToInputWeights;
301         aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(recurrentToInputWInfo);
302
303         if (cellToInputWeights != nullptr)
304         {
305             armnn::TensorInfo cellToInputWInfo = *cellToInputWeights;
306             aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(cellToInputWInfo);
307         }
308         armnn::TensorInfo inputGateBiasInfo = *inputGateBias;
309         aclInputGateBiasInfo = BuildArmComputeTensorInfo(inputGateBiasInfo);
310         lstm_params_info.set_cifg_params(&aclInputToInputWeightsInfo, &aclRecurrentToInputWeightsInfo,
311                                          cellToInputWeights != nullptr ? &aclCellToInputWeightsInfo: nullptr,
312                                          &aclInputGateBiasInfo);
313     }
314
315     if (descriptor.m_ProjectionEnabled)
316     {
317         const armnn::TensorInfo& projectionWInfo = *projectionWeights;
318         aclProjectionWeightsInfo = BuildArmComputeTensorInfo(projectionWInfo);
319
320         if (projectionBias != nullptr)
321         {
322             const armnn::TensorInfo& projectionBiasInfo = *projectionBias;
323             aclProjectionBiasInfo = BuildArmComputeTensorInfo(projectionBiasInfo);
324         }
325         lstm_params_info.set_projection_params(&aclProjectionWeightsInfo,
326                                                projectionBias != nullptr ? &aclProjectionBiasInfo: nullptr);
327     }
328
329     if (descriptor.m_PeepholeEnabled)
330     {
331         const armnn::TensorInfo& cellToForgetWInfo = *cellToForgetWeights;
332         aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(cellToForgetWInfo);
333         const armnn::TensorInfo& cellToOutputWInfo = *cellToOutputWeights;
334         aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(cellToOutputWInfo);
335         lstm_params_info.set_peephole_params(&aclCellToForgetWeightsInfo, &aclCellToOutputWeightsInfo);
336     }
337
338     float cell_threshold = descriptor.m_ClippingThresCell;
339     float projection_threshold = descriptor.m_ClippingThresProj;
340
341     // for preparing the object for the class ActivationLayerInfo, we need to consider 5 situations
342     arm_compute::ActivationLayerInfo activationLayerInfo;
343     switch (descriptor.m_ActivationFunc)
344     {
345         case 0:
346             // no activation, do nothing
347             break;
348         case 1:
349             activationLayerInfo = arm_compute::ActivationLayerInfo(
350                     arm_compute::ActivationLayerInfo::ActivationFunction::RELU);
351             break;
352         case 3:
353             activationLayerInfo = arm_compute::ActivationLayerInfo(
354                     arm_compute::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.0);
355             break;
356         case 4:
357             activationLayerInfo = arm_compute::ActivationLayerInfo(
358                     arm_compute::ActivationLayerInfo::ActivationFunction::TANH, 1.0, 1.0);
359             break;
360         case 6:
361             activationLayerInfo = arm_compute::ActivationLayerInfo(
362                     arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC);
363             break;
364         default:
365             throw armnn::Exception("Wrong Type of Activation Function!");
366     }
367
368     return arm_compute::NELSTMLayer::validate(&aclInputInfo,
369                                               &aclInputToForgetWeightsInfo,
370                                               &aclInputToCellWeightsInfo,
371                                               &aclInputToOutputWeightsInfo,
372                                               &aclRecurrentToForgetWeightsInfo,
373                                               &aclRecurrentToCellWeightsInfo,
374                                               &aclRecurrentToOutputWeightsInfo,
375                                               &aclForgetGateBiasInfo,
376                                               &aclCellBiasInfo,
377                                               &aclOutputGateBiasInfo,
378                                               &aclOutputStateInInfo,
379                                               &aclCellStateInInfo,
380                                               &aclScratchBufferInfo,
381                                               &aclOutputStateOutInfo,
382                                               &aclCellStateOutInfo,
383                                               &aclOutputInfo,
384                                               lstm_params_info,
385                                               activationLayerInfo,
386                                               cell_threshold,
387                                               projection_threshold);
388 }
389
390 void NeonLstmFloatWorkload::FreeUnusedTensors()
391 {
392     FreeTensorIfUnused(m_InputToInputWeightsTensor);
393     FreeTensorIfUnused(m_InputToForgetWeightsTensor);
394     FreeTensorIfUnused(m_InputToCellWeightsTensor);
395     FreeTensorIfUnused(m_InputToOutputWeightsTensor);
396     FreeTensorIfUnused(m_RecurrentToInputWeightsTensor);
397     FreeTensorIfUnused(m_RecurrentToForgetWeightsTensor);
398     FreeTensorIfUnused(m_RecurrentToCellWeightsTensor);
399     FreeTensorIfUnused(m_RecurrentToOutputWeightsTensor);
400     FreeTensorIfUnused(m_CellToInputWeightsTensor);
401     FreeTensorIfUnused(m_CellToForgetWeightsTensor);
402     FreeTensorIfUnused(m_CellToOutputWeightsTensor);
403     FreeTensorIfUnused(m_InputGateBiasTensor);
404     FreeTensorIfUnused(m_ForgetGateBiasTensor);
405     FreeTensorIfUnused(m_CellBiasTensor);
406     FreeTensorIfUnused(m_OutputGateBiasTensor);
407     FreeTensorIfUnused(m_ProjectionWeightsTensor);
408     FreeTensorIfUnused(m_ProjectionBiasTensor);
409     FreeTensorIfUnused(m_ScratchBuffer);
410 }
411
412 } //namespace armnn