1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "layer_transformation.hpp"
11 #include <gtest/gtest.h>
14 #include <transformations/utils/utils.hpp>
15 #include <transformations/init_node_info.hpp>
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "simple_low_precision_transformer.hpp"
20 #include <low_precision/add.hpp>
21 #include "ngraph_functions/low_precision_transformations/add_function.hpp"
22 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
24 using namespace testing;
25 using namespace ngraph::pass;
26 using namespace ngraph::builder::subgraph;
28 class AddTransformationTestValues {
32 ngraph::element::Type precision1;
33 ngraph::builder::subgraph::DequantizationOperations dequantization1;
34 ngraph::element::Type precision2;
35 ngraph::builder::subgraph::DequantizationOperations dequantization2;
36 std::vector<float> constValues;
41 ngraph::element::Type precision1;
42 ngraph::builder::subgraph::DequantizationOperations dequantization1;
43 ngraph::element::Type precision2;
44 ngraph::builder::subgraph::DequantizationOperations dequantization2;
45 ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
46 std::vector<float> constValues;
47 std::string operationType;
49 Expected(const ngraph::element::Type& precision1,
50 ngraph::builder::subgraph::DequantizationOperations dequantization1,
51 const ngraph::element::Type& precision2,
52 ngraph::builder::subgraph::DequantizationOperations dequantization2,
53 ngraph::builder::subgraph::DequantizationOperations dequantizationAfter,
54 std::vector<float> constValues,
55 std::string operationType = "Add"): precision1(precision1), dequantization1(std::move(dequantization1)),
56 precision2(precision2), dequantization2(std::move(dequantization2)),
57 dequantizationAfter(std::move(dequantizationAfter)), constValues(std::move(constValues)),
58 operationType(std::move(operationType)) {}
61 ngraph::element::Type precision;
62 ngraph::Shape inputShape;
65 ngraph::pass::low_precision::LayerTransformation::Params params;
68 std::string additionalLayer;
72 inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& values) {
74 for (size_t i = 0; i < values.size(); ++i) {
76 if (i != (values.size() - 1ul)) {
84 class AddTransformation : public LayerTransformation, public testing::WithParamInterface<AddTransformationTestValues> {
86 void SetUp() override {
87 const AddTransformationTestValues testValues = GetParam();
89 actualFunction = AddFunction::getOriginal(
91 testValues.inputShape,
94 testValues.actual.precision1,
95 testValues.actual.dequantization1,
96 testValues.actual.precision2,
97 testValues.actual.dequantization2,
98 testValues.constInput,
99 testValues.actual.constValues,
100 testValues.additionalLayer);
102 SimpleLowPrecisionTransformer transform;
103 transform.add<ngraph::pass::low_precision::AddTransformation, ngraph::opset1::Add>(
104 low_precision::LayerTransformation::Params(testValues.params));
105 transform.transform(actualFunction);
107 referenceFunction = AddFunction::getReference(
108 testValues.precision,
109 testValues.inputShape,
110 testValues.broadcast,
112 testValues.expected.precision1,
113 testValues.expected.dequantization1,
114 testValues.expected.precision2,
115 testValues.expected.dequantization2,
116 testValues.expected.dequantizationAfter,
117 // Constant operations after transformations are on 1 input only
118 testValues.constInput == -1 ? -1 : 1,
119 testValues.expected.constValues,
120 testValues.additionalLayer,
121 testValues.expected.operationType);
124 static std::string getTestCaseName(testing::TestParamInfo<AddTransformationTestValues> obj) {
125 const AddTransformationTestValues testValues = obj.param;
127 std::ostringstream result;
129 testValues.precision << "_" <<
130 testValues.inputShape << "_" <<
131 testValues.broadcast << "_" <<
132 testValues.actual.precision1 << "_" <<
133 testValues.actual.dequantization1 << "_" <<
134 testValues.actual.precision2 << "_" <<
135 testValues.actual.dequantization2 << "_" <<
136 testValues.constInput << "_" <<
137 testValues.actual.constValues << "_" <<
138 testValues.additionalLayer;
143 TEST_P(AddTransformation, CompareFunctions) {
144 actualFunction->validate_nodes_and_infer_types();
145 auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
146 ASSERT_TRUE(res.first) << res.second;
149 const std::vector<AddTransformationTestValues> addTransformationTestValues = {
150 // Multiply with zero on the first branch
152 ngraph::element::f32,
153 ngraph::Shape{1, 4, 16, 16},
156 LayerTransformation::createParamsU8I8(),
158 ngraph::element::f32,
161 { {ngraph::element::f32}, { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
165 ngraph::element::f32,
168 { {ngraph::element::f32}, { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
174 // Multiply with zero on the second branch
176 ngraph::element::f32,
177 ngraph::Shape{1, 4, 16, 16},
180 LayerTransformation::createParamsU8I8(),
183 { {ngraph::element::f32}, { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
184 ngraph::element::f32,
190 { {ngraph::element::f32}, { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
191 ngraph::element::f32,
200 ngraph::element::f32,
201 ngraph::Shape{1, 4, 16, 16},
204 LayerTransformation::createParamsU8I8(),
207 { {ngraph::element::f32}, { 7.f }, { 10.f }},
209 { {ngraph::element::f32}, { 3.f }, { 5.f } },
214 { {ngraph::element::f32}, { 8.5f }, { 2.f }},
223 ngraph::element::f32,
224 ngraph::Shape{1, 4, 16, 16},
227 LayerTransformation::createParamsU8I8(),
230 { {ngraph::element::f32}, { 2.f }, { 10.f }},
232 { {ngraph::element::f32}, { }, { 5.f } },
237 { {ngraph::element::f32}, { 2.f }, { 2.f }},
246 ngraph::element::f32,
247 ngraph::Shape{1, 4, 16, 16},
250 LayerTransformation::createParamsU8I8(),
253 { {ngraph::element::f32}, { }, { 10.f }},
255 { {ngraph::element::f32}, { }, { 5.f } },
260 { {ngraph::element::f32}, { }, { 2.f }},
269 ngraph::element::f32,
270 ngraph::Shape{1, 4, 16, 16},
273 LayerTransformation::createParamsU8I8(),
276 { {ngraph::element::f32}, { 2.f }, { }},
278 { {ngraph::element::f32}, { }, { 5.f } },
283 { {ngraph::element::f32}, { 2.f }, { 0.2f }},
292 ngraph::element::f32,
293 ngraph::Shape{1, 4, 16, 16},
296 LayerTransformation::createParamsU8I8(),
299 { {ngraph::element::f32}, { 2.f }, { }},
301 { {ngraph::element::f32}, { 3.f }, { 5.f } },
306 { {ngraph::element::f32}, { 17.f }, { 0.2f }},
318 ngraph::element::f32,
319 ngraph::Shape{1, 4, 16, 16},
322 LayerTransformation::createParamsU8I8(),
325 { {ngraph::element::f32}, { 7.f }, { 10.f }},
327 { {ngraph::element::f32}, { 3.f }, { 5.f } },
332 { {ngraph::element::f32}, { 8.5f }, { 2.f }},
341 ngraph::element::f32,
342 ngraph::Shape{1, 4, 16, 16},
345 LayerTransformation::createParamsU8I8(),
348 { {ngraph::element::f32}, { 2.f }, { 10.f }},
350 { {ngraph::element::f32}, { }, { 5.f } },
355 { {ngraph::element::f32}, { 2.f }, { 2.f }},
364 ngraph::element::f32,
365 ngraph::Shape{1, 4, 16, 16},
368 LayerTransformation::createParamsU8I8(),
371 { {ngraph::element::f32}, { }, { 10.f }},
373 { {ngraph::element::f32}, { }, { 5.f } },
378 { {ngraph::element::f32}, { }, { 2.f }},
387 ngraph::element::f32,
388 ngraph::Shape{1, 4, 16, 16},
391 LayerTransformation::createParamsU8I8(),
394 { {ngraph::element::f32}, { 2.f }, { }},
396 { {ngraph::element::f32}, { }, { 5.f } },
401 { {ngraph::element::f32}, { 2.f }, { 0.2f }},
410 ngraph::element::f32,
411 ngraph::Shape{1, 4, 16, 16},
414 LayerTransformation::createParamsU8I8(),
417 { {ngraph::element::f32}, { 2.f }, { }},
419 { {ngraph::element::f32}, { 3.f }, { 5.f } },
424 { {ngraph::element::f32}, { 17.f }, { 0.2f }},
434 ngraph::element::f32,
438 LayerTransformation::createParamsU8I8(),
441 { {ngraph::element::f32}, { }, { {1.f, 2.f, 3.f, 4.f}, ngraph::element::f32, {4, 1}, true, 0ul }},
442 ngraph::element::f32,
444 { 5.f, 6.f, 7.f, 8.f }
448 { {ngraph::element::f32}, { }, { {1.f, 2.f, 3.f, 4.f}, ngraph::element::f32, {4, 1}, true, 0ul }},
449 ngraph::element::f32,
452 { 5.f, 6.f, 7.f, 8.f }
457 // constant input: Add -> Subtract
459 ngraph::element::f32,
460 ngraph::Shape{ 1, 2, 2, 2 },
463 LayerTransformation::createParamsU8I8(),
466 { {ngraph::element::f32}, {}, {5.f}},
469 { 10.f, 5.f, 2.f, 4.f, 3.f, 12.f, 8.f, 14.f }
473 { {ngraph::element::f32}, { }, { }},
474 ngraph::element::f32,
477 { -2.f, -1.f, -0.4f, -0.8f, -0.6f, -2.4f, -1.6f, -2.8f },
483 // constant input: Add -> Subtract
485 ngraph::element::f32,
486 ngraph::Shape{1, 2, 2, 2},
489 LayerTransformation::createParamsU8I8(),
494 { {ngraph::element::f32}, {}, { 5.f } },
495 { 10.f, 5.f, 2.f, 4.f, 3.f, 12.f, 8.f, 14.f }
499 { {ngraph::element::f32}, {}, {} },
500 ngraph::element::f32,
504 { -2.f, -1.f, -0.4f, -0.8f, -0.6f, -2.4f, -1.6f, -2.8f },
509 // convolution before FQ (choose that branch)
511 ngraph::element::f32,
512 ngraph::Shape{1, 4, 16, 16},
515 LayerTransformation::createParamsU8I8(),
518 { {ngraph::element::f32}, { 7.f }, { 10.f }},
520 { {ngraph::element::f32}, { 3.f }, { 5.f } },
527 { {ngraph::element::f32}, { 17.f }, { 0.5f }},
533 // group convolution before FQ (choose that branch)
535 ngraph::element::f32,
536 ngraph::Shape{1, 4, 16, 16},
539 LayerTransformation::createParamsU8I8(),
542 { {ngraph::element::f32}, { 7.f }, { 10.f }},
544 { {ngraph::element::f32}, { 3.f }, { 5.f } },
551 { {ngraph::element::f32}, { 17.f }, { 0.5f }},
559 INSTANTIATE_TEST_CASE_P(
562 ::testing::ValuesIn(addTransformationTestValues),
563 AddTransformation::getTestCaseName);