1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "low_precision/add.hpp"
13 #include "ngraph_ops/type_relaxed.hpp"
15 #include "low_precision/common/ie_lpt_exception.hpp"
16 #include "low_precision/common/dequantization_op.hpp"
17 #include "low_precision/network_helper.hpp"
21 namespace low_precision {
23 std::shared_ptr<opset1::Subtract> replaceToSubtract(const std::shared_ptr<Node>& op) {
24 // TODO: separate this part to standalone transformation: AddToSubtractTransformation
26 // - single responsibility
27 // - keep AddTransformation and AddToSubtractTransformation transformations independent and optional
28 const auto add = as_type_ptr<opset1::Add>(op);
33 // TODO: use general way from getDequantization: is eltwise with Constant
34 const int constBranchIndex = is_type<opset1::Constant>(add->get_input_node_ptr(0)) ?
36 (is_type<opset1::Constant>(add->get_input_node_ptr(1)) ? 1 : -1);
37 if (constBranchIndex == -1) {
40 const size_t dataBranchIndex = constBranchIndex == 0 ? 1ul : 0;
42 const auto parent = add->get_input_node_shared_ptr(dataBranchIndex);
43 if (is_type<opset1::Convolution>(parent) ||
44 is_type<opset1::GroupConvolution>(parent) ||
45 (is_type<opset1::MatMul>(parent) &&
46 (is_type<opset1::Constant>(parent->get_input_node_ptr(0)) || is_type<opset1::Constant>(parent->get_input_node_ptr(1))))) {
50 auto constant = fold<opset1::Negative>(add->get_input_node_shared_ptr(constBranchIndex));
51 auto constOutput = constant->output(0);
53 const auto subtract = std::make_shared<DequantizationSubtract>(
54 add->get_input_node_shared_ptr(dataBranchIndex),
57 NetworkHelper::copyInfo(add, subtract);
59 replace_node(add, subtract);
63 std::shared_ptr<opset1::Subtract> fuseWithSubtract(const std::shared_ptr<Node>& op) {
64 const auto add = as_type_ptr<opset1::Add>(op);
65 if ((add == nullptr) ||
66 !is_type<opset1::Subtract>(add->get_input_node_shared_ptr(0)) ||
67 // TODO: use general way from getDequantization: is eltwise with Constant
68 !is_type<opset1::Constant>(add->get_input_node_shared_ptr(0)->get_input_node_shared_ptr(1))) {
72 const auto newSubConst = fold<opset1::Subtract>(
73 add->get_input_node_shared_ptr(0)->get_input_node_shared_ptr(1),
74 add->get_input_node_shared_ptr(1));
76 const auto newSubtract = std::make_shared<op::TypeRelaxed<DequantizationSubtract>>(
77 std::vector<element::Type>{element::f32, element::f32},
78 std::vector<element::Type>{ element::f32 },
79 ngraph::op::TemporaryReplaceOutputType(add->get_input_node_shared_ptr(0)->get_input_node_shared_ptr(0), element::f32).get(),
80 ngraph::op::TemporaryReplaceOutputType(newSubConst, element::f32).get());
81 NetworkHelper::copyInfo(add, newSubtract);
83 replace_node(add, newSubtract);
87 void AddTransformation::registerMatcherIn(GraphRewrite &pass, TransformationContext &context) const {
88 addSingleNodePattern<opset1::Add>(pass, context);
91 bool AddTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) const {
92 std::shared_ptr<opset1::Add> op = as_type_ptr<opset1::Add>(m.get_match_root());
93 if (!canBeTransformed(context, op)) {
97 std::shared_ptr<Node> addNode = separateInStandaloneBranch(op);
98 std::shared_ptr<opset1::Add> add = as_type_ptr<opset1::Add>(addNode);
100 const int fullPathIndex = getNotEmpty(add);
101 std::shared_ptr<Node> newMultiply;
102 std::shared_ptr<Node> newAddOrSubtract;
104 if (fullPathIndex == -1) {
105 // swap constant multiply and add and possibly fuse to subtract
106 const auto multiplyBranch = getMultiplyConstBranch(add);
108 if (multiplyBranch.first == -1) {
109 NetworkHelper::foldDequantization(addNode, 0);
110 NetworkHelper::foldDequantization(addNode, 1);
114 newMultiply = NetworkHelper::swapMultiplyAndAdd(add, multiplyBranch.first);
115 ngraph::copy_runtime_info({ add, newMultiply }, newMultiply);
116 if (is_type<opset1::Add>(newMultiply->get_input_node_shared_ptr(0))) {
117 newAddOrSubtract = newMultiply->get_input_node_shared_ptr(0);
119 auto subtract = fuseWithSubtract(newAddOrSubtract);
120 if (subtract != nullptr) {
121 newAddOrSubtract = subtract;
124 subtract = replaceToSubtract(newAddOrSubtract);
125 if (subtract != nullptr) {
126 newAddOrSubtract = subtract;
129 newAddOrSubtract = newMultiply;
132 // dequantizations are on both branches
133 const int emptyPathIndex = fullPathIndex == 0 ? 1 : 0;
135 FakeQuantizeDequantization dequantizationEmptyPath = NetworkHelper::getDequantization(add, emptyPathIndex);
136 if (updatePrecisions && !dequantizationEmptyPath.empty() && !dequantizationEmptyPath.isLowPrecision()) {
140 std::shared_ptr<Node> subtractEmptyPathValues;
141 std::shared_ptr<Node> multiplyEmptyPathValues;
142 std::tie(subtractEmptyPathValues, multiplyEmptyPathValues) = NetworkHelper::createEmptyValues(dequantizationEmptyPath);
144 FakeQuantizeDequantization dequantizationFullPath = NetworkHelper::getDequantization(add, fullPathIndex);
145 if (updatePrecisions && !dequantizationFullPath.empty() && !dequantizationFullPath.isLowPrecision()) {
149 std::shared_ptr<Node> subtractFullPathValues;
150 std::shared_ptr<Node> multiplyFullPathValues;
151 std::tie(subtractFullPathValues, multiplyFullPathValues) = NetworkHelper::createEmptyValues(dequantizationFullPath);
154 // before: Y = (SC1 * (X1 - SH1)) + (SC2 * (X2 - SH2))
155 // after : Y = SC2 * ( SC1' * (X1 - SH1') + X2 ) , where :
157 // SH1' = SH1 + SC2 * SH2 / SC1
158 std::shared_ptr<Node> newSubtractFullPathValues = fold<opset1::Add>(
159 subtractFullPathValues,
160 fold<opset1::Divide>(
161 fold<opset1::Multiply>(subtractEmptyPathValues, multiplyEmptyPathValues),
162 multiplyFullPathValues));
164 std::shared_ptr<Node> newMultiplyFullPathValues = fold<opset1::Divide>(multiplyFullPathValues, multiplyEmptyPathValues);
166 if (NetworkHelper::isZeroConst(newSubtractFullPathValues)) {
167 newSubtractFullPathValues = nullptr;
171 std::vector<std::shared_ptr<Node>> inputs{ {}, {} };
172 auto fullPathInput = dequantizationFullPath.convert == nullptr ? dequantizationFullPath.data : dequantizationFullPath.convert;
174 inputs[emptyPathIndex] = dequantizationEmptyPath.data.get_node_shared_ptr();
175 inputs[fullPathIndex] = std::make_shared<DequantizationMultiply>(
176 newSubtractFullPathValues == nullptr ?
178 std::make_shared<DequantizationSubtract>(fullPathInput, newSubtractFullPathValues),
179 newMultiplyFullPathValues);
181 newAddOrSubtract = std::make_shared<op::TypeRelaxed<opset1::Add>>(
182 std::vector<element::Type>{element::f32, element::f32}, std::vector<element::Type>{ element::f32 },
183 ngraph::op::TemporaryReplaceOutputType(inputs[0], element::f32).get(),
184 ngraph::op::TemporaryReplaceOutputType(inputs[1], element::f32).get());
185 newMultiply = std::make_shared<DequantizationMultiply>(newAddOrSubtract, multiplyEmptyPathValues);
187 replace_node(add, newMultiply);
188 NetworkHelper::copyInfo(add, newAddOrSubtract);
189 ngraph::copy_runtime_info({ add, newMultiply }, newMultiply);
192 updateOutput(context, newMultiply, newAddOrSubtract);
194 if (fullPathIndex != -1) {
195 std::shared_ptr<Node> node = add;
196 NetworkHelper::foldDequantization(node, fullPathIndex);
202 bool AddTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
203 const FakeQuantizeDequantization dequantization1 = pass::low_precision::NetworkHelper::getDequantization(layer, 0ul);
204 if (dequantization1.multiplyHasZero()) {
208 const FakeQuantizeDequantization dequantization2 = pass::low_precision::NetworkHelper::getDequantization(layer, 1ul);
209 if (dequantization2.multiplyHasZero()) {
213 return EltwiseBaseTransformation::canBeTransformed(context, layer);
216 } // namespace low_precision
218 } // namespace ngraph