1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "layer_transformation.hpp"
11 #include <gtest/gtest.h>
13 #include <transformations/init_node_info.hpp>
14 #include <low_precision/transformer.hpp>
15 #include <low_precision/mat_mul.hpp>
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "ngraph_functions/low_precision_transformations/mat_mul_function.hpp"
19 #include "ngraph_functions/subgraph_builders.hpp"
20 #include "simple_low_precision_transformer.hpp"
21 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
25 using namespace testing;
26 using namespace ngraph::pass;
28 class MatMullTransformationTestValues {
32 ngraph::Shape inputShape;
33 ngraph::element::Type precisionBeforeDequantization;
34 ngraph::builder::subgraph::DequantizationOperations dequantization;
35 ngraph::Shape weightsConstShape;
36 std::vector<float> weightsConstValues;
37 ngraph::builder::subgraph::FakeQuantizeOnWeights fqOnWeights;
42 ngraph::Shape inputShape;
43 ngraph::element::Type precisionBeforeDequantization;
44 ngraph::builder::subgraph::DequantizationOperations dequantization;
45 ngraph::element::Type weightsConstPrecision;
46 ngraph::Shape weightsConstShape;
47 std::vector<float> weightsConstValues;
49 ngraph::element::Type precisionBeforeOperation;
50 ngraph::builder::subgraph::DequantizationOperations resultDequantization;
52 ngraph::builder::subgraph::FakeQuantizeOnWeights fqOnWeights;
55 ngraph::pass::low_precision::LayerTransformation::Params params;
60 inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues::Actual& actual) {
62 actual.inputShape << "_" <<
63 actual.precisionBeforeDequantization << "_" <<
64 actual.dequantization << "_" <<
65 actual.weightsConstShape << "_" <<
69 inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues::Expected& expected) {
71 expected.weightsConstShape <<"_" <<
72 expected.dequantization << "_" <<
73 expected.precisionBeforeOperation << "_" <<
74 expected.resultDequantization << "_" <<
78 inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues& values) {
79 return out << "_" << values.actual << "_" << values.expected;
83 ngraph::element::Type,
85 MatMullTransformationTestValues> MatMulTransformationParams;
87 class MatMulWithConstantTransformation : public LayerTransformation, public testing::WithParamInterface<MatMulTransformationParams> {
89 void SetUp() override {
90 const ngraph::element::Type precision = std::get<0>(GetParam());
91 const size_t batch = std::get<1>(GetParam());
93 MatMullTransformationTestValues testValues = std::get<2>(GetParam());
94 testValues.actual.inputShape[0] = batch;
95 testValues.expected.inputShape[0] = batch;
97 actualFunction = ngraph::builder::subgraph::MatMulFunction::getOriginal(
99 testValues.actual.inputShape,
100 testValues.actual.precisionBeforeDequantization,
101 testValues.actual.dequantization,
102 testValues.actual.weightsConstShape,
103 testValues.actual.weightsConstValues,
104 testValues.actual.fqOnWeights);
106 SimpleLowPrecisionTransformer transformer;
107 transformer.add<ngraph::pass::low_precision::MatMulTransformation, ngraph::opset1::MatMul>(testValues.params);
108 transformer.transform(actualFunction);
110 referenceFunction = testValues.expected.fqOnWeights.empty() ?
111 ngraph::builder::subgraph::MatMulFunction::getReference(
113 testValues.expected.inputShape,
114 testValues.expected.precisionBeforeDequantization,
115 testValues.expected.dequantization,
116 testValues.expected.weightsConstPrecision,
117 testValues.expected.weightsConstShape,
118 testValues.expected.weightsConstValues,
119 testValues.expected.resultDequantization) :
120 ngraph::builder::subgraph::MatMulFunction::getOriginal(
122 testValues.expected.inputShape,
123 testValues.expected.precisionBeforeDequantization,
124 testValues.expected.dequantization,
125 testValues.expected.weightsConstShape,
126 testValues.expected.weightsConstValues,
127 testValues.expected.fqOnWeights);
130 static std::string getTestCaseName(testing::TestParamInfo<MatMulTransformationParams> obj) {
131 ngraph::element::Type precision;
133 MatMullTransformationTestValues testValues;
134 std::tie(precision, batch, testValues) = obj.param;
136 std::stringstream ss;
137 ss << precision << "_" << batch << "_" << testValues;
142 TEST_P(MatMulWithConstantTransformation, CompareFunctions) {
143 actualFunction->validate_nodes_and_infer_types();
144 auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
145 ASSERT_TRUE(res.first) << res.second;
148 const std::vector<ngraph::element::Type> precisions = {
149 ngraph::element::f32,
150 // ngraph::element::f16
153 const std::vector<size_t> batches = { 1, 4 };
155 std::vector<MatMullTransformationTestValues> testValues = {
156 // supported 3D: U8 & I8
158 LayerTransformation::createParamsU8I8(),
162 { ngraph::element::f32, {}, { 0.02f } },
164 std::vector<float>(1024 * 1024, 1.f),
165 { 255, { 1, 1 }, {0.f}, {254.f}, {-12.7f}, {12.7} },
173 std::vector<float>(1024 * 1024, -126),
175 { {}, {}, { 0.02f * 0.1f } },
180 // not supported 3D: U8 & I8
182 LayerTransformation::createParamsU8I8(),
186 { ngraph::element::f32, {}, { {0.01f, 0.02f, 0.03f} } },
188 std::vector<float>(4 * 4, 1.f),
189 { 255, { 1, 1 }, {0.f}, {254.f}, {-12.7f}, {12.7} },
194 { ngraph::element::f32, {}, { {0.01f, 0.02f, 0.03f} } },
197 std::vector<float>(4 * 4, 1.f),
198 ngraph::element::f32,
200 { 255, { 1, 1 }, {0.f}, {254.f}, {-12.7f}, {12.7} },
204 // not supported 3D: U8 & I8
206 LayerTransformation::createParamsU8I8(),
210 { ngraph::element::f32, {}, { 0.02f } },
212 std::vector<float>(4 * 4, 1.f),
216 {0.f, 0.f, 0.f, 0.f},
217 {254.f, 254.f, 254.f, 254.f},
218 {-12.7f / 4.f, -12.7f / 3.f, -12.7f / 2.f, -12.7f},
219 {12.7f / 4.f, 12.7f / 3.f, 12.7f / 2.f, 12.7f}
225 { ngraph::element::f32, {}, { 0.02f } },
228 std::vector<float>(4 * 4, 1.f),
229 ngraph::element::f32,
234 {0.f, 0.f, 0.f, 0.f},
235 {254.f, 254.f, 254.f, 254.f},
236 {-12.7f / 4.f, -12.7f / 3.f, -12.7f / 2.f, -12.7f},
237 {12.7f / 4.f, 12.7f / 3.f, 12.7f / 2.f, 12.7f}
244 LayerTransformation::createParamsU8I8(),
248 { ngraph::element::f32, {}, { 0.02f } },
250 std::vector<float>(2048 * 1000, 1.f),
251 { 255, { 1, 1 }, {0.f}, {254.f}, {-12.7f}, {12.7} },
259 std::vector<float>(2048 * 1000, -126),
261 { {}, {}, { 0.02f * 0.1f } },
267 LayerTransformation::createParamsI8I8(),
271 { ngraph::element::f32, {}, { 0.02f } },
273 std::vector<float>(2048 * 1000, 1.f),
274 { 255, { 1, 1 }, {0.f}, {254.f}, {-12.7f}, {12.7} },
282 std::vector<float>(2048 * 1000, -126),
284 { {}, {}, { 0.02f * 0.1f } },
290 LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
293 ngraph::element::f32,
294 { {}, {}, { 0.02f } },
296 std::vector<float>(2048 * 1000, 1.f),
297 { 255, { 1, 1 }, {0.f}, {254.f}, {-12.7f}, {12.7} },
301 ngraph::element::f32,
303 ngraph::element::f32,
305 std::vector<float>(2048 * 1000, -126),
306 ngraph::element::f32,
307 { {}, {}, { 0.02f * 0.1f } },
313 INSTANTIATE_TEST_CASE_P(
315 MatMulWithConstantTransformation,
317 ::testing::ValuesIn(precisions),
318 ::testing::ValuesIn(batches),
319 ::testing::ValuesIn(testValues)),
320 MatMulWithConstantTransformation::getTestCaseName);