From e2062cdf1eb31b87860f9889f0e799e89f0dfa30 Mon Sep 17 00:00:00 2001 From: Jan Eilers Date: Mon, 30 Mar 2020 15:07:45 +0100 Subject: [PATCH] IVGCVSW-4590 Fix Lstm layers CellToInputWeights * CellToInputWeights were not handeled correctly * Changed CellToInputWeights from Cifg to peephole parameter * Modified exiting unit tests * Added unit test to cover new configuration * Added more descriptive error messages Signed-off-by: Jan Eilers Change-Id: Ied5dc1253d3df1fd1a79b887a58603d0a9c8f396 --- src/armnn/Network.cpp | 49 +++++--- src/armnn/layers/LstmLayer.cpp | 28 +++-- src/armnn/layers/LstmLayer.hpp | 4 +- src/armnn/test/ConstTensorLayerVisitor.cpp | 126 +++++++++++++++++++-- src/armnn/test/OptimizerTests.cpp | 9 +- src/backends/backendsCommon/WorkloadFactory.cpp | 13 ++- .../test/IsLayerSupportedTestImpl.hpp | 2 - 7 files changed, 181 insertions(+), 50 deletions(-) diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 0272b3d..c2da4da 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -1520,27 +1520,24 @@ IConnectableLayer* Network::AddLstmLayer(const LstmDescriptor& descriptor, { if(params.m_InputToInputWeights == nullptr) { - throw InvalidArgumentException("AddLstmLayer: Input To Input Weights cannot be NULL"); + throw InvalidArgumentException("AddLstmLayer: Input To Input Weights cannot be NULL " + "when CIFG is disabled."); } if(params.m_RecurrentToInputWeights == nullptr) { throw InvalidArgumentException( - "AddLstmLayer: Recurrent To Input Weights cannot be NULL"); + "AddLstmLayer: Recurrent To Input Weights cannot be NULL " + "when CIFG is disabled."); } if(params.m_InputGateBias == nullptr) { - throw InvalidArgumentException("AddLstmLayer: Input Gate Bias cannot be NULL"); + throw InvalidArgumentException("AddLstmLayer: Input Gate Bias cannot be NULL " + "when CIFG is disabled."); } layer->m_CifgParameters.m_InputToInputWeights = std::make_unique(*(params.m_InputToInputWeights)); layer->m_CifgParameters.m_RecurrentToInputWeights = std::make_unique(*(params.m_RecurrentToInputWeights)); - // In the VTS tests, cell-to-input weights may be null, even if the other CIFG params are not. - if(params.m_CellToInputWeights != nullptr) - { - layer->m_CifgParameters.m_CellToInputWeights = - std::make_unique(*(params.m_CellToInputWeights)); - } layer->m_CifgParameters.m_InputGateBias = std::make_unique(*(params.m_InputGateBias)); } @@ -1550,7 +1547,8 @@ IConnectableLayer* Network::AddLstmLayer(const LstmDescriptor& descriptor, { if(params.m_ProjectionWeights == nullptr) { - throw InvalidArgumentException("AddLstmLayer: Projection Weights cannot be NULL"); + throw InvalidArgumentException("AddLstmLayer: Projection Weights cannot be NULL " + "when projection is enabled."); } layer->m_ProjectionParameters.m_ProjectionWeights = std::make_unique(*(params.m_ProjectionWeights)); @@ -1564,14 +1562,29 @@ IConnectableLayer* Network::AddLstmLayer(const LstmDescriptor& descriptor, //Lstm Peephole params if(descriptor.m_PeepholeEnabled) { + if(!descriptor.m_CifgEnabled) + { + if(params.m_CellToInputWeights == nullptr) + { + throw InvalidArgumentException("AddLstmLayer: Cell To Input Weights cannot be NULL " + "when Peephole is enabled and CIFG disabled."); + } + + layer->m_PeepholeParameters.m_CellToInputWeights = + std::make_unique(*(params.m_CellToInputWeights)); + } + if(params.m_CellToForgetWeights == nullptr) { - throw InvalidArgumentException("AddLstmLayer: Cell To Forget Weights cannot be NULL"); + throw InvalidArgumentException("AddLstmLayer: Cell To Forget Weights cannot be NULL " + "when Peephole is enabled."); } if(params.m_CellToOutputWeights == nullptr) { - throw InvalidArgumentException("AddLstmLayer: Cell To Output Weights cannot be NULL"); + throw InvalidArgumentException("AddLstmLayer: Cell To Output Weights cannot be NULL " + "when Peephole is enabled."); } + layer->m_PeepholeParameters.m_CellToForgetWeights = std::make_unique(*(params.m_CellToForgetWeights)); layer->m_PeepholeParameters.m_CellToOutputWeights = @@ -1585,7 +1598,8 @@ IConnectableLayer* Network::AddLstmLayer(const LstmDescriptor& descriptor, { if(params.m_InputLayerNormWeights == nullptr) { - throw InvalidArgumentException("AddLstmLayer: Input layer normalization weights cannot be NULL"); + throw InvalidArgumentException("AddLstmLayer: Input layer normalization weights cannot be NULL " + "when layer normalization is enabled and CIFG disabled."); } layer->m_LayerNormParameters.m_InputLayerNormWeights = std::make_unique(*(params.m_InputLayerNormWeights)); @@ -1593,15 +1607,18 @@ IConnectableLayer* Network::AddLstmLayer(const LstmDescriptor& descriptor, if(params.m_ForgetLayerNormWeights == nullptr) { - throw InvalidArgumentException("AddLstmLayer: Forget layer normalization weights cannot be NULL"); + throw InvalidArgumentException("AddLstmLayer: Forget layer normalization weights cannot be NULL " + "when layer normalization is enabled."); } if(params.m_CellLayerNormWeights == nullptr) { - throw InvalidArgumentException("AddLstmLayer: Cell layer normalization weights cannot be NULL"); + throw InvalidArgumentException("AddLstmLayer: Cell layer normalization weights cannot be NULL " + "when layer normalization is enabled."); } if(params.m_OutputLayerNormWeights == nullptr) { - throw InvalidArgumentException("AddLstmLayer: Output layer normalization weights cannot be NULL"); + throw InvalidArgumentException("AddLstmLayer: Output layer normalization weights cannot be NULL " + "when layer normalization is enabled."); } layer->m_LayerNormParameters.m_ForgetLayerNormWeights = std::make_unique(*(params.m_ForgetLayerNormWeights)); diff --git a/src/armnn/layers/LstmLayer.cpp b/src/armnn/layers/LstmLayer.cpp index 581ba45..1d94569 100644 --- a/src/armnn/layers/LstmLayer.cpp +++ b/src/armnn/layers/LstmLayer.cpp @@ -39,7 +39,6 @@ std::unique_ptr LstmLayer::CreateWorkload(const IWorkloadFactory& fac { descriptor.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights.get(); descriptor.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights.get(); - descriptor.m_CellToInputWeights = m_CifgParameters.m_CellToInputWeights.get(); descriptor.m_InputGateBias = m_CifgParameters.m_InputGateBias.get(); } @@ -53,6 +52,10 @@ std::unique_ptr LstmLayer::CreateWorkload(const IWorkloadFactory& fac // Peephole parameters if (m_Param.m_PeepholeEnabled) { + if (!m_Param.m_CifgEnabled) + { + descriptor.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights.get(); + } descriptor.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights.get(); descriptor.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights.get(); } @@ -102,8 +105,6 @@ LstmLayer* LstmLayer::Clone(Graph& graph) const std::make_unique(*m_CifgParameters.m_InputToInputWeights) : nullptr; layer->m_CifgParameters.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights ? std::make_unique(*m_CifgParameters.m_RecurrentToInputWeights) : nullptr; - layer->m_CifgParameters.m_CellToInputWeights = m_CifgParameters.m_CellToInputWeights ? - std::make_unique(*m_CifgParameters.m_CellToInputWeights) : nullptr; layer->m_CifgParameters.m_InputGateBias = m_CifgParameters.m_InputGateBias ? std::make_unique(*m_CifgParameters.m_InputGateBias) : nullptr; } @@ -118,6 +119,11 @@ LstmLayer* LstmLayer::Clone(Graph& graph) const if (m_Param.m_PeepholeEnabled) { + if (!m_Param.m_CifgEnabled) + { + layer->m_PeepholeParameters.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights ? + std::make_unique(*m_PeepholeParameters.m_CellToInputWeights) : nullptr; + } layer->m_PeepholeParameters.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights ? std::make_unique(*m_PeepholeParameters.m_CellToForgetWeights) : nullptr; layer->m_PeepholeParameters.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights ? @@ -209,8 +215,6 @@ void LstmLayer::ValidateTensorShapesFromInputs() "LstmLayer: m_CifgParameters.m_InputToInputWeights should not have a value when CIFG is enabled."); BOOST_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights == nullptr, "LstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not have a value when CIFG is enabled."); - BOOST_ASSERT_MSG(m_CifgParameters.m_CellToInputWeights == nullptr, - "LstmLayer: m_CifgParameters.m_CellToInputWeights should not have a value when CIFG is enabled."); BOOST_ASSERT_MSG(m_CifgParameters.m_InputGateBias == nullptr, "LstmLayer: m_CifgParameters.m_InputGateBias should not have a value when CIFG is enabled."); @@ -228,6 +232,12 @@ void LstmLayer::ValidateTensorShapesFromInputs() if (m_Param.m_PeepholeEnabled) { + if (!m_Param.m_CifgEnabled) + { + BOOST_ASSERT_MSG(m_PeepholeParameters.m_CellToInputWeights != nullptr, + "LstmLayer: m_PeepholeParameters.m_CellToInputWeights should not be null " + "when Peephole is enabled and CIFG is disabled."); + } BOOST_ASSERT_MSG(m_PeepholeParameters.m_CellToForgetWeights != nullptr, "LstmLayer: m_PeepholeParameters.m_CellToForgetWeights should not be null."); BOOST_ASSERT_MSG(m_PeepholeParameters.m_CellToOutputWeights != nullptr, @@ -278,7 +288,6 @@ Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef() // Cifg parameters m_CifgParameters.m_InputToInputWeights, m_CifgParameters.m_RecurrentToInputWeights, - m_CifgParameters.m_CellToInputWeights, m_CifgParameters.m_InputGateBias, // Projection parameters @@ -286,6 +295,7 @@ Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef() m_ProjectionParameters.m_ProjectionBias, // Peephole parameters + m_PeepholeParameters.m_CellToInputWeights, m_PeepholeParameters.m_CellToForgetWeights, m_PeepholeParameters.m_CellToOutputWeights, @@ -368,10 +378,10 @@ void LstmLayer::Accept(ILayerVisitor& visitor) const inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; } ConstTensor cellToInputWeightsTensor; - if (m_CifgParameters.m_CellToInputWeights != nullptr) + if (m_PeepholeParameters.m_CellToInputWeights != nullptr) { - ConstTensor cellToInputWeightsTensorCopy(m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), - m_CifgParameters.m_CellToInputWeights->Map(true)); + ConstTensor cellToInputWeightsTensorCopy(m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(), + m_PeepholeParameters.m_CellToInputWeights->Map(true)); cellToInputWeightsTensor = cellToInputWeightsTensorCopy; inputParams.m_CellToInputWeights = &cellToInputWeightsTensor; } diff --git a/src/armnn/layers/LstmLayer.hpp b/src/armnn/layers/LstmLayer.hpp index 21421f2..5ccb4bc 100644 --- a/src/armnn/layers/LstmLayer.hpp +++ b/src/armnn/layers/LstmLayer.hpp @@ -30,8 +30,6 @@ struct LstmOptCifgParameters /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. std::unique_ptr m_RecurrentToInputWeights; /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. - std::unique_ptr m_CellToInputWeights; - /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. std::unique_ptr m_InputGateBias; }; @@ -46,6 +44,8 @@ struct LstmOptProjectionParameters struct LstmOptPeepholeParameters { /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::unique_ptr m_CellToInputWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. std::unique_ptr m_CellToForgetWeights; /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. std::unique_ptr m_CellToOutputWeights; diff --git a/src/armnn/test/ConstTensorLayerVisitor.cpp b/src/armnn/test/ConstTensorLayerVisitor.cpp index 7ef3dd2..ab83a89 100644 --- a/src/armnn/test/ConstTensorLayerVisitor.cpp +++ b/src/armnn/test/ConstTensorLayerVisitor.cpp @@ -861,11 +861,6 @@ BOOST_AUTO_TEST_CASE(CheckLstmLayerCifgDisabled) ConstTensor recurrentToInputWeights(TensorInfo( 4, recurrentToInputWeightsDimensions.data(), DataType::Float32), recurrentToInputWeightsData); - std::vector cellToInputWeightsData = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; - std::vector cellToInputWeightsDimensions = {1, 1, 3, 3}; - ConstTensor cellToInputWeights( - TensorInfo(4, cellToInputWeightsDimensions.data(), DataType::Float32), cellToInputWeightsData); - std::vector inputGateBiasData = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; std::vector inputGateBiasDimensions = {1, 1, 3, 3}; ConstTensor inputGateBias( @@ -884,7 +879,6 @@ BOOST_AUTO_TEST_CASE(CheckLstmLayerCifgDisabled) params.m_InputToInputWeights = &inputToInputWeights; params.m_RecurrentToInputWeights = &recurrentToInputWeights; - params.m_CellToInputWeights = &cellToInputWeights; params.m_InputGateBias = &inputGateBias; TestLstmLayerVisitor visitor(descriptor, params); @@ -959,11 +953,6 @@ BOOST_AUTO_TEST_CASE(CheckNamedLstmLayerCifgDisabled) ConstTensor recurrentToInputWeights(TensorInfo( 4, recurrentToInputWeightsDimensions.data(), DataType::Float32), recurrentToInputWeightsData); - std::vector cellToInputWeightsData = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; - std::vector cellToInputWeightsDimensions = {1, 1, 3, 3}; - ConstTensor cellToInputWeights( - TensorInfo(4, cellToInputWeightsDimensions.data(), DataType::Float32), cellToInputWeightsData); - std::vector inputGateBiasData = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; std::vector inputGateBiasDimensions = {1, 1, 3, 3}; ConstTensor inputGateBias( @@ -982,7 +971,6 @@ BOOST_AUTO_TEST_CASE(CheckNamedLstmLayerCifgDisabled) params.m_InputToInputWeights = &inputToInputWeights; params.m_RecurrentToInputWeights = &recurrentToInputWeights; - params.m_CellToInputWeights = &cellToInputWeights; params.m_InputGateBias = &inputGateBias; TestLstmLayerVisitor visitor(descriptor, params, layerName); @@ -1080,6 +1068,120 @@ BOOST_AUTO_TEST_CASE(CheckLstmLayerPeephole) layer->Accept(visitor); } +BOOST_AUTO_TEST_CASE(CheckLstmLayerPeepholeCifgDisabled) +{ + LstmDescriptor descriptor; + descriptor.m_ActivationFunc = 3; + descriptor.m_ClippingThresProj = 0.5f; + descriptor.m_ClippingThresCell = 0.3f; + descriptor.m_CifgEnabled = false; + descriptor.m_PeepholeEnabled = true; + + std::vector inputToForgetWeightsData = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + std::vector inputToForgetWeightsDimensions = {1, 1, 3, 3}; + ConstTensor inputToForgetWeights( + TensorInfo(4, inputToForgetWeightsDimensions.data(), DataType::Float32), inputToForgetWeightsData); + + std::vector inputToCellWeightsData = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + std::vector inputToCellWeightsDimensions = {1, 1, 3, 3}; + ConstTensor inputToCellWeights( + TensorInfo(4, inputToCellWeightsDimensions.data(), DataType::Float32), inputToCellWeightsData); + + std::vector inputToOutputWeightsData = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + std::vector inputToOutputWeightsDimensions = {1, 1, 3, 3}; + ConstTensor inputToOutputWeights( + TensorInfo(4, inputToOutputWeightsDimensions.data(), DataType::Float32), inputToOutputWeightsData); + + std::vector recurrentToForgetWeightsData = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + std::vector recurrentToForgetWeightsDimensions = {1, 1, 3, 3}; + ConstTensor recurrentToForgetWeights(TensorInfo( + 4, recurrentToForgetWeightsDimensions.data(), DataType::Float32), recurrentToForgetWeightsData); + + std::vector recurrentToCellWeightsData = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + std::vector recurrentToCellWeightsDimensions = {1, 1, 3, 3}; + ConstTensor recurrentToCellWeights(TensorInfo( + 4, recurrentToCellWeightsDimensions.data(), DataType::Float32), recurrentToCellWeightsData); + + std::vector recurrentToOutputWeightsData = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + std::vector recurrentToOutputWeightsDimensions = {1, 1, 3, 3}; + ConstTensor recurrentToOutputWeights(TensorInfo( + 4, recurrentToOutputWeightsDimensions.data(), DataType::Float32), recurrentToOutputWeightsData); + + std::vector forgetGateBiasData = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + std::vector forgetGateBiasDimensions = {1, 1, 3, 3}; + ConstTensor forgetGateBias(TensorInfo( + 4, forgetGateBiasDimensions.data(), DataType::Float32), forgetGateBiasData); + + std::vector cellBiasData = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + std::vector cellBiasDimensions = {1, 1, 3, 3}; + ConstTensor cellBias(TensorInfo( + 4, cellBiasDimensions.data(), DataType::Float32), cellBiasData); + + std::vector outputGateBiasData = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + std::vector outputGateBiasDimensions = {1, 1, 3, 3}; + ConstTensor outputGateBias(TensorInfo( + 4, outputGateBiasDimensions.data(), DataType::Float32), outputGateBiasData); + + std::vector cellToInputWeightsData = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + std::vector cellToInputWeightsDimensions = {1, 1, 3, 3}; + ConstTensor cellToInputWeights( + TensorInfo(4, cellToInputWeightsDimensions.data(), DataType::Float32), cellToInputWeightsData); + + std::vector cellToForgetWeightsData = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + std::vector cellToForgetWeightsDimensions = {1, 1, 3, 3}; + ConstTensor cellToForgetWeights( + TensorInfo(4, cellToForgetWeightsDimensions.data(), DataType::Float32), cellToForgetWeightsData); + + std::vector cellToOutputWeightsData = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + std::vector cellToOutputWeightsDimensions = {1, 1, 3, 3}; + ConstTensor cellToOutputWeights( + TensorInfo(4, cellToOutputWeightsDimensions.data(), DataType::Float32), cellToOutputWeightsData); + + std::vector inputToInputWeightsData = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + std::vector inputToInputWeightsDimensions = {1, 1, 3, 3}; + ConstTensor inputToInputWeights( + TensorInfo(4, inputToInputWeightsDimensions.data(), DataType::Float32), inputToInputWeightsData); + + std::vector recurrentToInputWeightsData = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + std::vector recurrentToInputWeightsDimensions = {1, 1, 3, 3}; + ConstTensor recurrentToInputWeights(TensorInfo( + 4, recurrentToInputWeightsDimensions.data(), DataType::Float32), recurrentToInputWeightsData); + + std::vector inputGateBiasData = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + std::vector inputGateBiasDimensions = {1, 1, 3, 3}; + ConstTensor inputGateBias( + TensorInfo(4, inputGateBiasDimensions.data(), DataType::Float32), inputGateBiasData); + + LstmInputParams params; + // Basic params + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + + // Peephole params + params.m_CellToInputWeights = &cellToInputWeights; + params.m_CellToForgetWeights = &cellToForgetWeights; + params.m_CellToOutputWeights = &cellToOutputWeights; + + // Cifg params + params.m_InputToInputWeights = &inputToInputWeights; + params.m_RecurrentToInputWeights = &recurrentToInputWeights; + params.m_InputGateBias = &inputGateBias; + + TestLstmLayerVisitor visitor(descriptor, params); + + Network net; + + IConnectableLayer* const layer = net.AddLstmLayer(descriptor, params); + layer->Accept(visitor); +} + BOOST_AUTO_TEST_CASE(CheckNamedLstmLayerPeephole) { const char* layerName = "LstmLayer"; diff --git a/src/armnn/test/OptimizerTests.cpp b/src/armnn/test/OptimizerTests.cpp index 56032ad..a7b23db 100644 --- a/src/armnn/test/OptimizerTests.cpp +++ b/src/armnn/test/OptimizerTests.cpp @@ -78,13 +78,10 @@ void CreateLSTMLayerHelper(Graph &graph, bool CifgEnabled) (TensorInfo({ numUnits, inputSize }, DataType::Float32)); layer->m_CifgParameters.m_RecurrentToInputWeights = std::make_unique (TensorInfo({ numUnits, outputSize }, DataType::Float32)); - layer->m_CifgParameters.m_CellToInputWeights = std::make_unique - (TensorInfo({ numUnits }, DataType::Float32)); layer->m_CifgParameters.m_InputGateBias = std::make_unique (TensorInfo({ numUnits }, DataType::Float32)); layer->m_CifgParameters.m_InputToInputWeights->Allocate(); layer->m_CifgParameters.m_RecurrentToInputWeights->Allocate(); - layer->m_CifgParameters.m_CellToInputWeights->Allocate(); layer->m_CifgParameters.m_InputGateBias->Allocate(); } @@ -100,6 +97,12 @@ void CreateLSTMLayerHelper(Graph &graph, bool CifgEnabled) if (layerDesc.m_PeepholeEnabled) { + if (!layerDesc.m_CifgEnabled) + { + layer->m_PeepholeParameters.m_CellToInputWeights = std::make_unique + (TensorInfo({ numUnits }, DataType::Float32)); + layer->m_PeepholeParameters.m_CellToInputWeights->Allocate(); + } layer->m_PeepholeParameters.m_CellToForgetWeights = std::make_unique (TensorInfo({ numUnits }, DataType::Float32)); layer->m_PeepholeParameters.m_CellToOutputWeights = std::make_unique diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 40ab798..5628c36 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -529,12 +529,6 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, optRecurrentToInputWeights = OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType); paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights; - if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr) - { - optCellToInputWeights = - OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType); - paramsInfo.m_CellToInputWeights = &optCellToInputWeights; - } optInputGateBias = OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType); paramsInfo.m_InputGateBias = &optInputGateBias; @@ -555,6 +549,13 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, if(descriptor.m_PeepholeEnabled) { + if(!descriptor.m_CifgEnabled) + { + optCellToInputWeights = + OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(), + dataType); + paramsInfo.m_CellToInputWeights = &optCellToInputWeights; + } optCellToForgetWeights = OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType); paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights; diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp index 7534c8a..dccfd1e 100644 --- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp +++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp @@ -288,8 +288,6 @@ struct DummyLstmLayer armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); m_Layer->m_CifgParameters.m_RecurrentToInputWeights = std::make_unique( armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); - m_Layer->m_CifgParameters.m_CellToInputWeights = std::make_unique( - armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); m_Layer->m_CifgParameters.m_InputGateBias = std::make_unique( armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); } -- 2.7.4