IVGCVSW-3474 Refactor Lstm and QuantizedLstm Param Getters
authorFrancis Murtagh <francis.murtagh@arm.com>
Wed, 14 Aug 2019 08:51:36 +0000 (09:51 +0100)
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>
Wed, 14 Aug 2019 10:37:35 +0000 (10:37 +0000)
 * Change Getter Signatures to follow coding guidelines

Change-Id: Ic02621e834dbf79b9df63f8b4c6339f71651e944
Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
include/armnn/LstmParams.hpp
include/armnn/QuantizedLstmParams.hpp
src/armnn/Network.cpp
src/armnnSerializer/Serializer.cpp
src/backends/cl/workloads/ClLstmFloatWorkload.cpp
src/backends/cl/workloads/ClQuantizedLstmWorkload.cpp
src/backends/neon/workloads/NeonLstmFloatWorkload.cpp
src/backends/neon/workloads/NeonQuantizedLstmWorkload.cpp
src/backends/reference/RefLayerSupport.cpp

index 0c8e66d..89f37ab 100644 (file)
@@ -108,7 +108,7 @@ struct LstmInputParamsInfo
     const TensorInfo* m_CellLayerNormWeights;
     const TensorInfo* m_OutputLayerNormWeights;
 
-    const TensorInfo& deref(const TensorInfo* tensorInfo) const
+    const TensorInfo& Deref(const TensorInfo* tensorInfo) const
     {
         if (tensorInfo != nullptr)
         {
@@ -118,89 +118,89 @@ struct LstmInputParamsInfo
         throw InvalidArgumentException("Can't dereference a null pointer");
     }
 
-    const TensorInfo& get_InputToInputWeights() const
+    const TensorInfo& GetInputToInputWeights() const
     {
-        return deref(m_InputToInputWeights);
+        return Deref(m_InputToInputWeights);
     }
-    const TensorInfo& get_InputToForgetWeights() const
+    const TensorInfo& GetInputToForgetWeights() const
     {
-        return deref(m_InputToForgetWeights);
+        return Deref(m_InputToForgetWeights);
     }
-    const TensorInfo& get_InputToCellWeights() const
+    const TensorInfo& GetInputToCellWeights() const
     {
-        return deref(m_InputToCellWeights);
+        return Deref(m_InputToCellWeights);
     }
-    const TensorInfo& get_InputToOutputWeights() const
+    const TensorInfo& GetInputToOutputWeights() const
     {
-        return deref(m_InputToOutputWeights);
+        return Deref(m_InputToOutputWeights);
     }
-    const TensorInfo& get_RecurrentToInputWeights() const
+    const TensorInfo& GetRecurrentToInputWeights() const
     {
-        return deref(m_RecurrentToInputWeights);
+        return Deref(m_RecurrentToInputWeights);
     }
-    const TensorInfo& get_RecurrentToForgetWeights() const
+    const TensorInfo& GetRecurrentToForgetWeights() const
     {
-        return deref(m_RecurrentToForgetWeights);
+        return Deref(m_RecurrentToForgetWeights);
     }
-    const TensorInfo& get_RecurrentToCellWeights() const
+    const TensorInfo& GetRecurrentToCellWeights() const
     {
-        return deref(m_RecurrentToCellWeights);
+        return Deref(m_RecurrentToCellWeights);
     }
-    const TensorInfo& get_RecurrentToOutputWeights() const
+    const TensorInfo& GetRecurrentToOutputWeights() const
     {
-        return deref(m_RecurrentToOutputWeights);
+        return Deref(m_RecurrentToOutputWeights);
     }
-    const TensorInfo& get_CellToInputWeights() const
+    const TensorInfo& GetCellToInputWeights() const
     {
-        return deref(m_CellToInputWeights);
+        return Deref(m_CellToInputWeights);
     }
-    const TensorInfo& get_CellToForgetWeights() const
+    const TensorInfo& GetCellToForgetWeights() const
     {
-        return deref(m_CellToForgetWeights);
+        return Deref(m_CellToForgetWeights);
     }
-    const TensorInfo& get_CellToOutputWeights() const
+    const TensorInfo& GetCellToOutputWeights() const
     {
-        return deref(m_CellToOutputWeights);
+        return Deref(m_CellToOutputWeights);
     }
-    const TensorInfo& get_InputGateBias() const
+    const TensorInfo& GetInputGateBias() const
     {
-        return deref(m_InputGateBias);
+        return Deref(m_InputGateBias);
     }
-    const TensorInfo& get_ForgetGateBias() const
+    const TensorInfo& GetForgetGateBias() const
     {
-        return deref(m_ForgetGateBias);
+        return Deref(m_ForgetGateBias);
     }
-    const TensorInfo& get_CellBias() const
+    const TensorInfo& GetCellBias() const
     {
-        return deref(m_CellBias);
+        return Deref(m_CellBias);
     }
-    const TensorInfo& get_OutputGateBias() const
+    const TensorInfo& GetOutputGateBias() const
     {
-        return deref(m_OutputGateBias);
+        return Deref(m_OutputGateBias);
     }
-    const TensorInfo& get_ProjectionWeights() const
+    const TensorInfo& GetProjectionWeights() const
     {
-        return deref(m_ProjectionWeights);
+        return Deref(m_ProjectionWeights);
     }
-    const TensorInfo& get_ProjectionBias() const
+    const TensorInfo& GetProjectionBias() const
     {
-        return deref(m_ProjectionBias);
+        return Deref(m_ProjectionBias);
     }
-    const TensorInfo& get_InputLayerNormWeights() const
+    const TensorInfo& GetInputLayerNormWeights() const
     {
-        return deref(m_InputLayerNormWeights);
+        return Deref(m_InputLayerNormWeights);
     }
-    const TensorInfo& get_ForgetLayerNormWeights() const
+    const TensorInfo& GetForgetLayerNormWeights() const
     {
-        return deref(m_ForgetLayerNormWeights);
+        return Deref(m_ForgetLayerNormWeights);
     }
-    const TensorInfo& get_CellLayerNormWeights() const
+    const TensorInfo& GetCellLayerNormWeights() const
     {
-        return deref(m_CellLayerNormWeights);
+        return Deref(m_CellLayerNormWeights);
     }
-    const TensorInfo& get_OutputLayerNormWeights() const
+    const TensorInfo& GetOutputLayerNormWeights() const
     {
-        return deref(m_OutputLayerNormWeights);
+        return Deref(m_OutputLayerNormWeights);
     }
 };
 
index b3033ac..f68e607 100644 (file)
@@ -45,7 +45,7 @@ struct QuantizedLstmInputParams
     const ConstTensor* m_CellBias;
     const ConstTensor* m_OutputGateBias;
 
-    const ConstTensor& deref(const ConstTensor* tensorPtr) const
+    const ConstTensor& Deref(const ConstTensor* tensorPtr) const
     {
         if (tensorPtr != nullptr)
         {
@@ -55,64 +55,64 @@ struct QuantizedLstmInputParams
         throw InvalidArgumentException("QuantizedLstmInputParams: Can't dereference a null pointer");
     }
 
-    const ConstTensor& get_InputToInputWeights() const
+    const ConstTensor& GetInputToInputWeights() const
     {
-        return deref(m_InputToInputWeights);
+        return Deref(m_InputToInputWeights);
     }
 
-    const ConstTensor& get_InputToForgetWeights() const
+    const ConstTensor& GetInputToForgetWeights() const
     {
-        return deref(m_InputToForgetWeights);
+        return Deref(m_InputToForgetWeights);
     }
 
-    const ConstTensor& get_InputToCellWeights() const
+    const ConstTensor& GetInputToCellWeights() const
     {
-        return deref(m_InputToCellWeights);
+        return Deref(m_InputToCellWeights);
     }
 
-    const ConstTensor& get_InputToOutputWeights() const
+    const ConstTensor& GetInputToOutputWeights() const
     {
-        return deref(m_InputToOutputWeights);
+        return Deref(m_InputToOutputWeights);
     }
 
-    const ConstTensor& get_RecurrentToInputWeights() const
+    const ConstTensor& GetRecurrentToInputWeights() const
     {
-        return deref(m_RecurrentToInputWeights);
+        return Deref(m_RecurrentToInputWeights);
     }
 
-    const ConstTensor& get_RecurrentToForgetWeights() const
+    const ConstTensor& GetRecurrentToForgetWeights() const
     {
-        return deref(m_RecurrentToForgetWeights);
+        return Deref(m_RecurrentToForgetWeights);
     }
 
-    const ConstTensor& get_RecurrentToCellWeights() const
+    const ConstTensor& GetRecurrentToCellWeights() const
     {
-        return deref(m_RecurrentToCellWeights);
+        return Deref(m_RecurrentToCellWeights);
     }
 
-    const ConstTensor& get_RecurrentToOutputWeights() const
+    const ConstTensor& GetRecurrentToOutputWeights() const
     {
-        return deref(m_RecurrentToOutputWeights);
+        return Deref(m_RecurrentToOutputWeights);
     }
 
-    const ConstTensor& get_InputGateBias() const
+    const ConstTensor& GetInputGateBias() const
     {
-        return deref(m_InputGateBias);
+        return Deref(m_InputGateBias);
     }
 
-    const ConstTensor& get_ForgetGateBias() const
+    const ConstTensor& GetForgetGateBias() const
     {
-        return deref(m_ForgetGateBias);
+        return Deref(m_ForgetGateBias);
     }
 
-    const ConstTensor& get_CellBias() const
+    const ConstTensor& GetCellBias() const
     {
-        return deref(m_CellBias);
+        return Deref(m_CellBias);
     }
 
-    const ConstTensor& get_OutputGateBias() const
+    const ConstTensor& GetOutputGateBias() const
     {
-        return deref(m_OutputGateBias);
+        return Deref(m_OutputGateBias);
     }
 };
 
@@ -152,7 +152,7 @@ struct QuantizedLstmInputParamsInfo
     const TensorInfo* m_OutputGateBias;
 
 
-    const TensorInfo& deref(const TensorInfo* tensorInfo) const
+    const TensorInfo& Deref(const TensorInfo* tensorInfo) const
     {
         if (tensorInfo != nullptr)
         {
@@ -162,55 +162,55 @@ struct QuantizedLstmInputParamsInfo
         throw InvalidArgumentException("Can't dereference a null pointer");
     }
 
-    const TensorInfo& get_InputToInputWeights() const
+    const TensorInfo& GetInputToInputWeights() const
     {
-        return deref(m_InputToInputWeights);
+        return Deref(m_InputToInputWeights);
     }
-    const TensorInfo& get_InputToForgetWeights() const
+    const TensorInfo& GetInputToForgetWeights() const
     {
-        return deref(m_InputToForgetWeights);
+        return Deref(m_InputToForgetWeights);
     }
-    const TensorInfo& get_InputToCellWeights() const
+    const TensorInfo& GetInputToCellWeights() const
     {
-        return deref(m_InputToCellWeights);
+        return Deref(m_InputToCellWeights);
     }
-    const TensorInfo& get_InputToOutputWeights() const
+    const TensorInfo& GetInputToOutputWeights() const
     {
-        return deref(m_InputToOutputWeights);
+        return Deref(m_InputToOutputWeights);
     }
 
-    const TensorInfo& get_RecurrentToInputWeights() const
+    const TensorInfo& GetRecurrentToInputWeights() const
     {
-        return deref(m_RecurrentToInputWeights);
+        return Deref(m_RecurrentToInputWeights);
     }
-    const TensorInfo& get_RecurrentToForgetWeights() const
+    const TensorInfo& GetRecurrentToForgetWeights() const
     {
-        return deref(m_RecurrentToForgetWeights);
+        return Deref(m_RecurrentToForgetWeights);
     }
-    const TensorInfo& get_RecurrentToCellWeights() const
+    const TensorInfo& GetRecurrentToCellWeights() const
     {
-        return deref(m_RecurrentToCellWeights);
+        return Deref(m_RecurrentToCellWeights);
     }
-    const TensorInfo& get_RecurrentToOutputWeights() const
+    const TensorInfo& GetRecurrentToOutputWeights() const
     {
-        return deref(m_RecurrentToOutputWeights);
+        return Deref(m_RecurrentToOutputWeights);
     }
 
-    const TensorInfo& get_InputGateBias() const
+    const TensorInfo& GetInputGateBias() const
     {
-        return deref(m_InputGateBias);
+        return Deref(m_InputGateBias);
     }
-    const TensorInfo& get_ForgetGateBias() const
+    const TensorInfo& GetForgetGateBias() const
     {
-        return deref(m_ForgetGateBias);
+        return Deref(m_ForgetGateBias);
     }
-    const TensorInfo& get_CellBias() const
+    const TensorInfo& GetCellBias() const
     {
-        return deref(m_CellBias);
+        return Deref(m_CellBias);
     }
-    const TensorInfo& get_OutputGateBias() const
+    const TensorInfo& GetOutputGateBias() const
     {
-        return deref(m_OutputGateBias);
+        return Deref(m_OutputGateBias);
     }
 };
 
index b30cd9f..932f9eb 100644 (file)
@@ -1468,33 +1468,33 @@ IConnectableLayer* Network::AddQuantizedLstmLayer(const QuantizedLstmInputParams
 
     // InputToX weights
     layer->m_QuantizedLstmParameters.m_InputToInputWeights =
-            std::make_unique<ScopedCpuTensorHandle>(params.get_InputToInputWeights());
+            std::make_unique<ScopedCpuTensorHandle>(params.GetInputToInputWeights());
     layer->m_QuantizedLstmParameters.m_InputToForgetWeights =
-            std::make_unique<ScopedCpuTensorHandle>(params.get_InputToForgetWeights());
+            std::make_unique<ScopedCpuTensorHandle>(params.GetInputToForgetWeights());
     layer->m_QuantizedLstmParameters.m_InputToCellWeights =
-            std::make_unique<ScopedCpuTensorHandle>(params.get_InputToCellWeights());
+            std::make_unique<ScopedCpuTensorHandle>(params.GetInputToCellWeights());
     layer->m_QuantizedLstmParameters.m_InputToOutputWeights =
-            std::make_unique<ScopedCpuTensorHandle>(params.get_InputToOutputWeights());
+            std::make_unique<ScopedCpuTensorHandle>(params.GetInputToOutputWeights());
 
     // RecurrentToX weights
     layer->m_QuantizedLstmParameters.m_RecurrentToInputWeights =
-            std::make_unique<ScopedCpuTensorHandle>(params.get_RecurrentToInputWeights());
+            std::make_unique<ScopedCpuTensorHandle>(params.GetRecurrentToInputWeights());
     layer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights =
-            std::make_unique<ScopedCpuTensorHandle>(params.get_RecurrentToForgetWeights());
+            std::make_unique<ScopedCpuTensorHandle>(params.GetRecurrentToForgetWeights());
     layer->m_QuantizedLstmParameters.m_RecurrentToCellWeights =
-            std::make_unique<ScopedCpuTensorHandle>(params.get_RecurrentToCellWeights());
+            std::make_unique<ScopedCpuTensorHandle>(params.GetRecurrentToCellWeights());
     layer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights =
-            std::make_unique<ScopedCpuTensorHandle>(params.get_RecurrentToOutputWeights());
+            std::make_unique<ScopedCpuTensorHandle>(params.GetRecurrentToOutputWeights());
 
     // Bias
     layer->m_QuantizedLstmParameters.m_InputGateBias =
-            std::make_unique<ScopedCpuTensorHandle>(params.get_InputGateBias());
+            std::make_unique<ScopedCpuTensorHandle>(params.GetInputGateBias());
     layer->m_QuantizedLstmParameters.m_ForgetGateBias =
-            std::make_unique<ScopedCpuTensorHandle>(params.get_ForgetGateBias());
+            std::make_unique<ScopedCpuTensorHandle>(params.GetForgetGateBias());
     layer->m_QuantizedLstmParameters.m_CellBias =
-            std::make_unique<ScopedCpuTensorHandle>(params.get_CellBias());
+            std::make_unique<ScopedCpuTensorHandle>(params.GetCellBias());
     layer->m_QuantizedLstmParameters.m_OutputGateBias =
-            std::make_unique<ScopedCpuTensorHandle>(params.get_OutputGateBias());
+            std::make_unique<ScopedCpuTensorHandle>(params.GetOutputGateBias());
 
     return layer;
 }
index af4dc7a..d35be6f 100644 (file)
@@ -1049,20 +1049,20 @@ void SerializerVisitor::VisitQuantizedLstmLayer(const armnn::IConnectableLayer*
     auto fbQuantizedLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QuantizedLstm);
 
     // Get input parameters
-    auto inputToInputWeights = CreateConstTensorInfo(params.get_InputToInputWeights());
-    auto inputToForgetWeights = CreateConstTensorInfo(params.get_InputToForgetWeights());
-    auto inputToCellWeights = CreateConstTensorInfo(params.get_InputToCellWeights());
-    auto inputToOutputWeights = CreateConstTensorInfo(params.get_InputToOutputWeights());
-
-    auto recurrentToInputWeights = CreateConstTensorInfo(params.get_RecurrentToInputWeights());
-    auto recurrentToForgetWeights = CreateConstTensorInfo(params.get_RecurrentToForgetWeights());
-    auto recurrentToCellWeights = CreateConstTensorInfo(params.get_RecurrentToCellWeights());
-    auto recurrentToOutputWeights = CreateConstTensorInfo(params.get_RecurrentToOutputWeights());
-
-    auto inputGateBias = CreateConstTensorInfo(params.get_InputGateBias());
-    auto forgetGateBias = CreateConstTensorInfo(params.get_ForgetGateBias());
-    auto cellBias = CreateConstTensorInfo(params.get_CellBias());
-    auto outputGateBias = CreateConstTensorInfo(params.get_OutputGateBias());
+    auto inputToInputWeights = CreateConstTensorInfo(params.GetInputToInputWeights());
+    auto inputToForgetWeights = CreateConstTensorInfo(params.GetInputToForgetWeights());
+    auto inputToCellWeights = CreateConstTensorInfo(params.GetInputToCellWeights());
+    auto inputToOutputWeights = CreateConstTensorInfo(params.GetInputToOutputWeights());
+
+    auto recurrentToInputWeights = CreateConstTensorInfo(params.GetRecurrentToInputWeights());
+    auto recurrentToForgetWeights = CreateConstTensorInfo(params.GetRecurrentToForgetWeights());
+    auto recurrentToCellWeights = CreateConstTensorInfo(params.GetRecurrentToCellWeights());
+    auto recurrentToOutputWeights = CreateConstTensorInfo(params.GetRecurrentToOutputWeights());
+
+    auto inputGateBias = CreateConstTensorInfo(params.GetInputGateBias());
+    auto forgetGateBias = CreateConstTensorInfo(params.GetForgetGateBias());
+    auto cellBias = CreateConstTensorInfo(params.GetCellBias());
+    auto outputGateBias = CreateConstTensorInfo(params.GetOutputGateBias());
 
     auto fbQuantizedLstmParams = serializer::CreateQuantizedLstmInputParams(
         m_flatBufferBuilder,
index f5d081e..2f3ba75 100644 (file)
@@ -272,20 +272,20 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T
 
     // Basic parameters
     const arm_compute::TensorInfo aclInputToForgetWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToForgetWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetInputToForgetWeights());
     const arm_compute::TensorInfo aclInputToCellWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToCellWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetInputToCellWeights());
     const arm_compute::TensorInfo aclInputToOutputWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToOutputWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetInputToOutputWeights());
     const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToForgetWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToForgetWeights());
     const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToCellWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToCellWeights());
     const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToOutputWeights());
-    const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetGateBias());
-    const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellBias());
-    const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputGateBias());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToOutputWeights());
+    const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetForgetGateBias());
+    const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellBias());
+    const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetOutputGateBias());
 
     arm_compute::TensorInfo aclInputToInputWeightsInfo;
     arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
@@ -302,14 +302,14 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T
 
     if (!descriptor.m_CifgEnabled)
     {
-        aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputToInputWeights());
-        aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToInputWeights());
+        aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputToInputWeights());
+        aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToInputWeights());
 
         if (paramsInfo.m_CellToInputWeights != nullptr)
         {
-            aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToInputWeights());
+            aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToInputWeights());
         }
-        aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias());
+        aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputGateBias());
         lstm_params_info.set_cifg_params(&aclInputToInputWeightsInfo, &aclRecurrentToInputWeightsInfo,
                                          paramsInfo.m_CellToInputWeights != nullptr ?
                                          &aclCellToInputWeightsInfo: nullptr,
@@ -318,11 +318,11 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T
 
     if (descriptor.m_ProjectionEnabled)
     {
-        aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ProjectionWeights());
+        aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionWeights());
 
         if (paramsInfo.m_ProjectionBias != nullptr)
         {
-            aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias());
+            aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputGateBias());
         }
         lstm_params_info.set_projection_params(&aclProjectionWeightsInfo,
                                                paramsInfo.m_ProjectionBias != nullptr ?
@@ -331,8 +331,8 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T
 
     if (descriptor.m_PeepholeEnabled)
     {
-        aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToForgetWeights());
-        aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToOutputWeights());
+        aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToForgetWeights());
+        aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToOutputWeights());
         lstm_params_info.set_peephole_params(&aclCellToForgetWeightsInfo, &aclCellToOutputWeightsInfo);
     }
 
@@ -374,14 +374,14 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T
     {
         if (!descriptor.m_CifgEnabled)
         {
-            aclInputLayerNormWeightsInfo  = BuildArmComputeTensorInfo(paramsInfo.get_InputLayerNormWeights());
+            aclInputLayerNormWeightsInfo  = BuildArmComputeTensorInfo(paramsInfo.GetInputLayerNormWeights());
         }
 
-        aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetLayerNormWeights());
+        aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetForgetLayerNormWeights());
 
-        aclCellLayerNormWeightsInfo   = BuildArmComputeTensorInfo(paramsInfo.get_CellLayerNormWeights());
+        aclCellLayerNormWeightsInfo   = BuildArmComputeTensorInfo(paramsInfo.GetCellLayerNormWeights());
 
-        aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputLayerNormWeights());
+        aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetOutputLayerNormWeights());
 
         lstm_params_info.set_layer_normalization_params(descriptor.m_CifgEnabled ?
                                                         nullptr : &aclInputLayerNormWeightsInfo,
index 76a6694..688ebf9 100644 (file)
@@ -31,25 +31,25 @@ arm_compute::Status ClQuantizedLstmWorkloadValidate(const TensorInfo& input, con
 
     // Basic parameters
     const arm_compute::TensorInfo aclInputToInputWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToInputWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetInputToInputWeights());
     const arm_compute::TensorInfo aclInputToForgetWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToForgetWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetInputToForgetWeights());
     const arm_compute::TensorInfo aclInputToCellWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToCellWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetInputToCellWeights());
     const arm_compute::TensorInfo aclInputToOutputWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToOutputWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetInputToOutputWeights());
     const arm_compute::TensorInfo aclRecurrentToInputWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToInputWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToInputWeights());
     const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToForgetWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToForgetWeights());
     const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToCellWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToCellWeights());
     const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToOutputWeights());
-    const arm_compute::TensorInfo aclInputGateBiasInfo  = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias());
-    const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetGateBias());
-    const arm_compute::TensorInfo aclCellBiasInfo       = BuildArmComputeTensorInfo(paramsInfo.get_CellBias());
-    const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputGateBias());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToOutputWeights());
+    const arm_compute::TensorInfo aclInputGateBiasInfo  = BuildArmComputeTensorInfo(paramsInfo.GetInputGateBias());
+    const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetForgetGateBias());
+    const arm_compute::TensorInfo aclCellBiasInfo       = BuildArmComputeTensorInfo(paramsInfo.GetCellBias());
+    const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetOutputGateBias());
 
     return arm_compute::CLLSTMLayerQuantized::validate(&aclInputInfo, &aclInputToInputWeightsInfo,
                                                        &aclInputToForgetWeightsInfo, &aclInputToCellWeightsInfo,
index 6dd9f4f..2f29610 100644 (file)
@@ -291,23 +291,23 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input,
 
     // Basic parameters
     const arm_compute::TensorInfo aclInputToForgetWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToForgetWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetInputToForgetWeights());
     const arm_compute::TensorInfo aclInputToCellWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToCellWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetInputToCellWeights());
     const arm_compute::TensorInfo aclInputToOutputWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToOutputWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetInputToOutputWeights());
     const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToForgetWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToForgetWeights());
     const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToCellWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToCellWeights());
     const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToOutputWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToOutputWeights());
     const arm_compute::TensorInfo aclForgetGateBiasInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_ForgetGateBias());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetForgetGateBias());
     const arm_compute::TensorInfo aclCellBiasInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_CellBias());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetCellBias());
     const arm_compute::TensorInfo aclOutputGateBiasInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_OutputGateBias());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetOutputGateBias());
 
     arm_compute::TensorInfo aclInputToInputWeightsInfo;
     arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
@@ -328,11 +328,11 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input,
     {
         if (descriptor.m_PeepholeEnabled)
         {
-            aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToInputWeights());
+            aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToInputWeights());
         }
-        aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputToInputWeights());
-        aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToInputWeights());
-        aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias());
+        aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputToInputWeights());
+        aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToInputWeights());
+        aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputGateBias());
 
         lstm_params_info.set_cifg_params(&aclInputToInputWeightsInfo, &aclRecurrentToInputWeightsInfo,
                                          descriptor.m_PeepholeEnabled ? &aclCellToInputWeightsInfo : nullptr,
@@ -343,9 +343,9 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input,
     {
         if (paramsInfo.m_ProjectionBias != nullptr)
         {
-            aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_ProjectionBias());
+            aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionBias());
         }
-        aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ProjectionWeights());
+        aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionWeights());
 
         lstm_params_info.set_projection_params(&aclProjectionWeightsInfo,
                                                paramsInfo.m_ProjectionBias != nullptr ?
@@ -354,8 +354,8 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input,
 
     if (descriptor.m_PeepholeEnabled)
     {
-        aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToForgetWeights());
-        aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToOutputWeights());
+        aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToForgetWeights());
+        aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToOutputWeights());
 
         lstm_params_info.set_peephole_params(&aclCellToForgetWeightsInfo, &aclCellToOutputWeightsInfo);
     }
@@ -364,11 +364,11 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input,
     {
         if (!descriptor.m_CifgEnabled)
         {
-            aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputLayerNormWeights());
+            aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputLayerNormWeights());
         }
-        aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetLayerNormWeights());
-        aclCellLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellLayerNormWeights());
-        aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputLayerNormWeights());
+        aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetForgetLayerNormWeights());
+        aclCellLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellLayerNormWeights());
+        aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetOutputLayerNormWeights());
 
         lstm_params_info.set_layer_normalization_params(descriptor.m_CifgEnabled ?
                                                         nullptr : &aclInputLayerNormWeightsInfo,
index d4319d4..4c2ba75 100644 (file)
@@ -143,31 +143,31 @@ arm_compute::Status NeonQuantizedLstmWorkloadValidate(const TensorInfo& input,
 
     // Basic parameters
     const arm_compute::TensorInfo aclInputToInputWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToInputWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetInputToInputWeights());
     const arm_compute::TensorInfo aclInputToForgetWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToForgetWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetInputToForgetWeights());
     const arm_compute::TensorInfo aclInputToCellWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToCellWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetInputToCellWeights());
     const arm_compute::TensorInfo aclInputToOutputWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToOutputWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetInputToOutputWeights());
 
     const arm_compute::TensorInfo aclRecurrentToInputWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToInputWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToInputWeights());
     const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToForgetWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToForgetWeights());
     const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToCellWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToCellWeights());
     const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToOutputWeights());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToOutputWeights());
 
     const arm_compute::TensorInfo aclInputGateBiasInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetInputGateBias());
     const arm_compute::TensorInfo aclForgetGateBiasInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_ForgetGateBias());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetForgetGateBias());
     const arm_compute::TensorInfo aclCellBiasInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_CellBias());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetCellBias());
     const arm_compute::TensorInfo aclOutputGateBiasInfo
-                                  = BuildArmComputeTensorInfo(paramsInfo.get_OutputGateBias());
+                                  = BuildArmComputeTensorInfo(paramsInfo.GetOutputGateBias());
 
     return arm_compute::NELSTMLayerQuantized::validate(&aclInputInfo,
                                                        &aclInputToInputWeightsInfo,
index 2648f45..56ca437 100644 (file)
@@ -808,54 +808,54 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
                                   "Reference Lstm: input and output types are mismatched");
     // check layer parameters
-    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToForgetWeights()), reasonIfUnsupported,
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
                                   "Reference Lstm: input and InputToForgetWeights types are mismatched");
-    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToCellWeights()), reasonIfUnsupported,
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
                                   "Reference Lstm: input and InputToCellWeights types are mismatched");
-    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToOutputWeights()), reasonIfUnsupported,
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
                                   "Reference Lstm: input and InputToOutputWeights types are mismatched");
-    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToForgetWeights()), reasonIfUnsupported,
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
                                   "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
-    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToCellWeights()), reasonIfUnsupported,
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
                                   "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
-    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToOutputWeights()), reasonIfUnsupported,
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
                                   "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
-    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetGateBias()), reasonIfUnsupported,
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
                                   "Reference Lstm: input and ForgetGateBias types are mismatched");
-    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellBias()), reasonIfUnsupported,
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
                                   "Reference Lstm: input and CellBias types are mismatched");
-    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputGateBias()), reasonIfUnsupported,
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
                                   "Reference Lstm: input and OutputGateBias types are mismatched");
     if (!descriptor.m_CifgEnabled)
     {
-        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToInputWeights()), reasonIfUnsupported,
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
                                       "Reference Lstm: input and InputToInputWeights types are mismatched");
-        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToInputWeights()),
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
                                       reasonIfUnsupported,
                                       "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
-        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputGateBias()), reasonIfUnsupported,
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
                                       "Reference Lstm: input and InputGateBias types are mismatched");
         if (descriptor.m_PeepholeEnabled)
         {
-            supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToInputWeights()),
+            supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
                                           reasonIfUnsupported,
                                           "Reference Lstm: input and CellToInputWeights types are mismatched");
         }
     }
     if (descriptor.m_PeepholeEnabled)
     {
-        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToForgetWeights()), reasonIfUnsupported,
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
                                       "Reference Lstm: input and CellToForgetWeights types are mismatched");
-        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToOutputWeights()), reasonIfUnsupported,
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
                                       "Reference Lstm: input and CellToOutputWeights types are mismatched");
     }
     if (descriptor.m_ProjectionEnabled)
     {
-        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionWeights()), reasonIfUnsupported,
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
                                       "Reference Lstm: input and mProjectionWeights types are mismatched");
         if (paramsInfo.m_ProjectionBias != nullptr)
         {
-            supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionBias()), reasonIfUnsupported,
+            supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
                                           "Reference Lstm: input and ProjectionBias types are mismatched");
         }
     }
@@ -863,17 +863,17 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
     {
         if (!descriptor.m_CifgEnabled)
         {
-            supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputLayerNormWeights()),
+            supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
                                           reasonIfUnsupported,
                                           "Reference Lstm: input and InputLayerNormWeights types are mismatched");
         }
-        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetLayerNormWeights()),
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
                                       reasonIfUnsupported,
                                       "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
-        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellLayerNormWeights()),
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
                                       reasonIfUnsupported,
                                       "Reference Lstm: input and CellLayerNormWeights types are mismatched");
-        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputLayerNormWeights()),
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
                                       reasonIfUnsupported,
                                       "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
     }