29af334c6a9bf72570e3017570ac9cfdeaa646e2
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / concat_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <sstream>
9 #include <memory>
10
11 #include <gtest/gtest.h>
12
13 #include <transformations/utils/utils.hpp>
14 #include <transformations/init_node_info.hpp>
15 #include <low_precision/transformer.hpp>
16 #include <low_precision/concat.hpp>
17 #include <low_precision/concat_multi_channels.hpp>
18
19 #include "common_test_utils/ngraph_test_utils.hpp"
20 #include "ngraph_functions/low_precision_transformations/concat_function.hpp"
21 #include "ngraph_functions/low_precision_transformations/common/fake_quantize_on_data.hpp"
22 #include "simple_low_precision_transformer.hpp"
23
24 using namespace testing;
25 using namespace ngraph;
26 using namespace ngraph::pass;
27
28 namespace {
29
30 class ConcatTransformationActualValues {
31 public:
32     ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1;
33     ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2;
34 };
35
36 inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationActualValues& values) {
37     return out << "_" << values.fakeQuantize1 << "_" << values.fakeQuantize2;
38 }
39
40 class ConcatTransformationResultValues {
41 public:
42     ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1;
43     ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2;
44     ngraph::builder::subgraph::DequantizationOperations dequantizationOperations;
45 };
46
47 inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationResultValues& values) {
48     return out << "_" << values.fakeQuantize1 << "_" << values.fakeQuantize2 << "_" << values.dequantizationOperations;
49 }
50
51 class ConcatTransformationTestValues {
52 public:
53     ngraph::pass::low_precision::LayerTransformation::Params params;
54     bool multiChannels;
55     ConcatTransformationActualValues actual;
56     ConcatTransformationResultValues result;
57 };
58
59 inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationTestValues& values) {
60     return out << "_" << values.multiChannels << "_" << values.actual << "_" << values.result;
61 }
62
63 typedef std::tuple <
64     ngraph::element::Type,
65     bool,
66     ngraph::Shape,
67     ConcatTransformationTestValues
68 > ConcatTransformationParams;
69
70 class ConcatTransformation : public LayerTransformation, public testing::WithParamInterface<ConcatTransformationParams> {
71 public:
72     void SetUp() override {
73         const ngraph::element::Type precision = std::get<0>(GetParam());
74         const bool updatePrecisions = std::get<1>(GetParam());
75         const ngraph::Shape shape = std::get<2>(GetParam());
76         ConcatTransformationTestValues testValues = std::get<3>(GetParam());
77
78         testValues.params.updatePrecisions = updatePrecisions;
79         if (!updatePrecisions) {
80             testValues.result.fakeQuantize1.outputPrecision = testValues.actual.fakeQuantize1.outputPrecision;
81             testValues.result.fakeQuantize2.outputPrecision = testValues.actual.fakeQuantize2.outputPrecision;
82         }
83
84         actualFunction = ngraph::builder::subgraph::ConcatFunction::getOriginal(
85             precision,
86             shape,
87             testValues.actual.fakeQuantize1,
88             testValues.actual.fakeQuantize2);
89         SimpleLowPrecisionTransformer transform;
90         if (testValues.multiChannels) {
91             transform.add<ngraph::pass::low_precision::ConcatMultiChannelsTransformation, ngraph::opset1::Concat>(testValues.params);
92         } else {
93             transform.add<ngraph::pass::low_precision::ConcatTransformation, ngraph::opset1::Concat>(testValues.params);
94         }
95         transform.transform(actualFunction);
96
97         referenceFunction = ngraph::builder::subgraph::ConcatFunction::getReference(
98             precision,
99             shape,
100             testValues.result.fakeQuantize1,
101             testValues.result.fakeQuantize2,
102             testValues.result.dequantizationOperations);
103     }
104
105     static std::string getTestCaseName(testing::TestParamInfo<ConcatTransformationParams> obj) {
106         const ngraph::element::Type precision = std::get<0>(obj.param);
107         const bool updatePrecision = std::get<1>(obj.param);
108         const ngraph::Shape shape = std::get<2>(obj.param);
109         const ConcatTransformationTestValues testValues = std::get<3>(obj.param);
110
111         std::ostringstream result;
112         result <<
113             LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" <<
114             (testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") <<
115             (updatePrecision ? "updatePrecision_" : "notUpdatePrecision_") <<
116             testValues.actual << "_" <<
117             testValues.result << "_";
118         return result.str();
119     }
120 };
121
122 TEST_P(ConcatTransformation, CompareFunctions) {
123     actualFunction->validate_nodes_and_infer_types();
124     auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
125     ASSERT_TRUE(res.first) << res.second;
126 }
127
128 const std::vector<ngraph::element::Type> precisions = {
129     ngraph::element::f32,
130     // ngraph::element::f16
131 };
132
133 const std::vector<bool> updatePrecisions = { true, false };
134
135 const std::vector<ConcatTransformationTestValues> testValues = {
136     // U8: concat
137     {
138         LayerTransformation::createParamsU8I8(),
139         false,
140         {
141             { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
142             { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }
143         },
144         {
145             { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
146             { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
147             { ngraph::element::f32, {}, { 0.01f } }
148         }
149     },
150     // U8: concat multi channels
151     {
152         LayerTransformation::createParamsU8I8(),
153         true,
154         {
155             { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
156             { 256ul, ngraph::Shape({}), {0.f}, {1.275f}, {0.f}, {1.275f} }
157         },
158         {
159             { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
160             { 256ul, ngraph::Shape({}), {0.f}, {1.275f}, {0.f}, {255.f}, ngraph::element::u8 },
161             { ngraph::element::f32, {}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} }
162         }
163     },
164     // U8: concat multi channels with subtract
165     {
166         LayerTransformation::createParamsU8I8(),
167         true,
168         {
169             { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
170             { 256ul, ngraph::Shape({}), {1.275f}, {2.55f}, {1.275f}, {2.55f} }
171         },
172         {
173             { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
174             { 256ul, ngraph::Shape({}), {1.275f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
175             {
176                 ngraph::element::f32,
177                 {{ 0.f, 0.f, 0.f, -255.f, -255.f, -255.f }},
178                 {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }}
179             }
180         }
181     },
182     // I8
183     {
184         LayerTransformation::createParamsI8I8(),
185         false,
186         {
187             { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
188             { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
189         },
190         {
191             { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-128.f}, {127.f}, ngraph::element::i8 },
192             { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-128.f}, {127.f}, ngraph::element::i8 },
193             { ngraph::element::f32, {}, { 0.01f } }
194         }
195     },
196     // mixed: U8 + I8: concat (check constant values here)
197     {
198         LayerTransformation::createParamsU8I8(),
199         false,
200         {
201             { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
202             { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
203         },
204         {
205             { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {85.f}, {255.f}, ngraph::element::u8 },
206             { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {170.f}, ngraph::element::u8 },
207             { ngraph::element::f32, { 85 }, { 0.015f } }
208         }
209     },
210     // mixed: U8 + I8: concat multi channels
211     {
212         LayerTransformation::createParamsU8I8(),
213         true,
214         {
215             { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
216             { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
217         },
218         {
219             { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
220             { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {255.f}, ngraph::element::u8 },
221             { ngraph::element::f32, {{ 0.f, 0.f, 0.f, 128.f, 128.f, 128.f }}, { 0.01f } }
222         }
223     },
224     // mixed: I8 + U8: concat (check constant values here)
225     {
226         LayerTransformation::createParamsU8I8(),
227         false,
228         {
229             { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
230             { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }
231         },
232         {
233             { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {170.f}, ngraph::element::u8 },
234             { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {85.f}, {255.f}, ngraph::element::u8 },
235             { ngraph::element::f32, { 85 }, { 0.015f } }
236         }
237     },
238     // real case from ctdet_coco_dlav0_384 model, coverage bad rounding
239     {
240             LayerTransformation::createParamsU8I8(),
241             false,
242             {
243                     { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {2.3007815f} },
244                     { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {-3.873046875f}, {3.84375} }
245             },
246             {
247                     { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {128.f}, {204.f}, ngraph::element::u8 },
248                     { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
249                     { ngraph::element::f32, { 128 }, { 0.0302619f } }
250             }
251     },
252 };
253
254 const std::vector<ngraph::Shape> shapes = {
255     { 1, 3, 9, 9 },
256     { 4, 3, 9, 9 }
257 };
258
259 INSTANTIATE_TEST_CASE_P(
260     LPT,
261     ConcatTransformation,
262     ::testing::Combine(
263         ::testing::ValuesIn(precisions),
264         ::testing::ValuesIn(updatePrecisions),
265         ::testing::ValuesIn(shapes),
266         ::testing::ValuesIn(testValues)),
267     ConcatTransformation::getTestCaseName);
268 }  // namespace