Es/lpt/lpt to ngraph fixes2 with master (#2671)
[platform/upstream/dldt.git] / inference-engine / src / legacy_api / src / transformations / convert_opset1_to_legacy / convert_mul_add_to_scaleshift_or_power.cpp
1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "legacy/transformations/convert_opset1_to_legacy/convert_mul_add_to_scaleshift_or_power.hpp"
6
7 #include <memory>
8 #include <vector>
9 #include <algorithm>
10
11 #include <ngraph/opsets/opset1.hpp>
12 #include <ngraph/rt_info.hpp>
13
14 #include "transformations/utils/utils.hpp"
15
16 #include "legacy/ngraph_ops/power.hpp"
17 #include "legacy/ngraph_ops/scaleshift.hpp"
18
19 CONVERSION_RESULT check_constant(const std::shared_ptr<ngraph::opset1::Constant>& constant,
20                                  const ngraph::PartialShape& shape) {
21     if (!constant || shape.rank().is_dynamic()) return CONVERSION_RESULT::NONE;
22
23     auto const_shape = constant->get_shape();
24     std::vector<ngraph::Dimension> input_shape(shape);
25
26     // In case of scalar we will convert it to Power
27     if (const_shape.empty() || (const_shape.size() == 1 && const_shape[0] == 1)) {
28         return CONVERSION_RESULT::POWER;
29     }
30
31     // Align shapes
32     size_t max_shape_len = std::max(input_shape.size(), const_shape.size());
33     while (const_shape.size() < max_shape_len) const_shape.insert(const_shape.begin(), 1);
34     while (input_shape.size() < max_shape_len) input_shape.insert(input_shape.begin(), 1);
35
36     // This is feature dimension index from right side (ex. for NCDHW it's equal to 3).
37     const size_t feature_index = input_shape.size() - 2;
38     if (const_shape.size() < feature_index) return CONVERSION_RESULT::NONE;
39
40     bool is_power = false;
41     auto in_it = const_shape.rbegin();
42     auto out_it = input_shape.rbegin();
43     for (int idx = 0; in_it != const_shape.rend() && out_it != input_shape.rend(); ++in_it, ++out_it, ++idx) {
44         if (idx != feature_index && *in_it != 1) {
45             return CONVERSION_RESULT::NONE;
46         }
47
48         if (idx == feature_index && *in_it == 1) {
49             is_power = true;
50         } else if (idx == feature_index && (out_it->is_dynamic() || *in_it != out_it->get_length())) {
51             return CONVERSION_RESULT::NONE;
52         }
53     }
54
55     return is_power ? CONVERSION_RESULT::POWER : CONVERSION_RESULT::SCALE_SHIFT;
56 }
57
58 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertMulAddToScaleShiftOrPower, "ConvertMulAddToScaleShiftOrPower", 0);
59
60 void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshift_or_power() {
61     auto data_batch = std::make_shared<pattern::op::Label>(element::f32, Shape {1});
62
63     auto weights = std::make_shared<ngraph::opset1::Constant>(element::f32, Shape {1}, std::vector<float> {0});
64     auto bias = std::make_shared<ngraph::opset1::Constant>(element::f32, Shape {1}, std::vector<float> {0});
65
66     auto mul = std::make_shared<ngraph::opset1::Multiply>(data_batch, weights);
67     auto add = std::make_shared<ngraph::opset1::Add>(mul, bias);
68
69     ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
70         auto add_node = ngraph::as_type_ptr<ngraph::opset1::Add>(m.get_match_root());
71
72         if (!add_node) {
73             return false;
74         }
75
76         if (!add_node->get_element_type().is_real()) {
77             return false;
78         }
79
80         auto add_input_0 = add_node->input(0).get_source_output().get_node_shared_ptr();
81         auto add_input_1 = add_node->input(1).get_source_output().get_node_shared_ptr();
82
83         auto mul_node = ngraph::as_type_ptr<ngraph::opset1::Multiply>(add_input_0);
84         auto const_bias_node = ngraph::as_type_ptr<ngraph::opset1::Constant>(add_input_1);
85         if (!mul_node) {
86             mul_node = ngraph::as_type_ptr<ngraph::opset1::Multiply>(add_input_1);
87             const_bias_node = ngraph::as_type_ptr<ngraph::opset1::Constant>(add_input_0);
88         }
89
90         if (const_bias_node->output(0).get_element_type() != add_node->output(0).get_element_type()) {
91             return false;
92         }
93
94         auto mul_input_0 = mul_node->input(0).get_source_output().get_node_shared_ptr();
95         auto mul_input_1 = mul_node->input(1).get_source_output().get_node_shared_ptr();
96
97         auto data_node = mul_node->input(0).get_source_output();
98         auto const_weights_node = ngraph::as_type_ptr<ngraph::opset1::Constant>(mul_input_1);
99         if (!const_weights_node) {
100             data_node = mul_node->input(1).get_source_output();
101             const_weights_node = ngraph::as_type_ptr<ngraph::opset1::Constant>(mul_input_0);
102         }
103
104         if (const_weights_node->output(0).get_element_type() != mul_node->output(0).get_element_type()) {
105             return false;
106         }
107
108         if (add_node->get_output_partial_shape(0).rank().is_dynamic() ||
109             mul_node->get_output_partial_shape(0).rank().is_dynamic()) {
110             return false;
111         }
112
113         // Check that eltwise is not useless otherwise we remove it
114         if (ngraph::op::util::constantIsEqualTo(const_weights_node, 1) &&
115             ngraph::op::util::constantIsEqualTo(const_bias_node, 0)) {
116             bool has_result_output = false;
117             for (const auto & output : add_node->output(0).get_target_inputs()) {
118                 if (dynamic_cast<ngraph::opset1::Result*>(output.get_node())) {
119                     has_result_output = true;
120                 }
121             }
122
123             auto parent = data_node.get_node_shared_ptr();
124             size_t consumers_count = 0;
125             for (const auto &output : parent->outputs()) {
126                 consumers_count += output.get_target_inputs().size();
127             }
128
129             if (!has_result_output || consumers_count == 1) {
130                 if (!std::dynamic_pointer_cast<ngraph::opset1::Parameter>(parent)) {
131                     parent->set_friendly_name(add_node->get_friendly_name());
132                 }
133                 // TODO: due to ngraph::replace_node function limitations we have to reconnect output port consumers to the new input
134                 // using replace_source_output method
135                 for (auto &input : add_node->output(0).get_target_inputs()) {
136                     input.replace_source_output(data_node);
137                 }
138                 return true;
139             }
140         }
141
142         auto res1 = check_constant(const_weights_node, data_node.get_partial_shape());
143         auto res2 = check_constant(const_bias_node, mul_node->get_output_partial_shape(0));
144
145         const auto output_shape = add_node->get_output_partial_shape(0);
146         const auto output_shape_rank = output_shape.rank().get_length();
147
148         bool is_dequantization =
149                 (add_node->get_rt_info().count("DEQUANTIZATION") != 0 || mul_node->get_rt_info().count("DEQUANTIZATION") != 0);
150
151         if (res1 == CONVERSION_RESULT::NONE || res2 == CONVERSION_RESULT::NONE ||
152             ((res1 == CONVERSION_RESULT::SCALE_SHIFT || res2 == CONVERSION_RESULT::SCALE_SHIFT) && !is_dequantization && output_shape_rank < 4)) {
153             return false;
154         }
155
156         // TODO: in case if scale and shift constants has equal values the best way is to convert them to Power
157         if (res1 == CONVERSION_RESULT::SCALE_SHIFT || res2 == CONVERSION_RESULT::SCALE_SHIFT || is_dequantization) {
158             NodeVector new_ops;
159
160             auto weights_in = ngraph::op::util::normalize_constant(const_weights_node, output_shape);
161             auto biases_in = ngraph::op::util::normalize_constant(const_bias_node, output_shape);
162             new_ops.push_back(weights_in);
163             new_ops.push_back(biases_in);
164
165             if (is_dequantization) {
166                 const Shape data_shape = data_node.get_shape();
167                 Shape broadcasted_shape = std::vector<size_t>(data_shape.size(), 1ul);
168                 broadcasted_shape[1] = data_shape[1];
169
170                 weights_in = ngraph::op::util::broadcastTo(weights_in, broadcasted_shape);
171                 new_ops.push_back(weights_in);
172
173                 biases_in = ngraph::op::util::broadcastTo(biases_in, broadcasted_shape);
174                 new_ops.push_back(biases_in);
175             }
176
177             if (res1 == CONVERSION_RESULT::POWER && !is_dequantization) {
178                 weights_in = ngraph::op::util::broadcastTo(weights_in, biases_in->get_shape());
179                 new_ops.push_back(weights_in);
180             }
181             if (res2 == CONVERSION_RESULT::POWER && !is_dequantization) {
182                 biases_in = ngraph::op::util::broadcastTo(biases_in, weights_in->get_shape());
183                 new_ops.push_back(biases_in);
184             }
185
186             auto output_type = m.get_match_root()->get_output_element_type(0);
187             auto scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, weights_in, biases_in, output_type);
188             new_ops.push_back(scaleshift);
189
190             scaleshift->set_friendly_name(add_node->get_friendly_name());
191             ngraph::copy_runtime_info({mul_node, add_node}, new_ops);
192             ngraph::replace_node(m.get_match_root(), scaleshift);
193         } else {
194             float scale = 0.f, shift = 0.f;
195             if (!op::util::get_single_value(const_weights_node, scale)) {
196                 return false;
197             }
198             if (!op::util::get_single_value(const_bias_node, shift)) {
199                 return false;
200             }
201
202             auto output_type = m.get_match_root()->get_output_element_type(0);
203             auto power = std::make_shared<ngraph::op::PowerIE>(data_node, 1., scale, shift, output_type);
204             power->set_friendly_name(add_node->get_friendly_name());
205             ngraph::copy_runtime_info({mul_node, add_node}, power);
206             ngraph::replace_node(m.get_match_root(), power);
207         }
208
209         return true;
210     };
211
212     auto m = std::make_shared<ngraph::pattern::Matcher>(add, "CPUFusion.MulAddToScaleShiftOrPower");
213     this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
214 }