1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "layer_transformation.hpp"
9 #include <gtest/gtest.h>
11 #include <transformations/init_node_info.hpp>
12 #include <low_precision/clamp.hpp>
14 #include "common_test_utils/ngraph_test_utils.hpp"
15 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
16 #include "ngraph_functions/low_precision_transformations/clamp_function.hpp"
17 #include "simple_low_precision_transformer.hpp"
21 using namespace testing;
22 using namespace ngraph::pass;
24 class ClampTransformationTestValues {
28 ngraph::element::Type precisionBeforeDequantization;
29 ngraph::builder::subgraph::DequantizationOperations dequantization;
34 ngraph::element::Type precisionBeforeDequantization;
35 ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
36 ngraph::element::Type precisionAfterOperation;
37 ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
40 ngraph::Shape inputShape;
41 ngraph::pass::low_precision::LayerTransformation::Params params;
46 class ClampTransformation : public LayerTransformation, public testing::WithParamInterface<ClampTransformationTestValues> {
48 void SetUp() override {
49 const ClampTransformationTestValues testValues = GetParam();
51 actualFunction = ngraph::builder::subgraph::ClampFunction::getOriginal(
52 testValues.inputShape,
53 testValues.actual.precisionBeforeDequantization,
54 testValues.actual.dequantization);
56 SimpleLowPrecisionTransformer transformer;
57 transformer.add<ngraph::pass::low_precision::ClampTransformation, ngraph::opset1::Clamp>(testValues.params);
58 transformer.transform(actualFunction);
60 referenceFunction = ngraph::builder::subgraph::ClampFunction::getReference(
61 testValues.inputShape,
62 testValues.expected.precisionBeforeDequantization,
63 testValues.expected.dequantizationBefore,
64 testValues.expected.precisionAfterOperation,
65 testValues.expected.dequantizationAfter);
68 static std::string getTestCaseName(testing::TestParamInfo<ClampTransformationTestValues> obj) {
69 const ClampTransformationTestValues testValues = obj.param;
71 std::ostringstream result;
72 result << toString(testValues.params) << "_" <<
73 testValues.inputShape << "_" <<
74 testValues.actual.precisionBeforeDequantization << "_" <<
75 testValues.actual.dequantization << "_" <<
76 testValues.expected.dequantizationBefore;
81 TEST_P(ClampTransformation, CompareFunctions) {
82 InitNodeInfo().run_on_function(actualFunction);
83 actualFunction->validate_nodes_and_infer_types();
85 auto res = compare_functions(referenceFunction, actualFunction, true, true);
86 ASSERT_TRUE(res.first) << res.second;
89 const std::vector<ClampTransformationTestValues> testValues = {
90 // U8 per tensor quantization
92 ngraph::Shape({ 1, 3, 224, 224 }),
93 LayerTransformation::createParamsU8I8(),
97 {{ngraph::element::f32}, {128.f}, {3.f}}
103 ngraph::element::f32,
107 // I8 per tensor quantization
109 ngraph::Shape({ 1, 3, 224, 224 }),
110 LayerTransformation::createParamsI8I8(),
113 {{ngraph::element::f32}, {128.f}, {-5.f}}
118 ngraph::element::f32,
119 {{}, {128.f}, {-5.f}}
122 // U8 without convert
124 ngraph::Shape({ 1, 3, 224, 224 }),
125 LayerTransformation::createParamsU8I8(),
127 ngraph::element::f32,
131 ngraph::element::f32,
133 ngraph::element::f32,
137 // I8 without convert
139 ngraph::Shape({ 1, 3, 224, 224 }),
140 LayerTransformation::createParamsI8I8(),
142 ngraph::element::f32,
146 ngraph::element::f32,
148 ngraph::element::f32,
152 // U8 without subtract
154 ngraph::Shape({ 1, 3, 224, 224 }),
155 LayerTransformation::createParamsU8I8(),
158 {{ngraph::element::f32}, {}, {3.f}}
163 ngraph::element::f32,
167 // I8 without subtract
169 ngraph::Shape({ 1, 3, 224, 224 }),
170 LayerTransformation::createParamsI8I8(),
173 {{ngraph::element::f32}, {}, {3.f}}
178 ngraph::element::f32,
182 // U8 per channel quantization with different values
184 ngraph::Shape({ 1, 3, 224, 224 }),
185 LayerTransformation::createParamsU8I8(),
189 {ngraph::element::f32},
190 {{128.f, 0.f, 128.f / 2}},
197 {ngraph::element::f32},
198 {{128.f, 0.f, 128.f / 2}},
201 ngraph::element::f32,
205 // I8 per channel quantization with different values
207 ngraph::Shape({ 1, 3, 224, 224 }),
208 LayerTransformation::createParamsI8I8(),
212 {ngraph::element::f32},
213 {{128.f, 0.f, 128.f / 2}},
220 {ngraph::element::f32},
221 {{128.f, 0.f, 128.f / 2}},
224 ngraph::element::f32,
228 // U8 per channel quantization with the same values
230 ngraph::Shape({ 1, 3, 224, 224 }),
231 LayerTransformation::createParamsU8I8(),
235 {ngraph::element::f32},
236 {{128.f, 128.f, 128.f}},
243 ngraph::element::f32,
246 {{128.f, 128.f, 128.f}},
251 // I8 per channel quantization with the same values
253 ngraph::Shape({ 1, 3, 224, 224 }),
254 LayerTransformation::createParamsI8I8(),
258 {ngraph::element::f32},
259 {{128.f, 128.f, 128.f}},
266 ngraph::element::f32,
269 {{128.f, 128.f, 128.f}},
274 // U8 dequantization in second dimension
276 ngraph::Shape({ 1, 3, 4, 4 }),
277 LayerTransformation::createParamsU8I8(),
281 {ngraph::element::f32},
282 {{128.f, 128.f, 128.f, 128.f}, ngraph::element::f32, {1, 1, 4, 1}},
283 {{3.f, 3.f, 3.f, 3.f}, ngraph::element::f32, {1, 1, 4, 1}}
289 {ngraph::element::f32},
290 {{128.f, 128.f, 128.f, 128.f}, ngraph::element::f32, {1, 1, 4, 1}},
291 {{3.f, 3.f, 3.f, 3.f}, ngraph::element::f32, {1, 1, 4, 1}}
293 ngraph::element::f32,
297 // I8 dequantization in second dimension
299 ngraph::Shape({ 1, 3, 4, 4 }),
300 LayerTransformation::createParamsI8I8(),
304 {ngraph::element::f32},
305 {{128.f, 128.f, 128.f, 128.f}, ngraph::element::f32, {1, 1, 4, 1}},
306 {{3.f, 3.f, 3.f, 3.f}, ngraph::element::f32, {1, 1, 4, 1}}
312 {ngraph::element::f32},
313 {{128.f, 128.f, 128.f, 128.f}, ngraph::element::f32, {1, 1, 4, 1}},
314 {{3.f, 3.f, 3.f, 3.f}, ngraph::element::f32, {1, 1, 4, 1}}
316 ngraph::element::f32,
320 // U8 asymmetric quantization
322 ngraph::Shape({ 1, 3, 224, 224 }),
323 LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
327 {ngraph::element::f32},
328 {{ 128.f, 0.f, 128.f }},
335 {ngraph::element::f32},
336 {{ 128.f, 0.f, 128.f }},
339 ngraph::element::f32,
343 // U8 without asymmetric quantization
345 ngraph::Shape({ 1, 3, 224, 224 }),
346 LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(false),
350 {ngraph::element::f32},
351 {{ 128.f, 0.f, 128.f }},
358 {ngraph::element::f32},
359 {{ 128.f, 0.f, 128.f }},
362 ngraph::element::f32,
366 // per channel quantization with small values
368 ngraph::Shape({ 1, 3, 224, 224 }),
369 LayerTransformation::createParamsU8I8(),
373 {ngraph::element::f32},
374 {{1e-14, 1e-12, 1e-15}},
375 {{1e-14, 1e-12, 1e-15}}
381 {ngraph::element::f32},
382 {{1e-14, 1e-12, 1e-15}},
383 {{1e-14, 1e-12, 1e-15}}
385 ngraph::element::f32,
390 INSTANTIATE_TEST_CASE_P(
393 ::testing::ValuesIn(testValues),
394 ClampTransformation::getTestCaseName);