1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "layer_transformation.hpp"
11 #include <gtest/gtest.h>
13 #include <transformations/utils/utils.hpp>
14 #include <transformations/init_node_info.hpp>
15 #include "low_precision/mvn.hpp"
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
19 #include "simple_low_precision_transformer.hpp"
20 #include "ngraph_functions/low_precision_transformations/mvn_function.hpp"
22 using namespace testing;
23 using namespace ngraph::pass;
24 using namespace ngraph::builder::subgraph;
26 class MVNTransformationTestValues {
30 ngraph::element::Type precisionBeforeDequantization;
31 ngraph::builder::subgraph::DequantizationOperations dequantization;
36 ngraph::element::Type precisionBeforeDequantization;
37 ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
38 ngraph::element::Type precisionAfterOperation;
39 ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
42 ngraph::Shape inputShape;
43 ngraph::AxisSet reductionAxes;
44 bool normalizeVariance;
45 ngraph::pass::low_precision::LayerTransformation::Params params;
51 inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& values) {
53 for (size_t i = 0; i < values.size(); ++i) {
55 if (i != (values.size() - 1ul)) {
63 class MVNTransformation : public LayerTransformation, public testing::WithParamInterface<MVNTransformationTestValues> {
65 void SetUp() override {
66 const MVNTransformationTestValues testValues = GetParam();
68 actualFunction = ngraph::builder::subgraph::MVNFunction::getOriginal(
69 testValues.inputShape,
70 testValues.reductionAxes,
71 testValues.normalizeVariance,
72 testValues.actual.precisionBeforeDequantization,
73 testValues.actual.dequantization);
75 SimpleLowPrecisionTransformer transformer;
76 transformer.add<ngraph::pass::low_precision::MVNTransformation, ngraph::opset1::Interpolate>(testValues.params);
77 transformer.transform(actualFunction);
79 referenceFunction = ngraph::builder::subgraph::MVNFunction::getReference(
80 testValues.inputShape,
81 testValues.reductionAxes,
82 testValues.normalizeVariance,
83 testValues.expected.precisionBeforeDequantization,
84 testValues.expected.dequantizationBefore,
85 testValues.expected.precisionAfterOperation,
86 testValues.expected.dequantizationAfter);
89 static std::string getTestCaseName(testing::TestParamInfo<MVNTransformationTestValues> obj) {
90 const MVNTransformationTestValues testValues = obj.param;
92 std::ostringstream result;
94 toString(testValues.params) << "_" <<
95 testValues.inputShape << "_" <<
96 testValues.reductionAxes << "_" <<
97 testValues.normalizeVariance << "_" <<
98 testValues.actual.precisionBeforeDequantization << "_" <<
99 testValues.actual.dequantization << "_" <<
100 testValues.expected.dequantizationBefore;
105 const std::vector<MVNTransformationTestValues> testValues = {
107 ngraph::Shape{ 1, 4, 16, 16 },
110 LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(false),
113 {{ngraph::element::f32}, {-0.32f}, {0.45f}}
117 {{ngraph::element::f32}, {-0.32f}, {0.45f}},
118 ngraph::element::f32,
123 ngraph::Shape{ 1, 4, 16, 16 },
126 LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(false),
129 {{ngraph::element::f32}, {}, {0.45f}}
134 ngraph::element::f32,
139 ngraph::Shape{ 1, 4, 16, 16 },
142 LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
145 {{ngraph::element::f32}, {127.f}, {0.45f}}
149 {{ngraph::element::f32}, {127.f}, {0.45f}},
150 ngraph::element::f32,
155 ngraph::Shape{ 1, 4, 16, 16 },
158 LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
161 {{ngraph::element::f32}, {12.5f}, {0.45f}}
165 {{ngraph::element::f32}, {12.5f}, {0.45f}},
166 ngraph::element::f32,
171 ngraph::Shape{ 1, 4, 16, 16 },
174 LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(false),
177 {{ngraph::element::f32}, {127.f}, {0.45f}}
181 {{ngraph::element::f32}, {127.f}, {0.45f}},
182 ngraph::element::f32,
188 ngraph::Shape{ 1, 4, 16, 16 },
191 LayerTransformation::createParamsU8I8(),
194 {{ngraph::element::f32}, {}, {-0.5f}}
199 ngraph::element::f32,
205 ngraph::Shape{ 1, 4, 16, 16 },
208 LayerTransformation::createParamsU8I8(),
211 {{ngraph::element::f32}, {}, {0.45f}}
216 ngraph::element::f32,
221 ngraph::Shape{ 1, 2, 2, 2 },
224 LayerTransformation::createParamsU8I8(),
227 {{ngraph::element::f32}, {}, {{0.45f, 0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
232 ngraph::element::f32,
233 {{}, {}, {{0.45f, 0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
237 ngraph::Shape{ 1, 2, 2, 2 },
240 LayerTransformation::createParamsU8I8(),
243 {{ngraph::element::f32}, {}, {{0.45f, -0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
248 ngraph::element::f32,
249 {{}, {}, {{1.f, -1.f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
253 ngraph::Shape{ 1, 2, 2, 2 },
256 LayerTransformation::createParamsU8I8(),
259 {{ngraph::element::f32}, {}, {{0.45f, -0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}}
263 {{ngraph::element::f32}, {}, {{0.45f, -0.45f}, ngraph::element::f32, ngraph::Shape{ 1, 2, 1, 1 }}},
264 ngraph::element::f32,
270 TEST_P(MVNTransformation, CompareFunctions) {
271 actualFunction->validate_nodes_and_infer_types();
272 auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
273 ASSERT_TRUE(res.first) << res.second;
276 INSTANTIATE_TEST_CASE_P(
279 ::testing::ValuesIn(testValues),
280 MVNTransformation::getTestCaseName);