1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "legacy/transformations/convert_opset1_to_legacy/convert_mul_add_to_scaleshift_or_power.hpp"
11 #include <ngraph/opsets/opset1.hpp>
12 #include <ngraph/rt_info.hpp>
14 #include "transformations/utils/utils.hpp"
16 #include "legacy/ngraph_ops/power.hpp"
17 #include "legacy/ngraph_ops/scaleshift.hpp"
19 CONVERSION_RESULT check_constant(const std::shared_ptr<ngraph::opset1::Constant>& constant,
20 const ngraph::PartialShape& shape) {
21 if (!constant || shape.rank().is_dynamic()) return CONVERSION_RESULT::NONE;
23 auto const_shape = constant->get_shape();
24 std::vector<ngraph::Dimension> input_shape(shape);
26 // In case of scalar we will convert it to Power
27 if (const_shape.empty() || (const_shape.size() == 1 && const_shape[0] == 1)) {
28 return CONVERSION_RESULT::POWER;
32 size_t max_shape_len = std::max(input_shape.size(), const_shape.size());
33 while (const_shape.size() < max_shape_len) const_shape.insert(const_shape.begin(), 1);
34 while (input_shape.size() < max_shape_len) input_shape.insert(input_shape.begin(), 1);
36 // This is feature dimension index from right side (ex. for NCDHW it's equal to 3).
37 const size_t feature_index = input_shape.size() - 2;
38 if (const_shape.size() < feature_index) return CONVERSION_RESULT::NONE;
40 bool is_power = false;
41 auto in_it = const_shape.rbegin();
42 auto out_it = input_shape.rbegin();
43 for (int idx = 0; in_it != const_shape.rend() && out_it != input_shape.rend(); ++in_it, ++out_it, ++idx) {
44 if (idx != feature_index && *in_it != 1) {
45 return CONVERSION_RESULT::NONE;
48 if (idx == feature_index && *in_it == 1) {
50 } else if (idx == feature_index && (out_it->is_dynamic() || *in_it != out_it->get_length())) {
51 return CONVERSION_RESULT::NONE;
55 return is_power ? CONVERSION_RESULT::POWER : CONVERSION_RESULT::SCALE_SHIFT;
58 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertMulAddToScaleShiftOrPower, "ConvertMulAddToScaleShiftOrPower", 0);
60 void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshift_or_power() {
61 auto data_batch = std::make_shared<pattern::op::Label>(element::f32, Shape {1});
63 auto weights = std::make_shared<ngraph::opset1::Constant>(element::f32, Shape {1}, std::vector<float> {0});
64 auto bias = std::make_shared<ngraph::opset1::Constant>(element::f32, Shape {1}, std::vector<float> {0});
66 auto mul = std::make_shared<ngraph::opset1::Multiply>(data_batch, weights);
67 auto add = std::make_shared<ngraph::opset1::Add>(mul, bias);
69 ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
70 auto add_node = ngraph::as_type_ptr<ngraph::opset1::Add>(m.get_match_root());
76 if (!add_node->get_element_type().is_real()) {
80 auto add_input_0 = add_node->input(0).get_source_output().get_node_shared_ptr();
81 auto add_input_1 = add_node->input(1).get_source_output().get_node_shared_ptr();
83 auto mul_node = ngraph::as_type_ptr<ngraph::opset1::Multiply>(add_input_0);
84 auto const_bias_node = ngraph::as_type_ptr<ngraph::opset1::Constant>(add_input_1);
86 mul_node = ngraph::as_type_ptr<ngraph::opset1::Multiply>(add_input_1);
87 const_bias_node = ngraph::as_type_ptr<ngraph::opset1::Constant>(add_input_0);
90 if (const_bias_node->output(0).get_element_type() != add_node->output(0).get_element_type()) {
94 auto mul_input_0 = mul_node->input(0).get_source_output().get_node_shared_ptr();
95 auto mul_input_1 = mul_node->input(1).get_source_output().get_node_shared_ptr();
97 auto data_node = mul_node->input(0).get_source_output();
98 auto const_weights_node = ngraph::as_type_ptr<ngraph::opset1::Constant>(mul_input_1);
99 if (!const_weights_node) {
100 data_node = mul_node->input(1).get_source_output();
101 const_weights_node = ngraph::as_type_ptr<ngraph::opset1::Constant>(mul_input_0);
104 if (const_weights_node->output(0).get_element_type() != mul_node->output(0).get_element_type()) {
108 if (add_node->get_output_partial_shape(0).rank().is_dynamic() ||
109 mul_node->get_output_partial_shape(0).rank().is_dynamic()) {
113 // Check that eltwise is not useless otherwise we remove it
114 if (ngraph::op::util::constantIsEqualTo(const_weights_node, 1) &&
115 ngraph::op::util::constantIsEqualTo(const_bias_node, 0)) {
116 bool has_result_output = false;
117 for (const auto & output : add_node->output(0).get_target_inputs()) {
118 if (dynamic_cast<ngraph::opset1::Result*>(output.get_node())) {
119 has_result_output = true;
123 auto parent = data_node.get_node_shared_ptr();
124 size_t consumers_count = 0;
125 for (const auto &output : parent->outputs()) {
126 consumers_count += output.get_target_inputs().size();
129 if (!has_result_output || consumers_count == 1) {
130 if (!std::dynamic_pointer_cast<ngraph::opset1::Parameter>(parent)) {
131 parent->set_friendly_name(add_node->get_friendly_name());
133 // TODO: due to ngraph::replace_node function limitations we have to reconnect output port consumers to the new input
134 // using replace_source_output method
135 for (auto &input : add_node->output(0).get_target_inputs()) {
136 input.replace_source_output(data_node);
142 auto res1 = check_constant(const_weights_node, data_node.get_partial_shape());
143 auto res2 = check_constant(const_bias_node, mul_node->get_output_partial_shape(0));
145 const auto output_shape = add_node->get_output_partial_shape(0);
146 const auto output_shape_rank = output_shape.rank().get_length();
148 bool is_dequantization =
149 (add_node->get_rt_info().count("DEQUANTIZATION") != 0 || mul_node->get_rt_info().count("DEQUANTIZATION") != 0);
151 if (res1 == CONVERSION_RESULT::NONE || res2 == CONVERSION_RESULT::NONE ||
152 ((res1 == CONVERSION_RESULT::SCALE_SHIFT || res2 == CONVERSION_RESULT::SCALE_SHIFT) && !is_dequantization && output_shape_rank < 4)) {
156 // TODO: in case if scale and shift constants has equal values the best way is to convert them to Power
157 if (res1 == CONVERSION_RESULT::SCALE_SHIFT || res2 == CONVERSION_RESULT::SCALE_SHIFT || is_dequantization) {
160 auto weights_in = ngraph::op::util::normalize_constant(const_weights_node, output_shape);
161 auto biases_in = ngraph::op::util::normalize_constant(const_bias_node, output_shape);
162 new_ops.push_back(weights_in);
163 new_ops.push_back(biases_in);
165 if (is_dequantization) {
166 const Shape data_shape = data_node.get_shape();
167 Shape broadcasted_shape = std::vector<size_t>(data_shape.size(), 1ul);
168 broadcasted_shape[1] = data_shape[1];
170 weights_in = ngraph::op::util::broadcastTo(weights_in, broadcasted_shape);
171 new_ops.push_back(weights_in);
173 biases_in = ngraph::op::util::broadcastTo(biases_in, broadcasted_shape);
174 new_ops.push_back(biases_in);
177 if (res1 == CONVERSION_RESULT::POWER && !is_dequantization) {
178 weights_in = ngraph::op::util::broadcastTo(weights_in, biases_in->get_shape());
179 new_ops.push_back(weights_in);
181 if (res2 == CONVERSION_RESULT::POWER && !is_dequantization) {
182 biases_in = ngraph::op::util::broadcastTo(biases_in, weights_in->get_shape());
183 new_ops.push_back(biases_in);
186 auto output_type = m.get_match_root()->get_output_element_type(0);
187 auto scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, weights_in, biases_in, output_type);
188 new_ops.push_back(scaleshift);
190 scaleshift->set_friendly_name(add_node->get_friendly_name());
191 ngraph::copy_runtime_info({mul_node, add_node}, new_ops);
192 ngraph::replace_node(m.get_match_root(), scaleshift);
194 float scale = 0.f, shift = 0.f;
195 if (!op::util::get_single_value(const_weights_node, scale)) {
198 if (!op::util::get_single_value(const_bias_node, shift)) {
202 auto output_type = m.get_match_root()->get_output_element_type(0);
203 auto power = std::make_shared<ngraph::op::PowerIE>(data_node, 1., scale, shift, output_type);
204 power->set_friendly_name(add_node->get_friendly_name());
205 ngraph::copy_runtime_info({mul_node, add_node}, power);
206 ngraph::replace_node(m.get_match_root(), power);
212 auto m = std::make_shared<ngraph::pattern::Matcher>(add, "CPUFusion.MulAddToScaleShiftOrPower");
213 this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);