[LPT] integration branch: Reshape fix, Concat generalization, runtime info usage...
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / max_pool_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <memory>
9
10 #include <gtest/gtest.h>
11
12 #include <transformations/utils/utils.hpp>
13 #include <transformations/init_node_info.hpp>
14 #include <low_precision/max_pool.hpp>
15 #include <low_precision/transformer.hpp>
16
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "simple_low_precision_transformer.hpp"
19 #include "ngraph_functions/low_precision_transformations/max_pool_function.hpp"
20
21 using namespace testing;
22 using namespace ngraph::pass;
23
24 class MaxPoolTransformationTestValues {
25 public:
26     low_precision::LayerTransformation::Params params;
27     std::vector<float> subtractValues;
28     std::vector<float> mutliplyValues;
29 };
30
31 typedef std::tuple<
32     ngraph::element::Type,
33     ngraph::Shape,
34     MaxPoolTransformationTestValues> MaxPoolTransformationParams;
35
36 class MaxPoolTransformation : public LayerTransformation, public testing::WithParamInterface<MaxPoolTransformationParams> {
37 public:
38     void SetUp() override {
39         const ngraph::element::Type precision = std::get<0>(GetParam());
40         const ngraph::Shape shape = std::get<1>(GetParam());
41         const MaxPoolTransformationTestValues testValues = std::get<2>(GetParam());
42
43         actualFunction = ngraph::builder::subgraph::MaxPoolFunction::getOriginal(
44             precision,
45             shape,
46             {
47                 testValues.params.updatePrecisions ? testValues.params.precisionsOnActivations[0] : precision,
48                 testValues.subtractValues,
49                 testValues.mutliplyValues
50             });
51
52         SimpleLowPrecisionTransformer transform;
53         transform.add<ngraph::pass::low_precision::MaxPoolTransformation, ngraph::opset1::MaxPool>(testValues.params);
54         transform.transform(actualFunction);
55
56         referenceFunction = ngraph::builder::subgraph::MaxPoolFunction::getReference(
57             precision,
58             shape,
59             {
60                 testValues.params.updatePrecisions ? testValues.params.precisionsOnActivations[0] : precision,
61                 testValues.subtractValues,
62                 testValues.mutliplyValues
63             });
64     }
65
66     static std::string getTestCaseName(testing::TestParamInfo<MaxPoolTransformationParams> obj) {
67         const ngraph::element::Type precision = std::get<0>(obj.param);
68         const ngraph::Shape shape = std::get<1>(obj.param);
69         const MaxPoolTransformationTestValues testValues = std::get<2>(obj.param);
70
71         std::ostringstream result;
72         result <<
73             LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" <<
74             testValues.subtractValues.size() << "_" <<
75             testValues.mutliplyValues.size() << "_";
76         return result.str();
77     }
78 };
79
80 TEST_P(MaxPoolTransformation, CompareFunctions) {
81     actualFunction->validate_nodes_and_infer_types();
82     auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
83     ASSERT_TRUE(res.first) << res.second;
84 }
85
86 const std::vector<ngraph::element::Type> precisions = {
87     ngraph::element::f32,
88     // ngraph::element::f16
89 };
90
91 const std::vector<ngraph::Shape> shapes = {
92     { 1, 32, 72, 48 }
93 };
94
95 const std::vector<MaxPoolTransformationTestValues> testValues = {
96     { LayerTransformation::createParamsU8I8(), { 128 }, { 0.02f } },
97     { LayerTransformation::createParamsU8I8(), {}, { 0.02f } },
98     { LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), { 128 }, { 0.02f } },
99     { LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), {}, { 0.02f } },
100     { LayerTransformation::createParamsI8I8(), { 128 }, { 0.02f } },
101 };
102
103 INSTANTIATE_TEST_CASE_P(
104     LPT,
105     MaxPoolTransformation,
106     ::testing::Combine(
107         ::testing::ValuesIn(precisions),
108         ::testing::ValuesIn(shapes),
109         ::testing::ValuesIn(testValues)),
110     MaxPoolTransformation::getTestCaseName);