1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "layer_transformation.hpp"
11 #include <gtest/gtest.h>
14 #include <transformations/utils/utils.hpp>
15 #include <transformations/init_node_info.hpp>
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "simple_low_precision_transformer.hpp"
20 #include <low_precision/add.hpp>
21 #include "ngraph_functions/low_precision_transformations/add_function.hpp"
22 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
24 using namespace testing;
25 using namespace ngraph::pass;
26 using namespace ngraph::builder::subgraph;
28 class AddTransformationTestValues {
32 ngraph::element::Type precision1;
33 ngraph::builder::subgraph::DequantizationOperations dequantization1;
34 ngraph::element::Type precision2;
35 ngraph::builder::subgraph::DequantizationOperations dequantization2;
36 std::vector<float> constValues;
41 ngraph::element::Type precision1;
42 ngraph::builder::subgraph::DequantizationOperations dequantization1;
43 ngraph::element::Type precision2;
44 ngraph::builder::subgraph::DequantizationOperations dequantization2;
45 ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
46 std::vector<float> constValues;
47 std::string operationType;
49 Expected(const ngraph::element::Type& precision1,
50 ngraph::builder::subgraph::DequantizationOperations dequantization1,
51 const ngraph::element::Type& precision2,
52 ngraph::builder::subgraph::DequantizationOperations dequantization2,
53 ngraph::builder::subgraph::DequantizationOperations dequantizationAfter,
54 std::vector<float> constValues,
55 std::string operationType = "Add"): precision1(precision1), dequantization1(std::move(dequantization1)),
56 precision2(precision2), dequantization2(std::move(dequantization2)),
57 dequantizationAfter(std::move(dequantizationAfter)), constValues(std::move(constValues)),
58 operationType(std::move(operationType)) {}
61 ngraph::element::Type precision;
62 ngraph::Shape inputShape;
65 ngraph::pass::low_precision::LayerTransformation::Params params;
68 std::string additionalLayer;
72 inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& values) {
74 for (size_t i = 0; i < values.size(); ++i) {
76 if (i != (values.size() - 1ul)) {
84 class AddTransformation : public LayerTransformation, public testing::WithParamInterface<AddTransformationTestValues> {
86 void SetUp() override {
87 const AddTransformationTestValues testValues = GetParam();
89 actualFunction = AddFunction::getOriginal(
91 testValues.inputShape,
94 testValues.actual.precision1,
95 testValues.actual.dequantization1,
96 testValues.actual.precision2,
97 testValues.actual.dequantization2,
98 testValues.constInput,
99 testValues.actual.constValues,
100 testValues.additionalLayer);
102 SimpleLowPrecisionTransformer transform;
103 transform.add<ngraph::pass::low_precision::AddTransformation, ngraph::opset1::Add>(
104 low_precision::LayerTransformation::Params(testValues.params));
105 transform.transform(actualFunction);
107 referenceFunction = AddFunction::getReference(
108 testValues.precision,
109 testValues.inputShape,
110 testValues.broadcast,
112 testValues.expected.precision1,
113 testValues.expected.dequantization1,
114 testValues.expected.precision2,
115 testValues.expected.dequantization2,
116 testValues.expected.dequantizationAfter,
117 // Constant operations after transformations are on 1 input only
118 testValues.constInput == -1 ? -1 : 1,
119 testValues.expected.constValues,
120 testValues.additionalLayer,
121 testValues.expected.operationType);
124 static std::string getTestCaseName(testing::TestParamInfo<AddTransformationTestValues> obj) {
125 const AddTransformationTestValues testValues = obj.param;
127 std::ostringstream result;
129 testValues.precision << "_" <<
130 testValues.inputShape << "_" <<
131 testValues.broadcast << "_" <<
132 testValues.actual.precision1 << "_" <<
133 testValues.actual.dequantization1 << "_" <<
134 testValues.actual.precision2 << "_" <<
135 testValues.actual.dequantization2 << "_" <<
136 testValues.constInput << "_" <<
137 testValues.actual.constValues << "_" <<
138 testValues.additionalLayer;
143 TEST_P(AddTransformation, CompareFunctions) {
144 actualFunction->validate_nodes_and_infer_types();
145 auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
146 ASSERT_TRUE(res.first) << res.second;
149 const std::vector<AddTransformationTestValues> addTransformationTestValues = {
152 ngraph::element::f32,
153 ngraph::Shape{1, 4, 16, 16},
156 LayerTransformation::createParamsU8I8(),
159 { {ngraph::element::f32}, { 7.f }, { 10.f }},
161 { {ngraph::element::f32}, { 3.f }, { 5.f } },
166 { {ngraph::element::f32}, { 8.5f }, { 2.f }},
175 ngraph::element::f32,
176 ngraph::Shape{1, 4, 16, 16},
179 LayerTransformation::createParamsU8I8(),
182 { {ngraph::element::f32}, { 2.f }, { 10.f }},
184 { {ngraph::element::f32}, { }, { 5.f } },
189 { {ngraph::element::f32}, { 2.f }, { 2.f }},
198 ngraph::element::f32,
199 ngraph::Shape{1, 4, 16, 16},
202 LayerTransformation::createParamsU8I8(),
205 { {ngraph::element::f32}, { }, { 10.f }},
207 { {ngraph::element::f32}, { }, { 5.f } },
212 { {ngraph::element::f32}, { }, { 2.f }},
221 ngraph::element::f32,
222 ngraph::Shape{1, 4, 16, 16},
225 LayerTransformation::createParamsU8I8(),
228 { {ngraph::element::f32}, { 2.f }, { }},
230 { {ngraph::element::f32}, { }, { 5.f } },
235 { {ngraph::element::f32}, { 2.f }, { 0.2f }},
244 ngraph::element::f32,
245 ngraph::Shape{1, 4, 16, 16},
248 LayerTransformation::createParamsU8I8(),
251 { {ngraph::element::f32}, { 2.f }, { }},
253 { {ngraph::element::f32}, { 3.f }, { 5.f } },
258 { {ngraph::element::f32}, { 17.f }, { 0.2f }},
270 ngraph::element::f32,
271 ngraph::Shape{1, 4, 16, 16},
274 LayerTransformation::createParamsU8I8(),
277 { {ngraph::element::f32}, { 7.f }, { 10.f }},
279 { {ngraph::element::f32}, { 3.f }, { 5.f } },
284 { {ngraph::element::f32}, { 8.5f }, { 2.f }},
293 ngraph::element::f32,
294 ngraph::Shape{1, 4, 16, 16},
297 LayerTransformation::createParamsU8I8(),
300 { {ngraph::element::f32}, { 2.f }, { 10.f }},
302 { {ngraph::element::f32}, { }, { 5.f } },
307 { {ngraph::element::f32}, { 2.f }, { 2.f }},
316 ngraph::element::f32,
317 ngraph::Shape{1, 4, 16, 16},
320 LayerTransformation::createParamsU8I8(),
323 { {ngraph::element::f32}, { }, { 10.f }},
325 { {ngraph::element::f32}, { }, { 5.f } },
330 { {ngraph::element::f32}, { }, { 2.f }},
339 ngraph::element::f32,
340 ngraph::Shape{1, 4, 16, 16},
343 LayerTransformation::createParamsU8I8(),
346 { {ngraph::element::f32}, { 2.f }, { }},
348 { {ngraph::element::f32}, { }, { 5.f } },
353 { {ngraph::element::f32}, { 2.f }, { 0.2f }},
362 ngraph::element::f32,
363 ngraph::Shape{1, 4, 16, 16},
366 LayerTransformation::createParamsU8I8(),
369 { {ngraph::element::f32}, { 2.f }, { }},
371 { {ngraph::element::f32}, { 3.f }, { 5.f } },
376 { {ngraph::element::f32}, { 17.f }, { 0.2f }},
386 ngraph::element::f32,
390 LayerTransformation::createParamsU8I8(),
393 { {ngraph::element::f32}, { }, { {1.f, 2.f, 3.f, 4.f}, ngraph::element::f32, {4, 1}, true, 0ul }},
394 ngraph::element::f32,
396 { 5.f, 6.f, 7.f, 8.f }
400 { {ngraph::element::f32}, { }, { {1.f, 2.f, 3.f, 4.f}, ngraph::element::f32, {4, 1}, true, 0ul }},
401 ngraph::element::f32,
404 { 5.f, 6.f, 7.f, 8.f }
409 // constant input: Add -> Subtract
411 ngraph::element::f32,
412 ngraph::Shape{ 1, 2, 2, 2 },
415 LayerTransformation::createParamsU8I8(),
418 { {ngraph::element::f32}, {}, {5.f}},
421 { 10.f, 5.f, 2.f, 4.f, 3.f, 12.f, 8.f, 14.f }
425 { {ngraph::element::f32}, { }, { }},
426 ngraph::element::f32,
429 { -2.f, -1.f, -0.4f, -0.8f, -0.6f, -2.4f, -1.6f, -2.8f },
435 // constant input: Add -> Subtract
437 ngraph::element::f32,
438 ngraph::Shape{1, 2, 2, 2},
441 LayerTransformation::createParamsU8I8(),
446 { {ngraph::element::f32}, {}, { 5.f } },
447 { 10.f, 5.f, 2.f, 4.f, 3.f, 12.f, 8.f, 14.f }
451 { {ngraph::element::f32}, {}, {} },
452 ngraph::element::f32,
456 { -2.f, -1.f, -0.4f, -0.8f, -0.6f, -2.4f, -1.6f, -2.8f },
461 // convolution before FQ (choose that branch)
463 ngraph::element::f32,
464 ngraph::Shape{1, 4, 16, 16},
467 LayerTransformation::createParamsU8I8(),
470 { {ngraph::element::f32}, { 7.f }, { 10.f }},
472 { {ngraph::element::f32}, { 3.f }, { 5.f } },
479 { {ngraph::element::f32}, { 17.f }, { 0.5f }},
485 // group convolution before FQ (choose that branch)
487 ngraph::element::f32,
488 ngraph::Shape{1, 4, 16, 16},
491 LayerTransformation::createParamsU8I8(),
494 { {ngraph::element::f32}, { 7.f }, { 10.f }},
496 { {ngraph::element::f32}, { 3.f }, { 5.f } },
503 { {ngraph::element::f32}, { 17.f }, { 0.5f }},
511 INSTANTIATE_TEST_CASE_P(
514 ::testing::ValuesIn(addTransformationTestValues),
515 AddTransformation::getTestCaseName);