IVGCVSW-4449 Add missing QLstm nullptr checks
authorJames Conroy <james.conroy@arm.com>
Mon, 18 May 2020 14:16:42 +0000 (15:16 +0100)
committerJames Conroy <james.conroy@arm.com>
Mon, 18 May 2020 14:22:15 +0000 (15:22 +0100)
* Adds missing nullptr checks for peephole bias for
  QLstm.

Signed-off-by: James Conroy <james.conroy@arm.com>
Change-Id: Ib04c8352141977cc7ea11a4859f1b11d46472722

src/armnn/Network.cpp
src/armnn/layers/QLstmLayer.cpp
src/backends/backendsCommon/WorkloadFactory.cpp

index c2bf27a..a047297 100644 (file)
@@ -1874,15 +1874,16 @@ IConnectableLayer* Network::AddQLstmLayer(const QLstmDescriptor&  descriptor,
             throw InvalidArgumentException("AddQLstmLayer: Projection Weights cannot be NULL");
         }
 
-        if(params.m_ProjectionBias == nullptr)
+        layer->m_ProjectionParameters.m_ProjectionWeights =
+                std::make_unique<ScopedCpuTensorHandle>(*(params.m_ProjectionWeights));
+
+        // Projection bias is optional even if projection is enabled
+        if(params.m_ProjectionWeights != nullptr)
         {
-            throw InvalidArgumentException("AddQLstmLayer: Projection Biases cannot be NULL");
+            layer->m_ProjectionParameters.m_ProjectionBias =
+                    std::make_unique<ScopedCpuTensorHandle>(*(params.m_ProjectionBias));
         }
 
-        layer->m_ProjectionParameters.m_ProjectionWeights =
-                std::make_unique<ScopedCpuTensorHandle>(*(params.m_ProjectionWeights));
-        layer->m_ProjectionParameters.m_ProjectionBias =
-                std::make_unique<ScopedCpuTensorHandle>(*(params.m_ProjectionBias));
     }
 
     // QLstm Peephole params
index 9b940c1..7e61548 100644 (file)
@@ -232,8 +232,6 @@ void QLstmLayer::ValidateTensorShapesFromInputs()
     {
         ARMNN_ASSERT_MSG(m_ProjectionParameters.m_ProjectionWeights != nullptr,
                          "QLstmLayer: m_ProjectionParameters.m_ProjectionWeights should not be null.");
-        ARMNN_ASSERT_MSG(m_ProjectionParameters.m_ProjectionBias != nullptr,
-                         "QLstmLayer: m_ProjectionParameters.m_ProjectionBias should not be null.");
     }
 
     if (m_Param.m_PeepholeEnabled)
index c55c70a..34bfd7c 100644 (file)
@@ -795,7 +795,12 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
             if(descriptor.m_ProjectionEnabled)
             {
                 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
-                paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
+
+                // Projection bias is optional even if projection is enabled
+                if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
+                {
+                    paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
+                }
             }
 
             if(descriptor.m_PeepholeEnabled)