Improve ConvertBroadcast3 pass to avoid extra Multiply operations for BIDIRECTIONAL...
authorGleb Kazantaev <gleb.kazantaev@intel.com>
Fri, 13 Nov 2020 11:39:07 +0000 (14:39 +0300)
committerGitHub <noreply@github.com>
Fri, 13 Nov 2020 11:39:07 +0000 (14:39 +0300)
* Fixed ConvertBroadcast3 pass for BIDIRECTIONAL mode to avoid excess Multiply operations

* Added funcitonal tests for new decompositions

* Return false if mode is unknown; avoid usign node in replace_node

* Added functional tests for cases when TargetShape input is not a Constant

inference-engine/src/transformations/include/transformations/op_conversions/convert_broadcast3.hpp
inference-engine/src/transformations/src/transformations/op_conversions/convert_broadcast3.cpp
inference-engine/tests/functional/inference_engine/transformations/convert_broadcast3_test.cpp

index 2fa2583..03c2b35 100644 (file)
@@ -19,13 +19,8 @@ class TRANSFORMATIONS_API ConvertBroadcast3;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertBroadcast3: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertBroadcast3: public ngraph::pass::MatcherPass {
 public:
     NGRAPH_RTTI_DECLARATION;
-    ConvertBroadcast3() : GraphRewrite() {
-        convert_broadcast3();
-    }
-
-private:
-    void convert_broadcast3();
+    ConvertBroadcast3();
 };
index 2b251c3..c2641b9 100644 (file)
 
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph/opsets/opset3.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertBroadcast3, "ConvertBroadcast3", 0);
 
-void ngraph::pass::ConvertBroadcast3::convert_broadcast3() {
-    auto broadcast = std::make_shared<pattern::op::Label>(element::f32, Shape {}, pattern::has_class<opset3::Broadcast>());
+bool make_compatible_shape(const ngraph::PartialShape & input_shape, std::vector<size_t> & target_shape) {
+    if (input_shape.rank().is_dynamic()) {
+        return false;
+    }
+    const int64_t & input_shape_rank = input_shape.rank().get_length();
+    if (input_shape_rank > target_shape.size()) {
+        // target_shape rank must greater or equal to input_shape rank, so in case when it's less we
+        // insert missing input_shape dimensions to the beginning of the target_shape.
+        const int64_t & dims_to_add_count = input_shape_rank - target_shape.size();
+        std::vector<size_t> dims_to_add(dims_to_add_count);
+        for (size_t dim = 0; dim < dims_to_add_count; ++dim) {
+            if (input_shape[dim].is_dynamic()) {
+                return false;
+            }
+            dims_to_add[dim] = input_shape[dim].get_length();
+        }
+        target_shape.insert(target_shape.begin(), dims_to_add.begin(), dims_to_add.end());
+    }
+    for (int64_t i_dim = input_shape_rank - 1, t_dim = target_shape.size() - 1; i_dim >= 0 && t_dim >= 0; --i_dim, --t_dim) {
+        if (input_shape[i_dim].is_static()) {
+            const auto & input_dim = input_shape[i_dim].get_length();
+            if (input_dim != target_shape[t_dim] && input_dim != 1 && target_shape[t_dim] != 1) {
+                // this dimensions are not broadcastable
+                return false;
+            }
+            target_shape[t_dim] = std::max(target_shape[t_dim], static_cast<size_t>(input_dim));
+        } else {
+            if (target_shape[t_dim] == 1) {
+                // For example:    |
+                //                \/
+                // input_shape  [DYN, 3, 4]
+                // target_shape [  1, 3, 4] - broadcasted first dimension is unknown
+                return false;
+            }
+        }
+    }
+    return true;
+}
+
+ngraph::pass::ConvertBroadcast3::ConvertBroadcast3() {
+    auto broadcast = pattern::wrap_type<opset3::Broadcast>();
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
-        auto broadcast = std::dynamic_pointer_cast<ngraph::opset3::Broadcast>(m.get_match_root());
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
+        auto broadcast = std::dynamic_pointer_cast<opset3::Broadcast>(m.get_match_root());
         if (!broadcast) {
             return false;
         }
 
         auto input = broadcast->input_value(0);
-        auto target_shape = broadcast->input_value(1);
-
-        auto last_node = input.get_node_shared_ptr();
-        auto broadcast_type = broadcast->get_broadcast_spec();
+        auto target_shape_input = broadcast->input_value(1);
+        const auto & broadcast_type = broadcast->get_broadcast_spec();
+        const auto & input_element_type = input.get_element_type();
 
         if (broadcast_type == op::BroadcastType::NUMPY) {
-            last_node = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape, op::AutoBroadcastType::NUMPY);
-            ngraph::copy_runtime_info(broadcast, last_node);
+            input = std::make_shared<opset1::Broadcast>(input, target_shape_input, op::AutoBroadcastType::NUMPY);
         } else if (broadcast_type == op::BroadcastType::PDPD) {
-            last_node = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape, op::AutoBroadcastType::PDPD);
-            ngraph::copy_runtime_info(broadcast, last_node);
+            input = std::make_shared<opset1::Broadcast>(input, target_shape_input, op::AutoBroadcastType::PDPD);
         } else if (broadcast_type == op::BroadcastType::NONE) {
-            last_node = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape, broadcast->input_value(2), op::AutoBroadcastType::NONE);
-            ngraph::copy_runtime_info(broadcast, last_node);
+            input = std::make_shared<opset1::Broadcast>(input, target_shape_input, broadcast->input_value(2), op::AutoBroadcastType::NONE);
         } else if (broadcast_type == op::BroadcastType::BIDIRECTIONAL) {
-            auto constant_one = std::make_shared<ngraph::opset1::Constant>(input.get_element_type(), Shape({1}), std::vector<int>{1});
-            auto broadcast_ones = std::make_shared<ngraph::opset1::Broadcast>(constant_one, target_shape, op::AutoBroadcastType::NUMPY);
-            if (input.get_element_type() == element::boolean) {
-                last_node = std::make_shared<ngraph::opset1::LogicalOr>(input, broadcast_ones);
+            if (auto const_target_shape = std::dynamic_pointer_cast<opset1::Constant>(target_shape_input.get_node_shared_ptr())) {
+                const auto & input_shape = input.get_partial_shape();
+                const auto & target_shape = const_target_shape->cast_vector<size_t>();
+                std::vector<size_t> aligned_target_shape{target_shape};
+                if (make_compatible_shape(input_shape, aligned_target_shape)) {
+                    input = std::make_shared<opset1::Broadcast>(input,
+                            opset1::Constant::create(element::i64, Shape({aligned_target_shape.size()}), aligned_target_shape));
+                } else {
+                    input = std::make_shared<opset1::Multiply>(input,
+                            opset1::Constant::create(input_element_type, target_shape, {1}));
+                }
             } else {
-                last_node = std::make_shared<ngraph::opset1::Multiply>(input, broadcast_ones);
+                auto constant_one = opset1::Constant::create(input_element_type, {1}, {1});
+                auto broadcast_ones = std::make_shared<opset1::Broadcast>(constant_one, target_shape_input);
+                if (input_element_type == element::boolean) {
+                    input = std::make_shared<ngraph::opset1::LogicalOr>(input, broadcast_ones);
+                } else {
+                    input = std::make_shared<ngraph::opset1::Multiply>(input, broadcast_ones);
+                }
+                copy_runtime_info(broadcast, broadcast_ones);
             }
-            ngraph::copy_runtime_info(broadcast, {last_node, broadcast_ones, constant_one});
+        } else {
+            return false;
         }
 
-        last_node->set_friendly_name(broadcast->get_friendly_name());
-
-        ngraph::replace_node(m.get_match_root(), last_node);
+        input.get_node_shared_ptr()->set_friendly_name(broadcast->get_friendly_name());
+        copy_runtime_info(broadcast, input.get_node_shared_ptr());
+        replace_node(broadcast, {input});
         return true;
     };
 
-    auto m = std::make_shared<ngraph::pattern::Matcher>(broadcast, "ConvertBroadcast3");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    auto m = std::make_shared<pattern::Matcher>(broadcast, "ConvertBroadcast3");
+    register_matcher(m, callback);
 }
index 65ad9f1..45f1ff5 100644 (file)
 
 #include <gtest/gtest.h>
 
+#include "common_test_utils/test_common.hpp"
 #include <string>
+#include <sstream>
+#include <fstream>
 #include <memory>
 #include <queue>
+#include <map>
 
 #include <ngraph/function.hpp>
 #include <ngraph/opsets/opset1.hpp>
-#include <ngraph/opsets/opset3.hpp>
-#include <transformations/op_conversions/convert_broadcast3.hpp>
-#include <transformations/init_node_info.hpp>
+#include <ngraph/pass/constant_folding.hpp>
 #include <transformations/utils/utils.hpp>
+#include <transformations/init_node_info.hpp>
+#include <ngraph/pass/visualize_tree.hpp>
+#include <transformations/op_conversions/convert_broadcast3.hpp>
+#include <ngraph_ops/convolution_ie.hpp>
+#include <ngraph/pass/manager.hpp>
 
 #include "common_test_utils/ngraph_test_utils.hpp"
 
 using namespace testing;
+using namespace ngraph;
+
+using InputShape = PartialShape;
+using TargetShape = Shape;
+
+void convert_broadcast3_test(std::shared_ptr<Function> f, std::shared_ptr<Function> f_ref) {
+    pass::Manager manager;
+    manager.register_pass<pass::InitNodeInfo>();
+    manager.register_pass<pass::ConvertBroadcast3>();
+    manager.run_passes(f);
+    ASSERT_NO_THROW(check_rt_info(f));
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+class ConvertBroadcast3NUMPYTest: public CommonTestUtils::TestsCommon,
+                                  public testing::WithParamInterface<std::tuple<InputShape, TargetShape>> {
+public:
+    std::shared_ptr<Function> f, f_ref;
+
+    void SetUp() override {
+        const auto& input_shape = std::get<0>(GetParam());
+        const auto& target_shape = std::get<1>(GetParam());
+
+        f = get_initial_function(input_shape, target_shape);
+        f_ref = get_reference_broadcast(input_shape, target_shape);
+    }
+
+    std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
+                                                   const TargetShape & target_shape) {
+        auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
+        auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{target_shape.size()}, target_shape);
+        auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::NUMPY);
+
+        return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
+    }
+
+    std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
+                                                      const TargetShape & target_shape) {
+        auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
+        auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{target_shape.size()}, target_shape);
+        auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape_node, op::AutoBroadcastType::NUMPY);
+
+        return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
+    }
+};
+
+class ConvertBroadcast3BIDIRECTMulTest: public CommonTestUtils::TestsCommon,
+                                        public testing::WithParamInterface<std::tuple<InputShape, TargetShape>> {
+public:
+    std::shared_ptr<Function> f, f_ref;
+
+    void SetUp() override {
+        const auto& input_shape = std::get<0>(GetParam());
+        const auto& target_shape = std::get<1>(GetParam());
+
+        f = get_initial_function(input_shape, target_shape);
+        f_ref = get_reference_broadcast(input_shape, target_shape);
+    }
+
+    std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
+                                                   const TargetShape & target_shape) {
+        auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
+        auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{target_shape.size()}, target_shape);
+        auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::BIDIRECTIONAL);
+
+        return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
+    }
+
+    std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
+                                                      const TargetShape & target_shape) {
+        auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
+        auto const_node = ngraph::opset1::Constant::create(ngraph::element::f32, Shape{target_shape}, {1});
+        auto mul = std::make_shared<ngraph::opset1::Multiply>(input, const_node);
+
+        return std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
+    }
+};
+
+class ConvertBroadcast3BIDIRECTBroadcastTest: public CommonTestUtils::TestsCommon,
+                                              public testing::WithParamInterface<std::tuple<InputShape, TargetShape, TargetShape>> {
+public:
+    std::shared_ptr<Function> f, f_ref;
+
+    void SetUp() override {
+        const auto& input_shape = std::get<0>(GetParam());
+        const auto& target_shape = std::get<1>(GetParam());
+        const auto& aligned_target_shape = std::get<2>(GetParam());
+
+        f = get_initial_function(input_shape, target_shape);
+        f_ref = get_reference_broadcast(input_shape, aligned_target_shape);
+    }
+
+    std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
+                                                   const TargetShape & target_shape) {
+        auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
+        auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{target_shape.size()}, target_shape);
+        auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::BIDIRECTIONAL);
+
+        return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
+    }
+
+    std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
+                                                      const TargetShape & aligned_target_shape) {
+        auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
+        auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, Shape{aligned_target_shape.size()}, aligned_target_shape);
+        auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape_node, op::AutoBroadcastType::NUMPY);
+
+        return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
+    }
+};
+
+class ConvertBroadcast3BIDIRECTBroadcastMultiplyTest: public CommonTestUtils::TestsCommon,
+                                                      public testing::WithParamInterface<std::tuple<InputShape, TargetShape>> {
+public:
+    std::shared_ptr<Function> f, f_ref;
+
+    void SetUp() override {
+        const auto& input_shape = std::get<0>(GetParam());
+        const auto& target_shape = std::get<1>(GetParam());
+
+        f = get_initial_function(input_shape, target_shape);
+        f_ref = get_reference_broadcast(input_shape, target_shape);
+    }
+
+    std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
+                                                   const TargetShape & target_shape) {
+        auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
+        auto target_shape_node = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i64, target_shape);
+        auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::BIDIRECTIONAL);
+
+        return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input, target_shape_node});
+    }
+
+    std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
+                                                      const TargetShape & target_shape) {
+        auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
+        auto target_shape_node = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i64, target_shape);
+        auto constant_one = opset1::Constant::create(ngraph::element::f32, {1}, {1});
+        auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(constant_one, target_shape_node, op::AutoBroadcastType::NUMPY);
+        auto mul = std::make_shared<ngraph::opset1::Multiply>(input, broadcast);
+        return std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input, target_shape_node});
+    }
+};
+
+class ConvertBroadcast3BIDIRECTBroadcastLogicalOrTest: public CommonTestUtils::TestsCommon,
+                                                       public testing::WithParamInterface<std::tuple<InputShape, TargetShape>> {
+public:
+    std::shared_ptr<Function> f, f_ref;
+
+    void SetUp() override {
+        const auto& input_shape = std::get<0>(GetParam());
+        const auto& target_shape = std::get<1>(GetParam());
+
+        f = get_initial_function(input_shape, target_shape);
+        f_ref = get_reference_broadcast(input_shape, target_shape);
+    }
+
+    std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
+                                                   const TargetShape & target_shape) {
+        auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::boolean, input_shape);
+        auto target_shape_node = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i64, target_shape);
+        auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::BIDIRECTIONAL);
+
+        return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input, target_shape_node});
+    }
+
+    std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
+                                                      const TargetShape & target_shape) {
+        auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::boolean, input_shape);
+        auto target_shape_node = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i64, target_shape);
+        auto constant_one = opset1::Constant::create(ngraph::element::boolean, {1}, {1});
+        auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(constant_one, target_shape_node, op::AutoBroadcastType::NUMPY);
+        auto mul = std::make_shared<ngraph::opset1::LogicalOr>(input, broadcast);
+        return std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input, target_shape_node});
+    }
+};
+
+TEST_P(ConvertBroadcast3NUMPYTest, CompareFunctions) {
+    convert_broadcast3_test(f, f_ref);
+}
+
+TEST_P(ConvertBroadcast3BIDIRECTMulTest, CompareFunctions) {
+    convert_broadcast3_test(f, f_ref);
+}
+
+TEST_P(ConvertBroadcast3BIDIRECTBroadcastTest, CompareFunctions) {
+    convert_broadcast3_test(f, f_ref);
+}
+
+TEST_P(ConvertBroadcast3BIDIRECTBroadcastMultiplyTest, CompareFunctions) {
+    convert_broadcast3_test(f, f_ref);
+}
+
+TEST_P(ConvertBroadcast3BIDIRECTBroadcastLogicalOrTest, CompareFunctions) {
+    convert_broadcast3_test(f, f_ref);
+}
+
+INSTANTIATE_TEST_CASE_P(ConvertBroadcast3NUMPY, ConvertBroadcast3NUMPYTest,
+        testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{1, 2, 3, 4, 5}),
+                        std::make_tuple(InputShape{DYN, 3, 64, 64, 64},      TargetShape{8, 3, 64, 64, 64}),
+                        std::make_tuple(InputShape{2, DYN, 64, 64, 64},      TargetShape{2, 3, 64, 64, 64}),
+                        std::make_tuple(InputShape{3, 1, DYN, 64, 64},       TargetShape{3, 3, 3, 64, 64}),
+                        std::make_tuple(InputShape{3, 3, 64, DYN, 64},       TargetShape{3, 3, 64, 64, 64}),
+                        std::make_tuple(InputShape{3, 3, 64, 64, DYN},       TargetShape{3, 3, 64, 64, 3}),
+                        std::make_tuple(InputShape{1, 3, 64, 64},       TargetShape{6, 3, 64, 64}),
+                        std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{7, 3, 1, 1}),
+                        std::make_tuple(InputShape{DYN, 3, 64, 64},     TargetShape{8, 3, 64, 64}),
+                        std::make_tuple(InputShape{2, DYN, 64, 64},     TargetShape{2, 3, 64, 64}),
+                        std::make_tuple(InputShape{3, 3, DYN, 64},      TargetShape{3, 3, 3, 64}),
+                        std::make_tuple(InputShape{3, 3, 64, DYN},      TargetShape{3, 3, 64, 4}),
+                        std::make_tuple(InputShape{DYN, DYN, DYN},      TargetShape{5, 3, 1}),
+                        std::make_tuple(InputShape{DYN, 3, 10},         TargetShape{3, 3, 10}),
+                        std::make_tuple(InputShape{2, DYN, 9},          TargetShape{2, 3, 9}),
+                        std::make_tuple(InputShape{3, 3, DYN},          TargetShape{3, 3, 3})));
+
+INSTANTIATE_TEST_CASE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTMulTest,
+        testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{1, 2, 3, 4, 5}),
+                        std::make_tuple(InputShape{DYN, 3, 64, 64, 64},      TargetShape{1, 3, 64, 64, 64}),
+                        std::make_tuple(InputShape{2, DYN, 64, 64, 64},      TargetShape{2, 1, 64, 64, 64}),
+                        std::make_tuple(InputShape{3, 1, DYN, 64, 64},       TargetShape{3, 3, 1, 64, 64}),
+                        std::make_tuple(InputShape{DYN, 1, DYN, 64, DYN},    TargetShape{3, 3, 3, 64, 1}),
+                        std::make_tuple(InputShape{3, 3, 64, DYN, 64},       TargetShape{3, 3, 64, 1, 64}),
+                        std::make_tuple(InputShape{3, 3, 64, 64, DYN},       TargetShape{3, 3, 64, 64, 1}),
+                        std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{7, 3, 1, 1}),
+                        std::make_tuple(InputShape{DYN, 3, 64, 64},     TargetShape{1, 3, 64, 64}),
+                        std::make_tuple(InputShape{2, DYN, 64, 64},     TargetShape{2, 1, 64, 64}),
+                        std::make_tuple(InputShape{3, 3, DYN, 64},      TargetShape{3, 3, 1, 64}),
+                        std::make_tuple(InputShape{DYN, 3, DYN, 64},    TargetShape{3, 3, 64}),
+                        std::make_tuple(InputShape{3, 3, 64, DYN},      TargetShape{3, 3, 64, 1}),
+                        std::make_tuple(InputShape{DYN, DYN, DYN},      TargetShape{5, 3, 1}),
+                        std::make_tuple(InputShape{DYN, 3, 10},         TargetShape{1, 3, 10}),
+                        std::make_tuple(InputShape{DYN, 3, 10},         TargetShape{10}),
+                        std::make_tuple(InputShape{2, DYN, 9},          TargetShape{2, 1, 9}),
+                        std::make_tuple(InputShape{3, 3, DYN},          TargetShape{3, 3, 1})));
+
+INSTANTIATE_TEST_CASE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTBroadcastTest,
+        testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{2, 2, 3, 4, 5},    TargetShape{2, 2, 3, 4, 5}),
+                        std::make_tuple(InputShape{DYN, 3, 64, 64, 64},      TargetShape{3, 3, 64, 64, 64}, TargetShape{3, 3, 64, 64, 64}),
+                        std::make_tuple(InputShape{2, DYN, 64, 64, 64},      TargetShape{2, 3, 64, 64, 1},  TargetShape{2, 3, 64, 64, 64}),
+                        std::make_tuple(InputShape{3, 1, DYN, 64, 64},       TargetShape{1, 3, 3, 64, 64},  TargetShape{3, 3, 3, 64, 64}),
+                        std::make_tuple(InputShape{3, 1, DYN, 64, DYN},      TargetShape{1, 3, 3, 64, 3},   TargetShape{3, 3, 3, 64, 3}),
+                        std::make_tuple(InputShape{3, 3, 64, DYN, 64},       TargetShape{1, 1, 1, 2, 64},   TargetShape{3, 3, 64, 2, 64}),
+                        std::make_tuple(InputShape{3, 3, 64, 64, DYN},       TargetShape{3, 3, 64, 64, 3},  TargetShape{3, 3, 64, 64, 3}),
+                        std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{7, 3, 2, 3},    TargetShape{7, 3, 2, 3}),
+                        std::make_tuple(InputShape{DYN, 3, 64, 64},     TargetShape{3, 3, 64, 64},  TargetShape{3, 3, 64, 64}),
+                        std::make_tuple(InputShape{2, DYN, 64, 64},     TargetShape{2, 3, 64, 64},  TargetShape{2, 3, 64, 64}),
+                        std::make_tuple(InputShape{3, 3, DYN, 64},      TargetShape{1, 3, 1},       TargetShape{3, 3, 3, 64}),
+                        std::make_tuple(InputShape{3, 3, DYN, 64},      TargetShape{3, 3, 64},      TargetShape{3, 3, 3, 64}),
+                        std::make_tuple(InputShape{3, 3, 64, DYN},      TargetShape{64},            TargetShape{3, 3, 64, 64}),
+                        std::make_tuple(InputShape{DYN, DYN, DYN},      TargetShape{5, 3, 3},       TargetShape{5, 3, 3}),
+                        std::make_tuple(InputShape{1, 3, DYN},          TargetShape{3, 3, 10},      TargetShape{3, 3, 10}),
+                        std::make_tuple(InputShape{2, DYN, 9},          TargetShape{2, 2, 1},       TargetShape{2, 2, 9}),
+                        std::make_tuple(InputShape{3, 3, DYN},          TargetShape{3},             TargetShape{3, 3, 3})));
+
+INSTANTIATE_TEST_CASE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTBroadcastMultiplyTest,
+        testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{5}),
+                        std::make_tuple(InputShape{DYN, 3, 64, 64, 64},      TargetShape{4}),
+                        std::make_tuple(InputShape{2, DYN, 64, 64, 64},      TargetShape{3}),
+                        std::make_tuple(InputShape{3, 1, DYN, 64, 64},       TargetShape{2}),
+                        std::make_tuple(InputShape{3, 3, 64, DYN, 64},       TargetShape{1}),
+                        std::make_tuple(InputShape{1, 3, 64, 64},       TargetShape{5}),
+                        std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{4}),
+                        std::make_tuple(InputShape{DYN, 3, 64, 64},     TargetShape{3}),
+                        std::make_tuple(InputShape{2, DYN, 64, 64},     TargetShape{2}),
+                        std::make_tuple(InputShape{3, 3, DYN, 64},      TargetShape{1}),
+                        std::make_tuple(InputShape{DYN, DYN, DYN},      TargetShape{5}),
+                        std::make_tuple(InputShape{DYN, 3, 10},         TargetShape{4}),
+                        std::make_tuple(InputShape{2, DYN, 9},          TargetShape{3}),
+                        std::make_tuple(InputShape{3, 3, DYN},          TargetShape{2})));
+
+INSTANTIATE_TEST_CASE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTBroadcastLogicalOrTest,
+        testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{5}),
+                        std::make_tuple(InputShape{DYN, 3, 64, 64, 64},      TargetShape{4}),
+                        std::make_tuple(InputShape{2, DYN, 64, 64, 64},      TargetShape{3}),
+                        std::make_tuple(InputShape{3, 1, DYN, 64, 64},       TargetShape{2}),
+                        std::make_tuple(InputShape{3, 3, 64, DYN, 64},       TargetShape{1}),
+                        std::make_tuple(InputShape{1, 3, 64, 64},       TargetShape{5}),
+                        std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{4}),
+                        std::make_tuple(InputShape{DYN, 3, 64, 64},     TargetShape{3}),
+                        std::make_tuple(InputShape{2, DYN, 64, 64},     TargetShape{2}),
+                        std::make_tuple(InputShape{3, 3, DYN, 64},      TargetShape{1}),
+                        std::make_tuple(InputShape{DYN, DYN, DYN},      TargetShape{5}),
+                        std::make_tuple(InputShape{DYN, 3, 10},         TargetShape{4}),
+                        std::make_tuple(InputShape{2, DYN, 9},          TargetShape{3}),
+                        std::make_tuple(InputShape{3, 3, DYN},          TargetShape{2})));
+
 
 // Broadcast-3 is converted directly to Broadcast-1 for modes NUMPY, NONE and PDPD
 TEST(TransformationTests, ConvertBroadcast3WithNumpyModeToBroadcast1) {
@@ -30,8 +324,10 @@ TEST(TransformationTests, ConvertBroadcast3WithNumpyModeToBroadcast1) {
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input1});
 
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertBroadcast3().run_on_function(f);
+        pass::Manager manager;
+        manager.register_pass<pass::InitNodeInfo>();
+        manager.register_pass<pass::ConvertBroadcast3>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
@@ -63,8 +359,10 @@ TEST(TransformationTests, ConvertBroadcast3WithPDPDModeToBroadcast1) {
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input1});
 
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertBroadcast3().run_on_function(f);
+        pass::Manager manager;
+        manager.register_pass<pass::InitNodeInfo>();
+        manager.register_pass<pass::ConvertBroadcast3>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
@@ -97,8 +395,10 @@ TEST(TransformationTests, ConvertBroadcast3WithExplicitModeToBroadcast1) {
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input1});
 
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertBroadcast3().run_on_function(f);
+        pass::Manager manager;
+        manager.register_pass<pass::InitNodeInfo>();
+        manager.register_pass<pass::ConvertBroadcast3>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
@@ -131,20 +431,20 @@ TEST(TransformationTests, ConvertBroadcast3WithBidirectionalModeToBroadcast1) {
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input1});
 
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertBroadcast3().run_on_function(f);
+        pass::Manager manager;
+        manager.register_pass<pass::InitNodeInfo>();
+        manager.register_pass<pass::ConvertBroadcast3>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
     {
         auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 2});
-        auto target_shape = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{3}, std::vector<int64_t>{3, 5, 1});
-        auto constant_one = std::make_shared<ngraph::opset1::Constant>(input->get_output_element_type(0), ngraph::Shape({1}), std::vector<int>{1});
-        auto broadcast_ones = std::make_shared<ngraph::opset1::Broadcast>(constant_one, target_shape, ngraph::op::AutoBroadcastType::NUMPY);
-        auto multiply = std::make_shared<ngraph::opset1::Multiply>(input, broadcast_ones);
-        multiply->set_friendly_name("broadcast");
+        auto target_shape = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{3}, std::vector<int64_t>{3, 5, 2});
+        auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape, ngraph::op::AutoBroadcastType::NUMPY);
+        broadcast->set_friendly_name("broadcast");
 
-        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{multiply}, ngraph::ParameterVector{input});
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
     }
 
     auto res = compare_functions(f, f_ref);