[LPT] integration branch: Reshape fix, Concat generalization, runtime info usage...
[platform/upstream/dldt.git] / inference-engine / tests / ngraph_functions / src / low_precision_transformations / normalize_l2_function.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ngraph_functions/low_precision_transformations/normalize_l2_function.hpp"
6
7 #include <ngraph_ops/type_relaxed.hpp>
8 #include <ngraph/opsets/opset1.hpp>
9 #include "ngraph_functions/subgraph_builders.hpp"
10 #include "low_precision/common/dequantization_op.hpp"
11
12 namespace ngraph {
13 namespace builder {
14 namespace subgraph {
15
16 std::shared_ptr<ngraph::Function> NormalizeL2Function::getOriginal(
17     const ngraph::element::Type precision,
18     const std::pair<ngraph::Shape, ngraph::Shape>& shapes,
19     const ngraph::element::Type precisionOnActivation,
20     const std::vector<uint64_t>& axes,
21     const bool fuseMultiply,
22     const bool shift) {
23     const float low = precisionOnActivation == ngraph::element::u8 ? (0.f + (shift ? 10.f : 0.f)) : (-128.f + (shift ? 10.f : 0.f));
24     const float high = precisionOnActivation == ngraph::element::u8 ? 255.f : 127.f;
25     const float inputScale = 10.f;
26     const float outputScale = 20.f;
27
28
29     const auto paramNode = std::make_shared<ngraph::opset1::Parameter>(precision, shapes.first);
30     paramNode->set_friendly_name("input");
31
32     const auto fakeQuantize = ngraph::builder::makeFakeQuantize(
33         paramNode->output(0), precision, 256, shapes.second,
34         { low / inputScale }, { high / inputScale }, { low / outputScale }, { high / outputScale });
35
36     fakeQuantize->set_friendly_name("fakeQuantize");
37
38     const auto axesNode = std::make_shared<ngraph::op::Constant>(ngraph::element::u64, ngraph::Shape{ axes.size() }, axes);
39     axesNode->set_friendly_name("axes");
40     const auto normalizeL2 = std::make_shared<ngraph::opset1::NormalizeL2>(fakeQuantize->output(0), axesNode, 1e-6, ngraph::op::EpsMode::ADD);
41     normalizeL2->set_friendly_name("normalizeL2");
42
43     ngraph::ResultVector results;
44     if (fuseMultiply) {
45         const auto multiplyConst = std::make_shared<ngraph::op::Constant>(
46             precision, ngraph::Shape{ shapes.first[0], shapes.first[1], 1ul, 1ul }, std::vector<float>{ 2.f });
47         multiplyConst->set_friendly_name("multiplyConst");
48         const auto multiply = std::make_shared<ngraph::opset1::Multiply>(normalizeL2->output(0), multiplyConst);
49         multiply->set_friendly_name("output");
50
51         results = { std::make_shared<ngraph::opset1::Result>(multiply) };
52     } else {
53         normalizeL2->set_friendly_name("output");
54         results = { std::make_shared<ngraph::opset1::Result>(normalizeL2) };
55     }
56
57     const auto function = std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ paramNode }, "NormalizeL2Transformation");
58     return function;
59 }
60
61 std::shared_ptr<ngraph::Function> NormalizeL2Function::getOriginal(
62     const ngraph::element::Type precision,
63     const ngraph::Shape& shape,
64     const ngraph::op::EpsMode& epsMode,
65     const NormalizeL2ActualValues& actualValues) {
66     const auto input = std::make_shared<ngraph::opset1::Parameter>(actualValues.precision, shape);
67     std::shared_ptr<ngraph::Node> parent = input;
68
69     const std::shared_ptr<ngraph::Node> convert = std::make_shared<ngraph::opset1::Convert>(parent, precision);
70     parent = convert;
71
72     if (!actualValues.subtractValues.empty()) {
73         const std::shared_ptr<ngraph::Node> subtract = std::make_shared< ngraph::opset1::Subtract >(
74             parent,
75             std::make_shared<ngraph::opset1::Constant>(
76                 precision, Shape({ actualValues.subtractValues.size() }), actualValues.subtractValues));
77         parent = subtract;
78     }
79
80     if (!actualValues.mutliplyValues.empty()) {
81         const std::shared_ptr<ngraph::Node> multiply = std::make_shared< ngraph::opset1::Multiply >(
82             parent,
83             std::make_shared<ngraph::opset1::Constant>(
84                 precision, Shape({ 1, actualValues.mutliplyValues.size(), 1, 1 }), actualValues.mutliplyValues));
85         parent = multiply;
86     }
87
88     const auto axesNode = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{ actualValues.axes.size() }, actualValues.axes);
89     const auto normalizeL2 = std::make_shared<ngraph::opset1::NormalizeL2>(parent, axesNode, 1e-6, epsMode);
90     normalizeL2->set_friendly_name("output");
91     auto& rtInfo = normalizeL2->get_rt_info();
92     rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("normalizeL2");
93
94     ngraph::ResultVector results = { std::make_shared<ngraph::opset1::Result>(normalizeL2) };
95     const auto function = std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "NormalizeL2Transformation");
96     return function;
97 }
98
99 std::shared_ptr<ngraph::Function> NormalizeL2Function::getReference(
100     const ngraph::element::Type precision,
101     const ngraph::Shape& shape,
102     const ngraph::op::EpsMode& epsMode,
103     const NormalizeL2ExpectedValues& expectedValues) {
104     const auto input = std::make_shared<ngraph::opset1::Parameter>(expectedValues.precision, shape);
105     std::shared_ptr<ngraph::Node> parent = input;
106
107     if (!expectedValues.subtractValues.empty()) {
108         const std::shared_ptr<ngraph::Node> convert = std::make_shared<ngraph::opset1::Convert>(parent, precision);
109         parent = convert;
110
111         const std::shared_ptr<ngraph::Node> subtract = std::make_shared<op::TypeRelaxed<ngraph::opset1::Subtract>>(
112             std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{element::f32},
113             ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get(),
114             ngraph::op::TemporaryReplaceOutputType(std::make_shared<ngraph::opset1::Constant>(
115                 precision,
116                 Shape({ expectedValues.subtractValues.size() }),
117                 expectedValues.subtractValues), element::f32).get());
118         parent = subtract;
119     }
120
121     const auto axesNode = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{ expectedValues.axes.size() }, expectedValues.axes);
122     const auto normalizeL2 = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::NormalizeL2>>(
123         std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{element::f32},
124         ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get(),
125         ngraph::op::TemporaryReplaceOutputType(axesNode, element::f32).get(),
126         1e-6,
127         epsMode);
128     auto& rtInfo = normalizeL2->get_rt_info();
129     rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("normalizeL2");
130     std::shared_ptr<ngraph::Node> output = normalizeL2;
131
132     if (!expectedValues.mutliplyValues.empty()) {
133         const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ngraph::op::TypeRelaxed<pass::low_precision::DequantizationMultiply>>(
134             std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{element::f32},
135             ngraph::op::TemporaryReplaceOutputType(output, element::f32).get(),
136             ngraph::op::TemporaryReplaceOutputType(std::make_shared<ngraph::opset1::Constant>(
137                 precision, Shape({ 1, expectedValues.mutliplyValues.size(), 1, 1 }), expectedValues.mutliplyValues), element::f32).get());
138         multiply->get_rt_info()["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("normalizeL2");
139         output = multiply;
140     }
141     output->set_friendly_name("output");
142
143     ngraph::ResultVector results = { std::make_shared<ngraph::opset1::Result>(output) };
144     const auto function = std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "NormalizeL2Transformation");
145
146     return function;
147 }
148
149 }  // namespace subgraph
150 }  // namespace builder
151 }  // namespace ngraph