[LPT] integration: issue #42391 & issue #43001 (#3201)
[platform/upstream/dldt.git] / inference-engine / src / low_precision_transformations / src / network_helper.cpp
1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <low_precision/network_helper.hpp>
6
7 #include <algorithm>
8 #include <cmath>
9 #include <limits>
10 #include <map>
11 #include <memory>
12 #include <string>
13 #include <unordered_set>
14 #include <utility>
15 #include <vector>
16 #include <queue>
17
18 #include <ngraph/rt_info.hpp>
19 #include "low_precision/common/ie_lpt_exception.hpp"
20 #include "low_precision/common/dequantization_op.hpp"
21
22 namespace ngraph {
23 namespace pass {
24 namespace low_precision {
25
26 // Return true if `type` can be castable to at least one of `type`
27 bool NetworkHelper::is_castable_to_one_of(NodeTypeInfo type, const std::unordered_set<NodeTypeInfo>& types) {
28     for (auto another : types) {
29         if (type.is_castable(another)) {
30             return true;
31         }
32     }
33     return false;
34 }
35
36 // Collect and return a vector with all nodes that consumes any of the `node` output
37 std::vector<Input<Node>> NetworkHelper::consumer_inputs(std::shared_ptr<Node> node) {
38     std::vector<Input<Node>> result;
39     for (const auto& output_port : node->outputs()) {
40         for (const auto &input : output_port.get_target_inputs()) {
41             result.push_back(input);
42         }
43     }
44     return result;
45 }
46
47 std::vector<std::shared_ptr<Node>> NetworkHelper::consumers(std::shared_ptr<Node> node) {
48     auto inputs = consumer_inputs(node);
49     std::vector<std::shared_ptr<Node>> result(inputs.size());
50     std::transform(inputs.begin(), inputs.end(), result.begin(), [](Input<Node> input){ return input.get_node()->shared_from_this(); });
51     return result;
52 }
53
54 int NetworkHelper::onWeightsInDepth(std::shared_ptr<Node> layer) {
55     const std::vector<std::shared_ptr<Node>> children = consumers(layer);
56     for (std::shared_ptr<Node> child : children) {
57         if ((is_type<opset1::Convolution>(child) ||
58             is_type<opset1::GroupConvolution>(child) ||
59             is_type<opset1::MatMul>(child)) &&
60             (child->inputs().size() >= 2lu)) {
61             const std::vector<std::shared_ptr<Node>> parents = getParentsRecursivelyExceptTypes(child, {}, 1);
62             for (const std::shared_ptr<Node>& parent : parents) {
63                 if (parent.get() == layer.get()) {
64                     return 1;
65                 }
66             }
67             return -1;
68         }
69
70         const int result = onWeightsInDepth(child);
71         if (result != 0) {
72             return result;
73         }
74     }
75     return 0;
76 }
77
78 bool NetworkHelper::onWeights(std::shared_ptr<Node> layer) {
79     const int result = onWeightsInDepth(layer);
80     return result == 1;
81 }
82
83 size_t NetworkHelper::getOutputChannelsCount(std::shared_ptr<const Node> layer, bool isOnWeights) {
84     if (layer->outputs().size() == 0) {
85         THROW_TRANSFORMATION_EXCEPTION << "Layer " << layer->get_friendly_name() << " doesn't have output tensors";
86     }
87
88     if (layer->outputs().size() > 1) {
89         THROW_TRANSFORMATION_EXCEPTION << "Layer " << layer->get_friendly_name() << " has too many output tensors, expected one";
90     }
91
92     PartialShape shape = layer->get_output_partial_shape(0);
93     if (shape.rank() == 0) {
94         THROW_TRANSFORMATION_EXCEPTION << "Invalid dimensions count (0) in output of " << layer->get_friendly_name() << " layer on weights";
95     }
96     if (isOnWeights) {
97         return shape[0].get_length();
98     } else {
99         if (shape.rank() == 1) {
100             return shape[0].get_length();
101         }
102         return shape[1].get_length();
103     }
104 }
105
106 std::vector<std::shared_ptr<Node>> NetworkHelper::getParentsRecursivelyExceptTypes(
107         std::shared_ptr<Node> layer,
108         const std::unordered_set<NodeTypeInfo>& exceptionLayerTypes,
109         const int portIndex) {
110     std::vector<std::shared_ptr<Node>> parents;
111     size_t i = 0ul;
112     for (auto input : layer->inputs()) {
113         if ((portIndex == -1) || (portIndex == i)) {
114             auto parent = input.get_source_output().get_node_shared_ptr();
115             if (is_castable_to_one_of(parent->get_type_info(), exceptionLayerTypes)) {
116                 const std::vector<std::shared_ptr<Node>> tmpParents = getParentsRecursivelyExceptTypes(parent, exceptionLayerTypes);
117                 parents.insert(parents.end(), tmpParents.begin(), tmpParents.end());
118             } else {
119                 parents.push_back(parent);
120             }
121         }
122
123         i++;
124     }
125     return parents;
126 }
127
128 size_t NetworkHelper::getInputChannelsCount(std::shared_ptr<Node> layer) {
129     if (layer->get_input_size() == 0) {
130         THROW_TRANSFORMATION_EXCEPTION << "There are no input layers";
131     }
132
133     PartialShape shape = layer->get_input_partial_shape(0);
134     if (shape.rank().get_length() <= 1) {
135         THROW_TRANSFORMATION_EXCEPTION << "Invalid dimensions count (0) in input of " << layer->get_friendly_name();
136     }
137
138     return shape[1].get_length();
139 }
140
141 size_t NetworkHelper::getGroupsCount(std::shared_ptr<Node> layer) {
142     if (as_type_ptr<opset1::Convolution>(layer)) {
143         return 1;
144     } else if (auto group_convolution = as_type_ptr<opset1::GroupConvolution>(layer)) {
145         return layer->get_input_shape(1)[0];    // input weights for opset1::GC is in format GOI..., see the specification
146     } else {
147         THROW_TRANSFORMATION_EXCEPTION << "Invalid layer type of " << layer->get_friendly_name() << "; expected Convolutino or GroupConvolution";
148     }
149 }
150
151 // Assumin tensor in NC... layout, append necessary number of 1s to shape to align it to a give rank
152 Shape NetworkHelper::alignShapeForChannelDim(const Shape& shape, Rank rank) {
153     assert(shape.size() == 1);
154     assert(rank.is_static());
155     Shape result = shape;
156     result.resize(rank.get_length() - 1, 1);
157     return result;
158 }
159
160 void NetworkHelper::removeLayer(std::shared_ptr<Node> layer) {
161     ngraph::replace_output_update_name(layer->output(0), layer->input_value(0));
162 }
163
164 std::shared_ptr<Node> NetworkHelper::swapMultiplyAndAdd(std::shared_ptr<opset1::Add> addAfterMultiply, const int multiplyBranch) {
165     // Multiply --> Add(addAfterMultiply)  ==>  Add(new) --> Multiply(new)
166     // That means x*a + b ==> (x + b/a)*a; tries to fold b/a
167     const auto multiply = addAfterMultiply->get_input_node_shared_ptr(multiplyBranch);
168
169     const auto multiplyParent1 = multiply->get_input_node_shared_ptr(0);
170     const auto multiplyParent2 = multiply->get_input_node_shared_ptr(1);
171
172     auto multiplyInput = as_type_ptr<opset1::Multiply>(multiplyParent1);
173     auto multiplyConst = as_type_ptr<opset1::Constant>(multiplyParent2);
174     int multiplyInputBranch = 0;
175
176     if (multiplyConst == nullptr) {
177         multiplyInput = as_type_ptr<opset1::Multiply>(multiplyParent2);
178         multiplyConst = as_type_ptr<opset1::Constant>(multiplyParent1);
179         multiplyInputBranch = 1;
180     }
181
182     if (multiplyConst == nullptr)
183         return addAfterMultiply;
184
185     const auto x = multiply->get_input_node_shared_ptr(multiplyInputBranch);
186     const auto a = multiply->get_input_node_shared_ptr(multiplyInputBranch == 0 ? 1 : 0);
187     const auto b = addAfterMultiply->get_input_node_shared_ptr(multiplyBranch == 0 ? 1 : 0);
188     std::shared_ptr<Node> bDivA;
189
190     if (shape_size(b->get_output_shape(0)) == 1 ||
191         shape_size(a->get_output_shape(0)) == 1 ||
192         shape_size(b->get_output_shape(0)) == shape_size(a->get_output_shape(0))) {
193         // safely division to avoid NaN
194         const std::vector<float> bValues = as_type_ptr<opset1::Constant>(b)->cast_vector<float>();
195         const std::vector<float> aValues = as_type_ptr<opset1::Constant>(a)->cast_vector<float>();
196         const bool aBroadcasted = bValues.size() > aValues.size();
197         const bool bBroadcasted = bValues.size() < aValues.size();
198         std::vector<float> bDivAValues(aBroadcasted ? bValues.size() : aValues.size());
199
200         for (int i = 0; i < bDivAValues.size(); ++i) {
201             const auto bi = bValues[bBroadcasted ? 0 : i];
202             const auto ai = aValues[aBroadcasted ? 0 : i];
203             if (bi != 0.f || ai != 0.f) {
204                 bDivAValues[i] = bi / ai;
205             } else {
206                 bDivAValues[i] = 0.f;
207             }
208         }
209
210         bDivA = std::make_shared<opset1::Constant>(
211                 b->get_output_element_type(0),
212                 aBroadcasted ? b->get_output_shape(0) : a->get_output_shape(0),
213                 bDivAValues);
214     } else {
215         bDivA = fold<opset1::Divide>(b, a);
216     }
217
218     std::vector<std::shared_ptr<Node>> inputs{ {}, {} };
219
220     inputs[0] = x;
221     inputs[1] = bDivA;
222
223     std::shared_ptr<opset1::Add> newAdd = std::make_shared<op::TypeRelaxed<opset1::Add>>(
224         std::vector<element::Type>{element::f32, element::f32}, std::vector<element::Type>{ element::f32 },
225         ngraph::op::TemporaryReplaceOutputType(inputs[0], element::f32).get(),
226         ngraph::op::TemporaryReplaceOutputType(inputs[1], element::f32).get());
227     copyInfo(addAfterMultiply, newAdd);
228
229     NetworkHelper::setOutDataPrecision(newAdd, addAfterMultiply->get_output_element_type(0));
230
231     auto newMultiply = std::make_shared<DequantizationMultiply>(newAdd, a);
232     copyInfo(multiply, newMultiply);
233
234     replace_node(addAfterMultiply, newMultiply);
235     return newMultiply;
236 }
237
238 void NetworkHelper::copyInfo(const std::shared_ptr<Node>& source, const std::shared_ptr<Node>& target) {
239     // TODO: merge_runtime_info with correctly defined DEQUANTIZATION
240     const auto& sourceAttributes = source->get_rt_info();
241     auto& targetAttrubutes = target->get_rt_info();
242     for (auto attribute : sourceAttributes) {
243         targetAttrubutes[attribute.first] = attribute.second;
244     }
245
246     const std::string friendlyName = source->get_friendly_name();
247     if (!friendlyName.empty()) {
248         target->set_friendly_name(friendlyName);
249     }
250 }
251
252 void NetworkHelper::cleanRunTimeInfo(const std::shared_ptr<Node>& layer) {
253     auto& rt_info = layer->get_rt_info();
254     auto attributeIter = rt_info.find("DEQUANTIZATION");
255     if (rt_info.find("DEQUANTIZATION") != rt_info.end()) {
256         rt_info.erase(attributeIter);
257     }
258 }
259
260 bool NetworkHelper::isScalarLike(std::shared_ptr<opset1::Constant> constant) {
261     return constant->get_all_data_elements_bitwise_identical();
262 }
263
264 bool NetworkHelper::isZero(std::shared_ptr<opset1::Constant> constant) {
265     static const float minQuantizationShift = 1e-32f;
266
267     auto values = constant->cast_vector<float>();
268     for (size_t i = 0; i < values.size(); ++i) {
269         if (fabs(values[i]) > minQuantizationShift) {
270             return false;
271         }
272     }
273
274     return true;
275 }
276
277 std::shared_ptr<opset1::Constant> NetworkHelper::toScalar(std::shared_ptr<opset1::Constant> constant) {
278     assert(isScalarLike(constant));
279     return std::make_shared<opset1::Constant>(constant->get_element_type(), Shape{}, constant->get_data_ptr());
280 }
281
282 std::shared_ptr<Node> NetworkHelper::getConstantInput(std::shared_ptr<Node> node) {
283     std::shared_ptr<Node> constant1 = as_type_ptr<opset1::Constant>(node->input_value(0).get_node_shared_ptr());
284     if (!constant1) {
285         constant1 = as_type_ptr<opset1::Constant>(node->input_value(1).get_node_shared_ptr());
286     }
287     return constant1;
288 }
289
290 std::shared_ptr<ngraph::opset1::Multiply> NetworkHelper::optimizeMultipliesAfter(std::shared_ptr<Node> node) {
291     std::shared_ptr<ngraph::opset1::Multiply> multiply = as_type_ptr<opset1::Multiply>(node);
292     if (!multiply) {
293         THROW_IE_LPT_EXCEPTION(*multiply) << "Unexpected operation type";
294     }
295
296     if (multiply->output(0).get_target_inputs().size() == 1) {
297         auto constant1 = getConstantInput(multiply);
298         if (!constant1 || constant1->output(0).get_target_inputs().size() != 1) {
299             return multiply;
300         }
301         auto nextMultiplyInput = *multiply->output(0).get_target_inputs().begin();
302         auto nextMultiply = as_type_ptr<opset1::Multiply>(nextMultiplyInput.get_node()->shared_from_this());
303         if (nextMultiply) {
304             auto constant2 = getConstantInput(nextMultiply);
305             auto constant2Inputs = constant2->output(0).get_target_inputs().size();
306             if (!constant2 || constant2->output(0).get_target_inputs().size() != 1) {
307                 return multiply;
308             }
309
310             auto newConst = fold<opset1::Multiply>(constant1, constant2);
311             auto newMultiply =
312                     std::make_shared<opset1::Multiply>(
313                             multiply->input_value(1 - constant1->output(0).get_target_inputs().begin()->get_index()),
314                             newConst->output(0));
315             copy_runtime_info(multiply, newMultiply);
316             replace_node(nextMultiply, newMultiply);
317             return newMultiply;
318         }
319     }
320
321     return nullptr;
322 }
323
324 std::shared_ptr<opset1::Constant> NetworkHelper::round(std::shared_ptr<Node> node, element::Type target_type) {
325     const auto constant = as_type_ptr<opset1::Constant>(node);
326     assert(constant);
327
328     const auto castedConstant = as_type_ptr<ngraph::opset1::Constant>(fold<op::v0::Convert>(
329         fold<ngraph::op::v5::Round>(constant->output(0), ngraph::op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO),
330         target_type));
331
332     return castedConstant;
333 }
334
335 std::shared_ptr<Node> NetworkHelper::fold_fake_quantize(const std::shared_ptr<opset1::FakeQuantize>& fq) {
336     return foldFakeQuantize(fq, false, false);
337 }
338
339 std::shared_ptr<Node> NetworkHelper::fold_fake_quantize(const std::shared_ptr<opset1::FakeQuantize>& fq, const bool roundValues) {
340     return foldFakeQuantize(fq, roundValues, true);
341 }
342
343 void NetworkHelper::foldDequantization(std::shared_ptr<Node>& node, const size_t branchIndex, const bool inPlace) {
344     FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(node, branchIndex, inPlace);
345     if (dequantization.empty() || (dequantization.multiply == nullptr)) {
346         return;
347     }
348
349     std::shared_ptr<opset1::Constant> constant = as_type_ptr<opset1::Constant>(dequantization.data.get_node_shared_ptr());
350     if ((constant == nullptr) || (constant->output(0).get_target_inputs().size() != 1ul)) {
351         return;
352     }
353
354     if (dequantization.convert != nullptr) {
355         const std::shared_ptr<Node> result = fold<opset1::Convert>(dequantization.data, dequantization.convert->get_element_type());
356         if (!is_type<opset1::Constant>(result)) {
357             return;
358         }
359         if (inPlace) {
360             copyInfo(dequantization.convert, result);
361         }
362         replace_node(dequantization.convert, result);
363         dequantization = NetworkHelper::getDequantization(node, branchIndex, inPlace);
364     }
365
366     if (dequantization.subtract != nullptr) {
367         if (dequantization.data.get_element_type() != dequantization.subtract->input(1).get_element_type()) {
368             return;
369         }
370         const std::shared_ptr<Node> result = fold<opset1::Subtract>(dequantization.data, dequantization.subtract->get_input_node_shared_ptr(1));
371         if (!is_type<opset1::Constant>(result)) {
372             return;
373         }
374         if (inPlace) {
375             copyInfo(dequantization.subtract, result);
376         }
377         replace_node(dequantization.subtract, result);
378         dequantization = NetworkHelper::getDequantization(node, branchIndex, inPlace);
379     }
380
381     if (dequantization.multiply != nullptr) {
382         if (dequantization.data.get_element_type() != dequantization.multiply->input(1).get_element_type()) {
383             return;
384         }
385         const std::shared_ptr<Node> result = fold<opset1::Multiply>(dequantization.data, dequantization.multiply->get_input_node_shared_ptr(1));
386         if (!is_type<opset1::Constant>(result)) {
387             return;
388         }
389         if (inPlace) {
390             copyInfo(dequantization.multiply, result);
391         }
392         replace_node(dequantization.multiply, result);
393         dequantization = NetworkHelper::getDequantization(node, branchIndex, inPlace);
394     }
395 }
396
397 std::shared_ptr<Node> NetworkHelper::foldFakeQuantize(
398     const std::shared_ptr<opset1::FakeQuantize>& fq,
399     const bool roundValuesArg,
400     const bool roundValuesWasSet) {
401     if (is_type<opset1::Constant>(fq->get_input_node_shared_ptr(0)) &&
402         is_type<opset1::Constant>(fq->get_input_node_shared_ptr(1)) &&
403         is_type<opset1::Constant>(fq->get_input_node_shared_ptr(2)) &&
404         is_type<opset1::Constant>(fq->get_input_node_shared_ptr(3)) &&
405         is_type<opset1::Constant>(fq->get_input_node_shared_ptr(4)) &&
406         op::util::constantIsEqualTo(as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(1)), 0.f) &&
407         op::util::constantIsEqualTo(as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(2)), 254.f) &&
408         op::util::constantIsEqualTo(as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(3)), -127.f) &&
409         op::util::constantIsEqualTo(as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(4)), 127.f)) {
410         const auto type1 = fq->input_value(0).get_element_type();
411         const auto type2 = fq->input_value(3).get_element_type();
412         if (type1.is_real() && type2.is_real()) {
413             return fold<opset1::Add>(fq->input_value(0), fq->input_value(3));
414         }
415         if (type1.is_real() && !type2.is_real()) {
416             return fold<opset1::Add>(
417                 fq->input_value(0),
418                 fold<opset1::Convert>(fq->input_value(3), type1));
419         }
420         if (!type1.is_real() && type2.is_real()) {
421             return fold<opset1::Add>(
422                 fold<opset1::Convert>(fq->input_value(0), type2),
423                 fq->input_value(3));
424         }
425         return fold<opset1::Add>(
426             fold<opset1::Convert>(fq->input_value(0), element::f32),
427             fold<opset1::Convert>(fq->input_value(3), element::f32));
428     }
429
430     auto constant = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(0));
431
432     if (constant) {
433         const bool roundValues = roundValuesWasSet ? roundValuesArg : fq->output(0).get_element_type().is_integral();
434
435         Shape constShape = fq->get_output_shape(0);
436         if (constShape.empty() || constShape.size() > 5lu) {
437             THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected dimensions count " << constShape.size();
438         }
439
440         // OIDHW
441         const size_t OC = constShape[0];
442         const size_t IC = constShape.size() > 1lu ? constShape[1] : 1;
443         const size_t D = constShape.size() > 4lu ? constShape[constShape.size() - 3] : 1;
444         const size_t H = constShape.size() > 2lu ? constShape.size() == 3lu ? constShape[2] : constShape[constShape.size() - 2] : 1;
445         const size_t W = constShape.size() > 3lu ? constShape[constShape.size() - 1] : 1;
446
447         const auto inputLowValues = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(1))->cast_vector<float>();
448         const auto inputHighValues = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(2))->cast_vector<float>();
449         const auto outputLowValues = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(3))->cast_vector<float>();
450         const auto outputHighValues = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(4))->cast_vector<float>();
451
452         const size_t inputLowSize = inputLowValues.size();
453         const size_t inputHighSize = inputHighValues.size();
454         const size_t outputLowSize = outputLowValues.size();
455         const size_t outputHighSize = outputHighValues.size();
456
457         const bool isInputLowBroadcasted = inputLowSize != OC;
458         if ((inputLowSize != 1) && (inputLowSize != OC)) {
459             THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected input low values count " << inputLowSize << " for " << OC << " channels";
460         }
461         const bool isInputHighBroadcasted = inputHighSize != OC;
462         if ((inputHighSize != 1) && (inputHighSize != OC)) {
463             THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected input high values count " << inputHighSize << " for " << OC << " channels";
464         }
465         const bool isOutputLowBroadcasted = outputLowSize != OC;
466         if ((outputLowSize != 1) && (outputLowSize != OC)) {
467             THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected output low values count " << outputLowSize << " for " << OC << " channels";
468         }
469         const bool isOutputHighBroadcasted = outputHighSize != OC;
470         if ((outputHighSize != 1) && (outputHighSize != OC)) {
471             THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected output high values count " << outputHighSize << " for " << OC << " channels";
472         }
473
474         auto levels_1 = fq->get_levels() - 1.f;
475
476         //const size_t DHW = D * H * W;
477         const size_t IDHW = IC * D * H * W;
478
479         const auto values = constant->cast_vector<float>();
480         std::vector<float> quantizedValues(OC * IC * D * H * W);
481
482         for (int oc = 0; oc < OC; ++oc) {
483             for (int iidx = 0; iidx < IDHW; ++iidx) {
484                 const float inputLow = inputLowValues[isInputLowBroadcasted ? 0 : oc];
485                 const float inputHigh = inputHighValues[isInputHighBroadcasted ? 0 : oc];
486                 const float outputLow = outputLowValues[isOutputLowBroadcasted ? 0 : oc];
487                 const float outputHigh = outputHighValues[isOutputHighBroadcasted ? 0 : oc];
488
489                 const size_t idx = oc * IDHW + iidx;
490
491                 if (values[idx] <= inputLow) {
492                     quantizedValues[idx] = roundValues ? std::roundf(outputLow) : outputLow;
493                 } else if (values[idx] > inputHigh) {
494                     quantizedValues[idx] = roundValues ? std::roundf(outputHigh) : outputHigh;
495                 } else {
496                     const float value = std::roundf((values[idx] - inputLow) / (inputHigh - inputLow) * levels_1) /
497                         levels_1 * (outputHigh - outputLow) + outputLow;
498                     quantizedValues[idx] = roundValues ? std::roundf(value) : value;
499                 }
500             }
501         }
502
503         return std::make_shared<opset1::Constant>(fq->get_output_element_type(0), constShape, quantizedValues);
504     }
505
506     return fq;
507 }
508
509 // Decompose FakeQuantize to FakeQuantize with output integer limits (quantize), dequatized MultiplyAdd
510 // To align types the resulting sequence is FakeQuantize -> Convert -> Convert -> MultiplyAdd
511 std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> NetworkHelper::decomposeFakeQuantize(
512     std::shared_ptr<opset1::FakeQuantize> fq,
513     const element::Type precision,
514     const float min,
515     const float max,
516     const bool hasZeroPoint,
517     const bool updatePrecision) {
518     using std::make_shared;
519
520     const auto outputLow = fq->input_value(3);
521     const auto outputHigh = fq->input_value(4);
522
523     std::vector<float> outputLowValues = as_type_ptr<opset1::Constant>(outputLow.get_node_shared_ptr())->cast_vector<float>();
524     std::vector<float> outputHighValues = as_type_ptr<opset1::Constant>(outputHigh.get_node_shared_ptr())->cast_vector<float>();
525     size_t outputSize = outputLowValues.size();
526     std::vector<float> minValues(outputSize, min);
527     std::vector<float> maxValues(outputSize, max);
528     std::vector<float> shifts(outputSize, 0.f);
529     std::vector<float> scales(outputSize);
530
531     for (int i = 0; i < outputSize; ++i) {
532         if (outputHighValues[i] != outputLowValues[i]) {
533             shifts[i] = (min*outputHighValues[i] - max*outputLowValues[i]) / (outputHighValues[i] - outputLowValues[i]);
534             scales[i] = (outputHighValues[i] - outputLowValues[i]) / (max - min);
535             if (shifts[i] == -0.f) {
536                 shifts[i] = 0.f;
537             }
538         } else {
539             scales[i] = outputHighValues[i];
540             minValues[i] = 1.f;
541             maxValues[i] = 1.f;
542         }
543     }
544
545     std::shared_ptr<Node> shift = hasZeroPoint ?
546         std::make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), shifts) :
547         nullptr;
548     std::shared_ptr<Node> scale = std::make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), scales);
549
550     auto newMin = make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), minValues);
551     auto newMax = make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), maxValues);
552
553     if (isScalarLike(newMin)) {
554         newMin = toScalar(newMin);
555     }
556     if (isScalarLike(newMax)) {
557         newMax = toScalar(newMax);
558     }
559
560     {
561         static const float minQuantizationScale = 1e-32f;
562         static const float maxQuantizationScale = 1e32f;
563
564         auto scaleValues = scales;
565         bool wasChanged = false;
566         for (size_t i = 0; i < scaleValues.size(); ++i) {
567             const float scale = scaleValues[i];
568             if (fabs(scale) < minQuantizationScale) {
569                 scaleValues[i] = minQuantizationScale;
570                 wasChanged = true;
571             } else if (fabs(scale) > maxQuantizationScale) {
572                 scaleValues[i] = scale > 0.f ? maxQuantizationScale : -maxQuantizationScale;
573                 wasChanged = true;
574             }
575         }
576
577         if (wasChanged) {
578             scale = std::make_shared<opset1::Constant>(scale->output(0).get_element_type(), scale->output(0).get_shape(), scaleValues);
579         }
580     }
581
582     if ((shift != nullptr) && isZero(as_type_ptr<opset1::Constant>(shift))) {
583         shift = nullptr;
584     }
585
586     // Build a substitution sub-graph:
587
588     std::shared_ptr<ngraph::Node> newFQ = fold_fake_quantize(
589         std::make_shared<op::TypeRelaxed<opset1::FakeQuantize>>(
590             fq->input_value(0),
591             fq->input_value(1),
592             fq->input_value(2),
593             newMin->output(0),
594             newMax->output(0),
595             fq->get_levels(),
596             fq->get_auto_broadcast()),
597         true);
598     NetworkHelper::copyInfo(fq, newFQ);
599
600     std::shared_ptr<ngraph::Node> convert2;
601     if (updatePrecision) {
602         std::shared_ptr<Node> convert;
603         std::shared_ptr<opset1::Constant> newFqConstant = as_type_ptr<opset1::Constant>(newFQ);
604
605         if (is_type<opset1::Constant>(newFQ)) {
606             convert = fold<opset1::Convert>(newFQ, precision);
607         } else if (is_type<opset1::FakeQuantize>(newFQ)) {
608             newFQ = setOutDataPrecision(as_type_ptr<opset1::FakeQuantize>(newFQ), precision);
609             convert = newFQ;
610         } else {
611             THROW_IE_LPT_EXCEPTION(*newFQ) << "unexpected operation type";
612         }
613
614         convert2 = std::make_shared<DequantizationConvert>(convert, element::f32);
615         convert2->set_friendly_name(convert->get_friendly_name() + "/DequantizationConvert");
616         ngraph::copy_runtime_info({ newFQ, convert2 }, convert2);
617     } else {
618         if (newFQ->get_output_element_type(0) != element::f32) {
619             convert2 = std::make_shared<DequantizationConvert>(newFQ, element::f32);
620             convert2->set_friendly_name(newFQ->get_friendly_name() + "/DequantizationConvert");
621             ngraph::copy_runtime_info({ newFQ, convert2 }, convert2);
622         }
623     }
624
625     // TODO: why type relaxed?
626     const std::shared_ptr<ngraph::Node> sub = shift == nullptr ?
627         nullptr :
628         std::make_shared<ngraph::op::TypeRelaxed<DequantizationSubtract>>(convert2 == nullptr ? newFQ : convert2, shift);
629     if (sub != nullptr) {
630         sub->set_friendly_name(newFQ->get_friendly_name() + "/DequantizationSubtract");
631         ngraph::copy_runtime_info({ newFQ, sub }, sub);
632     }
633
634     const std::shared_ptr<ngraph::opset1::Multiply> dequantize = std::make_shared<DequantizationMultiply>(
635         sub == nullptr ? (convert2 == nullptr ? newFQ : convert2) : sub,
636         scale);
637     dequantize->set_friendly_name(newFQ->get_friendly_name() + "/DequantizationMultiply");
638     ngraph::copy_runtime_info({ newFQ, dequantize }, dequantize);
639
640     replace_node(fq, dequantize);
641
642     return std::make_tuple(newFQ, dequantize);
643 }
644
645 std::shared_ptr<opset1::FakeQuantize> NetworkHelper::updateFakeQuantize(
646     std::shared_ptr<opset1::FakeQuantize> fq,
647     element::Type precision,
648     float min,
649     float max) {
650     auto newMin = std::make_shared<opset1::Constant>(fq->get_output_element_type(0), Shape{}, min);
651     auto newMax = std::make_shared<opset1::Constant>(fq->get_output_element_type(0), Shape{}, max);
652
653     std::shared_ptr<opset1::FakeQuantize> newFQ = std::make_shared<ngraph::op::TypeRelaxed<opset1::FakeQuantize>>(
654             fq->input_value(0),
655             fq->input_value(1),
656             fq->input_value(2),
657             newMin->output(0),
658             newMax->output(0),
659             fq->get_levels(),
660             fq->get_auto_broadcast());
661
662     NetworkHelper::setOutDataPrecision(newFQ, precision);
663     replace_node(fq, newFQ);
664
665     newFQ->set_friendly_name(fq->get_friendly_name());
666     return newFQ;
667 }
668
669 FakeQuantizeDequantization NetworkHelper::makeDequantization(
670     const float dequantizationMul,
671     const float dequantizationSub,
672     const ngraph::element::Type originalPrecision,
673     const ngraph::Shape dataNodeOutputShape,
674     element::Type precision,
675     float min,
676     float max) {
677     // TODO: we create input here! we really need it here?
678     const std::shared_ptr<opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(precision, dataNodeOutputShape);
679     std::shared_ptr<ngraph::Node> parent = input;
680
681     // TODO: convert should be optional: where is updatePrecision?
682     std::shared_ptr<DequantizationConvert> convert;
683     {
684         convert = std::make_shared<DequantizationConvert>(
685             input,
686             originalPrecision);
687         parent = convert;
688     }
689
690     std::shared_ptr<DequantizationSubtract> subtract;
691     if (std::abs(dequantizationSub) > 1e-6) {
692         subtract = std::make_shared<ngraph::op::TypeRelaxed<DequantizationSubtract>>(
693             parent,
694             std::make_shared<ngraph::opset1::Constant>(originalPrecision, ngraph::Shape({}), std::vector<float>({ dequantizationSub })));
695         subtract->set_output_type(0, originalPrecision, subtract->get_output_partial_shape(0));
696         parent = subtract;
697     }
698
699     // mandatory
700     std::shared_ptr<ngraph::opset1::Multiply> multiply = std::make_shared<DequantizationMultiply>(
701         parent,
702         std::make_shared<ngraph::opset1::Constant>(originalPrecision, ngraph::Shape({}), std::vector<float>({ dequantizationMul })));
703
704     return FakeQuantizeDequantization(input, convert, subtract, multiply);
705 }
706
707 FakeQuantizeDequantization NetworkHelper::createDequantizationFromFakeQuantize(
708     std::shared_ptr<opset1::FakeQuantize> fq,
709     element::Type precision,
710     float min,
711     float max,
712     const bool hasZeroPoint,
713     const bool updatePrecision) {
714     using std::make_shared;
715
716     const ngraph::element::Type_t fqPrecision = fq->get_output_element_type(0);
717     auto newMin = make_shared<opset1::Constant>(fqPrecision, Shape{}, min);
718     auto newMax = make_shared<opset1::Constant>(fqPrecision, Shape{}, max);
719
720     auto outputLow = fq->input_value(3);
721     auto outputHigh = fq->input_value(4);
722
723     // TODO: threshold values have to used here to avoid shifts
724
725     const std::shared_ptr<Node> scale = fold<opset1::Divide>(
726         fold<opset1::Subtract>(outputHigh, outputLow),
727         fold<opset1::Subtract>(newMax, newMin));
728
729     std::shared_ptr<Node> shift = hasZeroPoint ?
730         fold<opset1::Divide>(
731             fold<opset1::Subtract>(fold<opset1::Multiply>(newMin, outputHigh), fold<opset1::Multiply>(newMax, outputLow)),
732             fold<opset1::Subtract>(outputHigh, outputLow)) :
733         nullptr;
734
735     if (shift != nullptr) {
736         std::shared_ptr<opset1::Constant> shiftConst = as_type_ptr<opset1::Constant>(shift);
737         if (isScalarLike(shiftConst)) {
738             auto scalar = toScalar(shiftConst);
739             if (op::util::constantIsEqualTo(scalar, 0)) {
740                 shift = nullptr;
741             }
742         }
743     }
744
745     const auto input = std::make_shared<ngraph::opset1::Parameter>(precision, fq->get_output_shape(0));
746     const std::shared_ptr<ngraph::opset1::Convert> convert = std::make_shared<DequantizationConvert>(
747         input,
748         fq->get_output_element_type(0));
749
750     const std::shared_ptr<ngraph::opset1::Subtract> subtract = shift == nullptr ?
751         nullptr :
752         make_shared<ngraph::op::TypeRelaxed<DequantizationSubtract>>(convert, shift);
753     if (subtract != nullptr) {
754         subtract->set_output_type(0, fq->get_output_element_type(0), subtract->get_output_partial_shape(0));
755     }
756
757     const std::shared_ptr<ngraph::opset1::Multiply> multiply = std::make_shared<DequantizationMultiply>(
758         subtract == nullptr ? static_cast<std::shared_ptr<Node>>(convert) : subtract,
759         scale);
760
761     return FakeQuantizeDequantization(fq, convert, subtract, multiply);
762 }
763
764 FakeQuantizeDequantization NetworkHelper::getDequantization(const std::shared_ptr<Node> node, const size_t parentIndex, const bool inPlace) {
765     auto getDataIndex = [](const std::shared_ptr<ngraph::Node>& node) {
766         if (is_type<opset1::Constant>(node->get_input_node_ptr(1))) {
767             return 0ul;
768         } else {
769             return 1ul;
770         }
771     };
772
773     Output<Node> dataNode = inPlace ? node : node->input_value(parentIndex);
774
775     const std::shared_ptr<ngraph::opset1::Multiply> multiply = as_type_ptr<ngraph::opset1::Multiply>(dataNode.get_node_shared_ptr());
776     if (multiply != nullptr) {
777         if (!is_type<opset1::Constant>(multiply->get_input_node_ptr(0)) && !is_type<opset1::Constant>(multiply->get_input_node_ptr(1))) {
778             return FakeQuantizeDequantization();
779         }
780         dataNode = multiply->get_input_source_output(getDataIndex(multiply));
781     }
782
783     const std::shared_ptr<opset1::Subtract> subtract = as_type_ptr<ngraph::opset1::Subtract>(dataNode.get_node_shared_ptr());
784     if (subtract != nullptr) {
785         if (!is_type<opset1::Constant>(subtract->get_input_node_ptr(0)) && !is_type<opset1::Constant>(subtract->get_input_node_ptr(1))) {
786             return FakeQuantizeDequantization(dataNode, nullptr, nullptr, multiply);
787         }
788         dataNode = subtract->get_input_source_output(getDataIndex(subtract));
789     }
790
791     const std::shared_ptr<opset1::Convert> convert = as_type_ptr<opset1::Convert>(dataNode.get_node_shared_ptr());
792     if (convert != nullptr) {
793         if ((convert->input(0).get_element_type() != element::i8) && (convert->input(0).get_element_type() != element::u8) &&
794             (convert->output(0).get_element_type() != element::f32)) {
795             return FakeQuantizeDequantization(dataNode, nullptr, subtract, multiply);
796         }
797         dataNode = convert->get_input_source_output(0);
798     }
799
800     return FakeQuantizeDequantization(dataNode, convert, subtract, multiply);
801 }
802
803 FakeQuantizeDequantizationValues NetworkHelper::createEmptyValues(const FakeQuantizeDequantization& dequantization) {
804     std::shared_ptr<Node> parent = dequantization.convert ? dequantization.convert : dequantization.data.get_node_shared_ptr();
805
806     std::shared_ptr<Node> multiply1Const = dequantization.multiply ?
807         dequantization.multiply->get_input_node_shared_ptr(1)->clone_with_new_inputs({}) :
808         std::make_shared<opset1::Constant>(parent->get_output_element_type(0), Shape({}), std::vector<float>({ 1.f }));
809
810     std::shared_ptr<Node> subtract1Const = dequantization.subtract ?
811         dequantization.subtract->get_input_node_shared_ptr(1)->clone_with_new_inputs({}) :
812         std::make_shared<opset1::Constant>(parent->get_output_element_type(0), Shape({}), std::vector<float>({ 0.f }));
813
814     subtract1Const->set_output_type(0, multiply1Const->get_output_element_type(0), subtract1Const->get_output_partial_shape(0));
815
816     return FakeQuantizeDequantizationValues(subtract1Const, multiply1Const);
817 }
818
819 bool NetworkHelper::isZeroConst(const std::shared_ptr<Node>& node) {
820     std::shared_ptr<opset1::Constant> constant = as_type_ptr<opset1::Constant>(node);
821
822     if (constant == nullptr)
823         return false;
824
825     if (NetworkHelper::isScalarLike(constant)) {
826         auto scalar = NetworkHelper::toScalar(constant);
827         if (op::util::constantIsEqualTo(scalar, 0)) {
828             return true;
829         } else {
830             return false;
831         }
832     } else {
833         return false;
834     }
835 }
836
837 std::shared_ptr<Node> NetworkHelper::optimizeSubtract(std::shared_ptr<opset1::Subtract> subtract) {
838     auto convertOnSubtract = subtract->input_value(0).get_node_shared_ptr();
839     if (as_type_ptr<opset1::Convert>(convertOnSubtract) == nullptr) {
840         return subtract;
841     }
842
843     // TODO: replace assert to condition and omit conversion part if there is no convert
844     // TODO: also check convertInputType to understand if we really want to propagate type
845     assert(as_type_ptr<opset1::Convert>(convertOnSubtract));
846     const element::Type convertInputType = convertOnSubtract->get_input_element_type(0);
847     const element::Type convertOutputType = convertOnSubtract->get_output_element_type(0);
848
849     if (!convertOutputType.is_real()) {
850         return subtract;
851     }
852
853     auto data = convertOnSubtract->input_value(0);
854     auto shift = subtract->input_value(1).get_node_shared_ptr();
855     auto roundedShift = NetworkHelper::round(shift, convertInputType);
856
857     // Propagate convertInputType down
858     const auto replacement = std::make_shared<op::TypeRelaxed<opset1::Subtract>>(data, roundedShift);
859     NetworkHelper::copyInfo(subtract, replacement);
860     NetworkHelper::setOutDataPrecisionForTypeRelaxed(replacement, convertOutputType);
861     replace_node(subtract, replacement);
862
863     // We lose the tail conversion here; not needed if the next node is a TypeRelaxed
864     // TODO: check cases when Convert should be preserved
865
866     // Try to optimize Add out if constant is zero
867     // TODO: don't remove operation here: don't create this Subtraction operation in FQ decomposition
868     // if (isScalarLike(roundedShift)) {
869     //    auto scalar = distillToScalar(roundedShift);
870     //    if (op::util::constantIsEqualTo(scalar, 0)) {
871     //        replace_node(replacement, replacement->input_value(0).get_node_shared_ptr());
872     //        replacement = nullptr;
873     //    }
874     // }
875
876     return replacement;
877 }
878
879 NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter(
880     const std::shared_ptr<ngraph::Node>& operation,
881     const FakeQuantizeDequantization& dequantization,
882     const bool updatePrecision,
883     const bool moveSubtract) {
884     std::vector<Output<Node>> inputs(operation->get_input_size());
885     for (size_t i = 0; i < operation->get_input_size(); ++i) {
886         inputs[i] = operation->get_input_node_shared_ptr(i);
887     }
888
889     const size_t dequantizationIndex = getChildInputIndex(dequantization.multiply, operation);
890     inputs[dequantizationIndex] = moveSubtract ?
891         dequantization.data :
892         (dequantization.subtract == nullptr ? dequantization.data : dequantization.subtract);
893
894     const std::shared_ptr<ngraph::Node> newOperation = operation->clone_with_new_inputs(inputs);
895     newOperation->set_friendly_name(operation->get_friendly_name());
896     ngraph::copy_runtime_info(operation, newOperation);
897
898     if (updatePrecision) {
899         auto op = std::dynamic_pointer_cast<ngraph::op::TypeRelaxedBase>(newOperation);
900         if (op == nullptr) {
901             THROW_IE_LPT_EXCEPTION(*newOperation) << "not possible to update precision for not TypeRelaxedBase operation";
902         }
903         op->set_overridden_output_type(newOperation->get_input_element_type(0));
904         std::dynamic_pointer_cast<ngraph::Node>(newOperation)->validate_and_infer_types();
905     }
906
907     const bool shouldConvert = (newOperation->get_output_element_type(0) != dequantization.multiply->get_output_element_type(0));
908
909     auto parent = newOperation;
910     if (shouldConvert) {
911         const auto convertOutputPrecision = dequantization.convert != nullptr ?
912             dequantization.convert->get_output_element_type(0) :
913             dequantization.multiply->get_output_element_type(0);
914         parent = std::make_shared<DequantizationConvert>(parent, convertOutputPrecision);
915         ngraph::copy_runtime_info({ newOperation, parent }, parent);
916     }
917
918     if (moveSubtract && (dequantization.subtract != nullptr)) {
919         auto subtractConstant = dequantization.subtract->get_input_node_shared_ptr(1);
920         const element::Type parentPrecision = parent->get_output_element_type(0);
921         if (parentPrecision.bitwidth() < subtractConstant->output(0).get_element_type().bitwidth()) {
922             THROW_IE_LPT_EXCEPTION(*parent) <<
923                 "unexpected precisions: on data " << parent->get_friendly_name() << ":" << parentPrecision <<
924                 ", subtract dequantization constant " << subtractConstant->get_friendly_name() << ":" << subtractConstant->output(0).get_element_type();
925         }
926
927         parent = std::make_shared<DequantizationSubtract>(
928             parent,
929             subtractConstant->output(0).get_element_type() == parentPrecision ?
930                 subtractConstant :
931                 fold<opset1::Convert>(subtractConstant->output(0), parentPrecision));
932         ngraph::copy_runtime_info({ newOperation, parent }, parent);
933     }
934
935     if (dequantization.multiply != nullptr) {
936         auto multiplyConstant = dequantization.multiply->get_input_node_shared_ptr(1);
937         const element::Type parentPrecision = parent->get_output_element_type(0);
938         if (parentPrecision.bitwidth() < multiplyConstant->output(0).get_element_type().bitwidth()) {
939             THROW_IE_LPT_EXCEPTION(*parent) <<
940                 "unexpected precisions: on data " << parent->get_friendly_name() << ":" << parentPrecision <<
941                 ", multiply dequantization constant " << multiplyConstant->get_friendly_name() << ":" << multiplyConstant->output(0).get_element_type();
942         }
943
944         parent = std::make_shared<DequantizationMultiply>(
945             parent,
946             multiplyConstant->output(0).get_element_type() == parentPrecision ?
947                 multiplyConstant :
948                 fold<opset1::Convert>(multiplyConstant->output(0), parentPrecision));
949         ngraph::copy_runtime_info({ newOperation, parent }, parent);
950     }
951     replace_node(operation, parent);
952
953     if ((!moveSubtract) && (dequantization.convert != nullptr) && (dequantization.subtract != nullptr)) {
954         NetworkHelper::cleanRunTimeInfo(dequantization.subtract);
955         // issue #43088
956         // NetworkHelper::optimizeElementwise(dequantization.subtract);
957     }
958
959     return InsertDequantizationResult(newOperation, parent);
960 }
961
962 void NetworkHelper::removeConvertIfPossible(
963     const std::shared_ptr<ngraph::Node>& operation,
964     const FakeQuantizeDequantization& dequantization) {
965     const element::Type precisionBeforeConvert = dequantization.convert->input(0).get_element_type();
966
967     if (checkConstantValuePrecision(precisionBeforeConvert, dequantization.subtract->get_input_node_shared_ptr(1))) {
968         auto newSubtract = dequantization.subtract->clone_with_new_inputs({
969             dequantization.convert->get_input_node_shared_ptr(0),
970             fold<opset1::Convert>(dequantization.subtract->get_input_node_shared_ptr(1), precisionBeforeConvert) });
971         replace_node(dequantization.subtract, newSubtract);
972     }
973 }
974
975 bool NetworkHelper::checkConstantValuePrecision(const element::Type expectedPrecision, const std::shared_ptr<Node>& constant) {
976     if (expectedPrecision.is_signed()) {
977         return true;
978     }
979
980     std::shared_ptr<opset1::Constant> constantOp = as_type_ptr<opset1::Constant>(constant);
981     if (constantOp == nullptr) {
982         return false;
983     }
984
985     const auto values = constantOp->cast_vector<float>();
986     const bool convertCanBeRemoved =
987         (expectedPrecision.is_signed() || (std::all_of(values.begin(), values.end(), [](const float value) { return value >= 0.f; })));
988     return convertCanBeRemoved;
989 }
990
991 size_t NetworkHelper::getChildInputIndex(const std::shared_ptr<ngraph::Node>& parent, const std::shared_ptr<ngraph::Node>& child) {
992     for (size_t i = 0; i < child->get_input_size(); ++i) {
993         if (parent.get() == child->get_input_node_ptr(i)) {
994             return i;
995         }
996     }
997     THROW_IE_LPT_EXCEPTION(*child) << "child input index between " <<
998         parent->get_friendly_name() << " and " << child->get_friendly_name() << " was not found";
999 }
1000
1001 size_t NetworkHelper::getParentOutputIndex(const std::shared_ptr<ngraph::Node>& parent, const std::shared_ptr<ngraph::Node>& child) {
1002     for (size_t i = 0; i < parent->get_output_size(); ++i) {
1003         const auto& targetInputs = parent->output(i).get_target_inputs();
1004         for (const auto& targetInput : targetInputs) {
1005             if (targetInput.get_node() == child.get()) {
1006                 return i;
1007             }
1008         }
1009     }
1010     THROW_IE_LPT_EXCEPTION(*child) << "parent output index between " <<
1011         parent->get_friendly_name() << " and " << child->get_friendly_name() << " was not found";
1012 }
1013
1014 std::vector<Output<Node>> NetworkHelper::getInputs(const std::shared_ptr<ngraph::Node>& node) {
1015     std::vector<Output<Node>> inputs(node->get_input_size());
1016     for (size_t i = 0; i < node->get_input_size(); ++i) {
1017         inputs[i] = node->get_input_node_shared_ptr(i);
1018     }
1019     return inputs;
1020 }
1021
1022 std::shared_ptr<Node> NetworkHelper::toScalarIfPossible(std::shared_ptr<Node> node) {
1023     std::shared_ptr<opset1::Constant> constant = as_type_ptr<opset1::Constant>(node);
1024     if (constant == nullptr) {
1025         return node;
1026     }
1027
1028     if (!NetworkHelper::isScalarLike(constant)) {
1029         return node;
1030     }
1031
1032     return NetworkHelper::toScalar(constant);
1033 }
1034
1035 }  // namespace low_precision
1036 }  // namespace pass
1037 }  // namespace ngraph