1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "layer_transformation.hpp"
11 #include <gtest/gtest.h>
13 #include <transformations/utils/utils.hpp>
14 #include <transformations/init_node_info.hpp>
15 #include <low_precision/reshape.hpp>
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
19 #include "ngraph_functions/low_precision_transformations/reshape_function.hpp"
20 #include "simple_low_precision_transformer.hpp"
24 using namespace testing;
25 using namespace ngraph::pass;
27 class ReshapeTransformationTestValues {
31 ngraph::element::Type precisionBeforeDequantization;
32 ngraph::builder::subgraph::DequantizationOperations dequantization;
37 ngraph::element::Type precisionBeforeDequantization;
38 ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
39 ngraph::element::Type precisionAfterOperation;
40 ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
43 ngraph::Shape inputShape;
44 std::vector<int> reshapeConstValues;
45 ngraph::pass::low_precision::LayerTransformation::Params params;
50 inline std::ostream& operator<<(std::ostream& os, const std::vector<int>& values) {
52 for (size_t i = 0; i < values.size(); ++i) {
54 if (i != (values.size() - 1ul)) {
62 class ReshapeTransformation : public LayerTransformation, public testing::WithParamInterface<ReshapeTransformationTestValues> {
64 void SetUp() override {
65 const ReshapeTransformationTestValues testValues = GetParam();
67 actualFunction = ngraph::builder::subgraph::ReshapeFunction::getOriginal(
68 testValues.inputShape,
69 testValues.reshapeConstValues,
70 testValues.actual.precisionBeforeDequantization,
71 testValues.actual.dequantization);
73 SimpleLowPrecisionTransformer transformer;
74 transformer.add<ngraph::pass::low_precision::ReshapeTransformation, ngraph::opset1::Reshape>(testValues.params);
75 transformer.transform(actualFunction);
77 referenceFunction = ngraph::builder::subgraph::ReshapeFunction::getReference(
78 testValues.inputShape,
79 testValues.reshapeConstValues,
80 testValues.expected.precisionBeforeDequantization,
81 testValues.expected.dequantizationBefore,
82 testValues.expected.precisionAfterOperation,
83 testValues.expected.dequantizationAfter);
86 static std::string getTestCaseName(testing::TestParamInfo<ReshapeTransformationTestValues> obj) {
87 const ReshapeTransformationTestValues testValues = obj.param;
89 std::ostringstream result;
91 testValues.inputShape << "_" <<
92 testValues.reshapeConstValues << "_" <<
93 testValues.actual.precisionBeforeDequantization << "_" <<
94 testValues.actual.dequantization << "_" <<
95 testValues.expected.precisionAfterOperation << "_" <<
96 testValues.expected.dequantizationAfter << "_" <<
97 testValues.expected.dequantizationBefore;
102 const std::vector<ReshapeTransformationTestValues> testValues = {
103 // U8: no subtract 3D -> 4D: channels are not affected
105 ngraph::Shape({ 1, 384, 1024 }),
107 LayerTransformation::createParamsU8I8(),
110 {{ngraph::element::f32}, {}, {0.1f}}
116 {{ngraph::element::f32}, {}, {0.1f}}
119 // U8: no subtract 3D -> 4D: channels are not affected
121 ngraph::Shape({ 4, 384, 1024 }),
123 LayerTransformation::createParamsU8I8(),
126 {{ngraph::element::f32}, {}, {0.1f}}
132 {{ngraph::element::f32}, {}, {0.1f}}
135 // U8: no subtract 3D -> 4D: channels are not affected: no subtract
137 ngraph::Shape({ 1, 3, 20 }),
139 LayerTransformation::createParamsU8I8(),
142 {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1}}}
148 {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}}
151 // U8: no subtract 3D -> 4D: channels are not affected: no subtract
153 ngraph::Shape({ 4, 3, 20 }),
155 LayerTransformation::createParamsU8I8(),
158 {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1}}}
164 {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}}
167 // U8: no subtract 3D -> 4D: channels are not affected: with subtract
169 ngraph::Shape({ 1, 3, 20 }),
171 LayerTransformation::createParamsU8I8(),
175 {ngraph::element::f32},
176 {{32, 64, 128}, ngraph::element::f32, {1, 3, 1}},
177 {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1}}
185 {ngraph::element::f32},
186 {{32, 64, 128}, ngraph::element::f32, {1, 3, 1, 1}},
187 {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}
191 // U8: no subtract 3D -> 4D: channels are not affected: with subtract
193 ngraph::Shape({ 1, 3, 20 }),
195 LayerTransformation::createParamsU8I8(),
199 {ngraph::element::f32},
200 {{32, 64, 128}, ngraph::element::f32, {1, 3, 1}},
201 {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1}}
209 {ngraph::element::f32},
210 {{32, 64, 128}, ngraph::element::f32, {1, 3, 1, 1}},
211 {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}
215 // U8: no subtract 4D -> 6D: channels are not affected: no subtract
217 ngraph::Shape({ 1, 3, 4, 5 }),
218 { 1, 3, 20, 1, 1, 1},
219 LayerTransformation::createParamsU8I8(),
222 {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}}
226 {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}},
227 ngraph::element::f32,
231 // U8: no subtract 4D -> 6D: channels are not affected: with subtract
233 ngraph::Shape({ 1, 3, 4, 5 }),
234 { 1, 3, 20, 1, 1, 1},
235 LayerTransformation::createParamsU8I8(),
239 {ngraph::element::f32},
240 {{32, 64, 128}, ngraph::element::f32, {1, 3, 1, 1}},
241 {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}
247 { ngraph::element::f32 },
248 {{32, 64, 128}, ngraph::element::f32, {1, 3, 1, 1}},
249 {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}
251 ngraph::element::f32,
255 // U8: no subtract 2D -> 4D: channels are affected: per tensor quantization
258 ngraph::Shape({ 1, 16, 384, 384 }),
260 LayerTransformation::createParamsU8I8(),
263 {{ngraph::element::f32}, {}, {0.1f}}
267 {{ngraph::element::f32}, {}, {0.1f}},
268 ngraph::element::f32,
272 // U8: no subtract 2D -> 4D: channels are affected: per channel quantization
274 ngraph::Shape({ 1, 3, 4, 5 }),
276 LayerTransformation::createParamsU8I8(),
279 {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}}}
283 {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}}},
284 ngraph::element::f32,
288 // U8: no subtract 2D -> 4D: channels are affected: per channel quantization
290 ngraph::Shape({ 1, 3, 4, 8 }),
292 LayerTransformation::createParamsU8I8(),
295 {{ngraph::element::f32}, {{0.f, 128.f, 255.f}}, {{0.1f, 0.2f, 0.3f}}}
299 {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32}, {{0.1f, 0.2f, 0.3f}}},
300 ngraph::element::f32,
306 ngraph::Shape({ 1, 3, 4, 8 }),
308 LayerTransformation::createParamsU8I8(),
310 ngraph::element::f32,
314 ngraph::element::f32,
316 ngraph::element::f32,
322 ngraph::Shape({ 1, 3, 4, 8 }),
324 LayerTransformation::createParamsU8I8(),
336 // U8: no subtract 4D -> 6D: channels are not affected: no subtract
338 ngraph::Shape({ 1, 3, 1, 1 }),
339 { 1, 3, 1, 1, 1, 1 },
340 LayerTransformation::createParamsU8I8(),
343 {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {3, 1, 1}}}
347 {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {3, 1, 1}}},
348 ngraph::element::f32,
352 // U8: no subtract 4D -> 2D: channels are not affected: per tensor quantization
355 ngraph::Shape({ 1, 3, 4, 5 }),
357 LayerTransformation::createParamsU8I8(),
360 {{ngraph::element::f32}, {{128.f}, ngraph::element::f32, {}}, {{0.1f}, ngraph::element::f32, {}}}
366 {{ngraph::element::f32}, {{128.f}, ngraph::element::f32, {}}, {{0.1f}, ngraph::element::f32, {}}}
369 // U8: no subtract 4D -> 2D: channels are not affected: per tensor quantization
371 ngraph::Shape({ 1, 3, 2, 2 }),
373 LayerTransformation::createParamsU8I8(),
376 {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3, 1, 1}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}}
383 {ngraph::element::f32},
384 {{0.f, 0.f, 0.f, 0.f, 128.f, 128.f, 128.f, 128.f, 255.f, 255.f, 255.f, 255.f}, ngraph::element::f32, {1, 12}},
385 {{0.1f, 0.1f, 0.1f, 0.1f, 0.2f, 0.2f, 0.2f, 0.2f, 0.3f, 0.3f, 0.3f, 0.3f}, ngraph::element::f32, {1, 12}}
389 // U8: no subtract 4D -> 2D: channels are not affected: per tensor quantization
391 ngraph::Shape({ 4, 3, 2, 2 }),
393 LayerTransformation::createParamsU8I8(),
396 {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3, 1, 1}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}}
403 {ngraph::element::f32},
404 {{0.f, 0.f, 0.f, 0.f, 128.f, 128.f, 128.f, 128.f, 255.f, 255.f, 255.f, 255.f}, ngraph::element::f32, {1, 12}},
405 {{0.1f, 0.1f, 0.1f, 0.1f, 0.2f, 0.2f, 0.2f, 0.2f, 0.3f, 0.3f, 0.3f, 0.3f}, ngraph::element::f32, {1, 12}}
409 // U8: no subtract 4D -> 2D: channels are not affected: per channel quantization: case #1: dequantization operation constant needs broadcast
411 ngraph::Shape({ 1, 3, 1, 1 }),
413 LayerTransformation::createParamsU8I8(),
416 {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {3, 1, 1}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {3, 1, 1}}}
422 {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3}}},
425 // U8: no subtract 4D -> 2D: channels are not affected: per channel quantization: case #2: dequantization operation constant doesn't need broadcast
427 ngraph::Shape({ 1, 3, 1, 1 }),
429 LayerTransformation::createParamsU8I8(),
432 {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3, 1, 1}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}}
438 {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3}}},
441 // U8: no subtract 4D -> 3D: channels are affected: per tensor quantization: case #1: dequantization operation constant needs broadcast
443 ngraph::Shape({ 1, 3, 4, 5 }),
445 LayerTransformation::createParamsU8I8(),
448 {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {3, 1, 1}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {3, 1, 1}}}
454 {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3, 1}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1}}},
457 // U8: no subtract 4D -> 3D: channels are affected: per tensor quantization: case #2: dequantization operation constant doesn't need broadcast
459 ngraph::Shape({ 1, 3, 4, 5 }),
461 LayerTransformation::createParamsU8I8(),
464 {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3, 1, 1}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}}
470 {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3, 1}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1}}},
473 // U8: no subtract 4D -> 2D
475 ngraph::Shape({ 1, 2048, 1, 1 }),
477 LayerTransformation::createParamsU8I8(),
480 {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {}}}
486 {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {}}}
489 // U8: no subtract 4D -> 2D
491 ngraph::Shape({ 2, 2048, 1, 1 }),
493 LayerTransformation::createParamsU8I8(),
496 {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1ul}}}
502 {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1ul}}}
505 // U8: no subtract 4D -> 2D
507 ngraph::Shape({ 1, 2048, 1, 1 }),
509 LayerTransformation::createParamsU8I8(),
512 {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1, 1, 1, 1}}}
518 {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1, 1}}}
521 // U8: no subtract 4D -> 2D: channels are not affected
523 ngraph::Shape({ 2, 2048, 1, 1 }),
525 LayerTransformation::createParamsU8I8(),
528 {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1, 1, 1, 1}}}
534 {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1, 1}}}
539 TEST_P(ReshapeTransformation, CompareFunctions) {
540 InitNodeInfo().run_on_function(actualFunction);
541 actualFunction->validate_nodes_and_infer_types();
542 auto res = compare_functions(referenceFunction, actualFunction, true, true);
543 ASSERT_TRUE(res.first) << res.second;
546 INSTANTIATE_TEST_CASE_P(
548 ReshapeTransformation,
549 ::testing::ValuesIn(testValues),
550 ReshapeTransformation::getTestCaseName);