Fixed query network for networks with KSO (#2201)
authorIlya Churaev <ilya.churaev@intel.com>
Tue, 15 Sep 2020 11:02:15 +0000 (14:02 +0300)
committerGitHub <noreply@github.com>
Tue, 15 Sep 2020 11:02:15 +0000 (14:02 +0300)
* Added a test to reproduce QueryNetwork with KSO

* Fixed QueryNetwork for networks with KSO

* Added additional test

inference-engine/src/inference_engine/ie_core.cpp
inference-engine/tests/functional/plugin/shared/include/behavior/core_integration.hpp
inference-engine/tests/ngraph_functions/include/ngraph_functions/subgraph_builders.hpp

index 89e8805..cf6725c 100644 (file)
@@ -12,6 +12,9 @@
 #include <ie_core.hpp>
 #include <multi-device/multi_device_config.hpp>
 #include <ngraph/opsets/opset.hpp>
+#include <ngraph/ngraph.hpp>
+#include <ngraph/graph_util.hpp>
+#include <ngraph/pass/constant_folding.hpp>
 
 #include <cpp_interfaces/exception2status.hpp>
 #include "ie_plugin_cpp.hpp"
@@ -294,6 +297,23 @@ public:
         QueryNetworkResult res;
         auto parsed = parseDeviceNameIntoConfig(deviceName, config);
         GetCPPPluginByName(parsed._deviceName).QueryNetwork(network, parsed._config, res);
+        if (!network.getFunction())
+            return res;
+
+        // WA for constant folded operations (plugins should support all folded ops)
+        const auto& func = network.getFunction();
+        auto specialized_function = ngraph::clone_function(*func);
+
+        ngraph::pass::ConstantFolding().run_on_function(specialized_function);
+        std::unordered_set<std::string> operationNames;
+        for (const auto& op : specialized_function->get_ops())
+            operationNames.emplace(op->get_friendly_name());
+
+        for (const auto& op : func->get_ops()) {
+            if (operationNames.find(op->get_friendly_name()) != operationNames.end())
+                continue;
+            res.supportedLayersMap[op->get_friendly_name()] = deviceName;
+        }
         return res;
     }
 
index cec77e9..4bbeb94 100644 (file)
@@ -86,7 +86,7 @@ public:
 
 class IEClassNetworkTest : public ::testing::Test {
 public:
-    CNNNetwork actualNetwork, simpleNetwork, multinputNetwork;
+    CNNNetwork actualNetwork, simpleNetwork, multinputNetwork, ksoNetwork;
 
     void SetUp() override {
         // Generic network
@@ -104,6 +104,11 @@ public:
             auto fnPtr = ngraph::builder::subgraph::make2InputSubtract();
             multinputNetwork = InferenceEngine::CNNNetwork(fnPtr);
         }
+        // Network with KSO
+        {
+            auto fnPtr = ngraph::builder::subgraph::makeKSOFunction();
+            ksoNetwork = InferenceEngine::CNNNetwork(fnPtr);
+        }
     }
     void setHeteroNetworkAffinity(const std::string& targetDevice) {
         const std::map<std::string, std::string> deviceMapping = {
@@ -549,6 +554,49 @@ TEST_P(IEClassNetworkTestP, QueryNetworkActualNoThrow) {
     }
 }
 
+TEST_P(IEClassNetworkTestP, QueryNetworkWithKSO) {
+    SKIP_IF_CURRENT_TEST_IS_DISABLED()
+    Core ie;
+
+    try {
+        auto rres = ie.QueryNetwork(ksoNetwork, deviceName);
+        auto rl_map = rres.supportedLayersMap;
+        auto func = ksoNetwork.getFunction();
+        for (const auto & op : func->get_ops()) {
+            if (!rl_map.count(op->get_friendly_name())) {
+                FAIL() << "Op " << op->get_friendly_name() << " is not supported by " << deviceName;
+            }
+        }
+    } catch (const InferenceEngine::details::InferenceEngineException & ex) {
+        std::string message = ex.what();
+        ASSERT_STR_CONTAINS(message, "[NOT_IMPLEMENTED]  ngraph::Function is not supported natively");
+    }
+}
+
+TEST_P(IEClassNetworkTestP, SetAffinityWithKSO) {
+    SKIP_IF_CURRENT_TEST_IS_DISABLED()
+    Core ie;
+
+    try {
+        auto rres = ie.QueryNetwork(ksoNetwork, deviceName);
+        auto rl_map = rres.supportedLayersMap;
+        auto func = ksoNetwork.getFunction();
+        for (const auto & op : func->get_ops()) {
+            if (!rl_map.count(op->get_friendly_name())) {
+                FAIL() << "Op " << op->get_friendly_name() << " is not supported by " << deviceName;
+            }
+        }
+        for (const auto & op : ksoNetwork.getFunction()->get_ops()) {
+            std::string affinity = rl_map[op->get_friendly_name()];
+            op->get_rt_info()["affinity"] = std::make_shared<ngraph::VariantWrapper<std::string>>(affinity);
+        }
+        ExecutableNetwork exeNetwork = ie.LoadNetwork(ksoNetwork, deviceName);
+    } catch (const InferenceEngine::details::InferenceEngineException & ex) {
+        std::string message = ex.what();
+        ASSERT_STR_CONTAINS(message, "[NOT_IMPLEMENTED]  ngraph::Function is not supported natively");
+    }
+}
+
 TEST_P(IEClassNetworkTestP, QueryNetworkHeteroActualNoThrow) {
     SKIP_IF_CURRENT_TEST_IS_DISABLED()
     Core ie;
index d6f002f..8064ffb 100644 (file)
@@ -60,6 +60,26 @@ static std::shared_ptr<ngraph::Function> makeSplitConvConcat(std::vector<size_t>
     return fnPtr;
 }
 
+static std::shared_ptr<ngraph::Function> makeKSOFunction(std::vector<size_t> inputShape = {1, 4, 20, 20},
+                                                         InferenceEngine::Precision netPrecision = InferenceEngine::Precision::FP32) {
+    auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
+    auto params = ngraph::builder::makeParams(ngPrc, {inputShape});
+
+    auto shapeOf = std::make_shared<ngraph::opset4::ShapeOf>(params[0]);
+    auto convert = std::make_shared<ngraph::opset4::Convert>(shapeOf, ngPrc);
+    auto newShape = ngraph::builder::makeConstant<int64_t>(ngraph::element::i64, {4}, {1, 4, 1, 1});
+    auto reshape = std::make_shared<ngraph::opset4::Reshape>(convert, newShape, false);
+    auto conv1 = ngraph::builder::makeConvolution(params[0], ngPrc, {3, 3}, {1, 1}, {0, 0}, {0, 0}, {1, 1},
+                                                  ngraph::op::PadType::EXPLICIT, 4);
+    auto relu1 = std::make_shared<ngraph::opset4::Relu>(conv1);
+    auto add = std::make_shared<ngraph::opset4::Add>(relu1, reshape);
+
+    ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(add)};
+    std::shared_ptr<ngraph::Function> fnPtr = std::make_shared<ngraph::Function>(results, params);
+    fnPtr->set_friendly_name("KSOFunction");
+    return fnPtr;
+}
+
 static std::shared_ptr<ngraph::Function> makeSplitMultiConvConcat(std::vector<size_t> inputShape = {1, 4, 20, 20}) {
     auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(InferenceEngine::Precision::FP32);
     auto params = ngraph::builder::makeParams(ngPrc, {inputShape});