From 13a9ba6a2baa70da260b7cae6b124ae09a88423c Mon Sep 17 00:00:00 2001 From: Evgenya Stepyreva Date: Wed, 28 Oct 2020 20:06:42 +0300 Subject: [PATCH] [ SSR ] Transpose->MatMul (#2525) * [ SSR ] Reshape->Softmax->Reshape * Call DepthToSpaceFusion during SmartReshape * rtti * remove softmax wa --- .../{reshape_with_hc_output.hpp => matmul_sr.hpp} | 6 + .../transformations/smart_reshape/utils.hpp | 8 - .../common_optimizations/common_optimizations.cpp | 2 +- .../{reshape_with_hc_output.cpp => matmul_sr.cpp} | 74 ++++- .../smart_reshape/smart_reshape.cpp | 3 +- .../cnn_network/matmul_sr_tests.cpp | 299 +++++++++++++++++++++ .../cnn_network/smart_reshape_tests.cpp | 80 ------ 7 files changed, 374 insertions(+), 98 deletions(-) rename inference-engine/src/transformations/include/transformations/smart_reshape/{reshape_with_hc_output.hpp => matmul_sr.hpp} (84%) rename inference-engine/src/transformations/src/transformations/smart_reshape/{reshape_with_hc_output.cpp => matmul_sr.cpp} (55%) create mode 100644 inference-engine/tests/functional/inference_engine/cnn_network/matmul_sr_tests.cpp delete mode 100644 inference-engine/tests/functional/inference_engine/cnn_network/smart_reshape_tests.cpp diff --git a/inference-engine/src/transformations/include/transformations/smart_reshape/reshape_with_hc_output.hpp b/inference-engine/src/transformations/include/transformations/smart_reshape/matmul_sr.hpp similarity index 84% rename from inference-engine/src/transformations/include/transformations/smart_reshape/reshape_with_hc_output.hpp rename to inference-engine/src/transformations/include/transformations/smart_reshape/matmul_sr.hpp index 798a01d..5149a57 100644 --- a/inference-engine/src/transformations/include/transformations/smart_reshape/reshape_with_hc_output.hpp +++ b/inference-engine/src/transformations/include/transformations/smart_reshape/matmul_sr.hpp @@ -16,6 +16,7 @@ namespace pass { class TRANSFORMATIONS_API ReshapeAMatMul; class TRANSFORMATIONS_API ReshapeBMatMul; +class TRANSFORMATIONS_API TransposeMatMul; } // namespace pass } // namespace ngraph @@ -37,4 +38,9 @@ class ngraph::pass::ReshapeBMatMul: public ngraph::pass::MatcherPass { public: NGRAPH_RTTI_DECLARATION; ReshapeBMatMul(); +}; +class ngraph::pass::TransposeMatMul: public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + TransposeMatMul(); }; \ No newline at end of file diff --git a/inference-engine/src/transformations/include/transformations/smart_reshape/utils.hpp b/inference-engine/src/transformations/include/transformations/smart_reshape/utils.hpp index a9d0888..84765b0 100644 --- a/inference-engine/src/transformations/include/transformations/smart_reshape/utils.hpp +++ b/inference-engine/src/transformations/include/transformations/smart_reshape/utils.hpp @@ -4,16 +4,8 @@ #pragma once -#include -#include -#include -#include -#include - #include #include -#include -#include #include namespace ngraph { diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp index ecce914..cbc040f 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -61,8 +61,8 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr(); manager.register_pass(); manager.register_pass(); // depends on CF - manager.register_pass(); // may introduce fake dynamism manager.register_pass(); // may introduce fake dynamism + manager.register_pass(); // may introduce fake dynamism manager.register_pass(); manager.register_pass(); // partially depends on CF manager.register_pass(); diff --git a/inference-engine/src/transformations/src/transformations/smart_reshape/reshape_with_hc_output.cpp b/inference-engine/src/transformations/src/transformations/smart_reshape/matmul_sr.cpp similarity index 55% rename from inference-engine/src/transformations/src/transformations/smart_reshape/reshape_with_hc_output.cpp rename to inference-engine/src/transformations/src/transformations/smart_reshape/matmul_sr.cpp index 31f48e5..311392e 100644 --- a/inference-engine/src/transformations/src/transformations/smart_reshape/reshape_with_hc_output.cpp +++ b/inference-engine/src/transformations/src/transformations/smart_reshape/matmul_sr.cpp @@ -2,11 +2,11 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "transformations/smart_reshape/reshape_with_hc_output.hpp" +#include "transformations/smart_reshape/matmul_sr.hpp" #include "transformations/smart_reshape/utils.hpp" +#include #include -#include #include #include @@ -20,19 +20,18 @@ bool relax_hc_reshape_followed_by_matmul(const ngraph::pattern::PatternValueMap const std::shared_ptr & other_input_label, const std::shared_ptr & reshape_pattern_label, bool reshape_is_A_input) { - auto reshape_pattern = std::dynamic_pointer_cast(pattern_to_output.at(reshape_pattern_label).get_node_shared_ptr()); + const auto & reshape_rank = pattern_to_output.at(reshape_label).get_partial_shape().rank(); const auto & matmul = std::dynamic_pointer_cast(pattern_to_output.at(matmul_label).get_node_shared_ptr()); - if (!reshape_pattern || !matmul || reshape_pattern->get_shape() != ngraph::Shape{2}) + if (!matmul || reshape_rank.is_dynamic() || reshape_rank.get_length() != 2) return false; const auto &shape_source = pattern_to_output.at(other_input_label); if (ngraph::is_type(shape_source.get_node_shared_ptr()) || ngraph::is_type(shape_source.get_node_shared_ptr())) // avoiding loop creation return false; - const auto & reshape = pattern_to_output.at(reshape_label).get_node_shared_ptr(); const auto & raw_idx = reshape_is_A_input ? (matmul->get_transpose_b() ? -1 : -2) : (matmul->get_transpose_a() ? -2 : -1); - const auto & idx = ngraph::normalize_axes(matmul->description(), {raw_idx}, reshape->get_output_partial_shape(0).rank()); + const auto & idx = ngraph::normalize_axes(matmul->description(), {raw_idx}, reshape_rank); const auto & C = ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_source(shape_source, idx); const auto & N = ngraph::opset4::Constant::create(ngraph::element::i64, {1}, {-1}); const auto & pattern_vector = reshape_is_A_input ? @@ -40,6 +39,7 @@ bool relax_hc_reshape_followed_by_matmul(const ngraph::pattern::PatternValueMap (matmul->get_transpose_b() ? ngraph::OutputVector({N, C}) : ngraph::OutputVector({C, N})); const auto & new_reshape_pattern = std::make_shared(pattern_vector, 0); + auto reshape_pattern = pattern_to_output.at(reshape_pattern_label).get_node_shared_ptr(); new_reshape_pattern->set_friendly_name(reshape_pattern->get_friendly_name()); copy_runtime_info(reshape_pattern, new_reshape_pattern); replace_node(reshape_pattern, new_reshape_pattern); @@ -51,7 +51,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ReshapeAMatMul, "ReshapeAMatMul", 0); ngraph::pass::ReshapeAMatMul::ReshapeAMatMul() { auto other_input_label = pattern::any_input(); auto reshape_input_label = pattern::any_input(); - auto reshape_pattern_label = ngraph::pattern::wrap_type(); + auto reshape_pattern_label = pattern::any_input(); auto reshape_label = ngraph::pattern::wrap_type({reshape_input_label, reshape_pattern_label}); auto matmul_label = ngraph::pattern::wrap_type({reshape_label, other_input_label}); @@ -69,7 +69,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ReshapeBMatMul, "ReshapeBMatMul", 0); ngraph::pass::ReshapeBMatMul::ReshapeBMatMul() { auto other_input_label = pattern::any_input(); auto reshape_input_label = pattern::any_input(); - auto reshape_pattern_label = ngraph::pattern::wrap_type(); + auto reshape_pattern_label = pattern::any_input(); auto reshape_label = ngraph::pattern::wrap_type({reshape_input_label, reshape_pattern_label}); auto matmul_label = ngraph::pattern::wrap_type({other_input_label, reshape_label}); @@ -80,4 +80,62 @@ ngraph::pass::ReshapeBMatMul::ReshapeBMatMul() { }; auto m = std::make_shared(matmul_label, "ReshapeMatMul_B"); register_matcher(m, callback); +} + +NGRAPH_RTTI_DEFINITION(ngraph::pass::TransposeMatMul, "TransposeMatMul", 0); + +ngraph::pass::TransposeMatMul::TransposeMatMul() { + auto matmul_label = ngraph::pattern::wrap_type(); + + matcher_pass_callback callback = [=](pattern::Matcher &m) -> bool { + const auto & pattern_to_output = m.get_pattern_value_map(); + auto matmul = std::dynamic_pointer_cast(pattern_to_output.at(matmul_label).get_node_shared_ptr()); + if (!matmul) + return false; + + auto transpose_is_fusable = [](const std::shared_ptr& input) { + const auto & input_rank = input->get_output_partial_shape(0).rank(); + if (input_rank.is_static() && input_rank.get_length() >= 2) { + if (auto transpose = std::dynamic_pointer_cast(input)) { + if (auto order = std::dynamic_pointer_cast(transpose->get_input_node_shared_ptr(1))) { + const auto & order_vector = order->cast_vector(); + std::vector fusable_order(input_rank.get_length()); + std::iota(fusable_order.begin(), fusable_order.end(), 0); + std::swap(fusable_order[input_rank.get_length() - 1], fusable_order[input_rank.get_length() - 2]); + return order_vector == fusable_order; + } + } + } + return false; + }; + + NodeVector fused_nodes; + auto input_A = matmul->get_input_node_shared_ptr(0); + bool transpose_A = matmul->get_transpose_a(); + if (transpose_is_fusable(input_A)) { + fused_nodes.push_back(input_A); + input_A = input_A->get_input_node_shared_ptr(0); + transpose_A = !transpose_A; + } + + auto input_B = matmul->get_input_node_shared_ptr(1); + auto transpose_B = matmul->get_transpose_b(); + if (transpose_is_fusable(input_B)) { + fused_nodes.push_back(input_B); + input_B = input_B->get_input_node_shared_ptr(0); + transpose_B = !transpose_B; + } + + if (!fused_nodes.empty()) { + auto updated_matmul = std::make_shared(input_A, input_B, transpose_A, transpose_B); + fused_nodes.push_back(matmul); + copy_runtime_info(fused_nodes, updated_matmul); + updated_matmul->set_friendly_name(matmul->get_friendly_name()); + replace_node(matmul, updated_matmul); + return true; + } + return false; + }; + auto m = std::make_shared(matmul_label, "TransposeMatMul"); + register_matcher(m, callback); } \ No newline at end of file diff --git a/inference-engine/src/transformations/src/transformations/smart_reshape/smart_reshape.cpp b/inference-engine/src/transformations/src/transformations/smart_reshape/smart_reshape.cpp index 5c35000..305b6b1 100644 --- a/inference-engine/src/transformations/src/transformations/smart_reshape/smart_reshape.cpp +++ b/inference-engine/src/transformations/src/transformations/smart_reshape/smart_reshape.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include @@ -30,6 +30,7 @@ bool ngraph::pass::SmartReshape::run_on_function(std::shared_ptr(); static_manager.register_pass(); static_manager.register_pass(); + static_manager.register_pass(); static_manager.run_passes(f); ngraph::pass::Manager dynamic_manager; diff --git a/inference-engine/tests/functional/inference_engine/cnn_network/matmul_sr_tests.cpp b/inference-engine/tests/functional/inference_engine/cnn_network/matmul_sr_tests.cpp new file mode 100644 index 0000000..4941cc9 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/cnn_network/matmul_sr_tests.cpp @@ -0,0 +1,299 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "cnn_network_ngraph_impl.hpp" + +using namespace testing; +using namespace InferenceEngine; + +namespace { + +using reshape_map = std::map>; + +struct ReshapeMatMulTestCase { + bool reshape_is_A_input; + ngraph::PartialShape A_shape, B_shape; + std::vector reshape_pattern; + bool transpose_a, transpose_b; + reshape_map new_shapes; +}; + +class SmartReshapeMatMulTests : public CommonTestUtils::TestsCommon, public testing::WithParamInterface> { +public: + void SetUp() override { + const auto& parameters = GetParam(); + const auto& test_case = std::get<0>(GetParam()); + + std::shared_ptr ngraph; + { + auto input_A = std::make_shared(ngraph::element::f32, test_case.A_shape); + input_A->set_friendly_name("input_A"); + auto input_B = std::make_shared(ngraph::element::f32, test_case.B_shape); + input_B->set_friendly_name("input_B"); + + auto reshape_pattern = std::make_shared( + ngraph::element::i64, ngraph::Shape{test_case.reshape_pattern.size()}, test_case.reshape_pattern); + reshape_pattern->set_friendly_name("reshape_pattern"); + auto reshape = std::make_shared(test_case.reshape_is_A_input ? input_A : input_B, reshape_pattern, true); + reshape->set_friendly_name("reshape"); + + auto mat_mul = std::make_shared(test_case.reshape_is_A_input ? reshape->output(0) : input_A->output(0), + test_case.reshape_is_A_input ? input_B->output(0) : reshape->output(0), + test_case.transpose_a, test_case.transpose_b); + reshape->set_friendly_name("matmul"); + + auto result = std::make_shared(mat_mul); + ngraph::ParameterVector params = {input_A, input_B}; + ngraph::ResultVector results = {result}; + ngraph = std::make_shared(results, params); + } + + InferenceEngine::details::CNNNetworkNGraphImpl network(ngraph); + const auto & resp = network.reshape(test_case.new_shapes, nullptr); + ASSERT_EQ(resp, StatusCode::OK); + } +}; + +TEST_P(SmartReshapeMatMulTests, ReshapeMatMul) { +} + +INSTANTIATE_TEST_CASE_P(NGraph, SmartReshapeMatMulTests, testing::Values( + ReshapeMatMulTestCase{true, {1, 20, 30}, {30, 40}, {20, -1}, false, false, {{"input_A", {2, 20, 30}}}}, + ReshapeMatMulTestCase{true, {1, 20, 30}, {40, 30}, {20, -1}, false, true, {{"input_A", {2, 20, 30}}}}, + ReshapeMatMulTestCase{true, {1, 30, 20}, {30, 20}, {-1, 20}, true, false, {{"input_A", {2, 30, 20}}}}, + ReshapeMatMulTestCase{true, {1, 30, 20}, {40, 30}, {-1, 20}, true, true, {{"input_A", {2, 30, 20}}}}, + ReshapeMatMulTestCase{false, {20, 30}, {1, 30, 40}, {-1, 40}, false, false, {{"input_B", {2, 30, 40}}}}, + ReshapeMatMulTestCase{false, {20, 30}, {1, 40, 30}, {40, -1}, false, true, {{"input_B", {2, 40, 30}}}}, + ReshapeMatMulTestCase{false, {30, 20}, {1, 30, 40}, {-1, 40}, true, false, {{"input_B", {2, 30, 40}}}}, + ReshapeMatMulTestCase{false, {30, 20}, {1, 40, 30}, {40, -1}, true, true, {{"input_B", {2, 40, 30}}}})); +} // namespace + +TEST(SmartReshapeTransposeMatMulTests, TransposeAMatMulFuse) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data_A = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 2}); + auto data_B = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 5}); + auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1}); + auto transpose = std::make_shared(data_A, order); + auto matmul = std::make_shared(transpose, data_B, false, false); + f = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B}); + + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + { + auto data_A = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 2}); + auto data_B = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 5}); + auto matmul = std::make_shared(data_A, data_B, true, false); + f_ref = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(SmartReshapeTransposeMatMulTests, TransposeBMatMulFuse) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data_A = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 2, 3}); + auto data_B = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 5, 3}); + auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1}); + auto transpose = std::make_shared(data_B, order); + auto matmul = std::make_shared(data_A, transpose, false, false); + f = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B}); + + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + { + auto data_A = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 2, 3}); + auto data_B = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 5, 3}); + auto matmul = std::make_shared(data_A, data_B, false, true); + f_ref = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(SmartReshapeTransposeMatMulTests, TransposeAMatMulWithAttrFuse) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data_A = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 2, 3}); + auto data_B = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 5}); + auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1}); + auto transpose = std::make_shared(data_A, order); + auto matmul = std::make_shared(transpose, data_B, true, false); + f = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B}); + + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + { + auto data_A = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 2, 3}); + auto data_B = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 5}); + auto matmul = std::make_shared(data_A, data_B, false, false); + f_ref = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(SmartReshapeTransposeMatMulTests, TransposeBMatMulWithAttrFuse) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data_A = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 2, 3}); + auto data_B = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 5}); + auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1}); + auto transpose = std::make_shared(data_B, order); + auto matmul = std::make_shared(data_A, transpose, false, true); + f = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B}); + + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + { + auto data_A = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 2, 3}); + auto data_B = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 5}); + auto matmul = std::make_shared(data_A, data_B, false, false); + f_ref = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} +TEST(SmartReshapeTransposeMatMulTests, TransposeAMatMulSideAttrFuse) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data_A = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 2, 3}); + auto data_B = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 5, 3}); + auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1}); + auto transpose = std::make_shared(data_A, order); + auto matmul = std::make_shared(transpose, data_B, true, true); + f = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B}); + + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + { + auto data_A = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 2, 3}); + auto data_B = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 5, 3}); + auto matmul = std::make_shared(data_A, data_B, false, true); + f_ref = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(SmartReshapeTransposeMatMulTests, TransposeBMatMulSideAttrFuse) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data_A = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 2}); + auto data_B = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 5}); + auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1}); + auto transpose = std::make_shared(data_B, order); + auto matmul = std::make_shared(data_A, transpose, true, true); + f = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B}); + + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + { + auto data_A = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 2}); + auto data_B = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 5}); + auto matmul = std::make_shared(data_A, data_B, true, false); + f_ref = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(SmartReshapeTransposeMatMulTests, TransposeBothMatMulFuse) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data_A = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 2}); + auto data_B = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 5, 3}); + auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1}); + auto transpose_A = std::make_shared(data_A, order); + auto transpose_B = std::make_shared(data_B, order); + auto matmul = std::make_shared(transpose_A, transpose_B, false, false); + f = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B}); + + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + { + auto data_A = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 2}); + auto data_B = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 5, 3}); + auto matmul = std::make_shared(data_A, data_B, true, true); + f_ref = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} +TEST(SmartReshapeTransposeMatMulTests, TransposeBothMatMulWithAttrFuse) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data_A = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 2}); + auto data_B = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 5}); + auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1}); + auto transpose_A = std::make_shared(data_A, order); + auto transpose_B = std::make_shared(data_B, order); + auto matmul = std::make_shared(transpose_A, transpose_B, false, true); + f = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B}); + + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + { + auto data_A = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 2}); + auto data_B = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 5}); + auto matmul = std::make_shared(data_A, data_B, true, false); + f_ref = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data_A, data_B}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} diff --git a/inference-engine/tests/functional/inference_engine/cnn_network/smart_reshape_tests.cpp b/inference-engine/tests/functional/inference_engine/cnn_network/smart_reshape_tests.cpp deleted file mode 100644 index f51d1ae..0000000 --- a/inference-engine/tests/functional/inference_engine/cnn_network/smart_reshape_tests.cpp +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (C) 2018-2020 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include - -#include -#include -#include - -#include -#include -#include - -#include "cnn_network_ngraph_impl.hpp" - -using namespace testing; -using namespace InferenceEngine; - -namespace { - -using reshape_map = std::map>; - -struct ReshapeMatMulTestCase { - bool reshape_is_A_input; - ngraph::PartialShape A_shape, B_shape; - std::vector reshape_pattern; - bool transpose_a, transpose_b; - reshape_map new_shapes; -}; - -class CNNNGraphImplSmartReshapeTests : public CommonTestUtils::TestsCommon, public testing::WithParamInterface> { -public: - void SetUp() override { - const auto& parameters = GetParam(); - const auto& test_case = std::get<0>(GetParam()); - - std::shared_ptr ngraph; - { - auto input_A = std::make_shared(ngraph::element::f32, test_case.A_shape); - input_A->set_friendly_name("input_A"); - auto input_B = std::make_shared(ngraph::element::f32, test_case.B_shape); - input_B->set_friendly_name("input_B"); - - auto reshape_pattern = std::make_shared( - ngraph::element::i64, ngraph::Shape{test_case.reshape_pattern.size()}, test_case.reshape_pattern); - reshape_pattern->set_friendly_name("reshape_pattern"); - auto reshape = std::make_shared(test_case.reshape_is_A_input ? input_A : input_B, reshape_pattern, true); - reshape->set_friendly_name("reshape"); - - auto mat_mul = std::make_shared(test_case.reshape_is_A_input ? reshape->output(0) : input_A->output(0), - test_case.reshape_is_A_input ? input_B->output(0) : reshape->output(0), - test_case.transpose_a, test_case.transpose_b); - reshape->set_friendly_name("matmul"); - - auto result = std::make_shared(mat_mul); - ngraph::ParameterVector params = {input_A, input_B}; - ngraph::ResultVector results = {result}; - ngraph = std::make_shared(results, params); - } - - InferenceEngine::details::CNNNetworkNGraphImpl network(ngraph); - const auto & resp = network.reshape(test_case.new_shapes, nullptr); - ASSERT_EQ(resp, StatusCode::OK); - } -}; - -TEST_P(CNNNGraphImplSmartReshapeTests, ReshapeMatMul) { -} - -INSTANTIATE_TEST_CASE_P(NGraph, CNNNGraphImplSmartReshapeTests, testing::Values( - ReshapeMatMulTestCase{true, {1, 20, 30}, {30, 40}, {20, -1}, false, false, {{"input_A", {2, 20, 30}}}}, - ReshapeMatMulTestCase{true, {1, 20, 30}, {40, 30}, {20, -1}, false, true, {{"input_A", {2, 20, 30}}}}, - ReshapeMatMulTestCase{true, {1, 30, 20}, {30, 20}, {-1, 20}, true, false, {{"input_A", {2, 30, 20}}}}, - ReshapeMatMulTestCase{true, {1, 30, 20}, {40, 30}, {-1, 20}, true, true, {{"input_A", {2, 30, 20}}}}, - ReshapeMatMulTestCase{false, {20, 30}, {1, 30, 40}, {-1, 40}, false, false, {{"input_B", {2, 30, 40}}}}, - ReshapeMatMulTestCase{false, {20, 30}, {1, 40, 30}, {40, -1}, false, true, {{"input_B", {2, 40, 30}}}}, - ReshapeMatMulTestCase{false, {30, 20}, {1, 30, 40}, {-1, 40}, true, false, {{"input_B", {2, 30, 40}}}}, - ReshapeMatMulTestCase{false, {30, 20}, {1, 40, 30}, {40, -1}, true, true, {{"input_B", {2, 40, 30}}}})); -} // namespace \ No newline at end of file -- 2.7.4