m_ParserFunctions[Layer_PermuteLayer] = &Deserializer::ParsePermute;
m_ParserFunctions[Layer_Pooling2dLayer] = &Deserializer::ParsePooling2d;
m_ParserFunctions[Layer_PreluLayer] = &Deserializer::ParsePrelu;
+ m_ParserFunctions[Layer_QLstmLayer] = &Deserializer::ParseQLstm;
m_ParserFunctions[Layer_QuantizeLayer] = &Deserializer::ParseQuantize;
m_ParserFunctions[Layer_QuantizedLstmLayer] = &Deserializer::ParseQuantizedLstm;
m_ParserFunctions[Layer_ReshapeLayer] = &Deserializer::ParseReshape;
return graphPtr->layers()->Get(layerIndex)->layer_as_Pooling2dLayer()->base();
case Layer::Layer_PreluLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_PreluLayer()->base();
+ case Layer::Layer_QLstmLayer:
+ return graphPtr->layers()->Get(layerIndex)->layer_as_QLstmLayer()->base();
case Layer::Layer_QuantizeLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_QuantizeLayer()->base();
case Layer::Layer_QuantizedLstmLayer:
RegisterOutputSlots(graph, layerIndex, layer);
}
+armnn::QLstmDescriptor Deserializer::GetQLstmDescriptor(Deserializer::QLstmDescriptorPtr qLstmDescriptor)
+{
+ armnn::QLstmDescriptor desc;
+
+ desc.m_CifgEnabled = qLstmDescriptor->cifgEnabled();
+ desc.m_PeepholeEnabled = qLstmDescriptor->peepholeEnabled();
+ desc.m_ProjectionEnabled = qLstmDescriptor->projectionEnabled();
+ desc.m_LayerNormEnabled = qLstmDescriptor->layerNormEnabled();
+
+ desc.m_CellClip = qLstmDescriptor->cellClip();
+ desc.m_ProjectionClip = qLstmDescriptor->projectionClip();
+
+ desc.m_InputIntermediateScale = qLstmDescriptor->inputIntermediateScale();
+ desc.m_ForgetIntermediateScale = qLstmDescriptor->forgetIntermediateScale();
+ desc.m_CellIntermediateScale = qLstmDescriptor->cellIntermediateScale();
+ desc.m_OutputIntermediateScale = qLstmDescriptor->outputIntermediateScale();
+
+ desc.m_HiddenStateScale = qLstmDescriptor->hiddenStateScale();
+ desc.m_HiddenStateZeroPoint = qLstmDescriptor->hiddenStateZeroPoint();
+
+ return desc;
+}
+
+void Deserializer::ParseQLstm(GraphPtr graph, unsigned int layerIndex)
+{
+ CHECK_LAYERS(graph, 0, layerIndex);
+
+ auto inputs = GetInputs(graph, layerIndex);
+ CHECK_VALID_SIZE(inputs.size(), 3);
+
+ auto outputs = GetOutputs(graph, layerIndex);
+ CHECK_VALID_SIZE(outputs.size(), 3);
+
+ auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_QLstmLayer();
+ auto layerName = GetLayerName(graph, layerIndex);
+ auto flatBufferDescriptor = flatBufferLayer->descriptor();
+ auto flatBufferInputParams = flatBufferLayer->inputParams();
+
+ auto qLstmDescriptor = GetQLstmDescriptor(flatBufferDescriptor);
+ armnn::LstmInputParams qLstmInputParams;
+
+ // Mandatory params
+ armnn::ConstTensor inputToForgetWeights = ToConstTensor(flatBufferInputParams->inputToForgetWeights());
+ armnn::ConstTensor inputToCellWeights = ToConstTensor(flatBufferInputParams->inputToCellWeights());
+ armnn::ConstTensor inputToOutputWeights = ToConstTensor(flatBufferInputParams->inputToOutputWeights());
+ armnn::ConstTensor recurrentToForgetWeights = ToConstTensor(flatBufferInputParams->recurrentToForgetWeights());
+ armnn::ConstTensor recurrentToCellWeights = ToConstTensor(flatBufferInputParams->recurrentToCellWeights());
+ armnn::ConstTensor recurrentToOutputWeights = ToConstTensor(flatBufferInputParams->recurrentToOutputWeights());
+ armnn::ConstTensor forgetGateBias = ToConstTensor(flatBufferInputParams->forgetGateBias());
+ armnn::ConstTensor cellBias = ToConstTensor(flatBufferInputParams->cellBias());
+ armnn::ConstTensor outputGateBias = ToConstTensor(flatBufferInputParams->outputGateBias());
+
+ qLstmInputParams.m_InputToForgetWeights = &inputToForgetWeights;
+ qLstmInputParams.m_InputToCellWeights = &inputToCellWeights;
+ qLstmInputParams.m_InputToOutputWeights = &inputToOutputWeights;
+ qLstmInputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ qLstmInputParams.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ qLstmInputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+ qLstmInputParams.m_ForgetGateBias = &forgetGateBias;
+ qLstmInputParams.m_CellBias = &cellBias;
+ qLstmInputParams.m_OutputGateBias = &outputGateBias;
+
+ // Optional CIFG params
+ armnn::ConstTensor inputToInputWeights;
+ armnn::ConstTensor recurrentToInputWeights;
+ armnn::ConstTensor inputGateBias;
+
+ if (!qLstmDescriptor.m_CifgEnabled)
+ {
+ inputToInputWeights = ToConstTensor(flatBufferInputParams->inputToInputWeights());
+ recurrentToInputWeights = ToConstTensor(flatBufferInputParams->recurrentToInputWeights());
+ inputGateBias = ToConstTensor(flatBufferInputParams->inputGateBias());
+
+ qLstmInputParams.m_InputToInputWeights = &inputToInputWeights;
+ qLstmInputParams.m_RecurrentToInputWeights = &recurrentToInputWeights;
+ qLstmInputParams.m_InputGateBias = &inputGateBias;
+ }
+
+ // Optional projection params
+ armnn::ConstTensor projectionWeights;
+ armnn::ConstTensor projectionBias;
+
+ if (qLstmDescriptor.m_ProjectionEnabled)
+ {
+ projectionWeights = ToConstTensor(flatBufferInputParams->projectionWeights());
+ projectionBias = ToConstTensor(flatBufferInputParams->projectionBias());
+
+ qLstmInputParams.m_ProjectionWeights = &projectionWeights;
+ qLstmInputParams.m_ProjectionBias = &projectionBias;
+ }
+
+ // Optional peephole params
+ armnn::ConstTensor cellToInputWeights;
+ armnn::ConstTensor cellToForgetWeights;
+ armnn::ConstTensor cellToOutputWeights;
+
+ if (qLstmDescriptor.m_PeepholeEnabled)
+ {
+ if (!qLstmDescriptor.m_CifgEnabled)
+ {
+ cellToInputWeights = ToConstTensor(flatBufferInputParams->cellToInputWeights());
+ qLstmInputParams.m_CellToInputWeights = &cellToInputWeights;
+ }
+
+ cellToForgetWeights = ToConstTensor(flatBufferInputParams->cellToForgetWeights());
+ cellToOutputWeights = ToConstTensor(flatBufferInputParams->cellToOutputWeights());
+
+ qLstmInputParams.m_CellToForgetWeights = &cellToForgetWeights;
+ qLstmInputParams.m_CellToOutputWeights = &cellToOutputWeights;
+ }
+
+ // Optional layer norm params
+ armnn::ConstTensor inputLayerNormWeights;
+ armnn::ConstTensor forgetLayerNormWeights;
+ armnn::ConstTensor cellLayerNormWeights;
+ armnn::ConstTensor outputLayerNormWeights;
+
+ if (qLstmDescriptor.m_LayerNormEnabled)
+ {
+ if (!qLstmDescriptor.m_CifgEnabled)
+ {
+ inputLayerNormWeights = ToConstTensor(flatBufferInputParams->inputLayerNormWeights());
+ qLstmInputParams.m_InputLayerNormWeights = &inputLayerNormWeights;
+ }
+
+ forgetLayerNormWeights = ToConstTensor(flatBufferInputParams->forgetLayerNormWeights());
+ cellLayerNormWeights = ToConstTensor(flatBufferInputParams->cellLayerNormWeights());
+ outputLayerNormWeights = ToConstTensor(flatBufferInputParams->outputLayerNormWeights());
+
+ qLstmInputParams.m_ForgetLayerNormWeights = &forgetLayerNormWeights;
+ qLstmInputParams.m_CellLayerNormWeights = &cellLayerNormWeights;
+ qLstmInputParams.m_OutputLayerNormWeights = &outputLayerNormWeights;
+ }
+
+ IConnectableLayer* layer = m_Network->AddQLstmLayer(qLstmDescriptor, qLstmInputParams, layerName.c_str());
+
+ armnn::TensorInfo outputStateOutInfo = ToTensorInfo(outputs[0]);
+ layer->GetOutputSlot(0).SetTensorInfo(outputStateOutInfo);
+
+ armnn::TensorInfo cellStateOutInfo = ToTensorInfo(outputs[1]);
+ layer->GetOutputSlot(1).SetTensorInfo(cellStateOutInfo);
+
+ armnn::TensorInfo outputInfo = ToTensorInfo(outputs[2]);
+ layer->GetOutputSlot(2).SetTensorInfo(outputInfo);
+
+ RegisterInputSlots(graph, layerIndex, layer);
+ RegisterOutputSlots(graph, layerIndex, layer);
+}
+
void Deserializer::ParseQuantizedLstm(GraphPtr graph, unsigned int layerIndex)
{
CHECK_LAYERS(graph, 0, layerIndex);
deserializedNetwork->Accept(checker);
}
+class VerifyQLstmLayer : public LayerVerifierBaseWithDescriptor<armnn::QLstmDescriptor>
+{
+public:
+ VerifyQLstmLayer(const std::string& layerName,
+ const std::vector<armnn::TensorInfo>& inputInfos,
+ const std::vector<armnn::TensorInfo>& outputInfos,
+ const armnn::QLstmDescriptor& descriptor,
+ const armnn::LstmInputParams& inputParams)
+ : LayerVerifierBaseWithDescriptor<armnn::QLstmDescriptor>(layerName, inputInfos, outputInfos, descriptor)
+ , m_InputParams(inputParams) {}
+
+ void VisitQLstmLayer(const armnn::IConnectableLayer* layer,
+ const armnn::QLstmDescriptor& descriptor,
+ const armnn::LstmInputParams& params,
+ const char* name)
+ {
+ VerifyNameAndConnections(layer, name);
+ VerifyDescriptor(descriptor);
+ VerifyInputParameters(params);
+ }
+
+protected:
+ void VerifyInputParameters(const armnn::LstmInputParams& params)
+ {
+ VerifyConstTensors(
+ "m_InputToInputWeights", m_InputParams.m_InputToInputWeights, params.m_InputToInputWeights);
+ VerifyConstTensors(
+ "m_InputToForgetWeights", m_InputParams.m_InputToForgetWeights, params.m_InputToForgetWeights);
+ VerifyConstTensors(
+ "m_InputToCellWeights", m_InputParams.m_InputToCellWeights, params.m_InputToCellWeights);
+ VerifyConstTensors(
+ "m_InputToOutputWeights", m_InputParams.m_InputToOutputWeights, params.m_InputToOutputWeights);
+ VerifyConstTensors(
+ "m_RecurrentToInputWeights", m_InputParams.m_RecurrentToInputWeights, params.m_RecurrentToInputWeights);
+ VerifyConstTensors(
+ "m_RecurrentToForgetWeights", m_InputParams.m_RecurrentToForgetWeights, params.m_RecurrentToForgetWeights);
+ VerifyConstTensors(
+ "m_RecurrentToCellWeights", m_InputParams.m_RecurrentToCellWeights, params.m_RecurrentToCellWeights);
+ VerifyConstTensors(
+ "m_RecurrentToOutputWeights", m_InputParams.m_RecurrentToOutputWeights, params.m_RecurrentToOutputWeights);
+ VerifyConstTensors(
+ "m_CellToInputWeights", m_InputParams.m_CellToInputWeights, params.m_CellToInputWeights);
+ VerifyConstTensors(
+ "m_CellToForgetWeights", m_InputParams.m_CellToForgetWeights, params.m_CellToForgetWeights);
+ VerifyConstTensors(
+ "m_CellToOutputWeights", m_InputParams.m_CellToOutputWeights, params.m_CellToOutputWeights);
+ VerifyConstTensors(
+ "m_InputGateBias", m_InputParams.m_InputGateBias, params.m_InputGateBias);
+ VerifyConstTensors(
+ "m_ForgetGateBias", m_InputParams.m_ForgetGateBias, params.m_ForgetGateBias);
+ VerifyConstTensors(
+ "m_CellBias", m_InputParams.m_CellBias, params.m_CellBias);
+ VerifyConstTensors(
+ "m_OutputGateBias", m_InputParams.m_OutputGateBias, params.m_OutputGateBias);
+ VerifyConstTensors(
+ "m_ProjectionWeights", m_InputParams.m_ProjectionWeights, params.m_ProjectionWeights);
+ VerifyConstTensors(
+ "m_ProjectionBias", m_InputParams.m_ProjectionBias, params.m_ProjectionBias);
+ VerifyConstTensors(
+ "m_InputLayerNormWeights", m_InputParams.m_InputLayerNormWeights, params.m_InputLayerNormWeights);
+ VerifyConstTensors(
+ "m_ForgetLayerNormWeights", m_InputParams.m_ForgetLayerNormWeights, params.m_ForgetLayerNormWeights);
+ VerifyConstTensors(
+ "m_CellLayerNormWeights", m_InputParams.m_CellLayerNormWeights, params.m_CellLayerNormWeights);
+ VerifyConstTensors(
+ "m_OutputLayerNormWeights", m_InputParams.m_OutputLayerNormWeights, params.m_OutputLayerNormWeights);
+ }
+
+private:
+ armnn::LstmInputParams m_InputParams;
+};
+
+BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmBasic)
+{
+ armnn::QLstmDescriptor descriptor;
+
+ descriptor.m_CifgEnabled = true;
+ descriptor.m_ProjectionEnabled = false;
+ descriptor.m_PeepholeEnabled = false;
+ descriptor.m_LayerNormEnabled = false;
+
+ descriptor.m_CellClip = 0.0f;
+ descriptor.m_ProjectionClip = 0.0f;
+
+ descriptor.m_InputIntermediateScale = 0.00001f;
+ descriptor.m_ForgetIntermediateScale = 0.00001f;
+ descriptor.m_CellIntermediateScale = 0.00001f;
+ descriptor.m_OutputIntermediateScale = 0.00001f;
+
+ descriptor.m_HiddenStateScale = 0.07f;
+ descriptor.m_HiddenStateZeroPoint = 0;
+
+ const unsigned int numBatches = 2;
+ const unsigned int inputSize = 5;
+ const unsigned int outputSize = 4;
+ const unsigned int numUnits = 4;
+
+ // Scale/Offset quantization info
+ float inputScale = 0.0078f;
+ int32_t inputOffset = 0;
+
+ float outputScale = 0.0078f;
+ int32_t outputOffset = 0;
+
+ float cellStateScale = 3.5002e-05f;
+ int32_t cellStateOffset = 0;
+
+ float weightsScale = 0.007f;
+ int32_t weightsOffset = 0;
+
+ float biasScale = 3.5002e-05f / 1024;
+ int32_t biasOffset = 0;
+
+ // Weights and bias tensor and quantization info
+ armnn::TensorInfo inputWeightsInfo({numUnits, inputSize},
+ armnn::DataType::QSymmS8,
+ weightsScale,
+ weightsOffset);
+
+ armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize},
+ armnn::DataType::QSymmS8,
+ weightsScale,
+ weightsOffset);
+
+ armnn::TensorInfo biasInfo({numUnits}, armnn::DataType::Signed32, biasScale, biasOffset);
+
+ std::vector<int8_t> inputToForgetWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+ std::vector<int8_t> inputToCellWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+ std::vector<int8_t> inputToOutputWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+
+ armnn::ConstTensor inputToForgetWeights(inputWeightsInfo, inputToForgetWeightsData);
+ armnn::ConstTensor inputToCellWeights(inputWeightsInfo, inputToCellWeightsData);
+ armnn::ConstTensor inputToOutputWeights(inputWeightsInfo, inputToOutputWeightsData);
+
+ std::vector<int8_t> recurrentToForgetWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+ std::vector<int8_t> recurrentToCellWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+ std::vector<int8_t> recurrentToOutputWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+
+ armnn::ConstTensor recurrentToForgetWeights(recurrentWeightsInfo, recurrentToForgetWeightsData);
+ armnn::ConstTensor recurrentToCellWeights(recurrentWeightsInfo, recurrentToCellWeightsData);
+ armnn::ConstTensor recurrentToOutputWeights(recurrentWeightsInfo, recurrentToOutputWeightsData);
+
+ std::vector<int32_t> forgetGateBiasData(numUnits, 1);
+ std::vector<int32_t> cellBiasData(numUnits, 0);
+ std::vector<int32_t> outputGateBiasData(numUnits, 0);
+
+ armnn::ConstTensor forgetGateBias(biasInfo, forgetGateBiasData);
+ armnn::ConstTensor cellBias(biasInfo, cellBiasData);
+ armnn::ConstTensor outputGateBias(biasInfo, outputGateBiasData);
+
+ // Set up params
+ armnn::LstmInputParams params;
+ params.m_InputToForgetWeights = &inputToForgetWeights;
+ params.m_InputToCellWeights = &inputToCellWeights;
+ params.m_InputToOutputWeights = &inputToOutputWeights;
+
+ params.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ params.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ params.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+
+ params.m_ForgetGateBias = &forgetGateBias;
+ params.m_CellBias = &cellBias;
+ params.m_OutputGateBias = &outputGateBias;
+
+ // Create network
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ const std::string layerName("qLstm");
+
+ armnn::IConnectableLayer* const input = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(1);
+ armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(2);
+
+ armnn::IConnectableLayer* const qLstmLayer = network->AddQLstmLayer(descriptor, params, layerName.c_str());
+
+ armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(0);
+ armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(1);
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(2);
+
+ // Input/Output tensor info
+ armnn::TensorInfo inputInfo({numBatches , inputSize},
+ armnn::DataType::QAsymmS8,
+ inputScale,
+ inputOffset);
+
+ armnn::TensorInfo cellStateInfo({numBatches , numUnits},
+ armnn::DataType::QSymmS16,
+ cellStateScale,
+ cellStateOffset);
+
+ armnn::TensorInfo outputStateInfo({numBatches , outputSize},
+ armnn::DataType::QAsymmS8,
+ outputScale,
+ outputOffset);
+
+ // Connect input/output slots
+ input->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(0));
+ input->GetOutputSlot(0).SetTensorInfo(inputInfo);
+
+ outputStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(1));
+ outputStateIn->GetOutputSlot(0).SetTensorInfo(cellStateInfo);
+
+ cellStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(2));
+ cellStateIn->GetOutputSlot(0).SetTensorInfo(outputStateInfo);
+
+ qLstmLayer->GetOutputSlot(0).Connect(outputStateOut->GetInputSlot(0));
+ qLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateInfo);
+
+ qLstmLayer->GetOutputSlot(1).Connect(cellStateOut->GetInputSlot(0));
+ qLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateInfo);
+
+ qLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0));
+ qLstmLayer->GetOutputSlot(2).SetTensorInfo(outputStateInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ BOOST_CHECK(deserializedNetwork);
+
+ VerifyQLstmLayer checker(layerName,
+ {inputInfo, cellStateInfo, outputStateInfo},
+ {outputStateInfo, cellStateInfo, outputStateInfo},
+ descriptor,
+ params);
+
+ deserializedNetwork->Accept(checker);
+}
+
+BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmCifgLayerNorm)
+{
+ armnn::QLstmDescriptor descriptor;
+
+ // CIFG params are used when CIFG is disabled
+ descriptor.m_CifgEnabled = true;
+ descriptor.m_ProjectionEnabled = false;
+ descriptor.m_PeepholeEnabled = false;
+ descriptor.m_LayerNormEnabled = true;
+
+ descriptor.m_CellClip = 0.0f;
+ descriptor.m_ProjectionClip = 0.0f;
+
+ descriptor.m_InputIntermediateScale = 0.00001f;
+ descriptor.m_ForgetIntermediateScale = 0.00001f;
+ descriptor.m_CellIntermediateScale = 0.00001f;
+ descriptor.m_OutputIntermediateScale = 0.00001f;
+
+ descriptor.m_HiddenStateScale = 0.07f;
+ descriptor.m_HiddenStateZeroPoint = 0;
+
+ const unsigned int numBatches = 2;
+ const unsigned int inputSize = 5;
+ const unsigned int outputSize = 4;
+ const unsigned int numUnits = 4;
+
+ // Scale/Offset quantization info
+ float inputScale = 0.0078f;
+ int32_t inputOffset = 0;
+
+ float outputScale = 0.0078f;
+ int32_t outputOffset = 0;
+
+ float cellStateScale = 3.5002e-05f;
+ int32_t cellStateOffset = 0;
+
+ float weightsScale = 0.007f;
+ int32_t weightsOffset = 0;
+
+ float layerNormScale = 3.5002e-05f;
+ int32_t layerNormOffset = 0;
+
+ float biasScale = layerNormScale / 1024;
+ int32_t biasOffset = 0;
+
+ // Weights and bias tensor and quantization info
+ armnn::TensorInfo inputWeightsInfo({numUnits, inputSize},
+ armnn::DataType::QSymmS8,
+ weightsScale,
+ weightsOffset);
+
+ armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize},
+ armnn::DataType::QSymmS8,
+ weightsScale,
+ weightsOffset);
+
+ armnn::TensorInfo biasInfo({numUnits},
+ armnn::DataType::Signed32,
+ biasScale,
+ biasOffset);
+
+ armnn::TensorInfo layerNormWeightsInfo({numUnits},
+ armnn::DataType::QSymmS16,
+ layerNormScale,
+ layerNormOffset);
+
+ // Mandatory params
+ std::vector<int8_t> inputToForgetWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+ std::vector<int8_t> inputToCellWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+ std::vector<int8_t> inputToOutputWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+
+ armnn::ConstTensor inputToForgetWeights(inputWeightsInfo, inputToForgetWeightsData);
+ armnn::ConstTensor inputToCellWeights(inputWeightsInfo, inputToCellWeightsData);
+ armnn::ConstTensor inputToOutputWeights(inputWeightsInfo, inputToOutputWeightsData);
+
+ std::vector<int8_t> recurrentToForgetWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+ std::vector<int8_t> recurrentToCellWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+ std::vector<int8_t> recurrentToOutputWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+
+ armnn::ConstTensor recurrentToForgetWeights(recurrentWeightsInfo, recurrentToForgetWeightsData);
+ armnn::ConstTensor recurrentToCellWeights(recurrentWeightsInfo, recurrentToCellWeightsData);
+ armnn::ConstTensor recurrentToOutputWeights(recurrentWeightsInfo, recurrentToOutputWeightsData);
+
+ std::vector<int32_t> forgetGateBiasData(numUnits, 1);
+ std::vector<int32_t> cellBiasData(numUnits, 0);
+ std::vector<int32_t> outputGateBiasData(numUnits, 0);
+
+ armnn::ConstTensor forgetGateBias(biasInfo, forgetGateBiasData);
+ armnn::ConstTensor cellBias(biasInfo, cellBiasData);
+ armnn::ConstTensor outputGateBias(biasInfo, outputGateBiasData);
+
+ // Layer Norm
+ std::vector<int16_t> forgetLayerNormWeightsData =
+ GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements());
+ std::vector<int16_t> cellLayerNormWeightsData =
+ GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements());
+ std::vector<int16_t> outputLayerNormWeightsData =
+ GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements());
+
+ armnn::ConstTensor forgetLayerNormWeights(layerNormWeightsInfo, forgetLayerNormWeightsData);
+ armnn::ConstTensor cellLayerNormWeights(layerNormWeightsInfo, cellLayerNormWeightsData);
+ armnn::ConstTensor outputLayerNormWeights(layerNormWeightsInfo, outputLayerNormWeightsData);
+
+ // Set up params
+ armnn::LstmInputParams params;
+
+ // Mandatory params
+ params.m_InputToForgetWeights = &inputToForgetWeights;
+ params.m_InputToCellWeights = &inputToCellWeights;
+ params.m_InputToOutputWeights = &inputToOutputWeights;
+
+ params.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ params.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ params.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+
+ params.m_ForgetGateBias = &forgetGateBias;
+ params.m_CellBias = &cellBias;
+ params.m_OutputGateBias = &outputGateBias;
+
+ // Layer Norm
+ params.m_ForgetLayerNormWeights = &forgetLayerNormWeights;
+ params.m_CellLayerNormWeights = &cellLayerNormWeights;
+ params.m_OutputLayerNormWeights = &outputLayerNormWeights;
+
+ // Create network
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ const std::string layerName("qLstm");
+
+ armnn::IConnectableLayer* const input = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(1);
+ armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(2);
+
+ armnn::IConnectableLayer* const qLstmLayer = network->AddQLstmLayer(descriptor, params, layerName.c_str());
+
+ armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(0);
+ armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(1);
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(2);
+
+ // Input/Output tensor info
+ armnn::TensorInfo inputInfo({numBatches , inputSize},
+ armnn::DataType::QAsymmS8,
+ inputScale,
+ inputOffset);
+
+ armnn::TensorInfo cellStateInfo({numBatches , numUnits},
+ armnn::DataType::QSymmS16,
+ cellStateScale,
+ cellStateOffset);
+
+ armnn::TensorInfo outputStateInfo({numBatches , outputSize},
+ armnn::DataType::QAsymmS8,
+ outputScale,
+ outputOffset);
+
+ // Connect input/output slots
+ input->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(0));
+ input->GetOutputSlot(0).SetTensorInfo(inputInfo);
+
+ outputStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(1));
+ outputStateIn->GetOutputSlot(0).SetTensorInfo(cellStateInfo);
+
+ cellStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(2));
+ cellStateIn->GetOutputSlot(0).SetTensorInfo(outputStateInfo);
+
+ qLstmLayer->GetOutputSlot(0).Connect(outputStateOut->GetInputSlot(0));
+ qLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateInfo);
+
+ qLstmLayer->GetOutputSlot(1).Connect(cellStateOut->GetInputSlot(0));
+ qLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateInfo);
+
+ qLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0));
+ qLstmLayer->GetOutputSlot(2).SetTensorInfo(outputStateInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ BOOST_CHECK(deserializedNetwork);
+
+ VerifyQLstmLayer checker(layerName,
+ {inputInfo, cellStateInfo, outputStateInfo},
+ {outputStateInfo, cellStateInfo, outputStateInfo},
+ descriptor,
+ params);
+
+ deserializedNetwork->Accept(checker);
+}
+
+BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmAdvanced)
+{
+ armnn::QLstmDescriptor descriptor;
+
+ descriptor.m_CifgEnabled = false;
+ descriptor.m_ProjectionEnabled = true;
+ descriptor.m_PeepholeEnabled = true;
+ descriptor.m_LayerNormEnabled = true;
+
+ descriptor.m_CellClip = 0.1f;
+ descriptor.m_ProjectionClip = 0.1f;
+
+ descriptor.m_InputIntermediateScale = 0.00001f;
+ descriptor.m_ForgetIntermediateScale = 0.00001f;
+ descriptor.m_CellIntermediateScale = 0.00001f;
+ descriptor.m_OutputIntermediateScale = 0.00001f;
+
+ descriptor.m_HiddenStateScale = 0.07f;
+ descriptor.m_HiddenStateZeroPoint = 0;
+
+ const unsigned int numBatches = 2;
+ const unsigned int inputSize = 5;
+ const unsigned int outputSize = 4;
+ const unsigned int numUnits = 4;
+
+ // Scale/Offset quantization info
+ float inputScale = 0.0078f;
+ int32_t inputOffset = 0;
+
+ float outputScale = 0.0078f;
+ int32_t outputOffset = 0;
+
+ float cellStateScale = 3.5002e-05f;
+ int32_t cellStateOffset = 0;
+
+ float weightsScale = 0.007f;
+ int32_t weightsOffset = 0;
+
+ float layerNormScale = 3.5002e-05f;
+ int32_t layerNormOffset = 0;
+
+ float biasScale = layerNormScale / 1024;
+ int32_t biasOffset = 0;
+
+ // Weights and bias tensor and quantization info
+ armnn::TensorInfo inputWeightsInfo({numUnits, inputSize},
+ armnn::DataType::QSymmS8,
+ weightsScale,
+ weightsOffset);
+
+ armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize},
+ armnn::DataType::QSymmS8,
+ weightsScale,
+ weightsOffset);
+
+ armnn::TensorInfo biasInfo({numUnits},
+ armnn::DataType::Signed32,
+ biasScale,
+ biasOffset);
+
+ armnn::TensorInfo peepholeWeightsInfo({numUnits},
+ armnn::DataType::QSymmS16,
+ weightsScale,
+ weightsOffset);
+
+ armnn::TensorInfo layerNormWeightsInfo({numUnits},
+ armnn::DataType::QSymmS16,
+ layerNormScale,
+ layerNormOffset);
+
+ armnn::TensorInfo projectionWeightsInfo({outputSize, numUnits},
+ armnn::DataType::QSymmS8,
+ weightsScale,
+ weightsOffset);
+
+ // Mandatory params
+ std::vector<int8_t> inputToForgetWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+ std::vector<int8_t> inputToCellWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+ std::vector<int8_t> inputToOutputWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+
+ armnn::ConstTensor inputToForgetWeights(inputWeightsInfo, inputToForgetWeightsData);
+ armnn::ConstTensor inputToCellWeights(inputWeightsInfo, inputToCellWeightsData);
+ armnn::ConstTensor inputToOutputWeights(inputWeightsInfo, inputToOutputWeightsData);
+
+ std::vector<int8_t> recurrentToForgetWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+ std::vector<int8_t> recurrentToCellWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+ std::vector<int8_t> recurrentToOutputWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+
+ armnn::ConstTensor recurrentToForgetWeights(recurrentWeightsInfo, recurrentToForgetWeightsData);
+ armnn::ConstTensor recurrentToCellWeights(recurrentWeightsInfo, recurrentToCellWeightsData);
+ armnn::ConstTensor recurrentToOutputWeights(recurrentWeightsInfo, recurrentToOutputWeightsData);
+
+ std::vector<int32_t> forgetGateBiasData(numUnits, 1);
+ std::vector<int32_t> cellBiasData(numUnits, 0);
+ std::vector<int32_t> outputGateBiasData(numUnits, 0);
+
+ armnn::ConstTensor forgetGateBias(biasInfo, forgetGateBiasData);
+ armnn::ConstTensor cellBias(biasInfo, cellBiasData);
+ armnn::ConstTensor outputGateBias(biasInfo, outputGateBiasData);
+
+ // CIFG
+ std::vector<int8_t> inputToInputWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
+ std::vector<int8_t> recurrentToInputWeightsData =
+ GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements());
+ std::vector<int32_t> inputGateBiasData(numUnits, 1);
+
+ armnn::ConstTensor inputToInputWeights(inputWeightsInfo, inputToInputWeightsData);
+ armnn::ConstTensor recurrentToInputWeights(recurrentWeightsInfo, recurrentToInputWeightsData);
+ armnn::ConstTensor inputGateBias(biasInfo, inputGateBiasData);
+
+ // Peephole
+ std::vector<int16_t> cellToInputWeightsData = GenerateRandomData<int16_t>(peepholeWeightsInfo.GetNumElements());
+ std::vector<int16_t> cellToForgetWeightsData = GenerateRandomData<int16_t>(peepholeWeightsInfo.GetNumElements());
+ std::vector<int16_t> cellToOutputWeightsData = GenerateRandomData<int16_t>(peepholeWeightsInfo.GetNumElements());
+
+ armnn::ConstTensor cellToInputWeights(peepholeWeightsInfo, cellToInputWeightsData);
+ armnn::ConstTensor cellToForgetWeights(peepholeWeightsInfo, cellToForgetWeightsData);
+ armnn::ConstTensor cellToOutputWeights(peepholeWeightsInfo, cellToOutputWeightsData);
+
+ // Projection
+ std::vector<int8_t> projectionWeightsData = GenerateRandomData<int8_t>(projectionWeightsInfo.GetNumElements());
+ std::vector<int32_t> projectionBiasData(outputSize, 1);
+
+ armnn::ConstTensor projectionWeights(projectionWeightsInfo, projectionWeightsData);
+ armnn::ConstTensor projectionBias(biasInfo, projectionBiasData);
+
+ // Layer Norm
+ std::vector<int16_t> inputLayerNormWeightsData =
+ GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements());
+ std::vector<int16_t> forgetLayerNormWeightsData =
+ GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements());
+ std::vector<int16_t> cellLayerNormWeightsData =
+ GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements());
+ std::vector<int16_t> outputLayerNormWeightsData =
+ GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements());
+
+ armnn::ConstTensor inputLayerNormWeights(layerNormWeightsInfo, inputLayerNormWeightsData);
+ armnn::ConstTensor forgetLayerNormWeights(layerNormWeightsInfo, forgetLayerNormWeightsData);
+ armnn::ConstTensor cellLayerNormWeights(layerNormWeightsInfo, cellLayerNormWeightsData);
+ armnn::ConstTensor outputLayerNormWeights(layerNormWeightsInfo, outputLayerNormWeightsData);
+
+ // Set up params
+ armnn::LstmInputParams params;
+
+ // Mandatory params
+ params.m_InputToForgetWeights = &inputToForgetWeights;
+ params.m_InputToCellWeights = &inputToCellWeights;
+ params.m_InputToOutputWeights = &inputToOutputWeights;
+
+ params.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ params.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ params.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+
+ params.m_ForgetGateBias = &forgetGateBias;
+ params.m_CellBias = &cellBias;
+ params.m_OutputGateBias = &outputGateBias;
+
+ // CIFG
+ params.m_InputToInputWeights = &inputToInputWeights;
+ params.m_RecurrentToInputWeights = &recurrentToInputWeights;
+ params.m_InputGateBias = &inputGateBias;
+
+ // Peephole
+ params.m_CellToInputWeights = &cellToInputWeights;
+ params.m_CellToForgetWeights = &cellToForgetWeights;
+ params.m_CellToOutputWeights = &cellToOutputWeights;
+
+ // Projection
+ params.m_ProjectionWeights = &projectionWeights;
+ params.m_ProjectionBias = &projectionBias;
+
+ // Layer Norm
+ params.m_InputLayerNormWeights = &inputLayerNormWeights;
+ params.m_ForgetLayerNormWeights = &forgetLayerNormWeights;
+ params.m_CellLayerNormWeights = &cellLayerNormWeights;
+ params.m_OutputLayerNormWeights = &outputLayerNormWeights;
+
+ // Create network
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ const std::string layerName("qLstm");
+
+ armnn::IConnectableLayer* const input = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(1);
+ armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(2);
+
+ armnn::IConnectableLayer* const qLstmLayer = network->AddQLstmLayer(descriptor, params, layerName.c_str());
+
+ armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(0);
+ armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(1);
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(2);
+
+ // Input/Output tensor info
+ armnn::TensorInfo inputInfo({numBatches , inputSize},
+ armnn::DataType::QAsymmS8,
+ inputScale,
+ inputOffset);
+
+ armnn::TensorInfo cellStateInfo({numBatches , numUnits},
+ armnn::DataType::QSymmS16,
+ cellStateScale,
+ cellStateOffset);
+
+ armnn::TensorInfo outputStateInfo({numBatches , outputSize},
+ armnn::DataType::QAsymmS8,
+ outputScale,
+ outputOffset);
+
+ // Connect input/output slots
+ input->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(0));
+ input->GetOutputSlot(0).SetTensorInfo(inputInfo);
+
+ outputStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(1));
+ outputStateIn->GetOutputSlot(0).SetTensorInfo(cellStateInfo);
+
+ cellStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(2));
+ cellStateIn->GetOutputSlot(0).SetTensorInfo(outputStateInfo);
+
+ qLstmLayer->GetOutputSlot(0).Connect(outputStateOut->GetInputSlot(0));
+ qLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateInfo);
+
+ qLstmLayer->GetOutputSlot(1).Connect(cellStateOut->GetInputSlot(0));
+ qLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateInfo);
+
+ qLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0));
+ qLstmLayer->GetOutputSlot(2).SetTensorInfo(outputStateInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ BOOST_CHECK(deserializedNetwork);
+
+ VerifyQLstmLayer checker(layerName,
+ {inputInfo, cellStateInfo, outputStateInfo},
+ {outputStateInfo, cellStateInfo, outputStateInfo},
+ descriptor,
+ params);
+
+ deserializedNetwork->Accept(checker);
+}
+
BOOST_AUTO_TEST_SUITE_END()