From 89a6f926a4784110b34848a432994e4632f860ce Mon Sep 17 00:00:00 2001 From: Anton Pankratv Date: Tue, 15 Sep 2020 14:03:24 +0300 Subject: [PATCH] Eliminated invalid subgraphs (#2196) --- docs/template_plugin/src/template_plugin.cpp | 40 ++++++++++++++++++++-- .../src/mkldnn_plugin/mkldnn_plugin.cpp | 24 +++++++++---- 2 files changed, 54 insertions(+), 10 deletions(-) diff --git a/docs/template_plugin/src/template_plugin.cpp b/docs/template_plugin/src/template_plugin.cpp index b7481d7..c640210 100644 --- a/docs/template_plugin/src/template_plugin.cpp +++ b/docs/template_plugin/src/template_plugin.cpp @@ -175,11 +175,45 @@ void Plugin::QueryNetwork(const ICNNNetwork &network, const ConfigMap& config, Q } // 4. The result set should contains just nodes from supported set - for (auto&& layerName : supported) { - if (!contains(unsupported, layerName)) { - res.supportedLayersMap.emplace(layerName, GetName()); + for (auto&& unsupportedNode : unsupported) { + supported.erase(unsupportedNode); + } + + // 5. If some housekeeping nodes were not added add them. + for (auto&& node : function->get_ops()) { + if (contains(supported, node->get_friendly_name())) { + for (auto&& inputNodeOutput : node->input_values()) { + if (ngraph::op::is_constant(inputNodeOutput.get_node()) || ngraph::op::is_parameter(inputNodeOutput.get_node())) { + supported.emplace(inputNodeOutput.get_node()->get_friendly_name()); + } + } + for (auto&& outputs : node->outputs()) { + for (auto&& outputNodeInput : outputs.get_target_inputs()) { + if (ngraph::op::is_output(outputNodeInput.get_node())) { + supported.emplace(outputNodeInput.get_node()->get_friendly_name()); + } + } + } } } + + // 6. Eliminate subgraphs that consists of housekeeping nodes only + for (auto&& node : function->get_ops()) { + if (ngraph::op::is_constant(node) || ngraph::op::is_parameter(node)) { + if (!contains(supported, node->output(0).get_target_inputs().begin()->get_node()->get_friendly_name())) { + supported.erase(node->get_friendly_name()); + } + } else if (ngraph::op::is_output(node)) { + if (!contains(supported, node->input_values().begin()->get_node()->get_friendly_name())) { + supported.erase(node->get_friendly_name()); + } + } + } + + // 7. Produce the result + for (auto&& layerName : supported) { + res.supportedLayersMap.emplace(layerName, GetName()); + } } // ! [plugin:query_network] diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp index 0631c27..6e98ea8 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp @@ -311,11 +311,13 @@ void Engine::QueryNetwork(const ICNNNetwork& network, const std::mapget_ops()) { - if (!contains(unsupported, node->get_friendly_name())) { + if (contains(supported, node->get_friendly_name())) { for (auto&& inputNodeOutput : node->input_values()) { - if (ngraph::op::is_constant(inputNodeOutput.get_node())) { + if (ngraph::op::is_constant(inputNodeOutput.get_node()) || ngraph::op::is_parameter(inputNodeOutput.get_node())) { supported.emplace(inputNodeOutput.get_node()->get_friendly_name()); } } @@ -328,12 +330,20 @@ void Engine::QueryNetwork(const ICNNNetwork& network, const std::mapget_ops()) { + if (ngraph::op::is_constant(node) || ngraph::op::is_parameter(node)) { + if (!contains(supported, node->output(0).get_target_inputs().begin()->get_node()->get_friendly_name())) { + supported.erase(node->get_friendly_name()); + } + } else if (ngraph::op::is_output(node)) { + if (!contains(supported, node->input_values().begin()->get_node()->get_friendly_name())) { + supported.erase(node->get_friendly_name()); + } } } + for (auto&& layerName : supported) { + res.supportedLayersMap.emplace(layerName, GetName()); + } } else { details::CNNNetworkIterator i(&network); while (i != details::CNNNetworkIterator()) { -- 2.7.4