IVGCVSW-5394 TfLiteDelegate: Implement the Lstm operator
authorMike Kelly <mike.kelly@arm.com>
Wed, 17 Feb 2021 13:45:50 +0000 (13:45 +0000)
committerKeith Davis <keith.davis@arm.com>
Thu, 18 Feb 2021 12:39:00 +0000 (12:39 +0000)
 * Add LSTM operator

Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: If8c667685fa1176738ffe2e6d08b1c684e7ee6b2

delegate/CMakeLists.txt
delegate/src/Lstm.hpp
delegate/src/test/LstmTest.cpp [new file with mode: 0644]
delegate/src/test/LstmTestHelper.hpp [new file with mode: 0644]

index 7de168f..981fc9f 100644 (file)
@@ -137,6 +137,8 @@ if(BUILD_UNIT_TESTS)
         src/test/GatherTestHelper.hpp
         src/test/LogicalTest.cpp
         src/test/LogicalTestHelper.hpp
+        src/test/LstmTest.cpp
+        src/test/LstmTestHelper.hpp
         src/test/NormalizationTest.cpp
         src/test/NormalizationTestHelper.hpp
         src/test/PadTest.cpp
index b81b256..829e3bf 100644 (file)
@@ -5,6 +5,10 @@
 
 #pragma once
 
+#include "DelegateUtils.hpp"
+
+#include <armnn/LstmParams.hpp>
+#include <armnn/Tensor.hpp>
 #include <armnn/utility/IgnoreUnused.hpp>
 
 #include <tensorflow/lite/builtin_ops.h>
 namespace armnnDelegate
 {
 
+bool IsOptional(TfLiteNode* tfLiteNode, const int index)
+{
+    if (tfLiteNode->inputs->data[index] < 0) {
+        return true;
+    }
+    return false;
+
+}
+
+armnn::ConstTensor* CreateConstTensor(const TfLiteTensor* tfLiteTensors, TfLiteNode* tfLiteNode, int index)
+{
+    const TfLiteTensor &tfLiteTensor = tfLiteTensors[tfLiteNode->inputs->data[index]];
+    armnn::TensorInfo tensorInfo = GetTensorInfoForTfLiteTensor(tfLiteTensor);
+    return new armnn::ConstTensor(tensorInfo, tfLiteTensor.data.data);
+}
+
 TfLiteStatus VisitLstmOperator(DelegateData& delegateData,
                                TfLiteContext* tfLiteContext,
                                TfLiteNode* tfLiteNode,
                                int nodeIndex,
                                int32_t operatorCode)
 {
-    armnn::IgnoreUnused(delegateData,
-                        tfLiteContext,
-                        tfLiteNode,
-                        nodeIndex,
-                        operatorCode);
+    auto numInputs = tfLiteNode->inputs->size;
+    if (numInputs < 2)
+    {
+        TF_LITE_MAYBE_KERNEL_LOG(
+                tfLiteContext, "TfLiteArmnnDelegate: Minimum number of inputs (%d != %d) in node #%d",
+                2, numInputs, nodeIndex);
+        return kTfLiteError;
+    }
+
+    const auto nodeParams = reinterpret_cast<TfLiteLSTMParams*>(tfLiteNode->builtin_data);
+    const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
+
+    const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
+    if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
+    {
+        return kTfLiteError;
+    }
+
+    const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
+    if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
+    {
+        return kTfLiteError;
+    }
+
+    // Set the params structure for the AddLstmLayer call
+    armnn::LstmInputParams params;
+
+    if (!IsOptional(tfLiteNode, 1))
+    {
+        params.m_InputToInputWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 1);
+    }
+
+    params.m_InputToForgetWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 2);
+    params.m_InputToCellWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 3);
+    params.m_InputToOutputWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 4);
+
+    // Recurrent weight tensors of size {n_cell, n_output}
+    if (!IsOptional(tfLiteNode, 5))
+    {
+        params.m_RecurrentToInputWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 5);
+    }
+
+    params.m_RecurrentToForgetWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 6);
+    params.m_RecurrentToCellWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 7);
+    params.m_RecurrentToOutputWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 8);
+
+    // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
+    if (!IsOptional(tfLiteNode, 9))
+    {
+        params.m_CellToInputWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 9);
+    }
+
+    if (!IsOptional(tfLiteNode, 10))
+    {
+        params.m_CellToForgetWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 10);
+    }
+
+    if (!IsOptional(tfLiteNode, 11))
+    {
+        params.m_CellToOutputWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 11);
+    }
+
+    // Gates bias tensors of size {n_cell}
+    if (!IsOptional(tfLiteNode, 12))
+    {
+        params.m_InputGateBias = CreateConstTensor(tfLiteTensors, tfLiteNode, 12);
+    }
+
+    params.m_ForgetGateBias = CreateConstTensor(tfLiteTensors, tfLiteNode, 13);
+    params.m_CellBias = CreateConstTensor(tfLiteTensors, tfLiteNode, 14);
+    params.m_OutputGateBias = CreateConstTensor(tfLiteTensors, tfLiteNode, 15);
+
+    // Projection weight tensor of size {n_output, n_cell}
+    if (!IsOptional(tfLiteNode, 16))
+    {
+        params.m_ProjectionWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 16);
+    }
+    // Projection bias tensor of size {n_output}
+    if (!IsOptional(tfLiteNode, 17))
+    {
+        params.m_ProjectionBias = CreateConstTensor(tfLiteTensors, tfLiteNode, 17);
+    }
+
+    // These state tensors are defined as variable tensors, and will be modified by this op.
+    armnn::TensorInfo outputStateInInfo = GetTensorInfoForTfLiteTensor(tfLiteTensors[tfLiteNode->inputs->data[18]]);
+    armnn::TensorInfo cellStateInInfo = GetTensorInfoForTfLiteTensor(tfLiteTensors[tfLiteNode->inputs->data[19]]);
+
+    // Layer norm coefficient tensors of size {n_cell}, representing a diagonal matrix.
+    if (tfLiteNode->inputs->size >= 21 && !IsOptional(tfLiteNode, 20))
+    {
+        params.m_InputLayerNormWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 20);
+    }
+
+    if (tfLiteNode->inputs->size >= 22 && !IsOptional(tfLiteNode, 21))
+    {
+        params.m_ForgetLayerNormWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 21);
+    }
+
+    if (tfLiteNode->inputs->size >= 23 && !IsOptional(tfLiteNode, 22))
+    {
+        params.m_CellLayerNormWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 22);
+    }
+
+    if (tfLiteNode->inputs->size >= 24 && !IsOptional(tfLiteNode, 23))
+    {
+        params.m_OutputLayerNormWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 23);
+    }
+
+    // set the layer descriptor
+    armnn::LstmDescriptor desc;
+    desc.m_ActivationFunc    = NonNegative(nodeParams->activation, nodeIndex);
+    desc.m_ClippingThresCell = nodeParams->cell_clip;
+    desc.m_ClippingThresProj = nodeParams->proj_clip;
+    desc.m_CifgEnabled       = (params.m_InputToInputWeights == nullptr
+                                || params.m_RecurrentToInputWeights == nullptr
+                                || params.m_InputGateBias == nullptr);
+    desc.m_PeepholeEnabled   = (params.m_CellToForgetWeights != nullptr || params.m_CellToOutputWeights != nullptr);
+    desc.m_ProjectionEnabled = (params.m_ProjectionWeights != nullptr);
+    desc.m_LayerNormEnabled  = (params.m_InputLayerNormWeights != nullptr
+                                || params.m_ForgetLayerNormWeights != nullptr
+                                || params.m_CellLayerNormWeights != nullptr
+                                || params.m_OutputLayerNormWeights != nullptr);
+
+    const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
+    const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
+
+    unsigned int batchSize  = inputTensorInfo.GetShape()[0];
+    unsigned int outputSize = outputTensorInfo.GetShape()[1];
+    unsigned int numUnits   = cellStateInInfo.GetShape()[1];
+
+    armnn::DataType dataType = inputTensorInfo.GetDataType();
+    float qScale = inputTensorInfo.GetQuantizationScale();
+    float qOffset = inputTensorInfo.GetQuantizationOffset();
+
+    armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 3}, dataType, qScale, qOffset);
+    if (!desc.m_CifgEnabled)
+    {
+        scratchBufferTensorInfo = armnn::TensorInfo({batchSize, numUnits * 4}, dataType, qScale, qOffset);
+    }
+    armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, dataType, qScale, qOffset);
+    armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, dataType, qScale, qOffset);
+
+    armnn::LstmInputParamsInfo paramsInfo;
+    paramsInfo.m_InputToForgetWeights     = &(params.m_InputToForgetWeights->GetInfo());
+    paramsInfo.m_InputToCellWeights       = &(params.m_InputToCellWeights->GetInfo());
+    paramsInfo.m_InputToOutputWeights     = &(params.m_InputToOutputWeights->GetInfo());
+    paramsInfo.m_RecurrentToForgetWeights = &(params.m_RecurrentToForgetWeights->GetInfo());
+    paramsInfo.m_RecurrentToCellWeights   = &(params.m_RecurrentToCellWeights->GetInfo());
+    paramsInfo.m_RecurrentToOutputWeights = &(params.m_RecurrentToOutputWeights->GetInfo());
+    paramsInfo.m_ForgetGateBias           = &(params.m_ForgetGateBias->GetInfo());
+    paramsInfo.m_CellBias                 = &(params.m_CellBias->GetInfo());
+    paramsInfo.m_OutputGateBias           = &(params.m_OutputGateBias->GetInfo());
+
+    if (!desc.m_CifgEnabled)
+    {
+        paramsInfo.m_InputToInputWeights = &(params.m_InputToInputWeights->GetInfo());
+        paramsInfo.m_RecurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo());
+        if (params.m_CellToInputWeights != nullptr)
+        {
+            paramsInfo.m_CellToInputWeights = &(params.m_CellToInputWeights->GetInfo());
+        }
+        paramsInfo.m_InputGateBias = &(params.m_InputGateBias->GetInfo());
+    }
+
+    if (desc.m_ProjectionEnabled)
+    {
+        paramsInfo.m_ProjectionWeights = &(params.m_ProjectionWeights->GetInfo());
+        if (params.m_ProjectionBias != nullptr)
+        {
+            paramsInfo.m_ProjectionBias = &(params.m_ProjectionBias->GetInfo());
+        }
+    }
+
+    if (desc.m_PeepholeEnabled)
+    {
+        paramsInfo.m_CellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo());
+        paramsInfo.m_CellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo());
+    }
+
+    if (desc.m_LayerNormEnabled)
+    {
+        if(!desc.m_CifgEnabled)
+        {
+            paramsInfo.m_InputLayerNormWeights = &(params.m_InputLayerNormWeights->GetInfo());
+        }
+        paramsInfo.m_ForgetLayerNormWeights = &(params.m_ForgetLayerNormWeights->GetInfo());
+        paramsInfo.m_CellLayerNormWeights = &(params.m_CellLayerNormWeights->GetInfo());
+        paramsInfo.m_OutputLayerNormWeights = &(params.m_OutputLayerNormWeights->GetInfo());
+    }
+
+    bool isSupported = false;
+    auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
+    {
+        FORWARD_LAYER_SUPPORT_FUNC(__func__,
+                                   tfLiteContext,
+                                   IsLstmSupported,
+                                   delegateData.m_Backends,
+                                   isSupported,
+                                   inputTensorInfo,
+                                   outputStateInInfo,
+                                   cellStateInInfo,
+                                   scratchBufferTensorInfo,
+                                   outputStateOutTensorInfo,
+                                   cellStateOutTensorInfo,
+                                   outputInfo,
+                                   desc,
+                                   paramsInfo);
+    };
+
+    if (!delegateData.m_Network)
+    {
+        validateFunc(outputTensorInfo, isSupported);
+        return isSupported ? kTfLiteOk : kTfLiteError;
+    }
+
+    armnn::IConnectableLayer* layer = delegateData.m_Network->AddLstmLayer(desc, params);
+    ARMNN_ASSERT(layer != nullptr);
+
+    layer->GetOutputSlot(0).SetTensorInfo(scratchBufferTensorInfo);
+    layer->GetOutputSlot(1).SetTensorInfo(outputStateOutTensorInfo);
+    layer->GetOutputSlot(2).SetTensorInfo(cellStateOutTensorInfo);
+    layer->GetOutputSlot(3).SetTensorInfo(outputTensorInfo);
+
+    // Connect the inputs
+    // input_layer
+    delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(layer->GetInputSlot(0));
+    // cellStateIn
+    delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[18]]->Connect(layer->GetInputSlot(1));
+    //outputStateIn
+    delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[19]]->Connect(layer->GetInputSlot(2));
 
-    return kTfLiteError;
+    // In the test_model there is only 1 Output
+    armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(1);
+    delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tfLiteNode->outputs->data[0])] = &outputSlot;
+    return kTfLiteOk;
 }
 
-} // namespace armnnDelegate
+} // namespace armnnDelegate
\ No newline at end of file
diff --git a/delegate/src/test/LstmTest.cpp b/delegate/src/test/LstmTest.cpp
new file mode 100644 (file)
index 0000000..1fa9f0c
--- /dev/null
@@ -0,0 +1,189 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "LstmTestHelper.hpp"
+
+#include <armnn_delegate.hpp>
+
+#include <flatbuffers/flatbuffers.h>
+#include <tensorflow/lite/schema/schema_generated.h>
+#include <doctest/doctest.h>
+
+namespace armnnDelegate
+{
+
+void LstmTest(std::vector<armnn::BackendId>& backends)
+{
+    int32_t batchSize = 2;
+    int32_t inputSize = 2;
+    int32_t outputSize = 4;
+    // cellSize and outputSize have the same size when there is no projection.
+    int32_t numUnits = outputSize;
+
+    std::vector<int32_t> inputShape {batchSize , inputSize};
+    std::vector<int32_t> cellStateInTensorInfo {batchSize , numUnits};
+    std::vector<int32_t> outputStateInTensorInfo {batchSize , outputSize};
+
+    std::vector<int32_t> scratchBufferTensorInfo {batchSize, numUnits * 4};
+    std::vector<int32_t> cellStateOutTensorInfo {batchSize, numUnits};
+    std::vector<int32_t> outputStateOutTensorInfo {batchSize, outputSize};
+    std::vector<int32_t> outputTensorInfo {batchSize, outputSize};
+
+    std::vector<int32_t> tensorInfo4 {numUnits};
+    std::vector<int32_t> tensorInfo8 {numUnits, 2};
+    std::vector<int32_t> tensorInfo16 {numUnits, 4};
+
+    //tensorInfo8,
+    bool hasInputToInputWeights = true;
+    std::vector<float> inputToInputWeights {-0.45018822f, -0.02338299f, -0.0870589f,
+                                            -0.34550029f, 0.04266912f, -0.15680569f,
+                                            -0.34856534f, 0.43890524f};
+
+    std::vector<float> inputToForgetWeights {0.09701663f, 0.20334584f, -0.50592935f,
+                                             -0.31343272f, -0.40032279f, 0.44781327f,
+                                             0.01387155f, -0.35593212f};
+
+    std::vector<float> inputToCellWeights {-0.50013041f, 0.1370284f, 0.11810488f, 0.2013163f,
+                                           -0.20583314f, 0.44344562f, 0.22077113f,
+                                           -0.29909778f};
+
+    std::vector<float> inputToOutputWeights {-0.25065863f, -0.28290087f, 0.04613829f,
+                                             0.40525138f, 0.44272184f, 0.03897077f,
+                                             -0.1556896f, 0.19487578f};
+
+    //tensorInfo16,
+    bool hasRecurrentToInputWeights = true;
+    std::vector<float> recurrentToInputWeights {-0.0063535f, -0.2042388f, 0.31454784f,
+                                                -0.35746509f, 0.28902304f, 0.08183324f,
+                                                -0.16555229f, 0.02286911f, -0.13566875f,
+                                                0.03034258f, 0.48091322f, -0.12528998f,
+                                                0.24077177f, -0.51332325f, -0.33502164f,
+                                                0.10629296f};
+
+    std::vector<float> recurrentToForgetWeights {-0.48684245f, -0.06655136f, 0.42224967f,
+                                                 0.2112639f, 0.27654213f, 0.20864892f,
+                                                 -0.07646349f, 0.45877004f, 0.00141793f,
+                                                 -0.14609534f, 0.36447752f, 0.09196436f,
+                                                 0.28053468f, 0.01560611f, -0.20127171f,
+                                                 -0.01140004f};
+
+    std::vector<float> recurrentToCellWeights {-0.3407414f, 0.24443203f, -0.2078532f,
+                                               0.26320225f, 0.05695659f, -0.00123841f,
+                                               -0.4744786f, -0.35869038f, -0.06418842f,
+                                               -0.13502428f, -0.501764f, 0.22830659f,
+                                               -0.46367589f, 0.26016325f, -0.03894562f,
+                                               -0.16368064f};
+
+    std::vector<float> recurrentToOutputWeights {0.43385774f, -0.17194885f, 0.2718237f,
+                                                 0.09215671f, 0.24107647f, -0.39835793f,
+                                                 0.18212086f, 0.01301402f, 0.48572797f,
+                                                 -0.50656658f, 0.20047462f, -0.20607421f,
+                                                 -0.51818722f, -0.15390486f, 0.0468148f,
+                                                 0.39922136f};
+    // tensorInfo4
+    bool hasCellToInputWeights = false;
+    std::vector<float> cellToInputWeights {};
+    bool hasCellToForgetWeights = false;
+    std::vector<float> cellToForgetWeights {};
+    bool hasCellToOutputWeights = false;
+    std::vector<float> cellToOutputWeights {};
+
+    bool hasInputGateBias = true;
+    std::vector<float> inputGateBias {0., 0., 0., 0.};
+    std::vector<float> forgetGateBias {1., 1., 1., 1.};
+    std::vector<float> cellBias {0., 0., 0., 0.};
+    std::vector<float> outputGateBias {0., 0., 0., 0.};
+
+    bool hasProjectionWeights = false;
+    std::vector<float> projectionWeights;
+    bool hasProjectionBias = false;
+    std::vector<float> projectionBias;
+
+    bool hasInputLayerNormWeights = false;
+    std::vector<float> inputLayerNormWeights;
+    bool hasForgetLayerNormWeights = false;
+    std::vector<float> forgetLayerNormWeights;
+    bool hasCellLayerNormWeights = false;
+    std::vector<float> cellLayerNormWeights;
+    bool hasOutputLayerNormWeights = false;
+    std::vector<float> outputLayerNormWeights;
+
+    std::vector<float> inputValues {2., 3., 3., 4.};
+    std::vector<float> expectedOutputValues {-0.02973187f, 0.1229473f,   0.20885126f, -0.15358765f,
+                                             -0.0185422f,   0.11281417f,  0.24466537f, -0.1826292f};
+
+    tflite::ActivationFunctionType activationFunction = tflite::ActivationFunctionType_TANH;
+    float clippingThresCell = 0.f;
+    float clippingThresProj = 0.f;
+
+    LstmTestImpl<float>(backends,
+                        ::tflite::TensorType_FLOAT32,
+                        batchSize,
+                        inputSize,
+                        outputSize,
+                        numUnits,
+                        hasInputToInputWeights,
+                        inputToInputWeights,
+                        inputToForgetWeights,
+                        inputToCellWeights,
+                        inputToOutputWeights,
+                        hasRecurrentToInputWeights,
+                        recurrentToInputWeights,
+                        recurrentToForgetWeights,
+                        recurrentToCellWeights,
+                        recurrentToOutputWeights,
+                        hasCellToInputWeights,
+                        cellToInputWeights,
+                        hasCellToForgetWeights,
+                        cellToForgetWeights,
+                        hasCellToOutputWeights,
+                        cellToOutputWeights,
+                        hasInputGateBias,
+                        inputGateBias,
+                        forgetGateBias,
+                        cellBias,
+                        outputGateBias,
+                        hasProjectionWeights,
+                        projectionWeights,
+                        hasProjectionBias,
+                        projectionBias,
+                        hasInputLayerNormWeights,
+                        inputLayerNormWeights,
+                        hasForgetLayerNormWeights,
+                        forgetLayerNormWeights,
+                        hasCellLayerNormWeights,
+                        cellLayerNormWeights,
+                        hasOutputLayerNormWeights,
+                        outputLayerNormWeights,
+                        inputValues,
+                        expectedOutputValues,
+                        activationFunction,
+                        clippingThresCell,
+                        clippingThresProj);
+}
+
+TEST_SUITE("LstmTest_CpuRefTests")
+{
+
+TEST_CASE ("LstmTest_CpuRef_Test")
+{
+    std::vector <armnn::BackendId> backends = {armnn::Compute::CpuRef};
+    LstmTest(backends);
+}
+
+} //End of TEST_SUITE("Convolution2dTest_CpuRef")
+
+TEST_SUITE("LstmTest_CpuAccTests")
+{
+
+TEST_CASE ("LstmTest_CpuAcc_Test")
+{
+    std::vector <armnn::BackendId> backends = {armnn::Compute::CpuAcc};
+    LstmTest(backends);
+}
+
+} //End of TEST_SUITE("Convolution2dTest_CpuAcc")
+
+} // namespace armnnDelegate
\ No newline at end of file
diff --git a/delegate/src/test/LstmTestHelper.hpp b/delegate/src/test/LstmTestHelper.hpp
new file mode 100644 (file)
index 0000000..36a6061
--- /dev/null
@@ -0,0 +1,691 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "TestUtils.hpp"
+
+#include <armnn_delegate.hpp>
+
+#include <flatbuffers/flatbuffers.h>
+#include <tensorflow/lite/interpreter.h>
+#include <tensorflow/lite/kernels/register.h>
+#include <tensorflow/lite/model.h>
+#include <tensorflow/lite/schema/schema_generated.h>
+#include <tensorflow/lite/version.h>
+#include <tensorflow/lite/c/common.h>
+
+#include <doctest/doctest.h>
+
+namespace
+{
+
+template <typename T>
+std::vector<char> CreateLstmTfLiteModel(tflite::TensorType tensorType,
+                                        int32_t batchSize,
+                                        int32_t inputSize,
+                                        int32_t outputSize,
+                                        int32_t numUnits,
+                                        bool hasInputToInputWeights,
+                                        const std::vector<T>& inputToInputWeights,
+                                        const std::vector<T>& inputToForgetWeights,
+                                        const std::vector<T>& inputToCellWeights,
+                                        const std::vector<T>& inputToOutputWeights,
+                                        bool hasRecurrentToInputWeights,
+                                        const std::vector<T>& recurrentToInputWeights,
+                                        const std::vector<T>& recurrentToForgetWeights,
+                                        const std::vector<T>& recurrentToCellWeights,
+                                        const std::vector<T>& recurrentToOutputWeights,
+                                        bool hasCellToInputWeights,
+                                        const std::vector<T>& cellToInputWeights,
+                                        bool hasCellToForgetWeights,
+                                        const std::vector<T>& cellToForgetWeights,
+                                        bool hasCellToOutputWeights,
+                                        const std::vector<T>& cellToOutputWeights,
+                                        bool hasInputGateBias,
+                                        const std::vector<T>& inputGateBias,
+                                        const std::vector<T>& forgetGateBias,
+                                        const std::vector<T>& cellBias,
+                                        const std::vector<T>& outputGateBias,
+                                        bool hasProjectionWeights,
+                                        const std::vector<T>& projectionWeights,
+                                        bool hasProjectionBias,
+                                        const std::vector<T>& projectionBias,
+                                        bool hasInputLayerNormWeights,
+                                        const std::vector<T>& inputLayerNormWeights,
+                                        bool hasForgetLayerNormWeights,
+                                        const std::vector<T>& forgetLayerNormWeights,
+                                        bool hasCellLayerNormWeights,
+                                        const std::vector<T>& cellLayerNormWeights,
+                                        bool hasOutputLayerNormWeights,
+                                        const std::vector<T>& outputLayerNormWeights,
+                                        tflite::ActivationFunctionType activationFunction,
+                                        float clippingThresCell,
+                                        float clippingThresProj,
+                                        float quantScale = 1.0f,
+                                        int quantOffset  = 0,
+                                        float outputQuantScale = 2.0f,
+                                        int outputQuantOffset  = 0)
+{
+
+    std::vector <int32_t> tensorInfo0 {};
+    std::vector <int32_t> tensorInfo4 {numUnits};
+    std::vector <int32_t> tensorInfo8 {numUnits, static_cast<int32_t>(2)};
+    std::vector <int32_t> tensorInfo16 {numUnits, static_cast<int32_t>(4)};
+
+    std::vector<int32_t> inputShape {batchSize , inputSize};
+    std::vector<int32_t> outputShape {batchSize , outputSize};
+
+    std::vector<int32_t> outputStateInDimensions{batchSize, outputSize};
+    std::vector<int32_t> cellStateInDimensions{batchSize, numUnits};
+
+    std::vector<int> operatorInputs;
+    using namespace tflite;
+    flatbuffers::FlatBufferBuilder flatBufferBuilder;
+    std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
+    std::vector<flatbuffers::Offset<Tensor>> tensors;
+
+    auto quantizationParameters =
+        CreateQuantizationParameters(flatBufferBuilder,
+                                     0,
+                                     0,
+                                     flatBufferBuilder.CreateVector<float>({ quantScale }),
+                                     flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
+
+    auto outputQuantizationParameters =
+        CreateQuantizationParameters(flatBufferBuilder,
+                                     0,
+                                     0,
+                                     flatBufferBuilder.CreateVector<float>({ outputQuantScale }),
+                                     flatBufferBuilder.CreateVector<int64_t>({ outputQuantOffset }));
+
+    buffers.push_back(CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({})));
+    tensors.push_back(CreateTensor(flatBufferBuilder,
+                                   flatBufferBuilder.CreateVector<int32_t>(inputShape.data(),
+                                                                           inputShape.size()),
+                                   tensorType,
+                                   buffers.size() - 1,
+                                   flatBufferBuilder.CreateString("input_0"),
+                                   quantizationParameters));
+    operatorInputs.push_back(buffers.size() - 1);
+
+    if (hasInputToInputWeights)
+    {
+        buffers.push_back(
+            CreateBuffer(flatBufferBuilder,
+                         flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToInputWeights.data()),
+                                                        sizeof(T) * inputToInputWeights.size())));
+        tensors.push_back(CreateTensor(flatBufferBuilder,
+                                       flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
+                                                                               tensorInfo8.size()),
+                                       tensorType,
+                                       buffers.size() - 1,
+                                       flatBufferBuilder.CreateString("inputToInputWeights"),
+                                       outputQuantizationParameters));
+        operatorInputs.push_back(buffers.size() - 1);
+    }
+    else
+    {
+        operatorInputs.push_back(kTfLiteOptionalTensor);
+    }
+
+    buffers.push_back(
+        CreateBuffer(flatBufferBuilder,
+                     flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToForgetWeights.data()),
+                                                    sizeof(T) * inputToForgetWeights.size())));
+    tensors.push_back(CreateTensor(flatBufferBuilder,
+                                   flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
+                                                                           tensorInfo8.size()),
+                                   tensorType,
+                                   buffers.size() - 1,
+                                   flatBufferBuilder.CreateString("inputToForgetWeights"),
+                                   outputQuantizationParameters));
+    operatorInputs.push_back(buffers.size() - 1);
+
+    buffers.push_back(
+        CreateBuffer(flatBufferBuilder,
+                     flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToCellWeights.data()),
+                                                    sizeof(T) * inputToCellWeights.size())));
+    tensors.push_back(CreateTensor(flatBufferBuilder,
+                                   flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
+                                                                           tensorInfo8.size()),
+                                   tensorType,
+                                   buffers.size() - 1,
+                                   flatBufferBuilder.CreateString("inputToCellWeights"),
+                                   outputQuantizationParameters));
+    operatorInputs.push_back(buffers.size() - 1);
+
+    buffers.push_back(
+        CreateBuffer(flatBufferBuilder,
+                     flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToOutputWeights.data()),
+                                                    sizeof(T) * inputToOutputWeights.size())));
+    tensors.push_back(CreateTensor(flatBufferBuilder,
+                                   flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
+                                                                           tensorInfo8.size()),
+                                   tensorType,
+                                   buffers.size() - 1,
+                                   flatBufferBuilder.CreateString("inputToOutputWeights"),
+                                   outputQuantizationParameters));
+    operatorInputs.push_back(buffers.size() - 1);
+
+    if (hasRecurrentToInputWeights)
+    {
+        buffers.push_back(CreateBuffer(
+            flatBufferBuilder,
+            flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(recurrentToInputWeights.data()),
+                                           sizeof(T) * recurrentToInputWeights.size())));
+        tensors.push_back(CreateTensor(flatBufferBuilder,
+                                       flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
+                                                                               tensorInfo16.size()),
+                                       tensorType,
+                                       buffers.size() - 1,
+                                       flatBufferBuilder.CreateString("recurrentToInputWeights"),
+                                       outputQuantizationParameters));
+        operatorInputs.push_back(buffers.size() - 1);
+    }
+    else
+    {
+        operatorInputs.push_back(kTfLiteOptionalTensor);
+    }
+
+    buffers.push_back(
+        CreateBuffer(flatBufferBuilder,
+                     flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToForgetWeights.data()),
+                                                    sizeof(T) * recurrentToForgetWeights.size())));
+    tensors.push_back(CreateTensor(flatBufferBuilder,
+                                   flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
+                                                                           tensorInfo16.size()),
+                                   tensorType,
+                                   buffers.size() - 1,
+                                   flatBufferBuilder.CreateString("recurrentToForgetWeights"),
+                                   outputQuantizationParameters));
+    operatorInputs.push_back(buffers.size() - 1);
+
+    buffers.push_back(
+        CreateBuffer(flatBufferBuilder,
+                     flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToCellWeights.data()),
+                                                    sizeof(T) * recurrentToCellWeights.size())));
+    tensors.push_back(CreateTensor(flatBufferBuilder,
+                                   flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
+                                                                           tensorInfo16.size()),
+                                   tensorType,
+                                   buffers.size() - 1,
+                                   flatBufferBuilder.CreateString("recurrentToCellWeights"),
+                                   outputQuantizationParameters));
+    operatorInputs.push_back(buffers.size() - 1);
+
+    buffers.push_back(
+        CreateBuffer(flatBufferBuilder,
+                     flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToOutputWeights.data()),
+                                                    sizeof(T) * recurrentToOutputWeights.size())));
+    tensors.push_back(CreateTensor(flatBufferBuilder,
+                                   flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
+                                                                           tensorInfo16.size()),
+                                   tensorType,
+                                   buffers.size() - 1 ,
+                                   flatBufferBuilder.CreateString("recurrentToOutputWeights"),
+                                   outputQuantizationParameters));
+    operatorInputs.push_back(buffers.size() - 1);
+
+    if (hasCellToInputWeights)
+    {
+        buffers.push_back(
+            CreateBuffer(flatBufferBuilder,
+                         flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToInputWeights.data()),
+                                                        sizeof(T) * cellToInputWeights.size())));
+        tensors.push_back(CreateTensor(flatBufferBuilder,
+                                       flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
+                                                                               tensorInfo4.size()),
+                                       tensorType,
+                                       buffers.size() - 1,
+                                       flatBufferBuilder.CreateString("cellToInputWeights"),
+                                       outputQuantizationParameters));
+        operatorInputs.push_back(buffers.size() - 1);
+    }
+    else
+    {
+        operatorInputs.push_back(kTfLiteOptionalTensor);
+    }
+
+    if (hasCellToForgetWeights)
+    {
+        buffers.push_back(
+            CreateBuffer(flatBufferBuilder,
+                         flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToForgetWeights.data()),
+                                                        sizeof(T) * cellToForgetWeights.size())));
+        tensors.push_back(CreateTensor(flatBufferBuilder,
+                                       flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
+                                                                               tensorInfo4.size()),
+                                       tensorType,
+                                       buffers.size() - 1,
+                                       flatBufferBuilder.CreateString("cellToForgetWeights"),
+                                       outputQuantizationParameters));
+        operatorInputs.push_back(buffers.size() - 1);
+    }
+    else
+    {
+        operatorInputs.push_back(kTfLiteOptionalTensor);
+    }
+
+    if (hasCellToOutputWeights)
+    {
+        buffers.push_back(
+            CreateBuffer(flatBufferBuilder,
+                         flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToOutputWeights.data()),
+                                                        sizeof(T) * cellToOutputWeights.size())));
+        tensors.push_back(CreateTensor(flatBufferBuilder,
+                                       flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
+                                                                               tensorInfo4.size()),
+                                       tensorType,
+                                       buffers.size() - 1,
+                                       flatBufferBuilder.CreateString("cellToOutputWeights"),
+                                       outputQuantizationParameters));
+        operatorInputs.push_back(buffers.size() - 1);
+    }
+    else
+    {
+        operatorInputs.push_back(kTfLiteOptionalTensor);
+    }
+
+    if (hasInputGateBias)
+    {
+        buffers.push_back(
+            CreateBuffer(flatBufferBuilder,
+                         flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(inputGateBias.data()),
+                                                        sizeof(T) * inputGateBias.size())));
+        tensors.push_back(CreateTensor(flatBufferBuilder,
+                                       flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
+                                                                               tensorInfo4.size()),
+                                       tensorType,
+                                       buffers.size() - 1,
+                                       flatBufferBuilder.CreateString("inputGateBias"),
+                                       outputQuantizationParameters));
+        operatorInputs.push_back(buffers.size() - 1);
+    }
+    else
+    {
+        operatorInputs.push_back(kTfLiteOptionalTensor);
+    }
+
+    buffers.push_back(
+        CreateBuffer(flatBufferBuilder,
+                     flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(forgetGateBias.data()),
+                                                    sizeof(T) * forgetGateBias.size())));
+    tensors.push_back(CreateTensor(flatBufferBuilder,
+                                   flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
+                                                                           tensorInfo4.size()),
+                                   tensorType,
+                                   buffers.size() - 1,
+                                   flatBufferBuilder.CreateString("forgetGateBias"),
+                                   outputQuantizationParameters));
+    operatorInputs.push_back(buffers.size() - 1);
+
+    buffers.push_back(
+        CreateBuffer(flatBufferBuilder,
+                     flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(cellBias.data()),
+                                                    sizeof(T) * cellBias.size())));
+    tensors.push_back(CreateTensor(flatBufferBuilder,
+                                   flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
+                                                                           tensorInfo4.size()),
+                                   tensorType,
+                                   buffers.size() - 1,
+                                   flatBufferBuilder.CreateString("cellBias"),
+                                   outputQuantizationParameters));
+    operatorInputs.push_back(buffers.size() - 1);
+
+    buffers.push_back(
+        CreateBuffer(flatBufferBuilder,
+                     flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(outputGateBias.data()),
+                                                    sizeof(T) * outputGateBias.size())));
+    tensors.push_back(CreateTensor(flatBufferBuilder,
+                                   flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
+                                                                           tensorInfo4.size()),
+                                   tensorType,
+                                   buffers.size() - 1,
+                                   flatBufferBuilder.CreateString("outputGateBias"),
+                                   outputQuantizationParameters));
+    operatorInputs.push_back(buffers.size() - 1);
+
+    if (hasProjectionWeights)
+    {
+        buffers.push_back(
+            CreateBuffer(flatBufferBuilder,
+                         flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(projectionWeights.data()),
+                                                        sizeof(T) * projectionWeights.size())));
+        tensors.push_back(CreateTensor(flatBufferBuilder,
+                                       flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
+                                                                               tensorInfo4.size()),
+                                       tensorType,
+                                       buffers.size() - 1,
+                                       flatBufferBuilder.CreateString("outputGateBias"),
+                                       outputQuantizationParameters));
+        operatorInputs.push_back(buffers.size() - 1);
+    }
+    else
+    {
+        operatorInputs.push_back(kTfLiteOptionalTensor);
+    }
+
+    if (hasProjectionBias)
+    {
+        buffers.push_back(
+            CreateBuffer(flatBufferBuilder,
+                         flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(projectionBias.data()),
+                                                        sizeof(T) * projectionBias.size())));
+        tensors.push_back(CreateTensor(flatBufferBuilder,
+                                       flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
+                                                                               tensorInfo4.size()),
+                                       tensorType,
+                                       buffers.size() - 1,
+                                       flatBufferBuilder.CreateString("projectionBias"),
+                                       outputQuantizationParameters));
+        operatorInputs.push_back(buffers.size() - 1);
+    }
+    else
+    {
+        operatorInputs.push_back(kTfLiteOptionalTensor);
+    }
+
+    buffers.push_back(CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({})));
+    tensors.push_back(CreateTensor(flatBufferBuilder,
+                                   flatBufferBuilder.CreateVector<int32_t>(outputStateInDimensions.data(),
+                                                                           outputStateInDimensions.size()),
+                                   tensorType,
+                                   buffers.size() - 1,
+                                   flatBufferBuilder.CreateString("outputStateInInfo"),
+                                   outputQuantizationParameters,
+                                   true));
+    operatorInputs.push_back(buffers.size() - 1);
+
+    buffers.push_back(CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({})));
+    tensors.push_back(CreateTensor(flatBufferBuilder,
+                                   flatBufferBuilder.CreateVector<int32_t>(cellStateInDimensions.data(),
+                                                                           cellStateInDimensions.size()),
+                                   tensorType,
+                                   buffers.size() - 1,
+                                   flatBufferBuilder.CreateString("cellStateInInfo"),
+                                   outputQuantizationParameters,
+                                   true));
+    operatorInputs.push_back(buffers.size() - 1);
+
+    if (hasInputLayerNormWeights)
+    {
+        buffers.push_back(
+            CreateBuffer(flatBufferBuilder,
+                         flatBufferBuilder.CreateVector(
+                                              reinterpret_cast<const uint8_t *>(inputLayerNormWeights.data()),
+                                              sizeof(T) * inputLayerNormWeights.size())));
+        tensors.push_back(CreateTensor(flatBufferBuilder,
+                                       flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
+                                                                               tensorInfo4.size()),
+                                       tensorType,
+                                       buffers.size() - 1,
+                                       flatBufferBuilder.CreateString("inputLayerNormWeights"),
+                                       outputQuantizationParameters));
+        operatorInputs.push_back(buffers.size() - 1);
+    }
+    else
+    {
+        operatorInputs.push_back(kTfLiteOptionalTensor);
+    }
+
+    if (hasForgetLayerNormWeights)
+    {
+        buffers.push_back(
+            CreateBuffer(flatBufferBuilder,
+                         flatBufferBuilder.CreateVector(
+                                              reinterpret_cast<const uint8_t *>(forgetLayerNormWeights.data()),
+                                              sizeof(T) * forgetLayerNormWeights.size())));
+        tensors.push_back(CreateTensor(flatBufferBuilder,
+                                       flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
+                                                                               tensorInfo4.size()),
+                                       tensorType,
+                                       buffers.size() - 1,
+                                       flatBufferBuilder.CreateString("forgetLayerNormWeights"),
+                                       outputQuantizationParameters));
+        operatorInputs.push_back(buffers.size() - 1);
+    }
+    else
+    {
+        operatorInputs.push_back(kTfLiteOptionalTensor);
+    }
+
+    if (hasCellLayerNormWeights)
+    {
+        buffers.push_back(
+            CreateBuffer(flatBufferBuilder,
+                         flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(cellLayerNormWeights.data()),
+                                                        sizeof(T) * cellLayerNormWeights.size())));
+        tensors.push_back(CreateTensor(flatBufferBuilder,
+                                       flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
+                                                                               tensorInfo4.size()),
+                                       tensorType,
+                                       buffers.size() - 1,
+                                       flatBufferBuilder.CreateString("cellLayerNormWeights"),
+                                       outputQuantizationParameters));
+        operatorInputs.push_back(buffers.size() - 1);
+    }
+    else
+    {
+        operatorInputs.push_back(kTfLiteOptionalTensor);
+    }
+
+    if (hasOutputLayerNormWeights)
+    {
+        buffers.push_back(
+            CreateBuffer(flatBufferBuilder,
+                         flatBufferBuilder.CreateVector(
+                             reinterpret_cast<const uint8_t *>(outputLayerNormWeights.data()),
+                             sizeof(T) * outputLayerNormWeights.size())));
+        tensors.push_back(CreateTensor(flatBufferBuilder,
+                                       flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
+                                                                               tensorInfo4.size()),
+                                       tensorType,
+                                       buffers.size() - 1,
+                                       flatBufferBuilder.CreateString("outputLayerNormWeights"),
+                                       outputQuantizationParameters));
+        operatorInputs.push_back(buffers.size() - 1);
+    }
+    else
+    {
+        operatorInputs.push_back(kTfLiteOptionalTensor);
+    }
+    int outputBufferId = buffers.size();
+    buffers.push_back(CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({})));
+    tensors.push_back(CreateTensor(flatBufferBuilder,
+                                   flatBufferBuilder.CreateVector<int32_t>(outputShape.data(),
+                                                                           outputShape.size()),
+                                   tensorType,
+                                   outputBufferId,
+                                   flatBufferBuilder.CreateString("output"),
+                                   outputQuantizationParameters));
+    std::vector<int> operatorOutputs;
+    operatorOutputs.push_back(buffers.size() - 1);
+
+    // create operator
+    tflite::BuiltinOptions operatorBuiltinOptionsType = BuiltinOptions_LSTMOptions;
+    flatbuffers::Offset<void> operatorBuiltinOptions =
+        CreateLSTMOptions(flatBufferBuilder,
+                          activationFunction,
+                          clippingThresCell,
+                          clippingThresProj).Union();
+
+    flatbuffers::Offset <Operator> lstmOperator =
+        CreateOperator(flatBufferBuilder,
+                       0,
+                       flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
+                       flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
+                       operatorBuiltinOptionsType, operatorBuiltinOptions);
+
+    flatbuffers::Offset <SubGraph> subgraph =
+        CreateSubGraph(flatBufferBuilder,
+                       flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
+                       flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
+                       flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
+                       flatBufferBuilder.CreateVector(&lstmOperator, 1));
+
+    flatbuffers::Offset <flatbuffers::String> modelDescription =
+        flatBufferBuilder.CreateString("ArmnnDelegate: LSTM Operator Model");
+    flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder,
+                                                                         tflite::BuiltinOperator_LSTM);
+
+    flatbuffers::Offset <Model> flatbufferModel =
+        CreateModel(flatBufferBuilder,
+                    TFLITE_SCHEMA_VERSION,
+                    flatBufferBuilder.CreateVector(&operatorCode, 1),
+                    flatBufferBuilder.CreateVector(&subgraph, 1),
+                    modelDescription,
+                    flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
+
+    flatBufferBuilder.Finish(flatbufferModel);
+
+    return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
+                             flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
+}
+
+template <typename T>
+void LstmTestImpl(std::vector<armnn::BackendId>& backends,
+                  tflite::TensorType tensorType,
+                  int32_t batchSize,
+                  int32_t inputSize,
+                  int32_t outputSize,
+                  int32_t numUnits,
+                  bool hasInputToInputWeights,
+                  const std::vector<T>& inputToInputWeights,
+                  const std::vector<T>& inputToForgetWeights,
+                  const std::vector<T>& inputToCellWeights,
+                  const std::vector<T>& inputToOutputWeights,
+                  bool hasRecurrentToInputWeights,
+                  const std::vector<T>& recurrentToInputWeights,
+                  const std::vector<T>& recurrentToForgetWeights,
+                  const std::vector<T>& recurrentToCellWeights,
+                  const std::vector<T>& recurrentToOutputWeights,
+                  bool hasCellToInputWeights,
+                  const std::vector<T>& cellToInputWeights,
+                  bool hasCellToForgetWeights,
+                  const std::vector<T>& cellToForgetWeights,
+                  bool hasCellToOutputWeights,
+                  const std::vector<T>& cellToOutputWeights,
+                  bool hasInputGateBias,
+                  const std::vector<T>& inputGateBias,
+                  const std::vector<T>& forgetGateBias,
+                  const std::vector<T>& cellBias,
+                  const std::vector<T>& outputGateBias,
+                  bool hasProjectionWeights,
+                  const std::vector<T>& projectionWeights,
+                  bool hasProjectionBias,
+                  const std::vector<T>& projectionBias,
+                  bool hasInputLayerNormWeights,
+                  const std::vector<T>& inputLayerNormWeights,
+                  bool hasForgetLayerNormWeights,
+                  const std::vector<T>& forgetLayerNormWeights,
+                  bool hasCellLayerNormWeights,
+                  const std::vector<T>& cellLayerNormWeights,
+                  bool hasOutputLayerNormWeights,
+                  const std::vector<T>& outputLayerNormWeights,
+                  std::vector<T>& inputValues,
+                  std::vector<T>& expectedOutputValues,
+                  tflite::ActivationFunctionType activationFunction,
+                  float clippingThresCell,
+                  float clippingThresProj)
+{
+    using namespace tflite;
+
+    std::vector<char> modelBuffer = CreateLstmTfLiteModel(tensorType,
+                                                          batchSize,
+                                                          inputSize,
+                                                          outputSize,
+                                                          numUnits,
+                                                          hasInputToInputWeights,
+                                                          inputToInputWeights,
+                                                          inputToForgetWeights,
+                                                          inputToCellWeights,
+                                                          inputToOutputWeights,
+                                                          hasRecurrentToInputWeights,
+                                                          recurrentToInputWeights,
+                                                          recurrentToForgetWeights,
+                                                          recurrentToCellWeights,
+                                                          recurrentToOutputWeights,
+                                                          hasCellToInputWeights,
+                                                          cellToInputWeights,
+                                                          hasCellToForgetWeights,
+                                                          cellToForgetWeights,
+                                                          hasCellToOutputWeights,
+                                                          cellToOutputWeights,
+                                                          hasInputGateBias,
+                                                          inputGateBias,
+                                                          forgetGateBias,
+                                                          cellBias,
+                                                          outputGateBias,
+                                                          hasProjectionWeights,
+                                                          projectionWeights,
+                                                          hasProjectionBias,
+                                                          projectionBias,
+                                                          hasInputLayerNormWeights,
+                                                          inputLayerNormWeights,
+                                                          hasForgetLayerNormWeights,
+                                                          forgetLayerNormWeights,
+                                                          hasCellLayerNormWeights,
+                                                          cellLayerNormWeights,
+                                                          hasOutputLayerNormWeights,
+                                                          outputLayerNormWeights,
+                                                          activationFunction,
+                                                          clippingThresCell,
+                                                          clippingThresProj);
+
+    const Model* tfLiteModel = GetModel(modelBuffer.data());
+    // Create TfLite Interpreters
+    std::unique_ptr<Interpreter> armnnDelegateInterpreter;
+    CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
+                  (&armnnDelegateInterpreter) == kTfLiteOk);
+    CHECK(armnnDelegateInterpreter != nullptr);
+    CHECK(armnnDelegateInterpreter->AllocateTensors() == kTfLiteOk);
+
+    std::unique_ptr<Interpreter> tfLiteInterpreter;
+    CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
+                  (&tfLiteInterpreter) == kTfLiteOk);
+    CHECK(tfLiteInterpreter != nullptr);
+    CHECK(tfLiteInterpreter->AllocateTensors() == kTfLiteOk);
+
+    // Create the ArmNN Delegate
+    armnnDelegate::DelegateOptions delegateOptions(backends);
+    std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
+    theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
+                     armnnDelegate::TfLiteArmnnDelegateDelete);
+    CHECK(theArmnnDelegate != nullptr);
+    // Modify armnnDelegateInterpreter to use armnnDelegate
+    CHECK(armnnDelegateInterpreter->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
+
+    // Set input data
+    auto tfLiteDelegateInputId = tfLiteInterpreter->inputs()[0];
+    auto tfLiteDelageInputData = tfLiteInterpreter->typed_tensor<T>(tfLiteDelegateInputId);
+    for (unsigned int i = 0; i < inputValues.size(); ++i)
+    {
+        tfLiteDelageInputData[i] = inputValues[i];
+    }
+
+    auto armnnDelegateInputId = armnnDelegateInterpreter->inputs()[0];
+    auto armnnDelegateInputData = armnnDelegateInterpreter->typed_tensor<T>(armnnDelegateInputId);
+    for (unsigned int i = 0; i < inputValues.size(); ++i)
+    {
+        armnnDelegateInputData[i] = inputValues[i];
+    }
+
+    // Run EnqueWorkload
+    CHECK(tfLiteInterpreter->Invoke() == kTfLiteOk);
+    CHECK(armnnDelegateInterpreter->Invoke() == kTfLiteOk);
+
+    // Compare output data
+    auto tfLiteDelegateOutputId = tfLiteInterpreter->outputs()[0];
+    auto tfLiteDelagateOutputData = tfLiteInterpreter->typed_tensor<T>(tfLiteDelegateOutputId);
+    auto armnnDelegateOutputId = armnnDelegateInterpreter->outputs()[0];
+    auto armnnDelegateOutputData = armnnDelegateInterpreter->typed_tensor<T>(armnnDelegateOutputId);
+
+    armnnDelegate::CompareData(expectedOutputValues.data(), armnnDelegateOutputData, expectedOutputValues.size());
+    armnnDelegate::CompareData(expectedOutputValues.data(), tfLiteDelagateOutputData, expectedOutputValues.size());
+    armnnDelegate::CompareData(tfLiteDelagateOutputData, armnnDelegateOutputData, expectedOutputValues.size());
+}
+
+} // anonymous namespace
\ No newline at end of file