From: James Conroy Date: Wed, 29 Apr 2020 19:01:10 +0000 (+0100) Subject: IVGCVSW-4449 Add QLstm ref implementation X-Git-Tag: submit/tizen/20200730.023729~80 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=4f1f899da140bb0490cf7e404daeaf1206f4db8b;p=platform%2Fupstream%2Farmnn.git IVGCVSW-4449 Add QLstm ref implementation * Adds ref implemenation for new HAL 1.3 operator, QLstm. * Adds Layer and CreateWorkload unit tests. * Adds WorkloadData validate for QLstm. Signed-off-by: James Conroy Change-Id: I8a721f07ff06105e6495a1a0561b9503aa8146dc --- diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp index 05d0e2f..f484a21 100644 --- a/src/armnn/test/CreateWorkload.hpp +++ b/src/armnn/test/CreateWorkload.hpp @@ -516,6 +516,167 @@ std::unique_ptr CreateQuantizedLstmWorkloadTest(armnn::IW return workload; } +template +std::unique_ptr CreateQLstmWorkloadTest(armnn::IWorkloadFactory& factory, + armnn::Graph& graph) +{ + QLstmDescriptor layerDesc; + layerDesc.m_CifgEnabled = true; + layerDesc.m_PeepholeEnabled = false; + layerDesc.m_ProjectionEnabled = false; + layerDesc.m_LayerNormEnabled = true; + + layerDesc.m_CellClip = 0.0f; + layerDesc.m_ProjectionClip = 0.0f; + + layerDesc.m_HiddenStateZeroPoint = 0; + layerDesc.m_HiddenStateScale = 0.007f; + + layerDesc.m_InputIntermediateScale = 0.007059f; + layerDesc.m_ForgetIntermediateScale = 0.007812f; + layerDesc.m_CellIntermediateScale = 0.007059f; + layerDesc.m_OutputIntermediateScale = 0.007812f; + + QLstmLayer* const layer = graph.AddLayer(layerDesc, "qLstm"); + + unsigned int numBatches = 2; + unsigned int inputSize = 4; + unsigned int numUnits = 4; + unsigned int outputSize = 4; + + // Scale/Offset quantization info + float inputScale = 0.0078125f; + int32_t inputOffset = 0; + + // if (!projectionEnabled) outputScale == hiddenStateScale + float outputScale = layerDesc.m_HiddenStateScale; + int32_t outputOffset = layerDesc.m_HiddenStateZeroPoint; + + float cellStateScale = 3.05176e-05f; + int32_t cellStateOffset = 0; + + float weightsScale = 0.00784314f; + int32_t weightsOffset = 0; + + float layerNormScale = 3.05182e-05f; + int32_t layerNormOffset = 0; + + float biasScale = layerNormScale / 1024; + int32_t biasOffset = 0; + + // Weights and bias tensor and quantization info + armnn::TensorInfo inputWeightsInfo({outputSize, inputSize}, + armnn::DataType::QSymmS8, + weightsScale, + weightsOffset); + + armnn::TensorInfo recurrentWeightsInfo({outputSize, outputSize}, + armnn::DataType::QSymmS8, + weightsScale, + weightsOffset); + + armnn::TensorInfo biasInfo({outputSize}, armnn::DataType::Signed32, biasScale, biasOffset); + + armnn::TensorInfo layerNormWeightsInfo({numUnits}, armnn::DataType::QSymmS16, layerNormScale, layerNormOffset); + + // Create and allocate tensors + layer->m_BasicParameters.m_InputToForgetWeights = std::make_unique(inputWeightsInfo); + layer->m_BasicParameters.m_InputToCellWeights = std::make_unique(inputWeightsInfo); + layer->m_BasicParameters.m_InputToOutputWeights = std::make_unique(inputWeightsInfo); + + layer->m_BasicParameters.m_RecurrentToForgetWeights = + std::make_unique(recurrentWeightsInfo); + layer->m_BasicParameters.m_RecurrentToCellWeights = + std::make_unique(recurrentWeightsInfo); + layer->m_BasicParameters.m_RecurrentToOutputWeights = + std::make_unique(recurrentWeightsInfo); + + layer->m_BasicParameters.m_ForgetGateBias = std::make_unique(biasInfo); + layer->m_BasicParameters.m_CellBias = std::make_unique(biasInfo); + layer->m_BasicParameters.m_OutputGateBias = std::make_unique(biasInfo); + + layer->m_LayerNormParameters.m_ForgetLayerNormWeights = + std::make_unique(layerNormWeightsInfo); + layer->m_LayerNormParameters.m_CellLayerNormWeights = + std::make_unique(layerNormWeightsInfo); + layer->m_LayerNormParameters.m_OutputLayerNormWeights = + std::make_unique(layerNormWeightsInfo); + + layer->m_BasicParameters.m_InputToForgetWeights->Allocate(); + layer->m_BasicParameters.m_InputToCellWeights->Allocate(); + layer->m_BasicParameters.m_InputToOutputWeights->Allocate(); + + layer->m_BasicParameters.m_RecurrentToForgetWeights->Allocate(); + layer->m_BasicParameters.m_RecurrentToCellWeights->Allocate(); + layer->m_BasicParameters.m_RecurrentToOutputWeights->Allocate(); + + layer->m_BasicParameters.m_ForgetGateBias->Allocate(); + layer->m_BasicParameters.m_CellBias->Allocate(); + layer->m_BasicParameters.m_OutputGateBias->Allocate(); + + layer->m_LayerNormParameters.m_ForgetLayerNormWeights->Allocate(); + layer->m_LayerNormParameters.m_CellLayerNormWeights->Allocate(); + layer->m_LayerNormParameters.m_OutputLayerNormWeights->Allocate(); + + // Input and output layers + Layer* const input = graph.AddLayer(0, "input"); + Layer* const outputStateIn = graph.AddLayer(1, "outputStateIn"); + Layer* const cellStateIn = graph.AddLayer(2, "cellStateIn"); + + Layer* const outputStateOut = graph.AddLayer(0, "outputStateOut"); + Layer* const cellStateOut = graph.AddLayer(1, "cellStateOut"); + Layer* const output = graph.AddLayer(2, "output"); + + // Input/Output tensor info + armnn::TensorInfo inputInfo({numBatches , inputSize}, + armnn::DataType::QAsymmS8, + inputScale, + inputOffset); + + armnn::TensorInfo cellStateInfo({numBatches , numUnits}, + armnn::DataType::QSymmS16, + cellStateScale, + cellStateOffset); + + armnn::TensorInfo outputStateInfo({numBatches , outputSize}, + armnn::DataType::QAsymmS8, + outputScale, + outputOffset); + + // Connect layers to slots + Connect(input, layer, inputInfo, 0, 0); + Connect(outputStateIn, layer, outputStateInfo, 0, 1); + Connect(cellStateIn, layer, cellStateInfo, 0, 2); + + Connect(layer, outputStateOut, outputStateInfo, 0, 0); + Connect(layer, cellStateOut, cellStateInfo, 1, 0); + Connect(layer, output, outputStateInfo, 2, 0); + + CreateTensorHandles(graph, factory); + + // Create and check workload + auto workload = MakeAndCheckWorkload(*layer, factory); + QLstmQueueDescriptor queueDescriptor = workload->GetData(); + BOOST_TEST(queueDescriptor.m_Parameters.m_CellClip == 0.0f); + BOOST_TEST(queueDescriptor.m_Parameters.m_ProjectionClip == 0.0f); + BOOST_TEST(queueDescriptor.m_Inputs.size() == 3); + BOOST_TEST(queueDescriptor.m_Outputs.size() == 3); + + BOOST_TEST((queueDescriptor.m_InputToForgetWeights->GetTensorInfo() == inputWeightsInfo)); + BOOST_TEST((queueDescriptor.m_InputToCellWeights->GetTensorInfo() == inputWeightsInfo)); + BOOST_TEST((queueDescriptor.m_InputToOutputWeights->GetTensorInfo() == inputWeightsInfo)); + + BOOST_TEST((queueDescriptor.m_RecurrentToForgetWeights->GetTensorInfo() == recurrentWeightsInfo)); + BOOST_TEST((queueDescriptor.m_RecurrentToCellWeights->GetTensorInfo() == recurrentWeightsInfo)); + BOOST_TEST((queueDescriptor.m_RecurrentToOutputWeights->GetTensorInfo() == recurrentWeightsInfo)); + + BOOST_TEST((queueDescriptor.m_ForgetGateBias->GetTensorInfo() == biasInfo)); + BOOST_TEST((queueDescriptor.m_CellBias->GetTensorInfo() == biasInfo)); + BOOST_TEST((queueDescriptor.m_OutputGateBias->GetTensorInfo() == biasInfo)); + + return workload; +} + template std::unique_ptr CreateDirectConvolution2dWorkloadTest(armnn::IWorkloadFactory& factory, armnn::Graph& graph) diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index d1249a4..5796fc7 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -2844,6 +2844,292 @@ void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); } +void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + const std::string descriptorName{"QLstmQueueDescriptor"}; + + // Validate number of inputs/outputs + ValidateNumInputs(workloadInfo, descriptorName, 3); + ValidateNumOutputs(workloadInfo, descriptorName, 3); + + // Input/output tensor info + auto inputInfo = workloadInfo.m_InputTensorInfos[0]; + auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1]; + auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2]; + + auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0]; + auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1]; + auto outputInfo = workloadInfo.m_OutputTensorInfos[2]; + + // Supported types for various tensors in QLSTM + std::vector inputOutputSupportedTypes = + { + DataType::QAsymmS8 + }; + + std::vector cellStateSupportedTypes = + { + DataType::QSymmS16 + }; + + std::vector weightsSupportedTypes = + { + DataType::QSymmS8 + }; + + std::vector layerNormPeepholeWeightsSupportedTypes = + { + DataType::QSymmS16 + }; + + std::vector biasSupportedTypes = + { + DataType::Signed32 + }; + + // Validate types of input/output tensors + ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName); + ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName); + ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName); + + ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName); + ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName); + ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName); + + // Validate matching types of input/output tensors + ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn"); + ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName, + "outputStateIn", "outputStateOut"); + ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut"); + + // Infer number of batches, number of units, input size and output size from tensor dimensions + const uint32_t numBatches = inputInfo.GetShape()[0]; + const uint32_t inputSize = inputInfo.GetShape()[1]; + const uint32_t outputSize = outputStateInInfo.GetShape()[1]; + const uint32_t numUnits = cellStateInInfo.GetShape()[1]; + + // Validate number of dimensions and number of elements for input/output tensors + ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input"); + ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn"); + ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn"); + + ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut"); + ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut"); + ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output"); + + // Validate number of dimensions and number of elements for MANDATORY weight tensors + ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights"); + auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights"); + + ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights"); + auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights"); + + ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights"); + auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights"); + + ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights"); + auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize), + " RecurrentToForgetWeights"); + + ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights"); + auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights"); + + ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights"); + auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights"); + + // Validate data types for MANDATORY weights tensors (all should match each other) + ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName); + + ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName, + "inputToForgetWeights", "inputToCellWeights"); + ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName, + "inputToForgetWeights", "inputToOutputWeights"); + + ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName, + "inputToForgetWeights", "recurrentToForgeteights"); + ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName, + "inputToForgetWeights", "recurrentToCellWeights"); + ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName, + "inputToForgetWeights", "recurrentToOutputWeights"); + + // Validate number of dimensions and number of elements for MANDATORY bias tensors + ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias"); + auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo(); + ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias"); + + ValidatePointer(m_CellBias, descriptorName, "CellBias"); + auto cellBiasInfo = m_CellBias->GetTensorInfo(); + ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias"); + + ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias"); + auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo(); + ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias"); + + // Validate data types for MANDATORY bias tensors + ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName); + + ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName, + "forgetGateBias", "cellBias"); + ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName, + "forgetGateBias", "outputGateBias"); + + // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias) + const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias && + !m_Parameters.m_CifgEnabled) || + (!m_InputToInputWeights && !m_RecurrentToInputWeights && + !m_InputGateBias && m_Parameters.m_CifgEnabled)); + + if (!allCifgParamsPresentOrNot) + { + throw InvalidArgumentException(descriptorName + + ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present " + "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be " + "set appropriately."); + } + + if (!m_Parameters.m_CifgEnabled) + { + // Validate number of dimensions and number of elements + auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights"); + + auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize), + " RecurrentToInputWeights"); + + auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo(); + ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias"); + + // Validate data types + ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName, + "inputToForgetWeights", "inputToInputWeights"); + ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName, + "inputToForgetWeights", "recurrentToInputWeights"); + ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName, + "forgetGateBias", "inputGateBias"); + } + + // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights) + bool allPeepholeWeightsPresentOrNot = + (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights + && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled) + || (!m_CellToInputWeights && !m_CellToForgetWeights + && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled)); + + if (!allPeepholeWeightsPresentOrNot) + { + throw InvalidArgumentException(descriptorName + + ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole " + "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present " + "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set " + "appropriately."); + } + + if (m_Parameters.m_PeepholeEnabled) + { + auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights"); + ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName); + + auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights"); + ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName, + "cellToForgetWeight", "cellToOutputWeights"); + + if (!m_Parameters.m_CifgEnabled) + { + auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights"); + ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName, + "cellToForgetWeights", "cellToInputWeights"); + } + } + + // Validate OPTIONAL params: Layer Norm Weights + bool allLayerNormWeightsPresentOrNot = + (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights + && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled) + || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights + && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled)); + + if (!allLayerNormWeightsPresentOrNot) + { + throw InvalidArgumentException(descriptorName + + ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights " + "and CellLayerNormWeights should all be present (Layer Norm enabled) or not " + "be present at all (Layer Norm disabled). InputLayerNormWeights should " + "only be present when Layer Norm is enabled and CIFG is disabled. " + "m_Parameters.m_LayerNormEnabled should be set appropriately."); + } + + if (m_Parameters.m_LayerNormEnabled) + { + auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights"); + ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName); + + auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights"); + ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName, + "forgetLayerNormWeights", "cellLayerNormWeights"); + + auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights"); + ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName, + "forgetLayerNormWeights", "outputLayerNormWeights"); + + if (!m_Parameters.m_CifgEnabled) + { + auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights"); + ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName, + "forgetLayerNormWeights", "inputLayerNormWeights"); + } + } + + // Validate OPTIONAL params: Projection (projectionWeights, projectionBias) + bool correctProjectionTensorsPresent = + ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) || + (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) || + (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled)); + + if (!correctProjectionTensorsPresent) + { + throw InvalidArgumentException(descriptorName + + ": If projection is enabled, ProjectionWeights should be present and " + "ProjectionBias is optional. If projection is disabled, neither " + "ProjectionWeights nor ProjectionBias should be present."); + } + + if (m_Parameters.m_ProjectionEnabled) + { + auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights"); + ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName); + + if (m_ProjectionBias) + { + auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo(); + ValidateTensorNumDimNumElem(projectionBiasInfo, 1, numUnits, "ProjectionBias"); + ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName); + } + + } + else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) && + outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) { + throw InvalidArgumentException(descriptorName + + ": If projection is disabled, output quantization info (scale, offset) " + "should match HiddenStateScale and HiddenStateZeroPoint."); + } + +} + void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { const std::string descriptorName{"QuantizedLstmQueueDescriptor"}; diff --git a/src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp index 50ef5c9..0ae55e4 100644 --- a/src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp +++ b/src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp @@ -1733,6 +1733,243 @@ LayerTestResult QuantizedLstmTestImpl( return ret; } +// QLSTM +LayerTestResult QLstmTestImpl( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const boost::multi_array& input, + const boost::multi_array& outputExpected) +{ + IgnoreUnused(memoryManager); + unsigned int numBatches = 2; + unsigned int inputSize = 5; + unsigned int outputSize = 4; + unsigned int numUnits = 4; + + bool cifgEnabled = true; + bool peepholeEnabled = false; + bool projectionEnabled = false; + bool layerNormEnabled = true; + + // Scale/Offset quantization info + float inputScale = 0.0078125f; + int32_t inputOffset = 0; + + int32_t hiddenStateZeroPoint = 0; + float hiddenStateScale = 0.007f; + + // if (!projectionEnabled) outputScale == hiddenStateScale + float outputScale = hiddenStateScale; + int32_t outputOffset = hiddenStateZeroPoint; + + float cellStateScale = 3.05176e-05f; + int32_t cellStateOffset = 0; + + float weightsScale = 0.00784314f; + int32_t weightsOffset = 0; + + float layerNormScale = 3.05182e-05f; + int32_t layerNormOffset = 0; + + float biasScale = layerNormScale / 1024; + int32_t biasOffset = 0; + + float inputIntermediateScale = 0.007059f; + float forgetIntermediateScale = 0.007812f; + float cellIntermediateScale = inputIntermediateScale; + float outputIntermediateScale = forgetIntermediateScale; + + float cellClip = 0.0f; + float projectionClip = 0.0f; + + // Input/Output tensor info + armnn::TensorInfo inputInfo({numBatches , inputSize}, + armnn::DataType::QAsymmS8, + inputScale, + inputOffset); + + armnn::TensorInfo cellStateInfo({numBatches , numUnits}, + armnn::DataType::QSymmS16, + cellStateScale, + cellStateOffset); + + armnn::TensorInfo outputStateInfo({numBatches , outputSize}, + armnn::DataType::QAsymmS8, + outputScale, + outputOffset); + + LayerTestResult ret(outputStateInfo); + + // Input tensors + std::vector inputVector; + inputVector.assign(input.data(), input.data() + (numBatches * inputSize)); + auto inputTensor = MakeTensor(inputInfo, inputVector); + + std::vector cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0}; + auto cellStateInTensor = MakeTensor(cellStateInfo, cellStateInVector); + + std::vector outputStateInVector = {0, 0, 0, 0, 0, 0, 0, 02}; + auto outputStateInTensor = MakeTensor(outputStateInfo, outputStateInVector); + + // Output tensors + std::vector cellStateOutVector = {-11692, 9960, 5491, 8861, -9422, 7726, 2056, 13149}; + auto cellStateOutTensor = MakeTensor(cellStateInfo, cellStateOutVector); + + std::vector outputVector; + outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize)); + ret.outputExpected = MakeTensor(outputStateInfo, outputVector); + + // Create tensor handles + std::unique_ptr inputHandle = workloadFactory.CreateTensorHandle(inputInfo); + std::unique_ptr cellStateInHandle = + workloadFactory.CreateTensorHandle(cellStateInfo); + std::unique_ptr outputStateInHandle = + workloadFactory.CreateTensorHandle(outputStateInfo); + + std::unique_ptr outputStateOutHandle = workloadFactory.CreateTensorHandle(outputStateInfo); + std::unique_ptr cellStateOutHandle = + workloadFactory.CreateTensorHandle(cellStateInfo); + std::unique_ptr outputHandle = workloadFactory.CreateTensorHandle(outputStateInfo); + + armnn::QLstmQueueDescriptor data; + armnn::WorkloadInfo info; + + // Add inputs and outputs to workload + AddInputToWorkload(data, info, inputInfo, inputHandle.get()); + AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get()); + AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get()); + + AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get()); + AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get()); + AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get()); + + // Weights and bias tensor and quantization info + armnn::TensorInfo inputWeightsInfo({outputSize, inputSize}, + armnn::DataType::QSymmS8, + weightsScale, + weightsOffset); + + armnn::TensorInfo recurrentWeightsInfo({outputSize, outputSize}, + armnn::DataType::QSymmS8, + weightsScale, + weightsOffset); + + armnn::TensorInfo biasInfo({outputSize}, armnn::DataType::Signed32, biasScale, biasOffset); + + armnn::TensorInfo layerNormWeightsInfo({numUnits}, armnn::DataType::QSymmS16, layerNormScale, layerNormOffset); + + // Weights and bias tensor data + auto inputToForgetWeights = MakeTensor(inputWeightsInfo, + {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64}); + auto inputToCellWeights = MakeTensor(inputWeightsInfo, + {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77}); + auto inputToOutputWeights = MakeTensor(inputWeightsInfo, + {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51}); + + auto recurrentToForgetWeights = MakeTensor(recurrentWeightsInfo, + {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25, 25, 38, -13, 51}); + auto recurrentToCellWeights = MakeTensor(recurrentWeightsInfo, + {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25, 38, -13, 25, 64}); + auto recurrentToOutputWeights = MakeTensor(recurrentWeightsInfo, + {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25, 13, 64, 25, -38}); + + auto forgetGateBias = MakeTensor(biasInfo, {2147484, -6442451, -4294968, 2147484}); + auto cellBias = MakeTensor(biasInfo, {-1073742, 15461883, 5368709, 1717987}); + auto outputGateBias = MakeTensor(biasInfo, {1073742, -214748, 4294968, 2147484}); + + auto forgetLayerNormWeights = MakeTensor(layerNormWeightsInfo, {6553, 6553, 13107, 9830}); + auto cellLayerNormWeights = MakeTensor(layerNormWeightsInfo, {22937, 6553, 9830, 26214}); + auto outputLayerNormWeights = MakeTensor(layerNormWeightsInfo, {19660, 6553, 6553, 16384}); + + // ScopedCpuTensorHandles + armnn::ScopedCpuTensorHandle inputToForgetWeightsTensor(inputWeightsInfo); + armnn::ScopedCpuTensorHandle inputToCellWeightsTensor(inputWeightsInfo); + armnn::ScopedCpuTensorHandle inputToOutputWeightsTensor(inputWeightsInfo); + + armnn::ScopedCpuTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo); + armnn::ScopedCpuTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo); + armnn::ScopedCpuTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo); + + armnn::ScopedCpuTensorHandle forgetGateBiasTensor(biasInfo); + armnn::ScopedCpuTensorHandle cellBiasTensor(biasInfo); + armnn::ScopedCpuTensorHandle outputGateBiasTensor(biasInfo); + + armnn::ScopedCpuTensorHandle forgetLayerNormWeightsTensor(layerNormWeightsInfo); + armnn::ScopedCpuTensorHandle cellLayerNormWeightsTensor(layerNormWeightsInfo); + armnn::ScopedCpuTensorHandle outputLayerNormWeightsTensor(layerNormWeightsInfo); + + // Allocate and copy data + AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, &inputToForgetWeights[0][0]); + AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, &inputToCellWeights[0][0]); + AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, &inputToOutputWeights[0][0]); + + AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, &recurrentToForgetWeights[0][0]); + AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, &recurrentToCellWeights[0][0]); + AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, &recurrentToOutputWeights[0][0]); + + AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, &forgetGateBias[0]); + AllocateAndCopyDataToITensorHandle(&cellBiasTensor, &cellBias[0]); + AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, &outputGateBias[0]); + + AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, &forgetLayerNormWeights[0]); + AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, &cellLayerNormWeights[0]); + AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, &outputLayerNormWeights[0]); + + // Setup queue descriptor + data.m_InputToForgetWeights = &inputToForgetWeightsTensor; + data.m_InputToCellWeights = &inputToCellWeightsTensor; + data.m_InputToOutputWeights = &inputToOutputWeightsTensor; + + data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; + data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; + data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; + + data.m_ForgetGateBias = &forgetGateBiasTensor; + data.m_CellBias = &cellBiasTensor; + data.m_OutputGateBias = &outputGateBiasTensor; + + data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor; + data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor; + data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor; + + data.m_Parameters.m_CifgEnabled = cifgEnabled; + data.m_Parameters.m_PeepholeEnabled = peepholeEnabled; + data.m_Parameters.m_ProjectionEnabled = projectionEnabled; + data.m_Parameters.m_LayerNormEnabled = layerNormEnabled; + + data.m_Parameters.m_InputIntermediateScale = inputIntermediateScale; + data.m_Parameters.m_ForgetIntermediateScale = forgetIntermediateScale; + data.m_Parameters.m_CellIntermediateScale = cellIntermediateScale; + data.m_Parameters.m_OutputIntermediateScale = outputIntermediateScale; + + data.m_Parameters.m_HiddenStateZeroPoint = hiddenStateZeroPoint; + data.m_Parameters.m_HiddenStateScale = hiddenStateScale; + + data.m_Parameters.m_CellClip = cellClip; + data.m_Parameters.m_ProjectionClip = projectionClip; + + // Create workload and allocate tensor handles + std::unique_ptr workload = workloadFactory.CreateQLstm(data, info); + inputHandle->Allocate(); + outputStateInHandle->Allocate(); + cellStateInHandle->Allocate(); + + outputStateOutHandle->Allocate(); + cellStateOutHandle->Allocate(); + outputHandle->Allocate(); + + CopyDataToITensorHandle(inputHandle.get(), &inputTensor[0][0]); + CopyDataToITensorHandle(outputStateInHandle.get(), &outputStateInTensor[0][0]); + CopyDataToITensorHandle(cellStateInHandle.get(), &cellStateInTensor[0][0]); + + workload->Execute(); + + CopyDataFromITensorHandle(&ret.output[0][0], outputHandle.get()); + + return ret; +} + + } // anonymous namespace #if defined(ARMNNREF_ENABLED) @@ -2107,3 +2344,19 @@ LayerTestResult QuantizedLstmTest( return QuantizedLstmTestImpl(workloadFactory, memoryManager, input, expectedOutput); } + +// QLSTM +LayerTestResult QLstmTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) +{ + armnn::TensorInfo inputDesc({2, 5}, armnn::DataType::QAsymmS8); + boost::multi_array input = MakeTensor(inputDesc, std::vector( + {90, 102, 13, 26, 38, 102, 13, 26, 51, 64})); + + armnn::TensorInfo outputDesc({2, 4}, armnn::DataType::QAsymmS8); + boost::multi_array expectedOutput = MakeTensor(outputDesc, std::vector( + {-15, 21, 14, 20, -15, 15, 5, 27})); + + return QLstmTestImpl(workloadFactory, memoryManager, input, expectedOutput); +} diff --git a/src/backends/backendsCommon/test/layerTests/LstmTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/LstmTestImpl.hpp index dad1760..f1180ae 100644 --- a/src/backends/backendsCommon/test/layerTests/LstmTestImpl.hpp +++ b/src/backends/backendsCommon/test/layerTests/LstmTestImpl.hpp @@ -58,3 +58,11 @@ LayerTestResult LstmLayerInt16NoCifgNoPeepholeNoProjectionInt16Const LayerTestResult QuantizedLstmTest( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); + +// +// QLstm +// + +LayerTestResult QLstmTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 87d2921..034cd12 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -1573,6 +1573,30 @@ bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input, return supported; } +bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input, + const TensorInfo& previousOutputIn, + const TensorInfo& previousCellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, + const TensorInfo& output, + const QLstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported) const +{ + IgnoreUnused(input); + IgnoreUnused(previousOutputIn); + IgnoreUnused(previousCellStateIn); + IgnoreUnused(outputStateOut); + IgnoreUnused(cellStateOut); + IgnoreUnused(output); + IgnoreUnused(descriptor); + IgnoreUnused(paramsInfo); + + IgnoreUnused(reasonIfUnsupported); + + return true; +} + bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input, const TensorInfo& output, Optional reasonIfUnsupported) const diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index 30f45c3..eb89946 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -249,6 +249,16 @@ public: const TensorInfo& output, Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsQLstmSupported(const TensorInfo& input, + const TensorInfo& previousOutputIn, + const TensorInfo& previousCellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, + const TensorInfo& output, + const QLstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsReshapeSupported(const TensorInfo& input, const TensorInfo& output, const ReshapeDescriptor& descriptor, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 4566fe5..5ce997c 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -512,6 +512,12 @@ std::unique_ptr RefWorkloadFactory::CreatePrelu(const PreluQueueDescr return std::make_unique(descriptor, info); } +std::unique_ptr RefWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::make_unique(descriptor, info); +} + std::unique_ptr RefWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor, const WorkloadInfo& info) const { diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp index 9a53ae2..1c607c0 100644 --- a/src/backends/reference/RefWorkloadFactory.hpp +++ b/src/backends/reference/RefWorkloadFactory.hpp @@ -200,6 +200,9 @@ public: std::unique_ptr CreatePrelu(const PreluQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + std::unique_ptr CreateQLstm(const QLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + std::unique_ptr CreateQuantize(const QuantizeQueueDescriptor& descriptor, const WorkloadInfo& info) const override; diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index 239863f..8d7f63d 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -75,6 +75,7 @@ BACKEND_SOURCES := \ workloads/RefPermuteWorkload.cpp \ workloads/RefPooling2dWorkload.cpp \ workloads/RefPreluWorkload.cpp \ + workloads/RefQLstmWorkload.cpp \ workloads/RefQuantizeWorkload.cpp \ workloads/RefReshapeWorkload.cpp \ workloads/RefResizeBilinearWorkload.cpp \ diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index 4a57df7..437366a 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -1089,4 +1089,33 @@ BOOST_AUTO_TEST_CASE(CreateStackUint16Workload) RefCreateStackWorkloadTest({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2); } +template +static void RefCreateQLstmWorkloadTest() +{ + Graph graph; + RefWorkloadFactory factory; + + auto workload = CreateQLstmWorkloadTest(factory, graph); + + armnn::TensorInfo inputInfo({2 , 4}, armnn::DataType::QAsymmS8, 0.0078125f, 0); + + armnn::TensorInfo cellStateInfo({2 , 4}, armnn::DataType::QSymmS16, 3.05176e-05f, 0); + + armnn::TensorInfo outputInfo({2 , 4}, armnn::DataType::QAsymmS8, 0.007f, 0); + + QLstmQueueDescriptor queueDescriptor = workload->GetData(); + auto inputHandle = boost::polymorphic_downcast(queueDescriptor.m_Inputs[0]); + auto cellStateOutHandle = boost::polymorphic_downcast(queueDescriptor.m_Outputs[1]); + auto outputHandle = boost::polymorphic_downcast(queueDescriptor.m_Outputs[2]); + + BOOST_TEST((inputHandle->GetTensorInfo() == inputInfo)); + BOOST_TEST((cellStateOutHandle->GetTensorInfo() == cellStateInfo)); + BOOST_TEST((outputHandle->GetTensorInfo() == outputInfo)); +} + +BOOST_AUTO_TEST_CASE(CreateQLstmWorkloadTest) +{ + RefCreateQLstmWorkloadTest(); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index f50051a..d8dab3d 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -1255,6 +1255,9 @@ ARMNN_AUTO_TEST_CASE(LstmLayerInt16NoCifgWithPeepholeWithProjection, ARMNN_AUTO_TEST_CASE(LstmLayerInt16NoCifgNoPeepholeNoProjectionInt16Constant, LstmLayerInt16NoCifgNoPeepholeNoProjectionInt16ConstantTest) +// QLstm +ARMNN_AUTO_TEST_CASE(QLstm, QLstmTest) + // Convert from BFloat16 to Float32 ARMNN_AUTO_TEST_CASE(ConvertBf16ToFp32, ConvertBf16ToFp32Test) diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 9f3880e..1abdb0b 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -123,6 +123,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefPreluWorkload.hpp RefQuantizeWorkload.cpp RefQuantizeWorkload.hpp + RefQLstmWorkload.cpp + RefQLstmWorkload.hpp RefReshapeWorkload.cpp RefReshapeWorkload.hpp RefResizeBilinearWorkload.cpp diff --git a/src/backends/reference/workloads/RefQLstmWorkload.cpp b/src/backends/reference/workloads/RefQLstmWorkload.cpp new file mode 100644 index 0000000..34d048b --- /dev/null +++ b/src/backends/reference/workloads/RefQLstmWorkload.cpp @@ -0,0 +1,519 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefQLstmWorkload.hpp" +#include "Activation.hpp" +#include "Encoders.hpp" +#include "Decoders.hpp" +#include "LstmUtils.hpp" +#include "RefWorkloadUtils.hpp" + +namespace armnn +{ + +RefQLstmWorkload::RefQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info) + : BaseWorkload(descriptor, info) + , m_InputToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights)) + , m_InputToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights)) + , m_InputToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights)) + , m_InputToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights)) + + , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights)) + , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights)) + , m_RecurrentToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights)) + , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights)) + + , m_CellToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights)) + , m_CellToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights)) + , m_CellToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights)) + + , m_InputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias)) + , m_ForgetGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias)) + , m_CellBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_CellBias)) + , m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias)) + + , m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights)) + , m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias)) + + , m_InputLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputLayerNormWeights)) + , m_ForgetLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetLayerNormWeights)) + , m_CellLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellLayerNormWeights)) + , m_OutputLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputLayerNormWeights)) +{} + +void RefQLstmWorkload::Execute() const +{ + // This is a porting of the QLSTM::Execute() method in the Android code base + // Note: this implementation wraps the arithmetic functions of the LSTM cell in Quantize/Dequantize ops, so all + // computation is done in the floating point domain. Arithmetic functions are found in LstmUtils.cpp. + // Refer to: android/frameworks/ml/nn/common/operations/QLSTM.cpp + const DataType& internalType = armnn::DataType::QSymmS16; + + const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& outputStateInInfo = GetTensorInfo(m_Data.m_Inputs[1]); + const TensorInfo& cellStateInInfo = GetTensorInfo(m_Data.m_Inputs[2]); + + const TensorInfo& outputStateOutInfo = GetTensorInfo(m_Data.m_Outputs[0]); + const TensorInfo& cellStateOutInfo = GetTensorInfo(m_Data.m_Outputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[2]); + + const TensorShape& inputShape = inputInfo.GetShape(); + const TensorShape& outputStateInShape = outputStateInInfo.GetShape(); + const TensorShape& cellStateInShape = cellStateInInfo.GetShape(); + + // Infer numBatches, inputSize, outputSize and numUnits + const uint32_t numBatches = inputShape[0]; + const uint32_t inputSize = inputShape[1]; + const uint32_t outputSize = outputStateInShape[1]; + const uint32_t numUnits = cellStateInShape[1]; + + // Optional param settings + const bool cifgEnabled = m_Data.m_Parameters.m_CifgEnabled; + const bool peepholeEnabled = m_Data.m_Parameters.m_PeepholeEnabled; + const bool projectionEnabled = m_Data.m_Parameters.m_ProjectionEnabled; + const bool layerNormEnabled = m_Data.m_Parameters.m_LayerNormEnabled; + + // Input decoders + std::unique_ptr> inputDecoder = + MakeDecoder(inputInfo, m_Data.m_Inputs[0]->Map()); + std::unique_ptr> outputStateInDecoder = + MakeDecoder(outputStateInInfo, m_Data.m_Inputs[1]->Map()); + std::unique_ptr> cellStateInDecoder = + MakeDecoder(cellStateInInfo, m_Data.m_Inputs[2]->Map()); + + // Output decoders + std::unique_ptr> outputStateOutDecoder = + MakeDecoder(outputStateOutInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr> cellStateOutDecoder = + MakeDecoder(cellStateOutInfo, m_Data.m_Outputs[1]->Map()); + std::unique_ptr> outputDecoder = + MakeDecoder(outputInfo, m_Data.m_Outputs[2]->Map()); + + // Output encoders + std::unique_ptr> outputStateOutEncoder = + MakeEncoder(outputStateOutInfo, m_Data.m_Outputs[0]->Map()); + std::unique_ptr> cellStateOutEncoder = + MakeEncoder(cellStateOutInfo, m_Data.m_Outputs[1]->Map()); + std::unique_ptr> outputEncoder = + MakeEncoder(outputInfo, m_Data.m_Outputs[2]->Map()); + + // Weights decoders + std::unique_ptr> inputToForgetWeightsDecoder = MakeDecoder( + m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetTensor()); + std::unique_ptr> inputToCellWeightsDecoder = MakeDecoder( + m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetTensor()); + std::unique_ptr> inputToOutputWeightsDecoder = MakeDecoder( + m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetTensor()); + + std::unique_ptr> recurrentToForgetWeightsDecoder = MakeDecoder( + m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetTensor()); + std::unique_ptr> recurrentToCellWeightsDecoder = MakeDecoder( + m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetTensor()); + std::unique_ptr> recurrentToOutputWeightsDecoder = MakeDecoder( + m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetTensor()); + + // Optional CIFG params + std::unique_ptr> inputToInputWeightsDecoder; + std::unique_ptr> recurrentToInputWeightsDecoder; + std::unique_ptr> inputGateBiasDecoder; + + // Optional Peephole params + std::unique_ptr> cellToInputWeightsDecoder; + std::unique_ptr> cellToForgetWeightsDecoder; + std::unique_ptr> cellToOutputWeightsDecoder; + + // Optional Projection params + std::unique_ptr> projectionWeightsDecoder; + std::unique_ptr> projectionBiasDecoder; + + // Optional Layer Norm params + std::unique_ptr> inputLayerNormWeightsDecoder; + std::unique_ptr> forgetLayerNormWeightsDecoder; + std::unique_ptr> cellLayerNormWeightsDecoder; + std::unique_ptr> outputLayerNormWeightsDecoder; + + // Biases are only used when Layer Norm is enabled. Scale is defined as (XLayerNormWeights Scale / 1024) + std::unique_ptr> forgetGateBiasDecoder; + std::unique_ptr> cellGateBiasDecoder; + std::unique_ptr> outputGateBiasDecoder; + + // Int16 vectors for internal state data (to be decoded/encoded) + const uint32_t stateTensorSize = numBatches * numUnits; + std::vector inputGateData(stateTensorSize); + std::vector cellGateData(stateTensorSize); + std::vector forgetGateData(stateTensorSize); + std::vector outputGateData(stateTensorSize); + std::vector hiddenStateData(stateTensorSize); + + armnn::TensorInfo inputGateInfo( + {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_InputIntermediateScale, 0); + armnn::TensorInfo cellGateInfo( + {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_CellIntermediateScale, 0); + armnn::TensorInfo forgetGateInfo( + {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_ForgetIntermediateScale, 0); + armnn::TensorInfo outputGateInfo( + {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_OutputIntermediateScale, 0); + armnn::TensorInfo hiddenStateInfo({numBatches, numUnits}, + armnn::DataType::QAsymmS8, + m_Data.m_Parameters.m_HiddenStateScale, + m_Data.m_Parameters.m_HiddenStateZeroPoint); + + // Decoders/Encoders for internal states + std::unique_ptr> inputGateDecoder = + MakeDecoder(inputGateInfo, inputGateData.data()); + std::unique_ptr> cellGateDecoder = + MakeDecoder(cellGateInfo, cellGateData.data()); + std::unique_ptr> forgetGateDecoder = + MakeDecoder(forgetGateInfo, forgetGateData.data()); + std::unique_ptr> outputGateDecoder = + MakeDecoder(outputGateInfo, outputGateData.data()); + std::unique_ptr> hiddenStateDecoder = + MakeDecoder(hiddenStateInfo, hiddenStateData.data()); + + std::unique_ptr> inputGateEncoder = + MakeEncoder(inputGateInfo, inputGateData.data()); + std::unique_ptr> cellGateEncoder = + MakeEncoder(cellGateInfo, cellGateData.data()); + std::unique_ptr> forgetGateEncoder = + MakeEncoder(forgetGateInfo, forgetGateData.data()); + std::unique_ptr> outputGateEncoder = + MakeEncoder(outputGateInfo, outputGateData.data()); + std::unique_ptr> hiddenStateEncoder = + MakeEncoder(hiddenStateInfo, hiddenStateData.data()); + + // Create decoders for optional params if they are enabled + if (!cifgEnabled) + { + inputToInputWeightsDecoder = MakeDecoder( + m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetTensor()); + recurrentToInputWeightsDecoder = MakeDecoder( + m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetTensor()); + } + + if (peepholeEnabled) + { + if (!cifgEnabled) + { + cellToInputWeightsDecoder = MakeDecoder( + m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetTensor()); + } + cellToForgetWeightsDecoder = MakeDecoder( + m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetTensor()); + cellToOutputWeightsDecoder = MakeDecoder( + m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetTensor()); + } + + if (projectionEnabled) + { + projectionWeightsDecoder = MakeDecoder( + m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetTensor()); + if (m_ProjectionBiasTensor) + { + projectionBiasDecoder = MakeDecoder( + m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetTensor()); + } + } + + if (layerNormEnabled) + { + if (!cifgEnabled) + { + inputLayerNormWeightsDecoder = MakeDecoder( + m_InputLayerNormWeightsTensor->GetTensorInfo(), m_InputLayerNormWeightsTensor->GetTensor()); + + // Bias only used if layer norm enabled + armnn::TensorInfo inputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, + m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); + inputGateBiasDecoder = MakeDecoder( + inputGateBiasTensorInfo, m_InputGateBiasTensor->GetTensor()); + } + + forgetLayerNormWeightsDecoder = MakeDecoder( + m_ForgetLayerNormWeightsTensor->GetTensorInfo(), m_ForgetLayerNormWeightsTensor->GetTensor()); + cellLayerNormWeightsDecoder = MakeDecoder( + m_CellLayerNormWeightsTensor->GetTensorInfo(), m_CellLayerNormWeightsTensor->GetTensor()); + outputLayerNormWeightsDecoder = MakeDecoder( + m_OutputLayerNormWeightsTensor->GetTensorInfo(), m_OutputLayerNormWeightsTensor->GetTensor()); + + // Bias only used if layer norm enabled + armnn::TensorInfo forgetGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, + m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); + forgetGateBiasDecoder = MakeDecoder( + forgetGateBiasTensorInfo, m_ForgetGateBiasTensor->GetTensor()); + + armnn::TensorInfo cellGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, + m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); + cellGateBiasDecoder = MakeDecoder( + cellGateBiasTensorInfo, m_CellBiasTensor->GetTensor()); + + armnn::TensorInfo outputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, + m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); + outputGateBiasDecoder = MakeDecoder( + outputGateBiasTensorInfo, m_OutputGateBiasTensor->GetTensor()); + } + + // Initialize internal state tensors with zeroes. + if (!cifgEnabled) + { + ZeroVector(*inputGateEncoder, stateTensorSize); + } + ZeroVector(*forgetGateEncoder, stateTensorSize); + ZeroVector(*cellGateEncoder, stateTensorSize); + ZeroVector(*outputGateEncoder, stateTensorSize); + ZeroVector(*hiddenStateEncoder, stateTensorSize); + + // Input weights * Input + if (!cifgEnabled) + { + MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsDecoder, + numUnits, inputSize, *inputDecoder, numBatches, *inputGateEncoder); + } + + MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsDecoder, + numUnits, inputSize, *inputDecoder, numBatches, *forgetGateEncoder); + + MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsDecoder, + numUnits, inputSize, *inputDecoder, numBatches, *cellGateEncoder); + + MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsDecoder, + numUnits, inputSize, *inputDecoder, numBatches, *outputGateEncoder); + + // Recurrent weights * OutputStateIn + if (!cifgEnabled) + { + MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsDecoder, + numUnits, outputSize, *outputStateInDecoder, numBatches, *inputGateEncoder); + } + + MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsDecoder, + numUnits, outputSize, *outputStateInDecoder, numBatches, *forgetGateEncoder); + + MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsDecoder, + numUnits, outputSize, *outputStateInDecoder, numBatches, *cellGateEncoder); + + MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsDecoder, + numUnits, outputSize, *outputStateInDecoder, numBatches, *outputGateEncoder); + + // Input gate. + if (!cifgEnabled) + { + if (peepholeEnabled) + { + VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsDecoder, + numUnits, *cellStateInDecoder, numBatches, *inputGateEncoder); + } + + if (layerNormEnabled) + { + inputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * + m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * + 1024); + inputGateEncoder = MakeEncoder(inputGateInfo, inputGateData.data()); + + MeanStddevNormalization(*inputGateDecoder, + *inputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); + + inputGateDecoder = MakeDecoder(inputGateInfo, inputGateData.data()); + + VectorBatchVectorCwiseProduct(*inputLayerNormWeightsDecoder, + numUnits, *inputGateDecoder, numBatches, *inputGateEncoder); + + inputGateInfo.SetQuantizationScale(1.f / 4096); + inputGateEncoder = MakeEncoder(inputGateInfo, inputGateData.data()); + + VectorBatchVectorAdd(*inputGateBiasDecoder, + numUnits, *inputGateDecoder, numBatches, *inputGateEncoder); + + inputGateDecoder = MakeDecoder(inputGateInfo, inputGateData.data()); + } + + inputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); + inputGateEncoder = MakeEncoder(inputGateInfo, inputGateData.data()); + + // Input gate sigmoid + Activation(*inputGateDecoder, *inputGateEncoder, + TensorInfo({numUnits, numBatches}, internalType), + ActivationFunction::Sigmoid, 0, 0); + + inputGateDecoder = MakeDecoder(inputGateInfo, inputGateData.data()); + } + + // Forget gate + if (peepholeEnabled) + { + VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsDecoder, numUnits, + *cellStateInDecoder, numBatches, *forgetGateEncoder); + } + + if (layerNormEnabled) + { + // Quantize layer norm output to Input Scale * m_ForgetLayerNormWeightsTensor * 1024 + forgetGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * + m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * + 1024); + forgetGateEncoder = MakeEncoder(forgetGateInfo, forgetGateData.data()); + + + + MeanStddevNormalization(*forgetGateDecoder, + *forgetGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); + + + forgetGateDecoder = MakeDecoder(forgetGateInfo, forgetGateData.data()); + + VectorBatchVectorCwiseProduct(*forgetLayerNormWeightsDecoder, + numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder); + + + // Dequantize layer norm output to (1 / 4096) + forgetGateInfo.SetQuantizationScale(1.f / 4096); + forgetGateEncoder = MakeEncoder(forgetGateInfo, forgetGateData.data()); + + VectorBatchVectorAdd(*forgetGateBiasDecoder, + numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder); + + + forgetGateDecoder = MakeDecoder(forgetGateInfo, forgetGateData.data()); + } + + forgetGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); + forgetGateEncoder = MakeEncoder(forgetGateInfo, forgetGateData.data()); + + // Forget gate sigmoid + Activation(*forgetGateDecoder, *forgetGateEncoder, + TensorInfo({numUnits, numBatches}, internalType), + ActivationFunction::Sigmoid, 0, 0); + + forgetGateDecoder = MakeDecoder(forgetGateInfo, forgetGateData.data()); + + // Cell (Modulation) gate + if (layerNormEnabled) + { + cellGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * + m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * + 1024); + cellGateEncoder = MakeEncoder(cellGateInfo, cellGateData.data()); + + MeanStddevNormalization(*cellGateDecoder, *cellGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); + + cellGateDecoder = MakeDecoder(cellGateInfo, cellGateData.data()); + + VectorBatchVectorCwiseProduct(*cellLayerNormWeightsDecoder, + numUnits, *cellGateDecoder, numBatches, *cellGateEncoder); + + cellGateInfo.SetQuantizationScale(1.f / 4096); + cellGateEncoder = MakeEncoder(cellGateInfo, cellGateData.data()); + + VectorBatchVectorAdd(*cellGateBiasDecoder, + numUnits, *cellGateDecoder, numBatches, *cellGateEncoder); + + cellGateDecoder = MakeDecoder(cellGateInfo, cellGateData.data()); + } + + cellGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); + cellGateEncoder = MakeEncoder(cellGateInfo, cellGateData.data()); + + // Cell (Modulation) gate tanH + Activation(*cellGateDecoder, *cellGateEncoder, + TensorInfo({numUnits, numBatches}, internalType), + ActivationFunction::TanH, 1.0f, 1.0f); + + cellGateDecoder = MakeDecoder(cellGateInfo, cellGateData.data()); + + VectorVectorCwiseProduct(*forgetGateDecoder, *cellStateInDecoder, stateTensorSize, *cellStateOutEncoder); + + if (cifgEnabled) + { + Sub1Vector(*forgetGateDecoder, stateTensorSize, *forgetGateEncoder); + VectorVectorCwiseProductAccumulate( + *cellGateDecoder, *forgetGateDecoder, stateTensorSize, *cellStateOutEncoder); + } + else + { + VectorVectorCwiseProductAccumulate( + *cellGateDecoder, *inputGateDecoder, stateTensorSize, *cellStateOutEncoder); + } + + // Final cell state out calculated here + if (m_Data.m_Parameters.m_CellClip > 0.0) + { + ClipVector(*cellStateOutDecoder, stateTensorSize, m_Data.m_Parameters.m_CellClip, *cellStateOutEncoder); + } + + // Output gate. + if (peepholeEnabled) + { + VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsDecoder, + numUnits, *cellStateOutDecoder, numBatches, *outputGateEncoder); + } + + if (layerNormEnabled) + { + outputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * + m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * + 1024); + outputGateEncoder = MakeEncoder(outputGateInfo, outputGateData.data()); + + MeanStddevNormalization(*outputGateDecoder, *outputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); + + outputGateDecoder = MakeDecoder(outputGateInfo, outputGateData.data()); + + VectorBatchVectorCwiseProduct(*outputLayerNormWeightsDecoder, numUnits, *outputGateDecoder, + numBatches, *outputGateEncoder); + + outputGateInfo.SetQuantizationScale(1.f / 4096); + outputGateEncoder = MakeEncoder(outputGateInfo, outputGateData.data()); + + VectorBatchVectorAdd(*outputGateBiasDecoder, numUnits, *outputGateDecoder, numBatches, *outputGateEncoder); + + outputGateDecoder = MakeDecoder(outputGateInfo, outputGateData.data()); + } + + outputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); + outputGateEncoder = MakeEncoder(outputGateInfo, outputGateData.data()); + + // Output gate sigmoid + Activation(*outputGateDecoder, *outputGateEncoder, + TensorInfo({numUnits, numBatches}, internalType), + ActivationFunction::Sigmoid, 0, 0); + + outputGateDecoder = MakeDecoder(outputGateInfo, outputGateData.data()); + + // Hidden state tanH + Activation(*cellStateOutDecoder, *cellGateEncoder, + TensorInfo({numUnits, numBatches}, internalType), + ActivationFunction::TanH, 1.0f, 1.0f); + + // Final hidden state output + VectorVectorCwiseProduct(*outputGateDecoder, *cellGateDecoder, stateTensorSize, *hiddenStateEncoder); + + // Projection + if (m_Data.m_Parameters.m_ProjectionEnabled) + { + if (m_ProjectionBiasTensor) + { + VectorBatchVectorAssign(*projectionBiasDecoder, + outputSize, numBatches, *outputEncoder); + } + + MatrixBatchVectorMultiplyAccumulate(*projectionWeightsDecoder, + outputSize, numUnits, *hiddenStateDecoder, numBatches, *outputEncoder); + + if (m_Data.m_Parameters.m_ProjectionClip > 0.0) + { + ClipVector(*outputDecoder, numBatches * outputSize, m_Data.m_Parameters.m_ProjectionClip, *outputEncoder); + } + } + else + { + // Output has same quantization scale as hidden state if projection is disabled + CopyVector(*hiddenStateDecoder, numBatches * outputSize, *outputEncoder); + } + + // output == outputStateOut + CopyVector(*outputDecoder, numBatches * outputSize, *outputStateOutEncoder); +} + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefQLstmWorkload.hpp b/src/backends/reference/workloads/RefQLstmWorkload.hpp new file mode 100644 index 0000000..19d3a2a --- /dev/null +++ b/src/backends/reference/workloads/RefQLstmWorkload.hpp @@ -0,0 +1,54 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +#include +#include + +namespace armnn +{ + +class RefQLstmWorkload : public BaseWorkload +{ +public: + explicit RefQLstmWorkload(const QLstmQueueDescriptor& descriptor, const WorkloadInfo& info); + + virtual void Execute() const override; + +private: + std::unique_ptr m_InputToInputWeightsTensor; + std::unique_ptr m_InputToForgetWeightsTensor; + std::unique_ptr m_InputToCellWeightsTensor; + std::unique_ptr m_InputToOutputWeightsTensor; + + std::unique_ptr m_RecurrentToInputWeightsTensor; + std::unique_ptr m_RecurrentToForgetWeightsTensor; + std::unique_ptr m_RecurrentToCellWeightsTensor; + std::unique_ptr m_RecurrentToOutputWeightsTensor; + + std::unique_ptr m_CellToInputWeightsTensor; + std::unique_ptr m_CellToForgetWeightsTensor; + std::unique_ptr m_CellToOutputWeightsTensor; + + std::unique_ptr m_InputGateBiasTensor; + std::unique_ptr m_ForgetGateBiasTensor; + std::unique_ptr m_CellBiasTensor; + std::unique_ptr m_OutputGateBiasTensor; + + std::unique_ptr m_ProjectionWeightsTensor; + std::unique_ptr m_ProjectionBiasTensor; + + std::unique_ptr m_InputLayerNormWeightsTensor; + std::unique_ptr m_ForgetLayerNormWeightsTensor; + std::unique_ptr m_CellLayerNormWeightsTensor; + std::unique_ptr m_OutputLayerNormWeightsTensor; + + float m_LayerNormEpsilon = static_cast(1e-8); +}; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index cbfade3..e396a6b 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -48,6 +48,7 @@ #include "RefPermuteWorkload.hpp" #include "RefPadWorkload.hpp" #include "RefPreluWorkload.hpp" +#include "RefQLstmWorkload.hpp" #include "RefQuantizeWorkload.hpp" #include "RefReshapeWorkload.hpp" #include "RefResizeBilinearWorkload.hpp"