1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "layer_transformation.hpp"
11 #include <gtest/gtest.h>
13 #include <transformations/utils/utils.hpp>
14 #include <transformations/init_node_info.hpp>
15 #include <low_precision/transformer.hpp>
16 #include <low_precision/concat.hpp>
17 #include <low_precision/concat_multi_channels.hpp>
19 #include "common_test_utils/ngraph_test_utils.hpp"
20 #include "ngraph_functions/low_precision_transformations/concat_function.hpp"
21 #include "ngraph_functions/low_precision_transformations/common/fake_quantize_on_data.hpp"
22 #include "simple_low_precision_transformer.hpp"
24 using namespace testing;
25 using namespace ngraph;
26 using namespace ngraph::pass;
30 class ConcatTransformationActualValues {
32 ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1;
33 ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2;
36 inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationActualValues& values) {
37 return out << "_" << values.fakeQuantize1 << "_" << values.fakeQuantize2;
40 class ConcatTransformationResultValues {
42 ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1;
43 ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2;
44 ngraph::builder::subgraph::DequantizationOperations dequantizationOperations;
47 inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationResultValues& values) {
48 return out << "_" << values.fakeQuantize1 << "_" << values.fakeQuantize2 << "_" << values.dequantizationOperations;
51 class ConcatTransformationTestValues {
53 ngraph::pass::low_precision::LayerTransformation::Params params;
55 ConcatTransformationActualValues actual;
56 ConcatTransformationResultValues result;
59 inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationTestValues& values) {
60 return out << "_" << values.multiChannels << "_" << values.actual << "_" << values.result;
64 ngraph::element::Type,
67 ConcatTransformationTestValues
68 > ConcatTransformationParams;
70 class ConcatTransformation : public LayerTransformation, public testing::WithParamInterface<ConcatTransformationParams> {
72 void SetUp() override {
73 const ngraph::element::Type precision = std::get<0>(GetParam());
74 const bool updatePrecisions = std::get<1>(GetParam());
75 const ngraph::Shape shape = std::get<2>(GetParam());
76 ConcatTransformationTestValues testValues = std::get<3>(GetParam());
78 testValues.params.updatePrecisions = updatePrecisions;
79 if (!updatePrecisions) {
80 testValues.result.fakeQuantize1.outputPrecision = testValues.actual.fakeQuantize1.outputPrecision;
81 testValues.result.fakeQuantize2.outputPrecision = testValues.actual.fakeQuantize2.outputPrecision;
84 actualFunction = ngraph::builder::subgraph::ConcatFunction::getOriginal(
87 testValues.actual.fakeQuantize1,
88 testValues.actual.fakeQuantize2);
89 SimpleLowPrecisionTransformer transform;
90 if (testValues.multiChannels) {
91 transform.add<ngraph::pass::low_precision::ConcatMultiChannelsTransformation, ngraph::opset1::Concat>(testValues.params);
93 transform.add<ngraph::pass::low_precision::ConcatTransformation, ngraph::opset1::Concat>(testValues.params);
95 transform.transform(actualFunction);
97 referenceFunction = ngraph::builder::subgraph::ConcatFunction::getReference(
100 testValues.result.fakeQuantize1,
101 testValues.result.fakeQuantize2,
102 testValues.result.dequantizationOperations);
105 static std::string getTestCaseName(testing::TestParamInfo<ConcatTransformationParams> obj) {
106 const ngraph::element::Type precision = std::get<0>(obj.param);
107 const bool updatePrecision = std::get<1>(obj.param);
108 const ngraph::Shape shape = std::get<2>(obj.param);
109 const ConcatTransformationTestValues testValues = std::get<3>(obj.param);
111 std::ostringstream result;
113 LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" <<
114 (testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") <<
115 (updatePrecision ? "updatePrecision_" : "notUpdatePrecision_") <<
116 testValues.actual << "_" <<
117 testValues.result << "_";
122 TEST_P(ConcatTransformation, CompareFunctions) {
123 actualFunction->validate_nodes_and_infer_types();
124 auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
125 ASSERT_TRUE(res.first) << res.second;
128 const std::vector<ngraph::element::Type> precisions = {
129 ngraph::element::f32,
130 // ngraph::element::f16
133 const std::vector<bool> updatePrecisions = { true, false };
135 const std::vector<ConcatTransformationTestValues> testValues = {
138 LayerTransformation::createParamsU8I8(),
141 { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
142 { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }
145 { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
146 { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
147 { ngraph::element::f32, {}, { 0.01f } }
150 // U8: concat multi channels
152 LayerTransformation::createParamsU8I8(),
155 { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
156 { 256ul, ngraph::Shape({}), {0.f}, {1.275f}, {0.f}, {1.275f} }
159 { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
160 { 256ul, ngraph::Shape({}), {0.f}, {1.275f}, {0.f}, {255.f}, ngraph::element::u8 },
161 { ngraph::element::f32, {}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} }
164 // U8: concat multi channels with subtract
166 LayerTransformation::createParamsU8I8(),
169 { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
170 { 256ul, ngraph::Shape({}), {1.275f}, {2.55f}, {1.275f}, {2.55f} }
173 { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
174 { 256ul, ngraph::Shape({}), {1.275f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
176 ngraph::element::f32,
177 {{ 0.f, 0.f, 0.f, -255.f, -255.f, -255.f }},
178 {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }}
184 LayerTransformation::createParamsI8I8(),
187 { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
188 { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
191 { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-128.f}, {127.f}, ngraph::element::i8 },
192 { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-128.f}, {127.f}, ngraph::element::i8 },
193 { ngraph::element::f32, {}, { 0.01f } }
196 // mixed: U8 + I8: concat (check constant values here)
198 LayerTransformation::createParamsU8I8(),
201 { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
202 { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
205 { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {85.f}, {255.f}, ngraph::element::u8 },
206 { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {170.f}, ngraph::element::u8 },
207 { ngraph::element::f32, { 85 }, { 0.015f } }
210 // mixed: U8 + I8: concat multi channels
212 LayerTransformation::createParamsU8I8(),
215 { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
216 { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
219 { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
220 { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {255.f}, ngraph::element::u8 },
221 { ngraph::element::f32, {{ 0.f, 0.f, 0.f, 128.f, 128.f, 128.f }}, { 0.01f } }
224 // mixed: I8 + U8: concat (check constant values here)
226 LayerTransformation::createParamsU8I8(),
229 { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
230 { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }
233 { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {170.f}, ngraph::element::u8 },
234 { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {85.f}, {255.f}, ngraph::element::u8 },
235 { ngraph::element::f32, { 85 }, { 0.015f } }
238 // real case from ctdet_coco_dlav0_384 model, coverage bad rounding
240 LayerTransformation::createParamsU8I8(),
243 { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {2.3007815f} },
244 { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {-3.873046875f}, {3.84375} }
247 { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {128.f}, {204.f}, ngraph::element::u8 },
248 { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
249 { ngraph::element::f32, { 128 }, { 0.0302619f } }
254 const std::vector<ngraph::Shape> shapes = {
259 INSTANTIATE_TEST_CASE_P(
261 ConcatTransformation,
263 ::testing::ValuesIn(precisions),
264 ::testing::ValuesIn(updatePrecisions),
265 ::testing::ValuesIn(shapes),
266 ::testing::ValuesIn(testValues)),
267 ConcatTransformation::getTestCaseName);