Eliminated invalid subgraphs (#2196)
authorAnton Pankratv <anton.pankratov@intel.com>
Tue, 15 Sep 2020 11:03:24 +0000 (14:03 +0300)
committerGitHub <noreply@github.com>
Tue, 15 Sep 2020 11:03:24 +0000 (14:03 +0300)
docs/template_plugin/src/template_plugin.cpp
inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp

index b7481d7..c640210 100644 (file)
@@ -175,11 +175,45 @@ void Plugin::QueryNetwork(const ICNNNetwork &network, const ConfigMap& config, Q
     }
 
     // 4. The result set should contains just nodes from supported set
-    for (auto&& layerName : supported) {
-        if (!contains(unsupported, layerName)) {
-            res.supportedLayersMap.emplace(layerName, GetName());
+    for (auto&& unsupportedNode : unsupported) {
+        supported.erase(unsupportedNode);
+    }
+
+    // 5. If some housekeeping nodes were not added add them.
+    for (auto&& node : function->get_ops()) {
+        if (contains(supported, node->get_friendly_name())) {
+            for (auto&& inputNodeOutput : node->input_values()) {
+                if (ngraph::op::is_constant(inputNodeOutput.get_node()) || ngraph::op::is_parameter(inputNodeOutput.get_node())) {
+                    supported.emplace(inputNodeOutput.get_node()->get_friendly_name());
+                }
+            }
+            for (auto&& outputs : node->outputs()) {
+                for (auto&& outputNodeInput : outputs.get_target_inputs()) {
+                    if (ngraph::op::is_output(outputNodeInput.get_node())) {
+                        supported.emplace(outputNodeInput.get_node()->get_friendly_name());
+                    }
+                }
+            }
         }
     }
+
+    // 6. Eliminate subgraphs that consists of housekeeping nodes only
+    for (auto&& node : function->get_ops()) {
+        if (ngraph::op::is_constant(node) || ngraph::op::is_parameter(node)) {
+            if (!contains(supported, node->output(0).get_target_inputs().begin()->get_node()->get_friendly_name())) {
+                supported.erase(node->get_friendly_name());
+            }
+        } else if (ngraph::op::is_output(node)) {
+            if (!contains(supported, node->input_values().begin()->get_node()->get_friendly_name())) {
+                supported.erase(node->get_friendly_name());
+            }
+        }
+    }
+
+    // 7. Produce the result
+    for (auto&& layerName : supported) {
+        res.supportedLayersMap.emplace(layerName, GetName());
+    }
 }
 // ! [plugin:query_network]
 
index 0631c27..6e98ea8 100644 (file)
@@ -311,11 +311,13 @@ void Engine::QueryNetwork(const ICNNNetwork& network, const std::map<std::string
                 }
             }
         }
-
+        for (auto&& unsupportedNode : unsupported) {
+            supported.erase(unsupportedNode);
+        }
         for (auto&& node : function->get_ops()) {
-            if (!contains(unsupported, node->get_friendly_name())) {
+            if (contains(supported, node->get_friendly_name())) {
                 for (auto&& inputNodeOutput : node->input_values()) {
-                    if (ngraph::op::is_constant(inputNodeOutput.get_node())) {
+                    if (ngraph::op::is_constant(inputNodeOutput.get_node()) || ngraph::op::is_parameter(inputNodeOutput.get_node())) {
                         supported.emplace(inputNodeOutput.get_node()->get_friendly_name());
                     }
                 }
@@ -328,12 +330,20 @@ void Engine::QueryNetwork(const ICNNNetwork& network, const std::map<std::string
                 }
             }
         }
-
-        for (auto&& layerName : supported) {
-            if (!contains(unsupported, layerName)) {
-                res.supportedLayersMap.emplace(layerName, GetName());
+        for (auto&& node : function->get_ops()) {
+            if (ngraph::op::is_constant(node) || ngraph::op::is_parameter(node)) {
+                if (!contains(supported, node->output(0).get_target_inputs().begin()->get_node()->get_friendly_name())) {
+                    supported.erase(node->get_friendly_name());
+                }
+            } else if (ngraph::op::is_output(node)) {
+                if (!contains(supported, node->input_values().begin()->get_node()->get_friendly_name())) {
+                    supported.erase(node->get_friendly_name());
+                }
             }
         }
+        for (auto&& layerName : supported) {
+            res.supportedLayersMap.emplace(layerName, GetName());
+        }
     } else {
         details::CNNNetworkIterator i(&network);
         while (i != details::CNNNetworkIterator()) {