[LPT] integration branch: Reshape fix, Concat generalization, runtime info usage...
[platform/upstream/dldt.git] / inference-engine / src / low_precision_transformations / src / common / mvn.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "low_precision/mvn.hpp"
6
7 #include <algorithm>
8 #include <string>
9 #include <memory>
10 #include <cmath>
11 #include <vector>
12
13 #include "ngraph/type/element_type.hpp"
14 #include "ngraph/type/element_type_traits.hpp"
15 #include "low_precision/network_helper.hpp"
16 #include "low_precision/common/dequantization_op.hpp"
17
18 using namespace ngraph;
19 using namespace ngraph::pass;
20 using namespace ngraph::pass::low_precision;
21
22 namespace mvn {
23
24 template<typename T>
25 std::shared_ptr<ngraph::op::Constant> createNewScalesConst(const ngraph::op::Constant& originalConst) {
26     std::vector<T> source = originalConst.cast_vector<T>();
27
28     std::vector<T> newData(source.size());
29     for (size_t i = 0; i < source.size(); ++i) {
30         newData[i] = source[i] < 0 ? T{-1} : T{1};
31     }
32
33     const ngraph::element::Type type = originalConst.get_output_element_type(0);
34     return ngraph::op::Constant::create(type, originalConst.get_shape(), newData);
35 }
36
37 } // namespace mvn
38
39 bool MVNTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
40     if (!LayerTransformation::canBeTransformed(context, operation)) {
41         return false;
42     }
43
44     if (!canSubtractBeHandled(operation)) {
45         return false;
46     }
47
48     auto mvn = as_type_ptr<op::MVN>(operation);
49
50     const std::shared_ptr<Node> multiply = mvn->get_input_node_shared_ptr(0);
51     auto scalesConst = as_type_ptr<ngraph::opset1::Constant>(multiply->get_input_node_shared_ptr(1));
52     if (scalesConst == nullptr) {
53         scalesConst = as_type_ptr<ngraph::opset1::Constant>(multiply->get_input_node_shared_ptr(0));
54     }
55     if (scalesConst == nullptr) {
56         return false;
57     }
58
59     const bool acrossChannels = mvn->get_reduction_axes().count(1) > 0;
60     const bool normalizeVariance = mvn->get_normalize_variance();
61
62     if (!NetworkHelper::isScalarLike(scalesConst) && acrossChannels) {
63         return false;
64     }
65     return true;
66 }
67
68 void MVNTransformation::registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const {
69     addPattern(
70         pass,
71         context,
72         make_op_pattern<ngraph::op::MVN>({ make_op_label<ngraph::opset1::Multiply>() }));
73 }
74
75 bool MVNTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
76     std::shared_ptr<Node> operation = m.get_match_root();
77     if (!canBeTransformed(context, operation)) {
78         return false;
79     }
80
81     auto mvn = as_type_ptr<op::MVN>(separateInStandaloneBranch(operation));
82
83     FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(mvn);
84     auto scalesConst = as_type_ptr<opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(1));
85     if (scalesConst == nullptr) {
86         scalesConst = as_type_ptr<opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(0));
87     }
88
89     const bool acrossChannels = mvn->get_reduction_axes().count(1) > 0;
90     const bool normalizeVariance = mvn->get_normalize_variance();
91
92     auto newScalesConst = scalesConst;
93     const auto type = scalesConst->get_output_element_type(0);
94     if (normalizeVariance) {
95         switch (type) {
96             case ngraph::element::Type_t::f16: {
97                 newScalesConst = mvn::createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f16>::value_type>(*scalesConst);
98                 break;
99             }
100             case ngraph::element::Type_t::f32: {
101                 newScalesConst = mvn::createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f32>::value_type>(*scalesConst);
102                 break;
103             }
104             default: {
105                 THROW_TRANSFORMATION_EXCEPTION << "unexpected element type " << type;
106             }
107         }
108     }
109
110     auto newMVN = std::make_shared<op::TypeRelaxed<op::MVN>>(
111         op::MVN(dequantization.subtract ?
112                     dequantization.subtract :
113                     dequantization.data,
114                 mvn->get_reduction_axes(),
115                 mvn->get_normalize_variance(),
116                 mvn->get_eps()),
117         type);
118     NetworkHelper::copyInfo(mvn, newMVN);
119
120     auto newMultiply = std::make_shared<DequantizationMultiply>(newMVN, newScalesConst);
121     ngraph::copy_runtime_info({ mvn, newMultiply }, newMultiply);
122
123     replace_node(mvn, newMultiply);
124
125     updateOutput(context, newMultiply, newMVN);
126     return true;
127 }
128
129 bool MVNTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
130     return false;
131 }