[LPT] integration branch: Reshape fix, Concat generalization, runtime info usage...
[platform/upstream/dldt.git] / inference-engine / src / low_precision_transformations / src / common / mat_mul.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "low_precision/mat_mul.hpp"
6
7 #include <numeric>
8 #include <memory>
9 #include <string>
10 #include <vector>
11
12 #include "low_precision/network_helper.hpp"
13 #include "low_precision/common/dequantization_op.hpp"
14
15 using namespace ngraph;
16 using namespace ngraph::pass;
17 using namespace ngraph::pass::low_precision;
18
19 bool MatMulTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
20     std::shared_ptr<ngraph::opset1::MatMul> matMul = as_type_ptr<ngraph::opset1::MatMul>(m.get_match_root());
21     if ((matMul == nullptr) || !canBeTransformed(context, matMul)) {
22         return false;
23     }
24
25     matMul = as_type_ptr<ngraph::opset1::MatMul>(separateInStandaloneBranch(matMul));
26
27     FakeQuantizeDequantization dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1);
28     if (dequantization2.empty()) {
29         const std::shared_ptr<opset1::FakeQuantize> fakeQuantize =
30             as_type_ptr<opset1::FakeQuantize>(dequantization2.data.get_node_shared_ptr());
31         if (fakeQuantize != nullptr) {
32             const QuantizationDetails quantizationDetails = QuantizationDetails::getDetails(fakeQuantize);
33             const DataPrecision dataPrecision = getDataPrecision(fakeQuantize, quantizationDetails, true);
34
35             auto tuple = NetworkHelper::decomposeFakeQuantize(
36                 fakeQuantize,
37                 dataPrecision.precision,
38                 dataPrecision.min,
39                 dataPrecision.max,
40                 dataPrecision.hasZeroPoint,
41                 updatePrecisions);
42
43             dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1);
44         }
45     }
46
47     const FakeQuantizeDequantization dequantization1 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 0);
48     std::shared_ptr<opset1::Subtract> subtract;
49     if (dequantization1.subtract != nullptr) {
50         std::shared_ptr<ngraph::Node> layer = dequantization1.subtract;
51         ngraph::pass::low_precision::NetworkHelper::cleanRunTimeInfo(layer);
52
53         auto optimizedSubtract = NetworkHelper::optimizeSubtract(dequantization1.subtract);
54         if (optimizedSubtract == nullptr) {
55             optimizedSubtract = dequantization1.subtract;
56         }
57         subtract = as_type_ptr<opset1::Subtract>(optimizedSubtract);
58     }
59
60     const std::shared_ptr<opset1::MatMul> newMatMul = std::make_shared<ngraph::op::TypeRelaxed<opset1::MatMul>>(
61         std::vector<element::Type>({ element::f32, element::f32 }), std::vector<element::Type>({}),
62         ngraph::op::TemporaryReplaceOutputType(dequantization1.subtract != nullptr ? subtract : dequantization1.data, element::f32).get(),
63         ngraph::op::TemporaryReplaceOutputType(dequantization2.subtract != nullptr ? dequantization2.subtract : dequantization2.data, element::f32).get(),
64         matMul->get_transpose_a(),
65         matMul->get_transpose_b());
66     NetworkHelper::setOutDataPrecisionForTypeRelaxed(newMatMul, matMul->get_output_element_type(0));
67     NetworkHelper::copyInfo(matMul, newMatMul);
68
69     auto transpose = [](const std::shared_ptr<Node>& node) -> std::shared_ptr<Node> {
70         const Shape outputShape = node->get_output_shape(0);
71         if (outputShape.size() < 2ul) {
72             return node;
73         }
74
75         std::vector<uint32_t> transposeConstant(outputShape.size());
76         std::iota(transposeConstant.begin(), transposeConstant.end(), 0);
77         std::swap(*(transposeConstant.end() - 1), *(transposeConstant.end() - 2));
78
79         auto order = opset1::Constant::create(element::u32, Shape{ transposeConstant.size() }, transposeConstant);
80         std::shared_ptr<Node> transposedConstant = fold<ngraph::opset1::Transpose>(node, order);
81         return transposedConstant;
82     };
83
84     const std::shared_ptr<Node> const1 = matMul->get_transpose_a() ?
85         transpose(dequantization1.multiply->get_input_node_shared_ptr(1)) :
86         dequantization1.multiply->get_input_node_shared_ptr(1);
87
88     const std::shared_ptr<Node> const2 = matMul->get_transpose_b() ?
89         transpose(dequantization2.multiply->get_input_node_shared_ptr(1)) :
90         dequantization2.multiply->get_input_node_shared_ptr(1);
91
92     const std::shared_ptr<opset1::Multiply> newMultiply = std::make_shared<DequantizationMultiply>(
93         newMatMul,
94         NetworkHelper::toScalarIfPossible(
95             fold<ngraph::opset1::Multiply>(
96                 NetworkHelper::toScalar(as_type_ptr<opset1::Constant>(const1)),
97                 const2)));
98     replace_node(matMul, newMultiply);
99     ngraph::copy_runtime_info({ newMultiply, matMul }, newMultiply);
100
101     updateOutput(context, newMultiply, matMul);
102
103     return true;
104 }
105
106 void MatMulTransformation::registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const {
107     addPattern(
108         pass,
109         context,
110         make_op_pattern<opset1::MatMul>({ make_op_label<ngraph::opset1::Multiply>(), make_op_label<ngraph::opset1::Multiply>() }));
111
112     addPattern(
113         pass,
114         context,
115         make_op_pattern<opset1::MatMul>({ make_op_label<ngraph::opset1::Multiply>(), make_op_label<ngraph::opset1::FakeQuantize>() }));
116 }
117
118 bool MatMulTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
119     return false;
120 }
121
122 bool MatMulTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
123     if (!LayerTransformation::canBeTransformed(context, layer)) {
124         return false;
125     }
126
127     if (!canSubtractBeHandled(layer)) {
128         return false;
129     }
130
131     const auto dequantization1 = ngraph::pass::low_precision::NetworkHelper::getDequantization(layer);
132     if (!NetworkHelper::isScalarLike(as_type_ptr<opset1::Constant>(dequantization1.multiply->get_input_node_shared_ptr(1)))) {
133         return false;
134     }
135
136     if (updatePrecisions && !dequantization1.empty() && !dequantization1.isLowPrecision()) {
137         return false;
138     }
139
140     if (updatePrecisions) {
141         const auto dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(layer, 1);
142         if (!dequantization2.empty() && !dequantization2.isLowPrecision()) {
143             return false;
144         }
145     }
146
147     const auto fakeQuantize = as_type_ptr<opset1::FakeQuantize>(layer->get_input_node_shared_ptr(1));
148     if (fakeQuantize != nullptr) {
149         if (!QuantizationDetails::outputLayoutIsSupported(fakeQuantize)) {
150             return false;
151         }
152
153         std::shared_ptr<opset1::MatMul> matMul = as_type_ptr<opset1::MatMul>(layer);
154         const size_t channelIndex1 = matMul->get_transpose_a() ? 0 : 1;
155         const size_t channelIndex2 = matMul->get_transpose_b() ? 1 : 0;
156
157         // for MatMul with 3D input the channel is 3'rd dimension (not 2'nd)
158         const Shape input1 = layer->input(0).get_shape();
159         const Shape input2 = layer->input(1).get_shape();
160         if ((input1[channelIndex1] != input2[channelIndex2]) &&
161             ((shape_size(dequantization1.multiply->input(1).get_shape()) > 1) ||
162             (shape_size(fakeQuantize->input(3).get_shape()) > 1) || (shape_size(fakeQuantize->input(4).get_shape()) > 1))) {
163             return false;
164         }
165     }
166
167     return true;
168 }