8bd84d5d1f00d780e1020d021f4b8fb7337cd94c
[platform/upstream/dldt.git] / inference-engine / src / low_precision_transformations / src / network_helper.cpp
1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <low_precision/network_helper.hpp>
6
7 #include <algorithm>
8 #include <cmath>
9 #include <limits>
10 #include <map>
11 #include <memory>
12 #include <string>
13 #include <unordered_set>
14 #include <utility>
15 #include <vector>
16 #include <queue>
17
18 #include <ngraph/rt_info.hpp>
19 #include "low_precision/common/ie_lpt_exception.hpp"
20 #include "low_precision/common/dequantization_op.hpp"
21
22 namespace ngraph {
23 namespace pass {
24 namespace low_precision {
25
26 // Return true if `type` can be castable to at least one of `type`
27 bool NetworkHelper::is_castable_to_one_of(NodeTypeInfo type, const std::unordered_set<NodeTypeInfo>& types) {
28     for (auto another : types) {
29         if (type.is_castable(another)) {
30             return true;
31         }
32     }
33     return false;
34 }
35
36 // Collect and return a vector with all nodes that consumes any of the `node` output
37 std::vector<Input<Node>> NetworkHelper::consumer_inputs(std::shared_ptr<Node> node) {
38     std::vector<Input<Node>> result;
39     for (const auto& output_port : node->outputs()) {
40         for (const auto &input : output_port.get_target_inputs()) {
41             result.push_back(input);
42         }
43     }
44     return result;
45 }
46
47 std::vector<std::shared_ptr<Node>> NetworkHelper::consumers(std::shared_ptr<Node> node) {
48     auto inputs = consumer_inputs(node);
49     std::vector<std::shared_ptr<Node>> result(inputs.size());
50     std::transform(inputs.begin(), inputs.end(), result.begin(), [](Input<Node> input){ return input.get_node()->shared_from_this(); });
51     return result;
52 }
53
54 int NetworkHelper::onWeightsInDepth(std::shared_ptr<Node> layer) {
55     const std::vector<std::shared_ptr<Node>> children = consumers(layer);
56     for (std::shared_ptr<Node> child : children) {
57         if ((is_type<opset1::Convolution>(child) ||
58             is_type<opset1::GroupConvolution>(child) ||
59             is_type<opset1::MatMul>(child)) &&
60             (child->inputs().size() >= 2lu)) {
61             const std::vector<std::shared_ptr<Node>> parents = getParentsRecursivelyExceptTypes(child, {}, 1);
62             for (const std::shared_ptr<Node>& parent : parents) {
63                 if (parent.get() == layer.get()) {
64                     return 1;
65                 }
66             }
67             return -1;
68         }
69
70         const int result = onWeightsInDepth(child);
71         if (result != 0) {
72             return result;
73         }
74     }
75     return 0;
76 }
77
78 bool NetworkHelper::onWeights(std::shared_ptr<Node> layer) {
79     const int result = onWeightsInDepth(layer);
80     return result == 1;
81 }
82
83 size_t NetworkHelper::getOutputChannelsCount(std::shared_ptr<const Node> layer, bool isOnWeights) {
84     if (layer->outputs().size() == 0) {
85         THROW_TRANSFORMATION_EXCEPTION << "Layer " << layer->get_friendly_name() << " doesn't have output tensors";
86     }
87
88     if (layer->outputs().size() > 1) {
89         THROW_TRANSFORMATION_EXCEPTION << "Layer " << layer->get_friendly_name() << " has too many output tensors, expected one";
90     }
91
92     PartialShape shape = layer->get_output_partial_shape(0);
93     if (shape.rank() == 0) {
94         THROW_TRANSFORMATION_EXCEPTION << "Invalid dimensions count (0) in output of " << layer->get_friendly_name() << " layer on weights";
95     }
96     if (isOnWeights) {
97         return shape[0].get_length();
98     } else {
99         if (shape.rank() == 1) {
100             return shape[0].get_length();
101         }
102         return shape[1].get_length();
103     }
104 }
105
106 std::vector<std::shared_ptr<Node>> NetworkHelper::getParentsRecursivelyExceptTypes(
107         std::shared_ptr<Node> layer,
108         const std::unordered_set<NodeTypeInfo>& exceptionLayerTypes,
109         const int portIndex) {
110     std::vector<std::shared_ptr<Node>> parents;
111     size_t i = 0ul;
112     for (auto input : layer->inputs()) {
113         if ((portIndex == -1) || (portIndex == i)) {
114             auto parent = input.get_source_output().get_node_shared_ptr();
115             if (is_castable_to_one_of(parent->get_type_info(), exceptionLayerTypes)) {
116                 const std::vector<std::shared_ptr<Node>> tmpParents = getParentsRecursivelyExceptTypes(parent, exceptionLayerTypes);
117                 parents.insert(parents.end(), tmpParents.begin(), tmpParents.end());
118             } else {
119                 parents.push_back(parent);
120             }
121         }
122
123         i++;
124     }
125     return parents;
126 }
127
128 size_t NetworkHelper::getInputChannelsCount(std::shared_ptr<Node> layer) {
129     if (layer->get_input_size() == 0) {
130         THROW_TRANSFORMATION_EXCEPTION << "There are no input layers";
131     }
132
133     PartialShape shape = layer->get_input_partial_shape(0);
134     if (shape.rank().get_length() <= 1) {
135         THROW_TRANSFORMATION_EXCEPTION << "Invalid dimensions count (0) in input of " << layer->get_friendly_name();
136     }
137
138     return shape[1].get_length();
139 }
140
141 size_t NetworkHelper::getGroupsCount(std::shared_ptr<Node> layer) {
142     if (as_type_ptr<opset1::Convolution>(layer)) {
143         return 1;
144     } else if (auto group_convolution = as_type_ptr<opset1::GroupConvolution>(layer)) {
145         return layer->get_input_shape(1)[0];    // input weights for opset1::GC is in format GOI..., see the specification
146     } else {
147         THROW_TRANSFORMATION_EXCEPTION << "Invalid layer type of " << layer->get_friendly_name() << "; expected Convolutino or GroupConvolution";
148     }
149 }
150
151 // Assumin tensor in NC... layout, append necessary number of 1s to shape to align it to a give rank
152 Shape NetworkHelper::alignShapeForChannelDim(const Shape& shape, Rank rank) {
153     assert(shape.size() == 1);
154     assert(rank.is_static());
155     Shape result = shape;
156     result.resize(rank.get_length() - 1, 1);
157     return result;
158 }
159
160 void NetworkHelper::removeLayer(std::shared_ptr<Node> layer) {
161     ngraph::replace_output_update_name(layer->output(0), layer->input_value(0));
162 }
163
164 std::shared_ptr<Node> NetworkHelper::swapMultiplyAndAdd(std::shared_ptr<opset1::Add> addAfterMultiply, const int multiplyBranch) {
165     // Multiply --> Add(addAfterMultiply)  ==>  Add(new) --> Multiply(new)
166     // That means x*a + b ==> (x + b/a)*a; tries to fold b/a
167     const auto multiply = addAfterMultiply->get_input_node_shared_ptr(multiplyBranch);
168
169     const auto multiplyParent1 = multiply->get_input_node_shared_ptr(0);
170     const auto multiplyParent2 = multiply->get_input_node_shared_ptr(1);
171
172     auto multiplyInput = as_type_ptr<opset1::Multiply>(multiplyParent1);
173     auto multiplyConst = as_type_ptr<opset1::Constant>(multiplyParent2);
174     int multiplyInputBranch = 0;
175
176     if (multiplyConst == nullptr) {
177         multiplyInput = as_type_ptr<opset1::Multiply>(multiplyParent2);
178         multiplyConst = as_type_ptr<opset1::Constant>(multiplyParent1);
179         multiplyInputBranch = 1;
180     }
181
182     if (multiplyConst == nullptr)
183         return addAfterMultiply;
184
185     const auto x = multiply->get_input_node_shared_ptr(multiplyInputBranch);
186     const auto a = multiply->get_input_node_shared_ptr(multiplyInputBranch == 0 ? 1 : 0);
187     const auto b = addAfterMultiply->get_input_node_shared_ptr(multiplyBranch == 0 ? 1 : 0);
188     std::shared_ptr<Node> bDivA;
189
190     if (shape_size(b->get_output_shape(0)) == 1 ||
191         shape_size(a->get_output_shape(0)) == 1 ||
192         shape_size(b->get_output_shape(0)) == shape_size(a->get_output_shape(0))) {
193         // safely division to avoid NaN
194         const std::vector<float> bValues = as_type_ptr<opset1::Constant>(b)->cast_vector<float>();
195         const std::vector<float> aValues = as_type_ptr<opset1::Constant>(a)->cast_vector<float>();
196         const bool aBroadcasted = bValues.size() > aValues.size();
197         const bool bBroadcasted = bValues.size() < aValues.size();
198         std::vector<float> bDivAValues(aBroadcasted ? bValues.size() : aValues.size());
199
200         for (int i = 0; i < bDivAValues.size(); ++i) {
201             const auto bi = bValues[bBroadcasted ? 0 : i];
202             const auto ai = aValues[aBroadcasted ? 0 : i];
203             if (bi != 0.f || ai != 0.f) {
204                 bDivAValues[i] = bi / ai;
205             } else {
206                 bDivAValues[i] = 0.f;
207             }
208         }
209
210         bDivA = std::make_shared<opset1::Constant>(
211                 b->get_output_element_type(0),
212                 aBroadcasted ? b->get_output_shape(0) : a->get_output_shape(0),
213                 bDivAValues);
214     } else {
215         bDivA = fold<opset1::Divide>(b, a);
216     }
217
218     std::vector<std::shared_ptr<Node>> inputs{ {}, {} };
219
220     inputs[0] = x;
221     inputs[1] = bDivA;
222
223     std::shared_ptr<opset1::Add> newAdd = std::make_shared<op::TypeRelaxed<opset1::Add>>(
224         std::vector<element::Type>{element::f32, element::f32}, std::vector<element::Type>{ element::f32 },
225         ngraph::op::TemporaryReplaceOutputType(inputs[0], element::f32).get(),
226         ngraph::op::TemporaryReplaceOutputType(inputs[1], element::f32).get());
227     copyInfo(addAfterMultiply, newAdd);
228
229     NetworkHelper::setOutDataPrecision(newAdd, addAfterMultiply->get_output_element_type(0));
230
231     auto newMultiply = std::make_shared<DequantizationMultiply>(newAdd, a);
232     copyInfo(multiply, newMultiply);
233
234     replace_node(addAfterMultiply, newMultiply);
235     return newMultiply;
236 }
237
238 void NetworkHelper::copyInfo(const std::shared_ptr<Node>& source, const std::shared_ptr<Node>& target) {
239     // TODO: merge_runtime_info with correctly defined DEQUANTIZATION
240     const auto& sourceAttributes = source->get_rt_info();
241     auto& targetAttrubutes = target->get_rt_info();
242     for (auto attribute : sourceAttributes) {
243         targetAttrubutes[attribute.first] = attribute.second;
244     }
245
246     const std::string friendlyName = source->get_friendly_name();
247     if (!friendlyName.empty()) {
248         target->set_friendly_name(friendlyName);
249     }
250 }
251
252 void NetworkHelper::cleanRunTimeInfo(const std::shared_ptr<Node>& layer) {
253     auto& rt_info = layer->get_rt_info();
254     auto attributeIter = rt_info.find("DEQUANTIZATION");
255     if (rt_info.find("DEQUANTIZATION") != rt_info.end()) {
256         rt_info.erase(attributeIter);
257     }
258 }
259
260 bool NetworkHelper::isScalarLike(std::shared_ptr<opset1::Constant> constant) {
261     return constant->get_all_data_elements_bitwise_identical();
262 }
263
264 bool NetworkHelper::isZero(std::shared_ptr<opset1::Constant> constant) {
265     static const float minQuantizationShift = 1e-32f;
266
267     auto values = constant->cast_vector<float>();
268     for (size_t i = 0; i < values.size(); ++i) {
269         if (fabs(values[i]) > minQuantizationShift) {
270             return false;
271         }
272     }
273
274     return true;
275 }
276
277 std::shared_ptr<opset1::Constant> NetworkHelper::toScalar(std::shared_ptr<opset1::Constant> constant) {
278     assert(isScalarLike(constant));
279     return std::make_shared<opset1::Constant>(constant->get_element_type(), Shape{}, constant->get_data_ptr());
280 }
281
282 std::shared_ptr<Node> NetworkHelper::getConstantInput(std::shared_ptr<Node> node) {
283     std::shared_ptr<Node> constant1 = as_type_ptr<opset1::Constant>(node->input_value(0).get_node_shared_ptr());
284     if (!constant1) {
285         constant1 = as_type_ptr<opset1::Constant>(node->input_value(1).get_node_shared_ptr());
286     }
287     return constant1;
288 }
289
290 std::shared_ptr<ngraph::opset1::Multiply> NetworkHelper::optimizeMultipliesAfter(std::shared_ptr<Node> node) {
291     std::shared_ptr<ngraph::opset1::Multiply> multiply = as_type_ptr<opset1::Multiply>(node);
292     if (!multiply) {
293         THROW_IE_LPT_EXCEPTION(*multiply) << "Unexpected operation type";
294     }
295
296     if (multiply->output(0).get_target_inputs().size() == 1) {
297         auto constant1 = getConstantInput(multiply);
298         if (!constant1 || constant1->output(0).get_target_inputs().size() != 1) {
299             return multiply;
300         }
301         auto nextMultiplyInput = *multiply->output(0).get_target_inputs().begin();
302         auto nextMultiply = as_type_ptr<opset1::Multiply>(nextMultiplyInput.get_node()->shared_from_this());
303         if (nextMultiply) {
304             auto constant2 = getConstantInput(nextMultiply);
305             auto constant2Inputs = constant2->output(0).get_target_inputs().size();
306             if (!constant2 || constant2->output(0).get_target_inputs().size() != 1) {
307                 return multiply;
308             }
309
310             auto newConst = fold<opset1::Multiply>(constant1, constant2);
311             auto newMultiply =
312                     std::make_shared<opset1::Multiply>(
313                             multiply->input_value(1 - constant1->output(0).get_target_inputs().begin()->get_index()),
314                             newConst->output(0));
315             copy_runtime_info(multiply, newMultiply);
316             replace_node(nextMultiply, newMultiply);
317             return newMultiply;
318         }
319     }
320
321     return nullptr;
322 }
323
324 std::shared_ptr<opset1::Constant> NetworkHelper::roundWithTolerance(std::shared_ptr<Node> node, element::Type target_type, float tolerance) {
325     auto constant = as_type_ptr<opset1::Constant>(node);
326     assert(constant);
327     auto values = constant->cast_vector<float>();
328
329     auto castedConstant = as_type_ptr<opset1::Constant>(fold<opset1::Convert>(constant, target_type));
330     auto castedValues = castedConstant->cast_vector<float>();
331
332     // TODO: implement with constant folding when ReduceAnd constant folding is ready
333     if (std::equal(values.begin(), values.end(), castedValues.begin(), [tolerance](float a, float b) { return fabs(a - b) < tolerance; })) {
334         return castedConstant;
335     }
336
337     auto round = [](
338         const std::shared_ptr<opset1::Constant>& constant,
339         element::Type target_type,
340         float tolerance,
341         std::vector<float>& values,
342         float increaseValue) -> std::shared_ptr<opset1::Constant> {
343         const auto castedConstant = as_type_ptr<opset1::Constant>(fold<opset1::Convert>(
344             fold<opset1::Add>(constant, std::make_shared<opset1::Constant>(constant->get_output_element_type(0), Shape{ 1 }, increaseValue)),
345             target_type));
346         const auto castedValues = castedConstant->cast_vector<float>();
347         if (std::equal(values.begin(), values.end(), castedValues.begin(), [tolerance](float a, float b) { return fabs(a - b) < tolerance; })) {
348             return castedConstant;
349         }
350
351         return nullptr;
352     };
353
354     castedConstant = round(constant, target_type, tolerance, values, 0.5f);
355     if (castedConstant != nullptr) {
356         return castedConstant;
357     }
358
359     castedConstant = round(constant, target_type, tolerance, values, -0.5f);
360     if (castedConstant != nullptr) {
361         return castedConstant;
362     }
363
364     castedConstant = round(constant, target_type, tolerance, values, 1.f);
365     if (castedConstant != nullptr) {
366         return castedConstant;
367     }
368
369     return constant;
370 }
371
372 std::shared_ptr<Node> NetworkHelper::fold_fake_quantize(const std::shared_ptr<opset1::FakeQuantize>& fq) {
373     return foldFakeQuantize(fq, false, false);
374 }
375
376 std::shared_ptr<Node> NetworkHelper::fold_fake_quantize(const std::shared_ptr<opset1::FakeQuantize>& fq, const bool roundValues) {
377     return foldFakeQuantize(fq, roundValues, true);
378 }
379
380 void NetworkHelper::foldDequantization(std::shared_ptr<Node>& node, const size_t branchIndex, const bool inPlace) {
381     FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(node, branchIndex, inPlace);
382     if (dequantization.empty() || (dequantization.multiply == nullptr)) {
383         return;
384     }
385
386     std::shared_ptr<opset1::Constant> constant = as_type_ptr<opset1::Constant>(dequantization.data.get_node_shared_ptr());
387     if ((constant == nullptr) || (constant->output(0).get_target_inputs().size() != 1ul)) {
388         return;
389     }
390
391     if (dequantization.convert != nullptr) {
392         const std::shared_ptr<Node> result = fold<opset1::Convert>(dequantization.data, dequantization.convert->get_element_type());
393         if (!is_type<opset1::Constant>(result)) {
394             return;
395         }
396         if (inPlace) {
397             copyInfo(dequantization.convert, result);
398         }
399         replace_node(dequantization.convert, result);
400         dequantization = NetworkHelper::getDequantization(node, branchIndex, inPlace);
401     }
402
403     if (dequantization.subtract != nullptr) {
404         if (dequantization.data.get_element_type() != dequantization.subtract->input(1).get_element_type()) {
405             return;
406         }
407         const std::shared_ptr<Node> result = fold<opset1::Subtract>(dequantization.data, dequantization.subtract->get_input_node_shared_ptr(1));
408         if (!is_type<opset1::Constant>(result)) {
409             return;
410         }
411         if (inPlace) {
412             copyInfo(dequantization.subtract, result);
413         }
414         replace_node(dequantization.subtract, result);
415         dequantization = NetworkHelper::getDequantization(node, branchIndex, inPlace);
416     }
417
418     if (dequantization.multiply != nullptr) {
419         if (dequantization.data.get_element_type() != dequantization.multiply->input(1).get_element_type()) {
420             return;
421         }
422         const std::shared_ptr<Node> result = fold<opset1::Multiply>(dequantization.data, dequantization.multiply->get_input_node_shared_ptr(1));
423         if (!is_type<opset1::Constant>(result)) {
424             return;
425         }
426         if (inPlace) {
427             copyInfo(dequantization.multiply, result);
428         }
429         replace_node(dequantization.multiply, result);
430         dequantization = NetworkHelper::getDequantization(node, branchIndex, inPlace);
431     }
432 }
433
434 std::shared_ptr<Node> NetworkHelper::foldFakeQuantize(
435     const std::shared_ptr<opset1::FakeQuantize>& fq,
436     const bool roundValuesArg,
437     const bool roundValuesWasSet) {
438     if (is_type<opset1::Constant>(fq->get_input_node_shared_ptr(0)) &&
439         is_type<opset1::Constant>(fq->get_input_node_shared_ptr(1)) &&
440         is_type<opset1::Constant>(fq->get_input_node_shared_ptr(2)) &&
441         is_type<opset1::Constant>(fq->get_input_node_shared_ptr(3)) &&
442         is_type<opset1::Constant>(fq->get_input_node_shared_ptr(4)) &&
443         op::util::constantIsEqualTo(as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(1)), 0.f) &&
444         op::util::constantIsEqualTo(as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(2)), 254.f) &&
445         op::util::constantIsEqualTo(as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(3)), -127.f) &&
446         op::util::constantIsEqualTo(as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(4)), 127.f)) {
447         const auto type1 = fq->input_value(0).get_element_type();
448         const auto type2 = fq->input_value(3).get_element_type();
449         if (type1.is_real() && type2.is_real()) {
450             return fold<opset1::Add>(fq->input_value(0), fq->input_value(3));
451         }
452         if (type1.is_real() && !type2.is_real()) {
453             return fold<opset1::Add>(
454                 fq->input_value(0),
455                 fold<opset1::Convert>(fq->input_value(3), type1));
456         }
457         if (!type1.is_real() && type2.is_real()) {
458             return fold<opset1::Add>(
459                 fold<opset1::Convert>(fq->input_value(0), type2),
460                 fq->input_value(3));
461         }
462         return fold<opset1::Add>(
463             fold<opset1::Convert>(fq->input_value(0), element::f32),
464             fold<opset1::Convert>(fq->input_value(3), element::f32));
465     }
466
467     auto constant = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(0));
468
469     if (constant) {
470         const bool roundValues = roundValuesWasSet ? roundValuesArg : fq->output(0).get_element_type().is_integral();
471
472         Shape constShape = fq->get_output_shape(0);
473         if (constShape.empty() || constShape.size() > 5lu) {
474             THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected dimensions count " << constShape.size();
475         }
476
477         // OIDHW
478         const size_t OC = constShape[0];
479         const size_t IC = constShape.size() > 1lu ? constShape[1] : 1;
480         const size_t D = constShape.size() > 4lu ? constShape[constShape.size() - 3] : 1;
481         const size_t H = constShape.size() > 2lu ? constShape.size() == 3lu ? constShape[2] : constShape[constShape.size() - 2] : 1;
482         const size_t W = constShape.size() > 3lu ? constShape[constShape.size() - 1] : 1;
483
484         const auto inputLowValues = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(1))->cast_vector<float>();
485         const auto inputHighValues = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(2))->cast_vector<float>();
486         const auto outputLowValues = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(3))->cast_vector<float>();
487         const auto outputHighValues = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(4))->cast_vector<float>();
488
489         const size_t inputLowSize = inputLowValues.size();
490         const size_t inputHighSize = inputHighValues.size();
491         const size_t outputLowSize = outputLowValues.size();
492         const size_t outputHighSize = outputHighValues.size();
493
494         const bool isInputLowBroadcasted = inputLowSize != OC;
495         if ((inputLowSize != 1) && (inputLowSize != OC)) {
496             THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected input low values count " << inputLowSize << " for " << OC << " channels";
497         }
498         const bool isInputHighBroadcasted = inputHighSize != OC;
499         if ((inputHighSize != 1) && (inputHighSize != OC)) {
500             THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected input high values count " << inputHighSize << " for " << OC << " channels";
501         }
502         const bool isOutputLowBroadcasted = outputLowSize != OC;
503         if ((outputLowSize != 1) && (outputLowSize != OC)) {
504             THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected output low values count " << outputLowSize << " for " << OC << " channels";
505         }
506         const bool isOutputHighBroadcasted = outputHighSize != OC;
507         if ((outputHighSize != 1) && (outputHighSize != OC)) {
508             THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected output high values count " << outputHighSize << " for " << OC << " channels";
509         }
510
511         auto levels_1 = fq->get_levels() - 1.f;
512
513         //const size_t DHW = D * H * W;
514         const size_t IDHW = IC * D * H * W;
515
516         const auto values = constant->cast_vector<float>();
517         std::vector<float> quantizedValues(OC * IC * D * H * W);
518
519         for (int oc = 0; oc < OC; ++oc) {
520             for (int iidx = 0; iidx < IDHW; ++iidx) {
521                 const float inputLow = inputLowValues[isInputLowBroadcasted ? 0 : oc];
522                 const float inputHigh = inputHighValues[isInputHighBroadcasted ? 0 : oc];
523                 const float outputLow = outputLowValues[isOutputLowBroadcasted ? 0 : oc];
524                 const float outputHigh = outputHighValues[isOutputHighBroadcasted ? 0 : oc];
525
526                 const size_t idx = oc * IDHW + iidx;
527
528                 if (values[idx] <= inputLow) {
529                     quantizedValues[idx] = roundValues ? std::roundf(outputLow) : outputLow;
530                 } else if (values[idx] > inputHigh) {
531                     quantizedValues[idx] = roundValues ? std::roundf(outputHigh) : outputHigh;
532                 } else {
533                     const float value = std::roundf((values[idx] - inputLow) / (inputHigh - inputLow) * levels_1) /
534                         levels_1 * (outputHigh - outputLow) + outputLow;
535                     quantizedValues[idx] = roundValues ? std::roundf(value) : value;
536                 }
537             }
538         }
539
540         return std::make_shared<opset1::Constant>(fq->get_output_element_type(0), constShape, quantizedValues);
541     }
542
543     return fq;
544 }
545
546 // Decompose FakeQuantize to FakeQuantize with output integer limits (quantize), dequatized MultiplyAdd
547 // To align types the resulting sequence is FakeQuantize -> Convert -> Convert -> MultiplyAdd
548 std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> NetworkHelper::decomposeFakeQuantize(
549     std::shared_ptr<opset1::FakeQuantize> fq,
550     const element::Type precision,
551     const float min,
552     const float max,
553     const bool hasZeroPoint,
554     const bool updatePrecision) {
555     using std::make_shared;
556
557     const auto outputLow = fq->input_value(3);
558     const auto outputHigh = fq->input_value(4);
559
560     std::vector<float> outputLowValues = as_type_ptr<opset1::Constant>(outputLow.get_node_shared_ptr())->cast_vector<float>();
561     std::vector<float> outputHighValues = as_type_ptr<opset1::Constant>(outputHigh.get_node_shared_ptr())->cast_vector<float>();
562     size_t outputSize = outputLowValues.size();
563     std::vector<float> minValues(outputSize, min);
564     std::vector<float> maxValues(outputSize, max);
565     std::vector<float> shifts(outputSize, 0.f);
566     std::vector<float> scales(outputSize);
567
568     for (int i = 0; i < outputSize; ++i) {
569         if (outputHighValues[i] != outputLowValues[i]) {
570             shifts[i] = (min*outputHighValues[i] - max*outputLowValues[i]) / (outputHighValues[i] - outputLowValues[i]);
571             scales[i] = (outputHighValues[i] - outputLowValues[i]) / (max - min);
572             if (shifts[i] == -0.f) {
573                 shifts[i] = 0.f;
574             }
575         } else {
576             scales[i] = outputHighValues[i];
577             minValues[i] = 1.f;
578             maxValues[i] = 1.f;
579         }
580     }
581
582     std::shared_ptr<Node> shift = hasZeroPoint ?
583         std::make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), shifts) :
584         nullptr;
585     std::shared_ptr<Node> scale = std::make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), scales);
586
587     auto newMin = make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), minValues);
588     auto newMax = make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), maxValues);
589
590     if (isScalarLike(newMin)) {
591         newMin = toScalar(newMin);
592     }
593     if (isScalarLike(newMax)) {
594         newMax = toScalar(newMax);
595     }
596
597     {
598         static const float minQuantizationScale = 1e-32f;
599         static const float maxQuantizationScale = 1e32f;
600
601         auto scaleValues = scales;
602         bool wasChanged = false;
603         for (size_t i = 0; i < scaleValues.size(); ++i) {
604             const float scale = scaleValues[i];
605             if (fabs(scale) < minQuantizationScale) {
606                 scaleValues[i] = minQuantizationScale;
607                 wasChanged = true;
608             } else if (fabs(scale) > maxQuantizationScale) {
609                 scaleValues[i] = scale > 0.f ? maxQuantizationScale : -maxQuantizationScale;
610                 wasChanged = true;
611             }
612         }
613
614         if (wasChanged) {
615             scale = std::make_shared<opset1::Constant>(scale->output(0).get_element_type(), scale->output(0).get_shape(), scaleValues);
616         }
617     }
618
619     if ((shift != nullptr) && isZero(as_type_ptr<opset1::Constant>(shift))) {
620         shift = nullptr;
621     }
622
623     // Build a substitution sub-graph:
624
625     std::shared_ptr<ngraph::Node> newFQ = fold_fake_quantize(
626         std::make_shared<op::TypeRelaxed<opset1::FakeQuantize>>(
627             fq->input_value(0),
628             fq->input_value(1),
629             fq->input_value(2),
630             newMin->output(0),
631             newMax->output(0),
632             fq->get_levels(),
633             fq->get_auto_broadcast()),
634         true);
635     NetworkHelper::copyInfo(fq, newFQ);
636
637     std::shared_ptr<ngraph::Node> convert2;
638     if (updatePrecision) {
639         std::shared_ptr<Node> convert;
640         std::shared_ptr<opset1::Constant> newFqConstant = as_type_ptr<opset1::Constant>(newFQ);
641
642         if (is_type<opset1::Constant>(newFQ)) {
643             convert = fold<opset1::Convert>(newFQ, precision);
644         } else if (is_type<opset1::FakeQuantize>(newFQ)) {
645             newFQ = setOutDataPrecision(as_type_ptr<opset1::FakeQuantize>(newFQ), precision);
646             convert = newFQ;
647         } else {
648             THROW_IE_LPT_EXCEPTION(*newFQ) << "unexpected operation type";
649         }
650
651         convert2 = std::make_shared<DequantizationConvert>(convert, element::f32);
652         convert2->set_friendly_name(convert->get_friendly_name() + "/DequantizationConvert");
653         ngraph::copy_runtime_info({ newFQ, convert2 }, convert2);
654     } else {
655         if (newFQ->get_output_element_type(0) != element::f32) {
656             convert2 = std::make_shared<DequantizationConvert>(newFQ, element::f32);
657             convert2->set_friendly_name(newFQ->get_friendly_name() + "/DequantizationConvert");
658             ngraph::copy_runtime_info({ newFQ, convert2 }, convert2);
659         }
660     }
661
662     // TODO: why type relaxed?
663     const std::shared_ptr<ngraph::Node> sub = shift == nullptr ?
664         nullptr :
665         std::make_shared<ngraph::op::TypeRelaxed<DequantizationSubtract>>(convert2 == nullptr ? newFQ : convert2, shift);
666     if (sub != nullptr) {
667         sub->set_friendly_name(newFQ->get_friendly_name() + "/DequantizationSubtract");
668         ngraph::copy_runtime_info({ newFQ, sub }, sub);
669     }
670
671     const std::shared_ptr<ngraph::opset1::Multiply> dequantize = std::make_shared<DequantizationMultiply>(
672         sub == nullptr ? (convert2 == nullptr ? newFQ : convert2) : sub,
673         scale);
674     dequantize->set_friendly_name(newFQ->get_friendly_name() + "/DequantizationMultiply");
675     ngraph::copy_runtime_info({ newFQ, dequantize }, dequantize);
676
677     replace_node(fq, dequantize);
678
679     return std::make_tuple(newFQ, dequantize);
680 }
681
682 std::shared_ptr<opset1::FakeQuantize> NetworkHelper::updateFakeQuantize(
683     std::shared_ptr<opset1::FakeQuantize> fq,
684     element::Type precision,
685     float min,
686     float max) {
687     auto newMin = std::make_shared<opset1::Constant>(fq->get_output_element_type(0), Shape{}, min);
688     auto newMax = std::make_shared<opset1::Constant>(fq->get_output_element_type(0), Shape{}, max);
689
690     std::shared_ptr<opset1::FakeQuantize> newFQ = std::make_shared<ngraph::op::TypeRelaxed<opset1::FakeQuantize>>(
691             fq->input_value(0),
692             fq->input_value(1),
693             fq->input_value(2),
694             newMin->output(0),
695             newMax->output(0),
696             fq->get_levels(),
697             fq->get_auto_broadcast());
698
699     NetworkHelper::setOutDataPrecision(newFQ, precision);
700     replace_node(fq, newFQ);
701
702     newFQ->set_friendly_name(fq->get_friendly_name());
703     return newFQ;
704 }
705
706 FakeQuantizeDequantization NetworkHelper::makeDequantization(
707     const float dequantizationMul,
708     const float dequantizationSub,
709     const ngraph::element::Type originalPrecision,
710     const ngraph::Shape dataNodeOutputShape,
711     element::Type precision,
712     float min,
713     float max) {
714     // TODO: we create input here! we really need it here?
715     const std::shared_ptr<opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(precision, dataNodeOutputShape);
716     std::shared_ptr<ngraph::Node> parent = input;
717
718     // TODO: convert should be optional: where is updatePrecision?
719     std::shared_ptr<DequantizationConvert> convert;
720     {
721         convert = std::make_shared<DequantizationConvert>(
722             input,
723             originalPrecision);
724         parent = convert;
725     }
726
727     std::shared_ptr<DequantizationSubtract> subtract;
728     if (std::abs(dequantizationSub) > 1e-6) {
729         subtract = std::make_shared<ngraph::op::TypeRelaxed<DequantizationSubtract>>(
730             parent,
731             std::make_shared<ngraph::opset1::Constant>(originalPrecision, ngraph::Shape({}), std::vector<float>({ dequantizationSub })));
732         subtract->set_output_type(0, originalPrecision, subtract->get_output_partial_shape(0));
733         parent = subtract;
734     }
735
736     // mandatory
737     std::shared_ptr<ngraph::opset1::Multiply> multiply = std::make_shared<DequantizationMultiply>(
738         parent,
739         std::make_shared<ngraph::opset1::Constant>(originalPrecision, ngraph::Shape({}), std::vector<float>({ dequantizationMul })));
740
741     return FakeQuantizeDequantization(input, convert, subtract, multiply);
742 }
743
744 FakeQuantizeDequantization NetworkHelper::createDequantizationFromFakeQuantize(
745     std::shared_ptr<opset1::FakeQuantize> fq,
746     element::Type precision,
747     float min,
748     float max,
749     const bool hasZeroPoint,
750     const bool updatePrecision) {
751     using std::make_shared;
752
753     const ngraph::element::Type_t fqPrecision = fq->get_output_element_type(0);
754     auto newMin = make_shared<opset1::Constant>(fqPrecision, Shape{}, min);
755     auto newMax = make_shared<opset1::Constant>(fqPrecision, Shape{}, max);
756
757     auto outputLow = fq->input_value(3);
758     auto outputHigh = fq->input_value(4);
759
760     // TODO: threshold values have to used here to avoid shifts
761
762     const std::shared_ptr<Node> scale = fold<opset1::Divide>(
763         fold<opset1::Subtract>(outputHigh, outputLow),
764         fold<opset1::Subtract>(newMax, newMin));
765
766     std::shared_ptr<Node> shift = hasZeroPoint ?
767         fold<opset1::Divide>(
768             fold<opset1::Subtract>(fold<opset1::Multiply>(newMin, outputHigh), fold<opset1::Multiply>(newMax, outputLow)),
769             fold<opset1::Subtract>(outputHigh, outputLow)) :
770         nullptr;
771
772     if (shift != nullptr) {
773         std::shared_ptr<opset1::Constant> shiftConst = as_type_ptr<opset1::Constant>(shift);
774         if (isScalarLike(shiftConst)) {
775             auto scalar = toScalar(shiftConst);
776             if (op::util::constantIsEqualTo(scalar, 0)) {
777                 shift = nullptr;
778             }
779         }
780     }
781
782     const auto input = std::make_shared<ngraph::opset1::Parameter>(precision, fq->get_output_shape(0));
783     const std::shared_ptr<ngraph::opset1::Convert> convert = std::make_shared<DequantizationConvert>(
784         input,
785         fq->get_output_element_type(0));
786
787     const std::shared_ptr<ngraph::opset1::Subtract> subtract = shift == nullptr ?
788         nullptr :
789         make_shared<ngraph::op::TypeRelaxed<DequantizationSubtract>>(convert, shift);
790     if (subtract != nullptr) {
791         subtract->set_output_type(0, fq->get_output_element_type(0), subtract->get_output_partial_shape(0));
792     }
793
794     const std::shared_ptr<ngraph::opset1::Multiply> multiply = std::make_shared<DequantizationMultiply>(
795         subtract == nullptr ? static_cast<std::shared_ptr<Node>>(convert) : subtract,
796         scale);
797
798     return FakeQuantizeDequantization(fq, convert, subtract, multiply);
799 }
800
801 FakeQuantizeDequantization NetworkHelper::getDequantization(const std::shared_ptr<Node> node, const size_t parentIndex, const bool inPlace) {
802     auto getDataIndex = [](const std::shared_ptr<ngraph::Node>& node) {
803         if (is_type<opset1::Constant>(node->get_input_node_ptr(1))) {
804             return 0ul;
805         } else {
806             return 1ul;
807         }
808     };
809
810     Output<Node> dataNode = inPlace ? node : node->input_value(parentIndex);
811
812     const std::shared_ptr<ngraph::opset1::Multiply> multiply = as_type_ptr<ngraph::opset1::Multiply>(dataNode.get_node_shared_ptr());
813     if (multiply != nullptr) {
814         if (!is_type<opset1::Constant>(multiply->get_input_node_ptr(0)) && !is_type<opset1::Constant>(multiply->get_input_node_ptr(1))) {
815             return FakeQuantizeDequantization();
816         }
817         dataNode = multiply->get_input_source_output(getDataIndex(multiply));
818     }
819
820     const std::shared_ptr<opset1::Subtract> subtract = as_type_ptr<ngraph::opset1::Subtract>(dataNode.get_node_shared_ptr());
821     if (subtract != nullptr) {
822         if (!is_type<opset1::Constant>(subtract->get_input_node_ptr(0)) && !is_type<opset1::Constant>(subtract->get_input_node_ptr(1))) {
823             return FakeQuantizeDequantization(dataNode, nullptr, nullptr, multiply);
824         }
825         dataNode = subtract->get_input_source_output(getDataIndex(subtract));
826     }
827
828     const std::shared_ptr<opset1::Convert> convert = as_type_ptr<opset1::Convert>(dataNode.get_node_shared_ptr());
829     if (convert != nullptr) {
830         if ((convert->input(0).get_element_type() != element::i8) && (convert->input(0).get_element_type() != element::u8) &&
831             (convert->output(0).get_element_type() != element::f32)) {
832             return FakeQuantizeDequantization(dataNode, nullptr, subtract, multiply);
833         }
834         dataNode = convert->get_input_source_output(0);
835     }
836
837     return FakeQuantizeDequantization(dataNode, convert, subtract, multiply);
838 }
839
840 FakeQuantizeDequantizationValues NetworkHelper::createEmptyValues(const FakeQuantizeDequantization& dequantization) {
841     std::shared_ptr<Node> parent = dequantization.convert ? dequantization.convert : dequantization.data.get_node_shared_ptr();
842
843     std::shared_ptr<Node> multiply1Const = dequantization.multiply ?
844         dequantization.multiply->get_input_node_shared_ptr(1)->clone_with_new_inputs({}) :
845         std::make_shared<opset1::Constant>(parent->get_output_element_type(0), Shape({}), std::vector<float>({ 1.f }));
846
847     std::shared_ptr<Node> subtract1Const = dequantization.subtract ?
848         dequantization.subtract->get_input_node_shared_ptr(1)->clone_with_new_inputs({}) :
849         std::make_shared<opset1::Constant>(parent->get_output_element_type(0), Shape({}), std::vector<float>({ 0.f }));
850
851     subtract1Const->set_output_type(0, multiply1Const->get_output_element_type(0), subtract1Const->get_output_partial_shape(0));
852
853     return FakeQuantizeDequantizationValues(subtract1Const, multiply1Const);
854 }
855
856 bool NetworkHelper::isZeroConst(const std::shared_ptr<Node>& node) {
857     std::shared_ptr<opset1::Constant> constant = as_type_ptr<opset1::Constant>(node);
858
859     if (constant == nullptr)
860         return false;
861
862     if (NetworkHelper::isScalarLike(constant)) {
863         auto scalar = NetworkHelper::toScalar(constant);
864         if (op::util::constantIsEqualTo(scalar, 0)) {
865             return true;
866         } else {
867             return false;
868         }
869     } else {
870         return false;
871     }
872 }
873
874 std::shared_ptr<Node> NetworkHelper::optimizeSubtract(std::shared_ptr<opset1::Subtract> subtract) {
875     auto convertOnSubtract = subtract->input_value(0).get_node_shared_ptr();
876     if (as_type_ptr<opset1::Convert>(convertOnSubtract) == nullptr) {
877         return subtract;
878     }
879
880     // TODO: replace assert to condition and omit conversion part if there is no convert
881     // TODO: also check convertInputType to understand if we really want to propagate type
882     assert(as_type_ptr<opset1::Convert>(convertOnSubtract));
883     const element::Type convertInputType = convertOnSubtract->get_input_element_type(0);
884     const element::Type convertOutputType = convertOnSubtract->get_output_element_type(0);
885
886     if (!convertOutputType.is_real()) {
887         return subtract;
888     }
889
890     auto data = convertOnSubtract->input_value(0);
891     auto shift = subtract->input_value(1).get_node_shared_ptr();
892     auto roundedShift = NetworkHelper::roundWithTolerance(shift, convertInputType);
893
894     std::shared_ptr<Node> replacement;
895     if (roundedShift->get_element_type() == convertInputType) {
896         // Propagate convertInputType down
897         replacement = std::make_shared<op::TypeRelaxed<opset1::Subtract>>(data, roundedShift);
898         NetworkHelper::copyInfo(subtract, replacement);
899         NetworkHelper::setOutDataPrecisionForTypeRelaxed(replacement, convertOutputType);
900         replace_node(subtract, replacement);
901     }
902
903     // We lose the tail conversion here; not needed if the next node is a TypeRelaxed
904     // TODO: check cases when Convert should be preserved
905
906     // Try to optimize Add out if constant is zero
907     // TODO: don't remove operation here: don't create this Subtraction operation in FQ decomposition
908     // if (isScalarLike(roundedShift)) {
909     //    auto scalar = distillToScalar(roundedShift);
910     //    if (op::util::constantIsEqualTo(scalar, 0)) {
911     //        replace_node(replacement, replacement->input_value(0).get_node_shared_ptr());
912     //        replacement = nullptr;
913     //    }
914     // }
915
916     return replacement;
917 }
918
919 NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter(
920     const std::shared_ptr<ngraph::Node>& operation,
921     const FakeQuantizeDequantization& dequantization,
922     const bool updatePrecision,
923     const bool moveSubtract) {
924     std::vector<Output<Node>> inputs(operation->get_input_size());
925     for (size_t i = 0; i < operation->get_input_size(); ++i) {
926         inputs[i] = operation->get_input_node_shared_ptr(i);
927     }
928
929     const size_t dequantizationIndex = getChildInputIndex(dequantization.multiply, operation);
930     inputs[dequantizationIndex] = moveSubtract ?
931         dequantization.data :
932         (dequantization.subtract == nullptr ? dequantization.data : dequantization.subtract);
933
934     const std::shared_ptr<ngraph::Node> newOperation = operation->clone_with_new_inputs(inputs);
935     newOperation->set_friendly_name(operation->get_friendly_name());
936     ngraph::copy_runtime_info(operation, newOperation);
937
938     if (updatePrecision) {
939         auto op = std::dynamic_pointer_cast<ngraph::op::TypeRelaxedBase>(newOperation);
940         if (op == nullptr) {
941             THROW_IE_LPT_EXCEPTION(*newOperation) << "not possible to update precision for not TypeRelaxedBase operation";
942         }
943         op->set_overridden_output_type(newOperation->get_input_element_type(0));
944         std::dynamic_pointer_cast<ngraph::Node>(newOperation)->validate_and_infer_types();
945     }
946
947     const bool shouldConvert = (newOperation->get_output_element_type(0) != dequantization.multiply->get_output_element_type(0));
948
949     auto parent = newOperation;
950     if (shouldConvert) {
951         const auto convertOutputPrecision = dequantization.convert != nullptr ?
952             dequantization.convert->get_output_element_type(0) :
953             dequantization.multiply->get_output_element_type(0);
954         parent = std::make_shared<DequantizationConvert>(parent, convertOutputPrecision);
955         ngraph::copy_runtime_info({ newOperation, parent }, parent);
956     }
957
958     if (moveSubtract && (dequantization.subtract != nullptr)) {
959         auto subtractConstant = dequantization.subtract->get_input_node_shared_ptr(1);
960         const element::Type parentPrecision = parent->get_output_element_type(0);
961         if (parentPrecision.bitwidth() < subtractConstant->output(0).get_element_type().bitwidth()) {
962             THROW_IE_LPT_EXCEPTION(*parent) <<
963                 "unexpected precisions: on data " << parent->get_friendly_name() << ":" << parentPrecision <<
964                 ", subtract dequantization constant " << subtractConstant->get_friendly_name() << ":" << subtractConstant->output(0).get_element_type();
965         }
966
967         parent = std::make_shared<DequantizationSubtract>(
968             parent,
969             subtractConstant->output(0).get_element_type() == parentPrecision ?
970                 subtractConstant :
971                 fold<opset1::Convert>(subtractConstant->output(0), parentPrecision));
972         ngraph::copy_runtime_info({ newOperation, parent }, parent);
973     }
974
975     if (dequantization.multiply != nullptr) {
976         auto multiplyConstant = dequantization.multiply->get_input_node_shared_ptr(1);
977         const element::Type parentPrecision = parent->get_output_element_type(0);
978         if (parentPrecision.bitwidth() < multiplyConstant->output(0).get_element_type().bitwidth()) {
979             THROW_IE_LPT_EXCEPTION(*parent) <<
980                 "unexpected precisions: on data " << parent->get_friendly_name() << ":" << parentPrecision <<
981                 ", multiply dequantization constant " << multiplyConstant->get_friendly_name() << ":" << multiplyConstant->output(0).get_element_type();
982         }
983
984         parent = std::make_shared<DequantizationMultiply>(
985             parent,
986             multiplyConstant->output(0).get_element_type() == parentPrecision ?
987                 multiplyConstant :
988                 fold<opset1::Convert>(multiplyConstant->output(0), parentPrecision));
989         ngraph::copy_runtime_info({ newOperation, parent }, parent);
990     }
991     replace_node(operation, parent);
992
993     if ((!moveSubtract) && (dequantization.convert != nullptr) && (dequantization.subtract != nullptr)) {
994         NetworkHelper::cleanRunTimeInfo(dequantization.subtract);
995         optimizeSubtract(dequantization.subtract);
996     }
997
998     return InsertDequantizationResult(newOperation, parent);
999 }
1000
1001 void NetworkHelper::removeConvertIfPossible(
1002     const std::shared_ptr<ngraph::Node>& operation,
1003     const FakeQuantizeDequantization& dequantization) {
1004     const element::Type precisionBeforeConvert = dequantization.convert->input(0).get_element_type();
1005
1006     if (checkConstantValuePrecision(precisionBeforeConvert, dequantization.subtract->get_input_node_shared_ptr(1))) {
1007         auto newSubtract = dequantization.subtract->clone_with_new_inputs({
1008             dequantization.convert->get_input_node_shared_ptr(0),
1009             fold<opset1::Convert>(dequantization.subtract->get_input_node_shared_ptr(1), precisionBeforeConvert) });
1010         replace_node(dequantization.subtract, newSubtract);
1011     }
1012 }
1013
1014 bool NetworkHelper::checkConstantValuePrecision(const element::Type expectedPrecision, const std::shared_ptr<Node>& constant) {
1015     if (expectedPrecision.is_signed()) {
1016         return true;
1017     }
1018
1019     std::shared_ptr<opset1::Constant> constantOp = as_type_ptr<opset1::Constant>(constant);
1020     if (constantOp == nullptr) {
1021         return false;
1022     }
1023
1024     const auto values = constantOp->cast_vector<float>();
1025     const bool convertCanBeRemoved =
1026         (expectedPrecision.is_signed() || (std::all_of(values.begin(), values.end(), [](const float value) { return value >= 0.f; })));
1027     return convertCanBeRemoved;
1028 }
1029
1030 size_t NetworkHelper::getChildInputIndex(const std::shared_ptr<ngraph::Node>& parent, const std::shared_ptr<ngraph::Node>& child) {
1031     for (size_t i = 0; i < child->get_input_size(); ++i) {
1032         if (parent.get() == child->get_input_node_ptr(i)) {
1033             return i;
1034         }
1035     }
1036     THROW_IE_LPT_EXCEPTION(*child) << "child input index between " <<
1037         parent->get_friendly_name() << " and " << child->get_friendly_name() << " was not found";
1038 }
1039
1040 size_t NetworkHelper::getParentOutputIndex(const std::shared_ptr<ngraph::Node>& parent, const std::shared_ptr<ngraph::Node>& child) {
1041     for (size_t i = 0; i < parent->get_output_size(); ++i) {
1042         const auto& targetInputs = parent->output(i).get_target_inputs();
1043         for (const auto& targetInput : targetInputs) {
1044             if (targetInput.get_node() == child.get()) {
1045                 return i;
1046             }
1047         }
1048     }
1049     THROW_IE_LPT_EXCEPTION(*child) << "parent output index between " <<
1050         parent->get_friendly_name() << " and " << child->get_friendly_name() << " was not found";
1051 }
1052
1053 std::vector<Output<Node>> NetworkHelper::getInputs(const std::shared_ptr<ngraph::Node>& node) {
1054     std::vector<Output<Node>> inputs(node->get_input_size());
1055     for (size_t i = 0; i < node->get_input_size(); ++i) {
1056         inputs[i] = node->get_input_node_shared_ptr(i);
1057     }
1058     return inputs;
1059 }
1060
1061 std::shared_ptr<Node> NetworkHelper::toScalarIfPossible(std::shared_ptr<Node> node) {
1062     std::shared_ptr<opset1::Constant> constant = as_type_ptr<opset1::Constant>(node);
1063     if (constant == nullptr) {
1064         return node;
1065     }
1066
1067     if (!NetworkHelper::isScalarLike(constant)) {
1068         return node;
1069     }
1070
1071     return NetworkHelper::toScalar(constant);
1072 }
1073
1074 }  // namespace low_precision
1075 }  // namespace pass
1076 }  // namespace ngraph