1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "layer_transformation.hpp"
11 #include <gtest/gtest.h>
13 #include <transformations/utils/utils.hpp>
14 #include <transformations/init_node_info.hpp>
15 #include "low_precision/mvn.hpp"
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
19 #include "simple_low_precision_transformer.hpp"
20 #include "ngraph_functions/low_precision_transformations/mvn_function.hpp"
22 using namespace testing;
23 using namespace ngraph::pass;
24 using namespace ngraph::builder::subgraph;
26 class MVNTransformationTestValues {
30 ngraph::element::Type precisionBeforeDequantization;
31 ngraph::builder::subgraph::DequantizationOperations dequantization;
36 ngraph::element::Type precisionBeforeDequantization;
37 ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
38 ngraph::element::Type precisionAfterOperation;
39 ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
42 ngraph::Shape inputShape;
43 ngraph::AxisSet reductionAxes;
44 bool normalizeVariance;
45 ngraph::pass::low_precision::LayerTransformation::Params params;
51 inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& values) {
53 for (size_t i = 0; i < values.size(); ++i) {
55 if (i != (values.size() - 1ul)) {
63 class MVNTransformation : public LayerTransformation, public testing::WithParamInterface<MVNTransformationTestValues> {
65 void SetUp() override {
66 const MVNTransformationTestValues testValues = GetParam();
68 actualFunction = ngraph::builder::subgraph::MVNFunction::getOriginal(
69 testValues.inputShape,
70 testValues.reductionAxes,
71 testValues.normalizeVariance,
72 testValues.actual.precisionBeforeDequantization,
73 testValues.actual.dequantization);
75 SimpleLowPrecisionTransformer transformer;
76 transformer.add<ngraph::pass::low_precision::MVNTransformation, ngraph::opset1::Interpolate>(testValues.params);
77 transformer.transform(actualFunction);
79 referenceFunction = ngraph::builder::subgraph::MVNFunction::getReference(
80 testValues.inputShape,
81 testValues.reductionAxes,
82 testValues.normalizeVariance,
83 testValues.expected.precisionBeforeDequantization,
84 testValues.expected.dequantizationBefore,
85 testValues.expected.precisionAfterOperation,
86 testValues.expected.dequantizationAfter);
89 static std::string getTestCaseName(testing::TestParamInfo<MVNTransformationTestValues> obj) {
90 const MVNTransformationTestValues testValues = obj.param;
92 std::ostringstream result;
94 testValues.inputShape << "_" <<
95 testValues.reductionAxes << "_" <<
96 testValues.normalizeVariance << "_" <<
97 testValues.actual.precisionBeforeDequantization << "_" <<
98 testValues.actual.dequantization << "_" <<
99 testValues.expected.dequantizationBefore;
104 const std::vector<MVNTransformationTestValues> testValues = {
106 ngraph::Shape{ 1, 4, 16, 16 },
109 LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(false),
112 {{ngraph::element::f32}, {-0.32f}, {0.45f}}
116 {{ngraph::element::f32}, {-0.32f}, {0.45f}},
117 ngraph::element::f32,
122 ngraph::Shape{ 1, 4, 16, 16 },
125 LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(false),
128 {{ngraph::element::f32}, {}, {0.45f}}
133 ngraph::element::f32,
138 ngraph::Shape{ 1, 4, 16, 16 },
141 LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
144 {{ngraph::element::f32}, {127.f}, {0.45f}}
148 {{ngraph::element::f32}, {127.f}, {}},
149 ngraph::element::f32,
154 ngraph::Shape{ 1, 4, 16, 16 },
157 LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
160 {{ngraph::element::f32}, {12.5f}, {0.45f}}
164 {{ngraph::element::f32}, {12.5f}, {0.45f}},
165 ngraph::element::f32,
170 ngraph::Shape{ 1, 4, 16, 16 },
173 LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(false),
176 {{ngraph::element::f32}, {127.f}, {0.45f}}
180 {{ngraph::element::f32}, {127.f}, {0.45f}},
181 ngraph::element::f32,
187 ngraph::Shape{ 1, 4, 16, 16 },
190 LayerTransformation::createParamsU8I8(),
193 {{ngraph::element::f32}, {}, {-0.5f}}
198 ngraph::element::f32,
204 ngraph::Shape{ 1, 4, 16, 16 },
207 LayerTransformation::createParamsU8I8(),
210 {{ngraph::element::f32}, {}, {0.45f}}
215 ngraph::element::f32,
220 ngraph::Shape{ 1, 2, 2, 2 },
223 LayerTransformation::createParamsU8I8(),
226 {{ngraph::element::f32}, {}, {{0.45f, 0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
231 ngraph::element::f32,
232 {{}, {}, {{0.45f, 0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
236 ngraph::Shape{ 1, 2, 2, 2 },
239 LayerTransformation::createParamsU8I8(),
242 {{ngraph::element::f32}, {}, {{0.45f, -0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
247 ngraph::element::f32,
248 {{}, {}, {{1.f, -1.f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
252 ngraph::Shape{ 1, 2, 2, 2 },
255 LayerTransformation::createParamsU8I8(),
258 {{ngraph::element::f32}, {}, {{0.45f, -0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
262 {{ngraph::element::f32}, {}, {{0.45f, -0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}},
263 ngraph::element::f32,
269 TEST_P(MVNTransformation, CompareFunctions) {
270 actualFunction->validate_nodes_and_infer_types();
271 auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
272 ASSERT_TRUE(res.first) << res.second;
275 INSTANTIATE_TEST_CASE_P(
278 ::testing::ValuesIn(testValues),
279 MVNTransformation::getTestCaseName);