From 72d387c702c629d986a8e991641df45c84ae3a0c Mon Sep 17 00:00:00 2001 From: Ilya Churaev Date: Wed, 14 Oct 2020 09:42:39 +0300 Subject: [PATCH] Do reshape only if input shapes will be changed (#2632) * Added private reshape * Removed incorrect check --- .../inference_engine/cnn_network_ngraph_impl.cpp | 154 ++++++++++++--------- .../inference_engine/cnn_network_ngraph_impl.hpp | 1 + 2 files changed, 86 insertions(+), 69 deletions(-) diff --git a/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp b/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp index 5d51b3e..243d1bf 100644 --- a/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp +++ b/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp @@ -301,101 +301,117 @@ void CNNNetworkNGraphImpl::reshape() { // Disable reshape for generic nodes ::ngraph::op::GenericIE::DisableReshape noReshape(_ngraph_function); - StatusCode ret = reshape({}, &desc); - if (ret != OK) - THROW_IE_EXCEPTION << desc.msg; + reshape({}); } StatusCode CNNNetworkNGraphImpl::reshape(const std::map>& inputShapes, - ResponseDesc* responseDesc) noexcept { - OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::reshape"); - + ResponseDesc* responseDesc) noexcept { if (cnnNetwork) return cnnNetwork->reshape(inputShapes, responseDesc); try { auto params = _ngraph_function->get_parameters(); - for (size_t i = 0; i < params.size(); i++) { + // Check that we need to do reshape only if input shapes will be changed + bool needReshape = false; + for (size_t i = 0; i < params.size() && !inputShapes.empty(); i++) { const auto& param = params[i]; - if (inputShapes.find(param->get_friendly_name()) == inputShapes.end()) + auto it = inputShapes.find(param->get_friendly_name()); + if (it == inputShapes.end()) continue; - ::ngraph::PartialShape shape(inputShapes.at(param->get_friendly_name())); - auto newParam = std::make_shared<::ngraph::op::Parameter>(param->get_element_type(), shape); - newParam->set_friendly_name(param->get_friendly_name()); - _ngraph_function->replace_parameter(i, newParam); + if (param->get_partial_shape().is_dynamic() || param->get_shape() != it->second) + needReshape = true; } - _ngraph_function->validate_nodes_and_infer_types(); + if (needReshape) + reshape(inputShapes); + } catch (std::exception& ex) { + return DescriptionBuffer(GENERAL_ERROR, responseDesc) << ex.what(); + } + + return OK; +} + +void +CNNNetworkNGraphImpl::reshape(const std::map>& inputShapes) { + OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::reshape"); + + auto params = _ngraph_function->get_parameters(); + + for (size_t i = 0; i < params.size(); i++) { + const auto& param = params[i]; + if (inputShapes.find(param->get_friendly_name()) == inputShapes.end()) + continue; + ::ngraph::PartialShape shape(inputShapes.at(param->get_friendly_name())); + auto newParam = std::make_shared<::ngraph::op::Parameter>(param->get_element_type(), shape); + newParam->set_friendly_name(param->get_friendly_name()); + _ngraph_function->replace_parameter(i, newParam); + } + _ngraph_function->validate_nodes_and_infer_types(); + { + auto specialized_ngraph_function = cloneFunction(true); + // Call this transformation because OneHot IE and nGraph have different output precisions { - auto specialized_ngraph_function = cloneFunction(true); - // Call this transformation because OneHot IE and nGraph have different output precisions - { - OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::ConvertOneHot"); - ::ngraph::pass::Manager manager; - manager.register_pass<::ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type( - specialized_ngraph_function); - manager.run_passes(specialized_ngraph_function); - } - specialized_ngraph_function->validate_nodes_and_infer_types(); + OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::ConvertOneHot"); + ::ngraph::pass::Manager manager; + manager.register_pass<::ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type( + specialized_ngraph_function); + manager.run_passes(specialized_ngraph_function); + } + specialized_ngraph_function->validate_nodes_and_infer_types(); #if 0 - for (const auto &op : specialized_ngraph_function->get_ordered_ops()) { - cout << "[ " << op->description() << " ] " << op->get_friendly_name() << endl; - cout << " Inputs: "; - for (const auto &in : op->inputs()) { - cout << "[" << in.get_element_type().get_type_name() << "]"; - if (in.get_partial_shape().is_dynamic()) { - cout << "dyn_shape"; - } else { - cout << "{"; - bool first = true; - for (auto i : in.get_shape()) { - if (!first) cout << ","; - cout << i; - first = false; - } - cout << "} "; + for (const auto &op : specialized_ngraph_function->get_ordered_ops()) { + cout << "[ " << op->description() << " ] " << op->get_friendly_name() << endl; + cout << " Inputs: "; + for (const auto &in : op->inputs()) { + cout << "[" << in.get_element_type().get_type_name() << "]"; + if (in.get_partial_shape().is_dynamic()) { + cout << "dyn_shape"; + } else { + cout << "{"; + bool first = true; + for (auto i : in.get_shape()) { + if (!first) cout << ","; + cout << i; + first = false; } + cout << "} "; } - cout << endl << " Outputs: "; - for (const auto &in : op->outputs()) { - cout << "[" << in.get_element_type().get_type_name() << "]"; - if (in.get_partial_shape().is_dynamic()) { - cout << "dyn_shape"; - } else { - cout << "{"; - bool first = true; - for (auto i : in.get_shape()) { - if (!first) cout << ","; - cout << i; - first = false; - } - cout << "} "; + } + cout << endl << " Outputs: "; + for (const auto &in : op->outputs()) { + cout << "[" << in.get_element_type().get_type_name() << "]"; + if (in.get_partial_shape().is_dynamic()) { + cout << "dyn_shape"; + } else { + cout << "{"; + bool first = true; + for (auto i : in.get_shape()) { + if (!first) cout << ","; + cout << i; + first = false; } + cout << "} "; } - cout << endl; } + cout << endl; + } #endif - std::unordered_set opName; - for (const auto &result : specialized_ngraph_function->get_results()) { - addOutput(result->input_value(0)); - } + std::unordered_set opName; + for (const auto &result : specialized_ngraph_function->get_results()) { + addOutput(result->input_value(0)); + } - for (const auto ¶meter : specialized_ngraph_function->get_parameters()) { - const auto &outName = parameter->get_friendly_name(); - if (opName.find(outName) != opName.end()) { - THROW_IE_EXCEPTION << "All operations in nGraph function should have unique friendly names!"; - } - opName.insert(outName); - createDataForResult(parameter, outName, _data[outName]); + for (const auto ¶meter : specialized_ngraph_function->get_parameters()) { + const auto &outName = parameter->get_friendly_name(); + if (opName.find(outName) != opName.end()) { + THROW_IE_EXCEPTION << "All operations in nGraph function should have unique friendly names!"; } + opName.insert(outName); + createDataForResult(parameter, outName, _data[outName]); } - } catch (std::exception& ex) { - return DescriptionBuffer(GENERAL_ERROR, responseDesc) << ex.what(); } - - return OK; } StatusCode CNNNetworkNGraphImpl::serialize(const std::string& xmlPath, const std::string& binPath, diff --git a/inference-engine/src/inference_engine/cnn_network_ngraph_impl.hpp b/inference-engine/src/inference_engine/cnn_network_ngraph_impl.hpp index fa80f44..f9160bd 100644 --- a/inference-engine/src/inference_engine/cnn_network_ngraph_impl.hpp +++ b/inference-engine/src/inference_engine/cnn_network_ngraph_impl.hpp @@ -118,6 +118,7 @@ private: * @brief Reshape on the same shape */ void reshape(); + void reshape(const std::map>& inputShapes); }; class TINGraphBody : public CNNNetworkNGraphImpl { -- 2.7.4