[LPT] integration: issue #42391 & issue #43001 (#3201)
[platform/upstream/dldt.git] / inference-engine / src / low_precision_transformations / src / subtract.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "low_precision/subtract.hpp"
6
7 #include <algorithm>
8 #include <memory>
9 #include <string>
10 #include <unordered_set>
11 #include <utility>
12 #include <vector>
13
14 #include "low_precision/common/ie_lpt_exception.hpp"
15 #include "low_precision/network_helper.hpp"
16
17 namespace ngraph {
18 namespace pass {
19 namespace low_precision {
20
21 void SubtractTransformation::registerMatcherIn(GraphRewrite &pass, TransformationContext &context) const {
22     addPattern(
23         pass,
24         context,
25         make_op_pattern<opset1::Subtract>({ make_op_label<opset1::Multiply>(), make_op_label<opset1::Constant>() }));
26
27     addPattern(
28         pass,
29         context,
30         make_op_pattern<opset1::Subtract>({ make_op_label<opset1::Convert>(), make_op_label<opset1::Constant>() }));
31 }
32
33 bool SubtractTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) const {
34     std::shared_ptr<opset1::Subtract> subtract = as_type_ptr<opset1::Subtract>(m.get_match_root());
35     if (!canBeTransformed(context, subtract)) {
36         return false;
37     }
38
39     const ngraph::element::Type originalPrecision = subtract->get_output_element_type(0);
40
41     const FakeQuantizeDequantization dequantization = ngraph::pass::low_precision::NetworkHelper::getDequantization(subtract);
42     if (dequantization.multiply != nullptr) {
43         // before: Y = X * SC - SH, after:  Y = (X - SH') * SC
44         //    X * SC - SH = X * SC - SH' * SC
45         //    SH' = SH / SC
46         std::shared_ptr<opset1::Subtract> newSubtract = as_type_ptr<opset1::Subtract>(subtract->copy_with_new_inputs({
47             dequantization.multiply->get_input_node_shared_ptr(0),
48             ngraph::pass::low_precision::fold<ngraph::opset1::Divide>(
49                 subtract->get_input_node_shared_ptr(1),
50                 dequantization.multiply->get_input_node_shared_ptr(1))
51         }));
52
53         std::shared_ptr<Node> newMultiply = dequantization.multiply->copy_with_new_inputs({
54             newSubtract,
55             dequantization.multiply->input_value(1)
56         });
57
58         replace_node(subtract, newMultiply);
59         subtract = newSubtract;
60     }
61
62     if (dequantization.subtract != nullptr) {
63         std::shared_ptr<opset1::Subtract> newSubtract = as_type_ptr<opset1::Subtract>(subtract->copy_with_new_inputs({
64             dequantization.subtract->get_input_node_shared_ptr(0),
65             ngraph::pass::low_precision::fold<ngraph::opset1::Add>(
66                 subtract->get_input_node_shared_ptr(1),
67                 dequantization.subtract->get_input_node_shared_ptr(1))
68         }));
69
70         replace_node(subtract, newSubtract);
71         subtract = newSubtract;
72     }
73
74     if (dequantization.convert != nullptr) {
75         // issue #43088
76         // std::shared_ptr<Node> newSubtract = NetworkHelper::optimizeElementwise(subtract);
77         subtract->set_output_type(0, originalPrecision, subtract->get_output_partial_shape(0));
78
79         replace_node(subtract, std::make_shared<op::TypeRelaxed<opset1::Subtract>>(
80             subtract->get_input_node_shared_ptr(0),
81             subtract->get_input_node_shared_ptr(1)));
82     }
83     return true;
84 }
85
86 } // namespace low_precision
87 } // namespace pass
88 } // namespace ngraph