[LPT] integration branch: Reshape fix, Concat generalization, runtime info usage...
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / concat_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <sstream>
9 #include <memory>
10
11 #include <gtest/gtest.h>
12
13 #include <transformations/utils/utils.hpp>
14 #include <transformations/init_node_info.hpp>
15 #include <low_precision/transformer.hpp>
16 #include <low_precision/concat.hpp>
17 #include <low_precision/concat_multi_channels.hpp>
18
19 #include "common_test_utils/ngraph_test_utils.hpp"
20 #include "ngraph_functions/low_precision_transformations/concat_function.hpp"
21 #include "ngraph_functions/low_precision_transformations/common/fake_quantize_on_data.hpp"
22 #include "simple_low_precision_transformer.hpp"
23
24 using namespace testing;
25 using namespace ngraph;
26 using namespace ngraph::pass;
27
28 namespace {
29
30 class ConcatTransformationActualValues {
31 public:
32     ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize1;
33     ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize2;
34 };
35
36 inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationActualValues& values) {
37     return out << "_" << values.fakeQuantize1 << "_" << values.fakeQuantize2;
38 }
39
40 class ConcatTransformationResultValues {
41 public:
42     ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize1;
43     ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize2;
44     ngraph::builder::subgraph::DequantizationOperations dequantizationOperations;
45 };
46
47 inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationResultValues& values) {
48     return out << "_" << values.fakeQuantize1 << "_" << values.fakeQuantize2 << "_" << values.dequantizationOperations;
49 }
50
51 class ConcatTransformationTestValues {
52 public:
53     ngraph::pass::low_precision::LayerTransformation::Params params;
54     bool multiChannels;
55     ConcatTransformationActualValues actual;
56     ConcatTransformationResultValues result;
57 };
58
59 inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationTestValues& values) {
60     return out << "_" << values.multiChannels << "_" << values.actual << "_" << values.result;
61 }
62
63 typedef std::tuple <
64     ngraph::element::Type,
65     bool,
66     ngraph::Shape,
67     ConcatTransformationTestValues
68 > ConcatTransformationParams;
69
70 class ConcatTransformation : public LayerTransformation, public testing::WithParamInterface<ConcatTransformationParams> {
71 public:
72     void SetUp() override {
73         const ngraph::element::Type precision = std::get<0>(GetParam());
74         const bool updatePrecisions = std::get<1>(GetParam());
75         const ngraph::Shape shape = std::get<2>(GetParam());
76         ConcatTransformationTestValues testValues = std::get<3>(GetParam());
77
78         testValues.params.updatePrecisions = updatePrecisions;
79         if (!updatePrecisions) {
80             testValues.result.fakeQuantize1.outputPrecision = testValues.actual.fakeQuantize1.outputPrecision;
81             testValues.result.fakeQuantize2.outputPrecision = testValues.actual.fakeQuantize2.outputPrecision;
82         }
83
84         actualFunction = ngraph::builder::subgraph::ConcatFunction::getOriginal(
85             precision,
86             shape,
87             testValues.actual.fakeQuantize1,
88             testValues.actual.fakeQuantize2);
89
90         SimpleLowPrecisionTransformer transform;
91         if (testValues.multiChannels) {
92             transform.add<ngraph::pass::low_precision::ConcatMultiChannelsTransformation, ngraph::opset1::Concat>(testValues.params);
93         } else {
94             transform.add<ngraph::pass::low_precision::ConcatTransformation, ngraph::opset1::Concat>(testValues.params);
95         }
96         transform.transform(actualFunction);
97
98         referenceFunction = ngraph::builder::subgraph::ConcatFunction::getReference(
99             precision,
100             shape,
101             testValues.result.fakeQuantize1,
102             testValues.result.fakeQuantize2,
103             testValues.result.dequantizationOperations);
104     }
105
106     static std::string getTestCaseName(testing::TestParamInfo<ConcatTransformationParams> obj) {
107         const ngraph::element::Type precision = std::get<0>(obj.param);
108         const bool updatePrecision = std::get<1>(obj.param);
109         const ngraph::Shape shape = std::get<2>(obj.param);
110         const ConcatTransformationTestValues testValues = std::get<3>(obj.param);
111
112         std::ostringstream result;
113         result <<
114             LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" <<
115             (testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") <<
116             (updatePrecision ? "updatePrecision_" : "notUpdatePrecision_") <<
117             testValues.actual << "_" <<
118             testValues.result << "_";
119         return result.str();
120     }
121 };
122
123 TEST_P(ConcatTransformation, CompareFunctions) {
124     actualFunction->validate_nodes_and_infer_types();
125     auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
126     ASSERT_TRUE(res.first) << res.second;
127 }
128
129 const std::vector<ngraph::element::Type> precisions = {
130     ngraph::element::f32,
131     // ngraph::element::f16
132 };
133
134 const std::vector<bool> updatePrecisions = { true, false };
135
136 const std::vector<ConcatTransformationTestValues> testValues = {
137     // U8: concat
138     {
139         LayerTransformation::createParamsU8I8(),
140         false,
141         {
142             { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
143             { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} }
144         },
145         {
146             { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
147             { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
148             { ngraph::element::f32, {}, { 0.01f } }
149         }
150     },
151     // U8: concat
152     {
153         LayerTransformation::createParamsU8I8(),
154         false,
155         {
156             { 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {2.55f}, {0.f}, {2.55f} },
157             { 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {2.55f}, {0.f}, {2.55f} }
158         },
159         {
160             { 256ul, {{1}, {1}, {}, {}}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
161             { 256ul, {{1}, {1}, {}, {}}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
162             { ngraph::element::f32, {}, { 0.01f } }
163         }
164     },
165     // U8: concat
166     {
167         LayerTransformation::createParamsU8I8(),
168         false,
169         {
170             { 256ul, {{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}}, {0.f}, {2.55f}, {0.f}, {2.55f} },
171             { 256ul, {{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}}, {0.f}, {2.55f}, {0.f}, {2.55f} }
172         },
173         {
174             { 256ul, {{1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
175             { 256ul, {{1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
176             { ngraph::element::f32, {}, { 0.01f } }
177         }
178     },
179     // U8: concat multi channels
180     {
181         LayerTransformation::createParamsU8I8(),
182         true,
183         {
184             { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
185             { 256ul, {}, {0.f}, {1.275f}, {0.f}, {1.275f} }
186         },
187         {
188             { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
189             { 256ul, {}, {0.f}, {1.275f}, {0.f}, {255.f}, ngraph::element::u8 },
190             { ngraph::element::f32, {}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} }
191         }
192     },
193     // U8: concat multi channels
194     {
195         LayerTransformation::createParamsU8I8(),
196         true,
197         {
198             { 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {2.55f}, {0.f}, {2.55f} },
199             { 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {1.275f}, {0.f}, {1.275f} }
200         },
201         {
202             { 256ul, {{1}, {1}, {}, {}}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
203             { 256ul, {{1}, {1}, {}, {}}, {0.f}, {1.275f}, {0.f}, {255.f}, ngraph::element::u8 },
204             { ngraph::element::f32, {}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} }
205         }
206     },
207     // U8: concat multi channels
208     {
209         LayerTransformation::createParamsU8I8(),
210         true,
211         {
212             {
213                 256ul,
214                 {{1, 3, 1, 1}, {1, 3, 1, 1}, {1, 3, 1, 1}, {1, 3, 1, 1}},
215                 {0.f, 0.f, 0.f}, {2.55f, 2.55f, 2.55f}, {0.f, 0.f, 0.f}, {2.55f / 1.f, 2.55f / 2.f, 2.55f / 3.f},
216                 ngraph::element::f32
217             },
218             {
219                 256ul,
220                 {{1, 3, 1, 1}, {1, 3, 1, 1}, {1, 3, 1, 1}, {1, 3, 1, 1}},
221                 {0.f, 0.f, 0.f}, {1.275f, 1.275f, 1.275f}, {0.f, 0.f, 0.f}, {1.275f / 1.f, 1.275f / 2.f, 1.275f / 3.f},
222                 ngraph::element::f32
223             }
224         },
225         {
226             {
227                 256ul,
228                 {{1, 3, 1, 1}, {1, 3, 1, 1}, {}, {}},
229                 {0.f, 0.f, 0.f}, {2.55f, 2.55f, 2.55f}, {0.f}, {255.f},
230                 ngraph::element::u8
231             },
232             {
233                 256ul,
234                 {{1, 3, 1, 1}, {1, 3, 1, 1}, {}, {}},
235                 {0.f, 0.f, 0.f}, {1.275f, 1.275f, 1.275f}, {0.f}, {255.f},
236                 ngraph::element::u8
237             },
238             { ngraph::element::f32, {}, {{ 0.01f / 1.f, 0.01f / 2.f, 0.01f / 3.f, 0.005f / 1.f, 0.005f / 2.f, 0.005f / 3.f }} }
239         }
240     },
241     // U8: concat multi channels with subtract
242     {
243         LayerTransformation::createParamsU8I8(),
244         true,
245         {
246             { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
247             { 256ul, {}, {1.275f}, {2.55f}, {1.275f}, {2.55f} }
248         },
249         {
250             { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
251             { 256ul, {}, {1.275f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
252             {
253                 ngraph::element::f32,
254                 {{ 0.f, 0.f, 0.f, -255.f, -255.f, -255.f }},
255                 {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }}
256             }
257         }
258     },
259     // I8
260     {
261         LayerTransformation::createParamsI8I8(),
262         false,
263         {
264             { 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
265             { 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
266         },
267         {
268             { 256ul, {}, {-1.28f}, {1.27f}, {-128.f}, {127.f}, ngraph::element::i8 },
269             { 256ul, {}, {-1.28f}, {1.27f}, {-128.f}, {127.f}, ngraph::element::i8 },
270             { ngraph::element::f32, {}, { 0.01f } }
271         }
272     },
273     // mixed: U8 + I8: concat (check constant values here)
274     {
275         LayerTransformation::createParamsU8I8(),
276         false,
277         {
278             { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
279             { 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
280         },
281         {
282             { 256ul, {}, {0.f}, {2.55f}, {85.f}, {255.f}, ngraph::element::u8 },
283             { 256ul, {}, {-1.28f}, {1.27f}, {0.f}, {170.f}, ngraph::element::u8 },
284             { ngraph::element::f32, { 85 }, { 0.015f } }
285         }
286     },
287     // mixed: U8 + I8: concat multi channels
288     {
289         LayerTransformation::createParamsU8I8(),
290         true,
291         {
292             { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
293             { 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
294         },
295         {
296             { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
297             { 256ul, {}, {-1.28f}, {1.27f}, {0.f}, {255.f}, ngraph::element::u8 },
298             { ngraph::element::f32, {{ 0.f, 0.f, 0.f, 128.f, 128.f, 128.f }}, { 0.01f } }
299         }
300     },
301     // mixed: I8 + U8: concat (check constant values here)
302     {
303         LayerTransformation::createParamsU8I8(),
304         false,
305         {
306             { 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
307             { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} }
308         },
309         {
310             { 256ul, {}, {-1.28f}, {1.27f}, {0.f}, {170.f}, ngraph::element::u8 },
311             { 256ul, {}, {0.f}, {2.55f}, {85.f}, {255.f}, ngraph::element::u8 },
312             { ngraph::element::f32, { 85 }, { 0.015f } }
313         }
314     },
315     // real case from ctdet_coco_dlav0_384 model, coverage bad rounding
316     {
317         LayerTransformation::createParamsU8I8(),
318         false,
319         {
320             { 256ul, {}, {-1.28f}, {1.27f}, {0.f}, {2.3007815f} },
321             { 256ul, {}, {0.f}, {2.55f}, {-3.873046875f}, {3.84375} }
322         },
323         {
324             { 256ul, {}, {-1.28f}, {1.27f}, {128.f}, {204.f}, ngraph::element::u8 },
325             { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
326             { ngraph::element::f32, { 128 }, { 0.0302619f } }
327         }
328     }
329 };
330
331 const std::vector<ngraph::Shape> shapes = {
332     { 1, 3, 9, 9 },
333     { 4, 3, 9, 9 }
334 };
335
336 INSTANTIATE_TEST_CASE_P(
337     LPT,
338     ConcatTransformation,
339     ::testing::Combine(
340         ::testing::ValuesIn(precisions),
341         ::testing::ValuesIn(updatePrecisions),
342         ::testing::ValuesIn(shapes),
343         ::testing::ValuesIn(testValues)),
344     ConcatTransformation::getTestCaseName);
345 }  // namespace