[LPT] integration: issue #42391 & issue #43001 (#3201)
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / mvn_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <sstream>
9 #include <memory>
10
11 #include <gtest/gtest.h>
12
13 #include <transformations/utils/utils.hpp>
14 #include <transformations/init_node_info.hpp>
15 #include "low_precision/mvn.hpp"
16
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
19 #include "simple_low_precision_transformer.hpp"
20 #include "ngraph_functions/low_precision_transformations/mvn_function.hpp"
21
22 using namespace testing;
23 using namespace ngraph::pass;
24 using namespace ngraph::builder::subgraph;
25
26 class MVNTransformationTestValues {
27 public:
28     class Actual {
29     public:
30         ngraph::element::Type precisionBeforeDequantization;
31         ngraph::builder::subgraph::DequantizationOperations dequantization;
32     };
33
34     class Expected {
35     public:
36         ngraph::element::Type precisionBeforeDequantization;
37         ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
38         ngraph::element::Type precisionAfterOperation;
39         ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
40     };
41
42     ngraph::Shape inputShape;
43     ngraph::AxisSet reductionAxes;
44     bool normalizeVariance;
45     ngraph::pass::low_precision::LayerTransformation::Params params;
46     Actual actual;
47     Expected expected;
48 };
49
50 template <typename T>
51 inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& values) {
52     os << "{ ";
53     for (size_t i = 0; i < values.size(); ++i) {
54         os << values[i];
55         if (i != (values.size() - 1ul)) {
56             os << ", ";
57         }
58     }
59     os << " }";
60     return os;
61 }
62
63 class MVNTransformation : public LayerTransformation, public testing::WithParamInterface<MVNTransformationTestValues> {
64 public:
65     void SetUp() override {
66         const MVNTransformationTestValues testValues = GetParam();
67
68         actualFunction = ngraph::builder::subgraph::MVNFunction::getOriginal(
69             testValues.inputShape,
70             testValues.reductionAxes,
71             testValues.normalizeVariance,
72             testValues.actual.precisionBeforeDequantization,
73             testValues.actual.dequantization);
74
75         SimpleLowPrecisionTransformer transformer;
76         transformer.add<ngraph::pass::low_precision::MVNTransformation, ngraph::opset1::Interpolate>(testValues.params);
77         transformer.transform(actualFunction);
78
79         referenceFunction = ngraph::builder::subgraph::MVNFunction::getReference(
80             testValues.inputShape,
81             testValues.reductionAxes,
82             testValues.normalizeVariance,
83             testValues.expected.precisionBeforeDequantization,
84             testValues.expected.dequantizationBefore,
85             testValues.expected.precisionAfterOperation,
86             testValues.expected.dequantizationAfter);
87     }
88
89     static std::string getTestCaseName(testing::TestParamInfo<MVNTransformationTestValues> obj) {
90         const MVNTransformationTestValues testValues = obj.param;
91
92         std::ostringstream result;
93         result <<
94             toString(testValues.params) << "_" <<
95             testValues.inputShape << "_" <<
96             testValues.reductionAxes << "_" <<
97             testValues.normalizeVariance << "_" <<
98             testValues.actual.precisionBeforeDequantization << "_" <<
99             testValues.actual.dequantization << "_" <<
100             testValues.expected.dequantizationBefore;
101         return result.str();
102     }
103 };
104
105 const std::vector<MVNTransformationTestValues> testValues = {
106     {
107         ngraph::Shape{ 1, 4, 16, 16 },
108         {1, 2, 3},
109         true,
110         LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(false),
111         {
112             ngraph::element::u8,
113             {{ngraph::element::f32}, {-0.32f}, {0.45f}}
114         },
115         {
116             ngraph::element::u8,
117             {{ngraph::element::f32}, {-0.32f}, {0.45f}},
118             ngraph::element::f32,
119             { }
120         }
121     },
122     {
123         ngraph::Shape{ 1, 4, 16, 16 },
124         {1, 2, 3},
125         true,
126         LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(false),
127         {
128             ngraph::element::u8,
129             {{ngraph::element::f32}, {}, {0.45f}}
130         },
131         {
132             ngraph::element::u8,
133             { },
134             ngraph::element::f32,
135             {{}, {}, {1.f}}
136         }
137     },
138     {
139         ngraph::Shape{ 1, 4, 16, 16 },
140         {1, 2, 3},
141         true,
142         LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
143         {
144             ngraph::element::u8,
145             {{ngraph::element::f32}, {127.f}, {0.45f}}
146         },
147         {
148             ngraph::element::u8,
149             {{ngraph::element::f32}, {127.f}, {0.45f}},
150             ngraph::element::f32,
151             {{}, {}, {}}
152         }
153     },
154     {
155         ngraph::Shape{ 1, 4, 16, 16 },
156         {1, 2, 3},
157         true,
158         LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
159         {
160             ngraph::element::u8,
161             {{ngraph::element::f32}, {12.5f}, {0.45f}}
162         },
163         {
164             ngraph::element::u8,
165             {{ngraph::element::f32}, {12.5f}, {0.45f}},
166             ngraph::element::f32,
167             {{}, {}, {}}
168         }
169     },
170     {
171         ngraph::Shape{ 1, 4, 16, 16 },
172         {1, 2, 3},
173         true,
174         LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(false),
175         {
176             ngraph::element::u8,
177             {{ngraph::element::f32}, {127.f}, {0.45f}}
178         },
179         {
180             ngraph::element::u8,
181             {{ngraph::element::f32}, {127.f}, {0.45f}},
182             ngraph::element::f32,
183             {}
184         }
185     },
186
187     {
188         ngraph::Shape{ 1, 4, 16, 16 },
189         {1, 2, 3},
190         true,
191         LayerTransformation::createParamsU8I8(),
192         {
193             ngraph::element::u8,
194             {{ngraph::element::f32}, {}, {-0.5f}}
195         },
196         {
197             ngraph::element::u8,
198             {{}, {}, {}},
199             ngraph::element::f32,
200             {{}, {}, {-1.f}}
201         }
202     },
203
204     {
205         ngraph::Shape{ 1, 4, 16, 16 },
206         {1, 2, 3},
207         false,
208         LayerTransformation::createParamsU8I8(),
209         {
210             ngraph::element::u8,
211             {{ngraph::element::f32}, {}, {0.45f}}
212         },
213         {
214             ngraph::element::u8,
215             {{}, {}, {}},
216             ngraph::element::f32,
217             {{}, {}, {0.45f}}
218         }
219     },
220     {
221         ngraph::Shape{ 1, 2, 2, 2 },
222         {1, 2, 3},
223         false,
224         LayerTransformation::createParamsU8I8(),
225         {
226             ngraph::element::u8,
227             {{ngraph::element::f32}, {}, {{0.45f, 0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
228         },
229         {
230             ngraph::element::u8,
231             {{}, {}, {}},
232             ngraph::element::f32,
233             {{}, {}, {{0.45f, 0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
234         }
235     },
236     {
237         ngraph::Shape{ 1, 2, 2, 2 },
238         {2, 3},
239         true,
240         LayerTransformation::createParamsU8I8(),
241         {
242             ngraph::element::u8,
243             {{ngraph::element::f32}, {}, {{0.45f, -0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
244         },
245         {
246             ngraph::element::u8,
247             {{}, {}, {}},
248             ngraph::element::f32,
249             {{}, {}, {{1.f, -1.f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
250         }
251     },
252     {
253         ngraph::Shape{ 1, 2, 2, 2 },
254         {1, 2, 3},
255         true,
256         LayerTransformation::createParamsU8I8(),
257         {
258             ngraph::element::u8,
259             {{ngraph::element::f32}, {}, {{0.45f, -0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
260         },
261         {
262             ngraph::element::u8,
263             {{ngraph::element::f32}, {}, {{0.45f, -0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}},
264             ngraph::element::f32,
265             {{}, {}, {}}
266         }
267     },
268 };
269
270 TEST_P(MVNTransformation, CompareFunctions) {
271     actualFunction->validate_nodes_and_infer_types();
272     auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
273     ASSERT_TRUE(res.first) << res.second;
274 }
275
276 INSTANTIATE_TEST_CASE_P(
277     LPT,
278     MVNTransformation,
279     ::testing::ValuesIn(testValues),
280     MVNTransformation::getTestCaseName);