lstm_param.set_peephole_params(m_CellToForgetWeightsTensor.get(), m_CellToOutputWeightsTensor.get());
}
+ if (m_Data.m_Parameters.m_LayerNormEnabled)
+ {
+ m_InputLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
+ m_ForgetLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
+ m_CellLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
+ m_OutputLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
+
+ if (!m_Data.m_Parameters.m_CifgEnabled)
+ {
+ BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo());
+ }
+ BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo());
+ BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo());
+ BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo());
+
+ lstm_param.set_layer_normalization_params(m_Data.m_Parameters.m_CifgEnabled ? nullptr :
+ m_InputLayerNormWeightsTensor.get(),
+ m_ForgetLayerNormWeightsTensor.get(),
+ m_CellLayerNormWeightsTensor.get(),
+ m_OutputLayerNormWeightsTensor.get());
+ }
+
const arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
const arm_compute::ICLTensor& output_state_in = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
const arm_compute::ICLTensor& cell_state_in = static_cast<IClTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
throw armnn::Exception("Wrong Type of Activation Function!");
}
-
m_LstmLayer.configure(&input, m_InputToForgetWeightsTensor.get(), m_InputToCellWeightsTensor.get(),
m_InputToOutputWeightsTensor.get(), m_RecurrentToForgetWeightsTensor.get(),
m_RecurrentToCellWeightsTensor.get(), m_RecurrentToOutputWeightsTensor.get(),
armcomputetensorutils::InitialiseArmComputeTensorEmpty(*m_ScratchBuffer);
- InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights);
- InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights);
- InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights);
+ InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights);
+ InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights);
+ InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights);
InitializeArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights);
- InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights);
+ InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights);
InitializeArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights);
- InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias);
- InitializeArmComputeClTensorData(*m_CellBiasTensor, m_Data.m_CellBias);
- InitializeArmComputeClTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias);
+ InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias);
+ InitializeArmComputeClTensorData(*m_CellBiasTensor, m_Data.m_CellBias);
+ InitializeArmComputeClTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias);
if (!m_Data.m_Parameters.m_CifgEnabled)
{
InitializeArmComputeClTensorData(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights);
}
+ if (m_Data.m_Parameters.m_LayerNormEnabled)
+ {
+ if (!m_Data.m_Parameters.m_CifgEnabled)
+ {
+ InitializeArmComputeClTensorData(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights);
+ }
+
+ InitializeArmComputeClTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights);
+ InitializeArmComputeClTensorData(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights);
+ InitializeArmComputeClTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights);
+ }
+
// Force Compute Library to perform the necessary copying and reshaping, after which
// delete all the input tensors that will no longer be needed
m_LstmLayer.prepare();
arm_compute::TensorInfo aclProjectionBiasInfo;
arm_compute::TensorInfo aclCellToForgetWeightsInfo;
arm_compute::TensorInfo aclCellToOutputWeightsInfo;
+ arm_compute::TensorInfo aclInputLayerNormWeightsInfo;
+ arm_compute::TensorInfo aclForgetLayerNormWeightsInfo;
+ arm_compute::TensorInfo aclCellLayerNormWeightsInfo;
+ arm_compute::TensorInfo aclOutputLayerNormWeightsInfo;
if (!descriptor.m_CifgEnabled)
{
throw armnn::Exception("Wrong Type of Activation Function!");
}
+ if (descriptor.m_LayerNormEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputLayerNormWeights());
+ }
+
+ aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetLayerNormWeights());
+
+ aclCellLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellLayerNormWeights());
+
+ aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputLayerNormWeights());
+
+ lstm_params_info.set_layer_normalization_params(descriptor.m_CifgEnabled ?
+ nullptr : &aclInputLayerNormWeightsInfo,
+ &aclForgetLayerNormWeightsInfo,
+ &aclCellLayerNormWeightsInfo,
+ &aclOutputLayerNormWeightsInfo);
+ }
+
return arm_compute::CLLSTMLayer::validate(&aclInputInfo, &aclInputToForgetWeightsInfo,
&aclInputToCellWeightsInfo,
&aclInputToOutputWeightsInfo,
FreeTensorIfUnused(m_ProjectionWeightsTensor);
FreeTensorIfUnused(m_ProjectionBiasTensor);
FreeTensorIfUnused(m_ScratchBuffer);
+ FreeTensorIfUnused(m_InputLayerNormWeightsTensor);
+ FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor);
+ FreeTensorIfUnused(m_CellLayerNormWeightsTensor);
+ FreeTensorIfUnused(m_OutputLayerNormWeightsTensor);
}
} //namespace armnn