1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <low_precision/layer_transformation.hpp>
6 #include <low_precision/network_helper.hpp>
16 #include <unordered_set>
22 namespace low_precision {
24 const char LayerTransformation::originalLayerPostfix[] = "_original";
26 LayerTransformation::LayerTransformation(const Params& params) :
27 updatePrecisions(params.updatePrecisions),
28 quantizedTensorAlignmentOnActivations(params.quantizedTensorAlignmentOnActivations),
29 quantizedTensorAlignmentOnWeights(params.quantizedTensorAlignmentOnWeights),
30 supportAsymmetricQuantization(params.supportAsymmetricQuantization),
31 precisionsOnActivations(params.precisionsOnActivations),
32 precisionsOnWeights(params.precisionsOnWeights),
33 layerTransformationsManager(nullptr),
34 paramsManager(nullptr),
35 quantizationIntervalAsymmetryThreshold(0.002f),
36 zeroThreshold(1.e-6f),
37 minQuantizationLevels(2ul) {}
39 void LayerTransformation::setParamsManager(IParamsManager* paramsManager) noexcept {
40 this->paramsManager = paramsManager;
43 void LayerTransformation::setLayerTransformationsManager(ILayerTransformationsManager* layerTransformationsManager) noexcept {
44 this->layerTransformationsManager = layerTransformationsManager;
47 void LayerTransformation::setUpdatePrecisions(const bool updatePrecisions) {
48 this->updatePrecisions = updatePrecisions;
51 void LayerTransformation::setQuantizedTensorAlignmentOnActivations(
52 const QuantizedTensorAlignment quantizedTensorAlignmentOnActivations) {
53 this->quantizedTensorAlignmentOnActivations = quantizedTensorAlignmentOnActivations;
56 void LayerTransformation::setQuantizedTensorAlignmentOnWeights(
57 const QuantizedTensorAlignment quantizedTensorAlignmentOnWeights) {
58 this->quantizedTensorAlignmentOnWeights = quantizedTensorAlignmentOnWeights;
61 const std::vector<element::Type>& LayerTransformation::getPrecisionsOnActivations() const {
62 return precisionsOnActivations;
65 const std::vector<element::Type>& LayerTransformation::getPrecisionsOnWeights() const {
66 return precisionsOnWeights;
69 bool LayerTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
70 if (!isQuantized(layer)) {
74 for (const auto& output : layer->outputs()) {
75 const size_t size = output.get_shape().size();
76 if ((size < 2ul) || (size > 5ul)) {
81 const auto dequantization = NetworkHelper::getDequantization(layer);
82 if (!dequantization.empty()) {
83 auto perChannelQuantization = [](const Shape dataShape, Shape constShape) {
84 if ((dataShape.size() - constShape.size()) == 1ul) {
85 constShape.insert(constShape.begin(), 1ul);
88 if ((constShape.size() >= 2ul) && (constShape[0] != 1ul)) {
92 for (size_t i = 2; i < constShape.size(); ++i) {
93 if (constShape[i] != 1ul) {
100 if ((dequantization.subtract != nullptr) && (!perChannelQuantization(
101 dequantization.subtract->output(0).get_shape(),
102 dequantization.subtract->input(1).get_shape()))) {
106 if ((dequantization.multiply != nullptr) && (!perChannelQuantization(
107 dequantization.multiply->output(0).get_shape(),
108 dequantization.multiply->input(1).get_shape()))) {
116 bool LayerTransformation::canSubtractBeHandled(const std::shared_ptr<Node>& op, const size_t parentIndex) const {
117 return canSubtractBeHandled(op, NetworkHelper::getDequantization(op, parentIndex));
120 bool LayerTransformation::canSubtractBeHandled(const std::shared_ptr<Node>& op, const FakeQuantizeDequantization& dequantization) const {
121 if (dequantization.empty() || (dequantization.subtract == nullptr)) {
125 if (!supportAsymmetricQuantization) {
129 if (!updatePrecisions) {
133 const element::Type operationType = dequantization.convert == nullptr ?
134 dequantization.subtract->input(0).get_element_type() :
135 dequantization.convert->input(0).get_element_type();
137 if ((operationType != element::i8) && (operationType != element::u8)) {
144 #ifdef LPT_PRINT_DEQUANTIZATION_INFO
145 std::stringstream toStream(const std::vector<float>& dequantizationValues) {
146 std::stringstream ss;
147 const size_t scalesCount = dequantizationValues.size() > 9ul ? 9ul : dequantizationValues.size();
149 for (size_t i = 0ul; i < scalesCount; ++i) {
150 ss << dequantizationValues[i] << (i < (scalesCount - 1) ? "," : "");
156 void LayerTransformation::printDequantizationInfo(const std::shared_ptr<Node>& layer) {
157 const QuantizationDetails quantizationDetails = QuantizationDetails::getDetails(as_type_ptr<opset1::FakeQuantize>(layer));
159 layer->get_type_name() << (NetworkHelper::onWeights(layer) ? " on weights " : " on activations ") <<
160 layer->get_friendly_name() << ":" << std::endl <<
161 " details : " << quantizationDetails << std::endl;
164 void LayerTransformation::printDequantizationInfo(const DataPrecision& dataPrecision) {
165 std::cout << " precision: " << dataPrecision << std::endl;
168 void LayerTransformation::printDequantizationValues(
169 const std::vector<float>& dequantizationScales,
170 const std::vector<float>& dequantizationShifts) {
172 " scales : " << toStream(dequantizationScales).str() << std::endl <<
173 " shifts : " << toStream(dequantizationShifts).str() << std::endl;
177 void LayerTransformation::setQuantizationIntervalAsymmetryThreshold(const float value) {
178 this->quantizationIntervalAsymmetryThreshold = value;
181 void LayerTransformation::setZeroThreshold(const float value) {
182 this->zeroThreshold = value;
185 void LayerTransformation::setMinQuantizationLevels(const size_t levels) {
186 this->minQuantizationLevels = levels;
189 LayerTransformation::PrecisionDetails LayerTransformation::getPrecisionDetails(const QuantizationDetails& quantizationDetails) const {
190 const float asymmetricIntervalSideRatio256 = -128.f / 127.f;
191 bool hasNegative = false;
192 bool signedPrecision = true;
193 bool unsignedPrecision = true;
195 bool hasZeroPoint = false;
196 for (size_t i = 0; i < quantizationDetails.outputLowValues.size(); ++i) {
197 const bool signedInterval = std::signbit(quantizationDetails.outputLowValues[i]) != std::signbit(quantizationDetails.outputHighValues[i]);
198 const bool boundaryValuesAreNotZero =
199 (std::fabs(quantizationDetails.outputLowValues[i]) >= zeroThreshold) &&
200 (std::fabs(quantizationDetails.outputHighValues[i]) >= zeroThreshold);
201 if (signedInterval && boundaryValuesAreNotZero) {
203 unsignedPrecision = false;
206 const float expectedRatio = quantizationDetails.levels == 256 ? asymmetricIntervalSideRatio256 : -1.f;
207 const float actualRatio = quantizationDetails.outputLowValues[i] / quantizationDetails.outputHighValues[i];
208 const float actual = std::fabs((actualRatio - expectedRatio) / std::min(actualRatio, expectedRatio));
209 if (actual > quantizationIntervalAsymmetryThreshold) {
212 #ifdef LPT_PRINT_DEQUANTIZATION_INFO
214 std::cout << " actual: " << actual << ", threshold: " << quantizationIntervalAsymmetryThreshold << std::endl;
215 std::cout << " hasZeroPoint: " << (hasZeroPoint ? "True" : "False") << std::endl;
220 signedPrecision = false;
221 if (boundaryValuesAreNotZero) {
222 hasZeroPoint = boundaryValuesAreNotZero;
225 #ifdef LPT_PRINT_DEQUANTIZATION_INFO
227 const float actual = quantizationDetails.outputLowValues[i] > 0.f ?
228 quantizationDetails.outputLowValues[i] :
229 quantizationDetails.outputHighValues[i];
230 std::cout << " actual: " << actual << ", threshold: 0.0" << std::endl;
231 std::cout << " hasZeroPoint: " << (hasZeroPoint ? "True" : "False") << std::endl;
238 if (signedPrecision && (!unsignedPrecision)) {
239 return LayerTransformation::PrecisionDetails(element::i8, hasNegative, hasZeroPoint);
242 if ((!signedPrecision) && unsignedPrecision) {
243 return LayerTransformation::PrecisionDetails(element::u8, hasNegative, hasZeroPoint);
247 return LayerTransformation::PrecisionDetails(element::undefined, hasNegative, hasZeroPoint);
250 bool LayerTransformation::isQuantized(std::shared_ptr<Node> layer) const noexcept {
254 DataPrecision LayerTransformation::getDataPrecision(
255 std::shared_ptr<Node> layer,
256 const QuantizationDetails& quantizationDetails,
257 const bool onWeights) const {
258 #ifdef LPT_PRINT_DEQUANTIZATION_INFO
259 printDequantizationInfo(layer);
261 std::vector<element::Type> precisions = onWeights ? precisionsOnWeights : precisionsOnActivations;
262 PrecisionDetails precisionDetailsAtOutputIntervals = getPrecisionDetails(quantizationDetails);
264 if (precisionDetailsAtOutputIntervals.precision != element::undefined) {
266 fillAvailablePrecisions(layer, precisions);
269 // if supportedPrecisions is empty then use the first available, not supported layer will be in original precision
270 if (!precisions.empty()) {
271 const auto foundIt = std::find(precisions.begin(), precisions.end(), precisionDetailsAtOutputIntervals.precision);
272 const element::Type resultPrecision = foundIt != precisions.end() ?
273 precisionDetailsAtOutputIntervals.precision :
276 const DataPrecision dataPrecision(
278 DataPrecision::getMinValue(resultPrecision, quantizationDetails.levels),
279 DataPrecision::getMaxValue(resultPrecision, quantizationDetails.levels),
280 foundIt != precisions.end() ? precisionDetailsAtOutputIntervals.hasZeroPoint : true);
282 #ifdef LPT_PRINT_DEQUANTIZATION_INFO
283 printDequantizationInfo(dataPrecision);
285 return dataPrecision;
290 const DataPrecision dataPrecision = precisions.empty() ?
291 DataPrecision(element::undefined, 0.f, 0.f, false) :
294 DataPrecision::getMinValue(*precisions.begin(), quantizationDetails.levels),
295 DataPrecision::getMaxValue(*precisions.begin(), quantizationDetails.levels),
297 #ifdef LPT_PRINT_DEQUANTIZATION_INFO
298 printDequantizationInfo(dataPrecision);
300 return dataPrecision;
303 void LayerTransformation::fillAvailablePrecisions(std::shared_ptr<Node> layer, std::vector<element::Type>& availablePrecisions) const {
304 if (availablePrecisions.empty()) {
308 const std::vector<std::shared_ptr<Node>> children = NetworkHelper::consumers(layer);
309 for (auto child : children) {
310 if (child->get_type_info().is_castable(opset1::FakeQuantize::get_type_info_static())) {
311 // FakeQuantize layer updates precision
315 if (!layerTransformationsManager->isQuantized(child)) {
316 // low precision chain is interrupted here: next operation supported precisions are ignored
320 const std::vector<element::Type> childPrecisionsOnActivations = paramsManager->getPrecisionsOnActivations(*child);
321 if (childPrecisionsOnActivations.size() == 0ul) {
325 for (size_t index = 0ul; index < availablePrecisions.size();) {
326 const element::Type availablePrecision = availablePrecisions[index];
328 childPrecisionsOnActivations.begin(),
329 childPrecisionsOnActivations.end(),
330 [&](const element::Type precision) { return availablePrecision == precision; })) {
331 availablePrecisions.erase(availablePrecisions.begin() + index);
337 if (!layerTransformationsManager->isPrecisionPreserved(child)) {
341 fillAvailablePrecisions(child, availablePrecisions);
342 if (availablePrecisions.empty()) {
348 std::vector<std::shared_ptr<Node>> LayerTransformation::getChildrenRecursivelyExceptPrecisionPreserved(
349 const std::shared_ptr<Node>& op) const noexcept {
350 std::queue<std::shared_ptr<Node>> notHandledChildren;
352 for (const auto& output : op->outputs()) {
353 for (const auto& input : output.get_target_inputs()) {
354 std::shared_ptr<Node> child = input.get_node()->shared_from_this();
355 notHandledChildren.emplace(child);
359 std::vector<std::shared_ptr<Node>> resultChildren;
361 while (!notHandledChildren.empty()) {
362 const std::shared_ptr<ngraph::Node> operation = notHandledChildren.front();
363 notHandledChildren.pop();
365 if (!this->layerTransformationsManager->isPrecisionPreserved(operation)) {
366 resultChildren.push_back(operation);
370 for (const auto& output : operation->outputs()) {
371 for (const auto& input : output.get_target_inputs()) {
372 std::shared_ptr<Node> child = input.get_node()->shared_from_this();
373 notHandledChildren.emplace(child);
378 return resultChildren;
382 std::shared_ptr<ngraph::Node> LayerTransformation::separateInStandaloneBranch(std::shared_ptr<ngraph::Node> node) const {
383 FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(node);
384 if (dequantization.isShared()) {
385 Output<Node> parent = dequantization.data;
386 if (dequantization.convert != nullptr) {
387 parent = dequantization.convert->clone_with_new_inputs({ parent });
388 parent.get_node_shared_ptr()->set_friendly_name(parent.get_node_shared_ptr()->get_name() + "_new");
391 if (dequantization.subtract != nullptr) {
392 parent = dequantization.subtract->clone_with_new_inputs({
394 dequantization.subtract->get_input_node_shared_ptr(1)->clone_with_new_inputs({}) });
395 parent.get_node_shared_ptr()->set_friendly_name(parent.get_node_shared_ptr()->get_name() + "_new");
398 if (dequantization.multiply != nullptr) {
399 parent = dequantization.multiply->clone_with_new_inputs({
401 dequantization.multiply->get_input_node_shared_ptr(1)->clone_with_new_inputs({}) });
402 parent.get_node_shared_ptr()->set_friendly_name(parent.get_node_shared_ptr()->get_name() + "_new");
405 std::vector<Output<Node>> inputs = NetworkHelper::getInputs(node);
406 const size_t inputIndex = NetworkHelper::getChildInputIndex(dequantization.multiply, node);
407 inputs[inputIndex] = parent;
408 const std::shared_ptr<Node> newNode = node->clone_with_new_inputs(inputs);
410 replace_node(node, newNode);
411 newNode->set_friendly_name(node->get_friendly_name());
419 std::shared_ptr<ngraph::Node> LayerTransformation::moveDequantizationAfter(
420 TransformationContext &context,
421 const std::shared_ptr<ngraph::Node>& operation,
422 const FakeQuantizeDequantization& dequantization,
423 const bool updatePrecision,
424 const bool moveSubtract) const {
425 const auto result = ngraph::pass::low_precision::NetworkHelper::moveDequantizationAfter(operation, dequantization, updatePrecision, moveSubtract);
426 updateOutput(context, result.lastDequantization, result.newOperation);
427 return result.newOperation;
430 void LayerTransformation::fuseConvertIfPossible(const std::shared_ptr<ngraph::Node>& operation) const {
431 FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(operation, 0);
432 if ((dequantization.subtract != nullptr) &&
433 NetworkHelper::checkConstantValuePrecision(
434 dequantization.convert->get_output_element_type(0),
435 dequantization.subtract->get_input_node_shared_ptr(1))) {
436 auto newOperation = separateInStandaloneBranch(operation);
437 dequantization = NetworkHelper::getDequantization(operation, 0);
438 // TODO: It is correct to use optimizeSubtract here: uncomment following rows and fix it
439 //auto newSubtract = NetworkHelper::optimizeSubtract(dequantization.subtract);
440 //replace_node(dequantization.subtract, newSubtract);
441 NetworkHelper::removeConvertIfPossible(operation, dequantization);
445 void LayerTransformation::updateOutput(
446 TransformationContext &context,
447 std::shared_ptr<ngraph::Node> lastNode,
448 std::shared_ptr<ngraph::Node> originalNode) const {
449 const size_t outputSize = context.function->get_output_size();
450 for (size_t i = 0; i < outputSize; ++i) {
451 std::shared_ptr<ngraph::Node> result = context.function->get_output_op(i);
452 std::shared_ptr<ngraph::Node> outputNode = result->get_input_node_shared_ptr(0);
453 if (outputNode.get() == lastNode.get()) {
454 const std::string originalName = originalNode->get_friendly_name();
455 originalNode->set_friendly_name(originalName + LayerTransformation::originalLayerPostfix);
456 lastNode->set_friendly_name(originalName);
462 void LayerTransformation::updateOutput(
463 TransformationContext& context,
464 std::shared_ptr<ngraph::Node> lastNode,
465 std::string originalName) const {
466 const size_t outputSize = context.function->get_output_size();
467 for (size_t i = 0; i < outputSize; ++i) {
468 std::shared_ptr<ngraph::Node> result = context.function->get_output_op(i);
469 std::shared_ptr<ngraph::Node> outputNode = result->get_input_node_shared_ptr(0);
470 if (outputNode.get() == lastNode.get()) {
471 lastNode->set_friendly_name(originalName);
477 void LayerTransformation::addPattern(ngraph::pass::GraphRewrite& pass, TransformationContext& context, std::shared_ptr<Node> patternRoot) const {
478 ngraph::graph_rewrite_callback internal_callback = [this, &context](ngraph::pattern::Matcher &m) {
479 const bool result = transform(context, m);
480 #ifdef LPT_DISPLAY_PRECISION
482 auto operationNode = m.get_match_root();
483 std::cout << "Operation was transformed: " <<
484 operationNode->get_type_name() << ", " <<
485 operationNode->get_friendly_name() << ", output operation precision: " <<
486 ((operationNode->get_output_size() == 1u) ? operationNode->get_output_element_type(0) : ngraph::element::Type()) <<
492 // TODO: better name for matcher? required?
493 auto m = std::make_shared<ngraph::pattern::Matcher>(patternRoot, "SingleNodeMatcher");
494 pass.add_matcher(m, internal_callback, ngraph::pass::PassProperty::CHANGE_DYNAMIC_STATE);
497 } // namespace low_precision
499 } // namespace ngraph