Super smart reshape: HC Reshape to 2D followed by MatMul (#2183)
authorEvgenya Stepyreva <evgenya.stepyreva@intel.com>
Mon, 14 Sep 2020 10:45:27 +0000 (13:45 +0300)
committerGitHub <noreply@github.com>
Mon, 14 Sep 2020 10:45:27 +0000 (13:45 +0300)
* Initial commit

* [SSR] Reshape(2D)->MatMul constrain relaxation

* Moved common pattern mechanics to the common function

* Moving SmartReshape to CNNNetworkNgraphImpl ctors

* Review comment

* Tests

inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp
inference-engine/src/transformations/include/transformations/smart_reshape/reshape_with_hc_output.hpp [new file with mode: 0644]
inference-engine/src/transformations/include/transformations/smart_reshape/smart_reshape.hpp [new file with mode: 0644]
inference-engine/src/transformations/include/transformations/smart_reshape/utils.hpp [new file with mode: 0644]
inference-engine/src/transformations/src/transformations/smart_reshape/reshape_with_hc_output.cpp [new file with mode: 0644]
inference-engine/src/transformations/src/transformations/smart_reshape/smart_reshape.cpp [new file with mode: 0644]
inference-engine/tests/functional/inference_engine/cnn_network/smart_reshape_tests.cpp [new file with mode: 0644]

index a920eac..3f3dae7 100644 (file)
@@ -22,6 +22,7 @@
 
 #include <transformations/utils/utils.hpp>
 #include <transformations/convert_opset1_to_legacy/convert_one_hot_to_one_hot_ie.hpp>
+#include <transformations/smart_reshape/smart_reshape.hpp>
 
 #include "ngraph_ops/eltwise.hpp"
 #include "exec_graph_info.hpp"
@@ -126,6 +127,10 @@ CNNNetworkNGraphImpl::CNNNetworkNGraphImpl(const std::shared_ptr<Function>& nGra
     // Add shape infer method for old operations which are not included to opset1, opset2 and opset3
     ::ngraph::op::GenericIE::addExtension(_ngraph_function, std::make_shared<ShapeInfer::BuiltInShapeInferHolder>());
 
+    ngraph::pass::Manager ssr_manager;
+    ssr_manager.register_pass<ngraph::pass::SmartReshape>();
+    ssr_manager.run_passes(_ngraph_function);
+
     reshape();
     for (const auto& layer : _ngraph_function->get_parameters()) {
         std::string outName = layer->get_friendly_name();
diff --git a/inference-engine/src/transformations/include/transformations/smart_reshape/reshape_with_hc_output.hpp b/inference-engine/src/transformations/include/transformations/smart_reshape/reshape_with_hc_output.hpp
new file mode 100644 (file)
index 0000000..c925928
--- /dev/null
@@ -0,0 +1,30 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <memory>
+#include <functional>
+
+#include <transformations_visibility.hpp>
+
+#include <ngraph/pass/graph_rewrite.hpp>
+
+namespace ngraph {
+namespace pass {
+
+class TRANSFORMATIONS_API ReshapeAMatMul;
+class TRANSFORMATIONS_API ReshapeBMatMul;
+
+}  // namespace pass
+}  // namespace ngraph
+
+class ngraph::pass::ReshapeAMatMul: public ngraph::pass::MatcherPass {
+public:
+    ReshapeAMatMul();
+};
+class ngraph::pass::ReshapeBMatMul: public ngraph::pass::MatcherPass {
+public:
+    ReshapeBMatMul();
+};
\ No newline at end of file
diff --git a/inference-engine/src/transformations/include/transformations/smart_reshape/smart_reshape.hpp b/inference-engine/src/transformations/include/transformations/smart_reshape/smart_reshape.hpp
new file mode 100644 (file)
index 0000000..344787e
--- /dev/null
@@ -0,0 +1,26 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <vector>
+#include <memory>
+
+#include <transformations_visibility.hpp>
+
+#include <ngraph/pass/graph_rewrite.hpp>
+
+
+namespace ngraph {
+namespace pass {
+
+class TRANSFORMATIONS_API SmartReshape;
+
+}  // namespace pass
+}  // namespace ngraph
+
+class ngraph::pass::SmartReshape: public ngraph::pass::FunctionPass {
+public:
+    bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
+};
diff --git a/inference-engine/src/transformations/include/transformations/smart_reshape/utils.hpp b/inference-engine/src/transformations/include/transformations/smart_reshape/utils.hpp
new file mode 100644 (file)
index 0000000..a9d0888
--- /dev/null
@@ -0,0 +1,38 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <assert.h>
+#include <vector>
+#include <limits>
+
+#include <transformations_visibility.hpp>
+#include <ngraph/op/util/op_annotations.hpp>
+#include <ngraph/op/constant.hpp>
+#include <ngraph/opsets/opset3.hpp>
+#include <ngraph/opsets/opset4.hpp>
+
+namespace ngraph {
+namespace op {
+namespace util {
+
+std::shared_ptr<ngraph::Node> node_to_get_shape_value_of_indices_from_shape_node(const std::shared_ptr<ngraph::Node>& shape_node,
+                                                                                 const std::vector<size_t>& indices) {
+    return std::make_shared<ngraph::opset4::Gather>(shape_node,
+                                                    ngraph::opset4::Constant::create(ngraph::element::i64, {indices.size()}, indices),
+                                                    ngraph::opset4::Constant::create(ngraph::element::i64, {}, {0}));
+}
+
+std::shared_ptr<ngraph::Node> node_to_get_shape_value_of_indices_from_shape_source(const ngraph::Output<ngraph::Node>& shape_source,
+                                                                                   const std::vector<size_t>& indices) {
+    const auto & shape_node = std::make_shared<ngraph::opset4::ShapeOf>(shape_source);
+    return node_to_get_shape_value_of_indices_from_shape_node(shape_node, indices);
+}
+
+}  // namespace util
+}  // namespace op
+}  // namespace ngraph
\ No newline at end of file
diff --git a/inference-engine/src/transformations/src/transformations/smart_reshape/reshape_with_hc_output.cpp b/inference-engine/src/transformations/src/transformations/smart_reshape/reshape_with_hc_output.cpp
new file mode 100644 (file)
index 0000000..2ee2d3c
--- /dev/null
@@ -0,0 +1,79 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "transformations/smart_reshape/reshape_with_hc_output.hpp"
+#include "transformations/smart_reshape/utils.hpp"
+
+#include <memory>
+#include <vector>
+
+#include <ngraph/ngraph.hpp>
+#include <ngraph/pattern/matcher.hpp>
+#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
+#include <ngraph/opsets/opset4.hpp>
+
+bool relax_hc_reshape_followed_by_matmul(const ngraph::pattern::PatternValueMap & pattern_to_output,
+                                         const std::shared_ptr<ngraph::Node> & matmul_label,
+                                         const std::shared_ptr<ngraph::Node> & reshape_label,
+                                         const std::shared_ptr<ngraph::Node> & other_input_label,
+                                         const std::shared_ptr<ngraph::Node> & reshape_pattern_label,
+                                         bool reshape_is_A_input) {
+    auto reshape_pattern = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(reshape_pattern_label).get_node_shared_ptr());
+    const auto & matmul = std::dynamic_pointer_cast<ngraph::opset4::MatMul>(pattern_to_output.at(matmul_label).get_node_shared_ptr());
+    if (!reshape_pattern || !matmul || reshape_pattern->get_shape() != ngraph::Shape{2})
+        return false;
+    const auto &shape_source = pattern_to_output.at(other_input_label);
+    if (ngraph::is_type<ngraph::opset4::Transpose>(shape_source.get_node_shared_ptr()) ||
+            ngraph::is_type<ngraph::opset4::Reshape>(shape_source.get_node_shared_ptr()))
+        // avoiding loop creation
+        return false;
+    const auto & reshape = pattern_to_output.at(reshape_label).get_node_shared_ptr();
+
+    const auto & raw_idx = reshape_is_A_input ? (matmul->get_transpose_b() ? -1 : -2) : (matmul->get_transpose_a() ? -2 : -1);
+    const auto & idx = ngraph::normalize_axes(matmul->description(), {raw_idx}, reshape->get_output_partial_shape(0).rank());
+    const auto & C = ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_source(shape_source, idx);
+    const auto & N = ngraph::opset4::Constant::create(ngraph::element::i64, {1}, {-1});
+    const auto & pattern_vector = reshape_is_A_input ?
+            (matmul->get_transpose_a() ? ngraph::OutputVector({C, N}) : ngraph::OutputVector({N, C})) :
+            (matmul->get_transpose_b() ? ngraph::OutputVector({N, C}) : ngraph::OutputVector({C, N}));
+    const auto & new_reshape_pattern = std::make_shared<ngraph::opset4::Concat>(pattern_vector, 0);
+
+    new_reshape_pattern->set_friendly_name(reshape_pattern->get_friendly_name());
+    copy_runtime_info(reshape_pattern, new_reshape_pattern);
+    replace_node(reshape_pattern, new_reshape_pattern);
+    return true;
+}
+
+ngraph::pass::ReshapeAMatMul::ReshapeAMatMul() {
+    auto other_input_label = pattern::any_input();
+    auto reshape_input_label = pattern::any_input();
+    auto reshape_pattern_label = ngraph::pattern::wrap_type<opset4::Constant>();
+    auto reshape_label = ngraph::pattern::wrap_type<opset4::Reshape>({reshape_input_label, reshape_pattern_label});
+    auto matmul_label = ngraph::pattern::wrap_type<opset4::MatMul>({reshape_label, other_input_label});
+
+    matcher_pass_callback callback = [=](pattern::Matcher &m) -> bool {
+        const auto & pattern_to_output = m.get_pattern_value_map();
+        return relax_hc_reshape_followed_by_matmul(
+                pattern_to_output, matmul_label, reshape_label, other_input_label, reshape_pattern_label, true);
+    };
+    auto m = std::make_shared<ngraph::pattern::Matcher>(matmul_label, "ReshapeMatMul_A");
+    register_matcher(m, callback);
+}
+
+ngraph::pass::ReshapeBMatMul::ReshapeBMatMul() {
+    auto other_input_label = pattern::any_input();
+    auto reshape_input_label = pattern::any_input();
+    auto reshape_pattern_label = ngraph::pattern::wrap_type<opset4::Constant>();
+    auto reshape_label = ngraph::pattern::wrap_type<opset4::Reshape>({reshape_input_label, reshape_pattern_label});
+    auto matmul_label = ngraph::pattern::wrap_type<opset4::MatMul>({other_input_label, reshape_label});
+
+    matcher_pass_callback callback = [=](pattern::Matcher &m) -> bool {
+        const auto & pattern_to_output = m.get_pattern_value_map();
+        return relax_hc_reshape_followed_by_matmul(
+                pattern_to_output, matmul_label, reshape_label, other_input_label, reshape_pattern_label, false);
+    };
+    auto m = std::make_shared<ngraph::pattern::Matcher>(matmul_label, "ReshapeMatMul_B");
+    register_matcher(m, callback);
+}
\ No newline at end of file
diff --git a/inference-engine/src/transformations/src/transformations/smart_reshape/smart_reshape.cpp b/inference-engine/src/transformations/src/transformations/smart_reshape/smart_reshape.cpp
new file mode 100644 (file)
index 0000000..52a903a
--- /dev/null
@@ -0,0 +1,27 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <memory>
+
+#include "transformations/smart_reshape/smart_reshape.hpp"
+#include "transformations/smart_reshape/reshape_with_hc_output.hpp"
+#include "transformations/itt.hpp"
+
+#include <ngraph/pass/manager.hpp>
+#include <ngraph/pass/constant_folding.hpp>
+#include <transformations/init_node_info.hpp>
+
+bool ngraph::pass::SmartReshape::run_on_function(std::shared_ptr<ngraph::Function> f) {
+    OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::SmartReshape");
+
+    ngraph::pass::Manager manager;
+    // This pass must be called first in pipeline
+    manager.register_pass<ngraph::pass::InitNodeInfo>();
+
+    manager.register_pass<ngraph::pass::ReshapeAMatMul>();
+    manager.register_pass<ngraph::pass::ReshapeBMatMul>();
+
+    manager.run_passes(f);
+    return true;
+}
diff --git a/inference-engine/tests/functional/inference_engine/cnn_network/smart_reshape_tests.cpp b/inference-engine/tests/functional/inference_engine/cnn_network/smart_reshape_tests.cpp
new file mode 100644 (file)
index 0000000..f51d1ae
--- /dev/null
@@ -0,0 +1,80 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+
+#include <string>
+#include <memory>
+#include <map>
+
+#include <ngraph/opsets/opset4.hpp>
+#include <ngraph/function.hpp>
+#include <common_test_utils/ngraph_test_utils.hpp>
+
+#include "cnn_network_ngraph_impl.hpp"
+
+using namespace testing;
+using namespace InferenceEngine;
+
+namespace {
+
+using reshape_map = std::map<std::string, std::vector<size_t>>;
+
+struct ReshapeMatMulTestCase {
+    bool reshape_is_A_input;
+    ngraph::PartialShape A_shape, B_shape;
+    std::vector<int64_t> reshape_pattern;
+    bool transpose_a, transpose_b;
+    reshape_map new_shapes;
+};
+
+class CNNNGraphImplSmartReshapeTests : public CommonTestUtils::TestsCommon, public testing::WithParamInterface<std::tuple<ReshapeMatMulTestCase>> {
+public:
+    void SetUp() override {
+        const auto& parameters = GetParam();
+        const auto& test_case = std::get<0>(GetParam());
+
+        std::shared_ptr<ngraph::Function> ngraph;
+        {
+            auto input_A = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.A_shape);
+            input_A->set_friendly_name("input_A");
+            auto input_B = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.B_shape);
+            input_B->set_friendly_name("input_B");
+
+            auto reshape_pattern = std::make_shared<ngraph::opset4::Constant>(
+                    ngraph::element::i64, ngraph::Shape{test_case.reshape_pattern.size()}, test_case.reshape_pattern);
+            reshape_pattern->set_friendly_name("reshape_pattern");
+            auto reshape = std::make_shared<ngraph::opset4::Reshape>(test_case.reshape_is_A_input ? input_A : input_B, reshape_pattern, true);
+            reshape->set_friendly_name("reshape");
+
+            auto mat_mul = std::make_shared<ngraph::opset4::MatMul>(test_case.reshape_is_A_input ? reshape->output(0) : input_A->output(0),
+                                                                    test_case.reshape_is_A_input ? input_B->output(0) : reshape->output(0),
+                                                                    test_case.transpose_a, test_case.transpose_b);
+            reshape->set_friendly_name("matmul");
+
+            auto result = std::make_shared<ngraph::op::Result>(mat_mul);
+            ngraph::ParameterVector params = {input_A, input_B};
+            ngraph::ResultVector results = {result};
+            ngraph = std::make_shared<ngraph::Function>(results, params);
+        }
+
+        InferenceEngine::details::CNNNetworkNGraphImpl network(ngraph);
+        const auto & resp = network.reshape(test_case.new_shapes, nullptr);
+        ASSERT_EQ(resp, StatusCode::OK);
+    }
+};
+
+TEST_P(CNNNGraphImplSmartReshapeTests, ReshapeMatMul) {
+}
+
+INSTANTIATE_TEST_CASE_P(NGraph, CNNNGraphImplSmartReshapeTests, testing::Values(
+        ReshapeMatMulTestCase{true, {1, 20, 30}, {30, 40}, {20, -1}, false, false, {{"input_A", {2, 20, 30}}}},
+        ReshapeMatMulTestCase{true, {1, 20, 30}, {40, 30}, {20, -1}, false, true, {{"input_A", {2, 20, 30}}}},
+        ReshapeMatMulTestCase{true, {1, 30, 20}, {30, 20}, {-1, 20}, true, false, {{"input_A", {2, 30, 20}}}},
+        ReshapeMatMulTestCase{true, {1, 30, 20}, {40, 30}, {-1, 20}, true, true, {{"input_A", {2, 30, 20}}}},
+        ReshapeMatMulTestCase{false, {20, 30}, {1, 30, 40}, {-1, 40}, false, false, {{"input_B", {2, 30, 40}}}},
+        ReshapeMatMulTestCase{false, {20, 30}, {1, 40, 30}, {40, -1}, false, true, {{"input_B", {2, 40, 30}}}},
+        ReshapeMatMulTestCase{false, {30, 20}, {1, 30, 40}, {-1, 40}, true, false, {{"input_B", {2, 30, 40}}}},
+        ReshapeMatMulTestCase{false, {30, 20}, {1, 40, 30}, {40, -1}, true, true, {{"input_B", {2, 40, 30}}}}));
+}  // namespace
\ No newline at end of file