ce9d9278ac06316c6817e79d4f90c2c4687da348
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / relu_transformation.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "layer_transformation.hpp"
6
7 #include <string>
8 #include <sstream>
9 #include <memory>
10
11 #include <gtest/gtest.h>
12
13 #include <transformations/utils/utils.hpp>
14 #include <transformations/init_node_info.hpp>
15 #include <low_precision/relu.hpp>
16
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
19 #include "ngraph_functions/low_precision_transformations/relu_function.hpp"
20 #include "simple_low_precision_transformer.hpp"
21
22 namespace {
23
24 using namespace testing;
25 using namespace ngraph::pass;
26
27 class ReluTransformationTestValues {
28 public:
29     class Actual {
30     public:
31         ngraph::element::Type precisionBeforeDequantization;
32         ngraph::builder::subgraph::DequantizationOperations dequantization;
33     };
34
35     class Expected {
36     public:
37         ngraph::element::Type precisionBeforeDequantization;
38         ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
39         ngraph::element::Type precisionAfterOperation;
40         ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
41     };
42
43     ngraph::Shape shape;
44     ngraph::pass::low_precision::LayerTransformation::Params params;
45     Actual actual;
46     Expected expected;
47 };
48
49 class ReluTransformation : public LayerTransformation, public testing::WithParamInterface<ReluTransformationTestValues> {
50 public:
51     void SetUp() override {
52         const ReluTransformationTestValues testValues = GetParam();
53
54         actualFunction = ngraph::builder::subgraph::ReluFunction::getOriginal(
55             testValues.shape,
56             testValues.actual.precisionBeforeDequantization,
57             testValues.actual.dequantization);
58
59         SimpleLowPrecisionTransformer transformer;
60         transformer.add<ngraph::pass::low_precision::ReluTransformation, ngraph::opset1::Relu>(testValues.params);
61         transformer.transform(actualFunction);
62
63         referenceFunction = ngraph::builder::subgraph::ReluFunction::getReference(
64             testValues.shape,
65             testValues.expected.precisionBeforeDequantization,
66             testValues.expected.dequantizationBefore,
67             testValues.expected.precisionAfterOperation,
68             testValues.expected.dequantizationAfter);
69     }
70
71     static std::string getTestCaseName(testing::TestParamInfo<ReluTransformationTestValues> obj) {
72         const ReluTransformationTestValues testValues = obj.param;
73
74         std::ostringstream result;
75         result <<
76             testValues.shape << "_" <<
77             testValues.actual.precisionBeforeDequantization << "_" <<
78             testValues.actual.dequantization << "_" <<
79             testValues.expected.dequantizationBefore;
80         return result.str();
81     }
82
83 protected:
84     std::shared_ptr<ngraph::Function> actualFunction;
85     std::shared_ptr<ngraph::Function> referenceFunction;
86 };
87
88 TEST_P(ReluTransformation, CompareFunctions) {
89     actualFunction->validate_nodes_and_infer_types();
90     auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
91     ASSERT_TRUE(res.first) << res.second;
92 }
93
94 const std::vector<ngraph::Shape> shapes = {
95     { 1, 3, 16, 16 }
96 };
97
98 const std::vector<ReluTransformationTestValues> testValues = {
99     // U8: no subtract
100     {
101         ngraph::Shape({ 1, 3, 16, 16 }),
102         LayerTransformation::createParamsU8I8(),
103         {
104             ngraph::element::u8,
105             {{ngraph::element::f32}, {}, {0.1f}}
106         },
107         {
108             ngraph::element::u8,
109             {{}, {}, {}},
110             ngraph::element::u8,
111             {{ngraph::element::f32}, {}, {0.1f}}
112         }
113     },
114     // U8: no subtract
115     {
116         ngraph::Shape({ 1, 3, 16, 16 }),
117         LayerTransformation::createParamsU8I8(),
118         {
119             ngraph::element::u8,
120             {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}}}
121         },
122         {
123             ngraph::element::u8,
124             {{}, {}, {}},
125             ngraph::element::u8,
126             {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}}}
127         }
128     },
129     // U8: no subtract
130     {
131         ngraph::Shape({ 1, 3, 16, 16 }),
132         LayerTransformation::createParamsU8I8(),
133         {
134             ngraph::element::u8,
135             {{ngraph::element::f32}, {}, {{0.1f, -0.2f, 0.3f}}}
136         },
137         {
138             ngraph::element::u8,
139             {{ngraph::element::f32}, {}, {{0.1f, -0.2f, 0.3f}}},
140             ngraph::element::f32,
141             {{}, {}, {}}
142         }
143     },
144     // I8: no subtract
145     {
146         ngraph::Shape({ 1, 3, 16, 16 }),
147         LayerTransformation::createParamsI8I8(),
148         {
149             ngraph::element::i8,
150             {{ngraph::element::f32}, {}, {0.1f}}
151         },
152         {
153             ngraph::element::i8,
154             {{}, {}, {}},
155             ngraph::element::i8,
156             {{ngraph::element::f32}, {}, {0.1f}}
157         }
158     },
159     // U8: with subtract value
160     {
161         ngraph::Shape({ 1, 3, 16, 16 }),
162         LayerTransformation::createParamsU8I8(),
163         {
164             ngraph::element::u8,
165             {{ngraph::element::f32}, { 128 }, {0.1f}}
166         },
167         {
168             ngraph::element::u8,
169             {{}, { {128}, ngraph::element::f32 }, {}},
170             ngraph::element::f32,
171             {{}, {}, {0.1f}}
172         }
173     },
174     // I8: with subtract value
175     {
176         ngraph::Shape({ 1, 3, 16, 16 }),
177         LayerTransformation::createParamsI8I8().setSupportAsymmetricQuantization(true),
178         {
179             ngraph::element::i8,
180             {{ngraph::element::f32}, { 127 }, {0.1f}}
181         },
182         {
183             ngraph::element::i8,
184             {{}, { {127}, ngraph::element::f32 }, {}},
185             ngraph::element::f32,
186             {{}, {}, {0.1f}}
187         }
188     },
189     // I8: with subtract value
190     {
191         ngraph::Shape({ 1, 3, 16, 16 }),
192         LayerTransformation::createParamsI8I8().setSupportAsymmetricQuantization(false),
193         {
194             ngraph::element::i8,
195             {{ngraph::element::f32}, { 127 }, {0.1f}}
196         },
197         {
198             ngraph::element::i8,
199             {{ngraph::element::f32}, { 127 }, {0.1f}},
200             ngraph::element::f32,
201             {{}, {}, {}}
202         }
203     },
204     // U8: empty
205     {
206         ngraph::Shape({ 1, 3, 16, 16 }),
207         LayerTransformation::createParamsU8I8(),
208         {
209             ngraph::element::u8,
210             {}
211         },
212         {
213             ngraph::element::u8,
214             {},
215             ngraph::element::u8,
216             {}
217         }
218     },
219     // FP32: empty
220     {
221         ngraph::Shape({ 1, 3, 16, 16 }),
222         LayerTransformation::createParamsU8I8(),
223         {
224             ngraph::element::f32,
225             {}
226         },
227         {
228             ngraph::element::f32,
229             {},
230             ngraph::element::f32,
231             {}
232         }
233     }
234 };
235
236 INSTANTIATE_TEST_CASE_P(
237     LPT,
238     ReluTransformation,
239     ::testing::ValuesIn(testValues),
240     ReluTransformation::getTestCaseName);
241
242 } // namespace