Added dynamic check for convertFunctionToCNNNetwork functoin (#2797)
authorGleb Kazantaev <gleb.kazantaev@intel.com>
Fri, 23 Oct 2020 15:17:26 +0000 (18:17 +0300)
committerGitHub <noreply@github.com>
Fri, 23 Oct 2020 15:17:26 +0000 (18:17 +0300)
* Keep changes

* Added dynamic check for convertFunctionToCNNNetwork

* Fixed test

inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp
inference-engine/tests/functional/inference_engine/cnn_network/convert_ngraph_to_cnn_network_tests.cpp

index 0cf25b9..97fb7b7 100644 (file)
@@ -907,6 +907,30 @@ void convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function
         network->setInputInfo(info);
     };
 
+    // Check if some of function nodes has dynamic input or output shape
+    // we collect this nodes and then throw an exception with the list
+    // of dynamic nodes.
+    std::stringstream err_log;
+    for (const auto & node : graph->get_ordered_ops()) {
+        bool is_dynamic = false;
+        for (const auto & input : node->inputs()) {
+            if (input.get_partial_shape().is_dynamic()) {
+                is_dynamic = true;
+                break;
+            }
+        }
+        for (const auto & output : node->outputs()) {
+            if (output.get_partial_shape().is_dynamic()) {
+                is_dynamic = true;
+                break;
+            }
+        }
+        if (is_dynamic) err_log << node << std::endl;
+    }
+    if (!err_log.str().empty()) {
+        THROW_IE_EXCEPTION << "\nUnsupported dynamic ops: \n" << err_log.str();
+    }
+
     const CNNNetworkNGraphImpl* nGraphImpl = dynamic_cast<const CNNNetworkNGraphImpl*>(&network);
 
     InputsDataMap thisInputDataMap;
index f744b8d..bdb93cc 100644 (file)
@@ -3,6 +3,7 @@
 //
 
 #include <gtest/gtest.h>
+#include <gmock/gmock.h>
 
 #include <cpp/ie_cnn_network.h>
 #include <legacy/cnn_network_impl.hpp>  // deprecated API
@@ -206,4 +207,33 @@ TEST(ConvertFunctionToCNNNetworkTests, ConvertTopKWithOneInput) {
         const std::string resp_msg = err.what();
         ASSERT_TRUE(resp_msg.find(ref_msg) != std::string::npos) << resp_msg;
     }
+}
+
+TEST(ConvertFunctionToCNNNetworkTests, UnsupportedDynamicOps) {
+    std::shared_ptr<ngraph::Function> f;
+    {
+        auto param = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
+        param->set_friendly_name("param");
+        auto relu = std::make_shared<ngraph::opset4::Relu>(param);
+        relu->set_friendly_name("relu");
+        auto non_zero = std::make_shared<ngraph::opset4::NonZero>(relu);
+        non_zero->set_friendly_name("non_zero");
+        auto result = std::make_shared<ngraph::op::Result>(non_zero->output(0));
+        result->set_friendly_name("result");
+
+        f = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
+                                               ngraph::ParameterVector{param});
+    }
+
+    InferenceEngine::CNNNetwork nGraphImpl(f);
+    try {
+        InferenceEngine::details::convertFunctionToICNNNetwork(f, nGraphImpl);
+        FAIL() << "InferenceEngineException must be thrown";
+    } catch(InferenceEngine::details::InferenceEngineException & e) {
+        EXPECT_THAT(e.what(), testing::HasSubstr(std::string("Unsupported dynamic ops: \n"
+                                                             "v0::Parameter param () -> (f32?)\n"
+                                                             "v0::Relu relu (param[0]:f32?) -> (f32?)\n"
+                                                             "v3::NonZero non_zero (relu[0]:f32?) -> (i64{?,?})\n"
+                                                             "v0::Result result (non_zero[0]:i64{?,?}) -> (i64{?,?})")));
+    }
 }
\ No newline at end of file