IVGCVSW-4777 Add QLstm serialization support
[platform/upstream/armnn.git] / src / armnnSerializer / Serializer.cpp
index 3556736..c4d3cfb 100644 (file)
@@ -1335,9 +1335,124 @@ void SerializerVisitor::VisitQLstmLayer(const armnn::IConnectableLayer* layer,
                                         const armnn::LstmInputParams& params,
                                         const char* name)
 {
-    IgnoreUnused(layer, descriptor, params, name);
+    IgnoreUnused(name);
+
+    auto fbQLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QLstm);
+
+    auto fbQLstmDescriptor = serializer::CreateQLstmDescriptor(
+            m_flatBufferBuilder,
+            descriptor.m_CifgEnabled,
+            descriptor.m_PeepholeEnabled,
+            descriptor.m_ProjectionEnabled,
+            descriptor.m_LayerNormEnabled,
+            descriptor.m_CellClip,
+            descriptor.m_ProjectionClip,
+            descriptor.m_InputIntermediateScale,
+            descriptor.m_ForgetIntermediateScale,
+            descriptor.m_CellIntermediateScale,
+            descriptor.m_OutputIntermediateScale,
+            descriptor.m_HiddenStateZeroPoint,
+            descriptor.m_HiddenStateScale
+            );
+
+    // Mandatory params
+    auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights);
+    auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights);
+    auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights);
+    auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights);
+    auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights);
+    auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights);
+    auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias);
+    auto cellBias = CreateConstTensorInfo(*params.m_CellBias);
+    auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias);
+
+    // CIFG
+    flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights;
+    flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights;
+    flatbuffers::Offset<serializer::ConstTensor> inputGateBias;
+
+    if (!descriptor.m_CifgEnabled)
+    {
+        inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights);
+        recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights);
+        inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias);
+    }
+
+    // Projectiom
+    flatbuffers::Offset<serializer::ConstTensor> projectionWeights;
+    flatbuffers::Offset<serializer::ConstTensor> projectionBias;
+
+    if (descriptor.m_ProjectionEnabled)
+    {
+        projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights);
+        projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias);
+    }
+
+    // Peephole
+    flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights;
+    flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
+    flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
+
+    if (descriptor.m_PeepholeEnabled)
+    {
+        if (!descriptor.m_CifgEnabled)
+        {
+            cellToInputWeights  = CreateConstTensorInfo(*params.m_CellToInputWeights);
+        }
+
+        cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights);
+        cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights);
+    }
+
+    // Layer norm
+    flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights;
+    flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights;
+    flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights;
+    flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights;
+
+    if (descriptor.m_LayerNormEnabled)
+    {
+        if (!descriptor.m_CifgEnabled)
+        {
+            inputLayerNormWeights = CreateConstTensorInfo((*params.m_InputLayerNormWeights));
+        }
+
+        forgetLayerNormWeights = CreateConstTensorInfo(*params.m_ForgetLayerNormWeights);
+        cellLayerNormWeights   = CreateConstTensorInfo(*params.m_CellLayerNormWeights);
+        outputLayerNormWeights = CreateConstTensorInfo(*params.m_OutputLayerNormWeights);
+    }
+
+    auto fbQLstmParams = serializer::CreateQLstmInputParams(
+            m_flatBufferBuilder,
+            inputToForgetWeights,
+            inputToCellWeights,
+            inputToOutputWeights,
+            recurrentToForgetWeights,
+            recurrentToCellWeights,
+            recurrentToOutputWeights,
+            forgetGateBias,
+            cellBias,
+            outputGateBias,
+            inputToInputWeights,
+            recurrentToInputWeights,
+            inputGateBias,
+            projectionWeights,
+            projectionBias,
+            cellToInputWeights,
+            cellToForgetWeights,
+            cellToOutputWeights,
+            inputLayerNormWeights,
+            forgetLayerNormWeights,
+            cellLayerNormWeights,
+            outputLayerNormWeights);
+
+    auto fbQLstmLayer = serializer::CreateQLstmLayer(
+            m_flatBufferBuilder,
+            fbQLstmBaseLayer,
+            fbQLstmDescriptor,
+            fbQLstmParams);
 
-    throw UnimplementedException("SerializerVisitor::VisitQLstmLayer not yet implemented");
+    CreateAnyLayer(fbQLstmLayer.o, serializer::Layer::Layer_QLstmLayer);
 }
 
 void SerializerVisitor::VisitQuantizedLstmLayer(const armnn::IConnectableLayer* layer,