#include <gtest/gtest.h>
+#include "common_test_utils/test_common.hpp"
#include <string>
+#include <sstream>
+#include <fstream>
#include <memory>
#include <queue>
+#include <map>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
-#include <ngraph/opsets/opset3.hpp>
-#include <transformations/op_conversions/convert_broadcast3.hpp>
-#include <transformations/init_node_info.hpp>
+#include <ngraph/pass/constant_folding.hpp>
#include <transformations/utils/utils.hpp>
+#include <transformations/init_node_info.hpp>
+#include <ngraph/pass/visualize_tree.hpp>
+#include <transformations/op_conversions/convert_broadcast3.hpp>
+#include <ngraph_ops/convolution_ie.hpp>
+#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
+using namespace ngraph;
+
+using InputShape = PartialShape;
+using TargetShape = Shape;
+
+void convert_broadcast3_test(std::shared_ptr<Function> f, std::shared_ptr<Function> f_ref) {
+ pass::Manager manager;
+ manager.register_pass<pass::InitNodeInfo>();
+ manager.register_pass<pass::ConvertBroadcast3>();
+ manager.run_passes(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+ auto res = compare_functions(f, f_ref);
+ ASSERT_TRUE(res.first) << res.second;
+}
+
+class ConvertBroadcast3NUMPYTest: public CommonTestUtils::TestsCommon,
+ public testing::WithParamInterface<std::tuple<InputShape, TargetShape>> {
+public:
+ std::shared_ptr<Function> f, f_ref;
+
+ void SetUp() override {
+ const auto& input_shape = std::get<0>(GetParam());
+ const auto& target_shape = std::get<1>(GetParam());
+
+ f = get_initial_function(input_shape, target_shape);
+ f_ref = get_reference_broadcast(input_shape, target_shape);
+ }
+
+ std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
+ const TargetShape & target_shape) {
+ auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
+ auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{target_shape.size()}, target_shape);
+ auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::NUMPY);
+
+ return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
+ }
+
+ std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
+ const TargetShape & target_shape) {
+ auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
+ auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{target_shape.size()}, target_shape);
+ auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape_node, op::AutoBroadcastType::NUMPY);
+
+ return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
+ }
+};
+
+class ConvertBroadcast3BIDIRECTMulTest: public CommonTestUtils::TestsCommon,
+ public testing::WithParamInterface<std::tuple<InputShape, TargetShape>> {
+public:
+ std::shared_ptr<Function> f, f_ref;
+
+ void SetUp() override {
+ const auto& input_shape = std::get<0>(GetParam());
+ const auto& target_shape = std::get<1>(GetParam());
+
+ f = get_initial_function(input_shape, target_shape);
+ f_ref = get_reference_broadcast(input_shape, target_shape);
+ }
+
+ std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
+ const TargetShape & target_shape) {
+ auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
+ auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{target_shape.size()}, target_shape);
+ auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::BIDIRECTIONAL);
+
+ return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
+ }
+
+ std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
+ const TargetShape & target_shape) {
+ auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
+ auto const_node = ngraph::opset1::Constant::create(ngraph::element::f32, Shape{target_shape}, {1});
+ auto mul = std::make_shared<ngraph::opset1::Multiply>(input, const_node);
+
+ return std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
+ }
+};
+
+class ConvertBroadcast3BIDIRECTBroadcastTest: public CommonTestUtils::TestsCommon,
+ public testing::WithParamInterface<std::tuple<InputShape, TargetShape, TargetShape>> {
+public:
+ std::shared_ptr<Function> f, f_ref;
+
+ void SetUp() override {
+ const auto& input_shape = std::get<0>(GetParam());
+ const auto& target_shape = std::get<1>(GetParam());
+ const auto& aligned_target_shape = std::get<2>(GetParam());
+
+ f = get_initial_function(input_shape, target_shape);
+ f_ref = get_reference_broadcast(input_shape, aligned_target_shape);
+ }
+
+ std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
+ const TargetShape & target_shape) {
+ auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
+ auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{target_shape.size()}, target_shape);
+ auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::BIDIRECTIONAL);
+
+ return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
+ }
+
+ std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
+ const TargetShape & aligned_target_shape) {
+ auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
+ auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, Shape{aligned_target_shape.size()}, aligned_target_shape);
+ auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape_node, op::AutoBroadcastType::NUMPY);
+
+ return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
+ }
+};
+
+class ConvertBroadcast3BIDIRECTBroadcastMultiplyTest: public CommonTestUtils::TestsCommon,
+ public testing::WithParamInterface<std::tuple<InputShape, TargetShape>> {
+public:
+ std::shared_ptr<Function> f, f_ref;
+
+ void SetUp() override {
+ const auto& input_shape = std::get<0>(GetParam());
+ const auto& target_shape = std::get<1>(GetParam());
+
+ f = get_initial_function(input_shape, target_shape);
+ f_ref = get_reference_broadcast(input_shape, target_shape);
+ }
+
+ std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
+ const TargetShape & target_shape) {
+ auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
+ auto target_shape_node = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i64, target_shape);
+ auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::BIDIRECTIONAL);
+
+ return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input, target_shape_node});
+ }
+
+ std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
+ const TargetShape & target_shape) {
+ auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
+ auto target_shape_node = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i64, target_shape);
+ auto constant_one = opset1::Constant::create(ngraph::element::f32, {1}, {1});
+ auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(constant_one, target_shape_node, op::AutoBroadcastType::NUMPY);
+ auto mul = std::make_shared<ngraph::opset1::Multiply>(input, broadcast);
+ return std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input, target_shape_node});
+ }
+};
+
+class ConvertBroadcast3BIDIRECTBroadcastLogicalOrTest: public CommonTestUtils::TestsCommon,
+ public testing::WithParamInterface<std::tuple<InputShape, TargetShape>> {
+public:
+ std::shared_ptr<Function> f, f_ref;
+
+ void SetUp() override {
+ const auto& input_shape = std::get<0>(GetParam());
+ const auto& target_shape = std::get<1>(GetParam());
+
+ f = get_initial_function(input_shape, target_shape);
+ f_ref = get_reference_broadcast(input_shape, target_shape);
+ }
+
+ std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
+ const TargetShape & target_shape) {
+ auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::boolean, input_shape);
+ auto target_shape_node = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i64, target_shape);
+ auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::BIDIRECTIONAL);
+
+ return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input, target_shape_node});
+ }
+
+ std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
+ const TargetShape & target_shape) {
+ auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::boolean, input_shape);
+ auto target_shape_node = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i64, target_shape);
+ auto constant_one = opset1::Constant::create(ngraph::element::boolean, {1}, {1});
+ auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(constant_one, target_shape_node, op::AutoBroadcastType::NUMPY);
+ auto mul = std::make_shared<ngraph::opset1::LogicalOr>(input, broadcast);
+ return std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input, target_shape_node});
+ }
+};
+
+TEST_P(ConvertBroadcast3NUMPYTest, CompareFunctions) {
+ convert_broadcast3_test(f, f_ref);
+}
+
+TEST_P(ConvertBroadcast3BIDIRECTMulTest, CompareFunctions) {
+ convert_broadcast3_test(f, f_ref);
+}
+
+TEST_P(ConvertBroadcast3BIDIRECTBroadcastTest, CompareFunctions) {
+ convert_broadcast3_test(f, f_ref);
+}
+
+TEST_P(ConvertBroadcast3BIDIRECTBroadcastMultiplyTest, CompareFunctions) {
+ convert_broadcast3_test(f, f_ref);
+}
+
+TEST_P(ConvertBroadcast3BIDIRECTBroadcastLogicalOrTest, CompareFunctions) {
+ convert_broadcast3_test(f, f_ref);
+}
+
+INSTANTIATE_TEST_CASE_P(ConvertBroadcast3NUMPY, ConvertBroadcast3NUMPYTest,
+ testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{1, 2, 3, 4, 5}),
+ std::make_tuple(InputShape{DYN, 3, 64, 64, 64}, TargetShape{8, 3, 64, 64, 64}),
+ std::make_tuple(InputShape{2, DYN, 64, 64, 64}, TargetShape{2, 3, 64, 64, 64}),
+ std::make_tuple(InputShape{3, 1, DYN, 64, 64}, TargetShape{3, 3, 3, 64, 64}),
+ std::make_tuple(InputShape{3, 3, 64, DYN, 64}, TargetShape{3, 3, 64, 64, 64}),
+ std::make_tuple(InputShape{3, 3, 64, 64, DYN}, TargetShape{3, 3, 64, 64, 3}),
+ std::make_tuple(InputShape{1, 3, 64, 64}, TargetShape{6, 3, 64, 64}),
+ std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{7, 3, 1, 1}),
+ std::make_tuple(InputShape{DYN, 3, 64, 64}, TargetShape{8, 3, 64, 64}),
+ std::make_tuple(InputShape{2, DYN, 64, 64}, TargetShape{2, 3, 64, 64}),
+ std::make_tuple(InputShape{3, 3, DYN, 64}, TargetShape{3, 3, 3, 64}),
+ std::make_tuple(InputShape{3, 3, 64, DYN}, TargetShape{3, 3, 64, 4}),
+ std::make_tuple(InputShape{DYN, DYN, DYN}, TargetShape{5, 3, 1}),
+ std::make_tuple(InputShape{DYN, 3, 10}, TargetShape{3, 3, 10}),
+ std::make_tuple(InputShape{2, DYN, 9}, TargetShape{2, 3, 9}),
+ std::make_tuple(InputShape{3, 3, DYN}, TargetShape{3, 3, 3})));
+
+INSTANTIATE_TEST_CASE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTMulTest,
+ testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{1, 2, 3, 4, 5}),
+ std::make_tuple(InputShape{DYN, 3, 64, 64, 64}, TargetShape{1, 3, 64, 64, 64}),
+ std::make_tuple(InputShape{2, DYN, 64, 64, 64}, TargetShape{2, 1, 64, 64, 64}),
+ std::make_tuple(InputShape{3, 1, DYN, 64, 64}, TargetShape{3, 3, 1, 64, 64}),
+ std::make_tuple(InputShape{DYN, 1, DYN, 64, DYN}, TargetShape{3, 3, 3, 64, 1}),
+ std::make_tuple(InputShape{3, 3, 64, DYN, 64}, TargetShape{3, 3, 64, 1, 64}),
+ std::make_tuple(InputShape{3, 3, 64, 64, DYN}, TargetShape{3, 3, 64, 64, 1}),
+ std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{7, 3, 1, 1}),
+ std::make_tuple(InputShape{DYN, 3, 64, 64}, TargetShape{1, 3, 64, 64}),
+ std::make_tuple(InputShape{2, DYN, 64, 64}, TargetShape{2, 1, 64, 64}),
+ std::make_tuple(InputShape{3, 3, DYN, 64}, TargetShape{3, 3, 1, 64}),
+ std::make_tuple(InputShape{DYN, 3, DYN, 64}, TargetShape{3, 3, 64}),
+ std::make_tuple(InputShape{3, 3, 64, DYN}, TargetShape{3, 3, 64, 1}),
+ std::make_tuple(InputShape{DYN, DYN, DYN}, TargetShape{5, 3, 1}),
+ std::make_tuple(InputShape{DYN, 3, 10}, TargetShape{1, 3, 10}),
+ std::make_tuple(InputShape{DYN, 3, 10}, TargetShape{10}),
+ std::make_tuple(InputShape{2, DYN, 9}, TargetShape{2, 1, 9}),
+ std::make_tuple(InputShape{3, 3, DYN}, TargetShape{3, 3, 1})));
+
+INSTANTIATE_TEST_CASE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTBroadcastTest,
+ testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{2, 2, 3, 4, 5}, TargetShape{2, 2, 3, 4, 5}),
+ std::make_tuple(InputShape{DYN, 3, 64, 64, 64}, TargetShape{3, 3, 64, 64, 64}, TargetShape{3, 3, 64, 64, 64}),
+ std::make_tuple(InputShape{2, DYN, 64, 64, 64}, TargetShape{2, 3, 64, 64, 1}, TargetShape{2, 3, 64, 64, 64}),
+ std::make_tuple(InputShape{3, 1, DYN, 64, 64}, TargetShape{1, 3, 3, 64, 64}, TargetShape{3, 3, 3, 64, 64}),
+ std::make_tuple(InputShape{3, 1, DYN, 64, DYN}, TargetShape{1, 3, 3, 64, 3}, TargetShape{3, 3, 3, 64, 3}),
+ std::make_tuple(InputShape{3, 3, 64, DYN, 64}, TargetShape{1, 1, 1, 2, 64}, TargetShape{3, 3, 64, 2, 64}),
+ std::make_tuple(InputShape{3, 3, 64, 64, DYN}, TargetShape{3, 3, 64, 64, 3}, TargetShape{3, 3, 64, 64, 3}),
+ std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{7, 3, 2, 3}, TargetShape{7, 3, 2, 3}),
+ std::make_tuple(InputShape{DYN, 3, 64, 64}, TargetShape{3, 3, 64, 64}, TargetShape{3, 3, 64, 64}),
+ std::make_tuple(InputShape{2, DYN, 64, 64}, TargetShape{2, 3, 64, 64}, TargetShape{2, 3, 64, 64}),
+ std::make_tuple(InputShape{3, 3, DYN, 64}, TargetShape{1, 3, 1}, TargetShape{3, 3, 3, 64}),
+ std::make_tuple(InputShape{3, 3, DYN, 64}, TargetShape{3, 3, 64}, TargetShape{3, 3, 3, 64}),
+ std::make_tuple(InputShape{3, 3, 64, DYN}, TargetShape{64}, TargetShape{3, 3, 64, 64}),
+ std::make_tuple(InputShape{DYN, DYN, DYN}, TargetShape{5, 3, 3}, TargetShape{5, 3, 3}),
+ std::make_tuple(InputShape{1, 3, DYN}, TargetShape{3, 3, 10}, TargetShape{3, 3, 10}),
+ std::make_tuple(InputShape{2, DYN, 9}, TargetShape{2, 2, 1}, TargetShape{2, 2, 9}),
+ std::make_tuple(InputShape{3, 3, DYN}, TargetShape{3}, TargetShape{3, 3, 3})));
+
+INSTANTIATE_TEST_CASE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTBroadcastMultiplyTest,
+ testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{5}),
+ std::make_tuple(InputShape{DYN, 3, 64, 64, 64}, TargetShape{4}),
+ std::make_tuple(InputShape{2, DYN, 64, 64, 64}, TargetShape{3}),
+ std::make_tuple(InputShape{3, 1, DYN, 64, 64}, TargetShape{2}),
+ std::make_tuple(InputShape{3, 3, 64, DYN, 64}, TargetShape{1}),
+ std::make_tuple(InputShape{1, 3, 64, 64}, TargetShape{5}),
+ std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{4}),
+ std::make_tuple(InputShape{DYN, 3, 64, 64}, TargetShape{3}),
+ std::make_tuple(InputShape{2, DYN, 64, 64}, TargetShape{2}),
+ std::make_tuple(InputShape{3, 3, DYN, 64}, TargetShape{1}),
+ std::make_tuple(InputShape{DYN, DYN, DYN}, TargetShape{5}),
+ std::make_tuple(InputShape{DYN, 3, 10}, TargetShape{4}),
+ std::make_tuple(InputShape{2, DYN, 9}, TargetShape{3}),
+ std::make_tuple(InputShape{3, 3, DYN}, TargetShape{2})));
+
+INSTANTIATE_TEST_CASE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTBroadcastLogicalOrTest,
+ testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{5}),
+ std::make_tuple(InputShape{DYN, 3, 64, 64, 64}, TargetShape{4}),
+ std::make_tuple(InputShape{2, DYN, 64, 64, 64}, TargetShape{3}),
+ std::make_tuple(InputShape{3, 1, DYN, 64, 64}, TargetShape{2}),
+ std::make_tuple(InputShape{3, 3, 64, DYN, 64}, TargetShape{1}),
+ std::make_tuple(InputShape{1, 3, 64, 64}, TargetShape{5}),
+ std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{4}),
+ std::make_tuple(InputShape{DYN, 3, 64, 64}, TargetShape{3}),
+ std::make_tuple(InputShape{2, DYN, 64, 64}, TargetShape{2}),
+ std::make_tuple(InputShape{3, 3, DYN, 64}, TargetShape{1}),
+ std::make_tuple(InputShape{DYN, DYN, DYN}, TargetShape{5}),
+ std::make_tuple(InputShape{DYN, 3, 10}, TargetShape{4}),
+ std::make_tuple(InputShape{2, DYN, 9}, TargetShape{3}),
+ std::make_tuple(InputShape{3, 3, DYN}, TargetShape{2})));
+
// Broadcast-3 is converted directly to Broadcast-1 for modes NUMPY, NONE and PDPD
TEST(TransformationTests, ConvertBroadcast3WithNumpyModeToBroadcast1) {
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input1});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertBroadcast3().run_on_function(f);
+ pass::Manager manager;
+ manager.register_pass<pass::InitNodeInfo>();
+ manager.register_pass<pass::ConvertBroadcast3>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input1});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertBroadcast3().run_on_function(f);
+ pass::Manager manager;
+ manager.register_pass<pass::InitNodeInfo>();
+ manager.register_pass<pass::ConvertBroadcast3>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input1});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertBroadcast3().run_on_function(f);
+ pass::Manager manager;
+ manager.register_pass<pass::InitNodeInfo>();
+ manager.register_pass<pass::ConvertBroadcast3>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input1});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertBroadcast3().run_on_function(f);
+ pass::Manager manager;
+ manager.register_pass<pass::InitNodeInfo>();
+ manager.register_pass<pass::ConvertBroadcast3>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 2});
- auto target_shape = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{3}, std::vector<int64_t>{3, 5, 1});
- auto constant_one = std::make_shared<ngraph::opset1::Constant>(input->get_output_element_type(0), ngraph::Shape({1}), std::vector<int>{1});
- auto broadcast_ones = std::make_shared<ngraph::opset1::Broadcast>(constant_one, target_shape, ngraph::op::AutoBroadcastType::NUMPY);
- auto multiply = std::make_shared<ngraph::opset1::Multiply>(input, broadcast_ones);
- multiply->set_friendly_name("broadcast");
+ auto target_shape = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{3}, std::vector<int64_t>{3, 5, 2});
+ auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape, ngraph::op::AutoBroadcastType::NUMPY);
+ broadcast->set_friendly_name("broadcast");
- f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{multiply}, ngraph::ParameterVector{input});
+ f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);