0f9b29a4a7e064f538fa79b17ae6da9e4325c3f5
[platform/upstream/dldt.git] / inference-engine / src / low_precision_transformations / src / common / mat_mul.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "low_precision/mat_mul.hpp"
6
7 #include <numeric>
8 #include <memory>
9 #include <string>
10 #include <vector>
11
12 #include "low_precision/network_helper.hpp"
13 #include "low_precision/common/dequantization_op.hpp"
14
15 using namespace ngraph;
16 using namespace ngraph::pass;
17 using namespace ngraph::pass::low_precision;
18
19 bool MatMulTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
20     std::shared_ptr<ngraph::opset1::MatMul> matMul = as_type_ptr<ngraph::opset1::MatMul>(m.get_match_root());
21     if ((matMul == nullptr) || !canBeTransformed(context, matMul)) {
22         return false;
23     }
24
25     matMul = as_type_ptr<ngraph::opset1::MatMul>(separateInStandaloneBranch(matMul));
26
27     FakeQuantizeDequantization dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1);
28     if (dequantization2.empty()) {
29         const std::shared_ptr<opset1::FakeQuantize> fakeQuantize =
30             as_type_ptr<opset1::FakeQuantize>(dequantization2.data.get_node_shared_ptr());
31         if (fakeQuantize != nullptr) {
32             const QuantizationDetails quantizationDetails = QuantizationDetails::getDetails(fakeQuantize);
33             const DataPrecision dataPrecision = getDataPrecision(fakeQuantize, quantizationDetails, true);
34
35             auto tuple = NetworkHelper::decomposeFakeQuantize(
36                 fakeQuantize,
37                 dataPrecision.precision,
38                 dataPrecision.min,
39                 dataPrecision.max,
40                 dataPrecision.hasZeroPoint,
41                 updatePrecisions);
42
43             dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1);
44         }
45     }
46
47     const FakeQuantizeDequantization dequantization1 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 0);
48     std::shared_ptr<opset1::Subtract> subtract;
49     if (dequantization1.subtract != nullptr) {
50         std::shared_ptr<ngraph::Node> layer = dequantization1.subtract;
51         ngraph::pass::low_precision::NetworkHelper::cleanRunTimeInfo(layer);
52
53         auto optimizedSubtract = NetworkHelper::optimizeSubtract(dequantization1.subtract);
54         if (optimizedSubtract == nullptr) {
55             optimizedSubtract = dequantization1.subtract;
56         }
57         subtract = as_type_ptr<opset1::Subtract>(optimizedSubtract);
58     }
59
60     const std::shared_ptr<opset1::MatMul> newMatMul = std::make_shared<ngraph::op::TypeRelaxed<opset1::MatMul>>(
61         std::vector<element::Type>({ element::f32, element::f32 }), std::vector<element::Type>({}),
62         ngraph::op::TemporaryReplaceOutputType(dequantization1.subtract != nullptr ? subtract : dequantization1.data, element::f32).get(),
63         ngraph::op::TemporaryReplaceOutputType(dequantization2.subtract != nullptr ? dequantization2.subtract : dequantization2.data, element::f32).get(),
64         matMul->get_transpose_a(),
65         matMul->get_transpose_b());
66     NetworkHelper::setOutDataPrecisionForTypeRelaxed(newMatMul, matMul->get_output_element_type(0));
67
68     auto transpose = [](const std::shared_ptr<Node>& node) -> std::shared_ptr<Node> {
69         const Shape outputShape = node->get_output_shape(0);
70         if (outputShape.size() < 2ul) {
71             return node;
72         }
73
74         std::vector<uint32_t> transposeConstant(outputShape.size());
75         std::iota(transposeConstant.begin(), transposeConstant.end(), 0);
76         std::swap(*(transposeConstant.end() - 1), *(transposeConstant.end() - 2));
77
78         auto order = opset1::Constant::create(element::u32, Shape{ transposeConstant.size() }, transposeConstant);
79         std::shared_ptr<Node> transposedConstant = fold<ngraph::opset1::Transpose>(node, order);
80         return transposedConstant;
81     };
82
83     const std::shared_ptr<Node> const1 = matMul->get_transpose_a() ?
84         transpose(dequantization1.multiply->get_input_node_shared_ptr(1)) :
85         dequantization1.multiply->get_input_node_shared_ptr(1);
86
87     const std::shared_ptr<Node> const2 = matMul->get_transpose_b() ?
88         transpose(dequantization2.multiply->get_input_node_shared_ptr(1)) :
89         dequantization2.multiply->get_input_node_shared_ptr(1);
90
91     const std::shared_ptr<opset1::Multiply> newMultiply = std::make_shared<DequantizationMultiply>(
92         newMatMul,
93         NetworkHelper::toScalarIfPossible(
94             fold<ngraph::opset1::Multiply>(
95                 NetworkHelper::toScalar(as_type_ptr<opset1::Constant>(const1)),
96                 const2)));
97     replace_node(matMul, newMultiply);
98
99     updateOutput(context, newMultiply, matMul);
100
101     return true;
102 }
103
104 void MatMulTransformation::registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const {
105     addPattern(
106         pass,
107         context,
108         make_op_pattern<opset1::MatMul>({ make_op_label<ngraph::opset1::Multiply>(), make_op_label<ngraph::opset1::Multiply>() }));
109
110     addPattern(
111         pass,
112         context,
113         make_op_pattern<opset1::MatMul>({ make_op_label<ngraph::opset1::Multiply>(), make_op_label<ngraph::opset1::FakeQuantize>() }));
114 }
115
116 bool MatMulTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
117     return false;
118 }
119
120 bool MatMulTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
121     if (!LayerTransformation::canBeTransformed(context, layer)) {
122         return false;
123     }
124
125     if (!canSubtractBeHandled(layer)) {
126         return false;
127     }
128
129     const auto dequantization1 = ngraph::pass::low_precision::NetworkHelper::getDequantization(layer);
130     if (!NetworkHelper::isScalarLike(as_type_ptr<opset1::Constant>(dequantization1.multiply->get_input_node_shared_ptr(1)))) {
131         return false;
132     }
133
134     if (updatePrecisions && !dequantization1.empty() && !dequantization1.isLowPrecision()) {
135         return false;
136     }
137
138     if (updatePrecisions) {
139         const auto dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(layer, 1);
140         if (!dequantization2.empty() && !dequantization2.isLowPrecision()) {
141             return false;
142         }
143     }
144
145     const auto fakeQuantize = as_type_ptr<opset1::FakeQuantize>(layer->get_input_node_shared_ptr(1));
146     if (fakeQuantize != nullptr) {
147         if (!QuantizationDetails::outputLayoutIsSupported(fakeQuantize)) {
148             return false;
149         }
150
151         std::shared_ptr<opset1::MatMul> matMul = as_type_ptr<opset1::MatMul>(layer);
152         const size_t channelIndex1 = matMul->get_transpose_a() ? 0 : 1;
153         const size_t channelIndex2 = matMul->get_transpose_b() ? 1 : 0;
154
155         // for MatMul with 3D input the channel is 3'rd dimension (not 2'nd)
156         const Shape input1 = layer->input(0).get_shape();
157         const Shape input2 = layer->input(1).get_shape();
158         if ((input1[channelIndex1] != input2[channelIndex2]) &&
159             ((shape_size(dequantization1.multiply->input(1).get_shape()) > 1) ||
160             (shape_size(fakeQuantize->input(3).get_shape()) > 1) || (shape_size(fakeQuantize->input(4).get_shape()) > 1))) {
161             return false;
162         }
163     }
164
165     return true;
166 }