1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "layer_transformation.hpp"
11 #include <gtest/gtest.h>
13 #include <transformations/utils/utils.hpp>
14 #include <transformations/init_node_info.hpp>
15 #include <low_precision/transformer.hpp>
16 #include <low_precision/mat_mul.hpp>
18 #include "common_test_utils/ngraph_test_utils.hpp"
19 #include "ngraph_functions/low_precision_transformations/mat_mul_function.hpp"
20 #include "ngraph_functions/subgraph_builders.hpp"
21 #include "simple_low_precision_transformer.hpp"
22 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
26 using namespace testing;
27 using namespace ngraph::pass;
29 class MatMullTransformationTestValues {
33 ngraph::element::Type precisionBeforeDequantization1;
34 ngraph::builder::subgraph::DequantizationOperations dequantization1;
35 ngraph::element::Type precisionBeforeDequantization2;
36 ngraph::builder::subgraph::DequantizationOperations dequantization2;
41 ngraph::element::Type precisionBeforeDequantization1;
42 ngraph::builder::subgraph::DequantizationOperations dequantization1;
43 ngraph::element::Type precisionBeforeDequantization2;
44 ngraph::builder::subgraph::DequantizationOperations dequantization2;
45 ngraph::element::Type precisionBeforeOperation1;
46 ngraph::element::Type precisionBeforeOperation2;
47 ngraph::builder::subgraph::DequantizationOperations result;
50 ngraph::pass::low_precision::LayerTransformation::Params params;
55 inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues::Actual& actual) {
56 return out << "_" << actual.dequantization1 << "_" << actual.dequantization2;
59 inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues::Expected& expected) {
61 expected.precisionBeforeDequantization1 << "_" <<
62 expected.dequantization1 << "_" <<
63 expected.precisionBeforeDequantization2 << "_" <<
64 expected.dequantization2 << "_" <<
65 expected.precisionBeforeOperation1 << "_" <<
66 expected.precisionBeforeOperation2 << "_" <<
70 inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues& values) {
72 values.params.supportAsymmetricQuantization << "_" <<
73 values.params.updatePrecisions << "_" <<
74 values.actual << "_" <<
79 ngraph::element::Type,
80 std::pair<ngraph::Shape, ngraph::Shape>,
81 MatMullTransformationTestValues> MatMulTransformationParams;
83 class MatMulTransformation : public LayerTransformation, public testing::WithParamInterface<MatMulTransformationParams> {
85 void SetUp() override {
86 const ngraph::element::Type precision = std::get<0>(GetParam());
87 const std::pair<ngraph::Shape, ngraph::Shape> shapes = std::get<1>(GetParam());
88 const MatMullTransformationTestValues testValues = std::get<2>(GetParam());
90 actualFunction = ngraph::builder::subgraph::MatMulFunction::getOriginal(
92 testValues.actual.precisionBeforeDequantization1,
93 testValues.actual.dequantization1,
95 testValues.actual.precisionBeforeDequantization2,
96 testValues.actual.dequantization2);
98 SimpleLowPrecisionTransformer transformer;
99 transformer.add<ngraph::pass::low_precision::MatMulTransformation, ngraph::opset1::MatMul>(testValues.params);
100 transformer.transform(actualFunction);
103 (testValues.expected.precisionBeforeOperation1 == ngraph::element::f32) && testValues.expected.result.empty() ?
104 ngraph::builder::subgraph::MatMulFunction::getOriginal(
106 testValues.actual.precisionBeforeDequantization1,
107 testValues.actual.dequantization1,
109 testValues.actual.precisionBeforeDequantization2,
110 testValues.actual.dequantization2) :
111 ngraph::builder::subgraph::MatMulFunction::getReference(
114 testValues.expected.precisionBeforeDequantization1,
115 testValues.expected.dequantization1,
117 testValues.expected.precisionBeforeDequantization2,
118 testValues.expected.dequantization2,
119 testValues.expected.result);
122 static std::string getTestCaseName(testing::TestParamInfo<MatMulTransformationParams> obj) {
123 ngraph::element::Type precision;
124 std::pair<ngraph::Shape, ngraph::Shape> shapes;
125 MatMullTransformationTestValues testValues;
126 std::tie(precision, shapes, testValues) = obj.param;
128 std::stringstream ss;
129 ss << precision << "_" << shapes.first << "_" << shapes.second << "_" << testValues;
134 TEST_P(MatMulTransformation, CompareFunctions) {
135 actualFunction->validate_nodes_and_infer_types();
136 auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
137 ASSERT_TRUE(res.first) << res.second;
140 const std::vector<ngraph::element::Type> precisions = {
141 ngraph::element::f32,
142 // ngraph::element::f16
145 const std::vector<std::pair<ngraph::Shape, ngraph::Shape>> shapes = {
146 { { 1, 16, 384, 64 }, { 1, 16, 64, 384 } },
147 { { 4, 16, 384, 64 }, { 4, 16, 64, 384 } }
150 const std::vector<bool> updatePrecisions = { true, false };
152 std::vector<MatMullTransformationTestValues> testValues = {
153 // U8 + I8: Constant on dequantization operations on 0 branch
155 // LayerTransformation::createParamsU8U8().setSupportAsymmetricQuantization(true),
157 // ngraph::element::u8,
158 // { ngraph::element::f32, { 127.f }, { {0.02f}, ngraph::element::f32, {}, true, 0 } },
159 // ngraph::element::i8,
160 // { ngraph::element::f32, {}, { 0.03f } },
163 // ngraph::element::u8,
164 // { {}, {{127.f}, ngraph::element::f32, ngraph::Shape{ }, false}, {} },
165 // ngraph::element::i8,
167 // ngraph::element::f32,
168 // ngraph::element::f32,
169 // { {}, {}, { 0.0006f } },
174 LayerTransformation::createParamsU8U8().setSupportAsymmetricQuantization(true),
177 { ngraph::element::f32, { 127.f }, { 0.02f } },
179 { ngraph::element::f32, {}, { 0.03f } },
183 { {}, {{127.f}, ngraph::element::f32, ngraph::Shape{ }, false}, {} },
186 ngraph::element::f32,
187 ngraph::element::f32,
188 { {}, {}, { 0.0006f } },
193 LayerTransformation::createParamsU8U8().setSupportAsymmetricQuantization(true),
196 { ngraph::element::f32, { 127.f }, { 0.02f } },
198 { ngraph::element::f32, {}, { 0.03f } },
202 { {}, {{127.f}, ngraph::element::f32, ngraph::Shape{ }, false}, {} },
205 ngraph::element::f32,
206 ngraph::element::f32,
207 { {}, {}, { 0.0006f } },
210 // U8 + I8, Subtract with not int
212 LayerTransformation::createParamsU8U8().setSupportAsymmetricQuantization(true),
215 { ngraph::element::f32, { 127.5f }, { 0.02f } },
217 { ngraph::element::f32, {}, { 0.03f } },
221 { {}, {{128.f}, ngraph::element::f32, ngraph::Shape{ }, false}, {} },
224 ngraph::element::f32,
225 ngraph::element::f32,
226 { {}, {}, { 0.0006f } },
231 LayerTransformation::createParamsU8U8().setSupportAsymmetricQuantization(true),
234 { ngraph::element::f32, { 127.f }, { 0.02f } },
235 ngraph::element::f32,
236 { {}, {}, { 0.03f } },
240 { ngraph::element::f32, { 127.f }, { 0.02f } },
241 ngraph::element::f32,
242 { {}, {}, { 0.03f } },
243 ngraph::element::f32,
244 ngraph::element::f32,
250 LayerTransformation::createParamsU8U8().setSupportAsymmetricQuantization(true),
252 ngraph::element::f32,
253 { {}, { 127.f }, { 0.02f } },
255 { ngraph::element::f32, {}, { 0.03f } },
258 ngraph::element::f32,
259 { {}, { 127.f }, { 0.02f } },
261 { ngraph::element::f32, {}, { 0.03f } },
262 ngraph::element::f32,
263 ngraph::element::f32,
268 LayerTransformation::createParamsU8U8().setSupportAsymmetricQuantization(false),
271 { ngraph::element::f32, { 127.f }, { 0.02f } },
273 { ngraph::element::f32, {}, { 0.03f } },
277 { ngraph::element::f32, { 127.f }, { 0.02f } },
279 { ngraph::element::f32, {}, { 0.03f } },
280 ngraph::element::f32,
281 ngraph::element::f32,
286 LayerTransformation::createParamsU8U8().setSupportAsymmetricQuantization(false),
289 { ngraph::element::f32, {}, { 0.02f } },
291 { ngraph::element::f32, {}, { 0.03f } },
300 { {}, {}, { 0.02f * 0.03f } },
304 LayerTransformation::createParamsU8U8(),
307 { ngraph::element::f32, {}, { 0.02f } },
309 { ngraph::element::f32, {}, { 0.03f } },
318 { {}, {}, { 0.02f * 0.03f } },
322 LayerTransformation::createParamsU8U8(),
325 { ngraph::element::f32, {}, { 0.02f } },
327 { ngraph::element::f32, {}, { 0.03f } },
336 { {}, {}, { 0.02f * 0.03f } },
340 LayerTransformation::createParamsI8I8().setUpdatePrecisions(true),
343 { ngraph::element::f32, {}, { 0.02f } },
345 { ngraph::element::f32, {}, { 0.03f } },
354 { {}, {}, { 0.02f * 0.03f } },
358 LayerTransformation::createParamsI8I8().setUpdatePrecisions(false),
360 ngraph::element::f32,
361 { {}, {}, { 0.02f } },
362 ngraph::element::f32,
363 { {}, {}, { 0.03f } },
366 ngraph::element::f32,
368 ngraph::element::f32,
370 ngraph::element::f32,
371 ngraph::element::f32,
372 { {}, {}, { 0.02f * 0.03f } },
377 INSTANTIATE_TEST_CASE_P(
379 MatMulTransformation,
381 ::testing::ValuesIn(precisions),
382 ::testing::ValuesIn(shapes),
383 ::testing::ValuesIn(testValues)),
384 MatMulTransformation::getTestCaseName);