//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
armnn::TensorInfo reshapedTensorInfo = ToTensorInfo(tensorPtr);
armnn::TensorInfo inputTensorInfo = ToTensorInfo(tensorPtr1);
+ uint32_t inputSlotId = 1;
+ uint32_t reshapeSlotId = 0;
+
if (inputTensorInfo.GetNumDimensions() < reshapedTensorInfo.GetNumDimensions())
{
uint32_t id = reshapedInputId;
reshapedTensorInfo = ToTensorInfo(tensorPtr1);
inputTensorInfo = ToTensorInfo(tensorPtr);
+
+ inputSlotId = 0;
+ reshapeSlotId = 1;
}
uint32_t numDimensions = inputTensorInfo.GetNumDimensions();
armnn::IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(desc, layerName.c_str());
reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
- reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
+ reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(reshapeSlotId));
RegisterInputSlots(subgraphIndex, operatorIndex, reshapeLayer, {reshapedInputId});
- armnn::IInputSlot* input1Slot = &(layer->GetInputSlot(1));
+ armnn::IInputSlot* input1Slot = &(layer->GetInputSlot(inputSlotId));
RegisterConsumerOfTensor(subgraphIndex, inputId, input1Slot);
}