2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include "NeonLstmFloatWorkload.hpp"
7 #include "NeonWorkloadUtils.hpp"
9 #include "backendsCommon/CpuTensorHandle.hpp"
10 #include "aclCommon/ArmComputeTensorUtils.hpp"
11 #include "neon/NeonTensorHandle.hpp"
15 using namespace armcomputetensorutils;
17 NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
18 : FloatWorkload<LstmQueueDescriptor>(descriptor, info)
20 arm_compute::LSTMParams<arm_compute::ITensor> lstm_param;
23 m_InputToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
24 BuildArmComputeTensor(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights->GetTensorInfo());
26 m_InputToCellWeightsTensor = std::make_unique<arm_compute::Tensor>();
27 BuildArmComputeTensor(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights->GetTensorInfo());
29 m_InputToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
30 BuildArmComputeTensor(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights->GetTensorInfo());
32 m_RecurrentToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
33 BuildArmComputeTensor(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights->GetTensorInfo());
35 m_RecurrentToCellWeightsTensor = std::make_unique<arm_compute::Tensor>();
36 BuildArmComputeTensor(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights->GetTensorInfo());
38 m_RecurrentToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
39 BuildArmComputeTensor(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights->GetTensorInfo());
41 m_ForgetGateBiasTensor = std::make_unique<arm_compute::Tensor>();
42 BuildArmComputeTensor(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias->GetTensorInfo());
44 m_CellBiasTensor = std::make_unique<arm_compute::Tensor>();
45 BuildArmComputeTensor(*m_CellBiasTensor, m_Data.m_CellBias->GetTensorInfo());
47 m_OutputGateBiasTensor = std::make_unique<arm_compute::Tensor>();
48 BuildArmComputeTensor(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias->GetTensorInfo());
50 // for future reference: check the AndroidNN API for the logic here
51 if (!m_Data.m_Parameters.m_CifgEnabled)
53 m_InputToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
54 BuildArmComputeTensor(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights->GetTensorInfo());
56 m_RecurrentToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
57 BuildArmComputeTensor(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights->GetTensorInfo());
59 m_CellToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
60 if (m_Data.m_CellToInputWeights != nullptr)
62 BuildArmComputeTensor(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights->GetTensorInfo());
65 m_InputGateBiasTensor = std::make_unique<arm_compute::Tensor>();
66 BuildArmComputeTensor(*m_InputGateBiasTensor, m_Data.m_InputGateBias->GetTensorInfo());
68 lstm_param.set_cifg_params(m_InputToInputWeightsTensor.get(),
69 m_RecurrentToInputWeightsTensor.get(),
70 m_Data.m_CellToInputWeights != nullptr ? m_CellToInputWeightsTensor.get() : nullptr,
71 m_InputGateBiasTensor.get());
74 if (m_Data.m_Parameters.m_ProjectionEnabled)
76 m_ProjectionWeightsTensor = std::make_unique<arm_compute::Tensor>();
77 BuildArmComputeTensor(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights->GetTensorInfo());
79 m_ProjectionBiasTensor = std::make_unique<arm_compute::Tensor>();
80 if (m_Data.m_ProjectionBias != nullptr)
82 BuildArmComputeTensor(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias->GetTensorInfo());
85 lstm_param.set_projection_params(m_ProjectionWeightsTensor.get(),
86 m_Data.m_ProjectionBias != nullptr ? m_ProjectionBiasTensor.get() : nullptr);
89 if (m_Data.m_Parameters.m_PeepholeEnabled)
91 m_CellToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
92 BuildArmComputeTensor(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights->GetTensorInfo());
94 m_CellToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
95 BuildArmComputeTensor(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights->GetTensorInfo());
97 lstm_param.set_peephole_params(m_CellToForgetWeightsTensor.get(), m_CellToOutputWeightsTensor.get());
100 const arm_compute::ITensor& input = static_cast<INeonTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
101 const arm_compute::ITensor& output_state_in = static_cast<INeonTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
102 const arm_compute::ITensor& cell_state_in = static_cast<INeonTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
104 arm_compute::ITensor& output_state_out = static_cast<INeonTensorHandle*>(m_Data.m_Outputs[1])->GetTensor();
105 arm_compute::ITensor& cell_state_out = static_cast<INeonTensorHandle*>(m_Data.m_Outputs[2])->GetTensor();
106 arm_compute::ITensor& output = static_cast<INeonTensorHandle*>(m_Data.m_Outputs[3])->GetTensor();
108 // Get the batch_size and the num_units from the cellStateIn dimensions
109 const TensorInfo& inputTensorInfo = info.m_InputTensorInfos[2];
110 const unsigned int batch_size = boost::numeric_cast<unsigned int>(inputTensorInfo.GetShape()[0]);
111 const unsigned int num_units = boost::numeric_cast<unsigned int>(inputTensorInfo.GetShape()[1]);
113 m_ScratchBuffer = std::make_unique<arm_compute::Tensor>();
114 if (m_Data.m_Parameters.m_CifgEnabled)
116 // 2D tensor with dimensions [num_units * 4, batch_size] with CIFG
117 armnn::TensorInfo scratchBuffer1({ batch_size, num_units * 4 }, DataType::Float32);
118 BuildArmComputeTensor(*m_ScratchBuffer, scratchBuffer1);
122 // scratch_buffer [num_units * 3, batch_size] without CIFG
123 armnn::TensorInfo scratchBuffer2({ batch_size, num_units * 3 }, DataType::Float32);
124 BuildArmComputeTensor(*m_ScratchBuffer, scratchBuffer2);
127 float cell_threshold = m_Data.m_Parameters.m_ClippingThresCell;
128 float projection_threshold = m_Data.m_Parameters.m_ClippingThresProj;
130 // for preparing the object for the class ActivationLayerInfo, we need to consider 5 situations
131 arm_compute::ActivationLayerInfo activationLayerInfo;
132 if (m_Data.m_Parameters.m_ActivationFunc == 0)
134 // no activation, do nothing
136 else if (m_Data.m_Parameters.m_ActivationFunc == 1)
138 activationLayerInfo = arm_compute::ActivationLayerInfo(
139 arm_compute::ActivationLayerInfo::ActivationFunction::RELU);
141 else if (m_Data.m_Parameters.m_ActivationFunc == 3)
143 activationLayerInfo = arm_compute::ActivationLayerInfo(
144 arm_compute::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.0);
146 else if (m_Data.m_Parameters.m_ActivationFunc == 4)
148 activationLayerInfo = arm_compute::ActivationLayerInfo(
149 arm_compute::ActivationLayerInfo::ActivationFunction::TANH, 1.0, 1.0);
151 else if (m_Data.m_Parameters.m_ActivationFunc == 6)
153 activationLayerInfo = arm_compute::ActivationLayerInfo(
154 arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC);
158 throw armnn::Exception("Wrong Type of Activation Function!");
162 m_LstmLayer.configure(&input, m_InputToForgetWeightsTensor.get(), m_InputToCellWeightsTensor.get(),
163 m_InputToOutputWeightsTensor.get(), m_RecurrentToForgetWeightsTensor.get(),
164 m_RecurrentToCellWeightsTensor.get(), m_RecurrentToOutputWeightsTensor.get(),
165 m_ForgetGateBiasTensor.get(), m_CellBiasTensor.get(), m_OutputGateBiasTensor.get(),
166 &output_state_in, &cell_state_in, m_ScratchBuffer.get(), &output_state_out,
167 &cell_state_out, &output, lstm_param, activationLayerInfo,
168 cell_threshold, projection_threshold);
170 armcomputetensorutils::InitialiseArmComputeTensorEmpty(*m_ScratchBuffer);
172 InitializeArmComputeTensorData(*m_InputToForgetWeightsTensor,
173 m_Data.m_InputToForgetWeights);
174 InitializeArmComputeTensorData(*m_InputToCellWeightsTensor,
175 m_Data.m_InputToCellWeights);
176 InitializeArmComputeTensorData(*m_InputToOutputWeightsTensor,
177 m_Data.m_InputToOutputWeights);
178 InitializeArmComputeTensorData(*m_RecurrentToForgetWeightsTensor,
179 m_Data.m_RecurrentToForgetWeights);
180 InitializeArmComputeTensorData(*m_RecurrentToCellWeightsTensor,
181 m_Data.m_RecurrentToCellWeights);
182 InitializeArmComputeTensorData(*m_RecurrentToOutputWeightsTensor,
183 m_Data.m_RecurrentToOutputWeights);
184 InitializeArmComputeTensorData(*m_ForgetGateBiasTensor,
185 m_Data.m_ForgetGateBias);
186 InitializeArmComputeTensorData(*m_CellBiasTensor,
188 InitializeArmComputeTensorData(*m_OutputGateBiasTensor,
189 m_Data.m_OutputGateBias);
191 if (!m_Data.m_Parameters.m_CifgEnabled)
193 InitializeArmComputeTensorData(*m_InputToInputWeightsTensor,
194 m_Data.m_InputToInputWeights);
195 InitializeArmComputeTensorData(*m_RecurrentToInputWeightsTensor,
196 m_Data.m_RecurrentToInputWeights);
197 if (m_Data.m_CellToInputWeights != nullptr)
199 InitializeArmComputeTensorData(*m_CellToInputWeightsTensor,
200 m_Data.m_CellToInputWeights);
202 InitializeArmComputeTensorData(*m_InputGateBiasTensor,
203 m_Data.m_InputGateBias);
206 if (m_Data.m_Parameters.m_ProjectionEnabled)
208 InitializeArmComputeTensorData(*m_ProjectionWeightsTensor,
209 m_Data.m_ProjectionWeights);
210 if (m_Data.m_ProjectionBias != nullptr)
212 InitializeArmComputeTensorData(*m_ProjectionBiasTensor,
213 m_Data.m_ProjectionBias);
217 if (m_Data.m_Parameters.m_PeepholeEnabled)
219 InitializeArmComputeTensorData(*m_CellToForgetWeightsTensor,
220 m_Data.m_CellToForgetWeights);
221 InitializeArmComputeTensorData(*m_CellToOutputWeightsTensor,
222 m_Data.m_CellToOutputWeights);
225 // Force Compute Library to perform the necessary copying and reshaping, after which
226 // delete all the input tensors that will no longer be needed
227 m_LstmLayer.prepare();
231 void NeonLstmFloatWorkload::Execute() const
236 arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input,
237 const TensorInfo& outputStateIn,
238 const TensorInfo& cellStateIn,
239 const TensorInfo& scratchBuffer,
240 const TensorInfo& outputStateOut,
241 const TensorInfo& cellStateOut,
242 const TensorInfo& output,
243 const LstmDescriptor& descriptor,
244 const TensorInfo& inputToForgetWeights,
245 const TensorInfo& inputToCellWeights,
246 const TensorInfo& inputToOutputWeights,
247 const TensorInfo& recurrentToForgetWeights,
248 const TensorInfo& recurrentToCellWeights,
249 const TensorInfo& recurrentToOutputWeights,
250 const TensorInfo& forgetGateBias,
251 const TensorInfo& cellBias,
252 const TensorInfo& outputGateBias,
253 const TensorInfo* inputToInputWeights,
254 const TensorInfo* recurrentToInputWeights,
255 const TensorInfo* cellToInputWeights,
256 const TensorInfo* inputGateBias,
257 const TensorInfo* projectionWeights,
258 const TensorInfo* projectionBias,
259 const TensorInfo* cellToForgetWeights,
260 const TensorInfo* cellToOutputWeights)
262 arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info;
264 // The inputs and the outputs
265 const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input);
266 const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);
267 const arm_compute::TensorInfo aclCellStateInInfo = BuildArmComputeTensorInfo(cellStateIn);
268 const arm_compute::TensorInfo aclScratchBufferInfo = BuildArmComputeTensorInfo(scratchBuffer);
269 const arm_compute::TensorInfo aclOutputStateOutInfo = BuildArmComputeTensorInfo(outputStateOut);
270 const arm_compute::TensorInfo aclCellStateOutInfo = BuildArmComputeTensorInfo(cellStateOut);
271 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
274 const arm_compute::TensorInfo aclInputToForgetWeightsInfo = BuildArmComputeTensorInfo(inputToForgetWeights);
275 const arm_compute::TensorInfo aclInputToCellWeightsInfo = BuildArmComputeTensorInfo(inputToCellWeights);
276 const arm_compute::TensorInfo aclInputToOutputWeightsInfo = BuildArmComputeTensorInfo(inputToOutputWeights);
277 const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
278 = BuildArmComputeTensorInfo(recurrentToForgetWeights);
279 const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
280 = BuildArmComputeTensorInfo(recurrentToCellWeights);
281 const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
282 = BuildArmComputeTensorInfo(recurrentToOutputWeights);
283 const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(forgetGateBias);
284 const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(cellBias);
285 const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(outputGateBias);
287 arm_compute::TensorInfo aclInputToInputWeightsInfo;
288 arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
289 arm_compute::TensorInfo aclCellToInputWeightsInfo;
290 arm_compute::TensorInfo aclInputGateBiasInfo;
291 arm_compute::TensorInfo aclProjectionWeightsInfo;
292 arm_compute::TensorInfo aclProjectionBiasInfo;
293 arm_compute::TensorInfo aclCellToForgetWeightsInfo;
294 arm_compute::TensorInfo aclCellToOutputWeightsInfo;
296 if (!descriptor.m_CifgEnabled)
298 armnn::TensorInfo inputToInputWInfo = *inputToInputWeights;
299 aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(inputToInputWInfo);
300 armnn::TensorInfo recurrentToInputWInfo = *recurrentToInputWeights;
301 aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(recurrentToInputWInfo);
303 if (cellToInputWeights != nullptr)
305 armnn::TensorInfo cellToInputWInfo = *cellToInputWeights;
306 aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(cellToInputWInfo);
308 armnn::TensorInfo inputGateBiasInfo = *inputGateBias;
309 aclInputGateBiasInfo = BuildArmComputeTensorInfo(inputGateBiasInfo);
310 lstm_params_info.set_cifg_params(&aclInputToInputWeightsInfo, &aclRecurrentToInputWeightsInfo,
311 cellToInputWeights != nullptr ? &aclCellToInputWeightsInfo: nullptr,
312 &aclInputGateBiasInfo);
315 if (descriptor.m_ProjectionEnabled)
317 const armnn::TensorInfo& projectionWInfo = *projectionWeights;
318 aclProjectionWeightsInfo = BuildArmComputeTensorInfo(projectionWInfo);
320 if (projectionBias != nullptr)
322 const armnn::TensorInfo& projectionBiasInfo = *projectionBias;
323 aclProjectionBiasInfo = BuildArmComputeTensorInfo(projectionBiasInfo);
325 lstm_params_info.set_projection_params(&aclProjectionWeightsInfo,
326 projectionBias != nullptr ? &aclProjectionBiasInfo: nullptr);
329 if (descriptor.m_PeepholeEnabled)
331 const armnn::TensorInfo& cellToForgetWInfo = *cellToForgetWeights;
332 aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(cellToForgetWInfo);
333 const armnn::TensorInfo& cellToOutputWInfo = *cellToOutputWeights;
334 aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(cellToOutputWInfo);
335 lstm_params_info.set_peephole_params(&aclCellToForgetWeightsInfo, &aclCellToOutputWeightsInfo);
338 float cell_threshold = descriptor.m_ClippingThresCell;
339 float projection_threshold = descriptor.m_ClippingThresProj;
341 // for preparing the object for the class ActivationLayerInfo, we need to consider 5 situations
342 arm_compute::ActivationLayerInfo activationLayerInfo;
343 switch (descriptor.m_ActivationFunc)
346 // no activation, do nothing
349 activationLayerInfo = arm_compute::ActivationLayerInfo(
350 arm_compute::ActivationLayerInfo::ActivationFunction::RELU);
353 activationLayerInfo = arm_compute::ActivationLayerInfo(
354 arm_compute::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.0);
357 activationLayerInfo = arm_compute::ActivationLayerInfo(
358 arm_compute::ActivationLayerInfo::ActivationFunction::TANH, 1.0, 1.0);
361 activationLayerInfo = arm_compute::ActivationLayerInfo(
362 arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC);
365 throw armnn::Exception("Wrong Type of Activation Function!");
368 return arm_compute::NELSTMLayer::validate(&aclInputInfo,
369 &aclInputToForgetWeightsInfo,
370 &aclInputToCellWeightsInfo,
371 &aclInputToOutputWeightsInfo,
372 &aclRecurrentToForgetWeightsInfo,
373 &aclRecurrentToCellWeightsInfo,
374 &aclRecurrentToOutputWeightsInfo,
375 &aclForgetGateBiasInfo,
377 &aclOutputGateBiasInfo,
378 &aclOutputStateInInfo,
380 &aclScratchBufferInfo,
381 &aclOutputStateOutInfo,
382 &aclCellStateOutInfo,
387 projection_threshold);
390 void NeonLstmFloatWorkload::FreeUnusedTensors()
392 FreeTensorIfUnused(m_InputToInputWeightsTensor);
393 FreeTensorIfUnused(m_InputToForgetWeightsTensor);
394 FreeTensorIfUnused(m_InputToCellWeightsTensor);
395 FreeTensorIfUnused(m_InputToOutputWeightsTensor);
396 FreeTensorIfUnused(m_RecurrentToInputWeightsTensor);
397 FreeTensorIfUnused(m_RecurrentToForgetWeightsTensor);
398 FreeTensorIfUnused(m_RecurrentToCellWeightsTensor);
399 FreeTensorIfUnused(m_RecurrentToOutputWeightsTensor);
400 FreeTensorIfUnused(m_CellToInputWeightsTensor);
401 FreeTensorIfUnused(m_CellToForgetWeightsTensor);
402 FreeTensorIfUnused(m_CellToOutputWeightsTensor);
403 FreeTensorIfUnused(m_InputGateBiasTensor);
404 FreeTensorIfUnused(m_ForgetGateBiasTensor);
405 FreeTensorIfUnused(m_CellBiasTensor);
406 FreeTensorIfUnused(m_OutputGateBiasTensor);
407 FreeTensorIfUnused(m_ProjectionWeightsTensor);
408 FreeTensorIfUnused(m_ProjectionBiasTensor);
409 FreeTensorIfUnused(m_ScratchBuffer);