37eb5dda76a1ba99262709014c9e7d9a7ab5645c
[platform/upstream/dldt.git] / inference-engine / src / low_precision_transformations / src / relu.cpp
1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "low_precision/relu.hpp"
6
7 #include <algorithm>
8 #include <memory>
9 #include <string>
10
11 #include "low_precision/common/ie_lpt_exception.hpp"
12 #include "low_precision/network_helper.hpp"
13
14 namespace ngraph {
15 namespace pass {
16 namespace low_precision {
17
18 void ReluTransformation::registerMatcherIn(GraphRewrite &pass, TransformationContext &context) const {
19     addPattern(
20         pass,
21         context,
22         make_op_pattern<opset1::Relu>({ make_op_label<opset1::Multiply>()}));
23 }
24
25 bool ReluTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) const {
26     std::shared_ptr<Node> relu = m.get_match_root();
27     if (!LayerTransformation::canBeTransformed(context, relu)) {
28         return false;
29     }
30
31     if (!canBeTransformed(context, relu)) {
32         return false;
33     }
34
35     relu = separateInStandaloneBranch(relu);
36     const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(relu, 0);
37     moveDequantizationAfter(context, relu, dequantization, false, false);
38     return true;
39 }
40
41 bool ReluTransformation::isPrecisionPreserved(std::shared_ptr<Node> op) const noexcept {
42     return true;
43 }
44
45 bool ReluTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const {
46     if (!LayerTransformation::canBeTransformed(context, op)) {
47         return false;
48     }
49
50     const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(op, 0);
51     if (dequantization.empty()) {
52         return false;
53     }
54
55     if (!canSubtractBeHandled(op, dequantization)) {
56         return false;
57     }
58
59     const std::shared_ptr<opset1::Constant> constant = as_type_ptr<opset1::Constant>(dequantization.multiply->input_value(1).get_node_shared_ptr());
60     const auto scales = constant->cast_vector<float>();
61     if (std::any_of(scales.begin(), scales.end(), [](const float value) { return value < 0.f; })) {
62         return false;
63     }
64
65     return true;
66 }
67
68 } // namespace low_precision
69 } // namespace pass
70 } // namespace ngraph