[LPT] integration branch: Reshape fix, Concat generalization, runtime info usage...
[platform/upstream/dldt.git] / inference-engine / src / low_precision_transformations / src / common / reshape.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "low_precision/reshape.hpp"
6
7 #include <algorithm>
8 #include <memory>
9 #include <string>
10 #include <unordered_set>
11 #include <utility>
12 #include <vector>
13
14 #include "low_precision/common/ie_lpt_exception.hpp"
15 #include "low_precision/network_helper.hpp"
16
17 namespace ngraph {
18 namespace pass {
19 namespace low_precision {
20
21 void ReshapeTransformation::registerMatcherIn(GraphRewrite &pass, TransformationContext &context) const {
22     addPattern(
23         pass,
24         context,
25         make_op_pattern<opset1::Reshape>({ make_op_label<opset1::Multiply>(), make_op_label<opset1::Constant>() }));
26 }
27
28 void reshapeDequantizationConstant(const std::shared_ptr<opset1::Reshape>& reshape) {
29     const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(reshape, 0);
30     if (dequantization.multiply->get_input_node_ptr(1)->get_output_shape(0).size() > 1ul) {
31         // Reshape Subtract or Multiply operation Constant.
32         //    1. modify reshape parameters to avoid reshape by spatial dimensions
33         //    2. broadcast element-wise constant if channels are changed
34         //    3. reshape element-wise constant with modified reshape parameters
35         auto replaceConstant = [](const std::shared_ptr<opset1::Reshape>& reshape, const std::shared_ptr<Node>& op) {
36             const size_t constantIndex = as_type<ngraph::opset1::Constant>(op->get_input_node_ptr(1)) ? 1 : 0;
37             const Shape constantShape = op->input(constantIndex).get_shape();
38             // reshape for element-wise constant is not required
39             if (constantShape.empty() || (constantShape.size() == 1ul)) {
40                 return;
41             }
42
43             // simple broadcast operation Constant shape to shape on activations
44             auto newOperationConstantShape = op->input(1).get_shape();
45             auto const reshapeInputShape = reshape->input(0).get_shape();
46             Shape newOperationConstantBroadcastedShape(reshapeInputShape);
47             newOperationConstantBroadcastedShape[0] = 1ul;
48
49             if ((reshapeInputShape.size() - newOperationConstantShape.size()) == 1ul) {
50                 newOperationConstantShape.insert(newOperationConstantShape.begin(), 1ul);
51             }
52             const std::shared_ptr<opset1::Constant> originalConstant = as_type_ptr<opset1::Constant>(op->get_input_node_shared_ptr(1));
53             const std::shared_ptr<opset1::Constant> newOperationConstant = std::make_shared<opset1::Constant>(
54                 op->input(1).get_element_type(),
55                 newOperationConstantShape,
56                 originalConstant->cast_vector<float>());
57
58             // reshape -1 value hanling
59             auto getOverallValue = [](const Shape& shape, const std::vector<int>& reshapeValues, const bool specialZero) -> size_t {
60                 size_t overallValue = shape_size(shape);
61                 for (size_t i = 0; i < reshapeValues.size(); ++i) {
62                     auto reshapeValue = reshapeValues[i];
63                     if ((reshapeValue == 1ul) || (reshapeValue == -1) || ((reshapeValue == 0ul) && !specialZero)) {
64                         continue;
65                     }
66
67                     if ((reshapeValue == 0ul) && specialZero) {
68                         reshapeValue = shape[i];
69                     }
70
71                     overallValue = overallValue / reshapeValue;
72                 }
73                 return overallValue;
74             };
75
76             // modify reshape constant for element-wise constant reshape
77             // element-wise constant doesn't have spatial dimensions, as result we should remove spatial dimensions from reshape parameters
78             const std::vector<int> reshapeConstValues = as_type_ptr<opset1::Constant>(reshape->get_input_node_shared_ptr(1))->cast_vector<int>();
79
80             size_t overallValue = 0;
81             for (size_t i = 0; i < reshapeConstValues.size(); ++i) {
82                 if (reshapeConstValues[i] == -1) {
83                     overallValue = getOverallValue(
84                         reshapeInputShape,
85                         reshapeConstValues,
86                         as_type_ptr<opset1::Reshape>(reshape)->get_special_zero());
87                     break;
88                 }
89             }
90
91             std::vector<int> newReshapeConstValues(reshapeConstValues);
92             for (int i = static_cast<int>(newReshapeConstValues.size() - 1); i >= 0; --i) {
93                 if (newOperationConstantShape.size() <= i) {
94                     // new dimension was added
95                     newReshapeConstValues[i] = 1;
96                 } else if (newOperationConstantShape[i] == 1ul) {
97                     // keep the same
98                     newReshapeConstValues[i] = 1;
99                 } else if (newReshapeConstValues[i] == -1) {
100                     // modified reshape parameters are different, but value instead '-1' has to be equal as original reshape
101                     newReshapeConstValues[i] = overallValue;
102                 }
103             }
104
105             const std::shared_ptr<opset1::Constant> newReshapeConstant = std::make_shared<opset1::Constant>(
106                 reshape->input(1).get_element_type(),
107                 Shape({ newReshapeConstValues.size() }),
108                 newReshapeConstValues);
109
110             // if channels are different then broadcast spatial dimensions to reshape channels correctly
111             // limitation which has to be covered by canBeTransformed:
112             //    1. spatial dimensions have to be absent or equal to 1 after reshape
113             //    2. only second dimension can be changed
114
115             const bool shouldBroadcast = (shape_size(newReshapeConstValues) != 1ul) && (reshapeConstValues[1] != 0) &&
116                 (((reshapeConstValues[1] != -1) && (constantShape[1] != reshapeConstValues[1])) ||
117                 ((reshapeConstValues[1] == -1) && (constantShape[1] != overallValue)));
118
119             const std::shared_ptr<Node> broadcastedConstant = shouldBroadcast ?
120                 fold<opset1::Broadcast>(
121                     newOperationConstant,
122                     std::make_shared<opset1::Constant>(
123                         element::i32,
124                         Shape({newOperationConstantBroadcastedShape.size()}),
125                         newOperationConstantBroadcastedShape)) :
126                 newOperationConstant;
127
128             const std::shared_ptr<Node> resultConstant = fold<opset1::Reshape>(
129                 broadcastedConstant,
130                 newReshapeConstant,
131                 reshape->get_special_zero());
132
133             replace_node(op->get_input_node_shared_ptr(1), resultConstant);
134         };
135
136         if (dequantization.subtract != nullptr) {
137             replaceConstant(reshape, dequantization.subtract);
138         }
139
140         if (dequantization.multiply != nullptr) {
141             replaceConstant(reshape, dequantization.multiply);
142         }
143     }
144 }
145
146 bool ReshapeTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) const {
147     std::shared_ptr<opset1::Reshape> reshape = as_type_ptr<opset1::Reshape>(m.get_match_root());
148     if ((reshape == nullptr) || (!canBeTransformed(context, reshape))) {
149         return false;
150     }
151
152     reshape = as_type_ptr<opset1::Reshape>(separateInStandaloneBranch(reshape));
153     reshapeDequantizationConstant(reshape);
154     moveDequantizationAfter(context, reshape, NetworkHelper::getDequantization(reshape, 0), false);
155     return true;
156 }
157
158 bool ReshapeTransformation::isPrecisionPreserved(std::shared_ptr<Node> op) const noexcept {
159     return true;
160 }
161
162 size_t getLastNotBroadcastedChannel(const Shape& shape) {
163     for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
164         if (shape[i] != 1ul) {
165             return i;
166         }
167     }
168     return 0;
169 }
170
171 size_t getFirstChangedChannel(const Shape& shape1, const Shape& shape2) {
172     const size_t minSize = std::min(shape1.size(), shape2.size());
173     size_t i = 0;
174     for (; i < minSize; ++i) {
175         if (shape1[i] != shape2[i]) {
176             return i;
177         }
178     }
179     return i;
180 }
181
182 bool ReshapeTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const {
183     if (!LayerTransformation::canBeTransformed(context, op)) {
184         return false;
185     }
186
187     const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(op);
188     if (dequantization.empty()) {
189         return false;
190     }
191
192     const Shape subtractShape = dequantization.subtract == nullptr ? Shape{} : dequantization.subtract->input(1).get_shape();
193     Shape subtractShapeWithBatch = subtractShape;
194     const Shape inputShape = op->get_input_shape(0);
195     if ((dequantization.subtract != nullptr) &&
196         (subtractShapeWithBatch.size() > 1) &&
197         (subtractShapeWithBatch.size() < inputShape.size())) {
198         subtractShapeWithBatch.insert(subtractShapeWithBatch.begin(), inputShape[0]);
199     }
200
201     const Shape multiplyShape = dequantization.multiply == nullptr ? Shape{} : dequantization.multiply->input(1).get_shape();
202     Shape multiplyShapeWithBatch = multiplyShape;
203     if ((dequantization.multiply != nullptr) &&
204         (multiplyShapeWithBatch.size() > 1) &&
205         (multiplyShapeWithBatch.size() < inputShape.size())) {
206         multiplyShapeWithBatch.insert(multiplyShapeWithBatch.begin(), inputShape[0]);
207     }
208
209     const Shape outputShape = op->get_output_shape(0);
210     return canBeTransformed(subtractShapeWithBatch, multiplyShapeWithBatch, inputShape, outputShape);
211 }
212
213 size_t getChannelVolume(const Shape& shape) {
214     size_t volume = 1ul;
215     for (size_t i = 2; i < shape.size(); ++i) {
216         volume = volume * shape[i];
217     }
218     return volume;
219 }
220
221 bool ReshapeTransformation::canBeTransformed(
222     const ngraph::Shape& subtractShape,
223     const ngraph::Shape& multiplyShape,
224     const ngraph::Shape& inputShape,
225     const ngraph::Shape& outputShape) {
226     if ((inputShape.size() < 2ul) || (outputShape.size() < 2ul) || (inputShape[0] != outputShape[0])) {
227         return false;
228     }
229
230     // TODO: story 38439
231     if ((inputShape.size() == 4ul) && (outputShape.size() == 2ul)) {
232         auto checkSpatialDimensions = [](const Shape& dequantizationConstShape) {
233             for (size_t i = (dequantizationConstShape.size() - 2); i < dequantizationConstShape.size(); ++i) {
234                 if (dequantizationConstShape[i] != 1ul) {
235                     return false;
236                 }
237             }
238             return true;
239         };
240
241         if (((subtractShape.size() >= 3ul) && (!checkSpatialDimensions(subtractShape))) ||
242             ((multiplyShape.size() >= 3ul) && (!checkSpatialDimensions(multiplyShape)))) {
243             return false;
244         }
245
246         // custom validation for Layout::NCHW => Layout::NC
247         const size_t inputChannelsCount = inputShape.size() > 1ul ? inputShape[1] : inputShape[0];
248         const size_t outputChannelsCount = outputShape.size() > 1ul ? outputShape[1] : outputShape[0];
249         if ((inputShape[0] != outputShape[0]) || ((inputChannelsCount * getChannelVolume(inputShape)) != outputChannelsCount)) {
250             return false;
251         }
252     } else {
253         for (size_t i = 0; i < 2ul; ++i) {
254             if (inputShape[i] != outputShape[i]) {
255                 return false;
256             }
257         }
258
259         const size_t lastNotBroadcastedChannel = std::max(getLastNotBroadcastedChannel(subtractShape), getLastNotBroadcastedChannel(multiplyShape));
260         const size_t firstChangedChannel = getFirstChangedChannel(inputShape, outputShape);
261         if (lastNotBroadcastedChannel >= firstChangedChannel) {
262             return false;
263         }
264     }
265
266     return true;
267 }
268
269 } // namespace low_precision
270 } // namespace pass
271 } // namespace ngraph