342f13891c696c8dc4b621c22b7c2d974dd297cf
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / max_pool_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <memory>
9
10 #include <gtest/gtest.h>
11
12 #include <transformations/utils/utils.hpp>
13 #include <transformations/init_node_info.hpp>
14 #include <low_precision/max_pool.hpp>
15 #include <low_precision/transformer.hpp>
16
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "simple_low_precision_transformer.hpp"
19 #include "ngraph_functions/low_precision_transformations/max_pool_function.hpp"
20
21 using namespace testing;
22 using namespace ngraph::pass;
23
24 class MaxPoolTransformationTestValues {
25 public:
26     low_precision::LayerTransformation::Params params;
27     std::vector<float> subtractValues;
28     std::vector<float> mutliplyValues;
29 };
30
31 typedef std::tuple<
32     ngraph::element::Type,
33     ngraph::Shape,
34     MaxPoolTransformationTestValues> MaxPoolTransformationParams;
35
36 class MaxPoolTransformation : public LayerTransformation, public testing::WithParamInterface<MaxPoolTransformationParams> {
37 public:
38     void SetUp() override {
39         const ngraph::element::Type precision = std::get<0>(GetParam());
40         const ngraph::Shape shape = std::get<1>(GetParam());
41         const MaxPoolTransformationTestValues testValues = std::get<2>(GetParam());
42
43         actualFunction = ngraph::builder::subgraph::MaxPoolFunction::getOriginal(
44             precision,
45             shape,
46             {
47                 testValues.params.updatePrecisions ? testValues.params.precisionsOnActivations[0] : precision,
48                 testValues.subtractValues,
49                 testValues.mutliplyValues
50             });
51
52         SimpleLowPrecisionTransformer transform;
53         transform.add<ngraph::pass::low_precision::MaxPoolTransformation, ngraph::opset1::MaxPool>(testValues.params);
54         transform.transform(actualFunction);
55
56         referenceFunction = ngraph::builder::subgraph::MaxPoolFunction::getReference(
57             precision,
58             shape,
59             {
60                 testValues.params.updatePrecisions ? testValues.params.precisionsOnActivations[0] : precision,
61                 testValues.subtractValues,
62                 testValues.mutliplyValues
63             });
64     }
65
66     static std::string getTestCaseName(testing::TestParamInfo<MaxPoolTransformationParams> obj) {
67         const ngraph::element::Type precision = std::get<0>(obj.param);
68         const ngraph::Shape shape = std::get<1>(obj.param);
69         const MaxPoolTransformationTestValues testValues = std::get<2>(obj.param);
70
71         std::ostringstream result;
72         result <<
73             LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" <<
74             testValues.subtractValues.size() << "_" <<
75             testValues.mutliplyValues.size() << "_";
76         return result.str();
77     }
78 };
79
80 TEST_P(MaxPoolTransformation, CompareFunctions) {
81     InitNodeInfo().run_on_function(actualFunction);
82     actualFunction->validate_nodes_and_infer_types();
83
84     auto res = compare_functions(referenceFunction, actualFunction, true, true);
85     ASSERT_TRUE(res.first) << res.second;
86 }
87
88 const std::vector<ngraph::element::Type> precisions = {
89     ngraph::element::f32,
90     // ngraph::element::f16
91 };
92
93 const std::vector<ngraph::Shape> shapes = {
94     { 1, 32, 72, 48 }
95 };
96
97 const std::vector<MaxPoolTransformationTestValues> testValues = {
98     { LayerTransformation::createParamsU8I8(), { 128 }, { 0.02f } },
99     { LayerTransformation::createParamsU8I8(), {}, { 0.02f } },
100     { LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), { 128 }, { 0.02f } },
101     { LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), {}, { 0.02f } },
102     { LayerTransformation::createParamsI8I8(), { 128 }, { 0.02f } },
103 };
104
105 INSTANTIATE_TEST_CASE_P(
106     LPT,
107     MaxPoolTransformation,
108     ::testing::Combine(
109         ::testing::ValuesIn(precisions),
110         ::testing::ValuesIn(shapes),
111         ::testing::ValuesIn(testValues)),
112     MaxPoolTransformation::getTestCaseName);