IVGCVSW-2395 TfLiteParse::ParseReshape doesn't support reshape input
authorkevmay01 <kevin.may@arm.com>
Mon, 17 Dec 2018 14:28:03 +0000 (14:28 +0000)
committerLes Bell <les.bell@arm.com>
Mon, 17 Dec 2018 14:51:56 +0000 (14:51 +0000)
Change-Id: If2a31a49df3701877ce0287a81c569334a24cd20

src/armnnTfLiteParser/TfLiteParser.cpp

index 89c72c5..49bc737 100644 (file)
@@ -1065,7 +1065,6 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex)
     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
 
     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
-    CHECK_VALID_SIZE(inputs.size(), 1);
 
     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
     CHECK_VALID_SIZE(outputs.size(), 1);
@@ -1074,15 +1073,29 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex)
     const auto * options = operatorPtr->builtin_options.AsReshapeOptions();
 
     armnn::TensorInfo inputTensorInfo  = ToTensorInfo(inputs[0]);
-    armnn::TensorInfo outputTensorInfo =
+    armnn::TensorInfo actualOutputTensorInfo  = ToTensorInfo(outputs[0]);
+    armnn::TensorInfo reshapeOutputTensorInfo =
         TfLiteParser::OutputShapeOfReshape(inputTensorInfo, options->new_shape);
 
+    // Check for valid input size and that reshape parameters equal output shape
+    if (inputs.size() > 1 && (options->new_shape != outputs[0]->shape))
+    {
+        std::stringstream ss;
+        ss << "New shape defined in reshape parameters "
+           << reshapeOutputTensorInfo.GetShape()
+           << " does not equal output shape "
+           << actualOutputTensorInfo.GetShape()
+           << ": "
+           << CHECK_LOCATION().AsString();
+        throw ParseException(ss.str());
+    }
+
     ReshapeDescriptor reshapeDesc;
-    reshapeDesc.m_TargetShape = outputTensorInfo.GetShape();
+    reshapeDesc.m_TargetShape = reshapeOutputTensorInfo.GetShape();
 
     auto layerName = boost::str(boost::format("Reshape:%1%:%2%") % subgraphIndex % operatorIndex);
     IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
-    layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+    layer->GetOutputSlot(0).SetTensorInfo(reshapeOutputTensorInfo);
 
     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});