1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "low_precision/mat_mul.hpp"
12 #include "low_precision/network_helper.hpp"
13 #include "low_precision/common/dequantization_op.hpp"
15 using namespace ngraph;
16 using namespace ngraph::pass;
17 using namespace ngraph::pass::low_precision;
19 bool MatMulTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
20 std::shared_ptr<ngraph::opset1::MatMul> matMul = as_type_ptr<ngraph::opset1::MatMul>(m.get_match_root());
21 if ((matMul == nullptr) || !canBeTransformed(context, matMul)) {
25 matMul = as_type_ptr<ngraph::opset1::MatMul>(separateInStandaloneBranch(matMul));
27 FakeQuantizeDequantization dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1);
28 if (dequantization2.empty()) {
29 const std::shared_ptr<opset1::FakeQuantize> fakeQuantize =
30 as_type_ptr<opset1::FakeQuantize>(dequantization2.data.get_node_shared_ptr());
31 if (fakeQuantize != nullptr) {
32 const QuantizationDetails quantizationDetails = QuantizationDetails::getDetails(fakeQuantize);
33 const DataPrecision dataPrecision = getDataPrecision(fakeQuantize, quantizationDetails, true);
35 auto tuple = NetworkHelper::decomposeFakeQuantize(
37 dataPrecision.precision,
40 dataPrecision.hasZeroPoint,
43 dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1);
47 const FakeQuantizeDequantization dequantization1 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 0);
48 std::shared_ptr<opset1::Subtract> subtract;
49 if (dequantization1.subtract != nullptr) {
50 std::shared_ptr<ngraph::Node> layer = dequantization1.subtract;
51 ngraph::pass::low_precision::NetworkHelper::cleanRunTimeInfo(layer);
53 auto optimizedSubtract = NetworkHelper::optimizeSubtract(dequantization1.subtract);
54 if (optimizedSubtract == nullptr) {
55 optimizedSubtract = dequantization1.subtract;
57 subtract = as_type_ptr<opset1::Subtract>(optimizedSubtract);
60 const std::shared_ptr<opset1::MatMul> newMatMul = std::make_shared<ngraph::op::TypeRelaxed<opset1::MatMul>>(
61 std::vector<element::Type>({ element::f32, element::f32 }), std::vector<element::Type>({}),
62 ngraph::op::TemporaryReplaceOutputType(dequantization1.subtract != nullptr ? subtract : dequantization1.data, element::f32).get(),
63 ngraph::op::TemporaryReplaceOutputType(dequantization2.subtract != nullptr ? dequantization2.subtract : dequantization2.data, element::f32).get(),
64 matMul->get_transpose_a(),
65 matMul->get_transpose_b());
66 NetworkHelper::setOutDataPrecisionForTypeRelaxed(newMatMul, matMul->get_output_element_type(0));
68 auto transpose = [](const std::shared_ptr<Node>& node) -> std::shared_ptr<Node> {
69 const Shape outputShape = node->get_output_shape(0);
70 if (outputShape.size() < 2ul) {
74 std::vector<uint32_t> transposeConstant(outputShape.size());
75 std::iota(transposeConstant.begin(), transposeConstant.end(), 0);
76 std::swap(*(transposeConstant.end() - 1), *(transposeConstant.end() - 2));
78 auto order = opset1::Constant::create(element::u32, Shape{ transposeConstant.size() }, transposeConstant);
79 std::shared_ptr<Node> transposedConstant = fold<ngraph::opset1::Transpose>(node, order);
80 return transposedConstant;
83 const std::shared_ptr<Node> const1 = matMul->get_transpose_a() ?
84 transpose(dequantization1.multiply->get_input_node_shared_ptr(1)) :
85 dequantization1.multiply->get_input_node_shared_ptr(1);
87 const std::shared_ptr<Node> const2 = matMul->get_transpose_b() ?
88 transpose(dequantization2.multiply->get_input_node_shared_ptr(1)) :
89 dequantization2.multiply->get_input_node_shared_ptr(1);
91 const std::shared_ptr<opset1::Multiply> newMultiply = std::make_shared<DequantizationMultiply>(
93 NetworkHelper::toScalarIfPossible(
94 fold<ngraph::opset1::Multiply>(
95 NetworkHelper::toScalar(as_type_ptr<opset1::Constant>(const1)),
97 replace_node(matMul, newMultiply);
99 updateOutput(context, newMultiply, matMul);
104 void MatMulTransformation::registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const {
108 make_op_pattern<opset1::MatMul>({ make_op_label<ngraph::opset1::Multiply>(), make_op_label<ngraph::opset1::Multiply>() }));
113 make_op_pattern<opset1::MatMul>({ make_op_label<ngraph::opset1::Multiply>(), make_op_label<ngraph::opset1::FakeQuantize>() }));
116 bool MatMulTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
120 bool MatMulTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
121 if (!LayerTransformation::canBeTransformed(context, layer)) {
125 if (!canSubtractBeHandled(layer)) {
129 const auto dequantization1 = ngraph::pass::low_precision::NetworkHelper::getDequantization(layer);
130 if (!NetworkHelper::isScalarLike(as_type_ptr<opset1::Constant>(dequantization1.multiply->get_input_node_shared_ptr(1)))) {
134 if (updatePrecisions && !dequantization1.empty() && !dequantization1.isLowPrecision()) {
138 if (updatePrecisions) {
139 const auto dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(layer, 1);
140 if (!dequantization2.empty() && !dequantization2.isLowPrecision()) {
145 const auto fakeQuantize = as_type_ptr<opset1::FakeQuantize>(layer->get_input_node_shared_ptr(1));
146 if (fakeQuantize != nullptr) {
147 if (!QuantizationDetails::outputLayoutIsSupported(fakeQuantize)) {
151 std::shared_ptr<opset1::MatMul> matMul = as_type_ptr<opset1::MatMul>(layer);
152 const size_t channelIndex1 = matMul->get_transpose_a() ? 0 : 1;
153 const size_t channelIndex2 = matMul->get_transpose_b() ? 1 : 0;
155 // for MatMul with 3D input the channel is 3'rd dimension (not 2'nd)
156 const Shape input1 = layer->input(0).get_shape();
157 const Shape input2 = layer->input(1).get_shape();
158 if ((input1[channelIndex1] != input2[channelIndex2]) &&
159 ((shape_size(dequantization1.multiply->input(1).get_shape()) > 1) ||
160 (shape_size(fakeQuantize->input(3).get_shape()) > 1) || (shape_size(fakeQuantize->input(4).get_shape()) > 1))) {