1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "layer_transformation.hpp"
10 #include <gtest/gtest.h>
12 #include <transformations/utils/utils.hpp>
13 #include <transformations/init_node_info.hpp>
14 #include <low_precision/max_pool.hpp>
15 #include <low_precision/transformer.hpp>
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "simple_low_precision_transformer.hpp"
19 #include "ngraph_functions/low_precision_transformations/max_pool_function.hpp"
20 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
23 using namespace testing;
24 using namespace ngraph::pass;
26 class MaxPoolTransformationTestValues {
30 ngraph::element::Type precisionBeforeDequantization;
31 ngraph::builder::subgraph::DequantizationOperations dequantization1;
32 ngraph::builder::subgraph::DequantizationOperations dequantization2;
37 ngraph::element::Type precisionBeforeDequantization;
38 ngraph::builder::subgraph::DequantizationOperations dequantization1;
39 ngraph::builder::subgraph::DequantizationOperations dequantization2;
42 ngraph::pass::low_precision::LayerTransformation::Params params;
49 MaxPoolTransformationTestValues> MaxPoolTransformationParams;
51 class MaxPoolTransformation : public LayerTransformation, public testing::WithParamInterface<MaxPoolTransformationParams> {
53 void SetUp() override {
54 const ngraph::Shape shape = std::get<0>(GetParam());
55 const MaxPoolTransformationTestValues testValues = std::get<1>(GetParam());
57 actualFunction = ngraph::builder::subgraph::MaxPoolFunction::get(
59 testValues.actual.precisionBeforeDequantization,
60 testValues.actual.dequantization1,
61 testValues.actual.dequantization2);
63 SimpleLowPrecisionTransformer transform;
64 transform.add<ngraph::pass::low_precision::MaxPoolTransformation, ngraph::opset1::MaxPool>(testValues.params);
65 transform.transform(actualFunction);
67 referenceFunction = ngraph::builder::subgraph::MaxPoolFunction::get(
69 testValues.expected.precisionBeforeDequantization,
70 testValues.expected.dequantization1,
71 testValues.expected.dequantization2);
74 static std::string getTestCaseName(testing::TestParamInfo<MaxPoolTransformationParams> obj) {
75 const ngraph::Shape shape = std::get<0>(obj.param);
76 const MaxPoolTransformationTestValues testValues = std::get<1>(obj.param);
78 std::ostringstream result;
80 LayerTransformation::getTestCaseNameByParams(testValues.actual.precisionBeforeDequantization, shape, testValues.params) << "_" <<
81 testValues.actual.dequantization1 << "_" <<
82 testValues.actual.dequantization2 << "_" <<
83 testValues.expected.dequantization1 << "_" <<
84 testValues.expected.dequantization2 << "_";
89 TEST_P(MaxPoolTransformation, CompareFunctions) {
90 actualFunction->validate_nodes_and_infer_types();
91 auto res = compare_functions(referenceFunction, actualFunction, true, false, true);
92 ASSERT_TRUE(res.first) << res.second;
95 const std::vector<ngraph::Shape> shapes = {
100 const std::vector<MaxPoolTransformationTestValues> testValues = {
103 LayerTransformation::createParamsU8I8(),
106 { {}, {}, { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }},
112 { ngraph::element::f32, {}, { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }}
115 // Subtract + Multiply
117 LayerTransformation::createParamsU8I8(),
122 { {128.f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 },
123 { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }
131 ngraph::element::f32,
132 { {128.f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 },
133 { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }
137 // Convert + Subtract + Multiply
139 LayerTransformation::createParamsU8I8(),
142 { ngraph::element::f32, { 128 }, { 0.02f }},
148 { ngraph::element::f32, { 128 }, { 0.02f }}
151 // Convert + Subtract + Multiply
153 LayerTransformation::createParamsU8I8(),
156 { ngraph::element::f32, {}, { 0.02f }},
162 { ngraph::element::f32, {}, { 0.02f }}
165 // Convert + Subtract + Multiply
167 LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
170 { ngraph::element::f32, { 128 }, { 0.02f }},
176 { ngraph::element::f32, { 128 }, { 0.02f }}
179 // Convert + Subtract + Multiply
181 LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
184 { ngraph::element::f32, {}, { 0.02f }},
190 { ngraph::element::f32, {}, { 0.02f }}
195 INSTANTIATE_TEST_CASE_P(
197 MaxPoolTransformation,
199 ::testing::ValuesIn(shapes),
200 ::testing::ValuesIn(testValues)),
201 MaxPoolTransformation::getTestCaseName);