// Disable reshape for generic nodes
::ngraph::op::GenericIE::DisableReshape noReshape(_ngraph_function);
- StatusCode ret = reshape({}, &desc);
- if (ret != OK)
- THROW_IE_EXCEPTION << desc.msg;
+ reshape({});
}
StatusCode
CNNNetworkNGraphImpl::reshape(const std::map<std::string, std::vector<size_t>>& inputShapes,
- ResponseDesc* responseDesc) noexcept {
- OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::reshape");
-
+ ResponseDesc* responseDesc) noexcept {
if (cnnNetwork)
return cnnNetwork->reshape(inputShapes, responseDesc);
try {
auto params = _ngraph_function->get_parameters();
- for (size_t i = 0; i < params.size(); i++) {
+ // Check that we need to do reshape only if input shapes will be changed
+ bool needReshape = false;
+ for (size_t i = 0; i < params.size() && !inputShapes.empty(); i++) {
const auto& param = params[i];
- if (inputShapes.find(param->get_friendly_name()) == inputShapes.end())
+ auto it = inputShapes.find(param->get_friendly_name());
+ if (it == inputShapes.end())
continue;
- ::ngraph::PartialShape shape(inputShapes.at(param->get_friendly_name()));
- auto newParam = std::make_shared<::ngraph::op::Parameter>(param->get_element_type(), shape);
- newParam->set_friendly_name(param->get_friendly_name());
- _ngraph_function->replace_parameter(i, newParam);
+ if (param->get_partial_shape().is_dynamic() || param->get_shape() != it->second)
+ needReshape = true;
}
- _ngraph_function->validate_nodes_and_infer_types();
+ if (needReshape)
+ reshape(inputShapes);
+ } catch (std::exception& ex) {
+ return DescriptionBuffer(GENERAL_ERROR, responseDesc) << ex.what();
+ }
+
+ return OK;
+}
+
+void
+CNNNetworkNGraphImpl::reshape(const std::map<std::string, std::vector<size_t>>& inputShapes) {
+ OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::reshape");
+
+ auto params = _ngraph_function->get_parameters();
+
+ for (size_t i = 0; i < params.size(); i++) {
+ const auto& param = params[i];
+ if (inputShapes.find(param->get_friendly_name()) == inputShapes.end())
+ continue;
+ ::ngraph::PartialShape shape(inputShapes.at(param->get_friendly_name()));
+ auto newParam = std::make_shared<::ngraph::op::Parameter>(param->get_element_type(), shape);
+ newParam->set_friendly_name(param->get_friendly_name());
+ _ngraph_function->replace_parameter(i, newParam);
+ }
+ _ngraph_function->validate_nodes_and_infer_types();
+ {
+ auto specialized_ngraph_function = cloneFunction(true);
+ // Call this transformation because OneHot IE and nGraph have different output precisions
{
- auto specialized_ngraph_function = cloneFunction(true);
- // Call this transformation because OneHot IE and nGraph have different output precisions
- {
- OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::ConvertOneHot");
- ::ngraph::pass::Manager manager;
- manager.register_pass<::ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(
- specialized_ngraph_function);
- manager.run_passes(specialized_ngraph_function);
- }
- specialized_ngraph_function->validate_nodes_and_infer_types();
+ OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::ConvertOneHot");
+ ::ngraph::pass::Manager manager;
+ manager.register_pass<::ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(
+ specialized_ngraph_function);
+ manager.run_passes(specialized_ngraph_function);
+ }
+ specialized_ngraph_function->validate_nodes_and_infer_types();
#if 0
- for (const auto &op : specialized_ngraph_function->get_ordered_ops()) {
- cout << "[ " << op->description() << " ] " << op->get_friendly_name() << endl;
- cout << " Inputs: ";
- for (const auto &in : op->inputs()) {
- cout << "[" << in.get_element_type().get_type_name() << "]";
- if (in.get_partial_shape().is_dynamic()) {
- cout << "dyn_shape";
- } else {
- cout << "{";
- bool first = true;
- for (auto i : in.get_shape()) {
- if (!first) cout << ",";
- cout << i;
- first = false;
- }
- cout << "} ";
+ for (const auto &op : specialized_ngraph_function->get_ordered_ops()) {
+ cout << "[ " << op->description() << " ] " << op->get_friendly_name() << endl;
+ cout << " Inputs: ";
+ for (const auto &in : op->inputs()) {
+ cout << "[" << in.get_element_type().get_type_name() << "]";
+ if (in.get_partial_shape().is_dynamic()) {
+ cout << "dyn_shape";
+ } else {
+ cout << "{";
+ bool first = true;
+ for (auto i : in.get_shape()) {
+ if (!first) cout << ",";
+ cout << i;
+ first = false;
}
+ cout << "} ";
}
- cout << endl << " Outputs: ";
- for (const auto &in : op->outputs()) {
- cout << "[" << in.get_element_type().get_type_name() << "]";
- if (in.get_partial_shape().is_dynamic()) {
- cout << "dyn_shape";
- } else {
- cout << "{";
- bool first = true;
- for (auto i : in.get_shape()) {
- if (!first) cout << ",";
- cout << i;
- first = false;
- }
- cout << "} ";
+ }
+ cout << endl << " Outputs: ";
+ for (const auto &in : op->outputs()) {
+ cout << "[" << in.get_element_type().get_type_name() << "]";
+ if (in.get_partial_shape().is_dynamic()) {
+ cout << "dyn_shape";
+ } else {
+ cout << "{";
+ bool first = true;
+ for (auto i : in.get_shape()) {
+ if (!first) cout << ",";
+ cout << i;
+ first = false;
}
+ cout << "} ";
}
- cout << endl;
}
+ cout << endl;
+ }
#endif
- std::unordered_set<std::string> opName;
- for (const auto &result : specialized_ngraph_function->get_results()) {
- addOutput(result->input_value(0));
- }
+ std::unordered_set<std::string> opName;
+ for (const auto &result : specialized_ngraph_function->get_results()) {
+ addOutput(result->input_value(0));
+ }
- for (const auto ¶meter : specialized_ngraph_function->get_parameters()) {
- const auto &outName = parameter->get_friendly_name();
- if (opName.find(outName) != opName.end()) {
- THROW_IE_EXCEPTION << "All operations in nGraph function should have unique friendly names!";
- }
- opName.insert(outName);
- createDataForResult(parameter, outName, _data[outName]);
+ for (const auto ¶meter : specialized_ngraph_function->get_parameters()) {
+ const auto &outName = parameter->get_friendly_name();
+ if (opName.find(outName) != opName.end()) {
+ THROW_IE_EXCEPTION << "All operations in nGraph function should have unique friendly names!";
}
+ opName.insert(outName);
+ createDataForResult(parameter, outName, _data[outName]);
}
- } catch (std::exception& ex) {
- return DescriptionBuffer(GENERAL_ERROR, responseDesc) << ex.what();
}
-
- return OK;
}
StatusCode CNNNetworkNGraphImpl::serialize(const std::string& xmlPath, const std::string& binPath,