GitHub #418 AddBroadcastReshapeLayer can cause inputs to be connected incorrectly
authorMike Kelly <mike.kelly@arm.com>
Mon, 6 Jul 2020 18:24:15 +0000 (19:24 +0100)
committermike.kelly <mike.kelly@arm.com>
Mon, 6 Jul 2020 19:00:14 +0000 (19:00 +0000)
 * Fixed issue where AddBroadcastReshapeLayer would always connect the Reshaped input to the first input slot and the other input to the first input slot.

Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: Ifd2745a819eb0f72ff9433690afc92a6a34f2ec3

src/armnnTfLiteParser/TfLiteParser.cpp

index 1b93aad..a690e53 100644 (file)
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -563,6 +563,9 @@ void TfLiteParser::AddBroadcastReshapeLayer(size_t subgraphIndex,
     armnn::TensorInfo reshapedTensorInfo = ToTensorInfo(tensorPtr);
     armnn::TensorInfo inputTensorInfo    = ToTensorInfo(tensorPtr1);
 
+    uint32_t inputSlotId   = 1;
+    uint32_t reshapeSlotId = 0;
+
     if (inputTensorInfo.GetNumDimensions() < reshapedTensorInfo.GetNumDimensions())
     {
         uint32_t id = reshapedInputId;
@@ -571,6 +574,9 @@ void TfLiteParser::AddBroadcastReshapeLayer(size_t subgraphIndex,
 
         reshapedTensorInfo = ToTensorInfo(tensorPtr1);
         inputTensorInfo = ToTensorInfo(tensorPtr);
+
+        inputSlotId = 0;
+        reshapeSlotId = 1;
     }
 
     uint32_t numDimensions = inputTensorInfo.GetNumDimensions();
@@ -592,11 +598,11 @@ void TfLiteParser::AddBroadcastReshapeLayer(size_t subgraphIndex,
     armnn::IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(desc, layerName.c_str());
 
     reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
-    reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
+    reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(reshapeSlotId));
 
     RegisterInputSlots(subgraphIndex, operatorIndex, reshapeLayer, {reshapedInputId});
 
-    armnn::IInputSlot* input1Slot = &(layer->GetInputSlot(1));
+    armnn::IInputSlot* input1Slot = &(layer->GetInputSlot(inputSlotId));
     RegisterConsumerOfTensor(subgraphIndex, inputId, input1Slot);
 }