From 75601e62ed1d2015ae3572f2c6175a07b163cc32 Mon Sep 17 00:00:00 2001 From: Evgenya Stepyreva Date: Mon, 14 Sep 2020 13:45:27 +0300 Subject: [PATCH] Super smart reshape: HC Reshape to 2D followed by MatMul (#2183) * Initial commit * [SSR] Reshape(2D)->MatMul constrain relaxation * Moved common pattern mechanics to the common function * Moving SmartReshape to CNNNetworkNgraphImpl ctors * Review comment * Tests --- .../inference_engine/cnn_network_ngraph_impl.cpp | 5 ++ .../smart_reshape/reshape_with_hc_output.hpp | 30 ++++++++ .../smart_reshape/smart_reshape.hpp | 26 +++++++ .../transformations/smart_reshape/utils.hpp | 38 ++++++++++ .../smart_reshape/reshape_with_hc_output.cpp | 79 +++++++++++++++++++++ .../smart_reshape/smart_reshape.cpp | 27 ++++++++ .../cnn_network/smart_reshape_tests.cpp | 80 ++++++++++++++++++++++ 7 files changed, 285 insertions(+) create mode 100644 inference-engine/src/transformations/include/transformations/smart_reshape/reshape_with_hc_output.hpp create mode 100644 inference-engine/src/transformations/include/transformations/smart_reshape/smart_reshape.hpp create mode 100644 inference-engine/src/transformations/include/transformations/smart_reshape/utils.hpp create mode 100644 inference-engine/src/transformations/src/transformations/smart_reshape/reshape_with_hc_output.cpp create mode 100644 inference-engine/src/transformations/src/transformations/smart_reshape/smart_reshape.cpp create mode 100644 inference-engine/tests/functional/inference_engine/cnn_network/smart_reshape_tests.cpp diff --git a/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp b/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp index a920eac..3f3dae7 100644 --- a/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp +++ b/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp @@ -22,6 +22,7 @@ #include #include +#include #include "ngraph_ops/eltwise.hpp" #include "exec_graph_info.hpp" @@ -126,6 +127,10 @@ CNNNetworkNGraphImpl::CNNNetworkNGraphImpl(const std::shared_ptr& nGra // Add shape infer method for old operations which are not included to opset1, opset2 and opset3 ::ngraph::op::GenericIE::addExtension(_ngraph_function, std::make_shared()); + ngraph::pass::Manager ssr_manager; + ssr_manager.register_pass(); + ssr_manager.run_passes(_ngraph_function); + reshape(); for (const auto& layer : _ngraph_function->get_parameters()) { std::string outName = layer->get_friendly_name(); diff --git a/inference-engine/src/transformations/include/transformations/smart_reshape/reshape_with_hc_output.hpp b/inference-engine/src/transformations/include/transformations/smart_reshape/reshape_with_hc_output.hpp new file mode 100644 index 0000000..c925928 --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/smart_reshape/reshape_with_hc_output.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include + +#include + +namespace ngraph { +namespace pass { + +class TRANSFORMATIONS_API ReshapeAMatMul; +class TRANSFORMATIONS_API ReshapeBMatMul; + +} // namespace pass +} // namespace ngraph + +class ngraph::pass::ReshapeAMatMul: public ngraph::pass::MatcherPass { +public: + ReshapeAMatMul(); +}; +class ngraph::pass::ReshapeBMatMul: public ngraph::pass::MatcherPass { +public: + ReshapeBMatMul(); +}; \ No newline at end of file diff --git a/inference-engine/src/transformations/include/transformations/smart_reshape/smart_reshape.hpp b/inference-engine/src/transformations/include/transformations/smart_reshape/smart_reshape.hpp new file mode 100644 index 0000000..344787e --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/smart_reshape/smart_reshape.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include + +#include + + +namespace ngraph { +namespace pass { + +class TRANSFORMATIONS_API SmartReshape; + +} // namespace pass +} // namespace ngraph + +class ngraph::pass::SmartReshape: public ngraph::pass::FunctionPass { +public: + bool run_on_function(std::shared_ptr f) override; +}; diff --git a/inference-engine/src/transformations/include/transformations/smart_reshape/utils.hpp b/inference-engine/src/transformations/include/transformations/smart_reshape/utils.hpp new file mode 100644 index 0000000..a9d0888 --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/smart_reshape/utils.hpp @@ -0,0 +1,38 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace ngraph { +namespace op { +namespace util { + +std::shared_ptr node_to_get_shape_value_of_indices_from_shape_node(const std::shared_ptr& shape_node, + const std::vector& indices) { + return std::make_shared(shape_node, + ngraph::opset4::Constant::create(ngraph::element::i64, {indices.size()}, indices), + ngraph::opset4::Constant::create(ngraph::element::i64, {}, {0})); +} + +std::shared_ptr node_to_get_shape_value_of_indices_from_shape_source(const ngraph::Output& shape_source, + const std::vector& indices) { + const auto & shape_node = std::make_shared(shape_source); + return node_to_get_shape_value_of_indices_from_shape_node(shape_node, indices); +} + +} // namespace util +} // namespace op +} // namespace ngraph \ No newline at end of file diff --git a/inference-engine/src/transformations/src/transformations/smart_reshape/reshape_with_hc_output.cpp b/inference-engine/src/transformations/src/transformations/smart_reshape/reshape_with_hc_output.cpp new file mode 100644 index 0000000..2ee2d3c --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/smart_reshape/reshape_with_hc_output.cpp @@ -0,0 +1,79 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/smart_reshape/reshape_with_hc_output.hpp" +#include "transformations/smart_reshape/utils.hpp" + +#include +#include + +#include +#include +#include +#include +#include + +bool relax_hc_reshape_followed_by_matmul(const ngraph::pattern::PatternValueMap & pattern_to_output, + const std::shared_ptr & matmul_label, + const std::shared_ptr & reshape_label, + const std::shared_ptr & other_input_label, + const std::shared_ptr & reshape_pattern_label, + bool reshape_is_A_input) { + auto reshape_pattern = std::dynamic_pointer_cast(pattern_to_output.at(reshape_pattern_label).get_node_shared_ptr()); + const auto & matmul = std::dynamic_pointer_cast(pattern_to_output.at(matmul_label).get_node_shared_ptr()); + if (!reshape_pattern || !matmul || reshape_pattern->get_shape() != ngraph::Shape{2}) + return false; + const auto &shape_source = pattern_to_output.at(other_input_label); + if (ngraph::is_type(shape_source.get_node_shared_ptr()) || + ngraph::is_type(shape_source.get_node_shared_ptr())) + // avoiding loop creation + return false; + const auto & reshape = pattern_to_output.at(reshape_label).get_node_shared_ptr(); + + const auto & raw_idx = reshape_is_A_input ? (matmul->get_transpose_b() ? -1 : -2) : (matmul->get_transpose_a() ? -2 : -1); + const auto & idx = ngraph::normalize_axes(matmul->description(), {raw_idx}, reshape->get_output_partial_shape(0).rank()); + const auto & C = ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_source(shape_source, idx); + const auto & N = ngraph::opset4::Constant::create(ngraph::element::i64, {1}, {-1}); + const auto & pattern_vector = reshape_is_A_input ? + (matmul->get_transpose_a() ? ngraph::OutputVector({C, N}) : ngraph::OutputVector({N, C})) : + (matmul->get_transpose_b() ? ngraph::OutputVector({N, C}) : ngraph::OutputVector({C, N})); + const auto & new_reshape_pattern = std::make_shared(pattern_vector, 0); + + new_reshape_pattern->set_friendly_name(reshape_pattern->get_friendly_name()); + copy_runtime_info(reshape_pattern, new_reshape_pattern); + replace_node(reshape_pattern, new_reshape_pattern); + return true; +} + +ngraph::pass::ReshapeAMatMul::ReshapeAMatMul() { + auto other_input_label = pattern::any_input(); + auto reshape_input_label = pattern::any_input(); + auto reshape_pattern_label = ngraph::pattern::wrap_type(); + auto reshape_label = ngraph::pattern::wrap_type({reshape_input_label, reshape_pattern_label}); + auto matmul_label = ngraph::pattern::wrap_type({reshape_label, other_input_label}); + + matcher_pass_callback callback = [=](pattern::Matcher &m) -> bool { + const auto & pattern_to_output = m.get_pattern_value_map(); + return relax_hc_reshape_followed_by_matmul( + pattern_to_output, matmul_label, reshape_label, other_input_label, reshape_pattern_label, true); + }; + auto m = std::make_shared(matmul_label, "ReshapeMatMul_A"); + register_matcher(m, callback); +} + +ngraph::pass::ReshapeBMatMul::ReshapeBMatMul() { + auto other_input_label = pattern::any_input(); + auto reshape_input_label = pattern::any_input(); + auto reshape_pattern_label = ngraph::pattern::wrap_type(); + auto reshape_label = ngraph::pattern::wrap_type({reshape_input_label, reshape_pattern_label}); + auto matmul_label = ngraph::pattern::wrap_type({other_input_label, reshape_label}); + + matcher_pass_callback callback = [=](pattern::Matcher &m) -> bool { + const auto & pattern_to_output = m.get_pattern_value_map(); + return relax_hc_reshape_followed_by_matmul( + pattern_to_output, matmul_label, reshape_label, other_input_label, reshape_pattern_label, false); + }; + auto m = std::make_shared(matmul_label, "ReshapeMatMul_B"); + register_matcher(m, callback); +} \ No newline at end of file diff --git a/inference-engine/src/transformations/src/transformations/smart_reshape/smart_reshape.cpp b/inference-engine/src/transformations/src/transformations/smart_reshape/smart_reshape.cpp new file mode 100644 index 0000000..52a903a --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/smart_reshape/smart_reshape.cpp @@ -0,0 +1,27 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "transformations/smart_reshape/smart_reshape.hpp" +#include "transformations/smart_reshape/reshape_with_hc_output.hpp" +#include "transformations/itt.hpp" + +#include +#include +#include + +bool ngraph::pass::SmartReshape::run_on_function(std::shared_ptr f) { + OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::SmartReshape"); + + ngraph::pass::Manager manager; + // This pass must be called first in pipeline + manager.register_pass(); + + manager.register_pass(); + manager.register_pass(); + + manager.run_passes(f); + return true; +} diff --git a/inference-engine/tests/functional/inference_engine/cnn_network/smart_reshape_tests.cpp b/inference-engine/tests/functional/inference_engine/cnn_network/smart_reshape_tests.cpp new file mode 100644 index 0000000..f51d1ae --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/cnn_network/smart_reshape_tests.cpp @@ -0,0 +1,80 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include + +#include +#include +#include + +#include "cnn_network_ngraph_impl.hpp" + +using namespace testing; +using namespace InferenceEngine; + +namespace { + +using reshape_map = std::map>; + +struct ReshapeMatMulTestCase { + bool reshape_is_A_input; + ngraph::PartialShape A_shape, B_shape; + std::vector reshape_pattern; + bool transpose_a, transpose_b; + reshape_map new_shapes; +}; + +class CNNNGraphImplSmartReshapeTests : public CommonTestUtils::TestsCommon, public testing::WithParamInterface> { +public: + void SetUp() override { + const auto& parameters = GetParam(); + const auto& test_case = std::get<0>(GetParam()); + + std::shared_ptr ngraph; + { + auto input_A = std::make_shared(ngraph::element::f32, test_case.A_shape); + input_A->set_friendly_name("input_A"); + auto input_B = std::make_shared(ngraph::element::f32, test_case.B_shape); + input_B->set_friendly_name("input_B"); + + auto reshape_pattern = std::make_shared( + ngraph::element::i64, ngraph::Shape{test_case.reshape_pattern.size()}, test_case.reshape_pattern); + reshape_pattern->set_friendly_name("reshape_pattern"); + auto reshape = std::make_shared(test_case.reshape_is_A_input ? input_A : input_B, reshape_pattern, true); + reshape->set_friendly_name("reshape"); + + auto mat_mul = std::make_shared(test_case.reshape_is_A_input ? reshape->output(0) : input_A->output(0), + test_case.reshape_is_A_input ? input_B->output(0) : reshape->output(0), + test_case.transpose_a, test_case.transpose_b); + reshape->set_friendly_name("matmul"); + + auto result = std::make_shared(mat_mul); + ngraph::ParameterVector params = {input_A, input_B}; + ngraph::ResultVector results = {result}; + ngraph = std::make_shared(results, params); + } + + InferenceEngine::details::CNNNetworkNGraphImpl network(ngraph); + const auto & resp = network.reshape(test_case.new_shapes, nullptr); + ASSERT_EQ(resp, StatusCode::OK); + } +}; + +TEST_P(CNNNGraphImplSmartReshapeTests, ReshapeMatMul) { +} + +INSTANTIATE_TEST_CASE_P(NGraph, CNNNGraphImplSmartReshapeTests, testing::Values( + ReshapeMatMulTestCase{true, {1, 20, 30}, {30, 40}, {20, -1}, false, false, {{"input_A", {2, 20, 30}}}}, + ReshapeMatMulTestCase{true, {1, 20, 30}, {40, 30}, {20, -1}, false, true, {{"input_A", {2, 20, 30}}}}, + ReshapeMatMulTestCase{true, {1, 30, 20}, {30, 20}, {-1, 20}, true, false, {{"input_A", {2, 30, 20}}}}, + ReshapeMatMulTestCase{true, {1, 30, 20}, {40, 30}, {-1, 20}, true, true, {{"input_A", {2, 30, 20}}}}, + ReshapeMatMulTestCase{false, {20, 30}, {1, 30, 40}, {-1, 40}, false, false, {{"input_B", {2, 30, 40}}}}, + ReshapeMatMulTestCase{false, {20, 30}, {1, 40, 30}, {40, -1}, false, true, {{"input_B", {2, 40, 30}}}}, + ReshapeMatMulTestCase{false, {30, 20}, {1, 30, 40}, {-1, 40}, true, false, {{"input_B", {2, 30, 40}}}}, + ReshapeMatMulTestCase{false, {30, 20}, {1, 40, 30}, {40, -1}, true, true, {{"input_B", {2, 40, 30}}}})); +} // namespace \ No newline at end of file -- 2.7.4