ed56764858b5ca7cf8ebe86c507827e122bbb7c5
[platform/upstream/dldt.git] / inference-engine / src / transformations / src / transformations / low_precision / fake_quantize.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "transformations/low_precision/fake_quantize.hpp"
6
7 #include <algorithm>
8 #include <cmath>
9 #include <limits>
10 #include <map>
11 #include <memory>
12 #include <string>
13 #include <utility>
14 #include <vector>
15
16 #include <ngraph/opsets/opset1.hpp>
17
18 #include "transformations/low_precision/common/ie_lpt_exception.hpp"
19 #include "transformations/low_precision/network_helper.hpp"
20
21 namespace ngraph {
22 namespace pass {
23 namespace low_precision {
24
25 void FakeQuantizeTransformation::registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const {
26     addSingleNodePattern<opset1::FakeQuantize>(pass, context);
27 }
28
29 bool FakeQuantizeTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) const {
30     std::shared_ptr<opset1::FakeQuantize> layer = std::dynamic_pointer_cast<opset1::FakeQuantize>(m.get_match_root());
31
32     std::shared_ptr<opset1::FakeQuantize> fakeQuantize = layer;
33
34     do {
35         layer = fakeQuantize;
36         fakeQuantize = fuseElementwise(context, fakeQuantize);
37     } while (fakeQuantize != nullptr);
38
39     const ngraph::element::Type precision = layer->get_output_element_type(0);
40     if ((precision == ngraph::element::i8) || (precision == ngraph::element::u8)) {
41         return false;
42     }
43
44     // FakeQuantize on weights are used without dequantization ScaleShifts
45     if (NetworkHelper::onWeights(layer)) {
46         return false;
47     }
48
49     if (as_type<opset1::Constant>(layer->get_input_node_ptr(0))) {
50         bool nextOpearionsWillBeNotHandled = true;
51         for (auto output : layer->outputs()) {
52             for (auto input : output.get_target_inputs()) {
53                 auto activations = paramsManager->getPrecisionsOnActivations(*input.get_node());
54                 if (paramsManager->getPrecisionsOnActivations(*input.get_node()).size() != 0ul) {
55                     nextOpearionsWillBeNotHandled = false;
56                     break;
57                 }
58             }
59
60             if (!nextOpearionsWillBeNotHandled) {
61                 break;
62             }
63         }
64
65         if (nextOpearionsWillBeNotHandled) {
66             const std::shared_ptr<ngraph::Node> resultConstant = NetworkHelper::fold_fake_quantize(layer);
67             if (as_type_ptr<opset1::Constant>(resultConstant)) {
68                 replace_node(layer, resultConstant);
69                 return true;
70             }
71         }
72     }
73
74     if (!QuantizationDetails::outputLayoutIsSupported(layer)) {
75         return false;
76     }
77
78     if (!QuantizationDetails::isSupportedLevel(layer->get_levels())) {
79         return false;
80     }
81
82     const QuantizationDetails quantizationDetails = QuantizationDetails::getDetails(layer);
83     const DataPrecision dataPrecision = getDataPrecision(layer, quantizationDetails, false);
84     if (dataPrecision.precision == element::undefined) {
85         return false;
86     }
87
88     // Split FakeQuantize to two parts: Quantize and Dequantize
89     auto QDQ = NetworkHelper::decomposeFakeQuantize(
90         as_type_ptr<opset1::FakeQuantize>(layer),
91         dataPrecision.precision,
92         dataPrecision.min,
93         dataPrecision.max,
94         dataPrecision.hasZeroPoint,
95         updatePrecisions);
96
97 #ifdef LPT_PRINT_DEQUANTIZATION_INFO
98     {
99         const std::shared_ptr<opset1::Multiply> multiply = as_type_ptr<opset1::Multiply>(std::get<1>(QDQ));
100         const std::shared_ptr<opset1::Constant> multiplyConst = as_type_ptr<opset1::Constant>(multiply->get_input_node_shared_ptr(1));
101         const std::vector<float> dequantizationScales = multiplyConst->cast_vector<float>();
102
103         const std::shared_ptr<opset1::Subtract> subtract = as_type_ptr<opset1::Subtract>(multiply->get_input_node_shared_ptr(0));
104         std::vector<float> dequantizationShifts;
105         if (subtract != nullptr) {
106             const std::shared_ptr<opset1::Constant> subtractConst = as_type_ptr<opset1::Constant>(subtract->get_input_node_shared_ptr(1));
107             dequantizationShifts = subtractConst->cast_vector<float>();
108         } else {
109             dequantizationShifts = std::vector<float>(dequantizationScales.size());
110         }
111
112         printDequantizationValues(dequantizationScales, dequantizationShifts);
113     }
114 #endif
115
116     std::shared_ptr<ngraph::Node> dequantize = std::get<1>(QDQ);
117     updateOutput(context, dequantize, layer);
118
119     return true;
120 }
121
122 static std::shared_ptr<Node> updateShape(std::shared_ptr<Node> op, const Shape& targetShape) {
123     const Shape shape = op->get_output_shape(0);
124     if ((shape.size() < targetShape.size()) && (shape.size() > 1ul)) {
125         op = fold<opset1::Unsqueeze>(
126             op,
127             std::make_shared<opset1::Constant>(ngraph::element::i32, Shape{ 1 }, std::vector<size_t>({ 0ul })));
128     }
129     return op;
130 }
131
132 static std::shared_ptr<Node> getData(const std::shared_ptr<Node>& eltwise) {
133     if (!is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(0))) {
134         return eltwise->get_input_node_shared_ptr(0);
135     }
136
137     if (!is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(1))) {
138         return eltwise->get_input_node_shared_ptr(1);
139     }
140
141     return nullptr;
142 }
143
144 static std::shared_ptr<opset1::Constant> getConstant(const std::shared_ptr<Node>& eltwise) {
145     if (eltwise->get_input_size() != 2) {
146         return nullptr;
147     }
148
149     std::shared_ptr<opset1::Constant> constant = as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(1));
150     if (constant != nullptr) {
151         return constant;
152     }
153
154     return as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(0));
155 }
156
157 bool FakeQuantizeTransformation::checkElementwise(const std::shared_ptr<Node>& eltwise) {
158     std::shared_ptr<opset1::Constant> constant = getConstant(eltwise);
159     if (constant == nullptr) {
160         return false;
161     }
162
163     Shape shape = constant->get_output_shape(0);
164     if ((!shape.empty()) && (shape_size(shape) != 1ul)) {
165         const Shape eltwiseShape = eltwise->get_output_shape(0);
166         if ((eltwiseShape.size() - shape.size()) > 1) {
167             return false;
168         }
169
170         if ((eltwiseShape.size() - shape.size()) == 1ul) {
171             shape.insert(shape.begin(), 1ul);
172         }
173
174         for (size_t i = 2ul; i < shape.size(); ++i) {
175             if (shape[i] != 1ul) {
176                 return false;
177             }
178         }
179     }
180
181     return getData(eltwise) != nullptr;
182 }
183
184 std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwise(
185     TransformationContext& context,
186     const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) const {
187     const std::shared_ptr<Node> eltwise = fakeQuantize->get_input_node_shared_ptr(0);
188
189     std::shared_ptr<Node> inputLowConst = fakeQuantize->get_input_node_shared_ptr(1);
190     std::shared_ptr<Node> inputHightConst = fakeQuantize->get_input_node_shared_ptr(2);
191
192     std::shared_ptr<opset1::Constant> constant = getConstant(eltwise);
193     if (is_type<opset1::Multiply>(eltwise) && checkElementwise(eltwise)) {
194         const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
195             constant :
196             fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
197
198         const auto valueVec = as_type_ptr<opset1::Constant>(value)->cast_vector<float>();
199         // TODO: temporary fix for GPU Plugin (inverted intervals)
200         for (const float& val : valueVec) {
201             if (val < 0) {
202                 return nullptr;
203             }
204         }
205
206         inputLowConst = updateShape(fold<opset1::Divide>(inputLowConst, value), fakeQuantize->get_output_shape(0));
207         inputHightConst = updateShape(fold<opset1::Divide>(inputHightConst, value), fakeQuantize->get_output_shape(0));
208     } else if (is_type<opset1::Divide>(eltwise) && checkElementwise(eltwise)) {
209         const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
210             constant :
211             fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
212
213         const auto valueVec = as_type_ptr<opset1::Constant>(value)->cast_vector<float>();
214         // TODO: temporary fix for GPU Plugin (inverted intervals)
215         for (const float& val : valueVec) {
216             if (val < 0) {
217                 return nullptr;
218             }
219         }
220
221         inputLowConst = updateShape(fold<opset1::Multiply>(inputLowConst, value), fakeQuantize->get_output_shape(0));
222         inputHightConst = updateShape(fold<opset1::Multiply>(inputHightConst, value), fakeQuantize->get_output_shape(0));
223     } else if (is_type<opset1::Subtract>(eltwise) && checkElementwise(eltwise)) {
224         const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
225             constant :
226             fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
227
228         inputLowConst = updateShape(fold<opset1::Add>(inputLowConst, value), fakeQuantize->get_output_shape(0));
229         inputHightConst = updateShape(fold<opset1::Add>(inputHightConst, value), fakeQuantize->get_output_shape(0));
230     } else if (is_type<opset1::Add>(eltwise) && checkElementwise(eltwise)) {
231         if (is_type<opset1::Convolution>(getData(eltwise)) ||
232             is_type<opset1::GroupConvolution>(getData(eltwise))) {
233             return nullptr;
234         }
235
236         const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
237             constant :
238             fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
239
240         inputLowConst = updateShape(fold<opset1::Subtract>(inputLowConst, value), fakeQuantize->get_output_shape(0));
241         inputHightConst = updateShape(fold<opset1::Subtract>(inputHightConst, value), fakeQuantize->get_output_shape(0));
242     } else if (is_type<opset1::Convert>(eltwise)) {
243         // issue #40611
244         if ((eltwise->input(0).get_element_type() == element::i32) && (eltwise->output(0).get_element_type() == element::f32)) {
245             return nullptr;
246         }
247     } else {
248         return nullptr;
249     }
250
251     std::shared_ptr<opset1::FakeQuantize> newFakeQuantize = as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({
252         getData(eltwise),
253         inputLowConst,
254         inputHightConst,
255         fakeQuantize->input_value(3),
256         fakeQuantize->input_value(4) }));
257
258     replace_node(fakeQuantize, newFakeQuantize);
259     NetworkHelper::copyInfo(fakeQuantize, newFakeQuantize);
260
261     return newFakeQuantize;
262 }
263
264 bool FakeQuantizeTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
265     return false;
266 }
267 } // namespace low_precision
268 } // namespace pass
269 } // namespace ngraph