IVGCVSW-4331 Calling RemoveDebugLayers can break connections
authorMike Kelly <mike.kelly@arm.com>
Mon, 20 Jan 2020 17:18:18 +0000 (17:18 +0000)
committerJames Conroy <james.conroy@arm.com>
Mon, 20 Jan 2020 18:13:16 +0000 (18:13 +0000)
 * Changed RemoveDebugLayers to move all connections from its OutputSlot.

Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I3c649e3f660804ca48f3c2af993a5af6a7ed4d4a

src/armnn/DynamicQuantizationVisitor.cpp
src/armnn/test/QuantizerTest.cpp

index ba87c6d..4b1dce0 100644 (file)
@@ -63,13 +63,14 @@ void DynamicQuantizationVisitor::RemoveDebugLayers()
     for (DebugLayer* debugLayer : m_DebugLayers)
     {
         OutputSlot& proceedingOutputSlot = *debugLayer->GetInputSlot(0).GetConnectedOutputSlot();
-        InputSlot& succeedingInputSlot = *debugLayer->GetOutputSlot(0).GetConnection(0);
         proceedingOutputSlot.Disconnect(debugLayer->GetInputSlot(0));
-        debugLayer->GetOutputSlot(0).Disconnect(succeedingInputSlot);
 
+        for (InputSlot* succeedingInputSlot : debugLayer->GetOutputSlot(0).GetConnections())
+        {
+            debugLayer->GetOutputSlot(0).Disconnect(*succeedingInputSlot);
+            proceedingOutputSlot.Connect(*succeedingInputSlot);
+        }
         m_Graph.EraseLayer(debugLayer);
-
-        proceedingOutputSlot.Connect(succeedingInputSlot);
     }
     m_DebugLayers.clear();
 }
index 900aa18..52beb63 100644 (file)
@@ -2708,5 +2708,85 @@ BOOST_AUTO_TEST_CASE(PreserveTypeQsymm16)
     PreserveTypeTestImpl(DataType::QSymmS16);
 }
 
+BOOST_AUTO_TEST_CASE(TestConnectionPreservationAfterDynamicQuant)
+{
+    class TestConnectionPreservation : public LayerVisitorBase<VisitorNoThrowPolicy>
+    {
+    public:
+        TestConnectionPreservation(const Graph& graph)
+            : LayerVisitorBase<VisitorNoThrowPolicy>()
+            , m_Graph(graph)
+        {}
+
+        void VisitAdditionLayer(const IConnectableLayer* layer, const char*) override
+        {
+            CheckLayerName(layer->GetInputSlot(0).GetConnection()->GetOwningLayerGuid(), "reLU1");
+            CheckLayerName(layer->GetInputSlot(1).GetConnection()->GetOwningLayerGuid(), "reLU2");
+        }
+
+        void CheckLayerName(LayerGuid guid, std::string expectedName)
+        {
+            bool guidFound = false;
+            for (Layer* layer : m_Graph)
+            {
+                if (layer->GetGuid() == guid)
+                {
+                    BOOST_CHECK_EQUAL(layer->GetName(), expectedName.c_str());
+                    guidFound = true;
+                    break;
+                }
+            }
+            if (!guidFound)
+            {
+                BOOST_FAIL("No layer matching the GUID was found");
+            }
+        }
+
+    private:
+        Graph m_Graph;
+    };
+
+    INetworkPtr network = INetwork::Create();
+
+    IConnectableLayer* inputLayer =  network->AddInputLayer(0,"inputLayer1");
+    armnn::ActivationDescriptor ReLUDesc;
+    ReLUDesc.m_Function = ActivationFunction::ReLu;
+
+    IConnectableLayer* reLULayer1 = network->AddActivationLayer(ReLUDesc, "reLU1");
+    IConnectableLayer* reLULayer2 = network->AddActivationLayer(ReLUDesc, "reLU2");
+    IConnectableLayer* addLayer1 = network->AddAdditionLayer("addLayer1");
+    IConnectableLayer* outputLayer = network->AddOutputLayer(0,"outPutLayer1");
+
+    inputLayer->GetOutputSlot(0).Connect(reLULayer1->GetInputSlot(0));
+    reLULayer1->GetOutputSlot(0).Connect(reLULayer2->GetInputSlot(0));
+    reLULayer1->GetOutputSlot(0).Connect(addLayer1->GetInputSlot(0));
+    reLULayer2->GetOutputSlot(0).Connect(addLayer1->GetInputSlot(1));
+    addLayer1->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+
+    inputLayer->GetOutputSlot(0).SetTensorInfo(TensorInfo(TensorShape({1, 2, 2, 1}), DataType::Float32));
+    reLULayer1->GetOutputSlot(0).SetTensorInfo(TensorInfo(TensorShape({1, 2, 2, 1}), DataType::Float32));
+    reLULayer2->GetOutputSlot(0).SetTensorInfo(TensorInfo(TensorShape({1, 2, 2, 1}), DataType::Float32));
+    addLayer1->GetOutputSlot(0).SetTensorInfo(TensorInfo(TensorShape({1, 2, 2, 1}), DataType::Float32));
+
+    TestConnectionPreservation visitor1(boost::polymorphic_downcast<const Network*>(network.get())->GetGraph());
+    VisitLayersTopologically(network.get(), visitor1);
+
+    armnn::INetworkQuantizerPtr quantizer = armnn::INetworkQuantizer::Create(network.get());
+
+    armnn::TensorInfo tensorInfo = GetInputTensorInfo(boost::polymorphic_downcast<const Network*>(network.get()));
+
+    std::vector<float> inputData({0, 2, 0, 4});
+    armnn::ConstTensor inputTensor(tensorInfo, inputData.data());
+
+    InputTensors inputTensors;
+    inputTensors.push_back(std::make_pair(0, inputTensor));
+    quantizer->Refine(inputTensors);
+
+    INetworkPtr quantNetwork = quantizer->ExportNetwork();
+
+    TestConnectionPreservation visitor2(boost::polymorphic_downcast<const Network*>(quantNetwork.get())->GetGraph());
+    VisitLayersTopologically(quantNetwork.get(), visitor2);
+}
+
 BOOST_AUTO_TEST_SUITE_END()
 } // namespace armnn