1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <low_precision/network_helper.hpp>
13 #include <unordered_set>
18 #include <ngraph/rt_info.hpp>
19 #include "low_precision/common/ie_lpt_exception.hpp"
20 #include "low_precision/common/dequantization_op.hpp"
24 namespace low_precision {
26 // Return true if `type` can be castable to at least one of `type`
27 bool NetworkHelper::is_castable_to_one_of(NodeTypeInfo type, const std::unordered_set<NodeTypeInfo>& types) {
28 for (auto another : types) {
29 if (type.is_castable(another)) {
36 // Collect and return a vector with all nodes that consumes any of the `node` output
37 std::vector<Input<Node>> NetworkHelper::consumer_inputs(std::shared_ptr<Node> node) {
38 std::vector<Input<Node>> result;
39 for (const auto& output_port : node->outputs()) {
40 for (const auto &input : output_port.get_target_inputs()) {
41 result.push_back(input);
47 std::vector<std::shared_ptr<Node>> NetworkHelper::consumers(std::shared_ptr<Node> node) {
48 auto inputs = consumer_inputs(node);
49 std::vector<std::shared_ptr<Node>> result(inputs.size());
50 std::transform(inputs.begin(), inputs.end(), result.begin(), [](Input<Node> input){ return input.get_node()->shared_from_this(); });
54 int NetworkHelper::onWeightsInDepth(std::shared_ptr<Node> layer) {
55 const std::vector<std::shared_ptr<Node>> children = consumers(layer);
56 for (std::shared_ptr<Node> child : children) {
57 if ((is_type<opset1::Convolution>(child) ||
58 is_type<opset1::GroupConvolution>(child) ||
59 is_type<opset1::MatMul>(child)) &&
60 (child->inputs().size() >= 2lu)) {
61 const std::vector<std::shared_ptr<Node>> parents = getParentsRecursivelyExceptTypes(child, {}, 1);
62 for (const std::shared_ptr<Node>& parent : parents) {
63 if (parent.get() == layer.get()) {
70 const int result = onWeightsInDepth(child);
78 bool NetworkHelper::onWeights(std::shared_ptr<Node> layer) {
79 const int result = onWeightsInDepth(layer);
83 size_t NetworkHelper::getOutputChannelsCount(std::shared_ptr<const Node> layer, bool isOnWeights) {
84 if (layer->outputs().size() == 0) {
85 THROW_TRANSFORMATION_EXCEPTION << "Layer " << layer->get_friendly_name() << " doesn't have output tensors";
88 if (layer->outputs().size() > 1) {
89 THROW_TRANSFORMATION_EXCEPTION << "Layer " << layer->get_friendly_name() << " has too many output tensors, expected one";
92 PartialShape shape = layer->get_output_partial_shape(0);
93 if (shape.rank() == 0) {
94 THROW_TRANSFORMATION_EXCEPTION << "Invalid dimensions count (0) in output of " << layer->get_friendly_name() << " layer on weights";
97 return shape[0].get_length();
99 if (shape.rank() == 1) {
100 return shape[0].get_length();
102 return shape[1].get_length();
106 std::vector<std::shared_ptr<Node>> NetworkHelper::getParentsRecursivelyExceptTypes(
107 std::shared_ptr<Node> layer,
108 const std::unordered_set<NodeTypeInfo>& exceptionLayerTypes,
109 const int portIndex) {
110 std::vector<std::shared_ptr<Node>> parents;
112 for (auto input : layer->inputs()) {
113 if ((portIndex == -1) || (portIndex == i)) {
114 auto parent = input.get_source_output().get_node_shared_ptr();
115 if (is_castable_to_one_of(parent->get_type_info(), exceptionLayerTypes)) {
116 const std::vector<std::shared_ptr<Node>> tmpParents = getParentsRecursivelyExceptTypes(parent, exceptionLayerTypes);
117 parents.insert(parents.end(), tmpParents.begin(), tmpParents.end());
119 parents.push_back(parent);
128 size_t NetworkHelper::getInputChannelsCount(std::shared_ptr<Node> layer) {
129 if (layer->get_input_size() == 0) {
130 THROW_TRANSFORMATION_EXCEPTION << "There are no input layers";
133 PartialShape shape = layer->get_input_partial_shape(0);
134 if (shape.rank().get_length() <= 1) {
135 THROW_TRANSFORMATION_EXCEPTION << "Invalid dimensions count (0) in input of " << layer->get_friendly_name();
138 return shape[1].get_length();
141 size_t NetworkHelper::getGroupsCount(std::shared_ptr<Node> layer) {
142 if (as_type_ptr<opset1::Convolution>(layer)) {
144 } else if (auto group_convolution = as_type_ptr<opset1::GroupConvolution>(layer)) {
145 return layer->get_input_shape(1)[0]; // input weights for opset1::GC is in format GOI..., see the specification
147 THROW_TRANSFORMATION_EXCEPTION << "Invalid layer type of " << layer->get_friendly_name() << "; expected Convolutino or GroupConvolution";
151 // Assumin tensor in NC... layout, append necessary number of 1s to shape to align it to a give rank
152 Shape NetworkHelper::alignShapeForChannelDim(const Shape& shape, Rank rank) {
153 assert(shape.size() == 1);
154 assert(rank.is_static());
155 Shape result = shape;
156 result.resize(rank.get_length() - 1, 1);
160 void NetworkHelper::removeLayer(std::shared_ptr<Node> layer) {
161 ngraph::replace_output_update_name(layer->output(0), layer->input_value(0));
164 std::shared_ptr<Node> NetworkHelper::swapMultiplyAndAdd(std::shared_ptr<opset1::Add> addAfterMultiply, const int multiplyBranch) {
165 // Multiply --> Add(addAfterMultiply) ==> Add(new) --> Multiply(new)
166 // That means x*a + b ==> (x + b/a)*a; tries to fold b/a
167 const auto multiply = addAfterMultiply->get_input_node_shared_ptr(multiplyBranch);
169 const auto multiplyParent1 = multiply->get_input_node_shared_ptr(0);
170 const auto multiplyParent2 = multiply->get_input_node_shared_ptr(1);
172 auto multiplyInput = as_type_ptr<opset1::Multiply>(multiplyParent1);
173 auto multiplyConst = as_type_ptr<opset1::Constant>(multiplyParent2);
174 int multiplyInputBranch = 0;
176 if (multiplyConst == nullptr) {
177 multiplyInput = as_type_ptr<opset1::Multiply>(multiplyParent2);
178 multiplyConst = as_type_ptr<opset1::Constant>(multiplyParent1);
179 multiplyInputBranch = 1;
182 if (multiplyConst == nullptr)
183 return addAfterMultiply;
185 const auto x = multiply->get_input_node_shared_ptr(multiplyInputBranch);
186 const auto a = multiply->get_input_node_shared_ptr(multiplyInputBranch == 0 ? 1 : 0);
187 const auto b = addAfterMultiply->get_input_node_shared_ptr(multiplyBranch == 0 ? 1 : 0);
188 std::shared_ptr<Node> bDivA;
190 if (shape_size(b->get_output_shape(0)) == 1 ||
191 shape_size(a->get_output_shape(0)) == 1 ||
192 shape_size(b->get_output_shape(0)) == shape_size(a->get_output_shape(0))) {
193 // safely division to avoid NaN
194 const std::vector<float> bValues = as_type_ptr<opset1::Constant>(b)->cast_vector<float>();
195 const std::vector<float> aValues = as_type_ptr<opset1::Constant>(a)->cast_vector<float>();
196 const bool aBroadcasted = bValues.size() > aValues.size();
197 const bool bBroadcasted = bValues.size() < aValues.size();
198 std::vector<float> bDivAValues(aBroadcasted ? bValues.size() : aValues.size());
200 for (int i = 0; i < bDivAValues.size(); ++i) {
201 const auto bi = bValues[bBroadcasted ? 0 : i];
202 const auto ai = aValues[aBroadcasted ? 0 : i];
203 if (bi != 0.f || ai != 0.f) {
204 bDivAValues[i] = bi / ai;
206 bDivAValues[i] = 0.f;
210 bDivA = std::make_shared<opset1::Constant>(
211 b->get_output_element_type(0),
212 aBroadcasted ? b->get_output_shape(0) : a->get_output_shape(0),
215 bDivA = fold<opset1::Divide>(b, a);
218 std::vector<std::shared_ptr<Node>> inputs{ {}, {} };
223 std::shared_ptr<opset1::Add> newAdd = std::make_shared<op::TypeRelaxed<opset1::Add>>(
224 std::vector<element::Type>{element::f32, element::f32}, std::vector<element::Type>{ element::f32 },
225 ngraph::op::TemporaryReplaceOutputType(inputs[0], element::f32).get(),
226 ngraph::op::TemporaryReplaceOutputType(inputs[1], element::f32).get());
227 copyInfo(addAfterMultiply, newAdd);
229 NetworkHelper::setOutDataPrecision(newAdd, addAfterMultiply->get_output_element_type(0));
231 auto newMultiply = std::make_shared<DequantizationMultiply>(newAdd, a);
232 copyInfo(multiply, newMultiply);
234 replace_node(addAfterMultiply, newMultiply);
238 void NetworkHelper::copyInfo(const std::shared_ptr<Node>& source, const std::shared_ptr<Node>& target) {
239 // TODO: merge_runtime_info with correctly defined DEQUANTIZATION
240 const auto& sourceAttributes = source->get_rt_info();
241 auto& targetAttrubutes = target->get_rt_info();
242 for (auto attribute : sourceAttributes) {
243 targetAttrubutes[attribute.first] = attribute.second;
246 const std::string friendlyName = source->get_friendly_name();
247 if (!friendlyName.empty()) {
248 target->set_friendly_name(friendlyName);
252 void NetworkHelper::cleanRunTimeInfo(const std::shared_ptr<Node>& layer) {
253 auto& rt_info = layer->get_rt_info();
254 auto attributeIter = rt_info.find("DEQUANTIZATION");
255 if (rt_info.find("DEQUANTIZATION") != rt_info.end()) {
256 rt_info.erase(attributeIter);
260 bool NetworkHelper::isScalarLike(std::shared_ptr<opset1::Constant> constant) {
261 return constant->get_all_data_elements_bitwise_identical();
264 bool NetworkHelper::isZero(std::shared_ptr<opset1::Constant> constant) {
265 static const float minQuantizationShift = 1e-32f;
267 auto values = constant->cast_vector<float>();
268 for (size_t i = 0; i < values.size(); ++i) {
269 if (fabs(values[i]) > minQuantizationShift) {
277 std::shared_ptr<opset1::Constant> NetworkHelper::toScalar(std::shared_ptr<opset1::Constant> constant) {
278 assert(isScalarLike(constant));
279 return std::make_shared<opset1::Constant>(constant->get_element_type(), Shape{}, constant->get_data_ptr());
282 std::shared_ptr<Node> NetworkHelper::getConstantInput(std::shared_ptr<Node> node) {
283 std::shared_ptr<Node> constant1 = as_type_ptr<opset1::Constant>(node->input_value(0).get_node_shared_ptr());
285 constant1 = as_type_ptr<opset1::Constant>(node->input_value(1).get_node_shared_ptr());
290 std::shared_ptr<ngraph::opset1::Multiply> NetworkHelper::optimizeMultipliesAfter(std::shared_ptr<Node> node) {
291 std::shared_ptr<ngraph::opset1::Multiply> multiply = as_type_ptr<opset1::Multiply>(node);
293 THROW_IE_LPT_EXCEPTION(*multiply) << "Unexpected operation type";
296 if (multiply->output(0).get_target_inputs().size() == 1) {
297 auto constant1 = getConstantInput(multiply);
298 if (!constant1 || constant1->output(0).get_target_inputs().size() != 1) {
301 auto nextMultiplyInput = *multiply->output(0).get_target_inputs().begin();
302 auto nextMultiply = as_type_ptr<opset1::Multiply>(nextMultiplyInput.get_node()->shared_from_this());
304 auto constant2 = getConstantInput(nextMultiply);
305 auto constant2Inputs = constant2->output(0).get_target_inputs().size();
306 if (!constant2 || constant2->output(0).get_target_inputs().size() != 1) {
310 auto newConst = fold<opset1::Multiply>(constant1, constant2);
312 std::make_shared<opset1::Multiply>(
313 multiply->input_value(1 - constant1->output(0).get_target_inputs().begin()->get_index()),
314 newConst->output(0));
315 copy_runtime_info(multiply, newMultiply);
316 replace_node(nextMultiply, newMultiply);
324 std::shared_ptr<opset1::Constant> NetworkHelper::round(std::shared_ptr<Node> node, element::Type target_type) {
325 const auto constant = as_type_ptr<opset1::Constant>(node);
328 const auto castedConstant = as_type_ptr<ngraph::opset1::Constant>(fold<op::v0::Convert>(
329 fold<ngraph::op::v5::Round>(constant->output(0), ngraph::op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO),
332 return castedConstant;
335 std::shared_ptr<Node> NetworkHelper::fold_fake_quantize(const std::shared_ptr<opset1::FakeQuantize>& fq) {
336 return foldFakeQuantize(fq, false, false);
339 std::shared_ptr<Node> NetworkHelper::fold_fake_quantize(const std::shared_ptr<opset1::FakeQuantize>& fq, const bool roundValues) {
340 return foldFakeQuantize(fq, roundValues, true);
343 void NetworkHelper::foldDequantization(std::shared_ptr<Node>& node, const size_t branchIndex, const bool inPlace) {
344 FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(node, branchIndex, inPlace);
345 if (dequantization.empty() || (dequantization.multiply == nullptr)) {
349 std::shared_ptr<opset1::Constant> constant = as_type_ptr<opset1::Constant>(dequantization.data.get_node_shared_ptr());
350 if ((constant == nullptr) || (constant->output(0).get_target_inputs().size() != 1ul)) {
354 if (dequantization.convert != nullptr) {
355 const std::shared_ptr<Node> result = fold<opset1::Convert>(dequantization.data, dequantization.convert->get_element_type());
356 if (!is_type<opset1::Constant>(result)) {
360 copyInfo(dequantization.convert, result);
362 replace_node(dequantization.convert, result);
363 dequantization = NetworkHelper::getDequantization(node, branchIndex, inPlace);
366 if (dequantization.subtract != nullptr) {
367 if (dequantization.data.get_element_type() != dequantization.subtract->input(1).get_element_type()) {
370 const std::shared_ptr<Node> result = fold<opset1::Subtract>(dequantization.data, dequantization.subtract->get_input_node_shared_ptr(1));
371 if (!is_type<opset1::Constant>(result)) {
375 copyInfo(dequantization.subtract, result);
377 replace_node(dequantization.subtract, result);
378 dequantization = NetworkHelper::getDequantization(node, branchIndex, inPlace);
381 if (dequantization.multiply != nullptr) {
382 if (dequantization.data.get_element_type() != dequantization.multiply->input(1).get_element_type()) {
385 const std::shared_ptr<Node> result = fold<opset1::Multiply>(dequantization.data, dequantization.multiply->get_input_node_shared_ptr(1));
386 if (!is_type<opset1::Constant>(result)) {
390 copyInfo(dequantization.multiply, result);
392 replace_node(dequantization.multiply, result);
393 dequantization = NetworkHelper::getDequantization(node, branchIndex, inPlace);
397 std::shared_ptr<Node> NetworkHelper::foldFakeQuantize(
398 const std::shared_ptr<opset1::FakeQuantize>& fq,
399 const bool roundValuesArg,
400 const bool roundValuesWasSet) {
401 if (is_type<opset1::Constant>(fq->get_input_node_shared_ptr(0)) &&
402 is_type<opset1::Constant>(fq->get_input_node_shared_ptr(1)) &&
403 is_type<opset1::Constant>(fq->get_input_node_shared_ptr(2)) &&
404 is_type<opset1::Constant>(fq->get_input_node_shared_ptr(3)) &&
405 is_type<opset1::Constant>(fq->get_input_node_shared_ptr(4)) &&
406 op::util::constantIsEqualTo(as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(1)), 0.f) &&
407 op::util::constantIsEqualTo(as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(2)), 254.f) &&
408 op::util::constantIsEqualTo(as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(3)), -127.f) &&
409 op::util::constantIsEqualTo(as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(4)), 127.f)) {
410 const auto type1 = fq->input_value(0).get_element_type();
411 const auto type2 = fq->input_value(3).get_element_type();
412 if (type1.is_real() && type2.is_real()) {
413 return fold<opset1::Add>(fq->input_value(0), fq->input_value(3));
415 if (type1.is_real() && !type2.is_real()) {
416 return fold<opset1::Add>(
418 fold<opset1::Convert>(fq->input_value(3), type1));
420 if (!type1.is_real() && type2.is_real()) {
421 return fold<opset1::Add>(
422 fold<opset1::Convert>(fq->input_value(0), type2),
425 return fold<opset1::Add>(
426 fold<opset1::Convert>(fq->input_value(0), element::f32),
427 fold<opset1::Convert>(fq->input_value(3), element::f32));
430 auto constant = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(0));
433 const bool roundValues = roundValuesWasSet ? roundValuesArg : fq->output(0).get_element_type().is_integral();
435 Shape constShape = fq->get_output_shape(0);
436 if (constShape.empty() || constShape.size() > 5lu) {
437 THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected dimensions count " << constShape.size();
441 const size_t OC = constShape[0];
442 const size_t IC = constShape.size() > 1lu ? constShape[1] : 1;
443 const size_t D = constShape.size() > 4lu ? constShape[constShape.size() - 3] : 1;
444 const size_t H = constShape.size() > 2lu ? constShape.size() == 3lu ? constShape[2] : constShape[constShape.size() - 2] : 1;
445 const size_t W = constShape.size() > 3lu ? constShape[constShape.size() - 1] : 1;
447 const auto inputLowValues = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(1))->cast_vector<float>();
448 const auto inputHighValues = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(2))->cast_vector<float>();
449 const auto outputLowValues = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(3))->cast_vector<float>();
450 const auto outputHighValues = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(4))->cast_vector<float>();
452 const size_t inputLowSize = inputLowValues.size();
453 const size_t inputHighSize = inputHighValues.size();
454 const size_t outputLowSize = outputLowValues.size();
455 const size_t outputHighSize = outputHighValues.size();
457 const bool isInputLowBroadcasted = inputLowSize != OC;
458 if ((inputLowSize != 1) && (inputLowSize != OC)) {
459 THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected input low values count " << inputLowSize << " for " << OC << " channels";
461 const bool isInputHighBroadcasted = inputHighSize != OC;
462 if ((inputHighSize != 1) && (inputHighSize != OC)) {
463 THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected input high values count " << inputHighSize << " for " << OC << " channels";
465 const bool isOutputLowBroadcasted = outputLowSize != OC;
466 if ((outputLowSize != 1) && (outputLowSize != OC)) {
467 THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected output low values count " << outputLowSize << " for " << OC << " channels";
469 const bool isOutputHighBroadcasted = outputHighSize != OC;
470 if ((outputHighSize != 1) && (outputHighSize != OC)) {
471 THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected output high values count " << outputHighSize << " for " << OC << " channels";
474 auto levels_1 = fq->get_levels() - 1.f;
476 //const size_t DHW = D * H * W;
477 const size_t IDHW = IC * D * H * W;
479 const auto values = constant->cast_vector<float>();
480 std::vector<float> quantizedValues(OC * IC * D * H * W);
482 for (int oc = 0; oc < OC; ++oc) {
483 for (int iidx = 0; iidx < IDHW; ++iidx) {
484 const float inputLow = inputLowValues[isInputLowBroadcasted ? 0 : oc];
485 const float inputHigh = inputHighValues[isInputHighBroadcasted ? 0 : oc];
486 const float outputLow = outputLowValues[isOutputLowBroadcasted ? 0 : oc];
487 const float outputHigh = outputHighValues[isOutputHighBroadcasted ? 0 : oc];
489 const size_t idx = oc * IDHW + iidx;
491 if (values[idx] <= inputLow) {
492 quantizedValues[idx] = roundValues ? std::roundf(outputLow) : outputLow;
493 } else if (values[idx] > inputHigh) {
494 quantizedValues[idx] = roundValues ? std::roundf(outputHigh) : outputHigh;
496 const float value = std::roundf((values[idx] - inputLow) / (inputHigh - inputLow) * levels_1) /
497 levels_1 * (outputHigh - outputLow) + outputLow;
498 quantizedValues[idx] = roundValues ? std::roundf(value) : value;
503 return std::make_shared<opset1::Constant>(fq->get_output_element_type(0), constShape, quantizedValues);
509 // Decompose FakeQuantize to FakeQuantize with output integer limits (quantize), dequatized MultiplyAdd
510 // To align types the resulting sequence is FakeQuantize -> Convert -> Convert -> MultiplyAdd
511 std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> NetworkHelper::decomposeFakeQuantize(
512 std::shared_ptr<opset1::FakeQuantize> fq,
513 const element::Type precision,
516 const bool hasZeroPoint,
517 const bool updatePrecision) {
518 using std::make_shared;
520 const auto outputLow = fq->input_value(3);
521 const auto outputHigh = fq->input_value(4);
523 std::vector<float> outputLowValues = as_type_ptr<opset1::Constant>(outputLow.get_node_shared_ptr())->cast_vector<float>();
524 std::vector<float> outputHighValues = as_type_ptr<opset1::Constant>(outputHigh.get_node_shared_ptr())->cast_vector<float>();
525 size_t outputSize = outputLowValues.size();
526 std::vector<float> minValues(outputSize, min);
527 std::vector<float> maxValues(outputSize, max);
528 std::vector<float> shifts(outputSize, 0.f);
529 std::vector<float> scales(outputSize);
531 for (int i = 0; i < outputSize; ++i) {
532 if (outputHighValues[i] != outputLowValues[i]) {
533 shifts[i] = (min*outputHighValues[i] - max*outputLowValues[i]) / (outputHighValues[i] - outputLowValues[i]);
534 scales[i] = (outputHighValues[i] - outputLowValues[i]) / (max - min);
535 if (shifts[i] == -0.f) {
539 scales[i] = outputHighValues[i];
545 std::shared_ptr<Node> shift = hasZeroPoint ?
546 std::make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), shifts) :
548 std::shared_ptr<Node> scale = std::make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), scales);
550 auto newMin = make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), minValues);
551 auto newMax = make_shared<opset1::Constant>(outputLow.get_element_type(), outputLow.get_shape(), maxValues);
553 if (isScalarLike(newMin)) {
554 newMin = toScalar(newMin);
556 if (isScalarLike(newMax)) {
557 newMax = toScalar(newMax);
561 static const float minQuantizationScale = 1e-32f;
562 static const float maxQuantizationScale = 1e32f;
564 auto scaleValues = scales;
565 bool wasChanged = false;
566 for (size_t i = 0; i < scaleValues.size(); ++i) {
567 const float scale = scaleValues[i];
568 if (fabs(scale) < minQuantizationScale) {
569 scaleValues[i] = minQuantizationScale;
571 } else if (fabs(scale) > maxQuantizationScale) {
572 scaleValues[i] = scale > 0.f ? maxQuantizationScale : -maxQuantizationScale;
578 scale = std::make_shared<opset1::Constant>(scale->output(0).get_element_type(), scale->output(0).get_shape(), scaleValues);
582 if ((shift != nullptr) && isZero(as_type_ptr<opset1::Constant>(shift))) {
586 // Build a substitution sub-graph:
588 std::shared_ptr<ngraph::Node> newFQ = fold_fake_quantize(
589 std::make_shared<op::TypeRelaxed<opset1::FakeQuantize>>(
596 fq->get_auto_broadcast()),
598 NetworkHelper::copyInfo(fq, newFQ);
600 std::shared_ptr<ngraph::Node> convert2;
601 if (updatePrecision) {
602 std::shared_ptr<Node> convert;
603 std::shared_ptr<opset1::Constant> newFqConstant = as_type_ptr<opset1::Constant>(newFQ);
605 if (is_type<opset1::Constant>(newFQ)) {
606 convert = fold<opset1::Convert>(newFQ, precision);
607 } else if (is_type<opset1::FakeQuantize>(newFQ)) {
608 newFQ = setOutDataPrecision(as_type_ptr<opset1::FakeQuantize>(newFQ), precision);
611 THROW_IE_LPT_EXCEPTION(*newFQ) << "unexpected operation type";
614 convert2 = std::make_shared<DequantizationConvert>(convert, element::f32);
615 convert2->set_friendly_name(convert->get_friendly_name() + "/DequantizationConvert");
616 ngraph::copy_runtime_info({ newFQ, convert2 }, convert2);
618 if (newFQ->get_output_element_type(0) != element::f32) {
619 convert2 = std::make_shared<DequantizationConvert>(newFQ, element::f32);
620 convert2->set_friendly_name(newFQ->get_friendly_name() + "/DequantizationConvert");
621 ngraph::copy_runtime_info({ newFQ, convert2 }, convert2);
625 // TODO: why type relaxed?
626 const std::shared_ptr<ngraph::Node> sub = shift == nullptr ?
628 std::make_shared<ngraph::op::TypeRelaxed<DequantizationSubtract>>(convert2 == nullptr ? newFQ : convert2, shift);
629 if (sub != nullptr) {
630 sub->set_friendly_name(newFQ->get_friendly_name() + "/DequantizationSubtract");
631 ngraph::copy_runtime_info({ newFQ, sub }, sub);
634 const std::shared_ptr<ngraph::opset1::Multiply> dequantize = std::make_shared<DequantizationMultiply>(
635 sub == nullptr ? (convert2 == nullptr ? newFQ : convert2) : sub,
637 dequantize->set_friendly_name(newFQ->get_friendly_name() + "/DequantizationMultiply");
638 ngraph::copy_runtime_info({ newFQ, dequantize }, dequantize);
640 replace_node(fq, dequantize);
642 return std::make_tuple(newFQ, dequantize);
645 std::shared_ptr<opset1::FakeQuantize> NetworkHelper::updateFakeQuantize(
646 std::shared_ptr<opset1::FakeQuantize> fq,
647 element::Type precision,
650 auto newMin = std::make_shared<opset1::Constant>(fq->get_output_element_type(0), Shape{}, min);
651 auto newMax = std::make_shared<opset1::Constant>(fq->get_output_element_type(0), Shape{}, max);
653 std::shared_ptr<opset1::FakeQuantize> newFQ = std::make_shared<ngraph::op::TypeRelaxed<opset1::FakeQuantize>>(
660 fq->get_auto_broadcast());
662 NetworkHelper::setOutDataPrecision(newFQ, precision);
663 replace_node(fq, newFQ);
665 newFQ->set_friendly_name(fq->get_friendly_name());
669 FakeQuantizeDequantization NetworkHelper::makeDequantization(
670 const float dequantizationMul,
671 const float dequantizationSub,
672 const ngraph::element::Type originalPrecision,
673 const ngraph::Shape dataNodeOutputShape,
674 element::Type precision,
677 // TODO: we create input here! we really need it here?
678 const std::shared_ptr<opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(precision, dataNodeOutputShape);
679 std::shared_ptr<ngraph::Node> parent = input;
681 // TODO: convert should be optional: where is updatePrecision?
682 std::shared_ptr<DequantizationConvert> convert;
684 convert = std::make_shared<DequantizationConvert>(
690 std::shared_ptr<DequantizationSubtract> subtract;
691 if (std::abs(dequantizationSub) > 1e-6) {
692 subtract = std::make_shared<ngraph::op::TypeRelaxed<DequantizationSubtract>>(
694 std::make_shared<ngraph::opset1::Constant>(originalPrecision, ngraph::Shape({}), std::vector<float>({ dequantizationSub })));
695 subtract->set_output_type(0, originalPrecision, subtract->get_output_partial_shape(0));
700 std::shared_ptr<ngraph::opset1::Multiply> multiply = std::make_shared<DequantizationMultiply>(
702 std::make_shared<ngraph::opset1::Constant>(originalPrecision, ngraph::Shape({}), std::vector<float>({ dequantizationMul })));
704 return FakeQuantizeDequantization(input, convert, subtract, multiply);
707 FakeQuantizeDequantization NetworkHelper::createDequantizationFromFakeQuantize(
708 std::shared_ptr<opset1::FakeQuantize> fq,
709 element::Type precision,
712 const bool hasZeroPoint,
713 const bool updatePrecision) {
714 using std::make_shared;
716 const ngraph::element::Type_t fqPrecision = fq->get_output_element_type(0);
717 auto newMin = make_shared<opset1::Constant>(fqPrecision, Shape{}, min);
718 auto newMax = make_shared<opset1::Constant>(fqPrecision, Shape{}, max);
720 auto outputLow = fq->input_value(3);
721 auto outputHigh = fq->input_value(4);
723 // TODO: threshold values have to used here to avoid shifts
725 const std::shared_ptr<Node> scale = fold<opset1::Divide>(
726 fold<opset1::Subtract>(outputHigh, outputLow),
727 fold<opset1::Subtract>(newMax, newMin));
729 std::shared_ptr<Node> shift = hasZeroPoint ?
730 fold<opset1::Divide>(
731 fold<opset1::Subtract>(fold<opset1::Multiply>(newMin, outputHigh), fold<opset1::Multiply>(newMax, outputLow)),
732 fold<opset1::Subtract>(outputHigh, outputLow)) :
735 if (shift != nullptr) {
736 std::shared_ptr<opset1::Constant> shiftConst = as_type_ptr<opset1::Constant>(shift);
737 if (isScalarLike(shiftConst)) {
738 auto scalar = toScalar(shiftConst);
739 if (op::util::constantIsEqualTo(scalar, 0)) {
745 const auto input = std::make_shared<ngraph::opset1::Parameter>(precision, fq->get_output_shape(0));
746 const std::shared_ptr<ngraph::opset1::Convert> convert = std::make_shared<DequantizationConvert>(
748 fq->get_output_element_type(0));
750 const std::shared_ptr<ngraph::opset1::Subtract> subtract = shift == nullptr ?
752 make_shared<ngraph::op::TypeRelaxed<DequantizationSubtract>>(convert, shift);
753 if (subtract != nullptr) {
754 subtract->set_output_type(0, fq->get_output_element_type(0), subtract->get_output_partial_shape(0));
757 const std::shared_ptr<ngraph::opset1::Multiply> multiply = std::make_shared<DequantizationMultiply>(
758 subtract == nullptr ? static_cast<std::shared_ptr<Node>>(convert) : subtract,
761 return FakeQuantizeDequantization(fq, convert, subtract, multiply);
764 FakeQuantizeDequantization NetworkHelper::getDequantization(const std::shared_ptr<Node> node, const size_t parentIndex, const bool inPlace) {
765 auto getDataIndex = [](const std::shared_ptr<ngraph::Node>& node) {
766 if (is_type<opset1::Constant>(node->get_input_node_ptr(1))) {
773 Output<Node> dataNode = inPlace ? node : node->input_value(parentIndex);
775 const std::shared_ptr<ngraph::opset1::Multiply> multiply = as_type_ptr<ngraph::opset1::Multiply>(dataNode.get_node_shared_ptr());
776 if (multiply != nullptr) {
777 if (!is_type<opset1::Constant>(multiply->get_input_node_ptr(0)) && !is_type<opset1::Constant>(multiply->get_input_node_ptr(1))) {
778 return FakeQuantizeDequantization();
780 dataNode = multiply->get_input_source_output(getDataIndex(multiply));
783 const std::shared_ptr<opset1::Subtract> subtract = as_type_ptr<ngraph::opset1::Subtract>(dataNode.get_node_shared_ptr());
784 if (subtract != nullptr) {
785 if (!is_type<opset1::Constant>(subtract->get_input_node_ptr(0)) && !is_type<opset1::Constant>(subtract->get_input_node_ptr(1))) {
786 return FakeQuantizeDequantization(dataNode, nullptr, nullptr, multiply);
788 dataNode = subtract->get_input_source_output(getDataIndex(subtract));
791 const std::shared_ptr<opset1::Convert> convert = as_type_ptr<opset1::Convert>(dataNode.get_node_shared_ptr());
792 if (convert != nullptr) {
793 if ((convert->input(0).get_element_type() != element::i8) && (convert->input(0).get_element_type() != element::u8) &&
794 (convert->output(0).get_element_type() != element::f32)) {
795 return FakeQuantizeDequantization(dataNode, nullptr, subtract, multiply);
797 dataNode = convert->get_input_source_output(0);
800 return FakeQuantizeDequantization(dataNode, convert, subtract, multiply);
803 FakeQuantizeDequantizationValues NetworkHelper::createEmptyValues(const FakeQuantizeDequantization& dequantization) {
804 std::shared_ptr<Node> parent = dequantization.convert ? dequantization.convert : dequantization.data.get_node_shared_ptr();
806 std::shared_ptr<Node> multiply1Const = dequantization.multiply ?
807 dequantization.multiply->get_input_node_shared_ptr(1)->clone_with_new_inputs({}) :
808 std::make_shared<opset1::Constant>(parent->get_output_element_type(0), Shape({}), std::vector<float>({ 1.f }));
810 std::shared_ptr<Node> subtract1Const = dequantization.subtract ?
811 dequantization.subtract->get_input_node_shared_ptr(1)->clone_with_new_inputs({}) :
812 std::make_shared<opset1::Constant>(parent->get_output_element_type(0), Shape({}), std::vector<float>({ 0.f }));
814 subtract1Const->set_output_type(0, multiply1Const->get_output_element_type(0), subtract1Const->get_output_partial_shape(0));
816 return FakeQuantizeDequantizationValues(subtract1Const, multiply1Const);
819 bool NetworkHelper::isZeroConst(const std::shared_ptr<Node>& node) {
820 std::shared_ptr<opset1::Constant> constant = as_type_ptr<opset1::Constant>(node);
822 if (constant == nullptr)
825 if (NetworkHelper::isScalarLike(constant)) {
826 auto scalar = NetworkHelper::toScalar(constant);
827 if (op::util::constantIsEqualTo(scalar, 0)) {
837 std::shared_ptr<Node> NetworkHelper::optimizeSubtract(std::shared_ptr<opset1::Subtract> subtract) {
838 auto convertOnSubtract = subtract->input_value(0).get_node_shared_ptr();
839 if (as_type_ptr<opset1::Convert>(convertOnSubtract) == nullptr) {
843 // TODO: replace assert to condition and omit conversion part if there is no convert
844 // TODO: also check convertInputType to understand if we really want to propagate type
845 assert(as_type_ptr<opset1::Convert>(convertOnSubtract));
846 const element::Type convertInputType = convertOnSubtract->get_input_element_type(0);
847 const element::Type convertOutputType = convertOnSubtract->get_output_element_type(0);
849 if (!convertOutputType.is_real()) {
853 auto data = convertOnSubtract->input_value(0);
854 auto shift = subtract->input_value(1).get_node_shared_ptr();
855 auto roundedShift = NetworkHelper::round(shift, convertInputType);
857 // Propagate convertInputType down
858 const auto replacement = std::make_shared<op::TypeRelaxed<opset1::Subtract>>(data, roundedShift);
859 NetworkHelper::copyInfo(subtract, replacement);
860 NetworkHelper::setOutDataPrecisionForTypeRelaxed(replacement, convertOutputType);
861 replace_node(subtract, replacement);
863 // We lose the tail conversion here; not needed if the next node is a TypeRelaxed
864 // TODO: check cases when Convert should be preserved
866 // Try to optimize Add out if constant is zero
867 // TODO: don't remove operation here: don't create this Subtraction operation in FQ decomposition
868 // if (isScalarLike(roundedShift)) {
869 // auto scalar = distillToScalar(roundedShift);
870 // if (op::util::constantIsEqualTo(scalar, 0)) {
871 // replace_node(replacement, replacement->input_value(0).get_node_shared_ptr());
872 // replacement = nullptr;
879 NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter(
880 const std::shared_ptr<ngraph::Node>& operation,
881 const FakeQuantizeDequantization& dequantization,
882 const bool updatePrecision,
883 const bool moveSubtract) {
884 std::vector<Output<Node>> inputs(operation->get_input_size());
885 for (size_t i = 0; i < operation->get_input_size(); ++i) {
886 inputs[i] = operation->get_input_node_shared_ptr(i);
889 const size_t dequantizationIndex = getChildInputIndex(dequantization.multiply, operation);
890 inputs[dequantizationIndex] = moveSubtract ?
891 dequantization.data :
892 (dequantization.subtract == nullptr ? dequantization.data : dequantization.subtract);
894 const std::shared_ptr<ngraph::Node> newOperation = operation->clone_with_new_inputs(inputs);
895 newOperation->set_friendly_name(operation->get_friendly_name());
896 ngraph::copy_runtime_info(operation, newOperation);
898 if (updatePrecision) {
899 auto op = std::dynamic_pointer_cast<ngraph::op::TypeRelaxedBase>(newOperation);
901 THROW_IE_LPT_EXCEPTION(*newOperation) << "not possible to update precision for not TypeRelaxedBase operation";
903 op->set_overridden_output_type(newOperation->get_input_element_type(0));
904 std::dynamic_pointer_cast<ngraph::Node>(newOperation)->validate_and_infer_types();
907 const bool shouldConvert = (newOperation->get_output_element_type(0) != dequantization.multiply->get_output_element_type(0));
909 auto parent = newOperation;
911 const auto convertOutputPrecision = dequantization.convert != nullptr ?
912 dequantization.convert->get_output_element_type(0) :
913 dequantization.multiply->get_output_element_type(0);
914 parent = std::make_shared<DequantizationConvert>(parent, convertOutputPrecision);
915 ngraph::copy_runtime_info({ newOperation, parent }, parent);
918 if (moveSubtract && (dequantization.subtract != nullptr)) {
919 auto subtractConstant = dequantization.subtract->get_input_node_shared_ptr(1);
920 const element::Type parentPrecision = parent->get_output_element_type(0);
921 if (parentPrecision.bitwidth() < subtractConstant->output(0).get_element_type().bitwidth()) {
922 THROW_IE_LPT_EXCEPTION(*parent) <<
923 "unexpected precisions: on data " << parent->get_friendly_name() << ":" << parentPrecision <<
924 ", subtract dequantization constant " << subtractConstant->get_friendly_name() << ":" << subtractConstant->output(0).get_element_type();
927 parent = std::make_shared<DequantizationSubtract>(
929 subtractConstant->output(0).get_element_type() == parentPrecision ?
931 fold<opset1::Convert>(subtractConstant->output(0), parentPrecision));
932 ngraph::copy_runtime_info({ newOperation, parent }, parent);
935 if (dequantization.multiply != nullptr) {
936 auto multiplyConstant = dequantization.multiply->get_input_node_shared_ptr(1);
937 const element::Type parentPrecision = parent->get_output_element_type(0);
938 if (parentPrecision.bitwidth() < multiplyConstant->output(0).get_element_type().bitwidth()) {
939 THROW_IE_LPT_EXCEPTION(*parent) <<
940 "unexpected precisions: on data " << parent->get_friendly_name() << ":" << parentPrecision <<
941 ", multiply dequantization constant " << multiplyConstant->get_friendly_name() << ":" << multiplyConstant->output(0).get_element_type();
944 parent = std::make_shared<DequantizationMultiply>(
946 multiplyConstant->output(0).get_element_type() == parentPrecision ?
948 fold<opset1::Convert>(multiplyConstant->output(0), parentPrecision));
949 ngraph::copy_runtime_info({ newOperation, parent }, parent);
951 replace_node(operation, parent);
953 if ((!moveSubtract) && (dequantization.convert != nullptr) && (dequantization.subtract != nullptr)) {
954 NetworkHelper::cleanRunTimeInfo(dequantization.subtract);
956 // NetworkHelper::optimizeElementwise(dequantization.subtract);
959 return InsertDequantizationResult(newOperation, parent);
962 void NetworkHelper::removeConvertIfPossible(
963 const std::shared_ptr<ngraph::Node>& operation,
964 const FakeQuantizeDequantization& dequantization) {
965 const element::Type precisionBeforeConvert = dequantization.convert->input(0).get_element_type();
967 if (checkConstantValuePrecision(precisionBeforeConvert, dequantization.subtract->get_input_node_shared_ptr(1))) {
968 auto newSubtract = dequantization.subtract->clone_with_new_inputs({
969 dequantization.convert->get_input_node_shared_ptr(0),
970 fold<opset1::Convert>(dequantization.subtract->get_input_node_shared_ptr(1), precisionBeforeConvert) });
971 replace_node(dequantization.subtract, newSubtract);
975 bool NetworkHelper::checkConstantValuePrecision(const element::Type expectedPrecision, const std::shared_ptr<Node>& constant) {
976 if (expectedPrecision.is_signed()) {
980 std::shared_ptr<opset1::Constant> constantOp = as_type_ptr<opset1::Constant>(constant);
981 if (constantOp == nullptr) {
985 const auto values = constantOp->cast_vector<float>();
986 const bool convertCanBeRemoved =
987 (expectedPrecision.is_signed() || (std::all_of(values.begin(), values.end(), [](const float value) { return value >= 0.f; })));
988 return convertCanBeRemoved;
991 size_t NetworkHelper::getChildInputIndex(const std::shared_ptr<ngraph::Node>& parent, const std::shared_ptr<ngraph::Node>& child) {
992 for (size_t i = 0; i < child->get_input_size(); ++i) {
993 if (parent.get() == child->get_input_node_ptr(i)) {
997 THROW_IE_LPT_EXCEPTION(*child) << "child input index between " <<
998 parent->get_friendly_name() << " and " << child->get_friendly_name() << " was not found";
1001 size_t NetworkHelper::getParentOutputIndex(const std::shared_ptr<ngraph::Node>& parent, const std::shared_ptr<ngraph::Node>& child) {
1002 for (size_t i = 0; i < parent->get_output_size(); ++i) {
1003 const auto& targetInputs = parent->output(i).get_target_inputs();
1004 for (const auto& targetInput : targetInputs) {
1005 if (targetInput.get_node() == child.get()) {
1010 THROW_IE_LPT_EXCEPTION(*child) << "parent output index between " <<
1011 parent->get_friendly_name() << " and " << child->get_friendly_name() << " was not found";
1014 std::vector<Output<Node>> NetworkHelper::getInputs(const std::shared_ptr<ngraph::Node>& node) {
1015 std::vector<Output<Node>> inputs(node->get_input_size());
1016 for (size_t i = 0; i < node->get_input_size(); ++i) {
1017 inputs[i] = node->get_input_node_shared_ptr(i);
1022 std::shared_ptr<Node> NetworkHelper::toScalarIfPossible(std::shared_ptr<Node> node) {
1023 std::shared_ptr<opset1::Constant> constant = as_type_ptr<opset1::Constant>(node);
1024 if (constant == nullptr) {
1028 if (!NetworkHelper::isScalarLike(constant)) {
1032 return NetworkHelper::toScalar(constant);
1035 } // namespace low_precision
1037 } // namespace ngraph