[LPT] integration: issue #42391 & issue #43001 (#3201)
[platform/upstream/dldt.git] / inference-engine / src / low_precision_transformations / src / clamp.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "low_precision/clamp.hpp"
6 #include <algorithm>
7 #include <memory>
8 #include <ngraph/ngraph.hpp>
9 #include "low_precision/network_helper.hpp"
10
11 namespace ngraph {
12 namespace pass {
13 namespace low_precision {
14
15 ClampTransformation::ClampTransformation(const Params& params) : LayerTransformation(params) {}
16
17 void ClampTransformation::registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const {
18     addPattern(pass,
19                context,
20                make_op_pattern<opset1::Clamp>({ make_op_label<opset1::Multiply>() }));
21 }
22
23 bool ClampTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher& m) const {
24     auto subWithTheSameValues = [](std::shared_ptr<ngraph::opset1::Subtract> sub) {
25         if (sub == nullptr) {
26             return false;
27         }
28         const auto constant = as_type_ptr<ngraph::opset1::Constant>(sub->get_input_node_shared_ptr(1));
29
30         if (constant == nullptr) {
31             return false;
32         }
33
34         return NetworkHelper::isScalarLike(constant);
35     };
36
37     if (!canBeTransformed(context, m.get_match_root())) {
38         return false;
39     }
40
41     const std::shared_ptr<Node> clamp = separateInStandaloneBranch(m.get_match_root());
42     const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(clamp);
43
44     const bool moveSubtract = subWithTheSameValues(dequantization.subtract);
45     // issue #43136
46     if (!moveSubtract && (dequantization.subtract != nullptr)) {
47         return false;
48     }
49     const auto newClamp = as_type_ptr<opset1::Clamp>(moveDequantizationAfter(context, clamp, dequantization, false, moveSubtract));
50     double min = newClamp->get_min();
51     double max = newClamp->get_max();
52
53     if (dequantization.multiply != nullptr) {
54         double scale = as_type_ptr<opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(1))->cast_vector<double>()[0];
55         if (scale < 0.0) {
56             std::swap(min, max);
57         }
58         min /= scale;
59         max /= scale;
60     }
61
62     if (dequantization.subtract != nullptr && moveSubtract) {
63         double shift = as_type_ptr<opset1::Constant>(dequantization.subtract->get_input_node_shared_ptr(1))->cast_vector<double>()[0];
64         min += shift;
65         max += shift;
66     }
67
68     const std::shared_ptr<ngraph::opset1::Clamp> replacement = std::make_shared<ngraph::opset1::Clamp>(newClamp->get_input_node_shared_ptr(0), min, max);
69     replace_node(newClamp, replacement);
70
71     element::Type outputClampType = dequantization.multiply ?
72         dequantization.multiply->get_output_element_type(0) :
73         dequantization.subtract->get_output_element_type(0);
74     ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(replacement, outputClampType);
75     return true;
76 }
77
78 bool ClampTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const {
79     if (!LayerTransformation::canBeTransformed(context, op)) {
80         return false;
81     }
82     const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(op);
83
84     const auto mulConst = as_type_ptr<ngraph::opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(1));
85     if (mulConst == nullptr) {
86         return false;
87     }
88
89     return NetworkHelper::isScalarLike(mulConst);
90 }
91
92 bool ClampTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
93     return false;
94 }
95
96 } // namespace low_precision
97 } // namespace pass
98 } // namespace ngraph