1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "layer_transformation.hpp"
11 #include <gtest/gtest.h>
13 #include <transformations/utils/utils.hpp>
14 #include <transformations/init_node_info.hpp>
15 #include <low_precision/prelu.hpp>
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
19 #include "ngraph_functions/low_precision_transformations/prelu_function.hpp"
20 #include "simple_low_precision_transformer.hpp"
24 using namespace testing;
25 using namespace ngraph::pass;
27 class PReluTransformationTestValues {
31 ngraph::element::Type precisionBeforeDequantization;
32 ngraph::builder::subgraph::DequantizationOperations dequantization;
37 ngraph::element::Type precisionBeforeDequantization;
38 ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
39 ngraph::element::Type precisionAfterOperation;
40 ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
44 ngraph::pass::low_precision::LayerTransformation::Params params;
49 class PReluTransformation : public LayerTransformation, public testing::WithParamInterface<PReluTransformationTestValues> {
51 void SetUp() override {
52 const PReluTransformationTestValues testValues = GetParam();
54 actualFunction = ngraph::builder::subgraph::PReluFunction::getOriginal(
56 testValues.actual.precisionBeforeDequantization,
57 testValues.actual.dequantization);
59 SimpleLowPrecisionTransformer transformer;
60 transformer.add<ngraph::pass::low_precision::PReluTransformation, ngraph::opset1::PRelu>(testValues.params);
61 transformer.transform(actualFunction);
63 referenceFunction = ngraph::builder::subgraph::PReluFunction::getReference(
65 testValues.expected.precisionBeforeDequantization,
66 testValues.expected.dequantizationBefore,
67 testValues.expected.precisionAfterOperation,
68 testValues.expected.dequantizationAfter);
71 static std::string getTestCaseName(testing::TestParamInfo<PReluTransformationTestValues> obj) {
72 const PReluTransformationTestValues testValues = obj.param;
74 std::ostringstream result;
76 testValues.shape << "_" <<
77 testValues.actual.precisionBeforeDequantization << "_" <<
78 testValues.actual.dequantization << "_" <<
79 testValues.expected.dequantizationBefore;
84 std::shared_ptr<ngraph::Function> actualFunction;
85 std::shared_ptr<ngraph::Function> referenceFunction;
88 TEST_P(PReluTransformation, CompareFunctions) {
89 InitNodeInfo().run_on_function(actualFunction);
90 actualFunction->validate_nodes_and_infer_types();
91 auto res = compare_functions(referenceFunction, actualFunction, true, true);
92 ASSERT_TRUE(res.first) << res.second;
95 const std::vector<ngraph::Shape> shapes = {
99 const std::vector<PReluTransformationTestValues> testValues = {
102 ngraph::Shape({ 1, 3, 16, 16 }),
103 LayerTransformation::createParamsU8I8(),
106 {{ngraph::element::f32}, {}, {0.1f}}
111 ngraph::element::f32,
117 ngraph::Shape({ 1, 3, 16, 16 }),
118 LayerTransformation::createParamsI8I8(),
121 {{ngraph::element::f32}, {}, {0.1f}}
126 ngraph::element::f32,
130 // U8: with positive subtract value
132 ngraph::Shape({ 1, 3, 16, 16 }),
133 LayerTransformation::createParamsU8I8(),
136 {{ngraph::element::f32}, { 128 }, {0.1f}}
140 {{}, { {128}, ngraph::element::f32 }, {}},
141 ngraph::element::f32,
145 // I8: with positive subtract value
147 ngraph::Shape({ 1, 3, 16, 16 }),
148 LayerTransformation::createParamsI8I8(),
151 {{ngraph::element::f32}, { 127 }, {0.1f}}
155 {{}, { {127}, ngraph::element::f32 }, {}},
156 ngraph::element::f32,
160 // U8: with negative subtract value: Convert is still here
162 ngraph::Shape({ 1, 3, 16, 16 }),
163 LayerTransformation::createParamsU8I8(),
166 {{ngraph::element::f32}, { -128 }, {0.1f}}
170 {{ngraph::element::f32}, { {-128}, ngraph::element::f32 }, {}},
171 ngraph::element::f32,
177 INSTANTIATE_TEST_CASE_P(
180 ::testing::ValuesIn(testValues),
181 PReluTransformation::getTestCaseName);