e090447963d8279b38bfeb11c5adff8b61737b26
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / convolution_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <sstream>
9 #include <memory>
10
11 #include <gtest/gtest.h>
12
13 #include <transformations/utils/utils.hpp>
14 #include <transformations/init_node_info.hpp>
15 #include <low_precision/convolution.hpp>
16
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "simple_low_precision_transformer.hpp"
19 #include "ngraph_functions/low_precision_transformations/convolution_function.hpp"
20
21 using namespace testing;
22 using namespace ngraph;
23 using namespace ngraph::pass;
24
25 class ConvolutionTransformationTestValues {
26 public:
27     class Actual {
28     public:
29         ngraph::element::Type precisionBeforeDequantization;
30         ngraph::builder::subgraph::DequantizationOperations dequantization;
31         std::shared_ptr<ngraph::opset1::Constant> weights;
32         builder::subgraph::FakeQuantizeOnWeights fakeQuantizeOnWeights;
33     };
34
35     class Expected {
36     public:
37         ngraph::element::Type precisionBeforeDequantization;
38         ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
39         std::shared_ptr<ngraph::opset1::Constant> weights;
40         builder::subgraph::FakeQuantizeOnWeights fakeQuantizeOnWeights;
41         ngraph::element::Type precisionAfterOperation;
42         ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
43         ngraph::element::Type precisionAfterDequantization;
44     };
45
46     ngraph::pass::low_precision::LayerTransformation::Params params;
47     Actual actual;
48     Expected expected;
49 };
50
51 typedef std::tuple<
52     ngraph::Shape,
53     ConvolutionTransformationTestValues> ConvolutionTransformationParams;
54
55 class ConvolutionTransformation : public LayerTransformation, public testing::WithParamInterface<ConvolutionTransformationParams> {
56 public:
57     void SetUp() override {
58         const auto inputShape = std::get<0>(GetParam());
59         const auto testValues = std::get<1>(GetParam());
60
61         actualFunction = ngraph::builder::subgraph::ConvolutionFunction::getOriginal(
62             testValues.actual.precisionBeforeDequantization,
63             inputShape,
64             testValues.actual.dequantization,
65             testValues.actual.weights,
66             testValues.actual.fakeQuantizeOnWeights);
67
68         SimpleLowPrecisionTransformer transform;
69         transform.add<ngraph::pass::low_precision::ConvolutionTransformation, ngraph::opset1::Convolution>(testValues.params);
70         transform.transform(actualFunction);
71
72         referenceFunction = ngraph::builder::subgraph::ConvolutionFunction::getReference(
73                 testValues.expected.precisionBeforeDequantization,
74                 inputShape,
75                 testValues.expected.dequantizationBefore,
76                 testValues.expected.weights,
77                 testValues.expected.fakeQuantizeOnWeights,
78                 testValues.expected.precisionAfterOperation,
79                 testValues.expected.dequantizationAfter,
80                 testValues.expected.precisionAfterDequantization);
81     }
82
83     static std::string getTestCaseName(testing::TestParamInfo<ConvolutionTransformationParams> obj) {
84         auto inputShape = std::get<0>(obj.param);
85         ConvolutionTransformationTestValues testValues = std::get<1>(obj.param);
86
87         std::ostringstream result;
88         result << toString(testValues.params) << "_" <<
89             inputShape << "_" <<
90             testValues.actual.precisionBeforeDequantization << "_" <<
91             testValues.actual.dequantization << "_" << "_weights_" <<
92             testValues.actual.weights->get_element_type() << "_" << "{ " <<
93             testValues.actual.weights->cast_vector<float>()[0] << " }_" <<
94             testValues.actual.fakeQuantizeOnWeights << "_";
95         return result.str();
96     }
97 };
98
99 TEST_P(ConvolutionTransformation, CompareFunctions) {
100     actualFunction->validate_nodes_and_infer_types();
101     auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
102     ASSERT_TRUE(res.first) << res.second;
103 }
104
105 const std::vector<ngraph::Shape> shapes = {
106     ngraph::Shape({ 1, 3, 72, 48 }),
107     ngraph::Shape({ 4, 3, 72, 48 })
108 };
109
110 const std::vector<ConvolutionTransformationTestValues> testValues = {
111     // with zero point
112     {
113         LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
114         // ActualValues
115         {
116             ngraph::element::u8,
117             {{ngraph::element::f32}, { 128.f }, { 0.02f }},
118             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
119             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
120         },
121         // ExpectedValues
122         {
123             ngraph::element::u8,
124             {{}, { { 128.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {}},
125             op::Constant::create(ngraph::element::i8, ngraph::Shape{}, std::vector<float>{ -125.f }),
126             {},
127             ngraph::element::f32,
128             {{}, {}, {{ 0.0002f }, ngraph::element::f32, { 1, 1, 1 }}}
129         }
130     },
131     // with zero point
132     {
133         LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(false),
134         // ActualValues
135         {
136             ngraph::element::u8,
137             {{ngraph::element::f32}, { 128.f }, { 0.02f }},
138             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
139             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
140         },
141         // ExpectedValues
142         {
143             ngraph::element::u8,
144             {{ ngraph::element::f32 }, { 128.f }, { 0.02f }},
145             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
146             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } },
147             ngraph::element::f32,
148             {}
149         }
150     },
151     // with zero point, not update precisions
152     {
153         LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
154         // ActualValues
155         {
156             ngraph::element::f32,
157             {{ngraph::element::f32}, { 128.f }, { 0.02f }},
158             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
159             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
160         },
161         // ExpectedValues
162         {
163             ngraph::element::f32,
164             {{}, { { 128.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {}},
165             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ -125.f }),
166             {},
167             ngraph::element::f32,
168             {{}, {}, {{ 0.0002f }, ngraph::element::f32, { 1, 1, 1 }}}
169         }
170     },
171     // without zero point
172     {
173         LayerTransformation::createParamsU8I8(),
174         // ActualValues
175         {
176             ngraph::element::u8,
177             {{ngraph::element::f32}, {}, { 0.02f }},
178             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
179             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
180         },
181         // ExpectedValues
182         {
183             ngraph::element::u8,
184             {},
185             op::Constant::create(ngraph::element::i8, ngraph::Shape{}, std::vector<float>{ -125.f }),
186             {},
187             ngraph::element::f32,
188             {{}, {}, {{ 0.0002f }, ngraph::element::f32, { 1, 1, 1 }}}
189         }
190     },
191     // without zero point
192     {
193         LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(false),
194         // ActualValues
195         {
196             ngraph::element::u8,
197             {{ngraph::element::f32}, {}, { 0.02f }},
198             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
199             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
200         },
201         // ExpectedValues
202         {
203             ngraph::element::u8,
204             {},
205             op::Constant::create(ngraph::element::i8, ngraph::Shape{}, std::vector<float>{ -125.f }),
206             {},
207             ngraph::element::f32,
208             {{}, {}, {{ 0.0002f }, ngraph::element::f32, { 1, 1, 1 }}}
209         }
210     },
211     // without zero point, not update precisions
212     {
213         LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
214         // ActualValues
215         {
216             ngraph::element::f32,
217             {{ngraph::element::f32}, {}, { 0.02f }},
218             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
219             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
220         },
221         // ExpectedValues
222         {
223             ngraph::element::f32,
224             {},
225             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ -125.f }),
226             {},
227             ngraph::element::f32,
228             {{}, {}, {{ 0.0002f }, ngraph::element::f32, { 1, 1, 1 }}}
229         }
230     },
231     // with zero point, per-channel quantization with the same values
232     {
233         LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
234         // ActualValues
235         {
236             ngraph::element::u8,
237             {{ngraph::element::f32}, { { 128.f }, ngraph::element::f32, {1, 3, 1, 1} }, { { 0.02f },  ngraph::element::f32, {1, 3, 1, 1} }},
238             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
239             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
240         },
241         // ExpectedValues
242         {
243             ngraph::element::u8,
244             {{}, { { 128.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {}},
245             op::Constant::create(ngraph::element::i8, ngraph::Shape{}, std::vector<float>{ -125.f }),
246             {},
247             ngraph::element::f32,
248             {{}, {}, {{ 0.0002f }, ngraph::element::f32, { 1, 1, 1 }}}
249         }
250     },
251     // with zero point, per-channel quantization with different values
252     {
253         LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
254         // ActualValues
255         {
256             ngraph::element::u8,
257             {
258                 {ngraph::element::f32},
259                 {{ 128.f, 0.f, 128.f }, ngraph::element::f32, { 1, 3, 1, 1 }},
260                 {{ 0.02f, 0.01f, 0.03f }, ngraph::element::f32, {1, 3, 1, 1}}
261             },
262             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
263             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
264         },
265         // ExpectedValues
266         {
267             ngraph::element::u8,
268             {
269                 {ngraph::element::f32},
270                 {{ 128.f, 0.f, 128.f }, ngraph::element::f32, { 1, 3, 1, 1 }},
271                 {{ 0.02f, 0.01f, 0.03f }, ngraph::element::f32, {1, 3, 1, 1}}
272             },
273             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
274             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } },
275             ngraph::element::f32,
276             {}
277         }
278     },
279     // dequantization in second dimension
280     {
281         LayerTransformation::createParamsU8I8(),
282         // ActualValues
283         {
284             ngraph::element::f32,
285             {
286                 {ngraph::element::f32},
287                 {{ 128.f }, ngraph::element::f32, { 1, 1, 1, 1 }},
288                 {{ 0.02f }, ngraph::element::f32, {1, 1, 1, 1}}
289             },
290             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
291             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
292         },
293         // ExpectedValues
294         {
295             ngraph::element::f32,
296             {
297                 {ngraph::element::f32},
298                 {{ 128.f }, ngraph::element::f32, { 1, 1, 1, 1 }},
299                 {{ 0.02f }, ngraph::element::f32, {1, 1, 1, 1}}
300             },
301             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
302             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } },
303             ngraph::element::f32,
304             {}
305         }
306     },
307     // without dequantization operations
308     {
309         LayerTransformation::createParamsU8I8(),
310         // ActualValues
311         {
312             ngraph::element::f32,
313             {},
314             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
315             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
316         },
317         // ExpectedValues
318         {
319             ngraph::element::f32,
320             {},
321             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
322             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } },
323             ngraph::element::f32,
324             {}
325         }
326     },
327     // without zero point, without convert
328     {
329         LayerTransformation::createParamsU8I8(),
330         // ActualValues
331         {
332             ngraph::element::f32,
333             {{}, {}, { 0.02f }},
334             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
335             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
336         },
337         // ExpectedValues
338         {
339             ngraph::element::f32,
340             {{}, {}, { {0.02f}, element::f32 }},
341             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
342             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } },
343             ngraph::element::f32,
344             {}
345         }
346     },
347     // without zero point
348     {
349         LayerTransformation::createParamsU8I8(),
350         // ActualValues
351         {
352             ngraph::element::u8,
353             {{element::f32}, {}, { {0.02f}, element::f32 }},
354             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
355             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
356         },
357         // ExpectedValues
358         {
359             ngraph::element::u8,
360             {},
361             op::Constant::create(ngraph::element::i8, ngraph::Shape{}, std::vector<float>{ -125.f }),
362             {},
363             ngraph::element::f32,
364             {{}, {}, {{ 0.0002f }, ngraph::element::f32, { 1, 1, 1 }}}
365         }
366     },
367 };
368
369 INSTANTIATE_TEST_CASE_P(
370     LPT,
371     ConvolutionTransformation,
372     ::testing::Combine(
373         ::testing::ValuesIn(shapes),
374         ::testing::ValuesIn(testValues)),
375     ConvolutionTransformation::getTestCaseName);