1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "layer_transformation.hpp"
11 #include <gtest/gtest.h>
13 #include <transformations/utils/utils.hpp>
14 #include <transformations/init_node_info.hpp>
15 #include <low_precision/convolution.hpp>
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "simple_low_precision_transformer.hpp"
19 #include "ngraph_functions/low_precision_transformations/convolution_function.hpp"
21 using namespace testing;
22 using namespace ngraph;
23 using namespace ngraph::pass;
25 class ConvolutionTransformationTestValues {
29 ngraph::element::Type precisionBeforeDequantization;
30 ngraph::builder::subgraph::DequantizationOperations dequantization;
31 std::shared_ptr<ngraph::opset1::Constant> weights;
32 builder::subgraph::FakeQuantizeOnWeights fakeQuantizeOnWeights;
37 ngraph::element::Type precisionBeforeDequantization;
38 ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
39 std::shared_ptr<ngraph::opset1::Constant> weights;
40 builder::subgraph::FakeQuantizeOnWeights fakeQuantizeOnWeights;
41 ngraph::element::Type precisionAfterOperation;
42 ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
43 ngraph::element::Type precisionAfterDequantization;
46 ngraph::pass::low_precision::LayerTransformation::Params params;
53 ConvolutionTransformationTestValues> ConvolutionTransformationParams;
55 class ConvolutionTransformation : public LayerTransformation, public testing::WithParamInterface<ConvolutionTransformationParams> {
57 void SetUp() override {
58 const auto inputShape = std::get<0>(GetParam());
59 const auto testValues = std::get<1>(GetParam());
61 actualFunction = ngraph::builder::subgraph::ConvolutionFunction::getOriginal(
62 testValues.actual.precisionBeforeDequantization,
64 testValues.actual.dequantization,
65 testValues.actual.weights,
66 testValues.actual.fakeQuantizeOnWeights);
68 SimpleLowPrecisionTransformer transform;
69 transform.add<ngraph::pass::low_precision::ConvolutionTransformation, ngraph::opset1::Convolution>(testValues.params);
70 transform.transform(actualFunction);
72 referenceFunction = ngraph::builder::subgraph::ConvolutionFunction::getReference(
73 testValues.expected.precisionBeforeDequantization,
75 testValues.expected.dequantizationBefore,
76 testValues.expected.weights,
77 testValues.expected.fakeQuantizeOnWeights,
78 testValues.expected.precisionAfterOperation,
79 testValues.expected.dequantizationAfter,
80 testValues.expected.precisionAfterDequantization);
83 static std::string getTestCaseName(testing::TestParamInfo<ConvolutionTransformationParams> obj) {
84 auto inputShape = std::get<0>(obj.param);
85 ConvolutionTransformationTestValues testValues = std::get<1>(obj.param);
87 std::ostringstream result;
88 result << toString(testValues.params) << "_" <<
90 testValues.actual.precisionBeforeDequantization << "_" <<
91 testValues.actual.dequantization << "_" << "_weights_" <<
92 testValues.actual.weights->get_element_type() << "_" << "{ " <<
93 testValues.actual.weights->cast_vector<float>()[0] << " }_" <<
94 testValues.actual.fakeQuantizeOnWeights << "_";
99 TEST_P(ConvolutionTransformation, CompareFunctions) {
100 actualFunction->validate_nodes_and_infer_types();
101 auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
102 ASSERT_TRUE(res.first) << res.second;
105 const std::vector<ngraph::Shape> shapes = {
106 ngraph::Shape({ 1, 3, 72, 48 }),
107 ngraph::Shape({ 4, 3, 72, 48 })
110 const std::vector<ConvolutionTransformationTestValues> testValues = {
113 LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
117 {{ngraph::element::f32}, { 128.f }, { 0.02f }},
118 op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
119 { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
124 {{}, { { 128.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {}},
125 op::Constant::create(ngraph::element::i8, ngraph::Shape{}, std::vector<float>{ -125.f }),
127 ngraph::element::f32,
128 {{}, {}, {{ 0.0002f }, ngraph::element::f32, { 1, 1, 1 }}}
133 LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(false),
137 {{ngraph::element::f32}, { 128.f }, { 0.02f }},
138 op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
139 { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
144 {{ ngraph::element::f32 }, { 128.f }, { 0.02f }},
145 op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
146 { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } },
147 ngraph::element::f32,
151 // with zero point, not update precisions
153 LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
156 ngraph::element::f32,
157 {{ngraph::element::f32}, { 128.f }, { 0.02f }},
158 op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
159 { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
163 ngraph::element::f32,
164 {{}, { { 128.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {}},
165 op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ -125.f }),
167 ngraph::element::f32,
168 {{}, {}, {{ 0.0002f }, ngraph::element::f32, { 1, 1, 1 }}}
171 // without zero point
173 LayerTransformation::createParamsU8I8(),
177 {{ngraph::element::f32}, {}, { 0.02f }},
178 op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
179 { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
185 op::Constant::create(ngraph::element::i8, ngraph::Shape{}, std::vector<float>{ -125.f }),
187 ngraph::element::f32,
188 {{}, {}, {{ 0.0002f }, ngraph::element::f32, { 1, 1, 1 }}}
191 // without zero point
193 LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(false),
197 {{ngraph::element::f32}, {}, { 0.02f }},
198 op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
199 { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
205 op::Constant::create(ngraph::element::i8, ngraph::Shape{}, std::vector<float>{ -125.f }),
207 ngraph::element::f32,
208 {{}, {}, {{ 0.0002f }, ngraph::element::f32, { 1, 1, 1 }}}
211 // without zero point, not update precisions
213 LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
216 ngraph::element::f32,
217 {{ngraph::element::f32}, {}, { 0.02f }},
218 op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
219 { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
223 ngraph::element::f32,
225 op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ -125.f }),
227 ngraph::element::f32,
228 {{}, {}, {{ 0.0002f }, ngraph::element::f32, { 1, 1, 1 }}}
231 // with zero point, per-channel quantization with the same values
233 LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
237 {{ngraph::element::f32}, { { 128.f }, ngraph::element::f32, {1, 3, 1, 1} }, { { 0.02f }, ngraph::element::f32, {1, 3, 1, 1} }},
238 op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
239 { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
244 {{}, { { 128.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {}},
245 op::Constant::create(ngraph::element::i8, ngraph::Shape{}, std::vector<float>{ -125.f }),
247 ngraph::element::f32,
248 {{}, {}, {{ 0.0002f }, ngraph::element::f32, { 1, 1, 1 }}}
251 // with zero point, per-channel quantization with different values
253 LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
258 {ngraph::element::f32},
259 {{ 128.f, 0.f, 128.f }, ngraph::element::f32, { 1, 3, 1, 1 }},
260 {{ 0.02f, 0.01f, 0.03f }, ngraph::element::f32, {1, 3, 1, 1}}
262 op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
263 { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
269 {ngraph::element::f32},
270 {{ 128.f, 0.f, 128.f }, ngraph::element::f32, { 1, 3, 1, 1 }},
271 {{ 0.02f, 0.01f, 0.03f }, ngraph::element::f32, {1, 3, 1, 1}}
273 op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
274 { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } },
275 ngraph::element::f32,
279 // dequantization in second dimension
281 LayerTransformation::createParamsU8I8(),
284 ngraph::element::f32,
286 {ngraph::element::f32},
287 {{ 128.f }, ngraph::element::f32, { 1, 1, 1, 1 }},
288 {{ 0.02f }, ngraph::element::f32, {1, 1, 1, 1}}
290 op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
291 { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
295 ngraph::element::f32,
297 {ngraph::element::f32},
298 {{ 128.f }, ngraph::element::f32, { 1, 1, 1, 1 }},
299 {{ 0.02f }, ngraph::element::f32, {1, 1, 1, 1}}
301 op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
302 { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } },
303 ngraph::element::f32,
307 // without dequantization operations
309 LayerTransformation::createParamsU8I8(),
312 ngraph::element::f32,
314 op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
315 { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
319 ngraph::element::f32,
321 op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
322 { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } },
323 ngraph::element::f32,
327 // without zero point, without convert
329 LayerTransformation::createParamsU8I8(),
332 ngraph::element::f32,
334 op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
335 { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
339 ngraph::element::f32,
340 {{}, {}, { {0.02f}, element::f32 }},
341 op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
342 { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } },
343 ngraph::element::f32,
347 // without zero point
349 LayerTransformation::createParamsU8I8(),
353 {{element::f32}, {}, { {0.02f}, element::f32 }},
354 op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
355 { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
361 op::Constant::create(ngraph::element::i8, ngraph::Shape{}, std::vector<float>{ -125.f }),
363 ngraph::element::f32,
364 {{}, {}, {{ 0.0002f }, ngraph::element::f32, { 1, 1, 1 }}}
369 INSTANTIATE_TEST_CASE_P(
371 ConvolutionTransformation,
373 ::testing::ValuesIn(shapes),
374 ::testing::ValuesIn(testValues)),
375 ConvolutionTransformation::getTestCaseName);