From: Zoran Zomborat Date: Thu, 29 Oct 2020 04:33:55 +0000 (+0200) Subject: [IE] Add RTTI macro to ReshapeFullyConnectedFusion ngrap pass (#2837) X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=4021e144b53d9d9708a856f667d9847cdecabcfa;p=platform%2Fupstream%2Fdldt.git [IE] Add RTTI macro to ReshapeFullyConnectedFusion ngrap pass (#2837) --- diff --git a/inference-engine/src/legacy_api/include/legacy/transformations/convert_opset1_to_legacy/reshape_fc_fusion.hpp b/inference-engine/src/legacy_api/include/legacy/transformations/convert_opset1_to_legacy/reshape_fc_fusion.hpp index 66ebdb1..d91c4c2 100644 --- a/inference-engine/src/legacy_api/include/legacy/transformations/convert_opset1_to_legacy/reshape_fc_fusion.hpp +++ b/inference-engine/src/legacy_api/include/legacy/transformations/convert_opset1_to_legacy/reshape_fc_fusion.hpp @@ -25,13 +25,14 @@ namespace ngraph { namespace pass { -class ReshapeFullyConnectedFusion; +class INFERENCE_ENGINE_API_CLASS(ReshapeFullyConnectedFusion); } // namespace pass } // namespace ngraph class ngraph::pass::ReshapeFullyConnectedFusion : public ngraph::pass::GraphRewrite { public: + NGRAPH_RTTI_DECLARATION; ReshapeFullyConnectedFusion() : GraphRewrite() { construct_reshape_fc(); } @@ -44,43 +45,5 @@ public: } private: - void construct_reshape_fc() { - auto m_reshape = pattern::wrap_type(pattern::has_static_shape()); - auto m_fc = pattern::wrap_type({m_reshape, - pattern::any_input(), - pattern::any_input()}); - - ngraph::graph_rewrite_callback callback = [=](pattern::Matcher &m) { - auto & pattern_to_output = m.get_pattern_value_map(); - auto fc = pattern_to_output[m_fc].get_node_shared_ptr(); - auto reshape = pattern_to_output[m_reshape].get_node_shared_ptr(); - - // Check that Reshape reshapes 4D tensor to 2D or input shape = output shape - auto shape_in = reshape->input_value(0).get_shape(); - auto shape_out = reshape->get_shape(); - if (!((shape_in.size() == 4 && reshape->get_shape().size() == 2) || (shape_in == shape_out && !shape_in.empty()))) { - return false; - } - - // Check that Weights[O, C*H*W] consistent with Input[N, C, H, W] - auto shape_w = fc->input_value(1).get_shape(); - if (shape_in[0] != shape_out[0] || std::accumulate(shape_in.begin() + 1, shape_in.end(), size_t{1}, std::multiplies()) != shape_w[1]) { - return false; - } - - auto new_fc = std::make_shared(reshape->input_value(0), - fc->input_value(1), - fc->input_value(2), - fc->get_shape(), - fc->output(0).get_element_type()); - - new_fc->set_friendly_name(fc->get_friendly_name()); - ngraph::copy_runtime_info({reshape, fc}, new_fc); - ngraph::replace_node(fc, new_fc); - return true; - }; - - auto m = std::make_shared(m_fc, "ReshapeFullyConnectedFusion"); - this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE); - } + void construct_reshape_fc(); }; diff --git a/inference-engine/src/legacy_api/src/transformations/convert_opset1_to_legacy/reshape_fc_fusion.cpp b/inference-engine/src/legacy_api/src/transformations/convert_opset1_to_legacy/reshape_fc_fusion.cpp new file mode 100644 index 0000000..610424d --- /dev/null +++ b/inference-engine/src/legacy_api/src/transformations/convert_opset1_to_legacy/reshape_fc_fusion.cpp @@ -0,0 +1,56 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "legacy/transformations/convert_opset1_to_legacy/reshape_fc_fusion.hpp" + +#include +#include +#include + +#include + +#include +#include + +NGRAPH_RTTI_DEFINITION(ngraph::pass::ReshapeFullyConnectedFusion, "ReshapeFullyConnectedFusion", 0); + +void ngraph::pass::ReshapeFullyConnectedFusion::construct_reshape_fc() { + auto m_reshape = pattern::wrap_type(pattern::has_static_shape()); + auto m_fc = pattern::wrap_type({m_reshape, + pattern::any_input(), + pattern::any_input()}); + + ngraph::graph_rewrite_callback callback = [=](pattern::Matcher &m) { + auto & pattern_to_output = m.get_pattern_value_map(); + auto fc = pattern_to_output[m_fc].get_node_shared_ptr(); + auto reshape = pattern_to_output[m_reshape].get_node_shared_ptr(); + + // Check that Reshape reshapes 4D tensor to 2D or input shape = output shape + auto shape_in = reshape->input_value(0).get_shape(); + auto shape_out = reshape->get_shape(); + if (!((shape_in.size() == 4 && reshape->get_shape().size() == 2) || (shape_in == shape_out && !shape_in.empty()))) { + return false; + } + + // Check that Weights[O, C*H*W] consistent with Input[N, C, H, W] + auto shape_w = fc->input_value(1).get_shape(); + if (shape_in[0] != shape_out[0] || std::accumulate(shape_in.begin() + 1, shape_in.end(), size_t{1}, std::multiplies()) != shape_w[1]) { + return false; + } + + auto new_fc = std::make_shared(reshape->input_value(0), + fc->input_value(1), + fc->input_value(2), + fc->get_shape(), + fc->output(0).get_element_type()); + + new_fc->set_friendly_name(fc->get_friendly_name()); + ngraph::copy_runtime_info({reshape, fc}, new_fc); + ngraph::replace_node(fc, new_fc); + return true; + }; + + auto m = std::make_shared(m_fc, "ReshapeFullyConnectedFusion"); + this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE); +}