1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "layer_transformation.hpp"
11 #include <gtest/gtest.h>
14 #include <transformations/utils/utils.hpp>
15 #include <transformations/init_node_info.hpp>
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "simple_low_precision_transformer.hpp"
20 #include <low_precision/add.hpp>
21 #include "ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp"
22 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
24 using namespace testing;
25 using namespace ngraph::pass;
26 using namespace ngraph::builder::subgraph;
28 class ElementwiseWithMultiParentDequantizationTransformationTestValues {
32 ngraph::element::Type precision1;
33 ngraph::builder::subgraph::DequantizationOperations dequantization1;
34 ngraph::element::Type precision2;
35 ngraph::builder::subgraph::DequantizationOperations dequantization2;
40 ngraph::element::Type precision1;
41 ngraph::builder::subgraph::DequantizationOperations dequantization1;
42 ngraph::element::Type precision2;
43 ngraph::builder::subgraph::DequantizationOperations dequantization2;
46 ngraph::element::Type precision;
47 ngraph::Shape inputShape;
48 ngraph::pass::low_precision::LayerTransformation::Params params;
54 inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& values) {
56 for (size_t i = 0; i < values.size(); ++i) {
58 if (i != (values.size() - 1ul)) {
66 class ElementwiseWithMultiParentDequantizationTransformation :
67 public LayerTransformation,
68 public testing::WithParamInterface<ElementwiseWithMultiParentDequantizationTransformationTestValues> {
70 void SetUp() override {
71 const ElementwiseWithMultiParentDequantizationTransformationTestValues testValues = GetParam();
73 actualFunction = ElementwiseWithMultiParentDequantizationFunction::get(
75 testValues.inputShape,
77 testValues.actual.precision1,
78 testValues.actual.dequantization1,
79 testValues.actual.precision2,
80 testValues.actual.dequantization2);
82 SimpleLowPrecisionTransformer transform;
83 transform.add<ngraph::pass::low_precision::AddTransformation, ngraph::opset1::Add>(
84 low_precision::LayerTransformation::Params(testValues.params));
85 transform.transform(actualFunction);
87 referenceFunction = ElementwiseWithMultiParentDequantizationFunction::get(
89 testValues.inputShape,
91 testValues.expected.precision1,
92 testValues.expected.dequantization1,
93 testValues.expected.precision2,
94 testValues.expected.dequantization2);
97 static std::string getTestCaseName(testing::TestParamInfo<ElementwiseWithMultiParentDequantizationTransformationTestValues> obj) {
98 const ElementwiseWithMultiParentDequantizationTransformationTestValues testValues = obj.param;
100 std::ostringstream result;
102 testValues.precision << "_" <<
103 testValues.inputShape << "_" <<
104 testValues.actual.precision1 << "_" <<
105 testValues.actual.dequantization1 << "_" <<
106 testValues.actual.precision2 << "_" <<
107 testValues.actual.dequantization2;
112 TEST_P(ElementwiseWithMultiParentDequantizationTransformation, CompareFunctions) {
113 actualFunction->validate_nodes_and_infer_types();
114 auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
115 ASSERT_TRUE(res.first) << res.second;
118 const std::vector<ElementwiseWithMultiParentDequantizationTransformationTestValues> addTransformationTestValues = {
121 ngraph::element::f32,
122 ngraph::Shape{1, 4, 16, 16},
123 LayerTransformation::createParamsU8I8(),
126 { {ngraph::element::f32}, { 7.f }, { 10.f }},
132 { {ngraph::element::f32}, { 7.f }, { 10.f }},
139 ngraph::element::f32,
140 ngraph::Shape{1, 4, 16, 16},
141 LayerTransformation::createParamsU8I8(),
146 { {ngraph::element::f32}, { 7.f }, { 10.f }}
152 { {ngraph::element::f32}, { 7.f }, { 10.f }}
157 INSTANTIATE_TEST_CASE_P(
159 ElementwiseWithMultiParentDequantizationTransformation,
160 ::testing::ValuesIn(addTransformationTestValues),
161 ElementwiseWithMultiParentDequantizationTransformation::getTestCaseName);