[LPT] integration branch: Reshape fix, Concat generalization, runtime info usage...
[platform/upstream/dldt.git] / inference-engine / src / low_precision_transformations / src / common / concat.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "low_precision/concat.hpp"
6
7 #include <algorithm>
8 #include <map>
9 #include <memory>
10 #include <string>
11 #include <utility>
12 #include <vector>
13
14 #include <ngraph/opsets/opset1.hpp>
15
16 #include "low_precision/common/fake_quantize_dequantization.hpp"
17 #include "low_precision/common/ie_lpt_exception.hpp"
18 #include "low_precision/common/subgraph.hpp"
19 #include "low_precision/common/dequantization_op.hpp"
20 #include "low_precision/network_helper.hpp"
21
22 namespace ngraph {
23 namespace pass {
24 namespace low_precision {
25
26 void ConcatTransformation::registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const {
27     addSingleNodePattern<opset1::Concat>(pass, context);
28 }
29
30 bool ConcatTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) const {
31     std::shared_ptr<ngraph::opset1::Concat> concat = ngraph::as_type_ptr<ngraph::opset1::Concat>(m.get_match_root());
32     if (!canBeTransformed(context, concat)) {
33         return false;
34     }
35
36     ngraph::pass::low_precision::Subgraph subgraph(layerTransformationsManager);
37     std::unordered_set<std::string> handledLayers;
38     if (!subgraph.fillSubgraphForConcat(concat, handledLayers)) {
39         return false;
40     }
41
42     if (subgraph.quantizationLayers.empty() || isHandled(context, subgraph.quantizationLayers)) {
43         return false;
44     }
45
46     // precisions can be different
47     ngraph::Node& quantizationLayer = *subgraph.quantizationLayers[0];
48     std::shared_ptr<ngraph::opset1::FakeQuantize> fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(quantizationLayer.shared_from_this());
49     DataPrecision dataPrecision = getDataPrecision(fq, QuantizationDetails::getDetails(fq), false);
50     if (dataPrecision.precision == ngraph::element::undefined) {
51         return false;
52     }
53
54     std::unordered_map<std::string, ngraph::pass::low_precision::FakeQuantizeDequantization> dequantizations;
55     std::vector<QuantizationDetails> quantizationLayersDetails;
56
57     for (size_t i = 0; i < subgraph.quantizationLayers.size(); ++i) {
58         const std::shared_ptr<ngraph::Node> fakeQuantizeLayer = subgraph.quantizationLayers[i];
59
60         const ngraph::Shape shape = fakeQuantizeLayer->get_output_shape(0);
61         if (shape.size() < 4ul) {
62             return false;
63         }
64
65         const std::shared_ptr<ngraph::opset1::FakeQuantize> fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(fakeQuantizeLayer->shared_from_this());
66         if (fq == nullptr) {
67             return false;
68         }
69
70         const QuantizationDetails& quantizationDetails = QuantizationDetails::getDetails(fq);
71         quantizationLayersDetails.push_back(quantizationDetails);
72
73         const DataPrecision dataPrecision2 = getDataPrecision(subgraph.quantizationLayers[i]->shared_from_this(), quantizationDetails, false);
74         if (dataPrecision2.precision == ngraph::element::undefined) {
75             return false;
76         }
77
78         if (dataPrecision.precision != dataPrecision2.precision) {
79             // quantization levels are the same, difference can be in sign
80             // wider interval (precision) is preferable: use signed if least one interval is signed
81             dataPrecision = dataPrecision.precision.is_signed() ? dataPrecision : dataPrecision2;
82         }
83     }
84
85     if (dataPrecision.precision == ngraph::element::undefined) {
86         return false;
87     }
88
89     // per tensor scale is supported only
90     if (quantizationLayersDetails.empty() || (quantizationLayersDetails[0].inputHighValues.size() != 1ul)) {
91         return false;
92     }
93
94     FakeQuantizeDequantization dequantization;
95
96     if ((quantizationLayersDetails[0].inputHighValues.size() == 1)) {
97         float outputLowValue = quantizationLayersDetails[0].outputLowValues[0];
98         float outputHighValue = quantizationLayersDetails[0].outputHighValues[0];
99
100         for (size_t index = 0lu; index < subgraph.quantizationLayers.size(); index++) {
101             const QuantizationDetails& quantizationDetails = quantizationLayersDetails[index];
102             if (outputLowValue > quantizationDetails.outputLowValues[0]) {
103                 outputLowValue = quantizationDetails.outputLowValues[0];
104             }
105             if (outputHighValue < quantizationDetails.outputHighValues[0]) {
106                 outputHighValue = quantizationDetails.outputHighValues[0];
107             }
108         }
109
110         if ((outputLowValue == 0.f) && (outputHighValue == 0.f)) {
111             return false;
112         }
113
114         const float maxOutputInterval = outputHighValue - outputLowValue;
115         if (quantizedTensorAlignmentOnActivations == QuantizedTensorAlignment::UpdateLevel) {
116             const size_t minLevels = getMinQuantizationLevels(
117                 dataPrecision,
118                 maxOutputInterval,
119                 quantizationLayersDetails,
120                 outputLowValue,
121                 outputHighValue);
122             if (minLevels < this->minQuantizationLevels) {
123                 return false;
124             }
125         }
126
127         // FQ -> SUB_quantization -> MUL_quantization -[INT8]-> SUB_dequantization -> MUL_dequantization ->
128         const float quantizationMul = (dataPrecision.max - dataPrecision.min) / maxOutputInterval;
129         const float dequantizationMul = maxOutputInterval / (dataPrecision.max - dataPrecision.min);
130
131         // FQ outputLowValue = dataPrecision.min * dequantizationMul - quantizationSub
132         const float quantizationSub = outputLowValue - dataPrecision.min * dequantizationMul;
133         const float dequantizationSub = std::round(-quantizationSub * quantizationMul);
134
135         // 1. get data for dequantization. Dequantization data will be used several times later.
136         dequantization = ngraph::pass::low_precision::NetworkHelper::makeDequantization(
137             dequantizationMul,
138             dequantizationSub,
139             subgraph.quantizationLayers[0]->get_output_element_type(0),
140             subgraph.quantizationLayers[0]->get_output_shape(0),
141             dataPrecision.precision,
142             dataPrecision.min,
143             dataPrecision.max);
144
145         for (int index = 0; index < subgraph.quantizationLayers.size(); index++) {
146             std::shared_ptr<ngraph::opset1::FakeQuantize> fakeQuantizeLayer = as_type_ptr<ngraph::opset1::FakeQuantize>(
147                 subgraph.quantizationLayers[index]->shared_from_this());
148
149             const QuantizationDetails& quantizationDetails = quantizationLayersDetails[index];
150
151             switch (quantizedTensorAlignmentOnActivations) {
152                 case QuantizedTensorAlignment::None: {
153                     THROW_TRANSFORMATION_EXCEPTION << "not implemented: " << quantizedTensorAlignmentOnActivations;
154                 }
155                 case QuantizedTensorAlignment::UpdateLevel: {
156                     const float updatedOutputLowValue = (quantizationDetails.outputLowValues[0] - quantizationSub) * quantizationMul;
157                     const float updatedOutputHighValue = (quantizationDetails.outputHighValues[0] - quantizationSub) * quantizationMul;
158
159                     // 2. update FakeQuantize - one time action
160                     std::shared_ptr<opset1::FakeQuantize> newFakeQuantizeLayer = ngraph::pass::low_precision::NetworkHelper::updateFakeQuantize(
161                         fakeQuantizeLayer,
162                         updatePrecisions ? dataPrecision.precision : fakeQuantizeLayer->get_output_element_type(0),
163                         roundf(updatedOutputLowValue),
164                         roundf(updatedOutputHighValue));
165
166                     const size_t levels = static_cast<size_t>(fabs(roundf(updatedOutputHighValue) - roundf(updatedOutputLowValue)) + 1.0);
167                     newFakeQuantizeLayer->set_levels(levels);
168
169                     subgraph.quantizationLayers[index] = newFakeQuantizeLayer;
170                     subgraph.layers[fakeQuantizeLayer->get_friendly_name()] = newFakeQuantizeLayer;
171                     break;
172                 }
173                 default: {
174                     THROW_TRANSFORMATION_EXCEPTION << "unexpected value " << quantizedTensorAlignmentOnActivations;
175                 }
176             }
177         }
178     } else {
179         return false;
180     }
181
182     auto dequantizationValuesCallback = [&](
183         std::shared_ptr<ngraph::Node> layer,
184         const std::string originalLayerName,
185         std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate) {
186         dequantizationsToConcatenate.push_back(dequantization);
187     };
188
189     addDequantizationLayers(context, subgraph, dequantizationValuesCallback);
190
191     if (updatePrecisions) {
192         for (const auto it : subgraph.layers) {
193             const std::shared_ptr<ngraph::Node>& node = it.second;
194             if (std::dynamic_pointer_cast<ngraph::op::TypeRelaxedBase>(node) != nullptr) {
195                 ngraph::pass::low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(node->shared_from_this(), dataPrecision.precision);
196             } else {
197                 // set precision to explicitly to have updated precision during transformation
198                 for (size_t i = 0; i < node->get_output_size(); ++i) {
199                     node->set_output_type(i, dataPrecision.precision, node->get_output_partial_shape(i));
200                 }
201             }
202         }
203     }
204
205     for (const std::shared_ptr<ngraph::Node>& quantizationLayer : subgraph.quantizationLayers) {
206         context.quantizedFakeQuantizeNames.insert(quantizationLayer->get_friendly_name());
207     }
208     return true;
209 }
210
211 bool ConcatTransformation::isPrecisionPreserved(std::shared_ptr<Node>) const noexcept {
212     return true;
213 }
214
215 bool ConcatTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
216     std::shared_ptr<opset1::Concat> concat = as_type_ptr<opset1::Concat>(layer);
217     return concat->get_axis() == 1ul;
218 }
219
220
221 void ConcatTransformation::addDequantizationLayers(
222     TransformationContext& context,
223     ngraph::pass::low_precision::Subgraph& subgraph,
224     std::function<void(
225         std::shared_ptr<ngraph::Node> layer,
226         const std::string originalLayerName,
227         std::vector<FakeQuantizeDequantization>& dequantizationsToConcatenate)> getLayerDequantizationCallback) const {
228     std::unordered_map<std::string, ngraph::Node*> outputs;
229     for (size_t i = 0; i < context.function->get_output_size(); ++i) {
230         ngraph::Node* node = context.function->get_output_op(i).get();
231         if (node->get_input_size() != 1ul) {
232             THROW_IE_LPT_EXCEPTION(*node) << "unexpected inputs count for result node";
233         }
234
235         outputs.emplace(node->get_input_node_shared_ptr(0)->get_friendly_name(), node);
236     }
237
238     std::unordered_map<std::string, std::shared_ptr<ngraph::Node>> notHandledSubgraphLayers = subgraph.layers;
239     while (notHandledSubgraphLayers.size() != 0ul) {
240         const auto layerIt = notHandledSubgraphLayers.begin();
241         std::shared_ptr<ngraph::Node> layer = layerIt->second;
242         notHandledSubgraphLayers.erase(layerIt);
243
244         std::vector<FakeQuantizeDequantization> layerDequantizations;
245
246         for (int i = 0; i < layer->get_output_size(); ++i) {
247             const auto childInputs = layer->get_output_target_inputs(i);
248             for (const auto childInput : childInputs) {
249                 ngraph::Node& child = *childInput.get_node();
250
251                 if (subgraph.layers.find(child.get_friendly_name()) == subgraph.layers.end()) {
252                     if (layerDequantizations.size() == 0ul) {
253                         getLayerDequantizationCallback(layer, layer->get_friendly_name(), layerDequantizations);
254                     }
255
256                     std::shared_ptr<ngraph::Node> source = layer->shared_from_this();
257                     {
258                         std::vector<std::shared_ptr<ngraph::Node>> convertNodes;
259                         std::vector<std::shared_ptr<ngraph::Node>> subtractNodes;
260                         std::vector<std::shared_ptr<ngraph::Node>> multiplyNodes;
261
262                         if (layerDequantizations.size() > 1ul) {
263                             auto broadcastElementWiseConst = [](
264                                 // FakeQuantize constant shape must be broadcastable to the shape on data.
265                                 std::shared_ptr<ngraph::opset1::Constant> operation,
266                                 const ngraph::Shape targetShape) -> std::shared_ptr<Node> {
267                                 auto targetShapeConst = std::make_shared<ngraph::opset1::Constant>(
268                                     element::i64, ngraph::Shape{ targetShape.size() },
269                                     targetShape);
270
271                                 auto broadcast = ngraph::pass::low_precision::fold<ngraph::opset1::Broadcast>(
272                                     operation,
273                                     targetShapeConst,
274                                     ngraph::op::AutoBroadcastType::NUMPY);
275
276                                 return broadcast;
277                             };
278
279                             bool allDequantizationShiftAreZero = true;
280                             bool allDequantizationMultiplyAreZero = true;
281                             for (FakeQuantizeDequantization dequantization : layerDequantizations) {
282                                 if (dequantization.subtract != nullptr) {
283                                     allDequantizationShiftAreZero = false;
284                                 }
285                                 if (dequantization.multiply != nullptr) {
286                                     allDequantizationMultiplyAreZero = false;
287                                 }
288                             }
289
290                             for (size_t i = 0; i < layerDequantizations.size(); ++i) {
291                                 const auto& dequantization = layerDequantizations[i];
292
293                                 convertNodes.push_back(dequantization.convert);
294
295                                 const ngraph::element::Type precision = dequantization.data.get_element_type();
296                                 ngraph::Shape targetShape = dequantization.data.get_shape();
297
298                                 targetShape[0] = 1ul;
299                                 for (size_t i = 2; i < targetShape.size(); ++i) {
300                                     targetShape[i] = 1ul;
301                                 }
302
303                                 if (!allDequantizationShiftAreZero) {
304                                     subtractNodes.push_back(dequantization.subtract == nullptr ?
305                                         std::make_shared<ngraph::opset1::Constant>(precision, targetShape, std::vector<float>({ 0.f })) :
306                                         broadcastElementWiseConst(
307                                             as_type_ptr<ngraph::opset1::Constant>(dequantization.subtract->input_value(1).get_node_shared_ptr()),
308                                             targetShape));
309                                 }
310
311                                 if (!allDequantizationMultiplyAreZero) {
312                                     multiplyNodes.push_back(dequantization.multiply == nullptr ?
313                                         std::make_shared<ngraph::opset1::Constant>(precision, targetShape, std::vector<float>({ 1.0f })) :
314                                         broadcastElementWiseConst(
315                                             as_type_ptr<ngraph::opset1::Constant>(dequantization.multiply->input_value(1).get_node_shared_ptr()),
316                                             targetShape));
317                                 }
318                             }
319                         } else {
320                             // TODO: check constant shapes here - has to be scalar
321                             if (layerDequantizations[0].convert != nullptr) {
322                                 convertNodes.push_back(layerDequantizations[0].convert);
323                             }
324
325                             if (layerDequantizations[0].subtract != nullptr) {
326                                 subtractNodes.push_back(layerDequantizations[0].subtract->input_value(1).get_node_shared_ptr());
327                             }
328
329                             if (layerDequantizations[0].multiply != nullptr) {
330                                 multiplyNodes.push_back(layerDequantizations[0].multiply->input_value(1).get_node_shared_ptr());
331                             }
332                         }
333
334                         // TODO: the second place (first is FQ decomposition) where dequantization operations are inserted
335                         const std::shared_ptr<ngraph::Node> destination = child.shared_from_this();
336
337                         if (!convertNodes.empty()) {
338                             const size_t sourceOutputIdx = NetworkHelper::getChildInputIndex(source, destination);
339                             std::shared_ptr<ngraph::Node> convert =
340                                 convertNodes[0]->clone_with_new_inputs({ destination->get_input_source_output(sourceOutputIdx) });
341                             insert_new_node_between(source, destination, convert);
342                             ngraph::copy_runtime_info({ layer, convert }, convert);
343                             source = convert;
344                         }
345
346                         // concatenation axis is 1
347                         if (!subtractNodes.empty()) {
348                             const size_t sourceOutputIdx = NetworkHelper::getChildInputIndex(source, destination);
349                             std::shared_ptr<ngraph::opset1::Subtract> subtract = std::make_shared<DequantizationSubtract>(
350                                 destination->get_input_source_output(sourceOutputIdx),
351                                 NetworkHelper::toScalarIfPossible(subtractNodes.size() == 1ul ?
352                                     subtractNodes[0] :
353                                     ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(subtractNodes, 1)));
354                             insert_new_node_between(source, destination, subtract);
355                             ngraph::copy_runtime_info({ layer, subtract }, subtract);
356                             source = subtract;
357                         }
358
359                         if (!multiplyNodes.empty()) {
360                             const size_t sourceOutputIdx = NetworkHelper::getChildInputIndex(source, destination);
361                             std::shared_ptr<ngraph::opset1::Multiply> multiply = std::make_shared<DequantizationMultiply>(
362                                 destination->get_input_source_output(sourceOutputIdx),
363                                 NetworkHelper::toScalarIfPossible(multiplyNodes.size() == 1ul ?
364                                     multiplyNodes[0] :
365                                     ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(multiplyNodes, 1)));
366                             insert_new_node_between(source, destination, multiply);
367                             ngraph::copy_runtime_info({ layer, multiply }, multiply);
368                             source = multiply;
369                         }
370                     }
371
372                     // first input is used
373                     const ngraph::element::Type precision = layerDequantizations[0].data.get_element_type();
374                     layer->set_output_type(0, precision, layer->get_output_partial_shape(0));
375
376                     const auto it = outputs.find(layer->get_friendly_name());
377                     if (it != outputs.end()) {
378                         const std::string originalName = layer->get_friendly_name();
379                         const std::string newName = layer->get_friendly_name() + LayerTransformation::originalLayerPostfix;
380                         layer->set_friendly_name(newName);
381                         source->set_friendly_name(originalName);
382                         subgraph.layers[layer->get_friendly_name()] = layer;
383                     }
384                 }
385             }
386         }
387     }
388 }
389
390 bool ConcatTransformation::isHandled(const TransformationContext& context, const std::vector<std::shared_ptr<ngraph::Node>>& quantizationOperations) {
391     for (const std::shared_ptr<ngraph::Node>& quantizationLayer : quantizationOperations) {
392         if (context.quantizedFakeQuantizeNames.find(quantizationLayer->get_friendly_name()) != context.quantizedFakeQuantizeNames.end()) {
393             return true;
394         }
395     }
396
397     return false;
398 }
399
400 size_t ConcatTransformation::getMinQuantizationLevels(
401     const DataPrecision& dataPrecision,
402     const float maxOutputInterval,
403     const std::vector<QuantizationDetails>& quantizationLayersDetails,
404     const float outputLowValue,
405     const float outputHighValue) const {
406     size_t minLevels = std::numeric_limits<std::size_t>::max();
407     for (const QuantizationDetails quantizationDetails : quantizationLayersDetails) {
408         // if there is negative part then calculation is based on `outputLowValue` if not then on `outputHighValue` only
409         const float updatedOutputLowValue = outputLowValue != 0.f ?
410             (quantizationDetails.outputLowValues[0] / outputLowValue) * dataPrecision.min :
411             (quantizationDetails.outputLowValues[0] / outputHighValue) * dataPrecision.max;
412
413         // if there is positive part then calculation is based on `outputHighValue` if not then on `outputLowValue` only
414         const float updatedOutputHighValue = outputHighValue != 0.f ?
415             (quantizationDetails.outputHighValues[0] / outputHighValue) * dataPrecision.max :
416             (quantizationDetails.outputHighValues[0] / outputLowValue) * dataPrecision.min;
417
418         const int levels = static_cast<int>(fabs(roundf(updatedOutputHighValue) - roundf(updatedOutputLowValue)) + 1.0);
419         if (minLevels > levels) {
420             minLevels = levels;
421         }
422     }
423     return minLevels;
424 }
425
426 } // namespace low_precision
427 } // namespace pass
428 } // namespace ngraph