1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <low_precision/network_helper.hpp>
13 #include <unordered_set>
18 #include <ngraph/rt_info.hpp>
19 #include "low_precision/common/ie_lpt_exception.hpp"
20 #include "low_precision/common/dequantization_op.hpp"
24 namespace low_precision {
26 // Return true if `type` can be castable to at least one of `type`
27 bool NetworkHelper::is_castable_to_one_of(NodeTypeInfo type, const std::unordered_set<NodeTypeInfo>& types) {
28 for (auto another : types) {
29 if (type.is_castable(another)) {
36 // Collect and return a vector with all nodes that consumes any of the `node` output
37 std::vector<Input<Node>> NetworkHelper::consumer_inputs(std::shared_ptr<Node> node) {
38 std::vector<Input<Node>> result;
39 for (const auto& output_port : node->outputs()) {
40 for (const auto &input : output_port.get_target_inputs()) {
41 result.push_back(input);
47 std::vector<std::shared_ptr<Node>> NetworkHelper::consumers(std::shared_ptr<Node> node) {
48 auto inputs = consumer_inputs(node);
49 std::vector<std::shared_ptr<Node>> result(inputs.size());
50 std::transform(inputs.begin(), inputs.end(), result.begin(), [](Input<Node> input){ return input.get_node()->shared_from_this(); });
54 int NetworkHelper::onWeightsInDepth(std::shared_ptr<Node> layer) {
55 const std::vector<std::shared_ptr<Node>> children = consumers(layer);
56 for (std::shared_ptr<Node> child : children) {
57 if ((is_type<opset1::Convolution>(child) ||
58 is_type<opset1::GroupConvolution>(child) ||
59 is_type<opset1::MatMul>(child)) &&
60 (child->inputs().size() >= 2lu)) {
61 const std::vector<std::shared_ptr<Node>> parents = getParentsRecursivelyExceptTypes(child, {}, 1);
62 for (const std::shared_ptr<Node>& parent : parents) {
63 if (parent.get() == layer.get()) {
70 const int result = onWeightsInDepth(child);
78 bool NetworkHelper::onWeights(std::shared_ptr<Node> layer) {
79 const int result = onWeightsInDepth(layer);
83 size_t NetworkHelper::getOutputChannelsCount(std::shared_ptr<const Node> layer, bool isOnWeights) {
84 if (layer->outputs().size() == 0) {
85 THROW_TRANSFORMATION_EXCEPTION << "Layer " << layer->get_friendly_name() << " doesn't have output tensors";
88 if (layer->outputs().size() > 1) {
89 THROW_TRANSFORMATION_EXCEPTION << "Layer " << layer->get_friendly_name() << " has too many output tensors, expected one";
92 PartialShape shape = layer->get_output_partial_shape(0);
93 if (shape.rank() == 0) {
94 THROW_TRANSFORMATION_EXCEPTION << "Invalid dimensions count (0) in output of " << layer->get_friendly_name() << " layer on weights";
97 return shape[0].get_length();
99 if (shape.rank() == 1) {
100 return shape[0].get_length();
102 return shape[1].get_length();
106 std::vector<std::shared_ptr<Node>> NetworkHelper::getParentsRecursivelyExceptTypes(
107 std::shared_ptr<Node> layer,
108 const std::unordered_set<NodeTypeInfo>& exceptionLayerTypes,
109 const int portIndex) {
110 std::vector<std::shared_ptr<Node>> parents;
112 for (auto input : layer->inputs()) {
113 if ((portIndex == -1) || (portIndex == i)) {
114 auto parent = input.get_source_output().get_node_shared_ptr();
115 if (is_castable_to_one_of(parent->get_type_info(), exceptionLayerTypes)) {
116 const std::vector<std::shared_ptr<Node>> tmpParents = getParentsRecursivelyExceptTypes(parent, exceptionLayerTypes);
117 parents.insert(parents.end(), tmpParents.begin(), tmpParents.end());
119 parents.push_back(parent);
128 size_t NetworkHelper::getInputChannelsCount(std::shared_ptr<Node> layer) {
129 if (layer->get_input_size() == 0) {
130 THROW_TRANSFORMATION_EXCEPTION << "There are no input layers";
133 PartialShape shape = layer->get_input_partial_shape(0);
134 if (shape.rank().get_length() <= 1) {
135 THROW_TRANSFORMATION_EXCEPTION << "Invalid dimensions count (0) in input of " << layer->get_friendly_name();
138 return shape[1].get_length();
141 size_t NetworkHelper::getGroupsCount(std::shared_ptr<Node> layer) {
142 if (as_type_ptr<opset1::Convolution>(layer)) {
144 } else if (auto group_convolution = as_type_ptr<opset1::GroupConvolution>(layer)) {
145 return layer->get_input_shape(1)[0]; // input weights for opset1::GC is in format GOI..., see the specification
147 THROW_TRANSFORMATION_EXCEPTION << "Invalid layer type of " << layer->get_friendly_name() << "; expected Convolutino or GroupConvolution";
151 // Assumin tensor in NC... layout, append necessary number of 1s to shape to align it to a give rank
152 Shape NetworkHelper::alignShapeForChannelDim(const Shape& shape, Rank rank) {
153 assert(shape.size() == 1);
154 assert(rank.is_static());
155 Shape result = shape;
156 result.resize(rank.get_length() - 1, 1);
160 void NetworkHelper::removeLayer(std::shared_ptr<Node> layer) {
161 ngraph::replace_output_update_name(layer->output(0), layer->input_value(0));
164 std::shared_ptr<Node> NetworkHelper::swapMultiplyAndAdd(std::shared_ptr<opset1::Add> addAfterMultiply, const int multiplyBranch) {
165 // Multiply --> Add(addAfterMultiply) ==> Add(new) --> Multiply(new)
166 // That means x*a + b ==> (x + b/a)*a; tries to fold b/a
167 const auto multiply = addAfterMultiply->get_input_node_shared_ptr(multiplyBranch);
169 const auto multiplyParent1 = multiply->get_input_node_shared_ptr(0);
170 const auto multiplyParent2 = multiply->get_input_node_shared_ptr(1);
172 auto multiplyInput = as_type_ptr<opset1::Multiply>(multiplyParent1);
173 auto multiplyConst = as_type_ptr<opset1::Constant>(multiplyParent2);
174 int multiplyInputBranch = 0;
176 if (multiplyConst == nullptr) {
177 multiplyInput = as_type_ptr<opset1::Multiply>(multiplyParent2);
178 multiplyConst = as_type_ptr<opset1::Constant>(multiplyParent1);
179 multiplyInputBranch = 1;
182 if (multiplyConst == nullptr)
183 return addAfterMultiply;
185 const auto x = multiply->get_input_node_shared_ptr(multiplyInputBranch);
186 const auto a = multiply->get_input_node_shared_ptr(multiplyInputBranch == 0 ? 1 : 0);
187 const auto b = addAfterMultiply->get_input_node_shared_ptr(multiplyBranch == 0 ? 1 : 0);
188 std::shared_ptr<Node> bDivA;
190 if (shape_size(b->get_output_shape(0)) == 1 ||
191 shape_size(a->get_output_shape(0)) == 1 ||
192 shape_size(b->get_output_shape(0)) == shape_size(a->get_output_shape(0))) {
193 // safely division to avoid NaN
194 const std::vector<float> bValues = as_type_ptr<opset1::Constant>(b)->cast_vector<float>();
195 const std::vector<float> aValues = as_type_ptr<opset1::Constant>(a)->cast_vector<float>();
196 const bool aBroadcasted = bValues.size() > aValues.size();
197 const bool bBroadcasted = bValues.size() < aValues.size();
198 std::vector<float> bDivAValues(aBroadcasted ? bValues.size() : aValues.size());
200 for (int i = 0; i < bDivAValues.size(); ++i) {
201 const auto bi = bValues[bBroadcasted ? 0 : i];
202 const auto ai = aValues[aBroadcasted ? 0 : i];
203 if (bi != 0.f || ai != 0.f) {
204 bDivAValues[i] = bi / ai;
206 bDivAValues[i] = 0.f;
210 bDivA = std::make_shared<opset1::Constant>(
211 b->get_output_element_type(0),
212 aBroadcasted ? b->get_output_shape(0) : a->get_output_shape(0),
215 bDivA = fold<opset1::Divide>(b, a);
218 std::vector<std::shared_ptr<Node>> inputs{ {}, {} };
223 std::shared_ptr<opset1::Add> newAdd = std::make_shared<op::TypeRelaxed<opset1::Add>>(
224 std::vector<element::Type>{element::f32, element::f32}, std::vector<element::Type>{ element::f32 },
225 ngraph::op::TemporaryReplaceOutputType(inputs[0], element::f32).get(),
226 ngraph::op::TemporaryReplaceOutputType(inputs[1], element::f32).get());
227 copyInfo(addAfterMultiply, newAdd);
229 NetworkHelper::setOutDataPrecision(newAdd, addAfterMultiply->get_output_element_type(0));
231 auto newMultiply = std::make_shared<DequantizationMultiply>(newAdd, a);
232 copyInfo(multiply, newMultiply);
234 replace_node(addAfterMultiply, newMultiply);
238 void NetworkHelper::copyInfo(const std::shared_ptr<Node>& source, const std::shared_ptr<Node>& target) {
239 // TODO: merge_runtime_info with correctly defined DEQUANTIZATION
240 const auto& sourceAttributes = source->get_rt_info();
241 auto& targetAttrubutes = target->get_rt_info();
242 for (auto attribute : sourceAttributes) {
243 targetAttrubutes[attribute.first] = attribute.second;
246 const std::string friendlyName = source->get_friendly_name();
247 if (!friendlyName.empty()) {
248 target->set_friendly_name(friendlyName);
252 void NetworkHelper::cleanRunTimeInfo(const std::shared_ptr<Node>& layer) {
253 auto& rt_info = layer->get_rt_info();
254 auto attributeIter = rt_info.find("DEQUANTIZATION");
255 if (rt_info.find("DEQUANTIZATION") != rt_info.end()) {
256 rt_info.erase(attributeIter);
260 bool NetworkHelper::isScalarLike(std::shared_ptr<opset1::Constant> constant) {
261 return constant->get_all_data_elements_bitwise_identical();
264 bool NetworkHelper::isZero(std::shared_ptr<opset1::Constant> constant) {
265 static const float minQuantizationShift = 1e-32f;
267 auto values = constant->cast_vector<float>();
268 for (size_t i = 0; i < values.size(); ++i) {
269 if (fabs(values[i]) > minQuantizationShift) {
277 std::shared_ptr<opset1::Constant> NetworkHelper::toScalar(std::shared_ptr<opset1::Constant> constant) {
278 assert(isScalarLike(constant));
279 return std::make_shared<opset1::Constant>(constant->get_element_type(), Shape{}, constant->get_data_ptr());
282 std::shared_ptr<Node> NetworkHelper::getConstantInput(std::shared_ptr<Node> node) {
283 std::shared_ptr<Node> constant1 = as_type_ptr<opset1::Constant>(node->input_value(0).get_node_shared_ptr());
285 constant1 = as_type_ptr<opset1::Constant>(node->input_value(1).get_node_shared_ptr());
290 std::shared_ptr<ngraph::opset1::Multiply> NetworkHelper::optimizeMultipliesAfter(std::shared_ptr<Node> node) {
291 std::shared_ptr<ngraph::opset1::Multiply> multiply = as_type_ptr<opset1::Multiply>(node);
293 THROW_IE_LPT_EXCEPTION(*multiply) << "Unexpected operation type";
296 if (multiply->output(0).get_target_inputs().size() == 1) {
297 auto constant1 = getConstantInput(multiply);
298 if (!constant1 || constant1->output(0).get_target_inputs().size() != 1) {
301 auto nextMultiplyInput = *multiply->output(0).get_target_inputs().begin();
302 auto nextMultiply = as_type_ptr<opset1::Multiply>(nextMultiplyInput.get_node()->shared_from_this());
304 auto constant2 = getConstantInput(nextMultiply);
305 auto constant2Inputs = constant2->output(0).get_target_inputs().size();
306 if (!constant2 || constant2->output(0).get_target_inputs().size() != 1) {
310 auto newConst = fold<opset1::Multiply>(constant1, constant2);
312 std::make_shared<opset1::Multiply>(
313 multiply->input_value(1 - constant1->output(0).get_target_inputs().begin()->get_index()),
314 newConst->output(0));
315 copy_runtime_info(multiply, newMultiply);
316 replace_node(nextMultiply, newMultiply);
324 std::shared_ptr<opset1::Constant> NetworkHelper::roundWithTolerance(std::shared_ptr<Node> node, element::Type target_type, float tolerance) {
325 auto constant = as_type_ptr<opset1::Constant>(node);
327 auto values = constant->cast_vector<float>();
329 auto castedConstant = as_type_ptr<opset1::Constant>(fold<opset1::Convert>(constant, target_type));
330 auto castedValues = castedConstant->cast_vector<float>();
332 // TODO: implement with constant folding when ReduceAnd constant folding is ready
333 if (std::equal(values.begin(), values.end(), castedValues.begin(), [tolerance](float a, float b) { return fabs(a - b) < tolerance; })) {
334 return castedConstant;
338 const std::shared_ptr<opset1::Constant>& constant,
339 element::Type target_type,
341 std::vector<float>& values,
342 float increaseValue) -> std::shared_ptr<opset1::Constant> {
343 const auto castedConstant = as_type_ptr<opset1::Constant>(fold<opset1::Convert>(
344 fold<opset1::Add>(constant, std::make_shared<opset1::Constant>(constant->get_output_element_type(0), Shape{ 1 }, increaseValue)),
346 const auto castedValues = castedConstant->cast_vector<float>();
347 if (std::equal(values.begin(), values.end(), castedValues.begin(), [tolerance](float a, float b) { return fabs(a - b) < tolerance; })) {
348 return castedConstant;
354 castedConstant = round(constant, target_type, tolerance, values, 0.5f);
355 if (castedConstant != nullptr) {
356 return castedConstant;
359 castedConstant = round(constant, target_type, tolerance, values, -0.5f);
360 if (castedConstant != nullptr) {
361 return castedConstant;
364 castedConstant = round(constant, target_type, tolerance, values, 1.f);
365 if (castedConstant != nullptr) {
366 return castedConstant;
372 std::shared_ptr<Node> NetworkHelper::fold_fake_quantize(const std::shared_ptr<opset1::FakeQuantize>& fq) {
373 return foldFakeQuantize(fq, false, false);
376 std::shared_ptr<Node> NetworkHelper::fold_fake_quantize(const std::shared_ptr<opset1::FakeQuantize>& fq, const bool roundValues) {
377 return foldFakeQuantize(fq, roundValues, true);
380 void NetworkHelper::foldDequantization(std::shared_ptr<Node>& node, const size_t branchIndex, const bool inPlace) {
381 FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(node, branchIndex, inPlace);
382 if (dequantization.empty() || (dequantization.multiply == nullptr)) {
386 std::shared_ptr<opset1::Constant> constant = as_type_ptr<opset1::Constant>(dequantization.data.get_node_shared_ptr());
387 if ((constant == nullptr) || (constant->output(0).get_target_inputs().size() != 1ul)) {
391 if (dequantization.convert != nullptr) {
392 const std::shared_ptr<Node> result = fold<opset1::Convert>(dequantization.data, dequantization.convert->get_element_type());
393 if (!is_type<opset1::Constant>(result)) {
397 copyInfo(dequantization.convert, result);
399 replace_node(dequantization.convert, result);
400 dequantization = NetworkHelper::getDequantization(node, branchIndex, inPlace);
403 if (dequantization.subtract != nullptr) {
404 if (dequantization.data.get_element_type() != dequantization.subtract->input(1).get_element_type()) {
407 const std::shared_ptr<Node> result = fold<opset1::Subtract>(dequantization.data, dequantization.subtract->get_input_node_shared_ptr(1));
408 if (!is_type<opset1::Constant>(result)) {
412 copyInfo(dequantization.subtract, result);
414 replace_node(dequantization.subtract, result);
415 dequantization = NetworkHelper::getDequantization(node, branchIndex, inPlace);
418 if (dequantization.multiply != nullptr) {
419 if (dequantization.data.get_element_type() != dequantization.multiply->input(1).get_element_type()) {
422 const std::shared_ptr<Node> result = fold<opset1::Multiply>(dequantization.data, dequantization.multiply->get_input_node_shared_ptr(1));
423 if (!is_type<opset1::Constant>(result)) {
427 copyInfo(dequantization.multiply, result);
429 replace_node(dequantization.multiply, result);
430 dequantization = NetworkHelper::getDequantization(node, branchIndex, inPlace);
434 std::shared_ptr<Node> NetworkHelper::foldFakeQuantize(
435 const std::shared_ptr<opset1::FakeQuantize>& fq,
436 const bool roundValuesArg,
437 const bool roundValuesWasSet) {
438 if (is_type<opset1::Constant>(fq->get_input_node_shared_ptr(0)) &&
439 is_type<opset1::Constant>(fq->get_input_node_shared_ptr(1)) &&
440 is_type<opset1::Constant>(fq->get_input_node_shared_ptr(2)) &&
441 is_type<opset1::Constant>(fq->get_input_node_shared_ptr(3)) &&
442 is_type<opset1::Constant>(fq->get_input_node_shared_ptr(4)) &&
443 op::util::constantIsEqualTo(as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(1)), 0.f) &&
444 op::util::constantIsEqualTo(as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(2)), 254.f) &&
445 op::util::constantIsEqualTo(as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(3)), -127.f) &&
446 op::util::constantIsEqualTo(as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(4)), 127.f)) {
447 const auto type1 = fq->input_value(0).get_element_type();
448 const auto type2 = fq->input_value(3).get_element_type();
449 if (type1.is_real() && type2.is_real()) {
450 return fold<opset1::Add>(fq->input_value(0), fq->input_value(3));
452 if (type1.is_real() && !type2.is_real()) {
453 return fold<opset1::Add>(
455 fold<opset1::Convert>(fq->input_value(3), type1));
457 if (!type1.is_real() && type2.is_real()) {
458 return fold<opset1::Add>(
459 fold<opset1::Convert>(fq->input_value(0), type2),
462 return fold<opset1::Add>(
463 fold<opset1::Convert>(fq->input_value(0), element::f32),
464 fold<opset1::Convert>(fq->input_value(3), element::f32));
467 auto constant = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(0));
470 const bool roundValues = roundValuesWasSet ? roundValuesArg : fq->output(0).get_element_type().is_integral();
472 Shape constShape = fq->get_output_shape(0);
473 if (constShape.empty() || constShape.size() > 5lu) {
474 THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected dimensions count " << constShape.size();
478 const size_t OC = constShape[0];
479 const size_t IC = constShape.size() > 1lu ? constShape[1] : 1;
480 const size_t D = constShape.size() > 4lu ? constShape[constShape.size() - 3] : 1;
481 const size_t H = constShape.size() > 2lu ? constShape.size() == 3lu ? constShape[2] : constShape[constShape.size() - 2] : 1;
482 const size_t W = constShape.size() > 3lu ? constShape[constShape.size() - 1] : 1;
484 const auto inputLowValues = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(1))->cast_vector<float>();
485 const auto inputHighValues = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(2))->cast_vector<float>();
486 const auto outputLowValues = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(3))->cast_vector<float>();
487 const auto outputHighValues = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(4))->cast_vector<float>();
489 const size_t inputLowSize = inputLowValues.size();
490 const size_t inputHighSize = inputHighValues.size();
491 const size_t outputLowSize = outputLowValues.size();
492 const size_t outputHighSize = outputHighValues.size();
494 const bool isInputLowBroadcasted = inputLowSize != OC;
495 if ((inputLowSize != 1) && (inputLowSize != OC)) {
496 THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected input low values count " << inputLowSize << " for " << OC << " channels";
498 const bool isInputHighBroadcasted = inputHighSize != OC;
499 if ((inputHighSize != 1) && (inputHighSize != OC)) {
500 THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected input high values count " << inputHighSize << " for " << OC << " channels";
502 const bool isOutputLowBroadcasted = outputLowSize != OC;
503 if ((outputLowSize != 1) && (outputLowSize != OC)) {
504 THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected output low values count " << outputLowSize << " for " << OC << " channels";
506 const bool isOutputHighBroadcasted = outputHighSize != OC;
507 if ((outputHighSize != 1) && (outputHighSize != OC)) {
508 THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected output high values count " << outputHighSize << " for " << OC << " channels";
511 auto levels_1 = fq->get_levels() - 1.f;
513 //const size_t DHW = D * H * W;
514 const size_t IDHW = IC * D * H * W;
516 const auto values = constant->cast_vector<float>();
517 std::vector<float> quantizedValues(OC * IC * D * H * W);
519 for (int oc = 0; oc < OC; ++oc) {
520 for (int iidx = 0; iidx < IDHW; ++iidx) {
521 const float inputLow = inputLowValues[isInputLowBroadcasted ? 0 : oc];
522 const float inputHigh = inputHighValues[isInputHighBroadcasted ? 0 : oc];
523 const float outputLow = outputLowValues[isOutputLowBroadcasted ? 0 : oc];
524 const float outputHigh = outputHighValues[isOutputHighBroadcasted ? 0 : oc];
526 const size_t idx = oc * IDHW + iidx;
528 if (values[idx] <= inputLow) {
529 quantizedValues[idx] = roundValues ? std::roundf(outputLow) : outputLow;
530 } else if (values[idx] > inputHigh) {
531 quantizedValues[idx] = roundValues ? std::roundf(outputHigh) : outputHigh;
533 const float value = std::roundf((values[idx] - inputLow) / (inputHigh - inputLow) * levels_1) /
534 levels_1 * (outputHigh - outputLow) + outputLow;
535 quantizedValues[idx] = roundValues ? std::roundf(value) : value;
540 return std::make_shared<opset1::Constant>(fq->get_output_element_type(0), constShape, quantizedValues);
546 // Decompose FakeQuantize to FakeQuantize with output integer limits (quantize), dequatized MultiplyAdd
547 // To align types the resulting sequence is FakeQuantize -> Convert -> Convert -> MultiplyAdd
548 std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> NetworkHelper::decomposeFakeQuantize(
549 std::shared_ptr<opset1::FakeQuantize> fq,
550 const element::Type precision,
553 const bool hasZeroPoint,
554 const bool updatePrecision) {
555 using std::make_shared;
557 const auto outputLow = fq->input_value(3);
558 const auto outputHigh = fq->input_value(4);
560 std::vector<float> outputLowValues = as_type_ptr<opset1::Constant>(outputLow.get_node_shared_ptr())->cast_vector<float>();
561 std::vector<float> outputHighValues = as_type_ptr<opset1::Constant>(outputHigh.get_node_shared_ptr())->cast_vector<float>();
562 size_t outputSize = outputLowValues.size();
563 std::vector<float> minValues(outputSize, min);
564 std::vector<float> maxValues(outputSize, max);
565 std::vector<float> shifts(outputSize, 0.f);
566 std::vector<float> scales(outputSize);
568 for (int i = 0; i < outputSize; ++i) {
569 if (outputHighValues[i] != outputLowValues[i]) {
570 shifts[i] = (min*outputHighValues[i] - max*outputLowValues[i]) / (outputHighValues[i] - outputLowValues[i]);
571 scales[i] = (outputHighValues[i] - outputLowValues[i]) / (max - min);
572 if (shifts[i] == -0.f) {
576 scales[i] = outputHighValues[i];
582 std::shared_ptr<Node> shift = hasZeroPoint ?
583 std::make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), shifts) :
585 std::shared_ptr<Node> scale = std::make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), scales);
587 auto newMin = make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), minValues);
588 auto newMax = make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), maxValues);
590 if (isScalarLike(newMin)) {
591 newMin = toScalar(newMin);
593 if (isScalarLike(newMax)) {
594 newMax = toScalar(newMax);
598 static const float minQuantizationScale = 1e-32f;
599 static const float maxQuantizationScale = 1e32f;
601 auto scaleValues = scales;
602 bool wasChanged = false;
603 for (size_t i = 0; i < scaleValues.size(); ++i) {
604 const float scale = scaleValues[i];
605 if (fabs(scale) < minQuantizationScale) {
606 scaleValues[i] = minQuantizationScale;
608 } else if (fabs(scale) > maxQuantizationScale) {
609 scaleValues[i] = scale > 0.f ? maxQuantizationScale : -maxQuantizationScale;
615 scale = std::make_shared<opset1::Constant>(scale->output(0).get_element_type(), scale->output(0).get_shape(), scaleValues);
619 if ((shift != nullptr) && isZero(as_type_ptr<opset1::Constant>(shift))) {
623 // Build a substitution sub-graph:
625 std::shared_ptr<ngraph::Node> newFQ = fold_fake_quantize(
626 std::make_shared<op::TypeRelaxed<opset1::FakeQuantize>>(
633 fq->get_auto_broadcast()),
635 NetworkHelper::copyInfo(fq, newFQ);
637 std::shared_ptr<ngraph::Node> convert2;
638 if (updatePrecision) {
639 std::shared_ptr<Node> convert;
640 std::shared_ptr<opset1::Constant> newFqConstant = as_type_ptr<opset1::Constant>(newFQ);
642 if (is_type<opset1::Constant>(newFQ)) {
643 convert = fold<opset1::Convert>(newFQ, precision);
644 } else if (is_type<opset1::FakeQuantize>(newFQ)) {
645 newFQ = setOutDataPrecision(as_type_ptr<opset1::FakeQuantize>(newFQ), precision);
648 THROW_IE_LPT_EXCEPTION(*newFQ) << "unexpected operation type";
651 convert2 = std::make_shared<DequantizationConvert>(convert, element::f32);
652 convert2->set_friendly_name(convert->get_friendly_name() + "/DequantizationConvert");
653 ngraph::copy_runtime_info({ newFQ, convert2 }, convert2);
655 if (newFQ->get_output_element_type(0) != element::f32) {
656 convert2 = std::make_shared<DequantizationConvert>(newFQ, element::f32);
657 convert2->set_friendly_name(newFQ->get_friendly_name() + "/DequantizationConvert");
658 ngraph::copy_runtime_info({ newFQ, convert2 }, convert2);
662 // TODO: why type relaxed?
663 const std::shared_ptr<ngraph::Node> sub = shift == nullptr ?
665 std::make_shared<ngraph::op::TypeRelaxed<DequantizationSubtract>>(convert2 == nullptr ? newFQ : convert2, shift);
666 if (sub != nullptr) {
667 sub->set_friendly_name(newFQ->get_friendly_name() + "/DequantizationSubtract");
668 ngraph::copy_runtime_info({ newFQ, sub }, sub);
671 const std::shared_ptr<ngraph::opset1::Multiply> dequantize = std::make_shared<DequantizationMultiply>(
672 sub == nullptr ? (convert2 == nullptr ? newFQ : convert2) : sub,
674 dequantize->set_friendly_name(newFQ->get_friendly_name() + "/DequantizationMultiply");
675 ngraph::copy_runtime_info({ newFQ, dequantize }, dequantize);
677 replace_node(fq, dequantize);
679 return std::make_tuple(newFQ, dequantize);
682 std::shared_ptr<opset1::FakeQuantize> NetworkHelper::updateFakeQuantize(
683 std::shared_ptr<opset1::FakeQuantize> fq,
684 element::Type precision,
687 auto newMin = std::make_shared<opset1::Constant>(fq->get_output_element_type(0), Shape{}, min);
688 auto newMax = std::make_shared<opset1::Constant>(fq->get_output_element_type(0), Shape{}, max);
690 std::shared_ptr<opset1::FakeQuantize> newFQ = std::make_shared<ngraph::op::TypeRelaxed<opset1::FakeQuantize>>(
697 fq->get_auto_broadcast());
699 NetworkHelper::setOutDataPrecision(newFQ, precision);
700 replace_node(fq, newFQ);
702 newFQ->set_friendly_name(fq->get_friendly_name());
706 FakeQuantizeDequantization NetworkHelper::makeDequantization(
707 const float dequantizationMul,
708 const float dequantizationSub,
709 const ngraph::element::Type originalPrecision,
710 const ngraph::Shape dataNodeOutputShape,
711 element::Type precision,
714 // TODO: we create input here! we really need it here?
715 const std::shared_ptr<opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(precision, dataNodeOutputShape);
716 std::shared_ptr<ngraph::Node> parent = input;
718 // TODO: convert should be optional: where is updatePrecision?
719 std::shared_ptr<DequantizationConvert> convert;
721 convert = std::make_shared<DequantizationConvert>(
727 std::shared_ptr<DequantizationSubtract> subtract;
728 if (std::abs(dequantizationSub) > 1e-6) {
729 subtract = std::make_shared<ngraph::op::TypeRelaxed<DequantizationSubtract>>(
731 std::make_shared<ngraph::opset1::Constant>(originalPrecision, ngraph::Shape({}), std::vector<float>({ dequantizationSub })));
732 subtract->set_output_type(0, originalPrecision, subtract->get_output_partial_shape(0));
737 std::shared_ptr<ngraph::opset1::Multiply> multiply = std::make_shared<DequantizationMultiply>(
739 std::make_shared<ngraph::opset1::Constant>(originalPrecision, ngraph::Shape({}), std::vector<float>({ dequantizationMul })));
741 return FakeQuantizeDequantization(input, convert, subtract, multiply);
744 FakeQuantizeDequantization NetworkHelper::createDequantizationFromFakeQuantize(
745 std::shared_ptr<opset1::FakeQuantize> fq,
746 element::Type precision,
749 const bool hasZeroPoint,
750 const bool updatePrecision) {
751 using std::make_shared;
753 const ngraph::element::Type_t fqPrecision = fq->get_output_element_type(0);
754 auto newMin = make_shared<opset1::Constant>(fqPrecision, Shape{}, min);
755 auto newMax = make_shared<opset1::Constant>(fqPrecision, Shape{}, max);
757 auto outputLow = fq->input_value(3);
758 auto outputHigh = fq->input_value(4);
760 // TODO: threshold values have to used here to avoid shifts
762 const std::shared_ptr<Node> scale = fold<opset1::Divide>(
763 fold<opset1::Subtract>(outputHigh, outputLow),
764 fold<opset1::Subtract>(newMax, newMin));
766 std::shared_ptr<Node> shift = hasZeroPoint ?
767 fold<opset1::Divide>(
768 fold<opset1::Subtract>(fold<opset1::Multiply>(newMin, outputHigh), fold<opset1::Multiply>(newMax, outputLow)),
769 fold<opset1::Subtract>(outputHigh, outputLow)) :
772 if (shift != nullptr) {
773 std::shared_ptr<opset1::Constant> shiftConst = as_type_ptr<opset1::Constant>(shift);
774 if (isScalarLike(shiftConst)) {
775 auto scalar = toScalar(shiftConst);
776 if (op::util::constantIsEqualTo(scalar, 0)) {
782 const auto input = std::make_shared<ngraph::opset1::Parameter>(precision, fq->get_output_shape(0));
783 const std::shared_ptr<ngraph::opset1::Convert> convert = std::make_shared<DequantizationConvert>(
785 fq->get_output_element_type(0));
787 const std::shared_ptr<ngraph::opset1::Subtract> subtract = shift == nullptr ?
789 make_shared<ngraph::op::TypeRelaxed<DequantizationSubtract>>(convert, shift);
790 if (subtract != nullptr) {
791 subtract->set_output_type(0, fq->get_output_element_type(0), subtract->get_output_partial_shape(0));
794 const std::shared_ptr<ngraph::opset1::Multiply> multiply = std::make_shared<DequantizationMultiply>(
795 subtract == nullptr ? static_cast<std::shared_ptr<Node>>(convert) : subtract,
798 return FakeQuantizeDequantization(fq, convert, subtract, multiply);
801 FakeQuantizeDequantization NetworkHelper::getDequantization(const std::shared_ptr<Node> node, const size_t parentIndex, const bool inPlace) {
802 auto getDataIndex = [](const std::shared_ptr<ngraph::Node>& node) {
803 if (is_type<opset1::Constant>(node->get_input_node_ptr(1))) {
810 Output<Node> dataNode = inPlace ? node : node->input_value(parentIndex);
812 const std::shared_ptr<ngraph::opset1::Multiply> multiply = as_type_ptr<ngraph::opset1::Multiply>(dataNode.get_node_shared_ptr());
813 if (multiply != nullptr) {
814 if (!is_type<opset1::Constant>(multiply->get_input_node_ptr(0)) && !is_type<opset1::Constant>(multiply->get_input_node_ptr(1))) {
815 return FakeQuantizeDequantization();
817 dataNode = multiply->get_input_source_output(getDataIndex(multiply));
820 const std::shared_ptr<opset1::Subtract> subtract = as_type_ptr<ngraph::opset1::Subtract>(dataNode.get_node_shared_ptr());
821 if (subtract != nullptr) {
822 if (!is_type<opset1::Constant>(subtract->get_input_node_ptr(0)) && !is_type<opset1::Constant>(subtract->get_input_node_ptr(1))) {
823 return FakeQuantizeDequantization(dataNode, nullptr, nullptr, multiply);
825 dataNode = subtract->get_input_source_output(getDataIndex(subtract));
828 const std::shared_ptr<opset1::Convert> convert = as_type_ptr<opset1::Convert>(dataNode.get_node_shared_ptr());
829 if (convert != nullptr) {
830 if ((convert->input(0).get_element_type() != element::i8) && (convert->input(0).get_element_type() != element::u8) &&
831 (convert->output(0).get_element_type() != element::f32)) {
832 return FakeQuantizeDequantization(dataNode, nullptr, subtract, multiply);
834 dataNode = convert->get_input_source_output(0);
837 return FakeQuantizeDequantization(dataNode, convert, subtract, multiply);
840 FakeQuantizeDequantizationValues NetworkHelper::createEmptyValues(const FakeQuantizeDequantization& dequantization) {
841 std::shared_ptr<Node> parent = dequantization.convert ? dequantization.convert : dequantization.data.get_node_shared_ptr();
843 std::shared_ptr<Node> multiply1Const = dequantization.multiply ?
844 dequantization.multiply->get_input_node_shared_ptr(1)->clone_with_new_inputs({}) :
845 std::make_shared<opset1::Constant>(parent->get_output_element_type(0), Shape({}), std::vector<float>({ 1.f }));
847 std::shared_ptr<Node> subtract1Const = dequantization.subtract ?
848 dequantization.subtract->get_input_node_shared_ptr(1)->clone_with_new_inputs({}) :
849 std::make_shared<opset1::Constant>(parent->get_output_element_type(0), Shape({}), std::vector<float>({ 0.f }));
851 subtract1Const->set_output_type(0, multiply1Const->get_output_element_type(0), subtract1Const->get_output_partial_shape(0));
853 return FakeQuantizeDequantizationValues(subtract1Const, multiply1Const);
856 bool NetworkHelper::isZeroConst(const std::shared_ptr<Node>& node) {
857 std::shared_ptr<opset1::Constant> constant = as_type_ptr<opset1::Constant>(node);
859 if (constant == nullptr)
862 if (NetworkHelper::isScalarLike(constant)) {
863 auto scalar = NetworkHelper::toScalar(constant);
864 if (op::util::constantIsEqualTo(scalar, 0)) {
874 std::shared_ptr<Node> NetworkHelper::optimizeSubtract(std::shared_ptr<opset1::Subtract> subtract) {
875 auto convertOnSubtract = subtract->input_value(0).get_node_shared_ptr();
876 if (as_type_ptr<opset1::Convert>(convertOnSubtract) == nullptr) {
880 // TODO: replace assert to condition and omit conversion part if there is no convert
881 // TODO: also check convertInputType to understand if we really want to propagate type
882 assert(as_type_ptr<opset1::Convert>(convertOnSubtract));
883 const element::Type convertInputType = convertOnSubtract->get_input_element_type(0);
884 const element::Type convertOutputType = convertOnSubtract->get_output_element_type(0);
886 if (!convertOutputType.is_real()) {
890 auto data = convertOnSubtract->input_value(0);
891 auto shift = subtract->input_value(1).get_node_shared_ptr();
892 auto roundedShift = NetworkHelper::roundWithTolerance(shift, convertInputType);
894 std::shared_ptr<Node> replacement;
895 if (roundedShift->get_element_type() == convertInputType) {
896 // Propagate convertInputType down
897 replacement = std::make_shared<op::TypeRelaxed<opset1::Subtract>>(data, roundedShift);
898 NetworkHelper::copyInfo(subtract, replacement);
899 NetworkHelper::setOutDataPrecisionForTypeRelaxed(replacement, convertOutputType);
900 replace_node(subtract, replacement);
903 // We lose the tail conversion here; not needed if the next node is a TypeRelaxed
904 // TODO: check cases when Convert should be preserved
906 // Try to optimize Add out if constant is zero
907 // TODO: don't remove operation here: don't create this Subtraction operation in FQ decomposition
908 // if (isScalarLike(roundedShift)) {
909 // auto scalar = distillToScalar(roundedShift);
910 // if (op::util::constantIsEqualTo(scalar, 0)) {
911 // replace_node(replacement, replacement->input_value(0).get_node_shared_ptr());
912 // replacement = nullptr;
919 NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter(
920 const std::shared_ptr<ngraph::Node>& operation,
921 const FakeQuantizeDequantization& dequantization,
922 const bool updatePrecision,
923 const bool moveSubtract) {
924 std::vector<Output<Node>> inputs(operation->get_input_size());
925 for (size_t i = 0; i < operation->get_input_size(); ++i) {
926 inputs[i] = operation->get_input_node_shared_ptr(i);
929 const size_t dequantizationIndex = getChildInputIndex(dequantization.multiply, operation);
930 inputs[dequantizationIndex] = moveSubtract ?
931 dequantization.data :
932 (dequantization.subtract == nullptr ? dequantization.data : dequantization.subtract);
934 const std::shared_ptr<ngraph::Node> newOperation = operation->clone_with_new_inputs(inputs);
935 newOperation->set_friendly_name(operation->get_friendly_name());
936 ngraph::copy_runtime_info(operation, newOperation);
938 if (updatePrecision) {
939 auto op = std::dynamic_pointer_cast<ngraph::op::TypeRelaxedBase>(newOperation);
941 THROW_IE_LPT_EXCEPTION(*newOperation) << "not possible to update precision for not TypeRelaxedBase operation";
943 op->set_overridden_output_type(newOperation->get_input_element_type(0));
944 std::dynamic_pointer_cast<ngraph::Node>(newOperation)->validate_and_infer_types();
947 const bool shouldConvert = (newOperation->get_output_element_type(0) != dequantization.multiply->get_output_element_type(0));
949 auto parent = newOperation;
951 const auto convertOutputPrecision = dequantization.convert != nullptr ?
952 dequantization.convert->get_output_element_type(0) :
953 dequantization.multiply->get_output_element_type(0);
954 parent = std::make_shared<DequantizationConvert>(parent, convertOutputPrecision);
955 ngraph::copy_runtime_info({ newOperation, parent }, parent);
958 if (moveSubtract && (dequantization.subtract != nullptr)) {
959 auto subtractConstant = dequantization.subtract->get_input_node_shared_ptr(1);
960 const element::Type parentPrecision = parent->get_output_element_type(0);
961 if (parentPrecision.bitwidth() < subtractConstant->output(0).get_element_type().bitwidth()) {
962 THROW_IE_LPT_EXCEPTION(*parent) <<
963 "unexpected precisions: on data " << parent->get_friendly_name() << ":" << parentPrecision <<
964 ", subtract dequantization constant " << subtractConstant->get_friendly_name() << ":" << subtractConstant->output(0).get_element_type();
967 parent = std::make_shared<DequantizationSubtract>(
969 subtractConstant->output(0).get_element_type() == parentPrecision ?
971 fold<opset1::Convert>(subtractConstant->output(0), parentPrecision));
972 ngraph::copy_runtime_info({ newOperation, parent }, parent);
975 if (dequantization.multiply != nullptr) {
976 auto multiplyConstant = dequantization.multiply->get_input_node_shared_ptr(1);
977 const element::Type parentPrecision = parent->get_output_element_type(0);
978 if (parentPrecision.bitwidth() < multiplyConstant->output(0).get_element_type().bitwidth()) {
979 THROW_IE_LPT_EXCEPTION(*parent) <<
980 "unexpected precisions: on data " << parent->get_friendly_name() << ":" << parentPrecision <<
981 ", multiply dequantization constant " << multiplyConstant->get_friendly_name() << ":" << multiplyConstant->output(0).get_element_type();
984 parent = std::make_shared<DequantizationMultiply>(
986 multiplyConstant->output(0).get_element_type() == parentPrecision ?
988 fold<opset1::Convert>(multiplyConstant->output(0), parentPrecision));
989 ngraph::copy_runtime_info({ newOperation, parent }, parent);
991 replace_node(operation, parent);
993 if ((!moveSubtract) && (dequantization.convert != nullptr) && (dequantization.subtract != nullptr)) {
994 NetworkHelper::cleanRunTimeInfo(dequantization.subtract);
995 optimizeSubtract(dequantization.subtract);
998 return InsertDequantizationResult(newOperation, parent);
1001 void NetworkHelper::removeConvertIfPossible(
1002 const std::shared_ptr<ngraph::Node>& operation,
1003 const FakeQuantizeDequantization& dequantization) {
1004 const element::Type precisionBeforeConvert = dequantization.convert->input(0).get_element_type();
1006 if (checkConstantValuePrecision(precisionBeforeConvert, dequantization.subtract->get_input_node_shared_ptr(1))) {
1007 auto newSubtract = dequantization.subtract->clone_with_new_inputs({
1008 dequantization.convert->get_input_node_shared_ptr(0),
1009 fold<opset1::Convert>(dequantization.subtract->get_input_node_shared_ptr(1), precisionBeforeConvert) });
1010 replace_node(dequantization.subtract, newSubtract);
1014 bool NetworkHelper::checkConstantValuePrecision(const element::Type expectedPrecision, const std::shared_ptr<Node>& constant) {
1015 if (expectedPrecision.is_signed()) {
1019 std::shared_ptr<opset1::Constant> constantOp = as_type_ptr<opset1::Constant>(constant);
1020 if (constantOp == nullptr) {
1024 const auto values = constantOp->cast_vector<float>();
1025 const bool convertCanBeRemoved =
1026 (expectedPrecision.is_signed() || (std::all_of(values.begin(), values.end(), [](const float value) { return value >= 0.f; })));
1027 return convertCanBeRemoved;
1030 size_t NetworkHelper::getChildInputIndex(const std::shared_ptr<ngraph::Node>& parent, const std::shared_ptr<ngraph::Node>& child) {
1031 for (size_t i = 0; i < child->get_input_size(); ++i) {
1032 if (parent.get() == child->get_input_node_ptr(i)) {
1036 THROW_IE_LPT_EXCEPTION(*child) << "child input index between " <<
1037 parent->get_friendly_name() << " and " << child->get_friendly_name() << " was not found";
1040 size_t NetworkHelper::getParentOutputIndex(const std::shared_ptr<ngraph::Node>& parent, const std::shared_ptr<ngraph::Node>& child) {
1041 for (size_t i = 0; i < parent->get_output_size(); ++i) {
1042 const auto& targetInputs = parent->output(i).get_target_inputs();
1043 for (const auto& targetInput : targetInputs) {
1044 if (targetInput.get_node() == child.get()) {
1049 THROW_IE_LPT_EXCEPTION(*child) << "parent output index between " <<
1050 parent->get_friendly_name() << " and " << child->get_friendly_name() << " was not found";
1053 std::vector<Output<Node>> NetworkHelper::getInputs(const std::shared_ptr<ngraph::Node>& node) {
1054 std::vector<Output<Node>> inputs(node->get_input_size());
1055 for (size_t i = 0; i < node->get_input_size(); ++i) {
1056 inputs[i] = node->get_input_node_shared_ptr(i);
1061 std::shared_ptr<Node> NetworkHelper::toScalarIfPossible(std::shared_ptr<Node> node) {
1062 std::shared_ptr<opset1::Constant> constant = as_type_ptr<opset1::Constant>(node);
1063 if (constant == nullptr) {
1067 if (!NetworkHelper::isScalarLike(constant)) {
1071 return NetworkHelper::toScalar(constant);
1074 } // namespace low_precision
1076 } // namespace ngraph