1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "layer_transformation.hpp"
11 #include <gtest/gtest.h>
13 #include <transformations/utils/utils.hpp>
14 #include <transformations/init_node_info.hpp>
15 #include <low_precision/transformer.hpp>
16 #include <low_precision/concat.hpp>
17 #include <low_precision/concat_multi_channels.hpp>
19 #include "common_test_utils/ngraph_test_utils.hpp"
20 #include "ngraph_functions/low_precision_transformations/concat_function.hpp"
21 #include "ngraph_functions/low_precision_transformations/common/fake_quantize_on_data.hpp"
22 #include "simple_low_precision_transformer.hpp"
24 using namespace testing;
25 using namespace ngraph;
26 using namespace ngraph::pass;
30 class ConcatTransformationActualValues {
32 ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize1;
33 ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize2;
36 inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationActualValues& values) {
37 return out << "_" << values.fakeQuantize1 << "_" << values.fakeQuantize2;
40 class ConcatTransformationResultValues {
42 ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize1;
43 ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize2;
44 ngraph::builder::subgraph::DequantizationOperations dequantizationOperations;
47 inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationResultValues& values) {
48 return out << "_" << values.fakeQuantize1 << "_" << values.fakeQuantize2 << "_" << values.dequantizationOperations;
51 class ConcatTransformationTestValues {
53 ngraph::pass::low_precision::LayerTransformation::Params params;
55 ConcatTransformationActualValues actual;
56 ConcatTransformationResultValues result;
59 inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationTestValues& values) {
60 return out << "_" << values.multiChannels << "_" << values.actual << "_" << values.result;
64 ngraph::element::Type,
67 ConcatTransformationTestValues
68 > ConcatTransformationParams;
70 class ConcatTransformation : public LayerTransformation, public testing::WithParamInterface<ConcatTransformationParams> {
72 void SetUp() override {
73 const ngraph::element::Type precision = std::get<0>(GetParam());
74 const bool updatePrecisions = std::get<1>(GetParam());
75 const ngraph::Shape shape = std::get<2>(GetParam());
76 ConcatTransformationTestValues testValues = std::get<3>(GetParam());
78 testValues.params.updatePrecisions = updatePrecisions;
79 if (!updatePrecisions) {
80 testValues.result.fakeQuantize1.outputPrecision = testValues.actual.fakeQuantize1.outputPrecision;
81 testValues.result.fakeQuantize2.outputPrecision = testValues.actual.fakeQuantize2.outputPrecision;
84 actualFunction = ngraph::builder::subgraph::ConcatFunction::getOriginal(
87 testValues.actual.fakeQuantize1,
88 testValues.actual.fakeQuantize2);
90 SimpleLowPrecisionTransformer transform;
91 if (testValues.multiChannels) {
92 transform.add<ngraph::pass::low_precision::ConcatMultiChannelsTransformation, ngraph::opset1::Concat>(testValues.params);
94 transform.add<ngraph::pass::low_precision::ConcatTransformation, ngraph::opset1::Concat>(testValues.params);
96 transform.transform(actualFunction);
98 referenceFunction = ngraph::builder::subgraph::ConcatFunction::getReference(
101 testValues.result.fakeQuantize1,
102 testValues.result.fakeQuantize2,
103 testValues.result.dequantizationOperations);
106 static std::string getTestCaseName(testing::TestParamInfo<ConcatTransformationParams> obj) {
107 const ngraph::element::Type precision = std::get<0>(obj.param);
108 const bool updatePrecision = std::get<1>(obj.param);
109 const ngraph::Shape shape = std::get<2>(obj.param);
110 const ConcatTransformationTestValues testValues = std::get<3>(obj.param);
112 std::ostringstream result;
114 LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" <<
115 (testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") <<
116 (updatePrecision ? "updatePrecision_" : "notUpdatePrecision_") <<
117 testValues.actual << "_" <<
118 testValues.result << "_";
123 TEST_P(ConcatTransformation, CompareFunctions) {
124 actualFunction->validate_nodes_and_infer_types();
125 auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
126 ASSERT_TRUE(res.first) << res.second;
129 const std::vector<ngraph::element::Type> precisions = {
130 ngraph::element::f32,
131 // ngraph::element::f16
134 const std::vector<bool> updatePrecisions = { true, false };
136 const std::vector<ConcatTransformationTestValues> testValues = {
139 LayerTransformation::createParamsU8I8(),
142 { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
143 { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} }
146 { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
147 { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
148 { ngraph::element::f32, {}, { 0.01f } }
153 LayerTransformation::createParamsU8I8(),
156 { 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {2.55f}, {0.f}, {2.55f} },
157 { 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {2.55f}, {0.f}, {2.55f} }
160 { 256ul, {{1}, {1}, {}, {}}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
161 { 256ul, {{1}, {1}, {}, {}}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
162 { ngraph::element::f32, {}, { 0.01f } }
167 LayerTransformation::createParamsU8I8(),
170 { 256ul, {{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}}, {0.f}, {2.55f}, {0.f}, {2.55f} },
171 { 256ul, {{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}}, {0.f}, {2.55f}, {0.f}, {2.55f} }
174 { 256ul, {{1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
175 { 256ul, {{1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
176 { ngraph::element::f32, {}, { 0.01f } }
179 // U8: concat multi channels
181 LayerTransformation::createParamsU8I8(),
184 { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
185 { 256ul, {}, {0.f}, {1.275f}, {0.f}, {1.275f} }
188 { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
189 { 256ul, {}, {0.f}, {1.275f}, {0.f}, {255.f}, ngraph::element::u8 },
190 { ngraph::element::f32, {}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} }
193 // U8: concat multi channels
195 LayerTransformation::createParamsU8I8(),
198 { 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {2.55f}, {0.f}, {2.55f} },
199 { 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {1.275f}, {0.f}, {1.275f} }
202 { 256ul, {{1}, {1}, {}, {}}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
203 { 256ul, {{1}, {1}, {}, {}}, {0.f}, {1.275f}, {0.f}, {255.f}, ngraph::element::u8 },
204 { ngraph::element::f32, {}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} }
207 // U8: concat multi channels
209 LayerTransformation::createParamsU8I8(),
214 {{1, 3, 1, 1}, {1, 3, 1, 1}, {1, 3, 1, 1}, {1, 3, 1, 1}},
215 {0.f, 0.f, 0.f}, {2.55f, 2.55f, 2.55f}, {0.f, 0.f, 0.f}, {2.55f / 1.f, 2.55f / 2.f, 2.55f / 3.f},
220 {{1, 3, 1, 1}, {1, 3, 1, 1}, {1, 3, 1, 1}, {1, 3, 1, 1}},
221 {0.f, 0.f, 0.f}, {1.275f, 1.275f, 1.275f}, {0.f, 0.f, 0.f}, {1.275f / 1.f, 1.275f / 2.f, 1.275f / 3.f},
228 {{1, 3, 1, 1}, {1, 3, 1, 1}, {}, {}},
229 {0.f, 0.f, 0.f}, {2.55f, 2.55f, 2.55f}, {0.f}, {255.f},
234 {{1, 3, 1, 1}, {1, 3, 1, 1}, {}, {}},
235 {0.f, 0.f, 0.f}, {1.275f, 1.275f, 1.275f}, {0.f}, {255.f},
238 { ngraph::element::f32, {}, {{ 0.01f / 1.f, 0.01f / 2.f, 0.01f / 3.f, 0.005f / 1.f, 0.005f / 2.f, 0.005f / 3.f }} }
241 // U8: concat multi channels with subtract
243 LayerTransformation::createParamsU8I8(),
246 { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
247 { 256ul, {}, {1.275f}, {2.55f}, {1.275f}, {2.55f} }
250 { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
251 { 256ul, {}, {1.275f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
253 ngraph::element::f32,
254 {{ 0.f, 0.f, 0.f, -255.f, -255.f, -255.f }},
255 {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }}
261 LayerTransformation::createParamsI8I8(),
264 { 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
265 { 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
268 { 256ul, {}, {-1.28f}, {1.27f}, {-128.f}, {127.f}, ngraph::element::i8 },
269 { 256ul, {}, {-1.28f}, {1.27f}, {-128.f}, {127.f}, ngraph::element::i8 },
270 { ngraph::element::f32, {}, { 0.01f } }
273 // mixed: U8 + I8: concat (check constant values here)
275 LayerTransformation::createParamsU8I8(),
278 { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
279 { 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
282 { 256ul, {}, {0.f}, {2.55f}, {85.f}, {255.f}, ngraph::element::u8 },
283 { 256ul, {}, {-1.28f}, {1.27f}, {0.f}, {170.f}, ngraph::element::u8 },
284 { ngraph::element::f32, { 85 }, { 0.015f } }
287 // mixed: U8 + I8: concat multi channels
289 LayerTransformation::createParamsU8I8(),
292 { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
293 { 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
296 { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
297 { 256ul, {}, {-1.28f}, {1.27f}, {0.f}, {255.f}, ngraph::element::u8 },
298 { ngraph::element::f32, {{ 0.f, 0.f, 0.f, 128.f, 128.f, 128.f }}, { 0.01f } }
301 // mixed: I8 + U8: concat (check constant values here)
303 LayerTransformation::createParamsU8I8(),
306 { 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
307 { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} }
310 { 256ul, {}, {-1.28f}, {1.27f}, {0.f}, {170.f}, ngraph::element::u8 },
311 { 256ul, {}, {0.f}, {2.55f}, {85.f}, {255.f}, ngraph::element::u8 },
312 { ngraph::element::f32, { 85 }, { 0.015f } }
315 // real case from ctdet_coco_dlav0_384 model, coverage bad rounding
317 LayerTransformation::createParamsU8I8(),
320 { 256ul, {}, {-1.28f}, {1.27f}, {0.f}, {2.3007815f} },
321 { 256ul, {}, {0.f}, {2.55f}, {-3.873046875f}, {3.84375} }
324 { 256ul, {}, {-1.28f}, {1.27f}, {128.f}, {204.f}, ngraph::element::u8 },
325 { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
326 { ngraph::element::f32, { 128 }, { 0.0302619f } }
331 const std::vector<ngraph::Shape> shapes = {
336 INSTANTIATE_TEST_CASE_P(
338 ConcatTransformation,
340 ::testing::ValuesIn(precisions),
341 ::testing::ValuesIn(updatePrecisions),
342 ::testing::ValuesIn(shapes),
343 ::testing::ValuesIn(testValues)),
344 ConcatTransformation::getTestCaseName);