eaee3ef697892a34f59730a613547150d3c00de4
[platform/upstream/dldt.git] / inference-engine / src / transformations / src / transformations / low_precision / normalize_l2.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "transformations/low_precision/normalize_l2.hpp"
6
7 #include <string>
8 #include <memory>
9 #include <cmath>
10 #include <vector>
11
12 #include "ngraph/type/element_type.hpp"
13 #include "ngraph/type/element_type_traits.hpp"
14 #include "transformations/low_precision/network_helper.hpp"
15
16 using namespace ngraph;
17 using namespace ngraph::pass;
18 using namespace ngraph::pass::low_precision;
19
20 namespace {
21
22 template<typename T>
23 std::shared_ptr<ngraph::op::Constant> createNewScalesConst(const ngraph::op::Constant& originalConst) {
24     std::vector<T> source = originalConst.cast_vector<T>();
25
26     std::vector<T> newData(source.size());
27     for (size_t i = 0; i < source.size(); ++i) {
28         newData[i] = source[i] < 0 ? -1 : 1;
29     }
30
31     const ngraph::element::Type type = originalConst.get_output_element_type(0);
32     return ngraph::op::Constant::create(type, originalConst.get_shape(), newData);
33 }
34
35 } // namespace
36
37 bool NormalizeL2Transformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
38     if (!LayerTransformation::canBeTransformed(context, operation)) {
39         return false;
40     }
41
42     if (!canSubtractBeHandled(operation)) {
43         return false;
44     }
45
46     const std::shared_ptr<Node> multiply = operation->get_input_node_shared_ptr(0);
47     auto scalesConst = as_type_ptr<ngraph::opset1::Constant>(multiply->get_input_node_shared_ptr(1));
48     if (scalesConst == nullptr) {
49         scalesConst = as_type_ptr<ngraph::opset1::Constant>(multiply->get_input_node_shared_ptr(0));
50     }
51     if (scalesConst == nullptr) {
52         return false;
53     }
54
55     // TODO: Expand transformation for all cases of axes values
56     const auto axes = as_type_ptr<opset1::Constant>(operation->get_input_node_shared_ptr(1));
57     const std::vector<int64_t> axesAcrossSpatial = { 1 };
58     const std::vector<int64_t> axesByChannels = { 1, 2, 3 };
59
60     std::vector<int64_t> axesValues = axes->cast_vector<int64_t>();
61     if (!(axesValues == axesAcrossSpatial || axesValues == axesByChannels)) {
62         return false;
63     }
64
65     const ngraph::Shape outputShape = scalesConst->get_output_shape(0);
66     const size_t size = ngraph::shape_size(outputShape);
67     const size_t channels = operation->get_output_shape(0)[1];
68
69     if (size != channels && size != 1) {
70         return false;
71     }
72
73     if (!NetworkHelper::isScalarLike(scalesConst)) {
74         return false;
75     }
76
77     return true;
78 }
79
80 void NormalizeL2Transformation::registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const {
81     addPattern(
82         pass,
83         context,
84         make_op_pattern<ngraph::opset1::NormalizeL2>({
85             make_op_label<ngraph::opset1::Multiply>(),
86             make_op_label<ngraph::opset1::Constant>()
87             }));
88 }
89
90 bool NormalizeL2Transformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
91     std::shared_ptr<Node> operation = m.get_match_root();
92     if (!canBeTransformed(context, operation)) {
93         return false;
94     }
95
96     auto normalize = as_type_ptr<opset1::NormalizeL2>(separateInStandaloneBranch(operation));
97
98     const auto axes = as_type_ptr<opset1::Constant>(normalize->get_input_node_shared_ptr(1));
99     FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(normalize);
100     auto scalesConst = as_type_ptr<opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(1));
101     if (scalesConst == nullptr) {
102         scalesConst = as_type_ptr<opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(0));
103     }
104
105     std::shared_ptr<opset1::Constant> newScalesConst;
106     const auto type = scalesConst->get_output_element_type(0);
107     switch (type) {
108         case ngraph::element::Type_t::f16: {
109             newScalesConst = createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f16>::value_type>(*scalesConst);
110             break;
111         }
112         case ngraph::element::Type_t::f32: {
113             newScalesConst = createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f32>::value_type>(*scalesConst);
114             break;
115         }
116         default: {
117             THROW_TRANSFORMATION_EXCEPTION << "unexpected element type " << type;
118         }
119     }
120
121     auto newNormalize = std::make_shared<op::TypeRelaxed<opset1::NormalizeL2>>(
122         std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{element::f32},
123         ngraph::op::TemporaryReplaceOutputType(dequantization.subtract == nullptr ? dequantization.data : dequantization.subtract, element::f32).get(),
124         ngraph::op::TemporaryReplaceOutputType(axes->clone_with_new_inputs({}), element::f32).get(),
125         normalize->get_eps(),
126         normalize->get_eps_mode());
127     NetworkHelper::copyInfo(normalize, newNormalize);
128
129     auto newMultiply = std::make_shared<op::TypeRelaxed<DequantizationMultiply>>(
130         std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{element::f32},
131         ngraph::op::TemporaryReplaceOutputType(newNormalize, element::f32).get(),
132         ngraph::op::TemporaryReplaceOutputType(newScalesConst, element::f32).get());
133
134     replace_node(normalize, newMultiply);
135
136     updateOutput(context, newMultiply, normalize);
137     return true;
138 }
139
140 bool NormalizeL2Transformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
141     return false;
142 }