54b238444d4a71bc2669434af5c822072d2a9aed
[platform/upstream/dldt.git] / inference-engine / src / low_precision_transformations / src / common / add.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "low_precision/add.hpp"
6
7 #include <algorithm>
8 #include <memory>
9 #include <string>
10 #include <utility>
11 #include <vector>
12
13 #include "ngraph_ops/type_relaxed.hpp"
14
15 #include "low_precision/common/ie_lpt_exception.hpp"
16 #include "low_precision/common/dequantization_op.hpp"
17 #include "low_precision/network_helper.hpp"
18
19 namespace ngraph {
20 namespace pass {
21 namespace low_precision {
22
23 std::shared_ptr<opset1::Subtract> replaceToSubtract(const std::shared_ptr<Node>& op) {
24     // TODO: separate this part to standalone transformation: AddToSubtractTransformation
25     // motivation:
26     //    - single responsibility
27     //    - keep AddTransformation and AddToSubtractTransformation transformations independent and optional
28     const auto add = as_type_ptr<opset1::Add>(op);
29     if (add == nullptr) {
30         return nullptr;
31     }
32
33     // TODO: use general way from getDequantization: is eltwise with Constant
34     const int constBranchIndex = is_type<opset1::Constant>(add->get_input_node_ptr(0)) ?
35         0 :
36         (is_type<opset1::Constant>(add->get_input_node_ptr(1)) ? 1 : -1);
37     if (constBranchIndex == -1) {
38         return nullptr;
39     }
40     const size_t dataBranchIndex = constBranchIndex == 0 ? 1ul : 0;
41
42     const auto parent = add->get_input_node_shared_ptr(dataBranchIndex);
43     if (is_type<opset1::Convolution>(parent) ||
44         is_type<opset1::GroupConvolution>(parent) ||
45         (is_type<opset1::MatMul>(parent) &&
46         (is_type<opset1::Constant>(parent->get_input_node_ptr(0)) || is_type<opset1::Constant>(parent->get_input_node_ptr(1))))) {
47         return nullptr;
48     }
49
50     auto constant = fold<opset1::Negative>(add->get_input_node_shared_ptr(constBranchIndex));
51     auto constOutput = constant->output(0);
52
53     const auto subtract = std::make_shared<DequantizationSubtract>(
54         add->get_input_node_shared_ptr(dataBranchIndex),
55         constOutput,
56         add->get_autob());
57     NetworkHelper::copyInfo(add, subtract);
58
59     replace_node(add, subtract);
60     return subtract;
61 }
62
63 std::shared_ptr<opset1::Subtract> fuseWithSubtract(const std::shared_ptr<Node>& op) {
64     const auto add = as_type_ptr<opset1::Add>(op);
65     if ((add == nullptr) ||
66         !is_type<opset1::Subtract>(add->get_input_node_shared_ptr(0)) ||
67         // TODO: use general way from getDequantization: is eltwise with Constant
68         !is_type<opset1::Constant>(add->get_input_node_shared_ptr(0)->get_input_node_shared_ptr(1))) {
69         return nullptr;
70     }
71
72     const auto newSubConst = fold<opset1::Subtract>(
73         add->get_input_node_shared_ptr(0)->get_input_node_shared_ptr(1),
74         add->get_input_node_shared_ptr(1));
75
76     const auto newSubtract = std::make_shared<op::TypeRelaxed<DequantizationSubtract>>(
77         std::vector<element::Type>{element::f32, element::f32},
78         std::vector<element::Type>{ element::f32 },
79         ngraph::op::TemporaryReplaceOutputType(add->get_input_node_shared_ptr(0)->get_input_node_shared_ptr(0), element::f32).get(),
80         ngraph::op::TemporaryReplaceOutputType(newSubConst, element::f32).get());
81     NetworkHelper::copyInfo(add, newSubtract);
82
83     replace_node(add, newSubtract);
84     return newSubtract;
85 }
86
87 void AddTransformation::registerMatcherIn(GraphRewrite &pass, TransformationContext &context) const {
88     addSingleNodePattern<opset1::Add>(pass, context);
89 }
90
91 bool AddTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) const {
92     std::shared_ptr<opset1::Add> op = as_type_ptr<opset1::Add>(m.get_match_root());
93     if (!canBeTransformed(context, op)) {
94         return false;
95     }
96
97     std::shared_ptr<Node> addNode = separateInStandaloneBranch(op);
98     std::shared_ptr<opset1::Add> add = as_type_ptr<opset1::Add>(addNode);
99
100     const int fullPathIndex = getNotEmpty(add);
101     std::shared_ptr<Node> newMultiply;
102     std::shared_ptr<Node> newAddOrSubtract;
103
104     if (fullPathIndex == -1) {
105         // swap constant multiply and add and possibly fuse to subtract
106         const auto multiplyBranch = getMultiplyConstBranch(add);
107
108         if (multiplyBranch.first == -1) {
109             NetworkHelper::foldDequantization(addNode, 0);
110             NetworkHelper::foldDequantization(addNode, 1);
111             return false;
112         }
113
114         newMultiply = NetworkHelper::swapMultiplyAndAdd(add, multiplyBranch.first);
115
116         if (is_type<opset1::Add>(newMultiply->get_input_node_shared_ptr(0))) {
117             newAddOrSubtract = newMultiply->get_input_node_shared_ptr(0);
118
119             auto subtract = fuseWithSubtract(newAddOrSubtract);
120             if (subtract != nullptr) {
121                 newAddOrSubtract = subtract;
122             }
123
124             subtract = replaceToSubtract(newAddOrSubtract);
125             if (subtract != nullptr) {
126                 newAddOrSubtract = subtract;
127             }
128         } else {
129             newAddOrSubtract = newMultiply;
130         }
131     } else {
132         // dequantizations are on both branches
133         const int emptyPathIndex = fullPathIndex == 0 ? 1 : 0;
134
135         FakeQuantizeDequantization dequantizationEmptyPath = NetworkHelper::getDequantization(add, emptyPathIndex);
136         if (updatePrecisions && !dequantizationEmptyPath.empty() && !dequantizationEmptyPath.isLowPrecision()) {
137             return false;
138         }
139
140         std::shared_ptr<Node> subtractEmptyPathValues;
141         std::shared_ptr<Node> multiplyEmptyPathValues;
142         std::tie(subtractEmptyPathValues, multiplyEmptyPathValues) = NetworkHelper::createEmptyValues(dequantizationEmptyPath);
143
144         FakeQuantizeDequantization dequantizationFullPath = NetworkHelper::getDequantization(add, fullPathIndex);
145         if (updatePrecisions && !dequantizationFullPath.empty() && !dequantizationFullPath.isLowPrecision()) {
146             return false;
147         }
148
149         std::shared_ptr<Node> subtractFullPathValues;
150         std::shared_ptr<Node> multiplyFullPathValues;
151         std::tie(subtractFullPathValues, multiplyFullPathValues) = NetworkHelper::createEmptyValues(dequantizationFullPath);
152
153         // calculation
154         // before: Y = (SC1 * (X1 - SH1)) + (SC2 * (X2 - SH2))
155         // after : Y = SC2 * ( SC1' * (X1 - SH1') + X2 ) , where :
156         //         SC1' = SC1 / SC2
157         //         SH1' = SH1 + SC2 * SH2 / SC1
158         std::shared_ptr<Node> newSubtractFullPathValues = fold<opset1::Add>(
159             subtractFullPathValues,
160             fold<opset1::Divide>(
161                 fold<opset1::Multiply>(subtractEmptyPathValues, multiplyEmptyPathValues),
162                 multiplyFullPathValues));
163
164         std::shared_ptr<Node> newMultiplyFullPathValues = fold<opset1::Divide>(multiplyFullPathValues, multiplyEmptyPathValues);
165
166         if (NetworkHelper::isZeroConst(newSubtractFullPathValues)) {
167             newSubtractFullPathValues = nullptr;
168         }
169
170         // graph update
171         std::vector<std::shared_ptr<Node>> inputs{ {}, {} };
172         auto fullPathInput = dequantizationFullPath.convert == nullptr ? dequantizationFullPath.data : dequantizationFullPath.convert;
173
174         inputs[emptyPathIndex] = dequantizationEmptyPath.data.get_node_shared_ptr();
175         inputs[fullPathIndex] = std::make_shared<DequantizationMultiply>(
176             newSubtractFullPathValues == nullptr ?
177                 fullPathInput :
178                 std::make_shared<DequantizationSubtract>(fullPathInput, newSubtractFullPathValues),
179             newMultiplyFullPathValues);
180
181         newAddOrSubtract = std::make_shared<op::TypeRelaxed<opset1::Add>>(
182             std::vector<element::Type>{element::f32, element::f32}, std::vector<element::Type>{ element::f32 },
183             ngraph::op::TemporaryReplaceOutputType(inputs[0], element::f32).get(),
184             ngraph::op::TemporaryReplaceOutputType(inputs[1], element::f32).get());
185         newMultiply = std::make_shared<DequantizationMultiply>(newAddOrSubtract, multiplyEmptyPathValues);
186
187         replace_node(add, newMultiply);
188         NetworkHelper::copyInfo(add, newAddOrSubtract);
189     }
190
191     updateOutput(context, newMultiply, newAddOrSubtract);
192
193     if (fullPathIndex != -1) {
194         std::shared_ptr<Node> node = add;
195         NetworkHelper::foldDequantization(node, fullPathIndex);
196     }
197
198     return true;
199 }
200
201 } // namespace low_precision
202 } // namespace pass
203 } // namespace ngraph