[LPT] integration: issue #42391 & issue #43001 (#3201)
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / clamp_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <sstream>
9 #include <gtest/gtest.h>
10
11 #include <transformations/init_node_info.hpp>
12 #include <low_precision/clamp.hpp>
13
14 #include "common_test_utils/ngraph_test_utils.hpp"
15 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
16 #include "ngraph_functions/low_precision_transformations/clamp_function.hpp"
17 #include "simple_low_precision_transformer.hpp"
18
19
20 namespace {
21 using namespace testing;
22 using namespace ngraph::pass;
23
24 class ClampTransformationTestValues {
25 public:
26     class Actual {
27     public:
28         ngraph::element::Type precisionBeforeDequantization;
29         ngraph::builder::subgraph::DequantizationOperations dequantization;
30     };
31
32     class Expected {
33     public:
34         ngraph::element::Type precisionBeforeDequantization;
35         ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
36         ngraph::element::Type precisionAfterOperation;
37         ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
38     };
39
40     ngraph::Shape inputShape;
41     ngraph::pass::low_precision::LayerTransformation::Params params;
42     Actual actual;
43     Expected expected;
44 };
45
46 class ClampTransformation : public LayerTransformation, public testing::WithParamInterface<ClampTransformationTestValues> {
47 public:
48     void SetUp() override {
49         const ClampTransformationTestValues testValues = GetParam();
50
51         actualFunction = ngraph::builder::subgraph::ClampFunction::getOriginal(
52             testValues.inputShape,
53             testValues.actual.precisionBeforeDequantization,
54             testValues.actual.dequantization);
55
56         SimpleLowPrecisionTransformer transformer;
57         transformer.add<ngraph::pass::low_precision::ClampTransformation, ngraph::opset1::Clamp>(testValues.params);
58         transformer.transform(actualFunction);
59
60         referenceFunction = ngraph::builder::subgraph::ClampFunction::getReference(
61             testValues.inputShape,
62             testValues.expected.precisionBeforeDequantization,
63             testValues.expected.dequantizationBefore,
64             testValues.expected.precisionAfterOperation,
65             testValues.expected.dequantizationAfter);
66     }
67
68     static std::string getTestCaseName(testing::TestParamInfo<ClampTransformationTestValues> obj) {
69         const ClampTransformationTestValues testValues = obj.param;
70
71         std::ostringstream result;
72         result << toString(testValues.params) << "_" <<
73             testValues.inputShape << "_" <<
74             testValues.actual.precisionBeforeDequantization << "_" <<
75             testValues.actual.dequantization << "_" <<
76             testValues.expected.dequantizationBefore;
77         return result.str();
78     }
79 };
80
81 TEST_P(ClampTransformation, CompareFunctions) {
82     InitNodeInfo().run_on_function(actualFunction);
83     actualFunction->validate_nodes_and_infer_types();
84
85     auto res = compare_functions(referenceFunction, actualFunction, true, true);
86     ASSERT_TRUE(res.first) << res.second;
87 }
88
89 const std::vector<ClampTransformationTestValues> testValues = {
90     // U8 per tensor quantization
91     {
92         ngraph::Shape({ 1, 3, 224, 224 }),
93         LayerTransformation::createParamsU8I8(),
94         // ActualValues
95         {
96             ngraph::element::u8,
97             {{ngraph::element::f32}, {128.f}, {3.f}}
98         },
99         // ExpectedValues
100         {
101             ngraph::element::u8,
102             {{}, {}, {}},
103             ngraph::element::f32,
104             {{}, {128.f}, {3.f}}
105         }
106     },
107     // I8 per tensor quantization
108     {
109         ngraph::Shape({ 1, 3, 224, 224 }),
110         LayerTransformation::createParamsI8I8(),
111         {
112             ngraph::element::i8,
113             {{ngraph::element::f32}, {128.f}, {-5.f}}
114         },
115         {
116             ngraph::element::i8,
117             {{}, {}, {}},
118             ngraph::element::f32,
119             {{}, {128.f}, {-5.f}}
120         }
121     },
122     // U8 without convert
123     {
124         ngraph::Shape({ 1, 3, 224, 224 }),
125         LayerTransformation::createParamsU8I8(),
126         {
127             ngraph::element::f32,
128             {{}, {128.f}, {3.f}}
129         },
130         {
131             ngraph::element::f32,
132             {{}, {}, {}},
133             ngraph::element::f32,
134             {{}, {128.f}, {3.f}}
135         }
136     },
137     // I8 without convert
138     {
139         ngraph::Shape({ 1, 3, 224, 224 }),
140         LayerTransformation::createParamsI8I8(),
141         {
142             ngraph::element::f32,
143             {{}, {128.f}, {3.f}}
144         },
145         {
146             ngraph::element::f32,
147             {{}, {}, {}},
148             ngraph::element::f32,
149             {{}, {128.f}, {3.f}}
150         }
151 },
152     // U8 without subtract
153     {
154         ngraph::Shape({ 1, 3, 224, 224 }),
155         LayerTransformation::createParamsU8I8(),
156         {
157             ngraph::element::u8,
158             {{ngraph::element::f32}, {}, {3.f}}
159         },
160         {
161             ngraph::element::u8,
162             {{}, {}, {}},
163             ngraph::element::f32,
164             {{}, {}, {3.f}}
165         }
166     },
167     // I8 without subtract
168     {
169         ngraph::Shape({ 1, 3, 224, 224 }),
170         LayerTransformation::createParamsI8I8(),
171         {
172             ngraph::element::i8,
173             {{ngraph::element::f32}, {}, {3.f}}
174         },
175         {
176             ngraph::element::i8,
177             {{}, {}, {}},
178             ngraph::element::f32,
179             {{}, {}, {3.f}}
180         }
181     },
182     // U8 per channel quantization with different values
183     {
184         ngraph::Shape({ 1, 3, 224, 224 }),
185         LayerTransformation::createParamsU8I8(),
186         {
187             ngraph::element::u8,
188             {
189                 {ngraph::element::f32},
190                 {{128.f, 0.f, 128.f / 2}},
191                 {{3.f, 1.f, 2.f}}
192             }
193         },
194         {
195             ngraph::element::u8,
196             {
197                 {ngraph::element::f32},
198                 {{128.f, 0.f, 128.f / 2}},
199                 {{3.f, 1.f, 2.f}}
200             },
201             ngraph::element::f32,
202             {{}, {}, {}}
203         }
204     },
205     // I8 per channel quantization with different values
206     {
207         ngraph::Shape({ 1, 3, 224, 224 }),
208         LayerTransformation::createParamsI8I8(),
209         {
210             ngraph::element::i8,
211             {
212                 {ngraph::element::f32},
213                 {{128.f, 0.f, 128.f / 2}},
214                 {{3.f, 1.f, 2.f}}
215             }
216         },
217         {
218             ngraph::element::i8,
219             {
220                 {ngraph::element::f32},
221                 {{128.f, 0.f, 128.f / 2}},
222                 {{3.f, 1.f, 2.f}}
223             },
224             ngraph::element::f32,
225             {{}, {}, {}}
226         }
227     },
228     // U8 per channel quantization with the same values
229     {
230         ngraph::Shape({ 1, 3, 224, 224 }),
231         LayerTransformation::createParamsU8I8(),
232         {
233             ngraph::element::u8,
234             {
235                 {ngraph::element::f32},
236                 {{128.f, 128.f, 128.f}},
237                 {{3.f, 3.f, 3.f}}
238             }
239         },
240         {
241             ngraph::element::u8,
242             {{}, {}, {}},
243             ngraph::element::f32,
244             {
245                 {},
246                 {{128.f, 128.f, 128.f}},
247                 {{3.f, 3.f, 3.f}}
248             },
249         }
250     },
251     // I8 per channel quantization with the same values
252     {
253         ngraph::Shape({ 1, 3, 224, 224 }),
254         LayerTransformation::createParamsI8I8(),
255         {
256             ngraph::element::i8,
257             {
258                 {ngraph::element::f32},
259                 {{128.f, 128.f, 128.f}},
260                 {{3.f, 3.f, 3.f}}
261             }
262         },
263         {
264             ngraph::element::i8,
265             {{}, {}, {}},
266             ngraph::element::f32,
267             {
268                 {},
269                 {{128.f, 128.f, 128.f}},
270                 {{3.f, 3.f, 3.f}}
271             },
272         }
273     },
274     // U8 dequantization in second dimension
275     {
276         ngraph::Shape({ 1, 3, 4, 4 }),
277         LayerTransformation::createParamsU8I8(),
278         {
279             ngraph::element::u8,
280             {
281                 {ngraph::element::f32},
282                 {{128.f, 128.f, 128.f, 128.f}, ngraph::element::f32, {1, 1, 4, 1}},
283                 {{3.f, 3.f, 3.f, 3.f}, ngraph::element::f32, {1, 1, 4, 1}}
284             }
285         },
286         {
287             ngraph::element::u8,
288             {
289                 {ngraph::element::f32},
290                 {{128.f, 128.f, 128.f, 128.f}, ngraph::element::f32, {1, 1, 4, 1}},
291                 {{3.f, 3.f, 3.f, 3.f}, ngraph::element::f32, {1, 1, 4, 1}}
292             },
293             ngraph::element::f32,
294             {{}, {}, {}}
295         }
296     },
297     // I8 dequantization in second dimension
298     {
299         ngraph::Shape({ 1, 3, 4, 4 }),
300         LayerTransformation::createParamsI8I8(),
301         {
302             ngraph::element::i8,
303             {
304                 {ngraph::element::f32},
305                 {{128.f, 128.f, 128.f, 128.f}, ngraph::element::f32, {1, 1, 4, 1}},
306                 {{3.f, 3.f, 3.f, 3.f}, ngraph::element::f32, {1, 1, 4, 1}}
307             }
308         },
309         {
310             ngraph::element::i8,
311             {
312                 {ngraph::element::f32},
313                 {{128.f, 128.f, 128.f, 128.f}, ngraph::element::f32, {1, 1, 4, 1}},
314                 {{3.f, 3.f, 3.f, 3.f}, ngraph::element::f32, {1, 1, 4, 1}}
315             },
316             ngraph::element::f32,
317             {{}, {}, {}}
318         }
319     },
320     // U8 asymmetric quantization
321     {
322         ngraph::Shape({ 1, 3, 224, 224 }),
323         LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
324         {
325             ngraph::element::u8,
326             {
327                 {ngraph::element::f32},
328                 {{ 128.f, 0.f, 128.f }},
329                 {{ 3.f, 3.f, 3.f }}
330             }
331         },
332         {
333             ngraph::element::u8,
334             {
335                 {ngraph::element::f32},
336                 {{ 128.f, 0.f, 128.f }},
337                 {{ 3.f, 3.f, 3.f }}
338             },
339             ngraph::element::f32,
340             {{}, {}, {}}
341         }
342     },
343     // U8 without asymmetric quantization
344     {
345         ngraph::Shape({ 1, 3, 224, 224 }),
346         LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(false),
347         {
348             ngraph::element::u8,
349             {
350                 {ngraph::element::f32},
351                 {{ 128.f, 0.f, 128.f }},
352                 {{ 3.f, 3.f, 3.f }}
353             }
354         },
355         {
356             ngraph::element::u8,
357             {
358                 {ngraph::element::f32},
359                 {{ 128.f, 0.f, 128.f }},
360                 {{ 3.f, 3.f, 3.f }}
361             },
362             ngraph::element::f32,
363             {{}, {}, {}}
364         }
365     },
366     // per channel quantization with small values
367     {
368         ngraph::Shape({ 1, 3, 224, 224 }),
369         LayerTransformation::createParamsU8I8(),
370         {
371             ngraph::element::u8,
372             {
373                     {ngraph::element::f32},
374                     {{1e-14, 1e-12, 1e-15}},
375                     {{1e-14, 1e-12, 1e-15}}
376             }
377         },
378         {
379             ngraph::element::u8,
380             {
381                     {ngraph::element::f32},
382                     {{1e-14, 1e-12, 1e-15}},
383                     {{1e-14, 1e-12, 1e-15}}
384             },
385             ngraph::element::f32,
386             {{}, {}, {}}
387         }
388     },
389 };
390 INSTANTIATE_TEST_CASE_P(
391     LPT,
392     ClampTransformation,
393     ::testing::ValuesIn(testValues),
394     ClampTransformation::getTestCaseName);
395 } // namespace