Simplified usage of CNNNetworkIterator (#1260)
authorIlya Lavrenov <ilya.lavrenov@intel.com>
Fri, 10 Jul 2020 08:22:49 +0000 (11:22 +0300)
committerGitHub <noreply@github.com>
Fri, 10 Jul 2020 08:22:49 +0000 (11:22 +0300)
inference-engine/src/legacy_api/include/details/ie_cnn_network_iterator.hpp
inference-engine/tests/functional/inference_engine/ngraph_reader/ngraph_reader_tests.hpp
inference-engine/tests/ie_test_utils/functional_test_utils/network_utils.hpp
inference-engine/tests_deprecated/functional/mkldnn/single_layer_tests/conv_tests.cpp
inference-engine/tests_deprecated/functional/shared_tests/common_single_layer_tests/single_layer_tests.hpp
inference-engine/tests_deprecated/functional/shared_tests/single_layer_tests/conv_tests.hpp
inference-engine/tests_deprecated/unit/engines/gna/gna_matcher.cpp
inference-engine/tests_deprecated/unit/engines/mkldnn/graph/structure/graph_conv_concat_tests.cpp
inference-engine/tests_deprecated/unit/engines/mkldnn/graph/structure/graph_structure_test.cpp
inference-engine/tests_deprecated/unit/graph_tools/graph_tools_test.cpp

index 2df26c7..6e7d1d1 100644 (file)
@@ -16,6 +16,8 @@
 #include "ie_api.h"
 #include "ie_layers.h"
 #include "ie_icnn_network.hpp"
+#include "cnn_network_impl.hpp"
+#include "cpp/ie_cnn_network.h"
 #include "ie_locked_memory.hpp"
 
 namespace InferenceEngine {
@@ -34,6 +36,21 @@ CNNNetworkIterator {
     InferenceEngine::CNNLayerPtr currentLayer;
     ICNNNetwork* network = nullptr;
 
+    void init(const ICNNNetwork* network) {
+        if (network == nullptr) THROW_IE_EXCEPTION << "ICNNNetwork object is nullptr";
+        // IE_ASSERT(dynamic_cast<const details::CNNNetworkImpl*>(network) != nullptr);
+        InputsDataMap inputs;
+        network->getInputsInfo(inputs);
+        if (!inputs.empty()) {
+            auto& nextLayers = getInputTo(inputs.begin()->second->getInputData());
+            if (!nextLayers.empty()) {
+                currentLayer = nextLayers.begin()->second;
+                nextLayersTovisit.push_back(currentLayer);
+                visited.insert(currentLayer.get());
+            }
+        }
+    }
+
 public:
     /**
      * iterator trait definitions
@@ -54,17 +71,12 @@ public:
      * scope.
      */
     explicit CNNNetworkIterator(const ICNNNetwork* network) {
-        if (network == nullptr) THROW_IE_EXCEPTION << "ICNNNetwork object is nullptr";
-        InputsDataMap inputs;
-        network->getInputsInfo(inputs);
-        if (!inputs.empty()) {
-            auto& nextLayers = getInputTo(inputs.begin()->second->getInputData());
-            if (!nextLayers.empty()) {
-                currentLayer = nextLayers.begin()->second;
-                nextLayersTovisit.push_back(currentLayer);
-                visited.insert(currentLayer.get());
-            }
-        }
+        init(network);
+    }
+
+    explicit CNNNetworkIterator(const CNNNetwork & network) {
+        const auto & inetwork = static_cast<const InferenceEngine::ICNNNetwork&>(network);
+        init(&inetwork);
     }
 
     /**
index e659ea4..b54fc02 100644 (file)
@@ -41,14 +41,12 @@ protected:
 
         FuncTestUtils::compareCNNNetworks(network, cnnNetwork, false);
         IE_SUPPRESS_DEPRECATED_START
-        auto & inetwork = static_cast<const ICNNNetwork&>(network);
-        for (auto it = details::CNNNetworkIterator(&inetwork); it != details::CNNNetworkIterator(); it++) {
+        for (auto it = details::CNNNetworkIterator(network); it != details::CNNNetworkIterator(); it++) {
             InferenceEngine::CNNLayerPtr layer = *it;
             ASSERT_NE(nullptr, layer->getNode());
         }
 
-        auto & icnnnetwork = static_cast<const ICNNNetwork&>(cnnNetwork);
-        for (auto it = details::CNNNetworkIterator(&icnnnetwork); it != details::CNNNetworkIterator(); it++) {
+        for (auto it = details::CNNNetworkIterator(cnnNetwork); it != details::CNNNetworkIterator(); it++) {
             InferenceEngine::CNNLayerPtr layer = *it;
             ASSERT_EQ(nullptr, layer->getNode());
         }
index c54cb5d..1628115 100644 (file)
@@ -16,10 +16,8 @@ void compareCNNNLayers(const InferenceEngine::CNNLayerPtr &layer, const Inferenc
 IE_SUPPRESS_DEPRECATED_START
 template <class T>
 inline void compareLayerByLayer(const T& network, const T& refNetwork, bool sameNetVersions = true) {
-    auto & inetwork = static_cast<const InferenceEngine::ICNNNetwork&>(network);
-    auto iterator = InferenceEngine::details::CNNNetworkIterator(&inetwork);
-    auto & irefNetwork = static_cast<const InferenceEngine::ICNNNetwork&>(refNetwork);
-    auto refIterator = InferenceEngine::details::CNNNetworkIterator(&irefNetwork);
+    auto iterator = InferenceEngine::details::CNNNetworkIterator(network);
+    auto refIterator = InferenceEngine::details::CNNNetworkIterator(refNetwork);
     auto end = InferenceEngine::details::CNNNetworkIterator();
     if (network.layerCount() != refNetwork.layerCount())
         THROW_IE_EXCEPTION << "CNNNetworks have different number of layers: " << network.layerCount() << " vs " << refNetwork.layerCount();
index 146ceab..811b789 100644 (file)
@@ -306,8 +306,7 @@ protected:
     }
 
     void updatePaddings(const CNNNetwork &network, conv_test_params& p) {
-        auto & inetwork = (const ICNNNetwork &)network;
-        details::CNNNetworkIterator i(&inetwork), end;
+        details::CNNNetworkIterator i(network), end;
         auto found = std::find_if(i, end, [](const CNNLayer::Ptr& layer) {
             return layer->type == "Convolution";
         });
index 1fa61c8..ea96f4c 100644 (file)
@@ -185,8 +185,7 @@ std::string LayerTestHelper::propertyToString(const PropertyVector<unsigned int>
 ConvolutionTestHelper::ConvolutionTestHelper(const CommonTestUtils::conv_common_params &_convParams) : LayerTestHelper("Convolution"), convParams(_convParams) {}
 
 void ConvolutionTestHelper::updatePaddingValues(const CNNNetwork &network) {
-    auto & inetwork = (const ICNNNetwork &)network;
-    details::CNNNetworkIterator i(&inetwork), end;
+    details::CNNNetworkIterator i(network), end;
     auto found = std::find_if(i, end, [this](const CNNLayer::Ptr &layer) {
         return layer->type == type;
     });
@@ -280,8 +279,7 @@ void DeformableConvolutionTestHelper::ref_fp16(const std::vector<InferenceEngine
 }
 
 void DeformableConvolutionTestHelper::updatePaddingValues(const CNNNetwork &network) {
-    auto & inetwork = (const ICNNNetwork &)network;
-    details::CNNNetworkIterator i(&inetwork), end;
+    details::CNNNetworkIterator i(network), end;
     auto found = std::find_if(i, end, [this](const CNNLayer::Ptr &layer) {
         return layer->type == type;
     });
@@ -352,8 +350,7 @@ void PoolingTestHelper::ref_fp16(const std::vector<InferenceEngine::Blob::Ptr> s
 }
 
 void PoolingTestHelper::updatePaddingValues(const InferenceEngine::CNNNetwork &network) {
-    auto & inetwork = (const ICNNNetwork &)network;
-    details::CNNNetworkIterator i(&inetwork), end;
+    details::CNNNetworkIterator i(network), end;
     auto found = std::find_if(i, end, [this](const CNNLayer::Ptr &layer) {
         return layer->type == type;
     });
index 6d502a3..fca233d 100644 (file)
@@ -493,8 +493,7 @@ protected:
     }
 
     void updatePaddings(const CNNNetwork &network, conv_test_params& p) {
-        auto & inetwork = static_cast<const ICNNNetwork &>(network);
-        details::CNNNetworkIterator i(&inetwork), end;
+        details::CNNNetworkIterator i(network), end;
         auto found = std::find_if(i, end, [](const CNNLayer::Ptr& layer) {
             return layer->type == "Convolution";
         });
index e7db7ff..536d2e2 100644 (file)
@@ -93,8 +93,7 @@ void GNAPropagateMatcher :: match() {
 
             std::vector<InferenceEngine::CNNLayerPtr> tiBodies;
 
-            const auto & inetwork = static_cast<const ICNNNetwork&>(net_original);
-            for (auto layerIt = details::CNNNetworkIterator(&inetwork), end = details::CNNNetworkIterator();
+            for (auto layerIt = details::CNNNetworkIterator(net_original), end = details::CNNNetworkIterator();
                      layerIt != end; ++layerIt) {
                 auto layer = *layerIt;
                 if (layer->type == "TensorIterator") {
index d3da39b..153411a 100644 (file)
@@ -211,8 +211,7 @@ protected:
 
             graph.Infer(srcs, outputBlobs);
 
-            const auto & inetwork = static_cast<const ICNNNetwork&>(network);
-            details::CNNNetworkIterator l(&inetwork), end;
+            details::CNNNetworkIterator l(network), end;
             for ( ; l != end; ++l) {
                 (*l)->params["PrimitivesPriority"] = "cpu:ref,cpu:ref_any";
             }
index d02c654..7ac445a 100644 (file)
@@ -5140,8 +5140,7 @@ TEST_F(MKLDNNGraphStructureTests, TestGemmConvolutionWithConcat) {
 
     auto graphInfer = [](InferenceEngine::CNNNetwork network, InferenceEngine::BlobMap& inBlobs,
             InferenceEngine::BlobMap& outBlobs, std::string primitivesPriority) {
-        const auto & inetwork = static_cast<const InferenceEngine::ICNNNetwork&>(network);
-        for (auto it = InferenceEngine::details::CNNNetworkIterator(&inetwork); !primitivesPriority.empty() &&
+        for (auto it = InferenceEngine::details::CNNNetworkIterator(network); !primitivesPriority.empty() &&
             it != InferenceEngine::details::CNNNetworkIterator(); it++) {
             (*it)->params["PrimitivesPriority"] = primitivesPriority;
         }
@@ -5426,8 +5425,7 @@ TEST_F(MKLDNNGraphStructureTests, TestRefPoolingWithConcat) {
 
     auto graphInfer = [](InferenceEngine::CNNNetwork network, InferenceEngine::BlobMap& inBlobs,
                          InferenceEngine::BlobMap& outBlobs, std::string primitivesPriority) {
-        const auto & inetwork = static_cast<const InferenceEngine::ICNNNetwork&>(network);
-        for (auto it = InferenceEngine::details::CNNNetworkIterator(&inetwork); !primitivesPriority.empty() &&
+        for (auto it = InferenceEngine::details::CNNNetworkIterator(network); !primitivesPriority.empty() &&
             it != InferenceEngine::details::CNNNetworkIterator(); it++) {
             (*it)->params["PrimitivesPriority"] = primitivesPriority;
         }
index 601cf70..ee010d0 100644 (file)
@@ -257,8 +257,7 @@ TEST_F(GraphToolsTest, canIterateOverCNNNetwork) {
     })));
 
     std::vector<CNNLayerPtr> resultedOrder;
-    const auto & inetwork = static_cast<const ICNNNetwork&>(wrap);
-    details::CNNNetworkIterator l(&inetwork), end;
+    details::CNNNetworkIterator l(wrap), end;
     for ( ; l != end; ++l) {
         resultedOrder.push_back(*l);
     }
@@ -285,8 +284,7 @@ TEST_F(GraphToolsTest, canIterateOverCNNNetworkWithCycle) {
     })));
 
     std::vector<CNNLayerPtr> resultedOrder;
-    const auto & inetwork = static_cast<const ICNNNetwork&>(wrap);
-    details::CNNNetworkIterator l(&inetwork), end;
+    details::CNNNetworkIterator l(wrap), end;
     for (; l != end; ++l) {
         resultedOrder.push_back(*l);
     }
@@ -306,8 +304,7 @@ TEST_F(GraphToolsTest, canCompareCNNNetworkIterators) {
         prepareInputs(maps);
     })));
 
-    const auto & inetwork = static_cast<const ICNNNetwork&>(wrap);
-    details::CNNNetworkIterator i(&inetwork);
+    details::CNNNetworkIterator i(wrap);
     auto i2 = i;
     i2++;
 
@@ -324,8 +321,7 @@ TEST_F(GraphToolsTest, canIterateOverEmptyNetwork) {
         prepareInputs(maps);
     })));
 
-    const auto & inetwork = static_cast<const ICNNNetwork&>(wrap);
-    details::CNNNetworkIterator beg(&inetwork), end;
+    details::CNNNetworkIterator beg(wrap), end;
     ASSERT_EQ(beg, end);
 }