IVGCVSW-2915 Add Merge Layer and no-op factory method
[platform/upstream/armnn.git] / src / armnn / test / NetworkTests.cpp
index 4de09a2..dd8eb77 100644 (file)
@@ -417,4 +417,56 @@ BOOST_AUTO_TEST_CASE(Network_AddQuantize)
 
 }
 
+BOOST_AUTO_TEST_CASE(Network_AddMerge)
+{
+    struct Test : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
+    {
+        void VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name) override
+        {
+            m_Visited = true;
+
+            BOOST_TEST(layer);
+
+            std::string expectedName = std::string("merge");
+            BOOST_TEST(std::string(layer->GetName()) == expectedName);
+            BOOST_TEST(std::string(name) == expectedName);
+
+            BOOST_TEST(layer->GetNumInputSlots() == 2);
+            BOOST_TEST(layer->GetNumOutputSlots() == 1);
+
+            const armnn::TensorInfo& infoIn0 = layer->GetInputSlot(0).GetConnection()->GetTensorInfo();
+            BOOST_TEST((infoIn0.GetDataType() == armnn::DataType::Float32));
+
+            const armnn::TensorInfo& infoIn1 = layer->GetInputSlot(1).GetConnection()->GetTensorInfo();
+            BOOST_TEST((infoIn1.GetDataType() == armnn::DataType::Float32));
+
+            const armnn::TensorInfo& infoOut = layer->GetOutputSlot(0).GetTensorInfo();
+            BOOST_TEST((infoOut.GetDataType() == armnn::DataType::Float32));
+        }
+
+        bool m_Visited = false;
+    };
+
+    armnn::INetworkPtr network = armnn::INetwork::Create();
+
+    armnn::IConnectableLayer* input0 = network->AddInputLayer(0);
+    armnn::IConnectableLayer* input1 = network->AddInputLayer(1);
+    armnn::IConnectableLayer* merge = network->AddMergeLayer("merge");
+    armnn::IConnectableLayer* output = network->AddOutputLayer(0);
+
+    input0->GetOutputSlot(0).Connect(merge->GetInputSlot(0));
+    input1->GetOutputSlot(0).Connect(merge->GetInputSlot(1));
+    merge->GetOutputSlot(0).Connect(output->GetInputSlot(0));
+
+    const armnn::TensorInfo info({3,1}, armnn::DataType::Float32);
+    input0->GetOutputSlot(0).SetTensorInfo(info);
+    input1->GetOutputSlot(0).SetTensorInfo(info);
+    merge->GetOutputSlot(0).SetTensorInfo(info);
+
+    Test testMerge;
+    network->Accept(testMerge);
+
+    BOOST_TEST(testMerge.m_Visited == true);
+}
+
 BOOST_AUTO_TEST_SUITE_END()