From 3571d448969eeaddd44007c2133bb6751c64c838 Mon Sep 17 00:00:00 2001 From: Ilya Churaev Date: Fri, 5 Jun 2020 13:36:35 +0300 Subject: [PATCH] Save the name of output data if we remove previous layer (#760) * Save the name of output data if we remove previous layer * Added test --- .../src/legacy_api/src/net_pass.cpp | 5 ++++ .../cnn_network/cnn_ngraph_impl_tests.cpp | 28 +++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/inference-engine/src/legacy_api/src/net_pass.cpp b/inference-engine/src/legacy_api/src/net_pass.cpp index 534467ddc..77c972cfe 100644 --- a/inference-engine/src/legacy_api/src/net_pass.cpp +++ b/inference-engine/src/legacy_api/src/net_pass.cpp @@ -298,6 +298,11 @@ void RemoveLayer(CNNLayerPtr& layer) { // transfer output connections into parent data CombineData(in_data, out_data); + + // Save name for output data + if (out_data->getInputTo().empty()) { + in_data->setName(out_data->getName()); + } } /************************************************************/ diff --git a/inference-engine/tests/functional/inference_engine/cnn_network/cnn_ngraph_impl_tests.cpp b/inference-engine/tests/functional/inference_engine/cnn_network/cnn_ngraph_impl_tests.cpp index a1388035d..fce269a49 100644 --- a/inference-engine/tests/functional/inference_engine/cnn_network/cnn_ngraph_impl_tests.cpp +++ b/inference-engine/tests/functional/inference_engine/cnn_network/cnn_ngraph_impl_tests.cpp @@ -15,11 +15,13 @@ #include #include #include +#include #include #include #include #include +#include #include #include #include @@ -56,6 +58,32 @@ TEST(CNNNGraphImplTests, TestConvertNetwork) { ASSERT_EQ(cnnRefNet, cnnNet.getCNNNetwork()); } +TEST(CNNNGraphImplTests, TestConvertWithRemoveLastLayerNetwork) { + std::shared_ptr ngraph; + { + ngraph::PartialShape shape({1, 3, 22, 22}); + ngraph::element::Type type(ngraph::element::Type_t::i32); + auto param = std::make_shared(type, shape); + param->set_friendly_name("param"); + auto relu = std::make_shared(param); + relu->set_friendly_name("relu"); + auto convert = std::make_shared(relu, ngraph::element::Type_t::i64); + convert->set_friendly_name("convert"); + auto result = std::make_shared(convert); + + ngraph::ParameterVector params = {param}; + ngraph::ResultVector results = {result}; + + ngraph = std::make_shared(results, params); + } + + InferenceEngine::details::CNNNetworkNGraphImpl cnnNet(ngraph); + InferenceEngine::ICNNNetwork& cnnRefNet = *cnnNet.getCNNNetwork(); + // Remove convert layer + InferenceEngine::NetPass::ConvertPrecision(cnnRefNet, Precision::I64, Precision::I32); + ASSERT_NO_THROW(cloneNet(cnnRefNet)); +} + TEST(CNNNGraphImplTests, TestResultWithNotEqualName) { std::shared_ptr ngraph; { -- 2.34.1