1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "low_precision/clamp.hpp"
8 #include <ngraph/ngraph.hpp>
9 #include "low_precision/network_helper.hpp"
13 namespace low_precision {
15 ClampTransformation::ClampTransformation(const Params& params) : LayerTransformation(params) {}
17 void ClampTransformation::registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const {
20 make_op_pattern<opset1::Clamp>({ make_op_label<opset1::Multiply>() }));
23 bool ClampTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher& m) const {
24 auto subWithTheSameValues = [](std::shared_ptr<ngraph::opset1::Subtract> sub) {
28 const auto constant = as_type_ptr<ngraph::opset1::Constant>(sub->get_input_node_shared_ptr(1));
30 if (constant == nullptr) {
34 return NetworkHelper::isScalarLike(constant);
37 if (!canBeTransformed(context, m.get_match_root())) {
41 const std::shared_ptr<Node> clamp = separateInStandaloneBranch(m.get_match_root());
42 const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(clamp);
44 const bool moveSubtract = subWithTheSameValues(dequantization.subtract);
46 if (!moveSubtract && (dequantization.subtract != nullptr)) {
49 const auto newClamp = as_type_ptr<opset1::Clamp>(moveDequantizationAfter(context, clamp, dequantization, false, moveSubtract));
50 double min = newClamp->get_min();
51 double max = newClamp->get_max();
53 if (dequantization.multiply != nullptr) {
54 double scale = as_type_ptr<opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(1))->cast_vector<double>()[0];
62 if (dequantization.subtract != nullptr && moveSubtract) {
63 double shift = as_type_ptr<opset1::Constant>(dequantization.subtract->get_input_node_shared_ptr(1))->cast_vector<double>()[0];
68 const std::shared_ptr<ngraph::opset1::Clamp> replacement = std::make_shared<ngraph::opset1::Clamp>(newClamp->get_input_node_shared_ptr(0), min, max);
69 replace_node(newClamp, replacement);
71 element::Type outputClampType = dequantization.multiply ?
72 dequantization.multiply->get_output_element_type(0) :
73 dequantization.subtract->get_output_element_type(0);
74 ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(replacement, outputClampType);
78 bool ClampTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const {
79 if (!LayerTransformation::canBeTransformed(context, op)) {
82 const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(op);
84 const auto mulConst = as_type_ptr<ngraph::opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(1));
85 if (mulConst == nullptr) {
89 return NetworkHelper::isScalarLike(mulConst);
92 bool ClampTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
96 } // namespace low_precision