1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "layer_transformation.hpp"
11 #include <gtest/gtest.h>
13 #include <transformations/init_node_info.hpp>
14 #include <low_precision/transformer.hpp>
15 #include <low_precision/mat_mul.hpp>
17 #include "common_test_utils/ngraph_test_utils.hpp"
18 #include "ngraph_functions/low_precision_transformations/mat_mul_function.hpp"
19 #include "ngraph_functions/subgraph_builders.hpp"
20 #include "simple_low_precision_transformer.hpp"
21 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
25 using namespace testing;
26 using namespace ngraph::pass;
28 class MatMullTransformationTestValues {
32 ngraph::Shape inputShape;
33 ngraph::element::Type precisionBeforeDequantization;
34 ngraph::builder::subgraph::DequantizationOperations dequantization;
35 ngraph::Shape weightsConstShape;
36 std::vector<float> weightsConstValues;
37 ngraph::builder::subgraph::FakeQuantizeOnWeights fqOnWeights;
42 ngraph::Shape inputShape;
43 ngraph::element::Type precisionBeforeDequantization;
44 ngraph::builder::subgraph::DequantizationOperations dequantization;
45 ngraph::element::Type weightsConstPrecision;
46 ngraph::Shape weightsConstShape;
47 std::vector<float> weightsConstValues;
49 ngraph::element::Type precisionBeforeOperation;
50 ngraph::builder::subgraph::DequantizationOperations resultDequantization;
52 ngraph::builder::subgraph::FakeQuantizeOnWeights fqOnWeights;
55 ngraph::pass::low_precision::LayerTransformation::Params params;
60 inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues::Actual& actual) {
62 actual.inputShape << "_" <<
63 actual.precisionBeforeDequantization << "_" <<
64 actual.dequantization << "_" <<
65 actual.weightsConstShape << "_" <<
69 inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues::Expected& expected) {
71 expected.weightsConstShape <<"_" <<
72 expected.dequantization << "_" <<
73 expected.precisionBeforeOperation << "_" <<
74 expected.resultDequantization << "_" <<
78 inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues& values) {
79 return out << "_" << values.actual << "_" << values.expected;
83 ngraph::element::Type,
85 MatMullTransformationTestValues> MatMulTransformationParams;
87 class MatMulWithConstantTransformation : public LayerTransformation, public testing::WithParamInterface<MatMulTransformationParams> {
89 void SetUp() override {
90 const ngraph::element::Type precision = std::get<0>(GetParam());
91 const size_t batch = std::get<1>(GetParam());
93 MatMullTransformationTestValues testValues = std::get<2>(GetParam());
94 testValues.actual.inputShape[0] = batch;
95 testValues.expected.inputShape[0] = batch;
97 actualFunction = ngraph::builder::subgraph::MatMulFunction::getOriginal(
99 testValues.actual.inputShape,
100 testValues.actual.precisionBeforeDequantization,
101 testValues.actual.dequantization,
102 testValues.actual.weightsConstShape,
103 testValues.actual.weightsConstValues,
104 testValues.actual.fqOnWeights);
106 SimpleLowPrecisionTransformer transformer;
107 transformer.add<ngraph::pass::low_precision::MatMulTransformation, ngraph::opset1::MatMul>(testValues.params);
108 transformer.transform(actualFunction);
110 referenceFunction = testValues.expected.fqOnWeights.empty() ?
111 ngraph::builder::subgraph::MatMulFunction::getReference(
113 testValues.expected.inputShape,
114 testValues.expected.precisionBeforeDequantization,
115 testValues.expected.dequantization,
116 testValues.expected.weightsConstPrecision,
117 testValues.expected.weightsConstShape,
118 testValues.expected.weightsConstValues,
119 testValues.expected.resultDequantization) :
120 ngraph::builder::subgraph::MatMulFunction::getOriginal(
122 testValues.expected.inputShape,
123 testValues.expected.precisionBeforeDequantization,
124 testValues.expected.dequantization,
125 testValues.expected.weightsConstShape,
126 testValues.expected.weightsConstValues,
127 testValues.expected.fqOnWeights);
130 static std::string getTestCaseName(testing::TestParamInfo<MatMulTransformationParams> obj) {
131 ngraph::element::Type precision;
133 MatMullTransformationTestValues testValues;
134 std::tie(precision, batch, testValues) = obj.param;
136 std::stringstream ss;
137 ss << precision << "_" << batch << "_" << testValues;
142 TEST_P(MatMulWithConstantTransformation, CompareFunctions) {
143 InitNodeInfo().run_on_function(actualFunction);
145 actualFunction->validate_nodes_and_infer_types();
147 auto res = compare_functions(referenceFunction, actualFunction, true, true);
148 ASSERT_TRUE(res.first) << res.second;
151 const std::vector<ngraph::element::Type> precisions = {
152 ngraph::element::f32,
153 // ngraph::element::f16
156 const std::vector<size_t> batches = { 1, 4 };
158 std::vector<MatMullTransformationTestValues> testValues = {
159 // supported 3D: U8 & I8
161 LayerTransformation::createParamsU8I8(),
165 { ngraph::element::f32, {}, { 0.02f } },
167 std::vector<float>(1024 * 1024, 1.f),
168 { 255, { 1, 1 }, {0.f}, {254.f}, {-12.7f}, {12.7} },
176 std::vector<float>(1024 * 1024, -126),
178 { {}, {}, { 0.02f * 0.1f } },
183 // not supported 3D: U8 & I8
185 LayerTransformation::createParamsU8I8(),
189 { ngraph::element::f32, {}, { {0.01f, 0.02f, 0.03f} } },
191 std::vector<float>(4 * 4, 1.f),
192 { 255, { 1, 1 }, {0.f}, {254.f}, {-12.7f}, {12.7} },
197 { ngraph::element::f32, {}, { {0.01f, 0.02f, 0.03f} } },
200 std::vector<float>(4 * 4, 1.f),
201 ngraph::element::f32,
203 { 255, { 1, 1 }, {0.f}, {254.f}, {-12.7f}, {12.7} },
207 // not supported 3D: U8 & I8
209 LayerTransformation::createParamsU8I8(),
213 { ngraph::element::f32, {}, { 0.02f } },
215 std::vector<float>(4 * 4, 1.f),
219 {0.f, 0.f, 0.f, 0.f},
220 {254.f, 254.f, 254.f, 254.f},
221 {-12.7f / 4.f, -12.7f / 3.f, -12.7f / 2.f, -12.7f},
222 {12.7f / 4.f, 12.7f / 3.f, 12.7f / 2.f, 12.7f}
228 { ngraph::element::f32, {}, { 0.02f } },
231 std::vector<float>(4 * 4, 1.f),
232 ngraph::element::f32,
237 {0.f, 0.f, 0.f, 0.f},
238 {254.f, 254.f, 254.f, 254.f},
239 {-12.7f / 4.f, -12.7f / 3.f, -12.7f / 2.f, -12.7f},
240 {12.7f / 4.f, 12.7f / 3.f, 12.7f / 2.f, 12.7f}
247 LayerTransformation::createParamsU8I8(),
251 { ngraph::element::f32, {}, { 0.02f } },
253 std::vector<float>(2048 * 1000, 1.f),
254 { 255, { 1, 1 }, {0.f}, {254.f}, {-12.7f}, {12.7} },
262 std::vector<float>(2048 * 1000, -126),
264 { {}, {}, { 0.02f * 0.1f } },
270 LayerTransformation::createParamsI8I8(),
274 { ngraph::element::f32, {}, { 0.02f } },
276 std::vector<float>(2048 * 1000, 1.f),
277 { 255, { 1, 1 }, {0.f}, {254.f}, {-12.7f}, {12.7} },
285 std::vector<float>(2048 * 1000, -126),
287 { {}, {}, { 0.02f * 0.1f } },
293 LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
296 ngraph::element::f32,
297 { {}, {}, { 0.02f } },
299 std::vector<float>(2048 * 1000, 1.f),
300 { 255, { 1, 1 }, {0.f}, {254.f}, {-12.7f}, {12.7} },
304 ngraph::element::f32,
306 ngraph::element::f32,
308 std::vector<float>(2048 * 1000, -126),
309 ngraph::element::f32,
310 { {}, {}, { 0.02f * 0.1f } },
316 INSTANTIATE_TEST_CASE_P(
318 MatMulWithConstantTransformation,
320 ::testing::ValuesIn(precisions),
321 ::testing::ValuesIn(batches),
322 ::testing::ValuesIn(testValues)),
323 MatMulWithConstantTransformation::getTestCaseName);