Do reshape only if input shapes will be changed (#2632)
authorIlya Churaev <ilya.churaev@intel.com>
Wed, 14 Oct 2020 06:42:39 +0000 (09:42 +0300)
committerGitHub <noreply@github.com>
Wed, 14 Oct 2020 06:42:39 +0000 (09:42 +0300)
* Added private reshape

* Removed incorrect check

inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp
inference-engine/src/inference_engine/cnn_network_ngraph_impl.hpp

index 5d51b3e..243d1bf 100644 (file)
@@ -301,101 +301,117 @@ void CNNNetworkNGraphImpl::reshape() {
 
     // Disable reshape for generic nodes
     ::ngraph::op::GenericIE::DisableReshape noReshape(_ngraph_function);
-    StatusCode ret = reshape({}, &desc);
-    if (ret != OK)
-        THROW_IE_EXCEPTION << desc.msg;
+    reshape({});
 }
 
 StatusCode
 CNNNetworkNGraphImpl::reshape(const std::map<std::string, std::vector<size_t>>& inputShapes,
-                        ResponseDesc* responseDesc) noexcept {
-    OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::reshape");
-
+                              ResponseDesc* responseDesc) noexcept {
     if (cnnNetwork)
         return cnnNetwork->reshape(inputShapes, responseDesc);
     try {
         auto params = _ngraph_function->get_parameters();
 
-        for (size_t i = 0; i < params.size(); i++) {
+        // Check that we need to do reshape only if input shapes will be changed
+        bool needReshape = false;
+        for (size_t i = 0; i < params.size() && !inputShapes.empty(); i++) {
             const auto& param = params[i];
-            if (inputShapes.find(param->get_friendly_name()) == inputShapes.end())
+            auto it = inputShapes.find(param->get_friendly_name());
+            if (it == inputShapes.end())
                 continue;
-            ::ngraph::PartialShape shape(inputShapes.at(param->get_friendly_name()));
-            auto newParam = std::make_shared<::ngraph::op::Parameter>(param->get_element_type(), shape);
-            newParam->set_friendly_name(param->get_friendly_name());
-            _ngraph_function->replace_parameter(i, newParam);
+            if (param->get_partial_shape().is_dynamic() || param->get_shape() != it->second)
+                needReshape = true;
         }
-        _ngraph_function->validate_nodes_and_infer_types();
+        if (needReshape)
+            reshape(inputShapes);
+    } catch (std::exception& ex) {
+        return DescriptionBuffer(GENERAL_ERROR, responseDesc) << ex.what();
+    }
+
+    return OK;
+}
+
+void
+CNNNetworkNGraphImpl::reshape(const std::map<std::string, std::vector<size_t>>& inputShapes) {
+    OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::reshape");
+
+    auto params = _ngraph_function->get_parameters();
+
+    for (size_t i = 0; i < params.size(); i++) {
+        const auto& param = params[i];
+        if (inputShapes.find(param->get_friendly_name()) == inputShapes.end())
+            continue;
+        ::ngraph::PartialShape shape(inputShapes.at(param->get_friendly_name()));
+        auto newParam = std::make_shared<::ngraph::op::Parameter>(param->get_element_type(), shape);
+        newParam->set_friendly_name(param->get_friendly_name());
+        _ngraph_function->replace_parameter(i, newParam);
+    }
+    _ngraph_function->validate_nodes_and_infer_types();
 
+    {
+        auto specialized_ngraph_function = cloneFunction(true);
+        // Call this transformation because OneHot IE and nGraph have different output precisions
         {
-            auto specialized_ngraph_function = cloneFunction(true);
-            // Call this transformation because OneHot IE and nGraph have different output precisions
-            {
-                OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::ConvertOneHot");
-                ::ngraph::pass::Manager manager;
-                manager.register_pass<::ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(
-                        specialized_ngraph_function);
-                manager.run_passes(specialized_ngraph_function);
-            }
-            specialized_ngraph_function->validate_nodes_and_infer_types();
+            OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::ConvertOneHot");
+            ::ngraph::pass::Manager manager;
+            manager.register_pass<::ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(
+                    specialized_ngraph_function);
+            manager.run_passes(specialized_ngraph_function);
+        }
+        specialized_ngraph_function->validate_nodes_and_infer_types();
 
 #if 0
-            for (const auto &op : specialized_ngraph_function->get_ordered_ops()) {
-                cout << "[ " <<  op->description() << " ] " << op->get_friendly_name() << endl;
-                cout << "    Inputs: ";
-                for (const auto &in : op->inputs()) {
-                    cout << "[" << in.get_element_type().get_type_name() << "]";
-                    if (in.get_partial_shape().is_dynamic()) {
-                        cout << "dyn_shape";
-                    } else {
-                        cout << "{";
-                        bool first = true;
-                        for (auto i : in.get_shape()) {
-                            if (!first) cout << ",";
-                            cout << i;
-                            first = false;
-                        }
-                        cout << "} ";
+        for (const auto &op : specialized_ngraph_function->get_ordered_ops()) {
+            cout << "[ " <<  op->description() << " ] " << op->get_friendly_name() << endl;
+            cout << "    Inputs: ";
+            for (const auto &in : op->inputs()) {
+                cout << "[" << in.get_element_type().get_type_name() << "]";
+                if (in.get_partial_shape().is_dynamic()) {
+                    cout << "dyn_shape";
+                } else {
+                    cout << "{";
+                    bool first = true;
+                    for (auto i : in.get_shape()) {
+                        if (!first) cout << ",";
+                        cout << i;
+                        first = false;
                     }
+                    cout << "} ";
                 }
-                cout << endl << "    Outputs: ";
-                for (const auto &in : op->outputs()) {
-                    cout << "[" << in.get_element_type().get_type_name() << "]";
-                    if (in.get_partial_shape().is_dynamic()) {
-                        cout << "dyn_shape";
-                    } else {
-                        cout << "{";
-                        bool first = true;
-                        for (auto i : in.get_shape()) {
-                            if (!first) cout << ",";
-                            cout << i;
-                            first = false;
-                        }
-                        cout << "} ";
+            }
+            cout << endl << "    Outputs: ";
+            for (const auto &in : op->outputs()) {
+                cout << "[" << in.get_element_type().get_type_name() << "]";
+                if (in.get_partial_shape().is_dynamic()) {
+                    cout << "dyn_shape";
+                } else {
+                    cout << "{";
+                    bool first = true;
+                    for (auto i : in.get_shape()) {
+                        if (!first) cout << ",";
+                        cout << i;
+                        first = false;
                     }
+                    cout << "} ";
                 }
-                cout << endl;
             }
+            cout << endl;
+        }
 #endif
-            std::unordered_set<std::string> opName;
-            for (const auto &result : specialized_ngraph_function->get_results()) {
-                addOutput(result->input_value(0));
-            }
+        std::unordered_set<std::string> opName;
+        for (const auto &result : specialized_ngraph_function->get_results()) {
+            addOutput(result->input_value(0));
+        }
 
-            for (const auto &parameter : specialized_ngraph_function->get_parameters()) {
-                const auto &outName = parameter->get_friendly_name();
-                if (opName.find(outName) != opName.end()) {
-                    THROW_IE_EXCEPTION << "All operations in nGraph function should have unique friendly names!";
-                }
-                opName.insert(outName);
-                createDataForResult(parameter, outName, _data[outName]);
+        for (const auto &parameter : specialized_ngraph_function->get_parameters()) {
+            const auto &outName = parameter->get_friendly_name();
+            if (opName.find(outName) != opName.end()) {
+                THROW_IE_EXCEPTION << "All operations in nGraph function should have unique friendly names!";
             }
+            opName.insert(outName);
+            createDataForResult(parameter, outName, _data[outName]);
         }
-    } catch (std::exception& ex) {
-        return DescriptionBuffer(GENERAL_ERROR, responseDesc) << ex.what();
     }
-
-    return OK;
 }
 
 StatusCode CNNNetworkNGraphImpl::serialize(const std::string& xmlPath, const std::string& binPath,
index fa80f44..f9160bd 100644 (file)
@@ -118,6 +118,7 @@ private:
      * @brief Reshape on the same shape
      */
     void reshape();
+    void reshape(const std::map<std::string, std::vector<size_t>>& inputShapes);
 };
 
 class TINGraphBody : public CNNNetworkNGraphImpl {