[LPT] integration branch: Reshape fix, Concat generalization, runtime info usage...
[platform/upstream/dldt.git] / inference-engine / src / low_precision_transformations / src / common / convolution.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "low_precision/convolution.hpp"
6
7 #include <algorithm>
8 #include <memory>
9 #include <string>
10 #include <vector>
11 #include <cassert>
12
13 #include "low_precision/network_helper.hpp"
14 #include "low_precision/common/dequantization_op.hpp"
15
16 namespace ngraph {
17 namespace pass {
18 namespace low_precision {
19
20 ConvolutionTransformation::ConvolutionTransformation(const Params& params) : WeightableLayerTransformation(params) {
21 }
22
23 void ConvolutionTransformation::registerMatcherIn(GraphRewrite &pass, TransformationContext &context) const {
24     addPattern(
25         pass,
26         context,
27         make_op_pattern<opset1::Convolution>({ make_op_label<opset1::Multiply>(), make_op_label<opset1::FakeQuantize>()}));
28 }
29
30 bool ConvolutionTransformation::isQuantized(std::shared_ptr<Node> layer) const noexcept {
31     return WeightableLayerTransformation::isQuantized(layer, false);
32 }
33
34 bool ConvolutionTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
35     auto convolution = m.get_match_root();
36
37     if (!WeightableLayerTransformation::canBeTransformed(context, convolution)) {
38         return false;
39     }
40
41     FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(convolution);
42     if (!canSubtractBeHandled(convolution, dequantization)) {
43         return false;
44     }
45
46     if ((!supportAsymmetricQuantization) && getDataPrecisionOnWeights(convolution).hasZeroPoint) {
47         return false;
48     }
49
50     if (updatePrecisions && !dequantization.empty() && !dequantization.isLowPrecision()) {
51         return false;
52     }
53
54     convolution = separateInStandaloneBranch(convolution);
55     dequantization = NetworkHelper::getDequantization(convolution);
56
57     {
58         std::shared_ptr<opset1::Subtract> subtract;
59         if (dequantization.subtract != nullptr) {
60             std::shared_ptr<ngraph::Node> layer = dequantization.subtract;
61             ngraph::pass::low_precision::NetworkHelper::cleanRunTimeInfo(layer);
62
63             auto optimizedSubtract = NetworkHelper::optimizeSubtract(dequantization.subtract);
64             if (optimizedSubtract == nullptr) {
65                 optimizedSubtract = dequantization.subtract;
66             }
67             subtract = as_type_ptr<opset1::Subtract>(optimizedSubtract);
68         }
69
70         // workaround normalizes shape of Subtract to match CPU plugin expectations
71         if (subtract && subtract->get_output_partial_shape(0) != subtract->get_input_partial_shape(1)) {
72             size_t length = subtract->get_output_partial_shape(0).rank().get_length();
73
74             // Insert explicit broadcast for channel dimension [1] and immediately fold it
75             Shape broadcastShape(subtract->get_output_partial_shape(0).rank().get_length(), 1);
76             broadcastShape[1] = subtract->get_output_shape(0)[1];
77
78             std::shared_ptr<Node> newShift = fold<opset1::Broadcast>(
79                 subtract->input_value(1).get_node_shared_ptr(),
80                 std::make_shared<opset1::Constant>(
81                     element::i64,
82                     Shape{ length },
83                     broadcastShape));
84
85             const auto newSubtract = as_type_ptr<opset1::Subtract>(subtract->clone_with_new_inputs({
86                 subtract->input_value(0).get_node_shared_ptr(),
87                 newShift }));
88             replace_node(subtract, newSubtract);
89
90             newSubtract->set_output_type(0, subtract->get_output_element_type(0), newSubtract->get_output_partial_shape(0));
91             subtract = newSubtract;
92         }
93
94         const size_t groupsCount = NetworkHelper::getGroupsCount(convolution);
95         std::shared_ptr<Node> newMultiplyAfterConst;
96         if (groupsCount > 1ul) {
97             std::shared_ptr<opset1::Constant> multiplyConst = as_type_ptr<opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(1));
98
99             const std::vector<float> scales = multiplyConst->cast_vector<float>();
100             if (scales.size() == 1ul) {
101                 newMultiplyAfterConst = dequantization.multiply->input_value(1).get_node_shared_ptr()->clone_with_new_inputs({});
102             } else {
103                 const ngraph::Shape inputShape = convolution->get_input_shape(0);
104                 const size_t inputChannelsInGroup = inputShape[1] / groupsCount;
105                 const ngraph::Shape outputShape = convolution->get_output_shape(0);
106                 std::vector<float> outputScales(outputShape[1]);
107
108                 const size_t outputChannelsInGroup = outputShape[1] / groupsCount;
109                 for (size_t group = 0; group < groupsCount; ++group) {
110                     const float scaleValue = scales[group * inputChannelsInGroup];
111
112                     for (size_t i = 0; i < outputChannelsInGroup; ++i) {
113                         size_t index = group * outputChannelsInGroup + i;
114                         outputScales[index] = scaleValue;
115                     }
116                 }
117
118                 auto newMulShape = Shape{ outputScales.size() };
119                 for (size_t i = 0; i < convolution->get_output_shape(0).size() - 2; ++i) {
120                     newMulShape.push_back(1ul);
121                 }
122
123                 newMultiplyAfterConst = std::make_shared<opset1::Constant>(
124                     dequantization.multiply->get_output_element_type(0),
125                     newMulShape,
126                     outputScales);
127             }
128         } else {
129             std::shared_ptr<opset1::Constant> reducedConstant = as_type_ptr<opset1::Constant>(
130                 dequantization.multiply->input_value(1).get_node_shared_ptr());
131             newMultiplyAfterConst = std::make_shared<opset1::Constant>(
132                 reducedConstant->get_output_element_type(0),
133                 Shape{ 1 },
134                 reducedConstant->cast_vector<float>()[0]);
135         }
136
137         auto newConvolution = convolution->copy_with_new_inputs({ dequantization.multiply->input_value(0), convolution->input_value(1) });
138         std::shared_ptr<ngraph::opset1::Multiply> newMultiplyAfter = std::make_shared<op::TypeRelaxed<DequantizationMultiply>>(
139             std::vector<element::Type>{ element::f32, element::f32 }, std::vector<element::Type>{ element::f32 },
140             ngraph::op::TemporaryReplaceOutputType(newConvolution, element::f32).get(),
141             ngraph::op::TemporaryReplaceOutputType(newMultiplyAfterConst, element::f32).get());
142
143         replace_node(convolution, newMultiplyAfter);
144         convolution = newMultiplyAfter->input_value(0).get_node_shared_ptr();
145
146         if (is_type<opset1::Convert>(convolution->get_input_node_ptr(0))) {
147             auto newConvolution = convolution->clone_with_new_inputs({
148                 convolution->get_input_node_ptr(0)->get_input_node_shared_ptr(0),
149                 convolution->get_input_node_shared_ptr(1) });
150             replace_node(convolution, newConvolution);
151             convolution = newConvolution;
152         }
153     }
154
155     {
156         decomposeFakeQuantizeForWeightsPath(convolution);
157
158         std::shared_ptr<opset1::Reshape> reshapeFromWeights = as_type_ptr<opset1::Reshape>(convolution->input_value(1).get_node_shared_ptr());
159         std::shared_ptr<opset1::Multiply> multiplyFromWeights = as_type_ptr<opset1::Multiply>(
160             reshapeFromWeights == nullptr ?
161             convolution->input_value(1).get_node_shared_ptr() :
162             convolution->get_input_node_ptr(1)->get_input_node_shared_ptr(0));
163         std::shared_ptr<opset1::Subtract> subtractFromWeights = as_type_ptr<opset1::Subtract>(multiplyFromWeights->get_input_node_shared_ptr(0));
164
165         {
166             Shape newScaleShape = multiplyFromWeights->get_input_shape(1);
167             // that's all we need: [C, 1, 1, 1] => [C, 1, 1]
168             newScaleShape.pop_back();
169
170             if (reshapeFromWeights != nullptr) {
171                 reshapeFromWeights = as_type_ptr<opset1::Reshape>(reshapeFromWeights->copy_with_new_inputs({
172                     multiplyFromWeights->input_value(0),
173                     reshapeFromWeights->input_value(1) }));
174             }
175
176             auto newMultiplyAfter = std::make_shared<DequantizationMultiply>(
177                 convolution->copy_with_new_inputs({
178                     convolution->input_value(0),
179                     reshapeFromWeights != nullptr ?
180                         reshapeFromWeights :
181                         multiplyFromWeights->input_value(0)
182                     }),
183                 fold_reshape<opset1::Reshape>(
184                     multiplyFromWeights->input_value(1),
185                     std::make_shared<opset1::Constant>(element::u64, Shape{ newScaleShape.size() }, newScaleShape),
186                     false));
187             replace_node(convolution, newMultiplyAfter);
188             convolution = newMultiplyAfter->input_value(0).get_node_shared_ptr();
189         }
190
191         if (subtractFromWeights != nullptr) {
192             auto optimizedSubtract = NetworkHelper::optimizeSubtract(subtractFromWeights);
193             // TODO: handle optimizedSubtract == nullptr;
194             if (optimizedSubtract != nullptr) {
195                 subtractFromWeights = as_type_ptr<opset1::Subtract>(optimizedSubtract);
196
197                 const Shape weightsShape = subtractFromWeights->input(0).get_shape();
198                 Shape zeroPointShape(weightsShape.size(), 1ul);
199                 zeroPointShape[0] = weightsShape[0];
200
201                 auto zeroPointConstant = fold<opset1::Broadcast>(
202                     subtractFromWeights->get_input_node_shared_ptr(1),
203                     std::make_shared<opset1::Constant>(element::i32, Shape{ zeroPointShape.size() }, zeroPointShape));
204                 replace_node(subtractFromWeights->get_input_node_shared_ptr(1), zeroPointConstant);
205             }
206         }
207
208         std::shared_ptr<opset1::Convert> convertFromWeights = as_type_ptr<opset1::Convert>(subtractFromWeights == nullptr ?
209             multiplyFromWeights->get_input_node_shared_ptr(0) :
210             subtractFromWeights->get_input_node_shared_ptr(0));
211
212         if (convertFromWeights != nullptr) {
213             std::shared_ptr<Node> childNode = reshapeFromWeights == nullptr ? convolution : reshapeFromWeights;
214
215             auto newConvolution = convolution->clone_with_new_inputs({
216                 convolution->get_input_node_shared_ptr(0),
217                 childNode.get() == convolution.get() ?
218                     convolution->get_input_node_ptr(1)->get_input_node_shared_ptr(0) :
219                     childNode->copy_with_new_inputs({convertFromWeights->input_value(0), childNode->input_value(1)})});
220             replace_node(convolution, newConvolution);
221             convolution = newConvolution;
222         }
223
224         reshapeFromWeights = as_type_ptr<opset1::Reshape>(convolution->get_input_node_shared_ptr(1));
225         if (reshapeFromWeights != nullptr) {
226             const std::shared_ptr<Node> newWeights = fold_reshape<opset1::Reshape>(
227                 reshapeFromWeights->input_value(0),
228                 reshapeFromWeights->input_value(1),
229                 false);
230
231             replace_node(reshapeFromWeights, newWeights);
232         }
233     }
234
235     std::shared_ptr<ngraph::opset1::Multiply> finalDequantization = NetworkHelper::optimizeMultipliesAfter(
236         convolution->output(0).get_target_inputs().begin()->get_node()->shared_from_this());
237     ngraph::copy_runtime_info({ convolution, finalDequantization }, finalDequantization);
238     updateOutput(context, finalDequantization, convolution);
239
240     auto onWeights = convolution->get_input_node_shared_ptr(1);
241     if (is_type<opset1::Reshape>(onWeights)) {
242         onWeights = onWeights->get_input_node_shared_ptr(0);
243     }
244
245     if (is_type<opset1::Subtract>(onWeights)) {
246         auto& rt = onWeights->get_rt_info();
247         rt["DISABLED_CONSTANT_FOLDING"] = std::make_shared<ngraph::VariantWrapper<std::string>>("");
248     }
249     return true;
250 }
251
252 } // namespace low_precision
253 } // namespace pass
254 } // namespace ngraph