[LPT] integration: issue #42391 & issue #43001 (#3201)
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / relu_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <sstream>
9 #include <memory>
10
11 #include <gtest/gtest.h>
12
13 #include <transformations/utils/utils.hpp>
14 #include <transformations/init_node_info.hpp>
15 #include <low_precision/relu.hpp>
16
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
19 #include "ngraph_functions/low_precision_transformations/relu_function.hpp"
20 #include "simple_low_precision_transformer.hpp"
21
22 namespace {
23
24 using namespace testing;
25 using namespace ngraph::pass;
26
27 class ReluTransformationTestValues {
28 public:
29     class Actual {
30     public:
31         ngraph::element::Type precisionBeforeDequantization;
32         ngraph::builder::subgraph::DequantizationOperations dequantization;
33     };
34
35     class Expected {
36     public:
37         ngraph::element::Type precisionBeforeDequantization;
38         ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
39         ngraph::element::Type precisionAfterOperation;
40         ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
41     };
42
43     ngraph::Shape shape;
44     ngraph::pass::low_precision::LayerTransformation::Params params;
45     Actual actual;
46     Expected expected;
47 };
48
49 class ReluTransformation : public LayerTransformation, public testing::WithParamInterface<ReluTransformationTestValues> {
50 public:
51     void SetUp() override {
52         const ReluTransformationTestValues testValues = GetParam();
53
54         actualFunction = ngraph::builder::subgraph::ReluFunction::getOriginal(
55             testValues.shape,
56             testValues.actual.precisionBeforeDequantization,
57             testValues.actual.dequantization);
58
59         SimpleLowPrecisionTransformer transformer;
60         transformer.add<ngraph::pass::low_precision::ReluTransformation, ngraph::opset1::Relu>(testValues.params);
61         transformer.transform(actualFunction);
62
63         referenceFunction = ngraph::builder::subgraph::ReluFunction::getReference(
64             testValues.shape,
65             testValues.expected.precisionBeforeDequantization,
66             testValues.expected.dequantizationBefore,
67             testValues.expected.precisionAfterOperation,
68             testValues.expected.dequantizationAfter);
69     }
70
71     static std::string getTestCaseName(testing::TestParamInfo<ReluTransformationTestValues> obj) {
72         const ReluTransformationTestValues testValues = obj.param;
73
74         std::ostringstream result;
75         result <<
76             toString(testValues.params) << "_" <<
77             testValues.shape << "_" <<
78             testValues.actual.precisionBeforeDequantization << "_" <<
79             testValues.actual.dequantization << "_" <<
80             testValues.expected.dequantizationBefore;
81         return result.str();
82     }
83
84 protected:
85     std::shared_ptr<ngraph::Function> actualFunction;
86     std::shared_ptr<ngraph::Function> referenceFunction;
87 };
88
89 TEST_P(ReluTransformation, CompareFunctions) {
90     actualFunction->validate_nodes_and_infer_types();
91     auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
92     ASSERT_TRUE(res.first) << res.second;
93 }
94
95 const std::vector<ngraph::Shape> shapes = {
96     { 1, 3, 16, 16 }
97 };
98
99 const std::vector<ReluTransformationTestValues> testValues = {
100     // U8: no subtract
101     {
102         ngraph::Shape({ 1, 3, 16, 16 }),
103         LayerTransformation::createParamsU8I8(),
104         {
105             ngraph::element::u8,
106             {{ngraph::element::f32}, {}, {0.1f}}
107         },
108         {
109             ngraph::element::u8,
110             {{}, {}, {}},
111             ngraph::element::u8,
112             {{ngraph::element::f32}, {}, {0.1f}}
113         }
114     },
115     // U8: no subtract
116     {
117         ngraph::Shape({ 1, 3, 16, 16 }),
118         LayerTransformation::createParamsU8I8(),
119         {
120             ngraph::element::u8,
121             {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}}}
122         },
123         {
124             ngraph::element::u8,
125             {{}, {}, {}},
126             ngraph::element::u8,
127             {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}}}
128         }
129     },
130     // U8: no subtract
131     {
132         ngraph::Shape({ 1, 3, 16, 16 }),
133         LayerTransformation::createParamsU8I8(),
134         {
135             ngraph::element::u8,
136             {{ngraph::element::f32}, {}, {{0.1f, -0.2f, 0.3f}}}
137         },
138         {
139             ngraph::element::u8,
140             {{ngraph::element::f32}, {}, {{0.1f, -0.2f, 0.3f}}},
141             ngraph::element::f32,
142             {{}, {}, {}}
143         }
144     },
145     // I8: no subtract
146     {
147         ngraph::Shape({ 1, 3, 16, 16 }),
148         LayerTransformation::createParamsI8I8(),
149         {
150             ngraph::element::i8,
151             {{ngraph::element::f32}, {}, {0.1f}}
152         },
153         {
154             ngraph::element::i8,
155             {{}, {}, {}},
156             ngraph::element::i8,
157             {{ngraph::element::f32}, {}, {0.1f}}
158         }
159     },
160     // U8: with subtract value
161     {
162         ngraph::Shape({ 1, 3, 16, 16 }),
163         LayerTransformation::createParamsU8I8(),
164         {
165             ngraph::element::u8,
166             {{ngraph::element::f32}, { 128 }, {0.1f}}
167         },
168         {
169             ngraph::element::u8,
170             {{ngraph::element::f32}, { 128 }, {0.1f}},
171             ngraph::element::f32,
172             {{}, {}, {}}
173         }
174     },
175     // I8: with subtract value
176     {
177         ngraph::Shape({ 1, 3, 16, 16 }),
178         LayerTransformation::createParamsI8I8().setSupportAsymmetricQuantization(true),
179         {
180             ngraph::element::i8,
181             {{ngraph::element::f32}, { 127 }, {0.1f}}
182         },
183         {
184             ngraph::element::i8,
185             {{ngraph::element::f32}, { 127 }, {0.1f}},
186             ngraph::element::f32,
187             {{}, {}, {}}
188         }
189     },
190     // I8: with subtract value
191     {
192         ngraph::Shape({ 1, 3, 16, 16 }),
193         LayerTransformation::createParamsI8I8().setSupportAsymmetricQuantization(false),
194         {
195             ngraph::element::i8,
196             {{ngraph::element::f32}, { 127 }, {0.1f}}
197         },
198         {
199             ngraph::element::i8,
200             {{ngraph::element::f32}, { 127 }, {0.1f}},
201             ngraph::element::f32,
202             {{}, {}, {}}
203         }
204     },
205     // U8: empty
206     {
207         ngraph::Shape({ 1, 3, 16, 16 }),
208         LayerTransformation::createParamsU8I8(),
209         {
210             ngraph::element::u8,
211             {}
212         },
213         {
214             ngraph::element::u8,
215             {},
216             ngraph::element::u8,
217             {}
218         }
219     },
220     // FP32: empty
221     {
222         ngraph::Shape({ 1, 3, 16, 16 }),
223         LayerTransformation::createParamsU8I8(),
224         {
225             ngraph::element::f32,
226             {}
227         },
228         {
229             ngraph::element::f32,
230             {},
231             ngraph::element::f32,
232             {}
233         }
234     }
235 };
236
237 INSTANTIATE_TEST_CASE_P(
238     LPT,
239     ReluTransformation,
240     ::testing::ValuesIn(testValues),
241     ReluTransformation::getTestCaseName);
242
243 } // namespace