Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / graph_tools / graph_test_base.hpp
index 94c0876..79a1f4a 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright (C) 2018 Intel Corporation
+// Copyright (C) 2018-2019 Intel Corporation
 // SPDX-License-Identifier: Apache-2.0
 //
 
@@ -73,8 +73,27 @@ class GraphTestsBase : public ::testing::Test {
         }
         return nullptr;
     }
+
+
+    #define ASSERT_N_CONNECTIONS(a, b, n) \
+        ASSERT_EQ(countForwardConnections(#a, #b), n);\
+        ASSERT_EQ(countBackwardConnections(#a, #b), n);
+
     #define ASSERT_CONNECTION(a, b) \
-        ASSERT_TRUE(assertConnection(#a, #b));
+        ASSERT_N_CONNECTIONS(a,b,1);
+
+    #define ASSERT_2_CONNECTIONS(a, b) \
+        ASSERT_N_CONNECTIONS(a,b,2);
+
+    #define ASSERT_3_CONNECTIONS(a, b) \
+        ASSERT_N_CONNECTIONS(a,b,3);
+
+    /**
+     * @brief check connection without direction
+     */
+    #define ASSERT_NO_CONNECTION(a, b) \
+        ASSERT_EQ(countConnections(#a, #b), 0);\
+        ASSERT_EQ(countConnections(#b, #a), 0);\
 
     void ASSERT_DIMS(int x, const SizeVector & dims) {
 
@@ -84,30 +103,51 @@ class GraphTestsBase : public ::testing::Test {
         }
     }
 
-    bool assertConnection(std::string a, std::string b) {
+    int countForwardConnections(std::string a, std::string b) {
+        long int nForward = 0;
+        CNNLayerPtr layerExist;
+        try {
+            layerExist = wrap.getLayerByName(a.c_str());
+            if (!layerExist) {
+                return 0;
+            }
+        } catch(...) {
+            return 0;
+        }
 
-        bool bForward = false;
-        for (auto && outData : wrap.getLayerByName(a.c_str())->outData) {
+        for (auto && outData : layerExist->outData) {
             auto &inputMap = outData->inputTo;
-            auto i =
-                std::find_if(inputMap.begin(), inputMap.end(), [&](std::map<std::string, CNNLayerPtr>::value_type &vt) {
+            nForward +=
+                std::count_if(inputMap.begin(), inputMap.end(), [&](std::map<std::string, CNNLayerPtr>::value_type &vt) {
                     return vt.second->name == b;
                 });
-            if (i != inputMap.end()) {
-                bForward = true;
-                break;
-            }
         }
-        if (!bForward) {
-            return false;
+
+        return nForward;
+    }
+
+    int countBackwardConnections(std::string a, std::string b) {
+        CNNLayerPtr layerExist;
+        try {
+            layerExist = wrap.getLayerByName(b.c_str());
+            if (!layerExist) {
+                return 0;
+            }
+        } catch(...) {
+            return 0;
         }
 
-        auto prevData = wrap.getLayerByName(b.c_str())->insData;
+        auto prevData = layerExist->insData;
 
-        auto j = std::find_if(prevData.begin(), prevData.end(), [&](DataWeakPtr wp) {
+        auto nBackward = std::count_if(prevData.begin(), prevData.end(), [&](DataWeakPtr wp) {
             return wp.lock()->getCreatorLayer().lock()->name == a;
         });
-        return  j != prevData.end();
+
+        return  nBackward;
+    }
+
+    int countConnections(std::string a, std::string b) {
+        return  countForwardConnections(a, b) + countBackwardConnections(a, b);
     }
 
     int numCreated = 0;
@@ -189,6 +229,17 @@ class GraphTestsBase : public ::testing::Test {
         }
     }
 
+    void TearDown() override {
+        // Reset shared_pointer circular dependencies to mitigate memory leaks.
+        for (auto& items : datas) {
+            for (auto& data : items) {
+                for (auto& input : data->getInputTo()) {
+                    input.second.reset();
+                }
+            }
+        }
+    }
+
     int ID(const CNNLayerPtr &ptr) {
         for (int i = 0; i < layers.size(); i++) {
             if (layers[i].get() == ptr.get())