1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "layer_transformation.hpp"
11 #include <gtest/gtest.h>
13 #include <transformations/utils/utils.hpp>
14 #include <transformations/init_node_info.hpp>
15 #include <low_precision/relu.hpp>
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
19 #include "ngraph_functions/low_precision_transformations/relu_function.hpp"
20 #include "simple_low_precision_transformer.hpp"
24 using namespace testing;
25 using namespace ngraph::pass;
27 class ReluTransformationTestValues {
31 ngraph::element::Type precisionBeforeDequantization;
32 ngraph::builder::subgraph::DequantizationOperations dequantization;
37 ngraph::element::Type precisionBeforeDequantization;
38 ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
39 ngraph::element::Type precisionAfterOperation;
40 ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
44 ngraph::pass::low_precision::LayerTransformation::Params params;
49 class ReluTransformation : public LayerTransformation, public testing::WithParamInterface<ReluTransformationTestValues> {
51 void SetUp() override {
52 const ReluTransformationTestValues testValues = GetParam();
54 actualFunction = ngraph::builder::subgraph::ReluFunction::getOriginal(
56 testValues.actual.precisionBeforeDequantization,
57 testValues.actual.dequantization);
59 SimpleLowPrecisionTransformer transformer;
60 transformer.add<ngraph::pass::low_precision::ReluTransformation, ngraph::opset1::Relu>(testValues.params);
61 transformer.transform(actualFunction);
63 referenceFunction = ngraph::builder::subgraph::ReluFunction::getReference(
65 testValues.expected.precisionBeforeDequantization,
66 testValues.expected.dequantizationBefore,
67 testValues.expected.precisionAfterOperation,
68 testValues.expected.dequantizationAfter);
71 static std::string getTestCaseName(testing::TestParamInfo<ReluTransformationTestValues> obj) {
72 const ReluTransformationTestValues testValues = obj.param;
74 std::ostringstream result;
76 testValues.shape << "_" <<
77 testValues.actual.precisionBeforeDequantization << "_" <<
78 testValues.actual.dequantization << "_" <<
79 testValues.expected.dequantizationBefore;
84 std::shared_ptr<ngraph::Function> actualFunction;
85 std::shared_ptr<ngraph::Function> referenceFunction;
88 TEST_P(ReluTransformation, CompareFunctions) {
89 actualFunction->validate_nodes_and_infer_types();
90 auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
91 ASSERT_TRUE(res.first) << res.second;
94 const std::vector<ngraph::Shape> shapes = {
98 const std::vector<ReluTransformationTestValues> testValues = {
101 ngraph::Shape({ 1, 3, 16, 16 }),
102 LayerTransformation::createParamsU8I8(),
105 {{ngraph::element::f32}, {}, {0.1f}}
111 {{ngraph::element::f32}, {}, {0.1f}}
116 ngraph::Shape({ 1, 3, 16, 16 }),
117 LayerTransformation::createParamsU8I8(),
120 {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}}}
126 {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}}}
131 ngraph::Shape({ 1, 3, 16, 16 }),
132 LayerTransformation::createParamsU8I8(),
135 {{ngraph::element::f32}, {}, {{0.1f, -0.2f, 0.3f}}}
139 {{ngraph::element::f32}, {}, {{0.1f, -0.2f, 0.3f}}},
140 ngraph::element::f32,
146 ngraph::Shape({ 1, 3, 16, 16 }),
147 LayerTransformation::createParamsI8I8(),
150 {{ngraph::element::f32}, {}, {0.1f}}
156 {{ngraph::element::f32}, {}, {0.1f}}
159 // U8: with subtract value
161 ngraph::Shape({ 1, 3, 16, 16 }),
162 LayerTransformation::createParamsU8I8(),
165 {{ngraph::element::f32}, { 128 }, {0.1f}}
169 {{}, { {128}, ngraph::element::f32 }, {}},
170 ngraph::element::f32,
174 // I8: with subtract value
176 ngraph::Shape({ 1, 3, 16, 16 }),
177 LayerTransformation::createParamsI8I8().setSupportAsymmetricQuantization(true),
180 {{ngraph::element::f32}, { 127 }, {0.1f}}
184 {{}, { {127}, ngraph::element::f32 }, {}},
185 ngraph::element::f32,
189 // I8: with subtract value
191 ngraph::Shape({ 1, 3, 16, 16 }),
192 LayerTransformation::createParamsI8I8().setSupportAsymmetricQuantization(false),
195 {{ngraph::element::f32}, { 127 }, {0.1f}}
199 {{ngraph::element::f32}, { 127 }, {0.1f}},
200 ngraph::element::f32,
206 ngraph::Shape({ 1, 3, 16, 16 }),
207 LayerTransformation::createParamsU8I8(),
221 ngraph::Shape({ 1, 3, 16, 16 }),
222 LayerTransformation::createParamsU8I8(),
224 ngraph::element::f32,
228 ngraph::element::f32,
230 ngraph::element::f32,
236 INSTANTIATE_TEST_CASE_P(
239 ::testing::ValuesIn(testValues),
240 ReluTransformation::getTestCaseName);