Reshape-Permute-Reshape pattern to DepthToSpace layer transformation (#601)
authorIvan Tikhonov <ivan.tikhonov@intel.com>
Mon, 1 Jun 2020 06:24:16 +0000 (09:24 +0300)
committerGitHub <noreply@github.com>
Mon, 1 Jun 2020 06:24:16 +0000 (09:24 +0300)
* implemented depth_to_space transformation

* renaming

* added functional tests, fixed mistakes in implementation of the transformation

* disable ConvertSpaceToDepth/ConvertDepthToSpace transformation for CPU plugin, enable DepthToSpaceFusion for CPU plugin only, add specific creators

* fix wrong include

* fix for functional tests: set transformation callback

* revert callback calls for CPU plugin

* move functions to .cpp file

* Apply review comments

* Apply additional review comments

* fix cast to bool type

inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp
inference-engine/src/transformations/include/transformations/common_optimizations/common_optimizations_tbl.hpp
inference-engine/src/transformations/include/transformations/convert_depth_to_space.hpp
inference-engine/src/transformations/include/transformations/convert_opset3_to_opset2/convert_opset3_to_opset2_tbl.hpp
inference-engine/src/transformations/include/transformations/convert_space_to_depth.hpp
inference-engine/src/transformations/include/transformations/depth_to_space_fusion.hpp [new file with mode: 0644]
inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp
inference-engine/src/transformations/src/transformations/convert_depth_to_space.cpp
inference-engine/src/transformations/src/transformations/convert_space_to_depth.cpp
inference-engine/src/transformations/src/transformations/depth_to_space_fusion.cpp [new file with mode: 0644]
inference-engine/tests/functional/inference_engine/transformations/depth_to_space_fusion_test.cpp [new file with mode: 0644]

index 22db934..e56ed03 100644 (file)
@@ -332,7 +332,7 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr
         res->params = params;
         return res;
     });
-
+    
     addSpecificCreator({"Assign"}, [](const std::shared_ptr<::ngraph::Node>& node,
                                             const std::map<std::string, std::string> params) -> CNNLayerPtr {
         LayerParams attrs = {node->get_friendly_name(), "Memory",
@@ -355,6 +355,24 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr
         return res;
     });
 
+    addSpecificCreator({"DepthToSpace"}, [](const std::shared_ptr<::ngraph::Node>& node,
+                                            const std::map<std::string, std::string> params) -> CNNLayerPtr {
+        LayerParams attrs = {node->get_friendly_name(), node->description(),
+                             details::convertPrecision(node->get_output_element_type(0))};
+        auto res = std::make_shared<DepthToSpaceLayer>(attrs);
+        res->params = params;
+        return res;
+    });
+
+    addSpecificCreator({"SpaceToDepth"}, [](const std::shared_ptr<::ngraph::Node>& node,
+                                            const std::map<std::string, std::string> params) -> CNNLayerPtr {
+        LayerParams attrs = {node->get_friendly_name(), node->description(),
+                             details::convertPrecision(node->get_output_element_type(0))};
+        auto res = std::make_shared<SpaceToDepthLayer>(attrs);
+        res->params = params;
+        return res;
+    });
+
     addSpecificCreator({"RNNCell"}, [](const std::shared_ptr<::ngraph::Node>& node,
                                             const std::map<std::string, std::string> params) -> CNNLayerPtr {
         THROW_IE_EXCEPTION << "RNNCell operation has a form that is not supported." << node->get_friendly_name()
index c0a53fc..f2da6d7 100644 (file)
@@ -25,3 +25,4 @@ NGRAPH_PASS(NopElimination, ::ngraph::pass) // may introduce fake dynamism
 NGRAPH_PASS(AlgebraicSimplification, ::ngraph::pass) // may introduce fake dynamism
 NGRAPH_PASS(ConstantFolding, ::ngraph::pass)
 NGRAPH_PASS(ConvertScatterElementsToScatter, ::ngraph::pass) // partially depends on CF
+NGRAPH_PASS(DepthToSpaceFusion, ::ngraph::pass)
index 7ad3f6c..8ac0cba 100644 (file)
@@ -10,6 +10,7 @@
 #include <ie_api.h>
 
 #include <ngraph/pass/graph_rewrite.hpp>
+#include "transformations/utils/pass_param.hpp"
 
 namespace ngraph {
 namespace pass {
@@ -19,9 +20,9 @@ class INFERENCE_ENGINE_API_CLASS(ConvertDepthToSpace);
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertDepthToSpace: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertDepthToSpace: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
 public:
-    ConvertDepthToSpace() : GraphRewrite() {
+    ConvertDepthToSpace() : GraphRewrite(), PassParam() {
         convert_depth_to_space();
     }
 
index 75058b3..f148c90 100644 (file)
@@ -10,6 +10,7 @@
 #include <ie_api.h>
 
 #include <ngraph/pass/graph_rewrite.hpp>
+#include "transformations/utils/pass_param.hpp"
 
 namespace ngraph {
 namespace pass {
@@ -19,9 +20,9 @@ class INFERENCE_ENGINE_API_CLASS(ConvertSpaceToDepth);
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertSpaceToDepth: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertSpaceToDepth: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam  {
 public:
-    ConvertSpaceToDepth() : GraphRewrite() {
+    ConvertSpaceToDepth() : GraphRewrite(), PassParam() {
         convert();
     }
 
diff --git a/inference-engine/src/transformations/include/transformations/depth_to_space_fusion.hpp b/inference-engine/src/transformations/include/transformations/depth_to_space_fusion.hpp
new file mode 100644 (file)
index 0000000..e9e0fc1
--- /dev/null
@@ -0,0 +1,54 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <vector>
+#include <memory>
+
+#include <ie_api.h>
+
+#include <ngraph/pass/graph_rewrite.hpp>
+#include "transformations/utils/pass_param.hpp"
+
+namespace ngraph {
+namespace pass {
+
+    class INFERENCE_ENGINE_API_CLASS(DepthToSpaceFusion);
+
+}  // namespace pass
+}  // namespace ngraph
+
+/*
+ * Description:
+ *     DepthToSpaceFusion transformation detects Reshape-Transpose-Reshape pattern and
+ *     tries to fuse it into a single DepthToSpace layer.
+ *
+ * Usage:
+ *     DepthToSpaceFusion transformation is optional and disabled by default.
+ *     The transformation can be enabled with callback using setCallback method.
+ *     See the example below.
+ *
+ * Callback example:
+ *
+ *     // This callback enables DepthToSpaceFusion transformation
+ *     auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
+ *         return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
+ *     };
+ *
+ *     auto p = ngraph::pass::DepthToSpaceFusion();
+ *     p.setCallback(callback);
+ *     p.run_on_function(f);
+ *
+ */
+
+class ngraph::pass::DepthToSpaceFusion: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+public:
+    DepthToSpaceFusion() : GraphRewrite(), PassParam() {
+        depth_to_space_fusion();
+    }
+
+private:
+    void depth_to_space_fusion();
+};
index 467f076..5b1a956 100644 (file)
@@ -6,6 +6,7 @@
 
 #include "transformations/common_optimizations/common_optimizations.hpp"
 #include "transformations/convert_opset1_to_legacy/convert_prior_to_ie_prior.hpp"
+#include "transformations/depth_to_space_fusion.hpp"
 #include "transformations/optimize_strided_slice.hpp"
 #include "transformations/convert_scatter_elements_to_scatter.hpp"
 #include "transformations/remove_filtering_boxes_by_size.hpp"
index aea1555..959b459 100644 (file)
@@ -14,9 +14,9 @@ void ngraph::pass::ConvertDepthToSpace::convert_depth_to_space() {
     auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
     auto dts_node = std::make_shared<ngraph::opset1::DepthToSpace>(input0, ngraph::op::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST);
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
         auto dts_node = std::dynamic_pointer_cast<ngraph::opset1::DepthToSpace> (m.get_match_root());
-        if (!dts_node) {
+        if (!dts_node || transformation_callback(dts_node)) {
             return false;
         }
 
index f4ff7ec..5eb1655 100644 (file)
@@ -14,9 +14,9 @@ void ngraph::pass::ConvertSpaceToDepth::convert() {
     auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
     auto dts = std::make_shared<ngraph::opset1::SpaceToDepth>(input0, ngraph::opset1::SpaceToDepth::SpaceToDepthMode::DEPTH_FIRST);
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
         auto std_node = std::dynamic_pointer_cast<ngraph::opset1::SpaceToDepth> (m.get_match_root());
-        if (!std_node) {
+        if (!std_node || transformation_callback(std_node)) {
             return false;
         }
 
diff --git a/inference-engine/src/transformations/src/transformations/depth_to_space_fusion.cpp b/inference-engine/src/transformations/src/transformations/depth_to_space_fusion.cpp
new file mode 100644 (file)
index 0000000..a2323b1
--- /dev/null
@@ -0,0 +1,166 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "transformations/depth_to_space_fusion.hpp"
+
+#include <memory>
+#include <vector>
+
+#include <ngraph/opsets/opset3.hpp>
+#include <ngraph/rt_info.hpp>
+
+bool check_block_first(const ngraph::Shape& shape_input, const ngraph::Shape& shape_reshape_before,
+                       const ngraph::AxisVector& permutation, const ngraph::Shape& shape_reshape_after,
+                       size_t& possible_block_size) {
+    bool is_transformation_valid = true;
+    uint64_t spatial_dims = shape_input.size() - 2;
+    possible_block_size = shape_reshape_before[1];
+    if (possible_block_size == 0)
+        return false;
+    uint64_t c_dim = shape_input[1] / std::pow(possible_block_size, spatial_dims);
+
+    // x' = reshape(data, [N, block_size, block_size, ..., block_size, C / (block_size ^ K), D1, D2, ..., DK])
+    ngraph::Shape expected_shape = {shape_input[0]};
+    for (uint64_t i = 0; i < spatial_dims; ++i)
+        expected_shape.push_back(possible_block_size);
+    expected_shape.push_back(c_dim);
+    for (uint64_t i = 2; i < shape_input.size(); ++i)
+        expected_shape.push_back(shape_input[i]);
+    is_transformation_valid &= (expected_shape == shape_reshape_before);
+
+    // x'' = transpose(x', [0,  K + 1,  K + 2, 1, K + 3, 2, K + 4, 3, ..., K + (K + 1), K])
+    ngraph::AxisVector expected_permutation = {0, spatial_dims + 1};
+    for (uint64_t i = 2; i < shape_input.size(); ++i) {
+        expected_permutation.push_back(spatial_dims + i);
+        expected_permutation.push_back(i - 1);
+    }
+    is_transformation_valid &= (expected_permutation == permutation);
+
+    // y = reshape(x'', [N, C / (block_size ^ K), D1 * block_size, D2 * block_size, D3 * block_size, ..., DK * block_size])
+    expected_shape = {shape_input[0], c_dim};
+    for (uint64_t i = 2; i < shape_input.size(); ++i)
+        expected_shape.push_back(shape_input[i] * possible_block_size);
+    is_transformation_valid &= (expected_shape == shape_reshape_after);
+
+    return is_transformation_valid;
+}
+
+bool check_depth_first(const ngraph::Shape& shape_input, const ngraph::Shape& shape_reshape_before,
+                       const ngraph::AxisVector& permutation, const ngraph::Shape& shape_reshape_after,
+                       size_t& possible_block_size) {
+    bool is_transformation_valid = true;
+    uint64_t spatial_dims = shape_input.size() - 2;
+    possible_block_size = shape_reshape_before[2];
+    if (possible_block_size == 0)
+        return false;
+    uint64_t c_dim = shape_input[1] / std::pow(possible_block_size, spatial_dims);
+
+    // x' = reshape(data, [N, C / (block_size ^ K), block_size, block_size, ..., block_size, D1, D2, ..., DK])
+    ngraph::Shape expected_shape = {shape_input[0], c_dim};
+    for (uint64_t i = 0; i < spatial_dims; ++i)
+        expected_shape.push_back(possible_block_size);
+    for (uint64_t i = 2; i < shape_input.size(); ++i)
+        expected_shape.push_back(shape_input[i]);
+    is_transformation_valid &= (expected_shape == shape_reshape_before);
+
+    // x'' = transpose(x', [0,  1,  K + 2, 2, K + 3, 3, K + 4, 4, ..., K + (K + 1), K + 1])
+    ngraph::AxisVector expected_permutation = {0, 1};
+    for (uint64_t i = 2; i < shape_input.size(); ++i) {
+        expected_permutation.push_back(spatial_dims + i);
+        expected_permutation.push_back(i);
+    }
+    is_transformation_valid &= (expected_permutation == permutation);
+
+    // y = reshape(x'', [N, C / (block_size ^ K), D1 * block_size, D2 * block_size, D3 * block_size, ..., DK * block_size])
+    expected_shape = {shape_input[0], c_dim};
+    for (uint64_t i = 2; i < shape_input.size(); ++i)
+        expected_shape.push_back(shape_input[i] * possible_block_size);
+    is_transformation_valid &= (expected_shape == shape_reshape_after);
+
+    return is_transformation_valid;
+}
+
+void ngraph::pass::DepthToSpaceFusion::depth_to_space_fusion() {
+    auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
+    auto input1 = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
+    auto input2 = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
+    auto input3 = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
+    auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, input1, false);
+    auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, input2);
+    auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, input3, false);
+
+    ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
+        if (!transformation_callback(std::make_shared<ngraph::opset3::DepthToSpace>())) {
+            return false;
+        }
+
+        auto reshape_after = std::dynamic_pointer_cast<ngraph::opset3::Reshape>(m.get_match_root());
+        if (!reshape_after) {
+            return false;
+        }
+
+        auto permute = std::dynamic_pointer_cast<ngraph::opset3::Transpose>(reshape_after->input_value(0).get_node_shared_ptr());
+        if (!permute || permute->get_output_target_inputs(0).size() != 1) {
+            return false;
+        }
+
+        auto reshape_before = std::dynamic_pointer_cast<ngraph::opset3::Reshape>(permute->input_value(0).get_node_shared_ptr());
+        if (!reshape_before || reshape_before->get_output_target_inputs(0).size() != 1) {
+            return false;
+        }
+
+        auto p_shape_input = reshape_before->get_input_partial_shape(0);
+        auto p_shape_reshape_before = reshape_before->get_output_partial_shape(0);
+        auto p_shape_permute = permute->get_output_partial_shape(0);
+        auto p_shape_reshape_after = reshape_after->get_output_partial_shape(0);
+
+        if (p_shape_input.is_dynamic() || p_shape_reshape_before.is_dynamic() ||
+            p_shape_permute.is_dynamic() || p_shape_reshape_after.is_dynamic()) {
+            return false;
+        }
+
+        auto shape_input = p_shape_input.get_shape();
+        auto shape_reshape_before = p_shape_reshape_before.get_shape();
+        auto shape_permute = p_shape_permute.get_shape();
+        auto shape_reshape_after = p_shape_reshape_after.get_shape();
+
+        if (shape_input.size() < 3) {
+            return false;
+        }
+
+        // input shape: [ batch, C, spatial_dims], expected_shape = spatial_dims.size() * 2 + 2
+        size_t expected_shape_size = (shape_input.size() - 2) * 2 + 2;
+        if (shape_input.size() != shape_reshape_after.size() || shape_reshape_before.size() != expected_shape_size ||
+            shape_permute.size() != expected_shape_size) {
+            return false;
+        }
+
+        ngraph::AxisVector permutation;
+        if (auto input_const = std::dynamic_pointer_cast<opset3::Constant>(permute->input_value(1).get_node_shared_ptr())) {
+            permutation = input_const->get_axis_vector_val();
+        } else {
+            return false;
+        }
+
+        ngraph::opset3::DepthToSpace::DepthToSpaceMode mode;
+        size_t block_size;
+        if (check_depth_first(shape_input, shape_reshape_before, permutation, shape_reshape_after, block_size)) {
+            mode = ngraph::opset3::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST;
+        } else if (check_block_first(shape_input, shape_reshape_before, permutation, shape_reshape_after, block_size)) {
+            mode = ngraph::opset3::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST;
+        } else {
+            return false;
+        }
+
+        auto depth_to_space =
+                std::make_shared<ngraph::opset3::DepthToSpace>(reshape_before->input_value(0), mode, block_size);
+        depth_to_space->set_friendly_name(reshape_after->get_friendly_name());
+        ngraph::copy_runtime_info({reshape_before, permute, reshape_after}, depth_to_space);
+        ngraph::replace_node(reshape_after, depth_to_space);
+        return true;
+    };
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_after, "DepthToSpaceFusion");
+    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+}
\ No newline at end of file
diff --git a/inference-engine/tests/functional/inference_engine/transformations/depth_to_space_fusion_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/depth_to_space_fusion_test.cpp
new file mode 100644 (file)
index 0000000..55a9187
--- /dev/null
@@ -0,0 +1,184 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+
+#include "common_test_utils/test_common.hpp"
+#include <string>
+#include <sstream>
+#include <fstream>
+#include <memory>
+#include <queue>
+#include <map>
+
+#include <ngraph/function.hpp>
+#include <ngraph/opsets/opset3.hpp>
+#include <ngraph/pass/constant_folding.hpp>
+#include <ngraph_ops/fully_connected.hpp>
+#include <transformations/depth_to_space_fusion.hpp>
+#include <transformations/utils/utils.hpp>
+#include <transformations/init_node_info.hpp>
+
+#include "ngraph_test_utils.hpp"
+
+using namespace testing;
+
+TEST(TransformationTests, DepthToSpaceFusionDepthFirst) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
+        auto shape_reshape_before = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {1, 32, 2, 2, 720, 480});
+        auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 1, 4, 2, 5, 3});
+        auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960});
+
+        auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, shape_reshape_before, false);
+        auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, permutation);
+        auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0});
+        ngraph::pass::InitNodeInfo().run_on_function(f);
+        auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
+            return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
+        };
+
+        auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
+        depth_to_space_transform.setCallback(callback);
+        depth_to_space_transform.run_on_function(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
+        auto depth_to_space = std::make_shared<ngraph::opset3::DepthToSpace>(input0, ngraph::opset3::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST, 2);
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{depth_to_space}, ngraph::ParameterVector{input0});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, DepthToSpaceFusionBlockFirst) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
+        auto shape_reshape_before = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {1, 2, 2, 32, 720, 480});
+        auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2});
+        auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960});
+
+        auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, shape_reshape_before, false);
+        auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, permutation);
+        auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0});
+        ngraph::pass::InitNodeInfo().run_on_function(f);
+        auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
+            return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
+        };
+
+        auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
+        depth_to_space_transform.setCallback(callback);
+        depth_to_space_transform.run_on_function(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
+        auto depth_to_space = std::make_shared<ngraph::opset3::DepthToSpace>(input0, ngraph::opset3::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST, 2);
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{depth_to_space}, ngraph::ParameterVector{input0});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, DepthToSpaceFusionDynamicShape) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
+        auto shape_reshape_before = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{6});
+        auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2});
+        auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960});
+
+        auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, shape_reshape_before, false);
+        auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, permutation);
+        auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0, shape_reshape_before});
+        ngraph::pass::InitNodeInfo().run_on_function(f);
+        auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
+            return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
+        };
+
+        // transformation won't be applied because of shape_reshape_before is dynamic, the graph will remain the same
+        auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
+        depth_to_space_transform.setCallback(callback);
+        depth_to_space_transform.run_on_function(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
+        auto shape_reshape_before = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{6});
+        auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2});
+        auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960});
+
+        auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, shape_reshape_before, false);
+        auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, permutation);
+        auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0, shape_reshape_before});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, DepthToSpaceFusionSeveralConsumers) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
+        auto shape_reshape_before = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {1, 2, 2, 32, 720, 480});
+        auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2});
+        auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960});
+
+        auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, shape_reshape_before, false);
+        auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, permutation);
+        auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
+
+        // additional consumers, not output of the function
+        auto result = std::make_shared<ngraph::opset3::Result> (reshape_before);
+        auto result_2 = std::make_shared<ngraph::opset3::Result> (permute);
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0});
+        ngraph::pass::InitNodeInfo().run_on_function(f);
+        auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
+            return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
+        };
+
+        // transformation won't be applied because of reshape_before has several consumers, the graph will remain the same
+        auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
+        depth_to_space_transform.setCallback(callback);
+        depth_to_space_transform.run_on_function(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
+        auto shape_reshape_before = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {1, 2, 2, 32, 720, 480});
+        auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2});
+        auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960});
+
+        auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, shape_reshape_before, false);
+        auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, permutation);
+        auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
+
+        // additional consumers, not output of the function
+        auto result = std::make_shared<ngraph::opset3::Result> (reshape_before);
+        auto result_2 = std::make_shared<ngraph::opset3::Result> (permute);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}