[LPT] integration branch: Reshape fix, Concat generalization, runtime info usage...
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / move_dequantization_after_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <sstream>
9 #include <memory>
10
11 #include <gtest/gtest.h>
12
13 #include <utility>
14 #include <transformations/utils/utils.hpp>
15 #include <transformations/init_node_info.hpp>
16 #include <low_precision/network_helper.hpp>
17
18 #include "common_test_utils/ngraph_test_utils.hpp"
19 #include "ngraph_functions/low_precision_transformations/move_dequantization_after_function.hpp"
20 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
21
22 using namespace testing;
23 using namespace ngraph::pass;
24 using namespace ngraph::builder::subgraph;
25
26 class MoveDequantizationAfterTransformationParams {
27 public:
28     class Actual {
29     public:
30         ngraph::builder::subgraph::DequantizationOperations dequantization;
31     };
32
33     class Expected {
34     public:
35         ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
36         ngraph::element::Type precisionAfterOperation;
37         ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
38     };
39
40     ngraph::element::Type originalPrecision;
41     ngraph::pass::low_precision::LayerTransformation::Params params;
42     bool updatePrecision;
43     bool moveSubtract;
44     Actual actual;
45     Expected expected;
46 };
47
48 typedef std::tuple<
49     ngraph::Shape,
50     MoveDequantizationAfterTransformationParams> MoveDequantizationAfterTransformationTestValues;
51
52 class MoveDequantizationAfterTransformation :
53     public LayerTransformation,
54     public testing::WithParamInterface<MoveDequantizationAfterTransformationTestValues> {
55 public:
56     void SetUp() override {
57         const auto inputShape = std::get<0>(GetParam());
58         const auto testValues = std::get<1>(GetParam());
59         actualFunction = ngraph::builder::subgraph::MoveDequantizationAfterFunction::getOriginal(
60             testValues.originalPrecision,
61             inputShape,
62             testValues.actual.dequantization);
63
64         const auto targetNode = actualFunction->get_output_op(0)->get_input_node_shared_ptr(0);
65         const auto dequantization = ngraph::pass::low_precision::NetworkHelper::getDequantization(targetNode);
66         ngraph::pass::low_precision::NetworkHelper::moveDequantizationAfter(
67             targetNode,
68             dequantization,
69             testValues.updatePrecision,
70             testValues.moveSubtract);
71
72         referenceFunction = ngraph::builder::subgraph::MoveDequantizationAfterFunction::getReference(
73             testValues.originalPrecision,
74             inputShape,
75             testValues.expected.dequantizationBefore,
76             testValues.expected.precisionAfterOperation,
77             testValues.expected.dequantizationAfter);
78     }
79
80     static std::string getTestCaseName(testing::TestParamInfo<MoveDequantizationAfterTransformationTestValues> obj) {
81         const auto inputShape = std::get<0>(obj.param);
82         const auto testValues = std::get<1>(obj.param);
83
84         std::ostringstream result;
85         result <<
86             testValues.originalPrecision << "_" <<
87             inputShape << "_" <<
88             testValues.actual.dequantization << "_" <<
89             (testValues.moveSubtract ? "move_subtract_" : "don't_move_subtract_") <<
90             (testValues.updatePrecision ? "updatePrecision" : "don't_update_precision");
91         return result.str();
92     }
93 };
94
95 TEST_P(MoveDequantizationAfterTransformation, CompareFunctions) {
96     actualFunction->validate_nodes_and_infer_types();
97     auto res = compare_functions(referenceFunction, actualFunction, true, false, true);
98     ASSERT_TRUE(res.first) << res.second;
99 }
100
101 const std::vector<ngraph::Shape> inputShapes = {
102     { 1, 3, 16, 16 },
103     { 4, 3, 16, 16 }
104 };
105
106 const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
107     // U8
108     {
109         ngraph::element::u8,
110         LayerTransformation::createParamsU8I8(),
111         true,
112         true,
113         {
114             { {ngraph::element::f32},  { 7.f }, { 10.f } },
115         },
116         {
117             { {},  {}, {} },
118             ngraph::element::u8,
119             { {ngraph::element::f32},  { 7.f }, { 10.f } },
120         },
121     },
122     // moveSubtract = false
123     {
124         ngraph::element::u8,
125         LayerTransformation::createParamsU8I8(),
126         true,
127         false,
128         {
129             { {ngraph::element::f32},  { 7.f }, { 10.f } },
130         },
131         {
132             { {},  { { 7.f }, ngraph::element::f32, {}, false }, {} },
133             ngraph::element::f32,
134             { {},  {}, { 10.f } },
135         },
136     },
137     // updatePrecision = false
138     {
139         ngraph::element::u8,
140         LayerTransformation::createParamsU8I8(),
141         false,
142         true,
143         {
144             { {ngraph::element::f32},  { 7.f }, { 10.f } },
145         },
146         {
147             { {},  {}, {} },
148             ngraph::element::f32,
149             { {},  { 7.f }, { 10.f } },
150         },
151     },
152     // moveSubtract = false & updatePrecision = false
153     {
154         ngraph::element::u8,
155         LayerTransformation::createParamsU8I8(),
156         false,
157         false,
158         {
159             { {ngraph::element::f32},  { 7.f }, { 10.f } },
160         },
161         {
162             { {},  { { 7.f }, ngraph::element::f32, {}, false }, {} },
163             ngraph::element::f32,
164             { {},  {}, { 10.f } },
165         },
166     },
167     // I8
168     {
169         ngraph::element::i8,
170         LayerTransformation::createParamsI8I8(),
171         true,
172         true,
173         {
174             { {ngraph::element::f32},  { 7.f }, { 10.f } },
175         },
176         {
177             { {},  {}, {} },
178             ngraph::element::i8,
179             { {ngraph::element::f32},  { 7.f }, { 10.f } },
180         },
181     },
182     // moveSubtract = false
183     {
184         ngraph::element::i8,
185         LayerTransformation::createParamsI8I8(),
186         true,
187         false,
188         {
189             { {ngraph::element::f32},  { 7.f }, { 10.f } },
190         },
191         {
192             { {},  { { 7.f }, ngraph::element::f32, {}, false }, {} },
193             ngraph::element::f32,
194             { {},  {}, { 10.f } },
195         },
196     },
197     // updatePrecision = false
198     {
199         ngraph::element::i8,
200         LayerTransformation::createParamsI8I8(),
201         false,
202         true,
203         {
204             { {ngraph::element::f32},  { 7.f }, { 10.f } },
205         },
206         {
207             { {},  {}, {} },
208             ngraph::element::f32,
209             { {},  { 7.f }, { 10.f } },
210         },
211     },
212     // moveSubtract = false & updatePrecision = false
213     {
214         ngraph::element::i8,
215         LayerTransformation::createParamsI8I8(),
216         false,
217         false,
218         {
219             { {ngraph::element::f32},  { 7.f }, { 10.f } },
220         },
221         {
222             { {},  { { 7.f }, ngraph::element::f32, {}, false }, {} },
223             ngraph::element::f32,
224             { {},  {}, { 10.f } },
225         },
226     },
227     // per-channel quantizations with the same values
228     {
229         ngraph::element::u8,
230         LayerTransformation::createParamsU8I8(),
231         false,
232         false,
233         {
234             { {ngraph::element::f32},  { { 7.f, 7.f, 7.f } }, { { 10.f, 10.f, 10.f } } },
235         },
236         {
237             { {},  { { 7.f, 7.f, 7.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {} },
238             ngraph::element::f32,
239             { {},  {}, { { 10.f, 10.f, 10.f } } },
240         },
241     },
242     // per-channel quantizations with the same values
243     {
244         ngraph::element::u8,
245         LayerTransformation::createParamsU8I8(),
246         false,
247         false,
248         {
249             { {ngraph::element::f32},  { { 7.f, 8.f, 9.f } }, { { 10.f, 12.f, 16.f } } },
250         },
251         {
252             { {},  { { 7.f, 8.f, 9.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {} },
253             ngraph::element::f32,
254             { {},  {}, { { 10.f, 12.f, 16.f } } },
255         },
256     },
257 };
258
259 INSTANTIATE_TEST_CASE_P(
260     LPT,
261     MoveDequantizationAfterTransformation,
262     ::testing::Combine(
263         ::testing::ValuesIn(inputShapes),
264         ::testing::ValuesIn(testValues)),
265     MoveDequantizationAfterTransformation::getTestCaseName);