1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "layer_transformation.hpp"
11 #include <gtest/gtest.h>
13 #include <transformations/utils/utils.hpp>
14 #include <transformations/init_node_info.hpp>
15 #include <low_precision/relu.hpp>
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
19 #include "ngraph_functions/low_precision_transformations/relu_function.hpp"
20 #include "simple_low_precision_transformer.hpp"
24 using namespace testing;
25 using namespace ngraph::pass;
27 class ReluTransformationTestValues {
31 ngraph::element::Type precisionBeforeDequantization;
32 ngraph::builder::subgraph::DequantizationOperations dequantization;
37 ngraph::element::Type precisionBeforeDequantization;
38 ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
39 ngraph::element::Type precisionAfterOperation;
40 ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
44 ngraph::pass::low_precision::LayerTransformation::Params params;
49 class ReluTransformation : public LayerTransformation, public testing::WithParamInterface<ReluTransformationTestValues> {
51 void SetUp() override {
52 const ReluTransformationTestValues testValues = GetParam();
54 actualFunction = ngraph::builder::subgraph::ReluFunction::getOriginal(
56 testValues.actual.precisionBeforeDequantization,
57 testValues.actual.dequantization);
59 SimpleLowPrecisionTransformer transformer;
60 transformer.add<ngraph::pass::low_precision::ReluTransformation, ngraph::opset1::Relu>(testValues.params);
61 transformer.transform(actualFunction);
63 referenceFunction = ngraph::builder::subgraph::ReluFunction::getReference(
65 testValues.expected.precisionBeforeDequantization,
66 testValues.expected.dequantizationBefore,
67 testValues.expected.precisionAfterOperation,
68 testValues.expected.dequantizationAfter);
71 static std::string getTestCaseName(testing::TestParamInfo<ReluTransformationTestValues> obj) {
72 const ReluTransformationTestValues testValues = obj.param;
74 std::ostringstream result;
76 toString(testValues.params) << "_" <<
77 testValues.shape << "_" <<
78 testValues.actual.precisionBeforeDequantization << "_" <<
79 testValues.actual.dequantization << "_" <<
80 testValues.expected.dequantizationBefore;
85 std::shared_ptr<ngraph::Function> actualFunction;
86 std::shared_ptr<ngraph::Function> referenceFunction;
89 TEST_P(ReluTransformation, CompareFunctions) {
90 actualFunction->validate_nodes_and_infer_types();
91 auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
92 ASSERT_TRUE(res.first) << res.second;
95 const std::vector<ngraph::Shape> shapes = {
99 const std::vector<ReluTransformationTestValues> testValues = {
102 ngraph::Shape({ 1, 3, 16, 16 }),
103 LayerTransformation::createParamsU8I8(),
106 {{ngraph::element::f32}, {}, {0.1f}}
112 {{ngraph::element::f32}, {}, {0.1f}}
117 ngraph::Shape({ 1, 3, 16, 16 }),
118 LayerTransformation::createParamsU8I8(),
121 {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}}}
127 {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}}}
132 ngraph::Shape({ 1, 3, 16, 16 }),
133 LayerTransformation::createParamsU8I8(),
136 {{ngraph::element::f32}, {}, {{0.1f, -0.2f, 0.3f}}}
140 {{ngraph::element::f32}, {}, {{0.1f, -0.2f, 0.3f}}},
141 ngraph::element::f32,
147 ngraph::Shape({ 1, 3, 16, 16 }),
148 LayerTransformation::createParamsI8I8(),
151 {{ngraph::element::f32}, {}, {0.1f}}
157 {{ngraph::element::f32}, {}, {0.1f}}
160 // U8: with subtract value
162 ngraph::Shape({ 1, 3, 16, 16 }),
163 LayerTransformation::createParamsU8I8(),
166 {{ngraph::element::f32}, { 128 }, {0.1f}}
170 {{ngraph::element::f32}, { 128 }, {0.1f}},
171 ngraph::element::f32,
175 // I8: with subtract value
177 ngraph::Shape({ 1, 3, 16, 16 }),
178 LayerTransformation::createParamsI8I8().setSupportAsymmetricQuantization(true),
181 {{ngraph::element::f32}, { 127 }, {0.1f}}
185 {{ngraph::element::f32}, { 127 }, {0.1f}},
186 ngraph::element::f32,
190 // I8: with subtract value
192 ngraph::Shape({ 1, 3, 16, 16 }),
193 LayerTransformation::createParamsI8I8().setSupportAsymmetricQuantization(false),
196 {{ngraph::element::f32}, { 127 }, {0.1f}}
200 {{ngraph::element::f32}, { 127 }, {0.1f}},
201 ngraph::element::f32,
207 ngraph::Shape({ 1, 3, 16, 16 }),
208 LayerTransformation::createParamsU8I8(),
222 ngraph::Shape({ 1, 3, 16, 16 }),
223 LayerTransformation::createParamsU8I8(),
225 ngraph::element::f32,
229 ngraph::element::f32,
231 ngraph::element::f32,
237 INSTANTIATE_TEST_CASE_P(
240 ::testing::ValuesIn(testValues),
241 ReluTransformation::getTestCaseName);