934adb21f49b9858e3946475f18bf56b2d1c4c3a
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / mat_mul_with_constant_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <sstream>
9 #include <memory>
10
11 #include <gtest/gtest.h>
12
13 #include <transformations/init_node_info.hpp>
14 #include <low_precision/transformer.hpp>
15 #include <low_precision/mat_mul.hpp>
16
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "ngraph_functions/low_precision_transformations/mat_mul_function.hpp"
19 #include "ngraph_functions/subgraph_builders.hpp"
20 #include "simple_low_precision_transformer.hpp"
21 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
22
23 namespace {
24
25 using namespace testing;
26 using namespace ngraph::pass;
27
28 class MatMullTransformationTestValues {
29 public:
30     class Actual {
31     public:
32         ngraph::Shape inputShape;
33         ngraph::element::Type precisionBeforeDequantization;
34         ngraph::builder::subgraph::DequantizationOperations dequantization;
35         ngraph::Shape weightsConstShape;
36         std::vector<float> weightsConstValues;
37         ngraph::builder::subgraph::FakeQuantizeOnWeights fqOnWeights;
38     };
39
40     class Expected {
41     public:
42         ngraph::Shape inputShape;
43         ngraph::element::Type precisionBeforeDequantization;
44         ngraph::builder::subgraph::DequantizationOperations dequantization;
45         ngraph::element::Type weightsConstPrecision;
46         ngraph::Shape weightsConstShape;
47         std::vector<float> weightsConstValues;
48
49         ngraph::element::Type precisionBeforeOperation;
50         ngraph::builder::subgraph::DequantizationOperations resultDequantization;
51
52         ngraph::builder::subgraph::FakeQuantizeOnWeights fqOnWeights;
53     };
54
55     ngraph::pass::low_precision::LayerTransformation::Params params;
56     Actual actual;
57     Expected expected;
58 };
59
60 inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues::Actual& actual) {
61     return out << "_" <<
62         actual.inputShape << "_" <<
63         actual.precisionBeforeDequantization << "_" <<
64         actual.dequantization << "_" <<
65         actual.weightsConstShape << "_" <<
66         actual.fqOnWeights;
67 }
68
69 inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues::Expected& expected) {
70     return out << "_" <<
71         expected.weightsConstShape <<"_" <<
72         expected.dequantization << "_" <<
73         expected.precisionBeforeOperation << "_" <<
74         expected.resultDequantization << "_" <<
75         expected.fqOnWeights;
76 }
77
78 inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues& values) {
79     return out << "_" << values.actual << "_" << values.expected;
80 }
81
82 typedef std::tuple<
83     ngraph::element::Type,
84     size_t,
85     MatMullTransformationTestValues> MatMulTransformationParams;
86
87 class MatMulWithConstantTransformation : public LayerTransformation, public testing::WithParamInterface<MatMulTransformationParams> {
88 public:
89     void SetUp() override {
90         const ngraph::element::Type precision = std::get<0>(GetParam());
91         const size_t batch = std::get<1>(GetParam());
92
93         MatMullTransformationTestValues testValues = std::get<2>(GetParam());
94         testValues.actual.inputShape[0] = batch;
95         testValues.expected.inputShape[0] = batch;
96
97         actualFunction = ngraph::builder::subgraph::MatMulFunction::getOriginal(
98             precision,
99             testValues.actual.inputShape,
100             testValues.actual.precisionBeforeDequantization,
101             testValues.actual.dequantization,
102             testValues.actual.weightsConstShape,
103             testValues.actual.weightsConstValues,
104             testValues.actual.fqOnWeights);
105
106         SimpleLowPrecisionTransformer transformer;
107         transformer.add<ngraph::pass::low_precision::MatMulTransformation, ngraph::opset1::MatMul>(testValues.params);
108         transformer.transform(actualFunction);
109
110         referenceFunction = testValues.expected.fqOnWeights.empty() ?
111             ngraph::builder::subgraph::MatMulFunction::getReference(
112                 precision,
113                 testValues.expected.inputShape,
114                 testValues.expected.precisionBeforeDequantization,
115                 testValues.expected.dequantization,
116                 testValues.expected.weightsConstPrecision,
117                 testValues.expected.weightsConstShape,
118                 testValues.expected.weightsConstValues,
119                 testValues.expected.resultDequantization) :
120             ngraph::builder::subgraph::MatMulFunction::getOriginal(
121                 precision,
122                 testValues.expected.inputShape,
123                 testValues.expected.precisionBeforeDequantization,
124                 testValues.expected.dequantization,
125                 testValues.expected.weightsConstShape,
126                 testValues.expected.weightsConstValues,
127                 testValues.expected.fqOnWeights);
128     }
129
130     static std::string getTestCaseName(testing::TestParamInfo<MatMulTransformationParams> obj) {
131         ngraph::element::Type precision;
132         size_t batch;
133         MatMullTransformationTestValues testValues;
134         std::tie(precision, batch, testValues) = obj.param;
135
136         std::stringstream ss;
137         ss << precision << "_" << batch << "_" << testValues;
138         return ss.str();
139     }
140 };
141
142 TEST_P(MatMulWithConstantTransformation, CompareFunctions) {
143     InitNodeInfo().run_on_function(actualFunction);
144
145     actualFunction->validate_nodes_and_infer_types();
146
147     auto res = compare_functions(referenceFunction, actualFunction, true, true);
148     ASSERT_TRUE(res.first) << res.second;
149 }
150
151 const std::vector<ngraph::element::Type> precisions = {
152     ngraph::element::f32,
153     // ngraph::element::f16
154 };
155
156 const std::vector<size_t> batches = { 1, 4 };
157
158 std::vector<MatMullTransformationTestValues> testValues = {
159     // supported 3D: U8 & I8
160     {
161         LayerTransformation::createParamsU8I8(),
162         {
163             { 1, 384, 1024 },
164             ngraph::element::u8,
165             { ngraph::element::f32, {}, { 0.02f } },
166             { 1024, 1024 },
167             std::vector<float>(1024 * 1024, 1.f),
168             { 255, { 1, 1 },  {0.f}, {254.f}, {-12.7f}, {12.7} },
169         },
170         {
171             { 1, 384, 1024 },
172             ngraph::element::u8,
173             { {}, {}, {} },
174             ngraph::element::i8,
175             { 1024, 1024 },
176             std::vector<float>(1024 * 1024, -126),
177             ngraph::element::i8,
178             { {}, {}, { 0.02f * 0.1f } },
179             {}
180         }
181     },
182
183     // not supported 3D: U8 & I8
184     {
185         LayerTransformation::createParamsU8I8(),
186         {
187             { 1, 3, 4 },
188             ngraph::element::u8,
189             { ngraph::element::f32, {}, { {0.01f, 0.02f, 0.03f} } },
190             { 4, 4 },
191             std::vector<float>(4 * 4, 1.f),
192             { 255, { 1, 1 },  {0.f}, {254.f}, {-12.7f}, {12.7} },
193         },
194         {
195             { 1, 3, 4 },
196             ngraph::element::u8,
197             { ngraph::element::f32, {}, { {0.01f, 0.02f, 0.03f} } },
198             ngraph::element::i8,
199             {4, 4},
200             std::vector<float>(4 * 4, 1.f),
201             ngraph::element::f32,
202             {},
203             { 255, { 1, 1 },  {0.f}, {254.f}, {-12.7f}, {12.7} },
204         }
205     },
206
207     // not supported 3D: U8 & I8
208     {
209         LayerTransformation::createParamsU8I8(),
210         {
211             { 1, 3, 4 },
212             ngraph::element::u8,
213             { ngraph::element::f32, {}, { 0.02f } },
214             { 4, 4 },
215             std::vector<float>(4 * 4, 1.f),
216             {
217                 255,
218                 { 4, 1 },
219                 {0.f, 0.f, 0.f, 0.f},
220                 {254.f, 254.f, 254.f, 254.f},
221                 {-12.7f / 4.f, -12.7f / 3.f, -12.7f / 2.f, -12.7f},
222                 {12.7f / 4.f, 12.7f / 3.f, 12.7f / 2.f, 12.7f}
223             },
224         },
225         {
226             { 1, 3, 4 },
227             ngraph::element::u8,
228             { ngraph::element::f32, {}, { 0.02f } },
229             ngraph::element::i8,
230             {4, 4},
231             std::vector<float>(4 * 4, 1.f),
232             ngraph::element::f32,
233             {},
234             {
235                 255,
236                 { 4, 1 },
237                 {0.f, 0.f, 0.f, 0.f},
238                 {254.f, 254.f, 254.f, 254.f},
239                 {-12.7f / 4.f, -12.7f / 3.f, -12.7f / 2.f, -12.7f},
240                 {12.7f / 4.f, 12.7f / 3.f, 12.7f / 2.f, 12.7f}
241             },
242         }
243     },
244
245     // 2D: U8 & I8
246     {
247         LayerTransformation::createParamsU8I8(),
248         {
249             { 1, 2048 },
250             ngraph::element::u8,
251             { ngraph::element::f32, {}, { 0.02f } },
252             { 2048, 1000 },
253             std::vector<float>(2048 * 1000, 1.f),
254             { 255, { 1, 1 },  {0.f}, {254.f}, {-12.7f}, {12.7} },
255         },
256         {
257             { 1, 2048 },
258             ngraph::element::u8,
259             { {}, {}, {} },
260             ngraph::element::i8,
261             {2048, 1000},
262             std::vector<float>(2048 * 1000, -126),
263             ngraph::element::i8,
264             { {}, {}, { 0.02f * 0.1f } },
265             {}
266         }
267     },
268     // 2D: I8 & I8
269     {
270         LayerTransformation::createParamsI8I8(),
271         {
272             { 1, 2048 },
273             ngraph::element::i8,
274             { ngraph::element::f32, {}, { 0.02f } },
275             { 2048, 1000 },
276             std::vector<float>(2048 * 1000, 1.f),
277             { 255, { 1, 1 },  {0.f}, {254.f}, {-12.7f}, {12.7} },
278         },
279         {
280             { 1, 2048 },
281             ngraph::element::i8,
282             { {}, {}, {} },
283             ngraph::element::i8,
284             {2048, 1000},
285             std::vector<float>(2048 * 1000, -126),
286             ngraph::element::i8,
287             { {}, {}, { 0.02f * 0.1f } },
288             {}
289         }
290     },
291     // 2D: FP32 & FP328
292     {
293         LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
294         {
295             { 1, 2048 },
296             ngraph::element::f32,
297             { {}, {}, { 0.02f } },
298             { 2048, 1000 },
299             std::vector<float>(2048 * 1000, 1.f),
300             { 255, { 1, 1 },  {0.f}, {254.f}, {-12.7f}, {12.7} },
301         },
302         {
303             { 1, 2048 },
304             ngraph::element::f32,
305             { {}, {}, {} },
306             ngraph::element::f32,
307             {2048, 1000},
308             std::vector<float>(2048 * 1000, -126),
309             ngraph::element::f32,
310             { {}, {}, { 0.02f * 0.1f } },
311             {}
312         }
313     },
314 };
315
316 INSTANTIATE_TEST_CASE_P(
317     LPT,
318     MatMulWithConstantTransformation,
319     ::testing::Combine(
320         ::testing::ValuesIn(precisions),
321         ::testing::ValuesIn(batches),
322         ::testing::ValuesIn(testValues)),
323     MatMulWithConstantTransformation::getTestCaseName);
324
325 } // namespace