[LPT] integration: issue #42391 & issue #43001 (#3201)
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / add_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <sstream>
9 #include <memory>
10
11 #include <gtest/gtest.h>
12
13 #include <utility>
14 #include <transformations/utils/utils.hpp>
15 #include <transformations/init_node_info.hpp>
16
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "simple_low_precision_transformer.hpp"
19
20 #include <low_precision/add.hpp>
21 #include "ngraph_functions/low_precision_transformations/add_function.hpp"
22 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
23
24 using namespace testing;
25 using namespace ngraph::pass;
26 using namespace ngraph::builder::subgraph;
27
28 class AddTransformationTestValues {
29 public:
30     class Actual {
31     public:
32         ngraph::element::Type precision1;
33         ngraph::builder::subgraph::DequantizationOperations dequantization1;
34         ngraph::element::Type precision2;
35         ngraph::builder::subgraph::DequantizationOperations dequantization2;
36         std::vector<float> constValues;
37     };
38
39     class Expected {
40     public:
41         ngraph::element::Type precision1;
42         ngraph::builder::subgraph::DequantizationOperations dequantization1;
43         ngraph::element::Type precision2;
44         ngraph::builder::subgraph::DequantizationOperations dequantization2;
45         ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
46         std::vector<float> constValues;
47         std::string operationType;
48
49         Expected(const ngraph::element::Type& precision1,
50                  ngraph::builder::subgraph::DequantizationOperations dequantization1,
51                  const ngraph::element::Type& precision2,
52                  ngraph::builder::subgraph::DequantizationOperations dequantization2,
53                  ngraph::builder::subgraph::DequantizationOperations dequantizationAfter,
54                  std::vector<float> constValues,
55                  std::string operationType = "Add"): precision1(precision1), dequantization1(std::move(dequantization1)),
56                                          precision2(precision2), dequantization2(std::move(dequantization2)),
57                                          dequantizationAfter(std::move(dequantizationAfter)), constValues(std::move(constValues)),
58                                          operationType(std::move(operationType)) {}
59     };
60
61     ngraph::element::Type precision;
62     ngraph::Shape inputShape;
63     bool broadcast;
64     int constInput;
65     ngraph::pass::low_precision::LayerTransformation::Params params;
66     Actual actual;
67     Expected expected;
68     std::string additionalLayer;
69 };
70
71 template <typename T>
72 inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& values) {
73     os << "{ ";
74     for (size_t i = 0; i < values.size(); ++i) {
75         os << values[i];
76         if (i != (values.size() - 1ul)) {
77             os << ", ";
78         }
79     }
80     os << " }";
81     return os;
82 }
83
84 class AddTransformation : public LayerTransformation, public testing::WithParamInterface<AddTransformationTestValues> {
85 public:
86     void SetUp() override {
87         const AddTransformationTestValues testValues = GetParam();
88
89         actualFunction = AddFunction::getOriginal(
90             testValues.precision,
91             testValues.inputShape,
92             testValues.broadcast,
93             testValues.params,
94             testValues.actual.precision1,
95             testValues.actual.dequantization1,
96             testValues.actual.precision2,
97             testValues.actual.dequantization2,
98             testValues.constInput,
99             testValues.actual.constValues,
100             testValues.additionalLayer);
101
102         SimpleLowPrecisionTransformer transform;
103         transform.add<ngraph::pass::low_precision::AddTransformation, ngraph::opset1::Add>(
104             low_precision::LayerTransformation::Params(testValues.params));
105         transform.transform(actualFunction);
106
107         referenceFunction = AddFunction::getReference(
108             testValues.precision,
109             testValues.inputShape,
110             testValues.broadcast,
111             testValues.params,
112             testValues.expected.precision1,
113             testValues.expected.dequantization1,
114             testValues.expected.precision2,
115             testValues.expected.dequantization2,
116             testValues.expected.dequantizationAfter,
117             // Constant operations after transformations are on 1 input only
118             testValues.constInput == -1 ? -1 : 1,
119             testValues.expected.constValues,
120             testValues.additionalLayer,
121             testValues.expected.operationType);
122     }
123
124     static std::string getTestCaseName(testing::TestParamInfo<AddTransformationTestValues> obj) {
125         const AddTransformationTestValues testValues = obj.param;
126
127         std::ostringstream result;
128         result <<
129             testValues.precision << "_" <<
130             testValues.inputShape << "_" <<
131             testValues.broadcast << "_" <<
132             testValues.actual.precision1 << "_" <<
133             testValues.actual.dequantization1 << "_" <<
134             testValues.actual.precision2 << "_" <<
135             testValues.actual.dequantization2 << "_" <<
136             testValues.constInput << "_" <<
137             testValues.actual.constValues << "_" <<
138             testValues.additionalLayer;
139         return result.str();
140     }
141 };
142
143 TEST_P(AddTransformation, CompareFunctions) {
144     actualFunction->validate_nodes_and_infer_types();
145     auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
146     ASSERT_TRUE(res.first) << res.second;
147 }
148
149 const std::vector<AddTransformationTestValues> addTransformationTestValues = {
150     // Multiply with zero on the first branch
151     {
152         ngraph::element::f32,
153         ngraph::Shape{1, 4, 16, 16},
154         false,
155         -1,
156         LayerTransformation::createParamsU8I8(),
157         {
158             ngraph::element::f32,
159             { },
160             ngraph::element::u8,
161             { {ngraph::element::f32},  { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
162             { }
163         },
164         {
165             ngraph::element::f32,
166             { },
167             ngraph::element::u8,
168             { {ngraph::element::f32},  { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
169             { },
170             { }
171         },
172         ""
173     },
174     // Multiply with zero on the second branch
175     {
176         ngraph::element::f32,
177         ngraph::Shape{1, 4, 16, 16},
178         false,
179         -1,
180         LayerTransformation::createParamsU8I8(),
181         {
182             ngraph::element::u8,
183             { {ngraph::element::f32},  { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
184             ngraph::element::f32,
185             { },
186             { }
187         },
188         {
189             ngraph::element::u8,
190             { {ngraph::element::f32},  { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
191             ngraph::element::f32,
192             { },
193             { },
194             { }
195         },
196         ""
197     },
198     // U8
199     {
200         ngraph::element::f32,
201         ngraph::Shape{1, 4, 16, 16},
202         false,
203         -1,
204         LayerTransformation::createParamsU8I8(),
205         {
206             ngraph::element::u8,
207             { {ngraph::element::f32},  { 7.f }, { 10.f }},
208             ngraph::element::u8,
209             { {ngraph::element::f32},  { 3.f }, { 5.f } },
210             {}
211         },
212         {
213             ngraph::element::u8,
214             { {ngraph::element::f32},  { 8.5f }, { 2.f }},
215             ngraph::element::u8,
216             { {},  {}, {} },
217             { {},  {}, {5.f} },
218             {}
219         },
220         ""
221     },
222     {
223         ngraph::element::f32,
224         ngraph::Shape{1, 4, 16, 16},
225         false,
226         -1,
227         LayerTransformation::createParamsU8I8(),
228         {
229             ngraph::element::u8,
230             { {ngraph::element::f32},  { 2.f }, { 10.f }},
231             ngraph::element::u8,
232             { {ngraph::element::f32},  { }, { 5.f } },
233             {}
234         },
235         {
236             ngraph::element::u8,
237             { {ngraph::element::f32},  { 2.f }, { 2.f }},
238             ngraph::element::u8,
239             { {},  {}, {} },
240             { {},  {}, {5.f} },
241             {}
242         },
243         ""
244     },
245     {
246         ngraph::element::f32,
247         ngraph::Shape{1, 4, 16, 16},
248         false,
249         -1,
250         LayerTransformation::createParamsU8I8(),
251         {
252             ngraph::element::u8,
253             { {ngraph::element::f32},  { }, { 10.f }},
254             ngraph::element::u8,
255             { {ngraph::element::f32},  { }, { 5.f } },
256             {}
257         },
258         {
259             ngraph::element::u8,
260             { {ngraph::element::f32},  { }, { 2.f }},
261             ngraph::element::u8,
262             { {},  {}, {} },
263             { {},  {}, {5.f} },
264             {}
265         },
266         ""
267     },
268     {
269         ngraph::element::f32,
270         ngraph::Shape{1, 4, 16, 16},
271         false,
272         -1,
273         LayerTransformation::createParamsU8I8(),
274         {
275             ngraph::element::u8,
276             { {ngraph::element::f32},  { 2.f }, { }},
277             ngraph::element::u8,
278             { {ngraph::element::f32},  { }, { 5.f } },
279             {}
280         },
281         {
282             ngraph::element::u8,
283             { {ngraph::element::f32},  { 2.f }, { 0.2f }},
284             ngraph::element::u8,
285             { {},  {}, {} },
286             { {},  {}, {5.f} },
287             {}
288         },
289         ""
290     },
291     {
292         ngraph::element::f32,
293         ngraph::Shape{1, 4, 16, 16},
294         false,
295         -1,
296         LayerTransformation::createParamsU8I8(),
297         {
298             ngraph::element::u8,
299             { {ngraph::element::f32},  { 2.f }, { }},
300             ngraph::element::u8,
301             { {ngraph::element::f32},  { 3.f }, { 5.f } },
302             {}
303         },
304         {
305             ngraph::element::u8,
306             { {ngraph::element::f32},  { 17.f }, { 0.2f }},
307             ngraph::element::u8,
308             { {},  {}, {} },
309             { {},  {}, {5.f} },
310             {}
311         },
312         ""
313     },
314
315     // I8 + broadcast
316
317     {
318         ngraph::element::f32,
319         ngraph::Shape{1, 4, 16, 16},
320         true,
321         -1,
322         LayerTransformation::createParamsU8I8(),
323         {
324             ngraph::element::i8,
325             { {ngraph::element::f32},  { 7.f }, { 10.f }},
326             ngraph::element::i8,
327             { {ngraph::element::f32},  { 3.f }, { 5.f } },
328             {}
329         },
330         {
331             ngraph::element::i8,
332             { {ngraph::element::f32},  { 8.5f }, { 2.f }},
333             ngraph::element::i8,
334             { {},  {}, {} },
335             { {},  {}, {5.f} },
336             {}
337         },
338         ""
339     },
340     {
341         ngraph::element::f32,
342         ngraph::Shape{1, 4, 16, 16},
343         true,
344         -1,
345         LayerTransformation::createParamsU8I8(),
346         {
347             ngraph::element::i8,
348             { {ngraph::element::f32},  { 2.f }, { 10.f }},
349             ngraph::element::i8,
350             { {ngraph::element::f32},  { }, { 5.f } },
351             {}
352         },
353         {
354             ngraph::element::i8,
355             { {ngraph::element::f32},  { 2.f }, { 2.f }},
356             ngraph::element::i8,
357             { {},  {}, {} },
358             { {},  {}, {5.f} },
359             {}
360         },
361         ""
362     },
363     {
364         ngraph::element::f32,
365         ngraph::Shape{1, 4, 16, 16},
366         true,
367         -1,
368         LayerTransformation::createParamsU8I8(),
369         {
370             ngraph::element::i8,
371             { {ngraph::element::f32},  { }, { 10.f }},
372             ngraph::element::i8,
373             { {ngraph::element::f32},  { }, { 5.f } },
374             {}
375         },
376         {
377             ngraph::element::i8,
378             { {ngraph::element::f32},  { }, { 2.f }},
379             ngraph::element::i8,
380             { {},  {}, {} },
381             { {},  {}, {5.f} },
382             {}
383         },
384         ""
385     },
386     {
387         ngraph::element::f32,
388         ngraph::Shape{1, 4, 16, 16},
389         true,
390         -1,
391         LayerTransformation::createParamsU8I8(),
392         {
393             ngraph::element::i8,
394             { {ngraph::element::f32},  { 2.f }, { }},
395             ngraph::element::i8,
396             { {ngraph::element::f32},  { }, { 5.f } },
397             {}
398         },
399         {
400             ngraph::element::i8,
401             { {ngraph::element::f32},  { 2.f }, { 0.2f }},
402             ngraph::element::i8,
403             { {},  {}, {} },
404             { {},  {}, {5.f} },
405             {}
406         },
407         ""
408     },
409     {
410         ngraph::element::f32,
411         ngraph::Shape{1, 4, 16, 16},
412         true,
413         -1,
414         LayerTransformation::createParamsU8I8(),
415         {
416             ngraph::element::i8,
417             { {ngraph::element::f32},  { 2.f }, { }},
418             ngraph::element::i8,
419             { {ngraph::element::f32},  { 3.f }, { 5.f } },
420             {}
421         },
422         {
423             ngraph::element::i8,
424             { {ngraph::element::f32},  { 17.f }, { 0.2f }},
425             ngraph::element::i8,
426             { {},  {}, {} },
427             { {},  {}, {5.f} },
428             {}
429         },
430         ""
431     },
432
433     {
434         ngraph::element::f32,
435         ngraph::Shape{4, 1},
436         false,
437         -1,
438         LayerTransformation::createParamsU8I8(),
439         {
440             ngraph::element::u8,
441             { {ngraph::element::f32},  { }, { {1.f, 2.f, 3.f, 4.f}, ngraph::element::f32, {4, 1}, true, 0ul }},
442             ngraph::element::f32,
443             {},
444             { 5.f, 6.f, 7.f, 8.f }
445         },
446         {
447             ngraph::element::u8,
448             { {ngraph::element::f32},  { }, { {1.f, 2.f, 3.f, 4.f}, ngraph::element::f32, {4, 1}, true, 0ul }},
449             ngraph::element::f32,
450             { {},  {}, {} },
451             { {},  {}, {} },
452             { 5.f, 6.f, 7.f, 8.f }
453         },
454         ""
455     },
456
457     // constant input: Add -> Subtract
458     {
459     ngraph::element::f32,
460         ngraph::Shape{ 1, 2, 2, 2 },
461         false,
462         1,
463         LayerTransformation::createParamsU8I8(),
464         {
465             ngraph::element::i8,
466             { {ngraph::element::f32},  {}, {5.f}},
467             ngraph::element::i8,
468             { {},  {}, {} },
469             { 10.f, 5.f, 2.f, 4.f, 3.f, 12.f, 8.f, 14.f }
470         },
471         {
472             ngraph::element::i8,
473             { {ngraph::element::f32},  { }, { }},
474             ngraph::element::f32,
475             { {},  {}, {} },
476             { {},  {}, {5.f} },
477             { -2.f, -1.f, -0.4f, -0.8f, -0.6f, -2.4f, -1.6f, -2.8f },
478             "Subtract"
479         },
480         ""
481     },
482
483     // constant input: Add -> Subtract
484     {
485         ngraph::element::f32,
486         ngraph::Shape{1, 2, 2, 2},
487         false,
488         0,
489         LayerTransformation::createParamsU8I8(),
490         {
491             ngraph::element::i8,
492             { {},  {}, {}},
493             ngraph::element::i8,
494             { {ngraph::element::f32},  {}, { 5.f } },
495             { 10.f, 5.f, 2.f, 4.f, 3.f, 12.f, 8.f, 14.f }
496         },
497         {
498             ngraph::element::i8,
499             { {ngraph::element::f32},  {}, {} },
500             ngraph::element::f32,
501             { {},  {}, { }},
502
503             { {},  {}, {5.f} },
504             { -2.f, -1.f, -0.4f, -0.8f, -0.6f, -2.4f, -1.6f, -2.8f },
505             "Subtract"
506         },
507         "",
508     },
509     // convolution before FQ (choose that branch)
510     {
511         ngraph::element::f32,
512         ngraph::Shape{1, 4, 16, 16},
513         false,
514         -1,
515         LayerTransformation::createParamsU8I8(),
516         {
517             ngraph::element::u8,
518             { {ngraph::element::f32},  { 7.f }, { 10.f }},
519             ngraph::element::u8,
520             { {ngraph::element::f32},  { 3.f }, { 5.f } },
521             {}
522         },
523         {
524             ngraph::element::u8,
525             { {},  {}, {} },
526             ngraph::element::u8,
527             { {ngraph::element::f32},  { 17.f }, { 0.5f }},
528             { {},  {}, {10.f} },
529             {}
530         },
531         "convolution"
532     },
533     // group convolution before FQ (choose that branch)
534     {
535         ngraph::element::f32,
536         ngraph::Shape{1, 4, 16, 16},
537         false,
538         -1,
539         LayerTransformation::createParamsU8I8(),
540         {
541             ngraph::element::u8,
542             { {ngraph::element::f32},  { 7.f }, { 10.f }},
543             ngraph::element::u8,
544             { {ngraph::element::f32},  { 3.f }, { 5.f } },
545             {}
546         },
547         {
548             ngraph::element::u8,
549             { {},  {}, {} },
550             ngraph::element::u8,
551             { {ngraph::element::f32},  { 17.f }, { 0.5f }},
552             { {},  {}, {10.f} },
553             {}
554         },
555         "group_convolution"
556     },
557 };
558
559 INSTANTIATE_TEST_CASE_P(
560     LPT,
561     AddTransformation,
562     ::testing::ValuesIn(addTransformationTestValues),
563     AddTransformation::getTestCaseName);