[IE] Add RTTI macro to ReshapeFullyConnectedFusion ngrap pass (#2837)
authorZoran Zomborat <zoran.zomborat@intel.com>
Thu, 29 Oct 2020 04:33:55 +0000 (06:33 +0200)
committerGitHub <noreply@github.com>
Thu, 29 Oct 2020 04:33:55 +0000 (07:33 +0300)
inference-engine/src/legacy_api/include/legacy/transformations/convert_opset1_to_legacy/reshape_fc_fusion.hpp
inference-engine/src/legacy_api/src/transformations/convert_opset1_to_legacy/reshape_fc_fusion.cpp [new file with mode: 0644]

index 66ebdb1..d91c4c2 100644 (file)
 namespace ngraph {
 namespace pass {
 
-class ReshapeFullyConnectedFusion;
+class INFERENCE_ENGINE_API_CLASS(ReshapeFullyConnectedFusion);
 
 }  // namespace pass
 }  // namespace ngraph
 
 class ngraph::pass::ReshapeFullyConnectedFusion : public ngraph::pass::GraphRewrite {
 public:
+    NGRAPH_RTTI_DECLARATION;
     ReshapeFullyConnectedFusion() : GraphRewrite() {
         construct_reshape_fc();
     }
@@ -44,43 +45,5 @@ public:
     }
 
 private:
-    void construct_reshape_fc() {
-        auto m_reshape = pattern::wrap_type<opset1::Reshape>(pattern::has_static_shape());
-        auto m_fc = pattern::wrap_type<op::FullyConnected>({m_reshape,
-                                                            pattern::any_input(),
-                                                            pattern::any_input()});
-
-        ngraph::graph_rewrite_callback callback = [=](pattern::Matcher &m) {
-            auto & pattern_to_output = m.get_pattern_value_map();
-            auto fc = pattern_to_output[m_fc].get_node_shared_ptr();
-            auto reshape = pattern_to_output[m_reshape].get_node_shared_ptr();
-
-            // Check that Reshape reshapes 4D tensor to 2D or input shape = output shape
-            auto shape_in = reshape->input_value(0).get_shape();
-            auto shape_out = reshape->get_shape();
-            if (!((shape_in.size() == 4 && reshape->get_shape().size() == 2) || (shape_in == shape_out && !shape_in.empty()))) {
-                return false;
-            }
-
-            // Check that Weights[O, C*H*W] consistent with Input[N, C, H, W]
-            auto shape_w = fc->input_value(1).get_shape();
-            if (shape_in[0] != shape_out[0] || std::accumulate(shape_in.begin() + 1, shape_in.end(), size_t{1}, std::multiplies<size_t>()) != shape_w[1]) {
-                return false;
-            }
-
-            auto new_fc = std::make_shared<op::FullyConnected>(reshape->input_value(0),
-                                                               fc->input_value(1),
-                                                               fc->input_value(2),
-                                                               fc->get_shape(),
-                                                               fc->output(0).get_element_type());
-
-            new_fc->set_friendly_name(fc->get_friendly_name());
-            ngraph::copy_runtime_info({reshape, fc}, new_fc);
-            ngraph::replace_node(fc, new_fc);
-            return true;
-        };
-
-        auto m = std::make_shared<ngraph::pattern::Matcher>(m_fc, "ReshapeFullyConnectedFusion");
-        this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
-    }
+    void construct_reshape_fc();
 };
diff --git a/inference-engine/src/legacy_api/src/transformations/convert_opset1_to_legacy/reshape_fc_fusion.cpp b/inference-engine/src/legacy_api/src/transformations/convert_opset1_to_legacy/reshape_fc_fusion.cpp
new file mode 100644 (file)
index 0000000..610424d
--- /dev/null
@@ -0,0 +1,56 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "legacy/transformations/convert_opset1_to_legacy/reshape_fc_fusion.hpp"
+
+#include <memory>
+#include <vector>
+#include <string>
+
+#include <ngraph/opsets/opset1.hpp>
+
+#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
+
+NGRAPH_RTTI_DEFINITION(ngraph::pass::ReshapeFullyConnectedFusion, "ReshapeFullyConnectedFusion", 0);
+
+void ngraph::pass::ReshapeFullyConnectedFusion::construct_reshape_fc() {
+    auto m_reshape = pattern::wrap_type<opset1::Reshape>(pattern::has_static_shape());
+    auto m_fc = pattern::wrap_type<op::FullyConnected>({m_reshape,
+                                                        pattern::any_input(),
+                                                        pattern::any_input()});
+
+    ngraph::graph_rewrite_callback callback = [=](pattern::Matcher &m) {
+        auto & pattern_to_output = m.get_pattern_value_map();
+        auto fc = pattern_to_output[m_fc].get_node_shared_ptr();
+        auto reshape = pattern_to_output[m_reshape].get_node_shared_ptr();
+
+        // Check that Reshape reshapes 4D tensor to 2D or input shape = output shape
+        auto shape_in = reshape->input_value(0).get_shape();
+        auto shape_out = reshape->get_shape();
+        if (!((shape_in.size() == 4 && reshape->get_shape().size() == 2) || (shape_in == shape_out && !shape_in.empty()))) {
+            return false;
+        }
+
+        // Check that Weights[O, C*H*W] consistent with Input[N, C, H, W]
+        auto shape_w = fc->input_value(1).get_shape();
+        if (shape_in[0] != shape_out[0] || std::accumulate(shape_in.begin() + 1, shape_in.end(), size_t{1}, std::multiplies<size_t>()) != shape_w[1]) {
+            return false;
+        }
+
+        auto new_fc = std::make_shared<op::FullyConnected>(reshape->input_value(0),
+                                                            fc->input_value(1),
+                                                            fc->input_value(2),
+                                                            fc->get_shape(),
+                                                            fc->output(0).get_element_type());
+
+        new_fc->set_friendly_name(fc->get_friendly_name());
+        ngraph::copy_runtime_info({reshape, fc}, new_fc);
+        ngraph::replace_node(fc, new_fc);
+        return true;
+    };
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(m_fc, "ReshapeFullyConnectedFusion");
+    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+}