From a2ec9092f0bff018bfe7ae0cacb7e30bcc17c1c7 Mon Sep 17 00:00:00 2001 From: Jan Eilers Date: Mon, 8 Jul 2019 15:56:59 +0100 Subject: [PATCH] IVGCVSW-3338 Add CL backend support for LSTM normalization * Enable calls to LSTM normalization unit tests on CL backend. * Update CL workload to set the layer normalization parameters. !android-nn-driver:1461 Change-Id: Ia5a29918961c391c1f1d8f331add377a38822ddd Signed-off-by: Francis Murtagh Signed-off-by: Jan Eilers --- src/backends/cl/test/ClLayerTests.cpp | 3 + src/backends/cl/workloads/ClLstmFloatWorkload.cpp | 77 ++++++++++++++++++++--- src/backends/cl/workloads/ClLstmFloatWorkload.hpp | 4 ++ 3 files changed, 76 insertions(+), 8 deletions(-) diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp index ac96bf8..5575a05 100644 --- a/src/backends/cl/test/ClLayerTests.cpp +++ b/src/backends/cl/test/ClLayerTests.cpp @@ -354,6 +354,9 @@ ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgNoPeepholeNoProjection, ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjection, LstmLayerFloat32NoCifgWithPeepholeWithProjectionTest) +ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNorm, + LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNormTest) + // Convert from Float16 to Float32 ARMNN_AUTO_TEST_CASE(SimpleConvertFp16ToFp32, SimpleConvertFp16ToFp32Test) // Convert from Float32 to Float16 diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp index 3dbbbc3..f5d081e 100644 --- a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp +++ b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp @@ -100,6 +100,28 @@ ClLstmFloatWorkload::ClLstmFloatWorkload(const LstmQueueDescriptor &descriptor, lstm_param.set_peephole_params(m_CellToForgetWeightsTensor.get(), m_CellToOutputWeightsTensor.get()); } + if (m_Data.m_Parameters.m_LayerNormEnabled) + { + m_InputLayerNormWeightsTensor = std::make_unique(); + m_ForgetLayerNormWeightsTensor = std::make_unique(); + m_CellLayerNormWeightsTensor = std::make_unique(); + m_OutputLayerNormWeightsTensor = std::make_unique(); + + if (!m_Data.m_Parameters.m_CifgEnabled) + { + BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo()); + } + BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo()); + BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo()); + BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo()); + + lstm_param.set_layer_normalization_params(m_Data.m_Parameters.m_CifgEnabled ? nullptr : + m_InputLayerNormWeightsTensor.get(), + m_ForgetLayerNormWeightsTensor.get(), + m_CellLayerNormWeightsTensor.get(), + m_OutputLayerNormWeightsTensor.get()); + } + const arm_compute::ICLTensor& input = static_cast(m_Data.m_Inputs[0])->GetTensor(); const arm_compute::ICLTensor& output_state_in = static_cast(m_Data.m_Inputs[1])->GetTensor(); const arm_compute::ICLTensor& cell_state_in = static_cast(m_Data.m_Inputs[2])->GetTensor(); @@ -161,7 +183,6 @@ ClLstmFloatWorkload::ClLstmFloatWorkload(const LstmQueueDescriptor &descriptor, throw armnn::Exception("Wrong Type of Activation Function!"); } - m_LstmLayer.configure(&input, m_InputToForgetWeightsTensor.get(), m_InputToCellWeightsTensor.get(), m_InputToOutputWeightsTensor.get(), m_RecurrentToForgetWeightsTensor.get(), m_RecurrentToCellWeightsTensor.get(), m_RecurrentToOutputWeightsTensor.get(), @@ -172,15 +193,15 @@ ClLstmFloatWorkload::ClLstmFloatWorkload(const LstmQueueDescriptor &descriptor, armcomputetensorutils::InitialiseArmComputeTensorEmpty(*m_ScratchBuffer); - InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights); - InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights); - InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights); + InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights); + InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights); + InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights); InitializeArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights); - InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights); + InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights); InitializeArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights); - InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias); - InitializeArmComputeClTensorData(*m_CellBiasTensor, m_Data.m_CellBias); - InitializeArmComputeClTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias); + InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias); + InitializeArmComputeClTensorData(*m_CellBiasTensor, m_Data.m_CellBias); + InitializeArmComputeClTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias); if (!m_Data.m_Parameters.m_CifgEnabled) { @@ -208,6 +229,18 @@ ClLstmFloatWorkload::ClLstmFloatWorkload(const LstmQueueDescriptor &descriptor, InitializeArmComputeClTensorData(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights); } + if (m_Data.m_Parameters.m_LayerNormEnabled) + { + if (!m_Data.m_Parameters.m_CifgEnabled) + { + InitializeArmComputeClTensorData(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights); + } + + InitializeArmComputeClTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights); + InitializeArmComputeClTensorData(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights); + InitializeArmComputeClTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights); + } + // Force Compute Library to perform the necessary copying and reshaping, after which // delete all the input tensors that will no longer be needed m_LstmLayer.prepare(); @@ -262,6 +295,10 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T arm_compute::TensorInfo aclProjectionBiasInfo; arm_compute::TensorInfo aclCellToForgetWeightsInfo; arm_compute::TensorInfo aclCellToOutputWeightsInfo; + arm_compute::TensorInfo aclInputLayerNormWeightsInfo; + arm_compute::TensorInfo aclForgetLayerNormWeightsInfo; + arm_compute::TensorInfo aclCellLayerNormWeightsInfo; + arm_compute::TensorInfo aclOutputLayerNormWeightsInfo; if (!descriptor.m_CifgEnabled) { @@ -333,6 +370,26 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T throw armnn::Exception("Wrong Type of Activation Function!"); } + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputLayerNormWeights()); + } + + aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetLayerNormWeights()); + + aclCellLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellLayerNormWeights()); + + aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputLayerNormWeights()); + + lstm_params_info.set_layer_normalization_params(descriptor.m_CifgEnabled ? + nullptr : &aclInputLayerNormWeightsInfo, + &aclForgetLayerNormWeightsInfo, + &aclCellLayerNormWeightsInfo, + &aclOutputLayerNormWeightsInfo); + } + return arm_compute::CLLSTMLayer::validate(&aclInputInfo, &aclInputToForgetWeightsInfo, &aclInputToCellWeightsInfo, &aclInputToOutputWeightsInfo, @@ -369,6 +426,10 @@ void ClLstmFloatWorkload::FreeUnusedTensors() FreeTensorIfUnused(m_ProjectionWeightsTensor); FreeTensorIfUnused(m_ProjectionBiasTensor); FreeTensorIfUnused(m_ScratchBuffer); + FreeTensorIfUnused(m_InputLayerNormWeightsTensor); + FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor); + FreeTensorIfUnused(m_CellLayerNormWeightsTensor); + FreeTensorIfUnused(m_OutputLayerNormWeightsTensor); } } //namespace armnn diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.hpp b/src/backends/cl/workloads/ClLstmFloatWorkload.hpp index 9a3211a..5bd67c2 100644 --- a/src/backends/cl/workloads/ClLstmFloatWorkload.hpp +++ b/src/backends/cl/workloads/ClLstmFloatWorkload.hpp @@ -39,6 +39,10 @@ private: std::unique_ptr m_OutputGateBiasTensor; std::unique_ptr m_ProjectionWeightsTensor; std::unique_ptr m_ProjectionBiasTensor; + std::unique_ptr m_InputLayerNormWeightsTensor; + std::unique_ptr m_ForgetLayerNormWeightsTensor; + std::unique_ptr m_CellLayerNormWeightsTensor; + std::unique_ptr m_OutputLayerNormWeightsTensor; std::unique_ptr m_ScratchBuffer; -- 2.7.4