15999cc1154f0ce1d7bb1f89f4d467c6914b055a
[platform/upstream/dldt.git] / inference-engine / src / low_precision_transformations / src / common / normalize_l2.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "low_precision/normalize_l2.hpp"
6
7 #include <string>
8 #include <memory>
9 #include <cmath>
10 #include <vector>
11
12 #include "ngraph/type/element_type.hpp"
13 #include "ngraph/type/element_type_traits.hpp"
14 #include "low_precision/network_helper.hpp"
15 #include "low_precision/common/dequantization_op.hpp"
16
17 using namespace ngraph;
18 using namespace ngraph::pass;
19 using namespace ngraph::pass::low_precision;
20
21 namespace normalize_l2 {
22
23 template<typename T>
24 std::shared_ptr<ngraph::op::Constant> createNewScalesConst(const ngraph::op::Constant& originalConst) {
25     std::vector<T> source = originalConst.cast_vector<T>();
26
27     std::vector<T> newData(source.size());
28     for (size_t i = 0; i < source.size(); ++i) {
29         newData[i] = source[i] < 0 ? T{-1} : T{1};
30     }
31
32     const ngraph::element::Type type = originalConst.get_output_element_type(0);
33     return ngraph::op::Constant::create(type, originalConst.get_shape(), newData);
34 }
35
36 } // namespace normalize_l2
37
38 bool NormalizeL2Transformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
39     if (!LayerTransformation::canBeTransformed(context, operation)) {
40         return false;
41     }
42
43     if (!canSubtractBeHandled(operation)) {
44         return false;
45     }
46
47     const std::shared_ptr<Node> multiply = operation->get_input_node_shared_ptr(0);
48     auto scalesConst = as_type_ptr<ngraph::opset1::Constant>(multiply->get_input_node_shared_ptr(1));
49     if (scalesConst == nullptr) {
50         scalesConst = as_type_ptr<ngraph::opset1::Constant>(multiply->get_input_node_shared_ptr(0));
51     }
52     if (scalesConst == nullptr) {
53         return false;
54     }
55
56     // TODO: Expand transformation for all cases of axes values
57     const auto axes = as_type_ptr<opset1::Constant>(operation->get_input_node_shared_ptr(1));
58     const std::vector<int64_t> axesAcrossSpatial = { 1 };
59     const std::vector<int64_t> axesByChannels = { 1, 2, 3 };
60
61     std::vector<int64_t> axesValues = axes->cast_vector<int64_t>();
62     if (!(axesValues == axesAcrossSpatial || axesValues == axesByChannels)) {
63         return false;
64     }
65
66     const ngraph::Shape outputShape = scalesConst->get_output_shape(0);
67     const size_t size = ngraph::shape_size(outputShape);
68     const size_t channels = operation->get_output_shape(0)[1];
69
70     if (size != channels && size != 1) {
71         return false;
72     }
73
74     if (!NetworkHelper::isScalarLike(scalesConst)) {
75         return false;
76     }
77
78     return true;
79 }
80
81 void NormalizeL2Transformation::registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const {
82     addPattern(
83         pass,
84         context,
85         make_op_pattern<ngraph::opset1::NormalizeL2>({
86             make_op_label<ngraph::opset1::Multiply>(),
87             make_op_label<ngraph::opset1::Constant>()
88             }));
89 }
90
91 bool NormalizeL2Transformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
92     std::shared_ptr<Node> operation = m.get_match_root();
93     if (!canBeTransformed(context, operation)) {
94         return false;
95     }
96
97     auto normalize = as_type_ptr<opset1::NormalizeL2>(separateInStandaloneBranch(operation));
98
99     const auto axes = as_type_ptr<opset1::Constant>(normalize->get_input_node_shared_ptr(1));
100     FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(normalize);
101     auto scalesConst = as_type_ptr<opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(1));
102     if (scalesConst == nullptr) {
103         scalesConst = as_type_ptr<opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(0));
104     }
105
106     std::shared_ptr<opset1::Constant> newScalesConst;
107     const auto type = scalesConst->get_output_element_type(0);
108     switch (type) {
109         case ngraph::element::Type_t::f16: {
110             newScalesConst = normalize_l2::createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f16>::value_type>(*scalesConst);
111             break;
112         }
113         case ngraph::element::Type_t::f32: {
114             newScalesConst = normalize_l2::createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f32>::value_type>(*scalesConst);
115             break;
116         }
117         default: {
118             THROW_TRANSFORMATION_EXCEPTION << "unexpected element type " << type;
119         }
120     }
121
122     auto newNormalize = std::make_shared<op::TypeRelaxed<opset1::NormalizeL2>>(
123         std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{element::f32},
124         ngraph::op::TemporaryReplaceOutputType(dequantization.subtract == nullptr ? dequantization.data : dequantization.subtract, element::f32).get(),
125         ngraph::op::TemporaryReplaceOutputType(axes->clone_with_new_inputs({}), element::f32).get(),
126         normalize->get_eps(),
127         normalize->get_eps_mode());
128     NetworkHelper::copyInfo(normalize, newNormalize);
129
130     auto newMultiply = std::make_shared<op::TypeRelaxed<DequantizationMultiply>>(
131         std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{element::f32},
132         ngraph::op::TemporaryReplaceOutputType(newNormalize, element::f32).get(),
133         ngraph::op::TemporaryReplaceOutputType(newScalesConst, element::f32).get());
134
135     replace_node(normalize, newMultiply);
136
137     updateOutput(context, newMultiply, normalize);
138     return true;
139 }
140
141 bool NormalizeL2Transformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
142     return false;
143 }