[LPT] POT support: absent convert fix & element-wise empty dequantization data (...
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / max_pool_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <memory>
9
10 #include <gtest/gtest.h>
11
12 #include <transformations/utils/utils.hpp>
13 #include <transformations/init_node_info.hpp>
14 #include <low_precision/max_pool.hpp>
15 #include <low_precision/transformer.hpp>
16
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "simple_low_precision_transformer.hpp"
19 #include "ngraph_functions/low_precision_transformations/max_pool_function.hpp"
20 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
21
22
23 using namespace testing;
24 using namespace ngraph::pass;
25
26 class MaxPoolTransformationTestValues {
27 public:
28     class Actual {
29     public:
30         ngraph::element::Type precisionBeforeDequantization;
31         ngraph::builder::subgraph::DequantizationOperations dequantization1;
32         ngraph::builder::subgraph::DequantizationOperations dequantization2;
33     };
34
35     class Expected {
36     public:
37         ngraph::element::Type precisionBeforeDequantization;
38         ngraph::builder::subgraph::DequantizationOperations dequantization1;
39         ngraph::builder::subgraph::DequantizationOperations dequantization2;
40     };
41
42     ngraph::pass::low_precision::LayerTransformation::Params params;
43     Actual actual;
44     Expected expected;
45 };
46
47 typedef std::tuple<
48     ngraph::Shape,
49     MaxPoolTransformationTestValues> MaxPoolTransformationParams;
50
51 class MaxPoolTransformation : public LayerTransformation, public testing::WithParamInterface<MaxPoolTransformationParams> {
52 public:
53     void SetUp() override {
54         const ngraph::Shape shape = std::get<0>(GetParam());
55         const MaxPoolTransformationTestValues testValues = std::get<1>(GetParam());
56
57         actualFunction = ngraph::builder::subgraph::MaxPoolFunction::get(
58             shape,
59             testValues.actual.precisionBeforeDequantization,
60             testValues.actual.dequantization1,
61             testValues.actual.dequantization2);
62
63         SimpleLowPrecisionTransformer transform;
64         transform.add<ngraph::pass::low_precision::MaxPoolTransformation, ngraph::opset1::MaxPool>(testValues.params);
65         transform.transform(actualFunction);
66
67         referenceFunction = ngraph::builder::subgraph::MaxPoolFunction::get(
68             shape,
69             testValues.expected.precisionBeforeDequantization,
70             testValues.expected.dequantization1,
71             testValues.expected.dequantization2);
72     }
73
74     static std::string getTestCaseName(testing::TestParamInfo<MaxPoolTransformationParams> obj) {
75         const ngraph::Shape shape = std::get<0>(obj.param);
76         const MaxPoolTransformationTestValues testValues = std::get<1>(obj.param);
77
78         std::ostringstream result;
79         result <<
80             LayerTransformation::getTestCaseNameByParams(testValues.actual.precisionBeforeDequantization, shape, testValues.params) << "_" <<
81             testValues.actual.dequantization1 << "_" <<
82             testValues.actual.dequantization2 << "_" <<
83             testValues.expected.dequantization1 << "_" <<
84             testValues.expected.dequantization2 << "_";
85         return result.str();
86     }
87 };
88
89 TEST_P(MaxPoolTransformation, CompareFunctions) {
90     actualFunction->validate_nodes_and_infer_types();
91     auto res = compare_functions(referenceFunction, actualFunction, true, false, true);
92     ASSERT_TRUE(res.first) << res.second;
93 }
94
95 const std::vector<ngraph::Shape> shapes = {
96     { 1, 32, 72, 48 },
97     { 4, 32, 72, 48 }
98 };
99
100 const std::vector<MaxPoolTransformationTestValues> testValues = {
101     // Multiply
102     {
103         LayerTransformation::createParamsU8I8(),
104         {
105             ngraph::element::u8,
106             { {}, {}, { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }},
107             {}
108         },
109         {
110             ngraph::element::u8,
111             {},
112             { ngraph::element::f32, {}, { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }}
113         }
114     },
115     // Subtract + Multiply
116     {
117         LayerTransformation::createParamsU8I8(),
118         {
119             ngraph::element::u8,
120             {
121                 {},
122                 { {128.f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 },
123                 { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }
124             },
125             {}
126         },
127         {
128             ngraph::element::u8,
129             {},
130             {
131                 ngraph::element::f32,
132                 { {128.f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 },
133                 { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }
134             }
135         }
136     },
137     // Convert + Subtract + Multiply
138     {
139         LayerTransformation::createParamsU8I8(),
140         {
141             ngraph::element::u8,
142             { ngraph::element::f32, { 128 }, { 0.02f }},
143             {}
144         },
145         {
146             ngraph::element::u8,
147             {},
148             { ngraph::element::f32, { 128 }, { 0.02f }}
149         }
150     },
151     // Convert + Subtract + Multiply
152     {
153         LayerTransformation::createParamsU8I8(),
154         {
155             ngraph::element::u8,
156             { ngraph::element::f32, {}, { 0.02f }},
157             {}
158         },
159         {
160             ngraph::element::u8,
161             {},
162             { ngraph::element::f32, {}, { 0.02f }}
163         }
164     },
165     // Convert + Subtract + Multiply
166     {
167         LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
168         {
169             ngraph::element::u8,
170             { ngraph::element::f32, { 128 }, { 0.02f }},
171             {}
172         },
173         {
174             ngraph::element::u8,
175             {},
176             { ngraph::element::f32, { 128 }, { 0.02f }}
177         }
178     },
179     // Convert + Subtract + Multiply
180     {
181         LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
182         {
183             ngraph::element::u8,
184             { ngraph::element::f32, {}, { 0.02f }},
185             {}
186         },
187         {
188             ngraph::element::u8,
189             {},
190             { ngraph::element::f32, {}, { 0.02f }}
191         }
192     }
193 };
194
195 INSTANTIATE_TEST_CASE_P(
196     LPT,
197     MaxPoolTransformation,
198     ::testing::Combine(
199         ::testing::ValuesIn(shapes),
200         ::testing::ValuesIn(testValues)),
201     MaxPoolTransformation::getTestCaseName);