1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "layer_transformation.hpp"
11 #include <gtest/gtest.h>
14 #include <transformations/utils/utils.hpp>
15 #include <transformations/init_node_info.hpp>
16 #include <low_precision/network_helper.hpp>
18 #include "common_test_utils/ngraph_test_utils.hpp"
19 #include "ngraph_functions/low_precision_transformations/move_dequantization_after_function.hpp"
20 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
22 using namespace testing;
23 using namespace ngraph::pass;
24 using namespace ngraph::builder::subgraph;
26 class MoveDequantizationAfterTransformationParams {
30 ngraph::builder::subgraph::DequantizationOperations dequantization;
35 ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
36 ngraph::element::Type precisionAfterOperation;
37 ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
40 ngraph::element::Type originalPrecision;
41 ngraph::pass::low_precision::LayerTransformation::Params params;
50 MoveDequantizationAfterTransformationParams> MoveDequantizationAfterTransformationTestValues;
52 class MoveDequantizationAfterTransformation :
53 public LayerTransformation,
54 public testing::WithParamInterface<MoveDequantizationAfterTransformationTestValues> {
56 void SetUp() override {
57 const auto inputShape = std::get<0>(GetParam());
58 const auto testValues = std::get<1>(GetParam());
59 actualFunction = ngraph::builder::subgraph::MoveDequantizationAfterFunction::getOriginal(
60 testValues.originalPrecision,
62 testValues.actual.dequantization);
64 const auto targetNode = actualFunction->get_output_op(0)->get_input_node_shared_ptr(0);
65 const auto dequantization = ngraph::pass::low_precision::NetworkHelper::getDequantization(targetNode);
66 ngraph::pass::low_precision::NetworkHelper::moveDequantizationAfter(
69 testValues.updatePrecision,
70 testValues.moveSubtract);
72 referenceFunction = ngraph::builder::subgraph::MoveDequantizationAfterFunction::getReference(
73 testValues.originalPrecision,
75 testValues.expected.dequantizationBefore,
76 testValues.expected.precisionAfterOperation,
77 testValues.expected.dequantizationAfter);
80 static std::string getTestCaseName(testing::TestParamInfo<MoveDequantizationAfterTransformationTestValues> obj) {
81 const auto inputShape = std::get<0>(obj.param);
82 const auto testValues = std::get<1>(obj.param);
84 std::ostringstream result;
86 testValues.originalPrecision << "_" <<
88 testValues.actual.dequantization << "_" <<
89 (testValues.moveSubtract ? "move_subtract_" : "don't_move_subtract_") <<
90 (testValues.updatePrecision ? "updatePrecision" : "don't_update_precision");
95 TEST_P(MoveDequantizationAfterTransformation, CompareFunctions) {
96 actualFunction->validate_nodes_and_infer_types();
97 auto res = compare_functions(referenceFunction, actualFunction, true, false, true);
98 ASSERT_TRUE(res.first) << res.second;
101 const std::vector<ngraph::Shape> inputShapes = {
106 const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
110 LayerTransformation::createParamsU8I8(),
114 { {ngraph::element::f32}, { 7.f }, { 10.f } },
119 { {ngraph::element::f32}, { 7.f }, { 10.f } },
122 // moveSubtract = false
125 LayerTransformation::createParamsU8I8(),
129 { {ngraph::element::f32}, { 7.f }, { 10.f } },
132 { {}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
133 ngraph::element::f32,
134 { {}, {}, { 10.f } },
137 // updatePrecision = false
140 LayerTransformation::createParamsU8I8(),
144 { {ngraph::element::f32}, { 7.f }, { 10.f } },
148 ngraph::element::f32,
149 { {}, { 7.f }, { 10.f } },
152 // moveSubtract = false & updatePrecision = false
155 LayerTransformation::createParamsU8I8(),
159 { {ngraph::element::f32}, { 7.f }, { 10.f } },
162 { {}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
163 ngraph::element::f32,
164 { {}, {}, { 10.f } },
170 LayerTransformation::createParamsI8I8(),
174 { {ngraph::element::f32}, { 7.f }, { 10.f } },
179 { {ngraph::element::f32}, { 7.f }, { 10.f } },
182 // moveSubtract = false
185 LayerTransformation::createParamsI8I8(),
189 { {ngraph::element::f32}, { 7.f }, { 10.f } },
192 { {}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
193 ngraph::element::f32,
194 { {}, {}, { 10.f } },
197 // updatePrecision = false
200 LayerTransformation::createParamsI8I8(),
204 { {ngraph::element::f32}, { 7.f }, { 10.f } },
208 ngraph::element::f32,
209 { {}, { 7.f }, { 10.f } },
212 // moveSubtract = false & updatePrecision = false
215 LayerTransformation::createParamsI8I8(),
219 { {ngraph::element::f32}, { 7.f }, { 10.f } },
222 { {}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
223 ngraph::element::f32,
224 { {}, {}, { 10.f } },
227 // per-channel quantizations with the same values
230 LayerTransformation::createParamsU8I8(),
234 { {ngraph::element::f32}, { { 7.f, 7.f, 7.f } }, { { 10.f, 10.f, 10.f } } },
237 { {}, { { 7.f, 7.f, 7.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {} },
238 ngraph::element::f32,
239 { {}, {}, { { 10.f, 10.f, 10.f } } },
242 // per-channel quantizations with the same values
245 LayerTransformation::createParamsU8I8(),
249 { {ngraph::element::f32}, { { 7.f, 8.f, 9.f } }, { { 10.f, 12.f, 16.f } } },
252 { {}, { { 7.f, 8.f, 9.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {} },
253 ngraph::element::f32,
254 { {}, {}, { { 10.f, 12.f, 16.f } } },
259 INSTANTIATE_TEST_CASE_P(
261 MoveDequantizationAfterTransformation,
263 ::testing::ValuesIn(inputShapes),
264 ::testing::ValuesIn(testValues)),
265 MoveDequantizationAfterTransformation::getTestCaseName);