[LPT] integration branch: Reshape fix, Concat generalization, runtime info usage...
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / reshape_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <sstream>
9 #include <memory>
10
11 #include <gtest/gtest.h>
12
13 #include <transformations/utils/utils.hpp>
14 #include <transformations/init_node_info.hpp>
15 #include <low_precision/reshape.hpp>
16
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
19 #include "ngraph_functions/low_precision_transformations/reshape_function.hpp"
20 #include "simple_low_precision_transformer.hpp"
21
22 namespace {
23
24 using namespace testing;
25 using namespace ngraph::pass;
26
27 class ReshapeTransformationTestValues {
28 public:
29     class Actual {
30     public:
31         ngraph::element::Type precisionBeforeDequantization;
32         ngraph::builder::subgraph::DequantizationOperations dequantization;
33     };
34
35     class Expected {
36     public:
37         ngraph::element::Type precisionBeforeDequantization;
38         ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
39         ngraph::element::Type precisionAfterOperation;
40         ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
41     };
42
43     ngraph::Shape inputShape;
44     std::vector<int> reshapeConstValues;
45     ngraph::pass::low_precision::LayerTransformation::Params params;
46     Actual actual;
47     Expected expected;
48 };
49
50 inline std::ostream& operator<<(std::ostream& os, const std::vector<int>& values) {
51     os << "{ ";
52     for (size_t i = 0; i < values.size(); ++i) {
53         os << values[i];
54         if (i != (values.size() - 1ul)) {
55             os << ", ";
56         }
57     }
58     os << " }";
59     return os;
60 }
61
62 class ReshapeTransformation : public LayerTransformation, public testing::WithParamInterface<ReshapeTransformationTestValues> {
63 public:
64     void SetUp() override {
65         const ReshapeTransformationTestValues testValues = GetParam();
66
67         actualFunction = ngraph::builder::subgraph::ReshapeFunction::getOriginal(
68             testValues.inputShape,
69             testValues.reshapeConstValues,
70             testValues.actual.precisionBeforeDequantization,
71             testValues.actual.dequantization);
72
73         SimpleLowPrecisionTransformer transformer;
74         transformer.add<ngraph::pass::low_precision::ReshapeTransformation, ngraph::opset1::Reshape>(testValues.params);
75         transformer.transform(actualFunction);
76
77         referenceFunction = ngraph::builder::subgraph::ReshapeFunction::getReference(
78             testValues.inputShape,
79             testValues.reshapeConstValues,
80             testValues.expected.precisionBeforeDequantization,
81             testValues.expected.dequantizationBefore,
82             testValues.expected.precisionAfterOperation,
83             testValues.expected.dequantizationAfter);
84     }
85
86     static std::string getTestCaseName(testing::TestParamInfo<ReshapeTransformationTestValues> obj) {
87         const ReshapeTransformationTestValues testValues = obj.param;
88
89         std::ostringstream result;
90         result <<
91             testValues.inputShape << "_" <<
92             testValues.reshapeConstValues << "_" <<
93             testValues.actual.precisionBeforeDequantization << "_" <<
94             testValues.actual.dequantization << "_" <<
95             testValues.expected.precisionAfterOperation << "_" <<
96             testValues.expected.dequantizationAfter << "_" <<
97             testValues.expected.dequantizationBefore;
98         return result.str();
99     }
100 };
101
102 const std::vector<ReshapeTransformationTestValues> testValues = {
103     // U8: no subtract 3D -> 4D: channels are not affected
104     {
105         ngraph::Shape({ 1, 384, 1024 }),
106         { 1, 384, 16, 64},
107         LayerTransformation::createParamsU8I8(),
108         {
109             ngraph::element::u8,
110             {{ngraph::element::f32}, {}, {0.1f}}
111         },
112         {
113             ngraph::element::u8,
114             {{}, {}, {}},
115             ngraph::element::u8,
116             {{ngraph::element::f32}, {}, {0.1f}}
117         }
118     },
119     // U8: no subtract 3D -> 4D: channels are not affected
120     {
121         ngraph::Shape({ 4, 384, 1024 }),
122         { 4, 384, 16, 64},
123         LayerTransformation::createParamsU8I8(),
124         {
125             ngraph::element::u8,
126             {{ngraph::element::f32}, {}, {0.1f}}
127         },
128         {
129             ngraph::element::u8,
130             {{}, {}, {}},
131             ngraph::element::u8,
132             {{ngraph::element::f32}, {}, {0.1f}}
133         }
134     },
135     // U8: no subtract 3D -> 4D: channels are not affected: no subtract
136     {
137         ngraph::Shape({ 1, 3, 20 }),
138         { 1, 3, 4, 5},
139         LayerTransformation::createParamsU8I8(),
140         {
141             ngraph::element::u8,
142             {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1}}}
143         },
144         {
145             ngraph::element::u8,
146             {{}, {}, {}},
147             ngraph::element::u8,
148             {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}}
149         }
150     },
151     // U8: no subtract 3D -> 4D: channels are not affected: no subtract
152     {
153         ngraph::Shape({ 4, 3, 20 }),
154         { 4, 3, 4, 5},
155         LayerTransformation::createParamsU8I8(),
156         {
157             ngraph::element::u8,
158             {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1}}}
159         },
160         {
161             ngraph::element::u8,
162             {{}, {}, {}},
163             ngraph::element::u8,
164             {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}}
165         }
166     },
167     // U8: no subtract 3D -> 4D: channels are not affected: with subtract
168     {
169         ngraph::Shape({ 1, 3, 20 }),
170         { 1, 3, 4, 5},
171         LayerTransformation::createParamsU8I8(),
172         {
173             ngraph::element::u8,
174             {
175                 {ngraph::element::f32},
176                 {{32, 64, 128}, ngraph::element::f32, {1, 3, 1}},
177                 {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1}}
178             }
179         },
180         {
181             ngraph::element::u8,
182             {{}, {}, {}},
183             ngraph::element::u8,
184             {
185                 {ngraph::element::f32},
186                 {{32, 64, 128}, ngraph::element::f32, {1, 3, 1, 1}},
187                 {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}
188             }
189         }
190     },
191     // U8: no subtract 3D -> 4D: channels are not affected: with subtract
192     {
193         ngraph::Shape({ 1, 3, 20 }),
194         { 1, -1, 4, 5},
195         LayerTransformation::createParamsU8I8(),
196         {
197             ngraph::element::u8,
198             {
199                 {ngraph::element::f32},
200                 {{32, 64, 128}, ngraph::element::f32, {1, 3, 1}},
201                 {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1}}
202             }
203         },
204         {
205             ngraph::element::u8,
206             {{}, {}, {}},
207             ngraph::element::u8,
208             {
209                 {ngraph::element::f32},
210                 {{32, 64, 128}, ngraph::element::f32, {1, 3, 1, 1}},
211                 {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}
212             }
213         }
214     },
215     // U8: no subtract 4D -> 6D: channels are not affected: no subtract
216     {
217         ngraph::Shape({ 1, 3, 4, 5 }),
218         { 1, 3, 20, 1, 1, 1},
219         LayerTransformation::createParamsU8I8(),
220         {
221             ngraph::element::u8,
222             {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}}
223         },
224         {
225             ngraph::element::u8,
226             {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}},
227             ngraph::element::f32,
228             {}
229         }
230     },
231     // U8: no subtract 4D -> 6D: channels are not affected: with subtract
232     {
233         ngraph::Shape({ 1, 3, 4, 5 }),
234         { 1, 3, 20, 1, 1, 1},
235         LayerTransformation::createParamsU8I8(),
236         {
237             ngraph::element::u8,
238             {
239                 {ngraph::element::f32},
240                 {{32, 64, 128}, ngraph::element::f32, {1, 3, 1, 1}},
241                 {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}
242             }
243         },
244         {
245             ngraph::element::u8,
246             {
247                 { ngraph::element::f32 },
248                 {{32, 64, 128}, ngraph::element::f32, {1, 3, 1, 1}},
249                 {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}
250             },
251             ngraph::element::f32,
252             {}
253         }
254     },
255     // U8: no subtract 2D -> 4D: channels are affected: per tensor quantization
256     // TODO: story 38439
257     {
258         ngraph::Shape({ 1, 16, 384, 384 }),
259         { 6144, -1 },
260         LayerTransformation::createParamsU8I8(),
261         {
262             ngraph::element::u8,
263             {{ngraph::element::f32}, {}, {0.1f}}
264         },
265         {
266             ngraph::element::u8,
267             {{ngraph::element::f32}, {}, {0.1f}},
268             ngraph::element::f32,
269             {}
270         }
271     },
272     // U8: no subtract 2D -> 4D: channels are affected: per channel quantization
273     {
274         ngraph::Shape({ 1, 3, 4, 5 }),
275         { 12, -1 },
276         LayerTransformation::createParamsU8I8(),
277         {
278             ngraph::element::u8,
279             {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}}}
280         },
281         {
282             ngraph::element::u8,
283             {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}}},
284             ngraph::element::f32,
285             {{}, {}, {}}
286         }
287     },
288     // U8: no subtract 2D -> 4D: channels are affected: per channel quantization
289     {
290         ngraph::Shape({ 1, 3, 4, 8 }),
291         { 12, -1 },
292         LayerTransformation::createParamsU8I8(),
293         {
294             ngraph::element::u8,
295             {{ngraph::element::f32}, {{0.f, 128.f, 255.f}}, {{0.1f, 0.2f, 0.3f}}}
296         },
297         {
298             ngraph::element::u8,
299             {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32}, {{0.1f, 0.2f, 0.3f}}},
300             ngraph::element::f32,
301             {{}, {}, {}}
302         }
303     },
304     // empty: FP32
305     {
306         ngraph::Shape({ 1, 3, 4, 8 }),
307         { 12, -1 },
308         LayerTransformation::createParamsU8I8(),
309         {
310             ngraph::element::f32,
311             {}
312         },
313         {
314             ngraph::element::f32,
315             {},
316             ngraph::element::f32,
317             {{}, {}, {}}
318         }
319     },
320     // empty: U8
321     {
322         ngraph::Shape({ 1, 3, 4, 8 }),
323         { 12, -1 },
324         LayerTransformation::createParamsU8I8(),
325         {
326             ngraph::element::u8,
327             {}
328         },
329         {
330             ngraph::element::u8,
331             {},
332             ngraph::element::u8,
333             {}
334         }
335     },
336     // U8: no subtract 4D -> 6D: channels are not affected: no subtract
337     {
338         ngraph::Shape({ 1, 3, 1, 1 }),
339         { 1, 3, 1, 1, 1, 1 },
340         LayerTransformation::createParamsU8I8(),
341         {
342             ngraph::element::u8,
343             {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {3, 1, 1}}}
344         },
345         {
346             ngraph::element::u8,
347             {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {3, 1, 1}}},
348             ngraph::element::f32,
349             {}
350         }
351     },
352     // U8: no subtract 4D -> 2D: channels are not affected: per tensor quantization
353     // TODO: story 38439
354     {
355         ngraph::Shape({ 1, 3, 4, 5 }),
356         { 0, -1 },
357         LayerTransformation::createParamsU8I8(),
358         {
359             ngraph::element::u8,
360             {{ngraph::element::f32}, {{128.f}, ngraph::element::f32, {}}, {{0.1f}, ngraph::element::f32, {}}}
361         },
362         {
363             ngraph::element::u8,
364             {},
365             ngraph::element::u8,
366             {{ngraph::element::f32}, {{128.f}, ngraph::element::f32, {}}, {{0.1f}, ngraph::element::f32, {}}}
367         }
368     },
369     // U8: no subtract 4D -> 2D: channels are not affected: per tensor quantization
370     {
371         ngraph::Shape({ 1, 3, 2, 2 }),
372         { 0, -1 },
373         LayerTransformation::createParamsU8I8(),
374         {
375             ngraph::element::u8,
376             {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3, 1, 1}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}}
377         },
378         {
379             ngraph::element::u8,
380             {},
381             ngraph::element::u8,
382             {
383                 {ngraph::element::f32},
384                 {{0.f, 0.f, 0.f, 0.f, 128.f, 128.f, 128.f, 128.f, 255.f, 255.f, 255.f, 255.f}, ngraph::element::f32, {1, 12}},
385                 {{0.1f, 0.1f, 0.1f, 0.1f, 0.2f, 0.2f, 0.2f, 0.2f, 0.3f, 0.3f, 0.3f, 0.3f}, ngraph::element::f32, {1, 12}}
386             }
387         }
388     },
389     // U8: no subtract 4D -> 2D: channels are not affected: per tensor quantization
390     {
391         ngraph::Shape({ 4, 3, 2, 2 }),
392         { 0, -1 },
393         LayerTransformation::createParamsU8I8(),
394         {
395             ngraph::element::u8,
396             {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3, 1, 1}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}}
397         },
398         {
399             ngraph::element::u8,
400             {},
401             ngraph::element::u8,
402             {
403                 {ngraph::element::f32},
404                 {{0.f, 0.f, 0.f, 0.f, 128.f, 128.f, 128.f, 128.f, 255.f, 255.f, 255.f, 255.f}, ngraph::element::f32, {1, 12}},
405                 {{0.1f, 0.1f, 0.1f, 0.1f, 0.2f, 0.2f, 0.2f, 0.2f, 0.3f, 0.3f, 0.3f, 0.3f}, ngraph::element::f32, {1, 12}}
406             }
407         }
408     },
409     // U8: no subtract 4D -> 2D: channels are not affected: per channel quantization: case #1: dequantization operation constant needs broadcast
410     {
411         ngraph::Shape({ 1, 3, 1, 1 }),
412         { 0, -1 },
413         LayerTransformation::createParamsU8I8(),
414         {
415             ngraph::element::u8,
416             {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {3, 1, 1}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {3, 1, 1}}}
417         },
418         {
419             ngraph::element::u8,
420             {{}, {}, {}},
421             ngraph::element::u8,
422             {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3}}},
423         }
424     },
425     // U8: no subtract 4D -> 2D: channels are not affected: per channel quantization: case #2: dequantization operation constant doesn't need broadcast
426     {
427         ngraph::Shape({ 1, 3, 1, 1 }),
428         { 0, -1 },
429         LayerTransformation::createParamsU8I8(),
430         {
431             ngraph::element::u8,
432             {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3, 1, 1}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}}
433         },
434         {
435             ngraph::element::u8,
436             {{}, {}, {}},
437             ngraph::element::u8,
438             {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3}}},
439         }
440     },
441     // U8: no subtract 4D -> 3D: channels are affected: per tensor quantization: case #1: dequantization operation constant needs broadcast
442     {
443         ngraph::Shape({ 1, 3, 4, 5 }),
444         { 0, 0, -1 },
445         LayerTransformation::createParamsU8I8(),
446         {
447             ngraph::element::u8,
448             {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {3, 1, 1}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {3, 1, 1}}}
449         },
450         {
451             ngraph::element::u8,
452             {{}, {}, {}},
453             ngraph::element::u8,
454             {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3, 1}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1}}},
455         }
456     },
457     // U8: no subtract 4D -> 3D: channels are affected: per tensor quantization: case #2: dequantization operation constant doesn't need broadcast
458     {
459         ngraph::Shape({ 1, 3, 4, 5 }),
460         { 0, 0, -1 },
461         LayerTransformation::createParamsU8I8(),
462         {
463             ngraph::element::u8,
464             {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3, 1, 1}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}}
465         },
466         {
467             ngraph::element::u8,
468             {{}, {}, {}},
469             ngraph::element::u8,
470             {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3, 1}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1}}},
471         }
472     },
473     // U8: no subtract 4D -> 2D
474     {
475         ngraph::Shape({ 1, 2048, 1, 1 }),
476         { 1, -1 },
477         LayerTransformation::createParamsU8I8(),
478         {
479             ngraph::element::u8,
480             {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {}}}
481         },
482         {
483             ngraph::element::u8,
484             {{}, {}, {}},
485             ngraph::element::u8,
486             {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {}}}
487         }
488     },
489     // U8: no subtract 4D -> 2D
490     {
491         ngraph::Shape({ 2, 2048, 1, 1 }),
492         { 2, -1 },
493         LayerTransformation::createParamsU8I8(),
494         {
495             ngraph::element::u8,
496             {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1ul}}}
497         },
498         {
499             ngraph::element::u8,
500             {{}, {}, {}},
501             ngraph::element::u8,
502             {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1ul}}}
503         }
504     },
505     // U8: no subtract 4D -> 2D
506     {
507         ngraph::Shape({ 1, 2048, 1, 1 }),
508         { 1, -1 },
509         LayerTransformation::createParamsU8I8(),
510         {
511             ngraph::element::u8,
512             {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1, 1, 1, 1}}}
513         },
514         {
515             ngraph::element::u8,
516             {{}, {}, {}},
517             ngraph::element::u8,
518             {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1, 1}}}
519         }
520     },
521     // U8: no subtract 4D -> 2D: channels are not affected
522     {
523         ngraph::Shape({ 2, 2048, 1, 1 }),
524         { 2, -1},
525         LayerTransformation::createParamsU8I8(),
526         {
527             ngraph::element::u8,
528             {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1, 1, 1, 1}}}
529         },
530         {
531             ngraph::element::u8,
532             {{}, {}, {}},
533             ngraph::element::u8,
534             {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1, 1}}}
535         }
536     }
537 };
538
539 TEST_P(ReshapeTransformation, CompareFunctions) {
540     InitNodeInfo().run_on_function(actualFunction);
541     actualFunction->validate_nodes_and_infer_types();
542     auto res = compare_functions(referenceFunction, actualFunction, true, true);
543     ASSERT_TRUE(res.first) << res.second;
544 }
545
546 INSTANTIATE_TEST_CASE_P(
547     LPT,
548     ReshapeTransformation,
549     ::testing::ValuesIn(testValues),
550     ReshapeTransformation::getTestCaseName);
551
552 } // namespace