From 297c9f52724f8e5b751c279f1c65772e4a495389 Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Fri, 10 Jul 2020 11:22:49 +0300 Subject: [PATCH] Simplified usage of CNNNetworkIterator (#1260) --- .../include/details/ie_cnn_network_iterator.hpp | 34 +++++++++++++++------- .../ngraph_reader/ngraph_reader_tests.hpp | 6 ++-- .../functional_test_utils/network_utils.hpp | 6 ++-- .../mkldnn/single_layer_tests/conv_tests.cpp | 3 +- .../single_layer_tests.hpp | 9 ++---- .../shared_tests/single_layer_tests/conv_tests.hpp | 3 +- .../unit/engines/gna/gna_matcher.cpp | 3 +- .../graph/structure/graph_conv_concat_tests.cpp | 3 +- .../graph/structure/graph_structure_test.cpp | 6 ++-- .../unit/graph_tools/graph_tools_test.cpp | 12 +++----- 10 files changed, 40 insertions(+), 45 deletions(-) diff --git a/inference-engine/src/legacy_api/include/details/ie_cnn_network_iterator.hpp b/inference-engine/src/legacy_api/include/details/ie_cnn_network_iterator.hpp index 2df26c7..6e7d1d1 100644 --- a/inference-engine/src/legacy_api/include/details/ie_cnn_network_iterator.hpp +++ b/inference-engine/src/legacy_api/include/details/ie_cnn_network_iterator.hpp @@ -16,6 +16,8 @@ #include "ie_api.h" #include "ie_layers.h" #include "ie_icnn_network.hpp" +#include "cnn_network_impl.hpp" +#include "cpp/ie_cnn_network.h" #include "ie_locked_memory.hpp" namespace InferenceEngine { @@ -34,6 +36,21 @@ CNNNetworkIterator { InferenceEngine::CNNLayerPtr currentLayer; ICNNNetwork* network = nullptr; + void init(const ICNNNetwork* network) { + if (network == nullptr) THROW_IE_EXCEPTION << "ICNNNetwork object is nullptr"; + // IE_ASSERT(dynamic_cast(network) != nullptr); + InputsDataMap inputs; + network->getInputsInfo(inputs); + if (!inputs.empty()) { + auto& nextLayers = getInputTo(inputs.begin()->second->getInputData()); + if (!nextLayers.empty()) { + currentLayer = nextLayers.begin()->second; + nextLayersTovisit.push_back(currentLayer); + visited.insert(currentLayer.get()); + } + } + } + public: /** * iterator trait definitions @@ -54,17 +71,12 @@ public: * scope. */ explicit CNNNetworkIterator(const ICNNNetwork* network) { - if (network == nullptr) THROW_IE_EXCEPTION << "ICNNNetwork object is nullptr"; - InputsDataMap inputs; - network->getInputsInfo(inputs); - if (!inputs.empty()) { - auto& nextLayers = getInputTo(inputs.begin()->second->getInputData()); - if (!nextLayers.empty()) { - currentLayer = nextLayers.begin()->second; - nextLayersTovisit.push_back(currentLayer); - visited.insert(currentLayer.get()); - } - } + init(network); + } + + explicit CNNNetworkIterator(const CNNNetwork & network) { + const auto & inetwork = static_cast(network); + init(&inetwork); } /** diff --git a/inference-engine/tests/functional/inference_engine/ngraph_reader/ngraph_reader_tests.hpp b/inference-engine/tests/functional/inference_engine/ngraph_reader/ngraph_reader_tests.hpp index e659ea4..b54fc02 100644 --- a/inference-engine/tests/functional/inference_engine/ngraph_reader/ngraph_reader_tests.hpp +++ b/inference-engine/tests/functional/inference_engine/ngraph_reader/ngraph_reader_tests.hpp @@ -41,14 +41,12 @@ protected: FuncTestUtils::compareCNNNetworks(network, cnnNetwork, false); IE_SUPPRESS_DEPRECATED_START - auto & inetwork = static_cast(network); - for (auto it = details::CNNNetworkIterator(&inetwork); it != details::CNNNetworkIterator(); it++) { + for (auto it = details::CNNNetworkIterator(network); it != details::CNNNetworkIterator(); it++) { InferenceEngine::CNNLayerPtr layer = *it; ASSERT_NE(nullptr, layer->getNode()); } - auto & icnnnetwork = static_cast(cnnNetwork); - for (auto it = details::CNNNetworkIterator(&icnnnetwork); it != details::CNNNetworkIterator(); it++) { + for (auto it = details::CNNNetworkIterator(cnnNetwork); it != details::CNNNetworkIterator(); it++) { InferenceEngine::CNNLayerPtr layer = *it; ASSERT_EQ(nullptr, layer->getNode()); } diff --git a/inference-engine/tests/ie_test_utils/functional_test_utils/network_utils.hpp b/inference-engine/tests/ie_test_utils/functional_test_utils/network_utils.hpp index c54cb5d..1628115 100644 --- a/inference-engine/tests/ie_test_utils/functional_test_utils/network_utils.hpp +++ b/inference-engine/tests/ie_test_utils/functional_test_utils/network_utils.hpp @@ -16,10 +16,8 @@ void compareCNNNLayers(const InferenceEngine::CNNLayerPtr &layer, const Inferenc IE_SUPPRESS_DEPRECATED_START template inline void compareLayerByLayer(const T& network, const T& refNetwork, bool sameNetVersions = true) { - auto & inetwork = static_cast(network); - auto iterator = InferenceEngine::details::CNNNetworkIterator(&inetwork); - auto & irefNetwork = static_cast(refNetwork); - auto refIterator = InferenceEngine::details::CNNNetworkIterator(&irefNetwork); + auto iterator = InferenceEngine::details::CNNNetworkIterator(network); + auto refIterator = InferenceEngine::details::CNNNetworkIterator(refNetwork); auto end = InferenceEngine::details::CNNNetworkIterator(); if (network.layerCount() != refNetwork.layerCount()) THROW_IE_EXCEPTION << "CNNNetworks have different number of layers: " << network.layerCount() << " vs " << refNetwork.layerCount(); diff --git a/inference-engine/tests_deprecated/functional/mkldnn/single_layer_tests/conv_tests.cpp b/inference-engine/tests_deprecated/functional/mkldnn/single_layer_tests/conv_tests.cpp index 146ceab..811b789 100644 --- a/inference-engine/tests_deprecated/functional/mkldnn/single_layer_tests/conv_tests.cpp +++ b/inference-engine/tests_deprecated/functional/mkldnn/single_layer_tests/conv_tests.cpp @@ -306,8 +306,7 @@ protected: } void updatePaddings(const CNNNetwork &network, conv_test_params& p) { - auto & inetwork = (const ICNNNetwork &)network; - details::CNNNetworkIterator i(&inetwork), end; + details::CNNNetworkIterator i(network), end; auto found = std::find_if(i, end, [](const CNNLayer::Ptr& layer) { return layer->type == "Convolution"; }); diff --git a/inference-engine/tests_deprecated/functional/shared_tests/common_single_layer_tests/single_layer_tests.hpp b/inference-engine/tests_deprecated/functional/shared_tests/common_single_layer_tests/single_layer_tests.hpp index 1fa61c8..ea96f4c 100644 --- a/inference-engine/tests_deprecated/functional/shared_tests/common_single_layer_tests/single_layer_tests.hpp +++ b/inference-engine/tests_deprecated/functional/shared_tests/common_single_layer_tests/single_layer_tests.hpp @@ -185,8 +185,7 @@ std::string LayerTestHelper::propertyToString(const PropertyVector ConvolutionTestHelper::ConvolutionTestHelper(const CommonTestUtils::conv_common_params &_convParams) : LayerTestHelper("Convolution"), convParams(_convParams) {} void ConvolutionTestHelper::updatePaddingValues(const CNNNetwork &network) { - auto & inetwork = (const ICNNNetwork &)network; - details::CNNNetworkIterator i(&inetwork), end; + details::CNNNetworkIterator i(network), end; auto found = std::find_if(i, end, [this](const CNNLayer::Ptr &layer) { return layer->type == type; }); @@ -280,8 +279,7 @@ void DeformableConvolutionTestHelper::ref_fp16(const std::vectortype == type; }); @@ -352,8 +350,7 @@ void PoolingTestHelper::ref_fp16(const std::vector s } void PoolingTestHelper::updatePaddingValues(const InferenceEngine::CNNNetwork &network) { - auto & inetwork = (const ICNNNetwork &)network; - details::CNNNetworkIterator i(&inetwork), end; + details::CNNNetworkIterator i(network), end; auto found = std::find_if(i, end, [this](const CNNLayer::Ptr &layer) { return layer->type == type; }); diff --git a/inference-engine/tests_deprecated/functional/shared_tests/single_layer_tests/conv_tests.hpp b/inference-engine/tests_deprecated/functional/shared_tests/single_layer_tests/conv_tests.hpp index 6d502a3..fca233d 100644 --- a/inference-engine/tests_deprecated/functional/shared_tests/single_layer_tests/conv_tests.hpp +++ b/inference-engine/tests_deprecated/functional/shared_tests/single_layer_tests/conv_tests.hpp @@ -493,8 +493,7 @@ protected: } void updatePaddings(const CNNNetwork &network, conv_test_params& p) { - auto & inetwork = static_cast(network); - details::CNNNetworkIterator i(&inetwork), end; + details::CNNNetworkIterator i(network), end; auto found = std::find_if(i, end, [](const CNNLayer::Ptr& layer) { return layer->type == "Convolution"; }); diff --git a/inference-engine/tests_deprecated/unit/engines/gna/gna_matcher.cpp b/inference-engine/tests_deprecated/unit/engines/gna/gna_matcher.cpp index e7db7ff..536d2e2 100644 --- a/inference-engine/tests_deprecated/unit/engines/gna/gna_matcher.cpp +++ b/inference-engine/tests_deprecated/unit/engines/gna/gna_matcher.cpp @@ -93,8 +93,7 @@ void GNAPropagateMatcher :: match() { std::vector tiBodies; - const auto & inetwork = static_cast(net_original); - for (auto layerIt = details::CNNNetworkIterator(&inetwork), end = details::CNNNetworkIterator(); + for (auto layerIt = details::CNNNetworkIterator(net_original), end = details::CNNNetworkIterator(); layerIt != end; ++layerIt) { auto layer = *layerIt; if (layer->type == "TensorIterator") { diff --git a/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/structure/graph_conv_concat_tests.cpp b/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/structure/graph_conv_concat_tests.cpp index d3da39b..153411a 100644 --- a/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/structure/graph_conv_concat_tests.cpp +++ b/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/structure/graph_conv_concat_tests.cpp @@ -211,8 +211,7 @@ protected: graph.Infer(srcs, outputBlobs); - const auto & inetwork = static_cast(network); - details::CNNNetworkIterator l(&inetwork), end; + details::CNNNetworkIterator l(network), end; for ( ; l != end; ++l) { (*l)->params["PrimitivesPriority"] = "cpu:ref,cpu:ref_any"; } diff --git a/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/structure/graph_structure_test.cpp b/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/structure/graph_structure_test.cpp index d02c654..7ac445a 100644 --- a/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/structure/graph_structure_test.cpp +++ b/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/structure/graph_structure_test.cpp @@ -5140,8 +5140,7 @@ TEST_F(MKLDNNGraphStructureTests, TestGemmConvolutionWithConcat) { auto graphInfer = [](InferenceEngine::CNNNetwork network, InferenceEngine::BlobMap& inBlobs, InferenceEngine::BlobMap& outBlobs, std::string primitivesPriority) { - const auto & inetwork = static_cast(network); - for (auto it = InferenceEngine::details::CNNNetworkIterator(&inetwork); !primitivesPriority.empty() && + for (auto it = InferenceEngine::details::CNNNetworkIterator(network); !primitivesPriority.empty() && it != InferenceEngine::details::CNNNetworkIterator(); it++) { (*it)->params["PrimitivesPriority"] = primitivesPriority; } @@ -5426,8 +5425,7 @@ TEST_F(MKLDNNGraphStructureTests, TestRefPoolingWithConcat) { auto graphInfer = [](InferenceEngine::CNNNetwork network, InferenceEngine::BlobMap& inBlobs, InferenceEngine::BlobMap& outBlobs, std::string primitivesPriority) { - const auto & inetwork = static_cast(network); - for (auto it = InferenceEngine::details::CNNNetworkIterator(&inetwork); !primitivesPriority.empty() && + for (auto it = InferenceEngine::details::CNNNetworkIterator(network); !primitivesPriority.empty() && it != InferenceEngine::details::CNNNetworkIterator(); it++) { (*it)->params["PrimitivesPriority"] = primitivesPriority; } diff --git a/inference-engine/tests_deprecated/unit/graph_tools/graph_tools_test.cpp b/inference-engine/tests_deprecated/unit/graph_tools/graph_tools_test.cpp index 601cf70..ee010d0 100644 --- a/inference-engine/tests_deprecated/unit/graph_tools/graph_tools_test.cpp +++ b/inference-engine/tests_deprecated/unit/graph_tools/graph_tools_test.cpp @@ -257,8 +257,7 @@ TEST_F(GraphToolsTest, canIterateOverCNNNetwork) { }))); std::vector resultedOrder; - const auto & inetwork = static_cast(wrap); - details::CNNNetworkIterator l(&inetwork), end; + details::CNNNetworkIterator l(wrap), end; for ( ; l != end; ++l) { resultedOrder.push_back(*l); } @@ -285,8 +284,7 @@ TEST_F(GraphToolsTest, canIterateOverCNNNetworkWithCycle) { }))); std::vector resultedOrder; - const auto & inetwork = static_cast(wrap); - details::CNNNetworkIterator l(&inetwork), end; + details::CNNNetworkIterator l(wrap), end; for (; l != end; ++l) { resultedOrder.push_back(*l); } @@ -306,8 +304,7 @@ TEST_F(GraphToolsTest, canCompareCNNNetworkIterators) { prepareInputs(maps); }))); - const auto & inetwork = static_cast(wrap); - details::CNNNetworkIterator i(&inetwork); + details::CNNNetworkIterator i(wrap); auto i2 = i; i2++; @@ -324,8 +321,7 @@ TEST_F(GraphToolsTest, canIterateOverEmptyNetwork) { prepareInputs(maps); }))); - const auto & inetwork = static_cast(wrap); - details::CNNNetworkIterator beg(&inetwork), end; + details::CNNNetworkIterator beg(wrap), end; ASSERT_EQ(beg, end); } -- 2.7.4