[LPT] integration branch: Reshape fix, Concat generalization, runtime info usage...
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / normalize_l2_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <sstream>
9 #include <memory>
10
11 #include <gtest/gtest.h>
12
13 #include <transformations/utils/utils.hpp>
14 #include "simple_low_precision_transformer.hpp"
15 #include <low_precision/normalize_l2.hpp>
16
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "ngraph_functions/low_precision_transformations/normalize_l2_function.hpp"
19
20 using namespace testing;
21 using namespace ngraph::pass;
22 using namespace ngraph::builder::subgraph;
23
24 class NormalizeL2TransformationTestValues {
25 public:
26     low_precision::LayerTransformation::Params transformationParams;
27
28     NormalizeL2ActualValues actual;
29     NormalizeL2ExpectedValues expected;
30 };
31
32 typedef std::tuple<
33     ngraph::element::Type,
34     ngraph::Shape,
35     ngraph::op::EpsMode,
36     NormalizeL2TransformationTestValues> NormalizeL2TransformationParams;
37
38 class NormalizeL2Transformation : public LayerTransformation, public testing::WithParamInterface<NormalizeL2TransformationParams> {
39 public:
40     void SetUp() override {
41         const ngraph::element::Type precision = std::get<0>(GetParam());
42         const ngraph::Shape shape = std::get<1>(GetParam());
43         const ngraph::op::EpsMode epsMode = std::get<2>(GetParam());
44         const NormalizeL2TransformationTestValues params = std::get<3>(GetParam());
45
46         actualFunction = ngraph::builder::subgraph::NormalizeL2Function::getOriginal(
47             precision,
48             shape,
49             epsMode,
50             params.actual);
51         SimpleLowPrecisionTransformer transform;
52         transform.add<low_precision::NormalizeL2Transformation, ngraph::opset1::NormalizeL2>(
53             low_precision::LayerTransformation::Params(params.transformationParams));
54         transform.transform(actualFunction);
55
56         referenceFunction = (!params.transformationParams.supportAsymmetricQuantization) && (!params.expected.subtractValues.empty()) ?
57             ngraph::builder::subgraph::NormalizeL2Function::getOriginal(
58                 precision,
59                 shape,
60                 epsMode,
61                 params.actual) :
62             ngraph::builder::subgraph::NormalizeL2Function::getReference(
63                 precision,
64                 shape,
65                 epsMode,
66                 params.expected);
67     }
68
69     static std::string getTestCaseName(testing::TestParamInfo<NormalizeL2TransformationParams> obj) {
70         ngraph::element::Type precision;
71         ngraph::Shape shape;
72         ngraph::Shape axes;
73         ngraph::op::EpsMode epsMode;
74         NormalizeL2TransformationTestValues params;
75         std::tie(precision, shape, epsMode, params) = obj.param;
76
77         std::ostringstream result;
78         result << toString(params.transformationParams) << precision << "_" << shape << "_" <<
79             axes << epsMode << params.actual << params.expected;
80         return result.str();
81     }
82 };
83
84 TEST_P(NormalizeL2Transformation, CompareFunctions) {
85     actualFunction->validate_nodes_and_infer_types();
86     auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
87     ASSERT_TRUE(res.first) << res.second;
88 }
89
90 const std::vector<ngraph::element::Type> precisions = {
91     ngraph::element::f32,
92     // ngraph::element::f16
93 };
94
95 const std::vector<ngraph::Shape> shapes = {
96     { 1, 4, 16, 16 }
97 };
98
99 std::vector<ngraph::op::EpsMode> epsMode = {
100     ngraph::op::EpsMode::ADD,
101     ngraph::op::EpsMode::MAX
102 };
103
104 const std::vector<NormalizeL2TransformationTestValues> normalizeL2TransformationTestValues = {
105     {
106         LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(false),
107         { ngraph::element::u8, { 1 }, { 2.f }, { -12.3f, -12.3f, -12.3f, -12.3f }},
108         { ngraph::element::u8, { 1 }, { 2.f }, { -1.f,   -1.f,   -1.f, -1.f}}
109     },
110
111     // U8
112     {
113         LayerTransformation::createParamsU8I8(),
114         { ngraph::element::u8, { 1 }, { 2.f }, { -12.3f, -12.3f, -12.3f, -12.3f }},
115         { ngraph::element::u8, { 1 }, { 2.f }, { -1.f,   -1.f,   -1.f, -1.f}}
116     },
117
118     {
119         LayerTransformation::createParamsU8I8(),
120         { ngraph::element::u8, { 1, 2, 3 }, { }, { 12.3f }},
121         { ngraph::element::u8, { 1, 2, 3 }, { }, { 1.f }}
122     },
123
124     // I8
125     {
126         LayerTransformation::createParamsI8I8(),
127         { ngraph::element::i8, { 1 }, { 2.f }, { -12.3f, -12.3f, -12.3f, -12.3f }},
128         { ngraph::element::i8, { 1 }, { 2.f }, { -1.f,   -1.f,   -1.f, -1.f}}
129     },
130
131     {
132         LayerTransformation::createParamsI8I8(),
133         { ngraph::element::i8, { 1, 2, 3 }, { }, { 12.3f }},
134         { ngraph::element::i8, { 1, 2, 3 }, { }, { 1.f }}
135     },
136 };
137
138 INSTANTIATE_TEST_CASE_P(
139     LPT,
140     NormalizeL2Transformation,
141     ::testing::Combine(
142         ::testing::ValuesIn(precisions),
143         ::testing::ValuesIn(shapes),
144         ::testing::ValuesIn(epsMode),
145         ::testing::ValuesIn(normalizeL2TransformationTestValues)),
146     NormalizeL2Transformation::getTestCaseName);