b66d0f1c46bf23f13f45956c0fde162783e6c98c
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / add_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <sstream>
9 #include <memory>
10
11 #include <gtest/gtest.h>
12
13 #include <utility>
14 #include <transformations/utils/utils.hpp>
15 #include <transformations/init_node_info.hpp>
16
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "simple_low_precision_transformer.hpp"
19
20 #include <low_precision/add.hpp>
21 #include "ngraph_functions/low_precision_transformations/add_function.hpp"
22 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
23
24 using namespace testing;
25 using namespace ngraph::pass;
26 using namespace ngraph::builder::subgraph;
27
28 class AddTransformationTestValues {
29 public:
30     class Actual {
31     public:
32         ngraph::element::Type precision1;
33         ngraph::builder::subgraph::DequantizationOperations dequantization1;
34         ngraph::element::Type precision2;
35         ngraph::builder::subgraph::DequantizationOperations dequantization2;
36         std::vector<float> constValues;
37     };
38
39     class Expected {
40     public:
41         ngraph::element::Type precision1;
42         ngraph::builder::subgraph::DequantizationOperations dequantization1;
43         ngraph::element::Type precision2;
44         ngraph::builder::subgraph::DequantizationOperations dequantization2;
45         ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
46         std::vector<float> constValues;
47         std::string operationType;
48
49         Expected(const ngraph::element::Type& precision1,
50                  ngraph::builder::subgraph::DequantizationOperations dequantization1,
51                  const ngraph::element::Type& precision2,
52                  ngraph::builder::subgraph::DequantizationOperations dequantization2,
53                  ngraph::builder::subgraph::DequantizationOperations dequantizationAfter,
54                  std::vector<float> constValues,
55                  std::string operationType = "Add"): precision1(precision1), dequantization1(std::move(dequantization1)),
56                                          precision2(precision2), dequantization2(std::move(dequantization2)),
57                                          dequantizationAfter(std::move(dequantizationAfter)), constValues(std::move(constValues)),
58                                          operationType(std::move(operationType)) {}
59     };
60
61     ngraph::element::Type precision;
62     ngraph::Shape inputShape;
63     bool broadcast;
64     int constInput;
65     ngraph::pass::low_precision::LayerTransformation::Params params;
66     Actual actual;
67     Expected expected;
68     std::string additionalLayer;
69 };
70
71 template <typename T>
72 inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& values) {
73     os << "{ ";
74     for (size_t i = 0; i < values.size(); ++i) {
75         os << values[i];
76         if (i != (values.size() - 1ul)) {
77             os << ", ";
78         }
79     }
80     os << " }";
81     return os;
82 }
83
84 class AddTransformation : public LayerTransformation, public testing::WithParamInterface<AddTransformationTestValues> {
85 public:
86     void SetUp() override {
87         const AddTransformationTestValues testValues = GetParam();
88
89         actualFunction = AddFunction::getOriginal(
90             testValues.precision,
91             testValues.inputShape,
92             testValues.broadcast,
93             testValues.params,
94             testValues.actual.precision1,
95             testValues.actual.dequantization1,
96             testValues.actual.precision2,
97             testValues.actual.dequantization2,
98             testValues.constInput,
99             testValues.actual.constValues,
100             testValues.additionalLayer);
101
102         SimpleLowPrecisionTransformer transform;
103         transform.add<ngraph::pass::low_precision::AddTransformation, ngraph::opset1::Add>(
104             low_precision::LayerTransformation::Params(testValues.params));
105         transform.transform(actualFunction);
106
107         referenceFunction = AddFunction::getReference(
108             testValues.precision,
109             testValues.inputShape,
110             testValues.broadcast,
111             testValues.params,
112             testValues.expected.precision1,
113             testValues.expected.dequantization1,
114             testValues.expected.precision2,
115             testValues.expected.dequantization2,
116             testValues.expected.dequantizationAfter,
117             // Constant operations after transformations are on 1 input only
118             testValues.constInput == -1 ? -1 : 1,
119             testValues.expected.constValues,
120             testValues.additionalLayer,
121             testValues.expected.operationType);
122     }
123
124     static std::string getTestCaseName(testing::TestParamInfo<AddTransformationTestValues> obj) {
125         const AddTransformationTestValues testValues = obj.param;
126
127         std::ostringstream result;
128         result <<
129             testValues.precision << "_" <<
130             testValues.inputShape << "_" <<
131             testValues.broadcast << "_" <<
132             testValues.actual.precision1 << "_" <<
133             testValues.actual.dequantization1 << "_" <<
134             testValues.actual.precision2 << "_" <<
135             testValues.actual.dequantization2 << "_" <<
136             testValues.constInput << "_" <<
137             testValues.actual.constValues << "_" <<
138             testValues.additionalLayer;
139         return result.str();
140     }
141 };
142
143 TEST_P(AddTransformation, CompareFunctions) {
144     actualFunction->validate_nodes_and_infer_types();
145     auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
146     ASSERT_TRUE(res.first) << res.second;
147 }
148
149 const std::vector<AddTransformationTestValues> addTransformationTestValues = {
150     // U8
151     {
152         ngraph::element::f32,
153         ngraph::Shape{1, 4, 16, 16},
154         false,
155         -1,
156         LayerTransformation::createParamsU8I8(),
157         {
158             ngraph::element::u8,
159             { {ngraph::element::f32},  { 7.f }, { 10.f }},
160             ngraph::element::u8,
161             { {ngraph::element::f32},  { 3.f }, { 5.f } },
162             {}
163         },
164         {
165             ngraph::element::u8,
166             { {ngraph::element::f32},  { 8.5f }, { 2.f }},
167             ngraph::element::u8,
168             { {},  {}, {} },
169             { {},  {}, {5.f} },
170             {}
171         },
172         ""
173     },
174     {
175         ngraph::element::f32,
176         ngraph::Shape{1, 4, 16, 16},
177         false,
178         -1,
179         LayerTransformation::createParamsU8I8(),
180         {
181             ngraph::element::u8,
182             { {ngraph::element::f32},  { 2.f }, { 10.f }},
183             ngraph::element::u8,
184             { {ngraph::element::f32},  { }, { 5.f } },
185             {}
186         },
187         {
188             ngraph::element::u8,
189             { {ngraph::element::f32},  { 2.f }, { 2.f }},
190             ngraph::element::u8,
191             { {},  {}, {} },
192             { {},  {}, {5.f} },
193             {}
194         },
195         ""
196     },
197     {
198         ngraph::element::f32,
199         ngraph::Shape{1, 4, 16, 16},
200         false,
201         -1,
202         LayerTransformation::createParamsU8I8(),
203         {
204             ngraph::element::u8,
205             { {ngraph::element::f32},  { }, { 10.f }},
206             ngraph::element::u8,
207             { {ngraph::element::f32},  { }, { 5.f } },
208             {}
209         },
210         {
211             ngraph::element::u8,
212             { {ngraph::element::f32},  { }, { 2.f }},
213             ngraph::element::u8,
214             { {},  {}, {} },
215             { {},  {}, {5.f} },
216             {}
217         },
218         ""
219     },
220     {
221         ngraph::element::f32,
222         ngraph::Shape{1, 4, 16, 16},
223         false,
224         -1,
225         LayerTransformation::createParamsU8I8(),
226         {
227             ngraph::element::u8,
228             { {ngraph::element::f32},  { 2.f }, { }},
229             ngraph::element::u8,
230             { {ngraph::element::f32},  { }, { 5.f } },
231             {}
232         },
233         {
234             ngraph::element::u8,
235             { {ngraph::element::f32},  { 2.f }, { 0.2f }},
236             ngraph::element::u8,
237             { {},  {}, {} },
238             { {},  {}, {5.f} },
239             {}
240         },
241         ""
242     },
243     {
244         ngraph::element::f32,
245         ngraph::Shape{1, 4, 16, 16},
246         false,
247         -1,
248         LayerTransformation::createParamsU8I8(),
249         {
250             ngraph::element::u8,
251             { {ngraph::element::f32},  { 2.f }, { }},
252             ngraph::element::u8,
253             { {ngraph::element::f32},  { 3.f }, { 5.f } },
254             {}
255         },
256         {
257             ngraph::element::u8,
258             { {ngraph::element::f32},  { 17.f }, { 0.2f }},
259             ngraph::element::u8,
260             { {},  {}, {} },
261             { {},  {}, {5.f} },
262             {}
263         },
264         ""
265     },
266
267     // I8 + broadcast
268
269     {
270         ngraph::element::f32,
271         ngraph::Shape{1, 4, 16, 16},
272         true,
273         -1,
274         LayerTransformation::createParamsU8I8(),
275         {
276             ngraph::element::i8,
277             { {ngraph::element::f32},  { 7.f }, { 10.f }},
278             ngraph::element::i8,
279             { {ngraph::element::f32},  { 3.f }, { 5.f } },
280             {}
281         },
282         {
283             ngraph::element::i8,
284             { {ngraph::element::f32},  { 8.5f }, { 2.f }},
285             ngraph::element::i8,
286             { {},  {}, {} },
287             { {},  {}, {5.f} },
288             {}
289         },
290         ""
291     },
292     {
293         ngraph::element::f32,
294         ngraph::Shape{1, 4, 16, 16},
295         true,
296         -1,
297         LayerTransformation::createParamsU8I8(),
298         {
299             ngraph::element::i8,
300             { {ngraph::element::f32},  { 2.f }, { 10.f }},
301             ngraph::element::i8,
302             { {ngraph::element::f32},  { }, { 5.f } },
303             {}
304         },
305         {
306             ngraph::element::i8,
307             { {ngraph::element::f32},  { 2.f }, { 2.f }},
308             ngraph::element::i8,
309             { {},  {}, {} },
310             { {},  {}, {5.f} },
311             {}
312         },
313         ""
314     },
315     {
316         ngraph::element::f32,
317         ngraph::Shape{1, 4, 16, 16},
318         true,
319         -1,
320         LayerTransformation::createParamsU8I8(),
321         {
322             ngraph::element::i8,
323             { {ngraph::element::f32},  { }, { 10.f }},
324             ngraph::element::i8,
325             { {ngraph::element::f32},  { }, { 5.f } },
326             {}
327         },
328         {
329             ngraph::element::i8,
330             { {ngraph::element::f32},  { }, { 2.f }},
331             ngraph::element::i8,
332             { {},  {}, {} },
333             { {},  {}, {5.f} },
334             {}
335         },
336         ""
337     },
338     {
339         ngraph::element::f32,
340         ngraph::Shape{1, 4, 16, 16},
341         true,
342         -1,
343         LayerTransformation::createParamsU8I8(),
344         {
345             ngraph::element::i8,
346             { {ngraph::element::f32},  { 2.f }, { }},
347             ngraph::element::i8,
348             { {ngraph::element::f32},  { }, { 5.f } },
349             {}
350         },
351         {
352             ngraph::element::i8,
353             { {ngraph::element::f32},  { 2.f }, { 0.2f }},
354             ngraph::element::i8,
355             { {},  {}, {} },
356             { {},  {}, {5.f} },
357             {}
358         },
359         ""
360     },
361     {
362         ngraph::element::f32,
363         ngraph::Shape{1, 4, 16, 16},
364         true,
365         -1,
366         LayerTransformation::createParamsU8I8(),
367         {
368             ngraph::element::i8,
369             { {ngraph::element::f32},  { 2.f }, { }},
370             ngraph::element::i8,
371             { {ngraph::element::f32},  { 3.f }, { 5.f } },
372             {}
373         },
374         {
375             ngraph::element::i8,
376             { {ngraph::element::f32},  { 17.f }, { 0.2f }},
377             ngraph::element::i8,
378             { {},  {}, {} },
379             { {},  {}, {5.f} },
380             {}
381         },
382         ""
383     },
384
385     {
386         ngraph::element::f32,
387         ngraph::Shape{4, 1},
388         false,
389         -1,
390         LayerTransformation::createParamsU8I8(),
391         {
392             ngraph::element::u8,
393             { {ngraph::element::f32},  { }, { {1.f, 2.f, 3.f, 4.f}, ngraph::element::f32, {4, 1}, true, 0ul }},
394             ngraph::element::f32,
395             {},
396             { 5.f, 6.f, 7.f, 8.f }
397         },
398         {
399             ngraph::element::u8,
400             { {ngraph::element::f32},  { }, { {1.f, 2.f, 3.f, 4.f}, ngraph::element::f32, {4, 1}, true, 0ul }},
401             ngraph::element::f32,
402             { {},  {}, {} },
403             { {},  {}, {} },
404             { 5.f, 6.f, 7.f, 8.f }
405         },
406         ""
407     },
408
409     // constant input: Add -> Subtract
410     {
411     ngraph::element::f32,
412         ngraph::Shape{ 1, 2, 2, 2 },
413         false,
414         1,
415         LayerTransformation::createParamsU8I8(),
416         {
417             ngraph::element::i8,
418             { {ngraph::element::f32},  {}, {5.f}},
419             ngraph::element::i8,
420             { {},  {}, {} },
421             { 10.f, 5.f, 2.f, 4.f, 3.f, 12.f, 8.f, 14.f }
422         },
423         {
424             ngraph::element::i8,
425             { {ngraph::element::f32},  { }, { }},
426             ngraph::element::f32,
427             { {},  {}, {} },
428             { {},  {}, {5.f} },
429             { -2.f, -1.f, -0.4f, -0.8f, -0.6f, -2.4f, -1.6f, -2.8f },
430             "Subtract"
431         },
432         ""
433     },
434
435     // constant input: Add -> Subtract
436     {
437         ngraph::element::f32,
438         ngraph::Shape{1, 2, 2, 2},
439         false,
440         0,
441         LayerTransformation::createParamsU8I8(),
442         {
443             ngraph::element::i8,
444             { {},  {}, {}},
445             ngraph::element::i8,
446             { {ngraph::element::f32},  {}, { 5.f } },
447             { 10.f, 5.f, 2.f, 4.f, 3.f, 12.f, 8.f, 14.f }
448         },
449         {
450             ngraph::element::i8,
451             { {ngraph::element::f32},  {}, {} },
452             ngraph::element::f32,
453             { {},  {}, { }},
454
455             { {},  {}, {5.f} },
456             { -2.f, -1.f, -0.4f, -0.8f, -0.6f, -2.4f, -1.6f, -2.8f },
457             "Subtract"
458         },
459         "",
460     },
461     // convolution before FQ (choose that branch)
462     {
463         ngraph::element::f32,
464         ngraph::Shape{1, 4, 16, 16},
465         false,
466         -1,
467         LayerTransformation::createParamsU8I8(),
468         {
469             ngraph::element::u8,
470             { {ngraph::element::f32},  { 7.f }, { 10.f }},
471             ngraph::element::u8,
472             { {ngraph::element::f32},  { 3.f }, { 5.f } },
473             {}
474         },
475         {
476             ngraph::element::u8,
477             { {},  {}, {} },
478             ngraph::element::u8,
479             { {ngraph::element::f32},  { 17.f }, { 0.5f }},
480             { {},  {}, {10.f} },
481             {}
482         },
483         "convolution"
484     },
485     // group convolution before FQ (choose that branch)
486     {
487         ngraph::element::f32,
488         ngraph::Shape{1, 4, 16, 16},
489         false,
490         -1,
491         LayerTransformation::createParamsU8I8(),
492         {
493             ngraph::element::u8,
494             { {ngraph::element::f32},  { 7.f }, { 10.f }},
495             ngraph::element::u8,
496             { {ngraph::element::f32},  { 3.f }, { 5.f } },
497             {}
498         },
499         {
500             ngraph::element::u8,
501             { {},  {}, {} },
502             ngraph::element::u8,
503             { {ngraph::element::f32},  { 17.f }, { 0.5f }},
504             { {},  {}, {10.f} },
505             {}
506         },
507         "group_convolution"
508     },
509 };
510
511 INSTANTIATE_TEST_CASE_P(
512     LPT,
513     AddTransformation,
514     ::testing::ValuesIn(addTransformationTestValues),
515     AddTransformation::getTestCaseName);