From: Gleb Kazantaev Date: Tue, 26 May 2020 07:24:52 +0000 (+0300) Subject: Updated Mul->Add conversion to support dynamic shapes (#512) X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d3764a75635c709e5e0b390cd09482a392b0a5f5;p=platform%2Fupstream%2Fdldt.git Updated Mul->Add conversion to support dynamic shapes (#512) * Updated Mul Add conversion to support dynamic shapes * Keep changes * Fix for cases when eltwise performs broadcasting via Constant * Added comments;Fixed eltwise shape infer; Updated tests --- diff --git a/inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_mul_add_to_scaleshift_or_power.hpp b/inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_mul_add_to_scaleshift_or_power.hpp index 50d29e2..ce8107d 100644 --- a/inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_mul_add_to_scaleshift_or_power.hpp +++ b/inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_mul_add_to_scaleshift_or_power.hpp @@ -35,5 +35,13 @@ enum class CONVERSION_RESULT { NONE }; +/* + * check_constant function checks how given constant performs elementwise operation with given input + * CONVERSION_RESULT has several types: + * SCALE_SHIFT - constant applies only per-channel + * POWER - constant applies as single value + * NONE - default return value + */ + INFERENCE_ENGINE_API_CPP(CONVERSION_RESULT) -check_constant(const std::shared_ptr & constant, const ngraph::Shape & shape); +check_constant(const std::shared_ptr & constant, const ngraph::PartialShape & shape); diff --git a/inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_mul_or_add_finally.hpp b/inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_mul_or_add_finally.hpp index adc6d47..b85d974 100644 --- a/inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_mul_or_add_finally.hpp +++ b/inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_mul_or_add_finally.hpp @@ -70,10 +70,13 @@ ngraph::graph_rewrite_callback get_callback() { "Unsupported template parameter. Only Add or Multiply allowed!"); auto lin_op = std::dynamic_pointer_cast (m.get_match_root()); - if (!lin_op) { + if (!lin_op || lin_op->output(0).get_partial_shape().rank().is_dynamic()) { return false; } + const auto output_shape = lin_op->output(0).get_partial_shape(); + const auto output_shape_rank = output_shape.rank().get_length(); + if (!lin_op->get_element_type().is_real()) { return convert_to_eltwise(lin_op, lin_op->input(0).get_source_output(), @@ -93,39 +96,58 @@ ngraph::graph_rewrite_callback get_callback() { } } - // Check that eltwise is not useless otherwise we remove it - if ((std::is_same() && ngraph::op::util::constantIsEqualTo(const_node, 0)) || - (std::is_same() && ngraph::op::util::constantIsEqualTo(const_node, 1))) { - bool has_result_output = false; - for (const auto & output : lin_op->output(0).get_target_inputs()) { - if (dynamic_cast(output.get_node())) { - has_result_output = true; - } + /* This lambda checks data and constant shapes for broadcasting + For example: + 1. data_shape{1, 64, 64} and const_shape{64, 1, 1} - constant broadcasts data_shape zero dimension + 2. data_shape{DYN, 64, 64} and const_shape{1, 1, 64} - constant do not broadcasts data_shape + 3. data_shape{64, 64} and const_shape{1, 1, 1} - constant broadcasts data_shape with additional dimension + */ + auto constant_broadcast_output = [](const ngraph::PartialShape & data_pshape, const ngraph::Shape & const_shape) -> bool { + if (data_pshape.rank().is_dynamic() || const_shape.size() > data_pshape.rank().get_length()) { + return true; } - auto parent = data_node.get_node_shared_ptr(); - size_t consumers_count = 0; - for (const auto &output : parent->outputs()) { - consumers_count += output.get_target_inputs().size(); + std::vector data_shape(data_pshape); + + auto const_shape_it = const_shape.rbegin(); + auto data_shape_it = data_shape.rbegin(); + + while (const_shape_it != const_shape.rend()) { + auto data_dim = *data_shape_it; + auto const_dim = *const_shape_it; + + /* DATA DIM - CONST DIM - CONSTANT BROADCAST OUTPUT + DYN - 64 - TRUE + DYN - 1 - FALSE + 64 - 1 - FALSE + 1 - 64 - TRUE + 64 - 64 - FALSE + */ + if ((data_dim.is_dynamic() && const_dim != 1) || + (data_dim.is_static() && data_dim.get_length() == 1 && const_dim != 1)) { + return true; + } + + ++const_shape_it; + ++data_shape_it; } - if (!has_result_output || consumers_count == 1) { - if (!std::dynamic_pointer_cast(parent)) { - parent->set_friendly_name(lin_op->get_friendly_name()); - } - // TODO: due to ngraph::replace_node function limitations we have to reconnect output port consumers to the new input - // using replace_source_output method - for (auto &input : lin_op->output(0).get_target_inputs()) { - input.replace_source_output(data_node); - } + return false; + }; + + // Check that eltwise is not useless and do not broadcast output otherwise we remove it + if (((std::is_same() && ngraph::op::util::constantIsEqualTo(const_node, 0)) || + (std::is_same() && ngraph::op::util::constantIsEqualTo(const_node, 1))) && + !constant_broadcast_output(data_node.get_partial_shape(), const_node->get_shape())) { + bool ret_status = ngraph::replace_output_update_name(lin_op->output(0), data_node); + if (ret_status) { return true; } } + auto res = check_constant(const_node, data_node.get_partial_shape()); - auto res = check_constant(const_node, data_node.get_shape()); - - if (res == CONVERSION_RESULT::NONE || (res == CONVERSION_RESULT::SCALE_SHIFT && lin_op->get_shape().size() < 4)) { + if (res == CONVERSION_RESULT::NONE || (res == CONVERSION_RESULT::SCALE_SHIFT && output_shape_rank < 4)) { return convert_to_eltwise(lin_op, lin_op->input(0).get_source_output(), lin_op->input(1).get_source_output()); @@ -140,12 +162,12 @@ ngraph::graph_rewrite_callback get_callback() { std::shared_ptr scaleshift; if (std::is_same()) { auto weights = ngraph::opset1::Constant::create(weights_et, weights_shape, {1}); - scaleshift = std::make_shared(data_node, ngraph::op::util::normalize_constant(weights, lin_op->get_shape()), - ngraph::op::util::normalize_constant(const_node, lin_op->get_shape())); + scaleshift = std::make_shared(data_node, ngraph::op::util::normalize_constant(weights, output_shape), + ngraph::op::util::normalize_constant(const_node, output_shape)); } else { auto bias = ngraph::opset1::Constant::create(weights_et, weights_shape, {0}); - scaleshift = std::make_shared(data_node, ngraph::op::util::normalize_constant(const_node, lin_op->get_shape()), - ngraph::op::util::normalize_constant(bias, lin_op->get_shape())); + scaleshift = std::make_shared(data_node, ngraph::op::util::normalize_constant(const_node, output_shape), + ngraph::op::util::normalize_constant(bias, output_shape)); } scaleshift->set_friendly_name(lin_op->get_friendly_name()); diff --git a/inference-engine/src/transformations/include/transformations/utils/utils.hpp b/inference-engine/src/transformations/include/transformations/utils/utils.hpp index c6169c1..d1edf9c 100644 --- a/inference-engine/src/transformations/include/transformations/utils/utils.hpp +++ b/inference-engine/src/transformations/include/transformations/utils/utils.hpp @@ -47,7 +47,7 @@ bool has_op_with_type(const std::shared_ptr &function) { INFERENCE_ENGINE_API_CPP(bool) get_single_value(const std::shared_ptr & const_node, float & value); INFERENCE_ENGINE_API_CPP(std::shared_ptr) normalize_constant(const std::shared_ptr & constant, - const Shape & shape); + const PartialShape & shape); INFERENCE_ENGINE_API_CPP(std::shared_ptr) broadcastTo(const Output& input, const Shape& shape); diff --git a/inference-engine/src/transformations/src/ngraph_ops/eltwise.cpp b/inference-engine/src/transformations/src/ngraph_ops/eltwise.cpp index e4c4d4c..ed28654 100644 --- a/inference-engine/src/transformations/src/ngraph_ops/eltwise.cpp +++ b/inference-engine/src/transformations/src/ngraph_ops/eltwise.cpp @@ -37,16 +37,24 @@ void op::Eltwise::validate_and_infer_types() { NODE_VALIDATION_CHECK(this, element::Type::merge(et_result, data1_et, data2_et), "Element types for first and second do not match :", data1_et, " and ", data2_et); - auto shape1 = get_input_partial_shape(0).to_shape(); - auto shape2 = get_input_partial_shape(1).to_shape(); + if (get_input_partial_shape(0).rank().is_dynamic() || + get_input_partial_shape(1).rank().is_dynamic()) { + set_output_type(0, et_result, PartialShape::dynamic()); + return; + } + + std::vector shape1(get_input_partial_shape(0)); + std::vector shape2(get_input_partial_shape(1)); - ngraph::Shape output_shape(std::max(shape1.size(), shape2.size())); + std::vector output_shape(PartialShape::dynamic(std::max(shape1.size(), shape2.size()))); auto output_shape_it = output_shape.rbegin(); auto shape1_it = shape1.rbegin(), shape2_it = shape2.rbegin(); while (shape1_it != shape1.rend() || shape2_it != shape2.rend()) { if (shape1_it != shape1.rend() && shape2_it != shape2.rend()) { - *output_shape_it = std::max(*shape1_it, *shape2_it); + if (shape1_it->is_static() && shape2_it->is_static()) { + *output_shape_it = (shape1_it->get_length() > shape2_it->get_length() ? *shape1_it : *shape2_it); + } } else if (shape1_it != shape1.rend()) { *output_shape_it = *shape1_it; } else if (shape2_it != shape2.rend()) { @@ -61,5 +69,5 @@ void op::Eltwise::validate_and_infer_types() { } } - set_output_type(0, data1_et, PartialShape(output_shape)); + set_output_type(0, et_result, output_shape); } diff --git a/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_mul_add_to_scaleshift_or_power.cpp b/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_mul_add_to_scaleshift_or_power.cpp index 4bb60be..018f797 100644 --- a/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_mul_add_to_scaleshift_or_power.cpp +++ b/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_mul_add_to_scaleshift_or_power.cpp @@ -17,11 +17,11 @@ #include "ngraph_ops/scaleshift.hpp" CONVERSION_RESULT check_constant(const std::shared_ptr& constant, - const ngraph::Shape& shape) { - if (!constant) return CONVERSION_RESULT::NONE; + const ngraph::PartialShape& shape) { + if (!constant || shape.rank().is_dynamic()) return CONVERSION_RESULT::NONE; auto const_shape = constant->get_shape(); - auto input_shape = shape; + std::vector input_shape(shape); // In case of scalar we will convert it to Power if (const_shape.empty() || (const_shape.size() == 1 && const_shape[0] == 1)) { @@ -47,7 +47,7 @@ CONVERSION_RESULT check_constant(const std::shared_ptr if (idx == feature_index && *in_it == 1) { is_power = true; - } else if (idx == feature_index && *in_it != *out_it) { + } else if (idx == feature_index && (out_it->is_dynamic() || *in_it != out_it->get_length())) { return CONVERSION_RESULT::NONE; } } @@ -95,6 +95,11 @@ void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshi const_weights_node = ngraph::as_type_ptr(mul_input_0); } + if (add_node->get_output_partial_shape(0).rank().is_dynamic() || + mul_node->get_output_partial_shape(0).rank().is_dynamic()) { + return false; + } + // Check that eltwise is not useless otherwise we remove it if (ngraph::op::util::constantIsEqualTo(const_weights_node, 1) && ngraph::op::util::constantIsEqualTo(const_bias_node, 0)) { @@ -124,11 +129,14 @@ void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshi } } - auto res1 = check_constant(const_weights_node, data_node.get_shape()); - auto res2 = check_constant(const_bias_node, mul_node->get_output_shape(0)); + auto res1 = check_constant(const_weights_node, data_node.get_partial_shape()); + auto res2 = check_constant(const_bias_node, mul_node->get_output_partial_shape(0)); + + const auto output_shape = add_node->get_output_partial_shape(0); + const auto output_shape_rank = output_shape.rank().get_length(); if (res1 == CONVERSION_RESULT::NONE || res2 == CONVERSION_RESULT::NONE || - ((res1 == CONVERSION_RESULT::SCALE_SHIFT || res2 == CONVERSION_RESULT::SCALE_SHIFT) && add_node->get_shape().size() < 4)) { + ((res1 == CONVERSION_RESULT::SCALE_SHIFT || res2 == CONVERSION_RESULT::SCALE_SHIFT) && output_shape_rank < 4)) { return false; } @@ -136,8 +144,8 @@ void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshi if (res1 == CONVERSION_RESULT::SCALE_SHIFT || res2 == CONVERSION_RESULT::SCALE_SHIFT) { NodeVector new_ops; - auto weights_in = ngraph::op::util::normalize_constant(const_weights_node, add_node->get_shape()); - auto biases_in = ngraph::op::util::normalize_constant(const_bias_node, add_node->get_shape()); + auto weights_in = ngraph::op::util::normalize_constant(const_weights_node, output_shape); + auto biases_in = ngraph::op::util::normalize_constant(const_bias_node, output_shape); new_ops.push_back(weights_in); new_ops.push_back(biases_in); diff --git a/inference-engine/src/transformations/src/transformations/utils/utils.cpp b/inference-engine/src/transformations/src/transformations/utils/utils.cpp index 6c9d04f..76a5797 100644 --- a/inference-engine/src/transformations/src/transformations/utils/utils.cpp +++ b/inference-engine/src/transformations/src/transformations/utils/utils.cpp @@ -49,12 +49,12 @@ bool get_single_value(const std::shared_ptr& const_node, float& va } std::shared_ptr normalize_constant(const std::shared_ptr& constant, - const Shape& shape) { + const PartialShape& shape) { auto const_shape = constant->get_shape(); - if (const_shape.size() == shape.size()) { + if (const_shape.size() == shape.rank().get_length()) { return constant; } - int cnt = shape.size() - const_shape.size(); + int64_t cnt = shape.rank().get_length() - const_shape.size(); for (int i = 0; i < cnt; ++i) { const_shape.insert(const_shape.begin(), 1); } diff --git a/inference-engine/tests/functional/inference_engine/ngraph_reader/linear_ops_tests.cpp b/inference-engine/tests/functional/inference_engine/ngraph_reader/linear_ops_tests.cpp index d17b602..046c62c 100644 --- a/inference-engine/tests/functional/inference_engine/ngraph_reader/linear_ops_tests.cpp +++ b/inference-engine/tests/functional/inference_engine/ngraph_reader/linear_ops_tests.cpp @@ -1757,7 +1757,7 @@ TEST_F(NGraphReaderTests, RemoveAdd2) { - + 1 diff --git a/inference-engine/tests/functional/inference_engine/transformations/mul_add_conversion_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/mul_add_conversion_test.cpp new file mode 100644 index 0000000..27cfe97 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/mul_add_conversion_test.cpp @@ -0,0 +1,315 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "common_test_utils/test_common.hpp" +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ngraph_test_utils.hpp" + +using namespace testing; + +using InputShape = ngraph::PartialShape; +struct ConstantParams { + ngraph::Shape shape; + float value; + bool skip; + ConstantParams() : skip(true) {} + ConstantParams(const ngraph::Shape & shape, float value) + : shape(shape), value(value), skip(false) {} +}; +using MulConstant = ConstantParams; +using AddConstant = ConstantParams; +using RefFunction = std::function(const InputShape&, const MulConstant&, const AddConstant&)>; + +class MulAddConversionTests: public CommonTestUtils::TestsCommon, +public testing::WithParamInterface, RefFunction> > { +public: + std::shared_ptr f, f_ref; + + void SetUp() override { + const auto& attrs = std::get<0>(GetParam()); + const auto& input_shape = std::get<0>(attrs); + const auto& mul_const = std::get<1>(attrs); + const auto& add_const = std::get<2>(attrs); + const auto& get_ref_function = std::get<1>(GetParam()); + + f = get_initial_function(input_shape, mul_const, add_const); + f_ref = get_ref_function(input_shape, mul_const, add_const); + } + + static + std::shared_ptr get_initial_function(const InputShape& input_shape, + const MulConstant& mul_const, + const AddConstant& add_const) { + auto input = std::make_shared(ngraph::element::f32, input_shape); + ngraph::Output last = input; + if (!mul_const.skip) { + last = std::make_shared(last, create_constant(mul_const.shape, mul_const.value)); + } + if (!add_const.skip) { + last = std::make_shared(last, create_constant(add_const.shape, add_const.value)); + } + last = std::make_shared(last); + return std::make_shared(ngraph::NodeVector{last.get_node_shared_ptr()}, ngraph::ParameterVector{input}); + } + + static + std::shared_ptr get_scale_shift_reference(const InputShape& input_shape, + const MulConstant& mul_const, + const AddConstant& add_const) { + if (mul_const.skip && add_const.skip) { + throw ngraph::ngraph_error("Invalid arguments"); + } + + auto input = std::make_shared(ngraph::element::f32, input_shape); + auto scsh = std::make_shared(input, (!mul_const.skip ? create_constant(mul_const.shape, mul_const.value) + : create_constant(add_const.shape, 1)), + (!add_const.skip ? create_constant(add_const.shape, add_const.value) + : create_constant(mul_const.shape, 0))); + auto relu = std::make_shared(scsh); + return std::make_shared(ngraph::NodeVector{relu}, ngraph::ParameterVector{input}); + } + + static + std::shared_ptr get_power_reference(const InputShape& input_shape, + const MulConstant& mul_const, + const AddConstant& add_const) { + auto input = std::make_shared(ngraph::element::f32, input_shape); + float scale(1), shift(0); + if (!mul_const.skip) scale = mul_const.value; + if (!add_const.skip) shift = add_const.value; + auto pow = std::make_shared(input, 1., scale, shift); + auto relu = std::make_shared(pow); + return std::make_shared(ngraph::NodeVector{relu}, ngraph::ParameterVector{input}); + } + + static + std::shared_ptr get_eltwise_add_reference(const InputShape& input_shape, + const MulConstant& mul_const, + const AddConstant& add_const) { + auto input = std::make_shared(ngraph::element::f32, input_shape); + auto add = std::make_shared(input, create_constant(add_const.shape, add_const.value), ELTWISE_TYPE::Sum); + auto relu = std::make_shared(add); + return std::make_shared(ngraph::NodeVector{relu}, ngraph::ParameterVector{input}); + } + + static + std::shared_ptr get_eltwise_mul_reference(const InputShape& input_shape, + const MulConstant& mul_const, + const AddConstant& add_const) { + auto input = std::make_shared(ngraph::element::f32, input_shape); + auto mul = std::make_shared(input, create_constant(mul_const.shape, mul_const.value), ELTWISE_TYPE::Prod); + auto relu = std::make_shared(mul); + return std::make_shared(ngraph::NodeVector{relu}, ngraph::ParameterVector{input}); + } + + static + std::shared_ptr create_constant(const ngraph::Shape & shape, float init_value) { + return ngraph::opset1::Constant::create(ngraph::element::f32, shape, {init_value}); + } +}; + +class MulOrAddConversionTests: public MulAddConversionTests {}; + +TEST_P(MulAddConversionTests, CompareFunctions) { + ngraph::pass::InitNodeInfo().run_on_function(f); + ngraph::pass::ConvertMulAddToScaleShiftOrPower().run_on_function(f); + ASSERT_NO_THROW(check_rt_info(f)); + ngraph::pass::ConstantFolding().run_on_function(f); + f->validate_nodes_and_infer_types(); + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST_P(MulOrAddConversionTests, CompareFunctions) { + ngraph::pass::InitNodeInfo().run_on_function(f); + ngraph::pass::ConvertMulOrAddFinally().run_on_function(f); + ASSERT_NO_THROW(check_rt_info(f)); + ngraph::pass::ConstantFolding().run_on_function(f); + f->validate_nodes_and_infer_types(); + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +#define CONST(A, B) ConstantParams(A, B) +#define NONE ConstantParams() +#define SCALESHIFT MulAddConversionTests::get_scale_shift_reference +#define POWER MulAddConversionTests::get_power_reference +#define SAME MulAddConversionTests::get_initial_function +#define ELTWISE_SUM MulAddConversionTests::get_eltwise_add_reference +#define ELTWISE_PROD MulAddConversionTests::get_eltwise_mul_reference + +INSTANTIATE_TEST_CASE_P(MulAddToScaleShift, MulAddConversionTests, testing::Combine( + testing::Values(std::make_tuple(InputShape{DYN, 3, 64, 64}, + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5), + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)), + std::make_tuple(InputShape{DYN, 3, DYN, 64}, + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5), + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)), + std::make_tuple(InputShape{DYN, 3, DYN, DYN}, + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5), + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5))), + testing::Values(SCALESHIFT))); + +INSTANTIATE_TEST_CASE_P(MulToScaleShift, MulOrAddConversionTests, testing::Combine( + testing::Values(std::make_tuple(InputShape{DYN, 3, 64, 64}, + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5), + NONE), + std::make_tuple(InputShape{DYN, 3, DYN, 64}, + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5), + NONE), + std::make_tuple(InputShape{DYN, 3, DYN, DYN}, + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5), + NONE)), + testing::Values(SCALESHIFT))); + +INSTANTIATE_TEST_CASE_P(AddToScaleShift, MulOrAddConversionTests, testing::Combine( + testing::Values(std::make_tuple(InputShape{DYN, 3, 64, 64}, + NONE, + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)), + std::make_tuple(InputShape{DYN, 3, DYN, 64}, + NONE, + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)), + std::make_tuple(InputShape{DYN, 3, DYN, DYN}, + NONE, + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5))), + testing::Values(SCALESHIFT))); + +INSTANTIATE_TEST_CASE_P(MulAddToPower, MulAddConversionTests, testing::Combine( + testing::Values(std::make_tuple(InputShape{DYN, 3, 64, 64}, + CONST(ngraph::Shape({1}), 0.5), + CONST(ngraph::Shape({1}), 0.5)), + std::make_tuple(InputShape{DYN, 3, DYN, 64}, + CONST(ngraph::Shape({1}), 0.5), + CONST(ngraph::Shape({1}), 0.5)), + std::make_tuple(InputShape{DYN, 3, DYN, DYN}, + CONST(ngraph::Shape({1}), 0.5), + CONST(ngraph::Shape({1}), 0.5))), + testing::Values(POWER))); + +INSTANTIATE_TEST_CASE_P(MulToPower, MulOrAddConversionTests, testing::Combine( + testing::Values(std::make_tuple(InputShape{DYN, 3, 64, 64}, + CONST(ngraph::Shape({1}), 0.5), + NONE), + std::make_tuple(InputShape{DYN, 3, DYN, 64}, + CONST(ngraph::Shape({1}), 0.5), + NONE), + std::make_tuple(InputShape{DYN, 3, DYN, DYN}, + CONST(ngraph::Shape({1}), 0.5), + NONE)), + testing::Values(POWER))); + +INSTANTIATE_TEST_CASE_P(AddToPower, MulOrAddConversionTests, testing::Combine( + testing::Values(std::make_tuple(InputShape{DYN, 3, 64, 64}, + NONE, + CONST(ngraph::Shape({1}), 0.5)), + std::make_tuple(InputShape{DYN, 3, DYN, 64}, + NONE, + CONST(ngraph::Shape({1}), 0.5)), + std::make_tuple(InputShape{DYN, 3, DYN, DYN}, + NONE, + CONST(ngraph::Shape({1}), 0.5))), + testing::Values(POWER))); + + +INSTANTIATE_TEST_CASE_P(MulAddNegative, MulAddConversionTests, testing::Combine( + testing::Values(std::make_tuple(InputShape{DYN, 3, 64}, + CONST(ngraph::Shape({1, 3, 1}), 0.5), + CONST(ngraph::Shape({1, 3, 1}), 0.5)/*ScaleShift must always be 4D*/), + std::make_tuple(InputShape{DYN, 3, DYN}, + CONST(ngraph::Shape({1, 1, 3, 1}), 0.5), + CONST(ngraph::Shape({3, 1}), 0.5)/*detect broadcast case*/), + std::make_tuple(InputShape{DYN, 3, DYN}, + CONST(ngraph::Shape({3, 1}), 0.5), + CONST(ngraph::Shape({1, 1, 3, 1}), 0.5)/*detect broadcast case*/), + std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5), + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)), + std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, + CONST(ngraph::Shape({1, 3, 2, 1}), 0.5), + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)), + std::make_tuple(InputShape{1, 3, 2}, + CONST(ngraph::Shape({1, 3, 1}), 0.5), + CONST(ngraph::Shape({1, 3, 2}), 0.5)), + std::make_tuple(InputShape{1, DYN, 64, 64}, + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5), + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5))), + testing::Values(SAME))); + +INSTANTIATE_TEST_CASE_P(MulToEltwise, MulOrAddConversionTests, testing::Combine( + testing::Values(std::make_tuple(InputShape{DYN, 3, 64}, + CONST(ngraph::Shape({1, 1, 64}), 0.5), + NONE), + std::make_tuple(InputShape{DYN, 3, DYN}, + CONST(ngraph::Shape({1, 1, 3, 1}), 0.5), + NONE), + std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5), + NONE), + std::make_tuple(InputShape{DYN, 3, DYN, DYN}, + CONST(ngraph::Shape({1, 3, 2, 1}), 0.5), + NONE), + std::make_tuple(InputShape{1, 3, 2}, + CONST(ngraph::Shape({1, 3, 2}), 0.5), + NONE), + std::make_tuple(InputShape{1, DYN, 64, 64}, + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5), + NONE), + std::make_tuple(InputShape{64, 1, 64}, + CONST(ngraph::Shape({64, 64, 64}), 1), + NONE), + std::make_tuple(InputShape{64, 64, 1}, + CONST(ngraph::Shape({1, 1, 64}), 1), + NONE), + std::make_tuple(InputShape{DYN, 1, 64}, + CONST(ngraph::Shape({64, 1, 64}), 1), + NONE)), + testing::Values(ELTWISE_PROD))); + +INSTANTIATE_TEST_CASE_P(AddToEltwise, MulOrAddConversionTests, testing::Combine( + testing::Values(std::make_tuple(InputShape{DYN, 3, 64}, + NONE, + CONST(ngraph::Shape({1, 1, 64}), 0.5)), + std::make_tuple(InputShape{DYN, 3, DYN}, + NONE, + CONST(ngraph::Shape({1, 1, 3, 1}), 0.5)), + std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, + NONE, + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)), + std::make_tuple(InputShape{DYN, 3, DYN, DYN}, + NONE, + CONST(ngraph::Shape({1, 3, 2, 1}), 0.5)), + std::make_tuple(InputShape{1, 3, 2}, + NONE, + CONST(ngraph::Shape({1, 3, 2}), 0.5)), + std::make_tuple(InputShape{1, DYN, 64, 64}, + NONE, + CONST(ngraph::Shape({1, 3, 1, 1}), 0.5))), + testing::Values(ELTWISE_SUM))); + +#undef CONST +#undef SCALESHIFT +#undef POWER +#undef SAME +#undef ELTWISE_PROD +#undef ELTWISE_SUM