Fix for Kaldi models with a batch of more than 1 (#1012)
authorIvan Tikhonov <ivan.tikhonov@intel.com>
Tue, 23 Jun 2020 05:22:12 +0000 (08:22 +0300)
committerGitHub <noreply@github.com>
Tue, 23 Jun 2020 05:22:12 +0000 (08:22 +0300)
* Fix kaldi models (batch > 1)

* ngraph codestyle

* fix ngraph to ie conversion

* Added comment

* apply review comments

* Added test for the case using the SetBatchSize function when ReadValue op is in the network

* Check status code instead of message

* Use new ngraph api

inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp
inference-engine/tests/functional/inference_engine/cnn_network/cnn_ngraph_impl_tests.cpp

index b34e599..89df86b 100644 (file)
@@ -769,6 +769,20 @@ std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_p
         cnnLayer->insData.resize(inputCount);
 
         for (size_t i = 0; i < layer->get_output_size(); i++) {
+            // Memory node with index = 1 has no inputs according to the specification.
+            // For proper conversion, we must cut off all the layers and data nodes above ReadValue,
+            // if they are connected only with this layer.
+            // Now MO generates only constants or constant sub-graphs as input to ReadValue op.
+            if (std::dynamic_pointer_cast<::ngraph::op::Constant>(layer)) {
+                bool all_to_read_value = !layer->output(i).get_target_inputs().empty();
+                for (const auto &output_input : layer->output(i).get_target_inputs()) {
+                    all_to_read_value
+                            &= dynamic_cast<ngraph::op::ReadValue *>(output_input.get_node()) != nullptr;
+                }
+                if (all_to_read_value)
+                    continue;
+            }
+
             if (cnnLayer->type == "Memory" && cnnLayer->params["index"] == "0") {
                 cnnLayer->outData.clear();
                 continue;
@@ -776,7 +790,6 @@ std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_p
             std::string outName = layer->get_friendly_name();
             if (layer->get_output_size() != 1) outName += "." + std::to_string(i);
             DataPtr &ptr = cnnNetworkImpl->getData(outName.c_str());
-
             SizeVector dims;
             dims = layer->get_output_shape(i);
             for (const auto &dim : dims) {
index 3ef0137..0841449 100644 (file)
@@ -17,6 +17,7 @@
 #include <ie_core.hpp>
 #include <net_pass.h>
 
+#include <ngraph/opsets/opset3.hpp>
 #include <ngraph/function.hpp>
 #include <ngraph/variant.hpp>
 #include <ngraph/op/maximum.hpp>
@@ -677,4 +678,25 @@ TEST(CNNNGraphImplTests, TestCheckStats) {
     InferenceEngine::details::CNNNetworkNGraphImpl cnnNet(ngraph);
 }
 
+TEST(CNNNGraphImplTests, CanSetBatchReadValue) {
+    std::shared_ptr<ngraph::Function> ngraph;
+    {
+        auto input = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 2});
+        auto constant = std::make_shared<ngraph::opset3::Constant>(ngraph::element::f32, ngraph::Shape{1, 2},
+                std::vector<float>{1, 2});
+
+        auto read_value = std::make_shared<ngraph::opset3::ReadValue>(constant, "variable_id");
+        auto add = std::make_shared<ngraph::opset3::Add>(input, read_value);
+        auto result = std::make_shared<ngraph::op::Result>(add);
+
+        ngraph::ParameterVector params = {input};
+        ngraph::ResultVector results = {result};
+
+        ngraph = std::make_shared<ngraph::Function>(results, params);
+    }
+
+    InferenceEngine::details::CNNNetworkNGraphImpl cnnNet(ngraph);
+    auto status = cnnNet.getCNNNetwork()->setBatchSize(4, nullptr);
+    EXPECT_EQ(status, StatusCode::OK);
+}
 IE_SUPPRESS_DEPRECATED_END