//
-// Copyright © 2020 Arm Ltd. All rights reserved.
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
{
arm_compute::LSTMParams<arm_compute::ITensor> qLstmParams;
- // Mandatory tensors
+ // Mandatory params
m_InputToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
BuildArmComputeTensor(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights->GetTensorInfo());
// Set projection params
qLstmParams.set_projection_params(
- m_ProjectionWeightsTensor.get(),
- m_Data.m_ProjectionBias != nullptr ? m_ProjectionBiasTensor.get() : nullptr);
+ m_ProjectionWeightsTensor.get(),
+ m_Data.m_ProjectionBias != nullptr ? m_ProjectionBiasTensor.get() : nullptr);
}
if (m_Data.m_Parameters.m_LayerNormEnabled)
// Set layer norm params
qLstmParams.set_layer_normalization_params(
- m_Data.m_InputLayerNormWeights != nullptr ? m_InputLayerNormWeightsTensor.get() : nullptr,
- m_ForgetLayerNormWeightsTensor.get(),
- m_CellLayerNormWeightsTensor.get(),
- m_OutputLayerNormWeightsTensor.get());
+ m_Data.m_InputLayerNormWeights != nullptr ? m_InputLayerNormWeightsTensor.get() : nullptr,
+ m_ForgetLayerNormWeightsTensor.get(),
+ m_CellLayerNormWeightsTensor.get(),
+ m_OutputLayerNormWeightsTensor.get());
}
if (!m_Data.m_Parameters.m_CifgEnabled)
// Set CIFG params
qLstmParams.set_cifg_params(
- m_InputToInputWeightsTensor.get(),
- m_RecurrentToInputWeightsTensor.get(),
- m_Data.m_CellToInputWeights != nullptr ? m_CellToInputWeightsTensor.get() : nullptr,
- m_InputGateBiasTensor.get());
+ m_InputToInputWeightsTensor.get(),
+ m_RecurrentToInputWeightsTensor.get(),
+ m_Data.m_CellToInputWeights != nullptr ? m_CellToInputWeightsTensor.get() : nullptr,
+ m_InputGateBiasTensor.get());
}
- // Input/output tensors
+ // Input/Output tensors
const arm_compute::ITensor& input = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
const arm_compute::ITensor& outputStateIn = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
const arm_compute::ITensor& cellStateIn = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
arm_compute::ITensor& cellStateOut = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[1])->GetTensor();
arm_compute::ITensor& output = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[2])->GetTensor();
-
// Set scalar descriptor params
qLstmParams.set_cell_clip_params(m_Data.m_Parameters.m_CellClip);
qLstmParams.set_projection_clip_params(m_Data.m_Parameters.m_ProjectionClip);
const arm_compute::TensorInfo aclOutputStateOutInfo = BuildArmComputeTensorInfo(outputStateOut);
const arm_compute::TensorInfo aclCellStateOutInfo = BuildArmComputeTensorInfo(cellStateOut);
- const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
+ const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
// Mandatory tensor info
const arm_compute::TensorInfo aclInputToForgetWeightsInfo
arm_compute::TensorInfo aclCellLayerNormWeightsInfo;
arm_compute::TensorInfo aclOutputLayerNormWeightsInfo;
-
// Create tensor info for optional params if they are enabled
if (descriptor.m_PeepholeEnabled)
{
// Set projection params info
aclParamsInfo.set_projection_params(
- &aclProjectionWeightsInfo,
- paramsInfo.m_ProjectionBias != nullptr ? &aclProjectionBiasInfo : nullptr);
+ &aclProjectionWeightsInfo,
+ paramsInfo.m_ProjectionBias != nullptr ? &aclProjectionBiasInfo : nullptr);
}
-
-
if (descriptor.m_LayerNormEnabled)
{
if (!descriptor.m_CifgEnabled)
{
aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputLayerNormWeights());
-
}
aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetForgetLayerNormWeights());
// Set layer norm params info
aclParamsInfo.set_layer_normalization_params(
- paramsInfo.m_InputLayerNormWeights != nullptr ? &aclInputLayerNormWeightsInfo : nullptr,
- &aclForgetLayerNormWeightsInfo,
- &aclCellLayerNormWeightsInfo,
- &aclOutputLayerNormWeightsInfo);
+ paramsInfo.m_InputLayerNormWeights != nullptr ? &aclInputLayerNormWeightsInfo : nullptr,
+ &aclForgetLayerNormWeightsInfo,
+ &aclCellLayerNormWeightsInfo,
+ &aclOutputLayerNormWeightsInfo);
}
if (!descriptor.m_CifgEnabled)
// Set CIFG params info
aclParamsInfo.set_cifg_params(
- &aclInputToInputWeightsInfo,
- &aclRecurrentToInputWeightsInfo,
- paramsInfo.m_CellToInputWeights != nullptr ? &aclCellToInputWeightsInfo : nullptr,
- &aclInputGateBiasInfo);
+ &aclInputToInputWeightsInfo,
+ &aclRecurrentToInputWeightsInfo,
+ paramsInfo.m_CellToInputWeights != nullptr ? &aclCellToInputWeightsInfo : nullptr,
+ &aclInputGateBiasInfo);
}
// Set scalar descriptor params