[LPT] POT support: absent convert fix & element-wise empty dequantization data (...
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / elementwise_with_multi_parent_dequantization_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <sstream>
9 #include <memory>
10
11 #include <gtest/gtest.h>
12
13 #include <utility>
14 #include <transformations/utils/utils.hpp>
15 #include <transformations/init_node_info.hpp>
16
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "simple_low_precision_transformer.hpp"
19
20 #include <low_precision/add.hpp>
21 #include "ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp"
22 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
23
24 using namespace testing;
25 using namespace ngraph::pass;
26 using namespace ngraph::builder::subgraph;
27
28 class ElementwiseWithMultiParentDequantizationTransformationTestValues {
29 public:
30     class Actual {
31     public:
32         ngraph::element::Type precision1;
33         ngraph::builder::subgraph::DequantizationOperations dequantization1;
34         ngraph::element::Type precision2;
35         ngraph::builder::subgraph::DequantizationOperations dequantization2;
36     };
37
38     class Expected {
39     public:
40         ngraph::element::Type precision1;
41         ngraph::builder::subgraph::DequantizationOperations dequantization1;
42         ngraph::element::Type precision2;
43         ngraph::builder::subgraph::DequantizationOperations dequantization2;
44     };
45
46     ngraph::element::Type precision;
47     ngraph::Shape inputShape;
48     ngraph::pass::low_precision::LayerTransformation::Params params;
49     Actual actual;
50     Expected expected;
51 };
52
53 template <typename T>
54 inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& values) {
55     os << "{ ";
56     for (size_t i = 0; i < values.size(); ++i) {
57         os << values[i];
58         if (i != (values.size() - 1ul)) {
59             os << ", ";
60         }
61     }
62     os << " }";
63     return os;
64 }
65
66 class ElementwiseWithMultiParentDequantizationTransformation :
67     public LayerTransformation,
68     public testing::WithParamInterface<ElementwiseWithMultiParentDequantizationTransformationTestValues> {
69 public:
70     void SetUp() override {
71         const ElementwiseWithMultiParentDequantizationTransformationTestValues testValues = GetParam();
72
73         actualFunction = ElementwiseWithMultiParentDequantizationFunction::get(
74             testValues.precision,
75             testValues.inputShape,
76             testValues.params,
77             testValues.actual.precision1,
78             testValues.actual.dequantization1,
79             testValues.actual.precision2,
80             testValues.actual.dequantization2);
81
82         SimpleLowPrecisionTransformer transform;
83         transform.add<ngraph::pass::low_precision::AddTransformation, ngraph::opset1::Add>(
84             low_precision::LayerTransformation::Params(testValues.params));
85         transform.transform(actualFunction);
86
87         referenceFunction = ElementwiseWithMultiParentDequantizationFunction::get(
88             testValues.precision,
89             testValues.inputShape,
90             testValues.params,
91             testValues.expected.precision1,
92             testValues.expected.dequantization1,
93             testValues.expected.precision2,
94             testValues.expected.dequantization2);
95     }
96
97     static std::string getTestCaseName(testing::TestParamInfo<ElementwiseWithMultiParentDequantizationTransformationTestValues> obj) {
98         const ElementwiseWithMultiParentDequantizationTransformationTestValues testValues = obj.param;
99
100         std::ostringstream result;
101         result <<
102             testValues.precision << "_" <<
103             testValues.inputShape << "_" <<
104             testValues.actual.precision1 << "_" <<
105             testValues.actual.dequantization1 << "_" <<
106             testValues.actual.precision2 << "_" <<
107             testValues.actual.dequantization2;
108         return result.str();
109     }
110 };
111
112 TEST_P(ElementwiseWithMultiParentDequantizationTransformation, CompareFunctions) {
113     actualFunction->validate_nodes_and_infer_types();
114     auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
115     ASSERT_TRUE(res.first) << res.second;
116 }
117
118 const std::vector<ElementwiseWithMultiParentDequantizationTransformationTestValues> addTransformationTestValues = {
119     // U8
120     {
121         ngraph::element::f32,
122         ngraph::Shape{1, 4, 16, 16},
123         LayerTransformation::createParamsU8I8(),
124         {
125             ngraph::element::u8,
126             { {ngraph::element::f32},  { 7.f }, { 10.f }},
127             ngraph::element::u8,
128             {},
129         },
130         {
131             ngraph::element::u8,
132             { {ngraph::element::f32},  { 7.f }, { 10.f }},
133             ngraph::element::u8,
134             {},
135         }
136     },
137     // U8
138     {
139         ngraph::element::f32,
140         ngraph::Shape{1, 4, 16, 16},
141         LayerTransformation::createParamsU8I8(),
142         {
143             ngraph::element::u8,
144             {},
145             ngraph::element::u8,
146             { {ngraph::element::f32},  { 7.f }, { 10.f }}
147         },
148         {
149             ngraph::element::u8,
150             {},
151             ngraph::element::u8,
152             { {ngraph::element::f32},  { 7.f }, { 10.f }}
153         }
154     }
155 };
156
157 INSTANTIATE_TEST_CASE_P(
158     LPT,
159     ElementwiseWithMultiParentDequantizationTransformation,
160     ::testing::ValuesIn(addTransformationTestValues),
161     ElementwiseWithMultiParentDequantizationTransformation::getTestCaseName);