[ SSR ] Transpose->MatMul (#2525)
authorEvgenya Stepyreva <evgenya.stepyreva@intel.com>
Wed, 28 Oct 2020 17:06:42 +0000 (20:06 +0300)
committerGitHub <noreply@github.com>
Wed, 28 Oct 2020 17:06:42 +0000 (20:06 +0300)
* [ SSR ] Reshape->Softmax->Reshape

* Call DepthToSpaceFusion during SmartReshape

* rtti

* remove softmax wa

inference-engine/src/transformations/include/transformations/smart_reshape/matmul_sr.hpp [moved from inference-engine/src/transformations/include/transformations/smart_reshape/reshape_with_hc_output.hpp with 84% similarity]
inference-engine/src/transformations/include/transformations/smart_reshape/utils.hpp
inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp
inference-engine/src/transformations/src/transformations/smart_reshape/matmul_sr.cpp [moved from inference-engine/src/transformations/src/transformations/smart_reshape/reshape_with_hc_output.cpp with 55% similarity]
inference-engine/src/transformations/src/transformations/smart_reshape/smart_reshape.cpp
inference-engine/tests/functional/inference_engine/cnn_network/matmul_sr_tests.cpp [new file with mode: 0644]
inference-engine/tests/functional/inference_engine/cnn_network/smart_reshape_tests.cpp [deleted file]

@@ -16,6 +16,7 @@ namespace pass {
 
 class TRANSFORMATIONS_API ReshapeAMatMul;
 class TRANSFORMATIONS_API ReshapeBMatMul;
+class TRANSFORMATIONS_API TransposeMatMul;
 
 }  // namespace pass
 }  // namespace ngraph
@@ -37,4 +38,9 @@ class ngraph::pass::ReshapeBMatMul: public ngraph::pass::MatcherPass {
 public:
     NGRAPH_RTTI_DECLARATION;
     ReshapeBMatMul();
+};
+class ngraph::pass::TransposeMatMul: public ngraph::pass::MatcherPass {
+public:
+    NGRAPH_RTTI_DECLARATION;
+    TransposeMatMul();
 };
\ No newline at end of file
index a9d0888..84765b0 100644 (file)
@@ -4,16 +4,8 @@
 
 #pragma once
 
-#include <functional>
-#include <memory>
-#include <assert.h>
-#include <vector>
-#include <limits>
-
 #include <transformations_visibility.hpp>
 #include <ngraph/op/util/op_annotations.hpp>
-#include <ngraph/op/constant.hpp>
-#include <ngraph/opsets/opset3.hpp>
 #include <ngraph/opsets/opset4.hpp>
 
 namespace ngraph {
index ecce914..cbc040f 100644 (file)
@@ -61,8 +61,8 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
     manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
     manager.register_pass<ngraph::pass::ConstantFolding>();
     manager.register_pass<ngraph::pass::StridedSliceOptimization>(); // depends on CF
-    manager.register_pass<ngraph::pass::NopElimination>(); // may introduce fake dynamism
     manager.register_pass<ngraph::pass::AlgebraicSimplification>(); // may introduce fake dynamism
+    manager.register_pass<ngraph::pass::NopElimination>(); // may introduce fake dynamism
     manager.register_pass<ngraph::pass::ConstantFolding>();
     manager.register_pass<ngraph::pass::ConvertScatterElementsToScatter>(); // partially depends on CF
     manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
@@ -2,11 +2,11 @@
 // SPDX-License-Identifier: Apache-2.0
 //
 
-#include "transformations/smart_reshape/reshape_with_hc_output.hpp"
+#include "transformations/smart_reshape/matmul_sr.hpp"
 #include "transformations/smart_reshape/utils.hpp"
 
+#include <numeric>
 #include <memory>
-#include <vector>
 
 #include <ngraph/ngraph.hpp>
 #include <ngraph/pattern/matcher.hpp>
@@ -20,19 +20,18 @@ bool relax_hc_reshape_followed_by_matmul(const ngraph::pattern::PatternValueMap
                                          const std::shared_ptr<ngraph::Node> & other_input_label,
                                          const std::shared_ptr<ngraph::Node> & reshape_pattern_label,
                                          bool reshape_is_A_input) {
-    auto reshape_pattern = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(reshape_pattern_label).get_node_shared_ptr());
+    const auto & reshape_rank = pattern_to_output.at(reshape_label).get_partial_shape().rank();
     const auto & matmul = std::dynamic_pointer_cast<ngraph::opset4::MatMul>(pattern_to_output.at(matmul_label).get_node_shared_ptr());
-    if (!reshape_pattern || !matmul || reshape_pattern->get_shape() != ngraph::Shape{2})
+    if (!matmul || reshape_rank.is_dynamic() || reshape_rank.get_length() != 2)
         return false;
     const auto &shape_source = pattern_to_output.at(other_input_label);
     if (ngraph::is_type<ngraph::opset4::Transpose>(shape_source.get_node_shared_ptr()) ||
             ngraph::is_type<ngraph::opset4::Reshape>(shape_source.get_node_shared_ptr()))
         // avoiding loop creation
         return false;
-    const auto & reshape = pattern_to_output.at(reshape_label).get_node_shared_ptr();
 
     const auto & raw_idx = reshape_is_A_input ? (matmul->get_transpose_b() ? -1 : -2) : (matmul->get_transpose_a() ? -2 : -1);
-    const auto & idx = ngraph::normalize_axes(matmul->description(), {raw_idx}, reshape->get_output_partial_shape(0).rank());
+    const auto & idx = ngraph::normalize_axes(matmul->description(), {raw_idx}, reshape_rank);
     const auto & C = ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_source(shape_source, idx);
     const auto & N = ngraph::opset4::Constant::create(ngraph::element::i64, {1}, {-1});
     const auto & pattern_vector = reshape_is_A_input ?
@@ -40,6 +39,7 @@ bool relax_hc_reshape_followed_by_matmul(const ngraph::pattern::PatternValueMap
             (matmul->get_transpose_b() ? ngraph::OutputVector({N, C}) : ngraph::OutputVector({C, N}));
     const auto & new_reshape_pattern = std::make_shared<ngraph::opset4::Concat>(pattern_vector, 0);
 
+    auto reshape_pattern = pattern_to_output.at(reshape_pattern_label).get_node_shared_ptr();
     new_reshape_pattern->set_friendly_name(reshape_pattern->get_friendly_name());
     copy_runtime_info(reshape_pattern, new_reshape_pattern);
     replace_node(reshape_pattern, new_reshape_pattern);
@@ -51,7 +51,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ReshapeAMatMul, "ReshapeAMatMul", 0);
 ngraph::pass::ReshapeAMatMul::ReshapeAMatMul() {
     auto other_input_label = pattern::any_input();
     auto reshape_input_label = pattern::any_input();
-    auto reshape_pattern_label = ngraph::pattern::wrap_type<opset4::Constant>();
+    auto reshape_pattern_label = pattern::any_input();
     auto reshape_label = ngraph::pattern::wrap_type<opset4::Reshape>({reshape_input_label, reshape_pattern_label});
     auto matmul_label = ngraph::pattern::wrap_type<opset4::MatMul>({reshape_label, other_input_label});
 
@@ -69,7 +69,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ReshapeBMatMul, "ReshapeBMatMul", 0);
 ngraph::pass::ReshapeBMatMul::ReshapeBMatMul() {
     auto other_input_label = pattern::any_input();
     auto reshape_input_label = pattern::any_input();
-    auto reshape_pattern_label = ngraph::pattern::wrap_type<opset4::Constant>();
+    auto reshape_pattern_label = pattern::any_input();
     auto reshape_label = ngraph::pattern::wrap_type<opset4::Reshape>({reshape_input_label, reshape_pattern_label});
     auto matmul_label = ngraph::pattern::wrap_type<opset4::MatMul>({other_input_label, reshape_label});
 
@@ -80,4 +80,62 @@ ngraph::pass::ReshapeBMatMul::ReshapeBMatMul() {
     };
     auto m = std::make_shared<ngraph::pattern::Matcher>(matmul_label, "ReshapeMatMul_B");
     register_matcher(m, callback);
+}
+
+NGRAPH_RTTI_DEFINITION(ngraph::pass::TransposeMatMul, "TransposeMatMul", 0);
+
+ngraph::pass::TransposeMatMul::TransposeMatMul() {
+    auto matmul_label = ngraph::pattern::wrap_type<opset4::MatMul>();
+
+    matcher_pass_callback callback = [=](pattern::Matcher &m) -> bool {
+        const auto & pattern_to_output = m.get_pattern_value_map();
+        auto matmul = std::dynamic_pointer_cast<ngraph::opset4::MatMul>(pattern_to_output.at(matmul_label).get_node_shared_ptr());
+        if (!matmul)
+            return false;
+
+        auto transpose_is_fusable = [](const std::shared_ptr<ngraph::Node>& input) {
+            const auto & input_rank = input->get_output_partial_shape(0).rank();
+            if (input_rank.is_static() && input_rank.get_length() >= 2) {
+                if (auto transpose = std::dynamic_pointer_cast<ngraph::opset4::Transpose>(input)) {
+                    if (auto order = std::dynamic_pointer_cast<opset4::Constant>(transpose->get_input_node_shared_ptr(1))) {
+                        const auto & order_vector = order->cast_vector<int64_t>();
+                        std::vector<int64_t> fusable_order(input_rank.get_length());
+                        std::iota(fusable_order.begin(), fusable_order.end(), 0);
+                        std::swap(fusable_order[input_rank.get_length() - 1], fusable_order[input_rank.get_length() - 2]);
+                        return order_vector == fusable_order;
+                    }
+                }
+            }
+            return false;
+        };
+
+        NodeVector fused_nodes;
+        auto input_A = matmul->get_input_node_shared_ptr(0);
+        bool transpose_A = matmul->get_transpose_a();
+        if (transpose_is_fusable(input_A)) {
+            fused_nodes.push_back(input_A);
+            input_A = input_A->get_input_node_shared_ptr(0);
+            transpose_A = !transpose_A;
+        }
+
+        auto input_B = matmul->get_input_node_shared_ptr(1);
+        auto transpose_B = matmul->get_transpose_b();
+        if (transpose_is_fusable(input_B)) {
+            fused_nodes.push_back(input_B);
+            input_B = input_B->get_input_node_shared_ptr(0);
+            transpose_B = !transpose_B;
+        }
+
+        if (!fused_nodes.empty()) {
+            auto updated_matmul = std::make_shared<opset4::MatMul>(input_A, input_B, transpose_A, transpose_B);
+            fused_nodes.push_back(matmul);
+            copy_runtime_info(fused_nodes, updated_matmul);
+            updated_matmul->set_friendly_name(matmul->get_friendly_name());
+            replace_node(matmul, updated_matmul);
+            return true;
+        }
+        return false;
+    };
+    auto m = std::make_shared<ngraph::pattern::Matcher>(matmul_label, "TransposeMatMul");
+    register_matcher(m, callback);
 }
\ No newline at end of file
index 5c35000..305b6b1 100644 (file)
@@ -10,7 +10,7 @@
 #include <transformations/itt.hpp>
 #include <transformations/smart_reshape/proposal_scales_stridedslice.hpp>
 #include <transformations/smart_reshape/reshape_to_1D.hpp>
-#include <transformations/smart_reshape/reshape_with_hc_output.hpp>
+#include <transformations/smart_reshape/matmul_sr.hpp>
 #include <transformations/smart_reshape/smart_reshape.hpp>
 #include <transformations/smart_reshape/strided_slice_squeeze.hpp>
 #include <transformations/smart_reshape/mimic_set_batch_size.hpp>
@@ -30,6 +30,7 @@ bool ngraph::pass::SmartReshape::run_on_function(std::shared_ptr<ngraph::Functio
     static_manager.register_pass<ngraph::pass::SqueezeStridedSlice>();
     static_manager.register_pass<ngraph::pass::StridedSliceSqueeze>();
     static_manager.register_pass<ngraph::pass::ReshapeTo1D>();
+    static_manager.register_pass<ngraph::pass::TransposeMatMul>();
     static_manager.run_passes(f);
 
     ngraph::pass::Manager dynamic_manager;
diff --git a/inference-engine/tests/functional/inference_engine/cnn_network/matmul_sr_tests.cpp b/inference-engine/tests/functional/inference_engine/cnn_network/matmul_sr_tests.cpp
new file mode 100644 (file)
index 0000000..4941cc9
--- /dev/null
@@ -0,0 +1,299 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+
+#include <string>
+#include <memory>
+#include <map>
+
+#include <ngraph/opsets/opset4.hpp>
+#include <ngraph/function.hpp>
+#include <common_test_utils/ngraph_test_utils.hpp>
+#include <ngraph/pass/manager.hpp>
+#include <transformations/init_node_info.hpp>
+#include <transformations/smart_reshape/matmul_sr.hpp>
+
+#include "cnn_network_ngraph_impl.hpp"
+
+using namespace testing;
+using namespace InferenceEngine;
+
+namespace {
+
+using reshape_map = std::map<std::string, std::vector<size_t>>;
+
+struct ReshapeMatMulTestCase {
+    bool reshape_is_A_input;
+    ngraph::PartialShape A_shape, B_shape;
+    std::vector<int64_t> reshape_pattern;
+    bool transpose_a, transpose_b;
+    reshape_map new_shapes;
+};
+
+class SmartReshapeMatMulTests : public CommonTestUtils::TestsCommon, public testing::WithParamInterface<std::tuple<ReshapeMatMulTestCase>> {
+public:
+    void SetUp() override {
+        const auto& parameters = GetParam();
+        const auto& test_case = std::get<0>(GetParam());
+
+        std::shared_ptr<ngraph::Function> ngraph;
+        {
+            auto input_A = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.A_shape);
+            input_A->set_friendly_name("input_A");
+            auto input_B = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.B_shape);
+            input_B->set_friendly_name("input_B");
+
+            auto reshape_pattern = std::make_shared<ngraph::opset4::Constant>(
+                    ngraph::element::i64, ngraph::Shape{test_case.reshape_pattern.size()}, test_case.reshape_pattern);
+            reshape_pattern->set_friendly_name("reshape_pattern");
+            auto reshape = std::make_shared<ngraph::opset4::Reshape>(test_case.reshape_is_A_input ? input_A : input_B, reshape_pattern, true);
+            reshape->set_friendly_name("reshape");
+
+            auto mat_mul = std::make_shared<ngraph::opset4::MatMul>(test_case.reshape_is_A_input ? reshape->output(0) : input_A->output(0),
+                                                                    test_case.reshape_is_A_input ? input_B->output(0) : reshape->output(0),
+                                                                    test_case.transpose_a, test_case.transpose_b);
+            reshape->set_friendly_name("matmul");
+
+            auto result = std::make_shared<ngraph::op::Result>(mat_mul);
+            ngraph::ParameterVector params = {input_A, input_B};
+            ngraph::ResultVector results = {result};
+            ngraph = std::make_shared<ngraph::Function>(results, params);
+        }
+
+        InferenceEngine::details::CNNNetworkNGraphImpl network(ngraph);
+        const auto & resp = network.reshape(test_case.new_shapes, nullptr);
+        ASSERT_EQ(resp, StatusCode::OK);
+    }
+};
+
+TEST_P(SmartReshapeMatMulTests, ReshapeMatMul) {
+}
+
+INSTANTIATE_TEST_CASE_P(NGraph, SmartReshapeMatMulTests, testing::Values(
+        ReshapeMatMulTestCase{true, {1, 20, 30}, {30, 40}, {20, -1}, false, false, {{"input_A", {2, 20, 30}}}},
+        ReshapeMatMulTestCase{true, {1, 20, 30}, {40, 30}, {20, -1}, false, true, {{"input_A", {2, 20, 30}}}},
+        ReshapeMatMulTestCase{true, {1, 30, 20}, {30, 20}, {-1, 20}, true, false, {{"input_A", {2, 30, 20}}}},
+        ReshapeMatMulTestCase{true, {1, 30, 20}, {40, 30}, {-1, 20}, true, true, {{"input_A", {2, 30, 20}}}},
+        ReshapeMatMulTestCase{false, {20, 30}, {1, 30, 40}, {-1, 40}, false, false, {{"input_B", {2, 30, 40}}}},
+        ReshapeMatMulTestCase{false, {20, 30}, {1, 40, 30}, {40, -1}, false, true, {{"input_B", {2, 40, 30}}}},
+        ReshapeMatMulTestCase{false, {30, 20}, {1, 30, 40}, {-1, 40}, true, false, {{"input_B", {2, 30, 40}}}},
+        ReshapeMatMulTestCase{false, {30, 20}, {1, 40, 30}, {40, -1}, true, true, {{"input_B", {2, 40, 30}}}}));
+}  // namespace
+
+TEST(SmartReshapeTransposeMatMulTests, TransposeAMatMulFuse) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto data_A = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 2});
+        auto data_B = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 5});
+        auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1});
+        auto transpose = std::make_shared<ngraph::opset4::Transpose>(data_A, order);
+        auto matmul = std::make_shared<ngraph::opset4::MatMul>(transpose, data_B, false, false);
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B});
+
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::TransposeMatMul>();
+        m.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+    {
+        auto data_A = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 2});
+        auto data_B = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 5});
+        auto matmul = std::make_shared<ngraph::opset4::MatMul>(data_A, data_B, true, false);
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(SmartReshapeTransposeMatMulTests, TransposeBMatMulFuse) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto data_A = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 2, 3});
+        auto data_B = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 5, 3});
+        auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1});
+        auto transpose = std::make_shared<ngraph::opset4::Transpose>(data_B, order);
+        auto matmul = std::make_shared<ngraph::opset4::MatMul>(data_A, transpose, false, false);
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B});
+
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::TransposeMatMul>();
+        m.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+    {
+        auto data_A = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 2, 3});
+        auto data_B = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 5, 3});
+        auto matmul = std::make_shared<ngraph::opset4::MatMul>(data_A, data_B, false, true);
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(SmartReshapeTransposeMatMulTests, TransposeAMatMulWithAttrFuse) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto data_A = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 2, 3});
+        auto data_B = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 5});
+        auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1});
+        auto transpose = std::make_shared<ngraph::opset4::Transpose>(data_A, order);
+        auto matmul = std::make_shared<ngraph::opset4::MatMul>(transpose, data_B, true, false);
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B});
+
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::TransposeMatMul>();
+        m.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+    {
+        auto data_A = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 2, 3});
+        auto data_B = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 5});
+        auto matmul = std::make_shared<ngraph::opset4::MatMul>(data_A, data_B, false, false);
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(SmartReshapeTransposeMatMulTests, TransposeBMatMulWithAttrFuse) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto data_A = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 2, 3});
+        auto data_B = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 5});
+        auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1});
+        auto transpose = std::make_shared<ngraph::opset4::Transpose>(data_B, order);
+        auto matmul = std::make_shared<ngraph::opset4::MatMul>(data_A, transpose, false, true);
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B});
+
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::TransposeMatMul>();
+        m.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+    {
+        auto data_A = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 2, 3});
+        auto data_B = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 5});
+        auto matmul = std::make_shared<ngraph::opset4::MatMul>(data_A, data_B, false, false);
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+TEST(SmartReshapeTransposeMatMulTests, TransposeAMatMulSideAttrFuse) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto data_A = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 2, 3});
+        auto data_B = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 5, 3});
+        auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1});
+        auto transpose = std::make_shared<ngraph::opset4::Transpose>(data_A, order);
+        auto matmul = std::make_shared<ngraph::opset4::MatMul>(transpose, data_B, true, true);
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B});
+
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::TransposeMatMul>();
+        m.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+    {
+        auto data_A = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 2, 3});
+        auto data_B = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 5, 3});
+        auto matmul = std::make_shared<ngraph::opset4::MatMul>(data_A, data_B, false, true);
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(SmartReshapeTransposeMatMulTests, TransposeBMatMulSideAttrFuse) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto data_A = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 2});
+        auto data_B = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 5});
+        auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1});
+        auto transpose = std::make_shared<ngraph::opset4::Transpose>(data_B, order);
+        auto matmul = std::make_shared<ngraph::opset4::MatMul>(data_A, transpose, true, true);
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B});
+
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::TransposeMatMul>();
+        m.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+    {
+        auto data_A = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 2});
+        auto data_B = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 5});
+        auto matmul = std::make_shared<ngraph::opset4::MatMul>(data_A, data_B, true, false);
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(SmartReshapeTransposeMatMulTests, TransposeBothMatMulFuse) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto data_A = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 2});
+        auto data_B = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 5, 3});
+        auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1});
+        auto transpose_A = std::make_shared<ngraph::opset4::Transpose>(data_A, order);
+        auto transpose_B = std::make_shared<ngraph::opset4::Transpose>(data_B, order);
+        auto matmul = std::make_shared<ngraph::opset4::MatMul>(transpose_A, transpose_B, false, false);
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B});
+
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::TransposeMatMul>();
+        m.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+    {
+        auto data_A = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 2});
+        auto data_B = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 5, 3});
+        auto matmul = std::make_shared<ngraph::opset4::MatMul>(data_A, data_B, true, true);
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+TEST(SmartReshapeTransposeMatMulTests, TransposeBothMatMulWithAttrFuse) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto data_A = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 2});
+        auto data_B = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 5});
+        auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1});
+        auto transpose_A = std::make_shared<ngraph::opset4::Transpose>(data_A, order);
+        auto transpose_B = std::make_shared<ngraph::opset4::Transpose>(data_B, order);
+        auto matmul = std::make_shared<ngraph::opset4::MatMul>(transpose_A, transpose_B, false, true);
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B});
+
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::TransposeMatMul>();
+        m.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+    {
+        auto data_A = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 2});
+        auto data_B = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 5});
+        auto matmul = std::make_shared<ngraph::opset4::MatMul>(data_A, data_B, true, false);
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
diff --git a/inference-engine/tests/functional/inference_engine/cnn_network/smart_reshape_tests.cpp b/inference-engine/tests/functional/inference_engine/cnn_network/smart_reshape_tests.cpp
deleted file mode 100644 (file)
index f51d1ae..0000000
+++ /dev/null
@@ -1,80 +0,0 @@
-// Copyright (C) 2018-2020 Intel Corporation
-// SPDX-License-Identifier: Apache-2.0
-//
-
-#include <gtest/gtest.h>
-
-#include <string>
-#include <memory>
-#include <map>
-
-#include <ngraph/opsets/opset4.hpp>
-#include <ngraph/function.hpp>
-#include <common_test_utils/ngraph_test_utils.hpp>
-
-#include "cnn_network_ngraph_impl.hpp"
-
-using namespace testing;
-using namespace InferenceEngine;
-
-namespace {
-
-using reshape_map = std::map<std::string, std::vector<size_t>>;
-
-struct ReshapeMatMulTestCase {
-    bool reshape_is_A_input;
-    ngraph::PartialShape A_shape, B_shape;
-    std::vector<int64_t> reshape_pattern;
-    bool transpose_a, transpose_b;
-    reshape_map new_shapes;
-};
-
-class CNNNGraphImplSmartReshapeTests : public CommonTestUtils::TestsCommon, public testing::WithParamInterface<std::tuple<ReshapeMatMulTestCase>> {
-public:
-    void SetUp() override {
-        const auto& parameters = GetParam();
-        const auto& test_case = std::get<0>(GetParam());
-
-        std::shared_ptr<ngraph::Function> ngraph;
-        {
-            auto input_A = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.A_shape);
-            input_A->set_friendly_name("input_A");
-            auto input_B = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.B_shape);
-            input_B->set_friendly_name("input_B");
-
-            auto reshape_pattern = std::make_shared<ngraph::opset4::Constant>(
-                    ngraph::element::i64, ngraph::Shape{test_case.reshape_pattern.size()}, test_case.reshape_pattern);
-            reshape_pattern->set_friendly_name("reshape_pattern");
-            auto reshape = std::make_shared<ngraph::opset4::Reshape>(test_case.reshape_is_A_input ? input_A : input_B, reshape_pattern, true);
-            reshape->set_friendly_name("reshape");
-
-            auto mat_mul = std::make_shared<ngraph::opset4::MatMul>(test_case.reshape_is_A_input ? reshape->output(0) : input_A->output(0),
-                                                                    test_case.reshape_is_A_input ? input_B->output(0) : reshape->output(0),
-                                                                    test_case.transpose_a, test_case.transpose_b);
-            reshape->set_friendly_name("matmul");
-
-            auto result = std::make_shared<ngraph::op::Result>(mat_mul);
-            ngraph::ParameterVector params = {input_A, input_B};
-            ngraph::ResultVector results = {result};
-            ngraph = std::make_shared<ngraph::Function>(results, params);
-        }
-
-        InferenceEngine::details::CNNNetworkNGraphImpl network(ngraph);
-        const auto & resp = network.reshape(test_case.new_shapes, nullptr);
-        ASSERT_EQ(resp, StatusCode::OK);
-    }
-};
-
-TEST_P(CNNNGraphImplSmartReshapeTests, ReshapeMatMul) {
-}
-
-INSTANTIATE_TEST_CASE_P(NGraph, CNNNGraphImplSmartReshapeTests, testing::Values(
-        ReshapeMatMulTestCase{true, {1, 20, 30}, {30, 40}, {20, -1}, false, false, {{"input_A", {2, 20, 30}}}},
-        ReshapeMatMulTestCase{true, {1, 20, 30}, {40, 30}, {20, -1}, false, true, {{"input_A", {2, 20, 30}}}},
-        ReshapeMatMulTestCase{true, {1, 30, 20}, {30, 20}, {-1, 20}, true, false, {{"input_A", {2, 30, 20}}}},
-        ReshapeMatMulTestCase{true, {1, 30, 20}, {40, 30}, {-1, 20}, true, true, {{"input_A", {2, 30, 20}}}},
-        ReshapeMatMulTestCase{false, {20, 30}, {1, 30, 40}, {-1, 40}, false, false, {{"input_B", {2, 30, 40}}}},
-        ReshapeMatMulTestCase{false, {20, 30}, {1, 40, 30}, {40, -1}, false, true, {{"input_B", {2, 40, 30}}}},
-        ReshapeMatMulTestCase{false, {30, 20}, {1, 30, 40}, {-1, 40}, true, false, {{"input_B", {2, 30, 40}}}},
-        ReshapeMatMulTestCase{false, {30, 20}, {1, 40, 30}, {40, -1}, true, true, {{"input_B", {2, 40, 30}}}}));
-}  // namespace
\ No newline at end of file