IVGCVSW-4777 Add QLstm serialization support
[platform/upstream/armnn.git] / src / armnnDeserializer / Deserializer.cpp
index 42b0052..36beebc 100644 (file)
@@ -222,6 +222,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer)
     m_ParserFunctions[Layer_PermuteLayer]                = &Deserializer::ParsePermute;
     m_ParserFunctions[Layer_Pooling2dLayer]              = &Deserializer::ParsePooling2d;
     m_ParserFunctions[Layer_PreluLayer]                  = &Deserializer::ParsePrelu;
+    m_ParserFunctions[Layer_QLstmLayer]                  = &Deserializer::ParseQLstm;
     m_ParserFunctions[Layer_QuantizeLayer]               = &Deserializer::ParseQuantize;
     m_ParserFunctions[Layer_QuantizedLstmLayer]          = &Deserializer::ParseQuantizedLstm;
     m_ParserFunctions[Layer_ReshapeLayer]                = &Deserializer::ParseReshape;
@@ -322,6 +323,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt
             return graphPtr->layers()->Get(layerIndex)->layer_as_Pooling2dLayer()->base();
         case Layer::Layer_PreluLayer:
             return graphPtr->layers()->Get(layerIndex)->layer_as_PreluLayer()->base();
+        case Layer::Layer_QLstmLayer:
+            return graphPtr->layers()->Get(layerIndex)->layer_as_QLstmLayer()->base();
         case Layer::Layer_QuantizeLayer:
             return graphPtr->layers()->Get(layerIndex)->layer_as_QuantizeLayer()->base();
         case Layer::Layer_QuantizedLstmLayer:
@@ -2610,6 +2613,155 @@ void Deserializer::ParseLstm(GraphPtr graph, unsigned int layerIndex)
     RegisterOutputSlots(graph, layerIndex, layer);
 }
 
+armnn::QLstmDescriptor Deserializer::GetQLstmDescriptor(Deserializer::QLstmDescriptorPtr qLstmDescriptor)
+{
+    armnn::QLstmDescriptor desc;
+
+    desc.m_CifgEnabled       = qLstmDescriptor->cifgEnabled();
+    desc.m_PeepholeEnabled   = qLstmDescriptor->peepholeEnabled();
+    desc.m_ProjectionEnabled = qLstmDescriptor->projectionEnabled();
+    desc.m_LayerNormEnabled  = qLstmDescriptor->layerNormEnabled();
+
+    desc.m_CellClip       = qLstmDescriptor->cellClip();
+    desc.m_ProjectionClip = qLstmDescriptor->projectionClip();
+
+    desc.m_InputIntermediateScale  = qLstmDescriptor->inputIntermediateScale();
+    desc.m_ForgetIntermediateScale = qLstmDescriptor->forgetIntermediateScale();
+    desc.m_CellIntermediateScale   = qLstmDescriptor->cellIntermediateScale();
+    desc.m_OutputIntermediateScale = qLstmDescriptor->outputIntermediateScale();
+
+    desc.m_HiddenStateScale     = qLstmDescriptor->hiddenStateScale();
+    desc.m_HiddenStateZeroPoint = qLstmDescriptor->hiddenStateZeroPoint();
+
+    return desc;
+}
+
+void Deserializer::ParseQLstm(GraphPtr graph, unsigned int layerIndex)
+{
+    CHECK_LAYERS(graph, 0, layerIndex);
+
+    auto inputs = GetInputs(graph, layerIndex);
+    CHECK_VALID_SIZE(inputs.size(), 3);
+
+    auto outputs = GetOutputs(graph, layerIndex);
+    CHECK_VALID_SIZE(outputs.size(), 3);
+
+    auto flatBufferLayer       = graph->layers()->Get(layerIndex)->layer_as_QLstmLayer();
+    auto layerName             = GetLayerName(graph, layerIndex);
+    auto flatBufferDescriptor  = flatBufferLayer->descriptor();
+    auto flatBufferInputParams = flatBufferLayer->inputParams();
+
+    auto qLstmDescriptor = GetQLstmDescriptor(flatBufferDescriptor);
+    armnn::LstmInputParams qLstmInputParams;
+
+    // Mandatory params
+    armnn::ConstTensor inputToForgetWeights     = ToConstTensor(flatBufferInputParams->inputToForgetWeights());
+    armnn::ConstTensor inputToCellWeights       = ToConstTensor(flatBufferInputParams->inputToCellWeights());
+    armnn::ConstTensor inputToOutputWeights     = ToConstTensor(flatBufferInputParams->inputToOutputWeights());
+    armnn::ConstTensor recurrentToForgetWeights = ToConstTensor(flatBufferInputParams->recurrentToForgetWeights());
+    armnn::ConstTensor recurrentToCellWeights   = ToConstTensor(flatBufferInputParams->recurrentToCellWeights());
+    armnn::ConstTensor recurrentToOutputWeights = ToConstTensor(flatBufferInputParams->recurrentToOutputWeights());
+    armnn::ConstTensor forgetGateBias           = ToConstTensor(flatBufferInputParams->forgetGateBias());
+    armnn::ConstTensor cellBias                 = ToConstTensor(flatBufferInputParams->cellBias());
+    armnn::ConstTensor outputGateBias           = ToConstTensor(flatBufferInputParams->outputGateBias());
+
+    qLstmInputParams.m_InputToForgetWeights = &inputToForgetWeights;
+    qLstmInputParams.m_InputToCellWeights = &inputToCellWeights;
+    qLstmInputParams.m_InputToOutputWeights = &inputToOutputWeights;
+    qLstmInputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+    qLstmInputParams.m_RecurrentToCellWeights = &recurrentToCellWeights;
+    qLstmInputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+    qLstmInputParams.m_ForgetGateBias = &forgetGateBias;
+    qLstmInputParams.m_CellBias = &cellBias;
+    qLstmInputParams.m_OutputGateBias = &outputGateBias;
+
+    // Optional CIFG params
+    armnn::ConstTensor inputToInputWeights;
+    armnn::ConstTensor recurrentToInputWeights;
+    armnn::ConstTensor inputGateBias;
+
+    if (!qLstmDescriptor.m_CifgEnabled)
+    {
+        inputToInputWeights     = ToConstTensor(flatBufferInputParams->inputToInputWeights());
+        recurrentToInputWeights = ToConstTensor(flatBufferInputParams->recurrentToInputWeights());
+        inputGateBias           = ToConstTensor(flatBufferInputParams->inputGateBias());
+
+        qLstmInputParams.m_InputToInputWeights     = &inputToInputWeights;
+        qLstmInputParams.m_RecurrentToInputWeights = &recurrentToInputWeights;
+        qLstmInputParams.m_InputGateBias           = &inputGateBias;
+    }
+
+    // Optional projection params
+    armnn::ConstTensor projectionWeights;
+    armnn::ConstTensor projectionBias;
+
+    if (qLstmDescriptor.m_ProjectionEnabled)
+    {
+        projectionWeights = ToConstTensor(flatBufferInputParams->projectionWeights());
+        projectionBias = ToConstTensor(flatBufferInputParams->projectionBias());
+
+        qLstmInputParams.m_ProjectionWeights = &projectionWeights;
+        qLstmInputParams.m_ProjectionBias = &projectionBias;
+    }
+
+    // Optional peephole params
+    armnn::ConstTensor cellToInputWeights;
+    armnn::ConstTensor cellToForgetWeights;
+    armnn::ConstTensor cellToOutputWeights;
+
+    if (qLstmDescriptor.m_PeepholeEnabled)
+    {
+        if (!qLstmDescriptor.m_CifgEnabled)
+        {
+            cellToInputWeights = ToConstTensor(flatBufferInputParams->cellToInputWeights());
+            qLstmInputParams.m_CellToInputWeights = &cellToInputWeights;
+        }
+
+        cellToForgetWeights = ToConstTensor(flatBufferInputParams->cellToForgetWeights());
+        cellToOutputWeights = ToConstTensor(flatBufferInputParams->cellToOutputWeights());
+
+        qLstmInputParams.m_CellToForgetWeights = &cellToForgetWeights;
+        qLstmInputParams.m_CellToOutputWeights = &cellToOutputWeights;
+    }
+
+    // Optional layer norm params
+    armnn::ConstTensor inputLayerNormWeights;
+    armnn::ConstTensor forgetLayerNormWeights;
+    armnn::ConstTensor cellLayerNormWeights;
+    armnn::ConstTensor outputLayerNormWeights;
+
+    if (qLstmDescriptor.m_LayerNormEnabled)
+    {
+        if (!qLstmDescriptor.m_CifgEnabled)
+        {
+            inputLayerNormWeights = ToConstTensor(flatBufferInputParams->inputLayerNormWeights());
+            qLstmInputParams.m_InputLayerNormWeights = &inputLayerNormWeights;
+        }
+
+        forgetLayerNormWeights = ToConstTensor(flatBufferInputParams->forgetLayerNormWeights());
+        cellLayerNormWeights   = ToConstTensor(flatBufferInputParams->cellLayerNormWeights());
+        outputLayerNormWeights = ToConstTensor(flatBufferInputParams->outputLayerNormWeights());
+
+        qLstmInputParams.m_ForgetLayerNormWeights = &forgetLayerNormWeights;
+        qLstmInputParams.m_CellLayerNormWeights   = &cellLayerNormWeights;
+        qLstmInputParams.m_OutputLayerNormWeights = &outputLayerNormWeights;
+    }
+
+    IConnectableLayer* layer = m_Network->AddQLstmLayer(qLstmDescriptor, qLstmInputParams, layerName.c_str());
+
+    armnn::TensorInfo outputStateOutInfo = ToTensorInfo(outputs[0]);
+    layer->GetOutputSlot(0).SetTensorInfo(outputStateOutInfo);
+
+    armnn::TensorInfo cellStateOutInfo = ToTensorInfo(outputs[1]);
+    layer->GetOutputSlot(1).SetTensorInfo(cellStateOutInfo);
+
+    armnn::TensorInfo outputInfo = ToTensorInfo(outputs[2]);
+    layer->GetOutputSlot(2).SetTensorInfo(outputInfo);
+
+    RegisterInputSlots(graph, layerIndex, layer);
+    RegisterOutputSlots(graph, layerIndex, layer);
+}
+
 void Deserializer::ParseQuantizedLstm(GraphPtr graph, unsigned int layerIndex)
 {
     CHECK_LAYERS(graph, 0, layerIndex);