0e30a8411b255d9e0f906d9839d67d6fecf6305f
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / mvn_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <sstream>
9 #include <memory>
10
11 #include <gtest/gtest.h>
12
13 #include <transformations/utils/utils.hpp>
14 #include <transformations/init_node_info.hpp>
15 #include "low_precision/mvn.hpp"
16
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
19 #include "simple_low_precision_transformer.hpp"
20 #include "ngraph_functions/low_precision_transformations/mvn_function.hpp"
21
22 using namespace testing;
23 using namespace ngraph::pass;
24 using namespace ngraph::builder::subgraph;
25
26 class MVNTransformationTestValues {
27 public:
28     class Actual {
29     public:
30         ngraph::element::Type precisionBeforeDequantization;
31         ngraph::builder::subgraph::DequantizationOperations dequantization;
32     };
33
34     class Expected {
35     public:
36         ngraph::element::Type precisionBeforeDequantization;
37         ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
38         ngraph::element::Type precisionAfterOperation;
39         ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
40     };
41
42     ngraph::Shape inputShape;
43     ngraph::AxisSet reductionAxes;
44     bool normalizeVariance;
45     ngraph::pass::low_precision::LayerTransformation::Params params;
46     Actual actual;
47     Expected expected;
48 };
49
50 template <typename T>
51 inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& values) {
52     os << "{ ";
53     for (size_t i = 0; i < values.size(); ++i) {
54         os << values[i];
55         if (i != (values.size() - 1ul)) {
56             os << ", ";
57         }
58     }
59     os << " }";
60     return os;
61 }
62
63 class MVNTransformation : public LayerTransformation, public testing::WithParamInterface<MVNTransformationTestValues> {
64 public:
65     void SetUp() override {
66         const MVNTransformationTestValues testValues = GetParam();
67
68         actualFunction = ngraph::builder::subgraph::MVNFunction::getOriginal(
69             testValues.inputShape,
70             testValues.reductionAxes,
71             testValues.normalizeVariance,
72             testValues.actual.precisionBeforeDequantization,
73             testValues.actual.dequantization);
74
75         SimpleLowPrecisionTransformer transformer;
76         transformer.add<ngraph::pass::low_precision::MVNTransformation, ngraph::opset1::Interpolate>(testValues.params);
77         transformer.transform(actualFunction);
78
79         referenceFunction = ngraph::builder::subgraph::MVNFunction::getReference(
80             testValues.inputShape,
81             testValues.reductionAxes,
82             testValues.normalizeVariance,
83             testValues.expected.precisionBeforeDequantization,
84             testValues.expected.dequantizationBefore,
85             testValues.expected.precisionAfterOperation,
86             testValues.expected.dequantizationAfter);
87     }
88
89     static std::string getTestCaseName(testing::TestParamInfo<MVNTransformationTestValues> obj) {
90         const MVNTransformationTestValues testValues = obj.param;
91
92         std::ostringstream result;
93         result <<
94             testValues.inputShape << "_" <<
95             testValues.reductionAxes << "_" <<
96             testValues.normalizeVariance << "_" <<
97             testValues.actual.precisionBeforeDequantization << "_" <<
98             testValues.actual.dequantization << "_" <<
99             testValues.expected.dequantizationBefore;
100         return result.str();
101     }
102 };
103
104 const std::vector<MVNTransformationTestValues> testValues = {
105     {
106         ngraph::Shape{ 1, 4, 16, 16 },
107         {1, 2, 3},
108         true,
109         LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(false),
110         {
111             ngraph::element::u8,
112             {{ngraph::element::f32}, {-0.32f}, {0.45f}}
113         },
114         {
115             ngraph::element::u8,
116             {{ngraph::element::f32}, {-0.32f}, {0.45f}},
117             ngraph::element::f32,
118             { }
119         }
120     },
121     {
122         ngraph::Shape{ 1, 4, 16, 16 },
123         {1, 2, 3},
124         true,
125         LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(false),
126         {
127             ngraph::element::u8,
128             {{ngraph::element::f32}, {}, {0.45f}}
129         },
130         {
131             ngraph::element::u8,
132             { },
133             ngraph::element::f32,
134             {{}, {}, {1.f}}
135         }
136     },
137     {
138         ngraph::Shape{ 1, 4, 16, 16 },
139         {1, 2, 3},
140         true,
141         LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
142         {
143             ngraph::element::u8,
144             {{ngraph::element::f32}, {127.f}, {0.45f}}
145         },
146         {
147             ngraph::element::u8,
148             {{ngraph::element::f32}, {127.f}, {}},
149             ngraph::element::f32,
150             {{}, {}, {1.f}}
151         }
152     },
153     {
154         ngraph::Shape{ 1, 4, 16, 16 },
155         {1, 2, 3},
156         true,
157         LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
158         {
159             ngraph::element::u8,
160             {{ngraph::element::f32}, {12.5f}, {0.45f}}
161         },
162         {
163             ngraph::element::u8,
164             {{ngraph::element::f32}, {12.5f}, {0.45f}},
165             ngraph::element::f32,
166             {}
167         }
168     },
169     {
170         ngraph::Shape{ 1, 4, 16, 16 },
171         {1, 2, 3},
172         true,
173         LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(false),
174         {
175             ngraph::element::u8,
176             {{ngraph::element::f32}, {127.f}, {0.45f}}
177         },
178         {
179             ngraph::element::u8,
180             {{ngraph::element::f32}, {127.f}, {0.45f}},
181             ngraph::element::f32,
182             {}
183         }
184     },
185
186     {
187         ngraph::Shape{ 1, 4, 16, 16 },
188         {1, 2, 3},
189         true,
190         LayerTransformation::createParamsU8I8(),
191         {
192             ngraph::element::u8,
193             {{ngraph::element::f32}, {}, {-0.5f}}
194         },
195         {
196             ngraph::element::u8,
197             {{}, {}, {}},
198             ngraph::element::f32,
199             {{}, {}, {-1.f}}
200         }
201     },
202
203     {
204         ngraph::Shape{ 1, 4, 16, 16 },
205         {1, 2, 3},
206         false,
207         LayerTransformation::createParamsU8I8(),
208         {
209             ngraph::element::u8,
210             {{ngraph::element::f32}, {}, {0.45f}}
211         },
212         {
213             ngraph::element::u8,
214             {{}, {}, {}},
215             ngraph::element::f32,
216             {{}, {}, {0.45f}}
217         }
218     },
219     {
220         ngraph::Shape{ 1, 2, 2, 2 },
221         {1, 2, 3},
222         false,
223         LayerTransformation::createParamsU8I8(),
224         {
225             ngraph::element::u8,
226             {{ngraph::element::f32}, {}, {{0.45f, 0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
227         },
228         {
229             ngraph::element::u8,
230             {{}, {}, {}},
231             ngraph::element::f32,
232             {{}, {}, {{0.45f, 0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
233         }
234     },
235     {
236         ngraph::Shape{ 1, 2, 2, 2 },
237         {2, 3},
238         true,
239         LayerTransformation::createParamsU8I8(),
240         {
241             ngraph::element::u8,
242             {{ngraph::element::f32}, {}, {{0.45f, -0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
243         },
244         {
245             ngraph::element::u8,
246             {{}, {}, {}},
247             ngraph::element::f32,
248             {{}, {}, {{1.f, -1.f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
249         }
250     },
251     {
252         ngraph::Shape{ 1, 2, 2, 2 },
253         {1, 2, 3},
254         true,
255         LayerTransformation::createParamsU8I8(),
256         {
257             ngraph::element::u8,
258             {{ngraph::element::f32}, {}, {{0.45f, -0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
259         },
260         {
261             ngraph::element::u8,
262             {{ngraph::element::f32}, {}, {{0.45f, -0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}},
263             ngraph::element::f32,
264             {{}, {}, {}}
265         }
266     },
267 };
268
269 TEST_P(MVNTransformation, CompareFunctions) {
270     actualFunction->validate_nodes_and_infer_types();
271     auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
272     ASSERT_TRUE(res.first) << res.second;
273 }
274
275 INSTANTIATE_TEST_CASE_P(
276     LPT,
277     MVNTransformation,
278     ::testing::ValuesIn(testValues),
279     MVNTransformation::getTestCaseName);