572262854d65a4ef58bf1fb41c1ef6e3601b612b
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / mat_mul_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <sstream>
9 #include <memory>
10
11 #include <gtest/gtest.h>
12
13 #include <transformations/utils/utils.hpp>
14 #include <transformations/init_node_info.hpp>
15 #include <low_precision/transformer.hpp>
16 #include <low_precision/mat_mul.hpp>
17
18 #include "common_test_utils/ngraph_test_utils.hpp"
19 #include "ngraph_functions/low_precision_transformations/mat_mul_function.hpp"
20 #include "ngraph_functions/subgraph_builders.hpp"
21 #include "simple_low_precision_transformer.hpp"
22 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
23
24 namespace {
25
26 using namespace testing;
27 using namespace ngraph::pass;
28
29 class MatMullTransformationTestValues {
30 public:
31     class Actual {
32     public:
33         ngraph::element::Type precisionBeforeDequantization1;
34         ngraph::builder::subgraph::DequantizationOperations dequantization1;
35         ngraph::element::Type precisionBeforeDequantization2;
36         ngraph::builder::subgraph::DequantizationOperations dequantization2;
37     };
38
39     class Expected {
40     public:
41         ngraph::element::Type precisionBeforeDequantization1;
42         ngraph::builder::subgraph::DequantizationOperations dequantization1;
43         ngraph::element::Type precisionBeforeDequantization2;
44         ngraph::builder::subgraph::DequantizationOperations dequantization2;
45         ngraph::element::Type precisionBeforeOperation1;
46         ngraph::element::Type precisionBeforeOperation2;
47         ngraph::builder::subgraph::DequantizationOperations result;
48     };
49
50     ngraph::pass::low_precision::LayerTransformation::Params params;
51     Actual actual;
52     Expected expected;
53 };
54
55 inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues::Actual& actual) {
56     return out << "_" << actual.dequantization1 << "_" << actual.dequantization2;
57 }
58
59 inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues::Expected& expected) {
60     return out << "_" <<
61         expected.precisionBeforeDequantization1 << "_" <<
62         expected.dequantization1 << "_" <<
63         expected.precisionBeforeDequantization2 << "_" <<
64         expected.dequantization2 << "_" <<
65         expected.precisionBeforeOperation1 << "_" <<
66         expected.precisionBeforeOperation2 << "_" <<
67         expected.result;
68 }
69
70 inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues& values) {
71     return out << "_" <<
72         values.params.supportAsymmetricQuantization << "_" <<
73         values.params.updatePrecisions << "_" <<
74         values.actual << "_" <<
75         values.expected;
76 }
77
78 typedef std::tuple<
79     ngraph::element::Type,
80     std::pair<ngraph::Shape, ngraph::Shape>,
81     MatMullTransformationTestValues> MatMulTransformationParams;
82
83 class MatMulTransformation : public LayerTransformation, public testing::WithParamInterface<MatMulTransformationParams> {
84 public:
85     void SetUp() override {
86         const ngraph::element::Type precision = std::get<0>(GetParam());
87         const std::pair<ngraph::Shape, ngraph::Shape> shapes = std::get<1>(GetParam());
88         const MatMullTransformationTestValues testValues = std::get<2>(GetParam());
89
90         actualFunction = ngraph::builder::subgraph::MatMulFunction::getOriginal(
91             shapes.first,
92             testValues.actual.precisionBeforeDequantization1,
93             testValues.actual.dequantization1,
94             shapes.second,
95             testValues.actual.precisionBeforeDequantization2,
96             testValues.actual.dequantization2);
97
98         SimpleLowPrecisionTransformer transformer;
99         transformer.add<ngraph::pass::low_precision::MatMulTransformation, ngraph::opset1::MatMul>(testValues.params);
100         transformer.transform(actualFunction);
101
102         referenceFunction =
103             (testValues.expected.precisionBeforeOperation1 == ngraph::element::f32) && testValues.expected.result.empty() ?
104             ngraph::builder::subgraph::MatMulFunction::getOriginal(
105                 shapes.first,
106                 testValues.actual.precisionBeforeDequantization1,
107                 testValues.actual.dequantization1,
108                 shapes.second,
109                 testValues.actual.precisionBeforeDequantization2,
110                 testValues.actual.dequantization2) :
111             ngraph::builder::subgraph::MatMulFunction::getReference(
112                 precision,
113                 shapes.first,
114                 testValues.expected.precisionBeforeDequantization1,
115                 testValues.expected.dequantization1,
116                 shapes.second,
117                 testValues.expected.precisionBeforeDequantization2,
118                 testValues.expected.dequantization2,
119                 testValues.expected.result);
120     }
121
122     static std::string getTestCaseName(testing::TestParamInfo<MatMulTransformationParams> obj) {
123         ngraph::element::Type precision;
124         std::pair<ngraph::Shape, ngraph::Shape> shapes;
125         MatMullTransformationTestValues testValues;
126         std::tie(precision, shapes, testValues) = obj.param;
127
128         std::stringstream ss;
129         ss << precision << "_" << shapes.first << "_" << shapes.second << "_" << testValues;
130         return ss.str();
131     }
132 };
133
134 TEST_P(MatMulTransformation, CompareFunctions) {
135     actualFunction->validate_nodes_and_infer_types();
136     auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
137     ASSERT_TRUE(res.first) << res.second;
138 }
139
140 const std::vector<ngraph::element::Type> precisions = {
141     ngraph::element::f32,
142     // ngraph::element::f16
143 };
144
145 const std::vector<std::pair<ngraph::Shape, ngraph::Shape>> shapes = {
146     { { 1, 16, 384, 64 }, { 1, 16, 64, 384 } },
147     { { 4, 16, 384, 64 }, { 4, 16, 64, 384 } }
148 };
149
150 const std::vector<bool> updatePrecisions = { true, false };
151
152 std::vector<MatMullTransformationTestValues> testValues = {
153     // U8 + I8: Constant on dequantization operations on 0 branch
154     // {
155     //    LayerTransformation::createParamsU8U8().setSupportAsymmetricQuantization(true),
156     //    {
157     //        ngraph::element::u8,
158     //        { ngraph::element::f32, { 127.f }, { {0.02f}, ngraph::element::f32, {}, true, 0 } },
159     //        ngraph::element::i8,
160     //        { ngraph::element::f32, {}, { 0.03f } },
161     //    },
162     //    {
163     //        ngraph::element::u8,
164     //        { {}, {{127.f}, ngraph::element::f32, ngraph::Shape{ }, false}, {} },
165     //        ngraph::element::i8,
166     //        { },
167     //        ngraph::element::f32,
168     //        ngraph::element::f32,
169     //        { {}, {}, { 0.0006f } },
170     //    }
171     // },
172     // U8 + I8
173     {
174         LayerTransformation::createParamsU8U8().setSupportAsymmetricQuantization(true),
175         {
176             ngraph::element::u8,
177             { ngraph::element::f32, { 127.f }, { 0.02f } },
178             ngraph::element::i8,
179             { ngraph::element::f32, {}, { 0.03f } },
180         },
181         {
182             ngraph::element::u8,
183             { {}, {{127.f}, ngraph::element::f32, ngraph::Shape{ }, false}, {} },
184             ngraph::element::i8,
185             { },
186             ngraph::element::f32,
187             ngraph::element::f32,
188             { {}, {}, { 0.0006f } },
189         }
190     },
191     // I8 + I8
192     {
193         LayerTransformation::createParamsU8U8().setSupportAsymmetricQuantization(true),
194         {
195             ngraph::element::i8,
196             { ngraph::element::f32, { 127.f }, { 0.02f } },
197             ngraph::element::i8,
198             { ngraph::element::f32, {}, { 0.03f } },
199         },
200         {
201             ngraph::element::i8,
202             { {}, {{127.f}, ngraph::element::f32, ngraph::Shape{ }, false}, {} },
203             ngraph::element::i8,
204             { },
205             ngraph::element::f32,
206             ngraph::element::f32,
207             { {}, {}, { 0.0006f } },
208         }
209     },
210     // U8 + I8, Subtract with not int
211     {
212         LayerTransformation::createParamsU8U8().setSupportAsymmetricQuantization(true),
213         {
214             ngraph::element::u8,
215             { ngraph::element::f32, { 127.5f }, { 0.02f } },
216             ngraph::element::i8,
217             { ngraph::element::f32, {}, { 0.03f } },
218         },
219         {
220             ngraph::element::u8,
221             { ngraph::element::f32, { 127.5f }, { 0.02f } },
222             ngraph::element::i8,
223             { ngraph::element::f32, {}, { 0.03f } },
224             ngraph::element::f32,
225             ngraph::element::f32,
226             {},
227         }
228     },
229     // U8 + FP32
230     {
231         LayerTransformation::createParamsU8U8().setSupportAsymmetricQuantization(true),
232         {
233             ngraph::element::u8,
234             { ngraph::element::f32, { 127.f }, { 0.02f } },
235             ngraph::element::f32,
236             { {}, {}, { 0.03f } },
237         },
238         {
239             ngraph::element::u8,
240             { ngraph::element::f32, { 127.f }, { 0.02f } },
241             ngraph::element::f32,
242             { {}, {}, { 0.03f } },
243             ngraph::element::f32,
244             ngraph::element::f32,
245             { },
246         }
247     },
248     // FP32 + I8
249     {
250         LayerTransformation::createParamsU8U8().setSupportAsymmetricQuantization(true),
251         {
252             ngraph::element::f32,
253             { {}, { 127.f }, { 0.02f } },
254             ngraph::element::i8,
255             { ngraph::element::f32, {}, { 0.03f } },
256         },
257         {
258             ngraph::element::f32,
259             { {}, { 127.f }, { 0.02f } },
260             ngraph::element::i8,
261             { ngraph::element::f32, {}, { 0.03f } },
262             ngraph::element::f32,
263             ngraph::element::f32,
264             { },
265         }
266     },
267     {
268         LayerTransformation::createParamsU8U8().setSupportAsymmetricQuantization(false),
269         {
270             ngraph::element::u8,
271             { ngraph::element::f32, { 127.f }, { 0.02f } },
272             ngraph::element::i8,
273             { ngraph::element::f32, {}, { 0.03f } },
274         },
275         {
276             ngraph::element::u8,
277             { ngraph::element::f32, { 127.f }, { 0.02f } },
278             ngraph::element::i8,
279             { ngraph::element::f32, {}, { 0.03f } },
280             ngraph::element::f32,
281             ngraph::element::f32,
282             { },
283         }
284     },
285     {
286         LayerTransformation::createParamsU8U8().setSupportAsymmetricQuantization(false),
287         {
288             ngraph::element::u8,
289             { ngraph::element::f32, {}, { 0.02f } },
290             ngraph::element::i8,
291             { ngraph::element::f32, {}, { 0.03f } },
292         },
293         {
294             ngraph::element::u8,
295             { {}, {}, {} },
296             ngraph::element::i8,
297             { {}, {}, {} },
298             ngraph::element::u8,
299             ngraph::element::i8,
300             { {}, {}, { 0.02f * 0.03f } },
301         }
302     },
303     {
304         LayerTransformation::createParamsU8U8(),
305         {
306             ngraph::element::u8,
307             { ngraph::element::f32, {}, { 0.02f } },
308             ngraph::element::i8,
309             { ngraph::element::f32, {}, { 0.03f } },
310         },
311         {
312             ngraph::element::u8,
313             { {}, {}, {} },
314             ngraph::element::i8,
315             { {}, {}, {} },
316             ngraph::element::u8,
317             ngraph::element::i8,
318             { {}, {}, { 0.02f * 0.03f } },
319         }
320     },
321     {
322         LayerTransformation::createParamsU8U8(),
323         {
324             ngraph::element::u8,
325             { ngraph::element::f32, {}, { 0.02f } },
326             ngraph::element::u8,
327             { ngraph::element::f32, {}, { 0.03f } },
328         },
329         {
330             ngraph::element::u8,
331             { {}, {}, {} },
332             ngraph::element::u8,
333             { {}, {}, {} },
334             ngraph::element::u8,
335             ngraph::element::u8,
336             { {}, {}, { 0.02f * 0.03f } },
337         }
338     },
339     {
340         LayerTransformation::createParamsI8I8().setUpdatePrecisions(true),
341         {
342             ngraph::element::i8,
343             { ngraph::element::f32, {}, { 0.02f } },
344             ngraph::element::i8,
345             { ngraph::element::f32, {}, { 0.03f } },
346         },
347         {
348             ngraph::element::i8,
349             { {}, {}, {} },
350             ngraph::element::i8,
351             { {}, {}, {} },
352             ngraph::element::i8,
353             ngraph::element::i8,
354             { {}, {}, { 0.02f * 0.03f } },
355         }
356     },
357     {
358         LayerTransformation::createParamsI8I8().setUpdatePrecisions(false),
359         {
360             ngraph::element::f32,
361             { {}, {}, { 0.02f } },
362             ngraph::element::f32,
363             { {}, {}, { 0.03f } },
364         },
365         {
366             ngraph::element::f32,
367             { {}, {}, {} },
368             ngraph::element::f32,
369             { {}, {}, {} },
370             ngraph::element::f32,
371             ngraph::element::f32,
372             { {}, {}, { 0.02f * 0.03f } },
373         }
374     }
375 };
376
377 INSTANTIATE_TEST_CASE_P(
378     LPT,
379     MatMulTransformation,
380     ::testing::Combine(
381         ::testing::ValuesIn(precisions),
382         ::testing::ValuesIn(shapes),
383         ::testing::ValuesIn(testValues)),
384     MatMulTransformation::getTestCaseName);
385
386 } // namespace