Fix reshape delegate intermittent error
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>
Tue, 24 Nov 2020 18:40:42 +0000 (18:40 +0000)
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>
Tue, 24 Nov 2020 19:00:51 +0000 (19:00 +0000)
 * Make sue that incorrect corrupted data from reshapeOptions is not used
instead of shape from input tensor
 * Remove redundant check

Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: Ib30f632b5fdb039a618234c1faae183c98033e57

delegate/src/Redefine.hpp

index 9129576..e880383 100644 (file)
@@ -90,62 +90,53 @@ TfLiteStatus VisitReshapeOperator(DelegateData& delegateData,
     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
 
     armnn::ReshapeDescriptor reshapeDesc;
-
-    // The new shape can be defined by either a second input tensor or by a builtin option, we need to check for both.
-    TfLiteReshapeParams* reshapeOptions = reinterpret_cast<TfLiteReshapeParams*>(tfLiteNode->builtin_data);
     std::vector<int32_t> targetShape;
-    bool targetShapeFound = false;
 
-    if (reshapeOptions != nullptr)
+    // The new shape can be defined by either a second input tensor or by a builtin option, we need to check for both.
+    if (numInputs == 2)
     {
-        // Options might be set without valid data. we need to check the dimensions are in a valid range.
-        if (reshapeOptions->num_dimensions > 0 && reshapeOptions->num_dimensions <= 8)
+        // Get shape from the second input tensor
+        const TfLiteTensor& tfLiteShapeInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
+        if (IsDynamicTensor(tfLiteShapeInputTensor))
         {
-            uint64_t elementCounter = 1;
-            for (int i=0; i < reshapeOptions->num_dimensions; ++i)
-            {
-                targetShape.push_back(reshapeOptions->shape[i]);
-                if (reshapeOptions->shape[i] > 0)
-                {
-                    elementCounter = elementCounter * reshapeOptions->shape[i];
-                }
-            }
-            // Check the number of elements match, otherwise fall back to using the second input tensor.
-            if (elementCounter <= inputTensorInfo0.GetNumElements())
-            {
-                targetShapeFound = true;
-            }
+            TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
+                                     "TfLiteArmnnDelegate: Dynamic input tensors are not supported in "
+                                     "operator #%d node #%d: ",
+                                     operatorCode, nodeIndex);
+            return kTfLiteError;
         }
-    }
-    if (!targetShapeFound)
-    {
-        if (numInputs == 2)
+
+        if (tfLiteShapeInputTensor.dims->size != 1)
         {
-            const TfLiteTensor& tfLiteShapeInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
-            if (IsDynamicTensor(tfLiteShapeInputTensor))
-            {
-                TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
-                                         "TfLiteArmnnDelegate: Dynamic input tensors are not supported in "
-                                         "operator #%d node #%d: ",
-                                         operatorCode, nodeIndex);
-                return kTfLiteError;
-            }
+            TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
+                                     "TfLiteArmnnDelegate: Target 'shape' input is not a 1D tensor in "
+                                     "operator #%d node #%d: ",
+                                     operatorCode, nodeIndex);
+            return kTfLiteError;
+        }
 
-            if (tfLiteShapeInputTensor.dims->size != 1)
-            {
-                TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
-                         "TfLiteArmnnDelegate: Target 'shape' input is not a 1D tensor in "
-                         "operator #%d node #%d: ",
-                         operatorCode, nodeIndex);
-                return kTfLiteError;
-            }
+        // Get the shape data out of the input tensor
+        auto* shapeTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteShapeInputTensor);
+        auto shapeTensorNumValues = tfLiteShapeInputTensor.dims->data[0];
+        for (auto i=0; i < shapeTensorNumValues; ++i)
+        {
+            targetShape.push_back(*(shapeTensorDataPtr+i));
+        }
+    }
+    else
+    {
+        // Get shape from the builtin data
+        TfLiteReshapeParams* reshapeOptions = reinterpret_cast<TfLiteReshapeParams*>(tfLiteNode->builtin_data);
 
-            // Get the shape data out of the input tensor
-            auto* shapeTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteShapeInputTensor);
-            auto shapeTensorNumValues = tfLiteShapeInputTensor.dims->data[0];
-            for (auto i=0; i < shapeTensorNumValues; ++i)
+        if (reshapeOptions != nullptr)
+        {
+            // Options might be set without valid data. we need to check the dimensions are in a valid range.
+            if (reshapeOptions->num_dimensions > 0 && reshapeOptions->num_dimensions <= 8)
             {
-                targetShape.push_back(*(shapeTensorDataPtr+i));
+                for (int i=0; i < reshapeOptions->num_dimensions; ++i)
+                {
+                    targetShape.push_back(reshapeOptions->shape[i]);
+                }
             }
         }
         else
@@ -170,10 +161,11 @@ TfLiteStatus VisitReshapeOperator(DelegateData& delegateData,
 
     if (reshapeDesc.m_TargetShape.GetNumElements() != inputTensorInfo0.GetNumElements())
     {
-        TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
-                         "TfLiteArmnnDelegate: Reshape, number of elements in output shape does not match input "
-                         "operator #%d node #%d: ",
-                         operatorCode, nodeIndex);
+        TF_LITE_MAYBE_KERNEL_LOG(
+            tfLiteContext,
+            "TfLiteArmnnDelegate: Reshape, number of elements in output shape does not match input "
+            "operator #%d node #%d: ",
+            operatorCode, nodeIndex);
         return kTfLiteError;
     }