[LPT] integration branch: Reshape fix, Concat generalization, runtime info usage...
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / mat_mul_with_constant_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <sstream>
9 #include <memory>
10
11 #include <gtest/gtest.h>
12
13 #include <transformations/init_node_info.hpp>
14 #include <low_precision/transformer.hpp>
15 #include <low_precision/mat_mul.hpp>
16
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "ngraph_functions/low_precision_transformations/mat_mul_function.hpp"
19 #include "ngraph_functions/subgraph_builders.hpp"
20 #include "simple_low_precision_transformer.hpp"
21 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
22
23 namespace {
24
25 using namespace testing;
26 using namespace ngraph::pass;
27
28 class MatMullTransformationTestValues {
29 public:
30     class Actual {
31     public:
32         ngraph::Shape inputShape;
33         ngraph::element::Type precisionBeforeDequantization;
34         ngraph::builder::subgraph::DequantizationOperations dequantization;
35         ngraph::Shape weightsConstShape;
36         std::vector<float> weightsConstValues;
37         ngraph::builder::subgraph::FakeQuantizeOnWeights fqOnWeights;
38     };
39
40     class Expected {
41     public:
42         ngraph::Shape inputShape;
43         ngraph::element::Type precisionBeforeDequantization;
44         ngraph::builder::subgraph::DequantizationOperations dequantization;
45         ngraph::element::Type weightsConstPrecision;
46         ngraph::Shape weightsConstShape;
47         std::vector<float> weightsConstValues;
48
49         ngraph::element::Type precisionBeforeOperation;
50         ngraph::builder::subgraph::DequantizationOperations resultDequantization;
51
52         ngraph::builder::subgraph::FakeQuantizeOnWeights fqOnWeights;
53     };
54
55     ngraph::pass::low_precision::LayerTransformation::Params params;
56     Actual actual;
57     Expected expected;
58 };
59
60 inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues::Actual& actual) {
61     return out << "_" <<
62         actual.inputShape << "_" <<
63         actual.precisionBeforeDequantization << "_" <<
64         actual.dequantization << "_" <<
65         actual.weightsConstShape << "_" <<
66         actual.fqOnWeights;
67 }
68
69 inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues::Expected& expected) {
70     return out << "_" <<
71         expected.weightsConstShape <<"_" <<
72         expected.dequantization << "_" <<
73         expected.precisionBeforeOperation << "_" <<
74         expected.resultDequantization << "_" <<
75         expected.fqOnWeights;
76 }
77
78 inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues& values) {
79     return out << "_" << values.actual << "_" << values.expected;
80 }
81
82 typedef std::tuple<
83     ngraph::element::Type,
84     size_t,
85     MatMullTransformationTestValues> MatMulTransformationParams;
86
87 class MatMulWithConstantTransformation : public LayerTransformation, public testing::WithParamInterface<MatMulTransformationParams> {
88 public:
89     void SetUp() override {
90         const ngraph::element::Type precision = std::get<0>(GetParam());
91         const size_t batch = std::get<1>(GetParam());
92
93         MatMullTransformationTestValues testValues = std::get<2>(GetParam());
94         testValues.actual.inputShape[0] = batch;
95         testValues.expected.inputShape[0] = batch;
96
97         actualFunction = ngraph::builder::subgraph::MatMulFunction::getOriginal(
98             precision,
99             testValues.actual.inputShape,
100             testValues.actual.precisionBeforeDequantization,
101             testValues.actual.dequantization,
102             testValues.actual.weightsConstShape,
103             testValues.actual.weightsConstValues,
104             testValues.actual.fqOnWeights);
105
106         SimpleLowPrecisionTransformer transformer;
107         transformer.add<ngraph::pass::low_precision::MatMulTransformation, ngraph::opset1::MatMul>(testValues.params);
108         transformer.transform(actualFunction);
109
110         referenceFunction = testValues.expected.fqOnWeights.empty() ?
111             ngraph::builder::subgraph::MatMulFunction::getReference(
112                 precision,
113                 testValues.expected.inputShape,
114                 testValues.expected.precisionBeforeDequantization,
115                 testValues.expected.dequantization,
116                 testValues.expected.weightsConstPrecision,
117                 testValues.expected.weightsConstShape,
118                 testValues.expected.weightsConstValues,
119                 testValues.expected.resultDequantization) :
120             ngraph::builder::subgraph::MatMulFunction::getOriginal(
121                 precision,
122                 testValues.expected.inputShape,
123                 testValues.expected.precisionBeforeDequantization,
124                 testValues.expected.dequantization,
125                 testValues.expected.weightsConstShape,
126                 testValues.expected.weightsConstValues,
127                 testValues.expected.fqOnWeights);
128     }
129
130     static std::string getTestCaseName(testing::TestParamInfo<MatMulTransformationParams> obj) {
131         ngraph::element::Type precision;
132         size_t batch;
133         MatMullTransformationTestValues testValues;
134         std::tie(precision, batch, testValues) = obj.param;
135
136         std::stringstream ss;
137         ss << precision << "_" << batch << "_" << testValues;
138         return ss.str();
139     }
140 };
141
142 TEST_P(MatMulWithConstantTransformation, CompareFunctions) {
143     actualFunction->validate_nodes_and_infer_types();
144     auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
145     ASSERT_TRUE(res.first) << res.second;
146 }
147
148 const std::vector<ngraph::element::Type> precisions = {
149     ngraph::element::f32,
150     // ngraph::element::f16
151 };
152
153 const std::vector<size_t> batches = { 1, 4 };
154
155 std::vector<MatMullTransformationTestValues> testValues = {
156     // supported 3D: U8 & I8
157     {
158         LayerTransformation::createParamsU8I8(),
159         {
160             { 1, 384, 1024 },
161             ngraph::element::u8,
162             { ngraph::element::f32, {}, { 0.02f } },
163             { 1024, 1024 },
164             std::vector<float>(1024 * 1024, 1.f),
165             { 255, { 1, 1 },  {0.f}, {254.f}, {-12.7f}, {12.7} },
166         },
167         {
168             { 1, 384, 1024 },
169             ngraph::element::u8,
170             { {}, {}, {} },
171             ngraph::element::i8,
172             { 1024, 1024 },
173             std::vector<float>(1024 * 1024, -126),
174             ngraph::element::i8,
175             { {}, {}, { 0.02f * 0.1f } },
176             {}
177         }
178     },
179
180     // not supported 3D: U8 & I8
181     {
182         LayerTransformation::createParamsU8I8(),
183         {
184             { 1, 3, 4 },
185             ngraph::element::u8,
186             { ngraph::element::f32, {}, { {0.01f, 0.02f, 0.03f} } },
187             { 4, 4 },
188             std::vector<float>(4 * 4, 1.f),
189             { 255, { 1, 1 },  {0.f}, {254.f}, {-12.7f}, {12.7} },
190         },
191         {
192             { 1, 3, 4 },
193             ngraph::element::u8,
194             { ngraph::element::f32, {}, { {0.01f, 0.02f, 0.03f} } },
195             ngraph::element::i8,
196             {4, 4},
197             std::vector<float>(4 * 4, 1.f),
198             ngraph::element::f32,
199             {},
200             { 255, { 1, 1 },  {0.f}, {254.f}, {-12.7f}, {12.7} },
201         }
202     },
203
204     // not supported 3D: U8 & I8
205     {
206         LayerTransformation::createParamsU8I8(),
207         {
208             { 1, 3, 4 },
209             ngraph::element::u8,
210             { ngraph::element::f32, {}, { 0.02f } },
211             { 4, 4 },
212             std::vector<float>(4 * 4, 1.f),
213             {
214                 255,
215                 { 4, 1 },
216                 {0.f, 0.f, 0.f, 0.f},
217                 {254.f, 254.f, 254.f, 254.f},
218                 {-12.7f / 4.f, -12.7f / 3.f, -12.7f / 2.f, -12.7f},
219                 {12.7f / 4.f, 12.7f / 3.f, 12.7f / 2.f, 12.7f}
220             },
221         },
222         {
223             { 1, 3, 4 },
224             ngraph::element::u8,
225             { ngraph::element::f32, {}, { 0.02f } },
226             ngraph::element::i8,
227             {4, 4},
228             std::vector<float>(4 * 4, 1.f),
229             ngraph::element::f32,
230             {},
231             {
232                 255,
233                 { 4, 1 },
234                 {0.f, 0.f, 0.f, 0.f},
235                 {254.f, 254.f, 254.f, 254.f},
236                 {-12.7f / 4.f, -12.7f / 3.f, -12.7f / 2.f, -12.7f},
237                 {12.7f / 4.f, 12.7f / 3.f, 12.7f / 2.f, 12.7f}
238             },
239         }
240     },
241
242     // 2D: U8 & I8
243     {
244         LayerTransformation::createParamsU8I8(),
245         {
246             { 1, 2048 },
247             ngraph::element::u8,
248             { ngraph::element::f32, {}, { 0.02f } },
249             { 2048, 1000 },
250             std::vector<float>(2048 * 1000, 1.f),
251             { 255, { 1, 1 },  {0.f}, {254.f}, {-12.7f}, {12.7} },
252         },
253         {
254             { 1, 2048 },
255             ngraph::element::u8,
256             { {}, {}, {} },
257             ngraph::element::i8,
258             {2048, 1000},
259             std::vector<float>(2048 * 1000, -126),
260             ngraph::element::i8,
261             { {}, {}, { 0.02f * 0.1f } },
262             {}
263         }
264     },
265     // 2D: I8 & I8
266     {
267         LayerTransformation::createParamsI8I8(),
268         {
269             { 1, 2048 },
270             ngraph::element::i8,
271             { ngraph::element::f32, {}, { 0.02f } },
272             { 2048, 1000 },
273             std::vector<float>(2048 * 1000, 1.f),
274             { 255, { 1, 1 },  {0.f}, {254.f}, {-12.7f}, {12.7} },
275         },
276         {
277             { 1, 2048 },
278             ngraph::element::i8,
279             { {}, {}, {} },
280             ngraph::element::i8,
281             {2048, 1000},
282             std::vector<float>(2048 * 1000, -126),
283             ngraph::element::i8,
284             { {}, {}, { 0.02f * 0.1f } },
285             {}
286         }
287     },
288     // 2D: FP32 & FP328
289     {
290         LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
291         {
292             { 1, 2048 },
293             ngraph::element::f32,
294             { {}, {}, { 0.02f } },
295             { 2048, 1000 },
296             std::vector<float>(2048 * 1000, 1.f),
297             { 255, { 1, 1 },  {0.f}, {254.f}, {-12.7f}, {12.7} },
298         },
299         {
300             { 1, 2048 },
301             ngraph::element::f32,
302             { {}, {}, {} },
303             ngraph::element::f32,
304             {2048, 1000},
305             std::vector<float>(2048 * 1000, -126),
306             ngraph::element::f32,
307             { {}, {}, { 0.02f * 0.1f } },
308             {}
309         }
310     },
311 };
312
313 INSTANTIATE_TEST_CASE_P(
314     LPT,
315     MatMulWithConstantTransformation,
316     ::testing::Combine(
317         ::testing::ValuesIn(precisions),
318         ::testing::ValuesIn(batches),
319         ::testing::ValuesIn(testValues)),
320     MatMulWithConstantTransformation::getTestCaseName);
321
322 } // namespace