58d01b732a43129605bdea9aae22d25356cd52bb
[platform/upstream/dldt.git] / inference-engine / src / low_precision_transformations / src / common / reshape.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "low_precision/reshape.hpp"
6
7 #include <algorithm>
8 #include <memory>
9 #include <string>
10 #include <unordered_set>
11 #include <utility>
12 #include <vector>
13
14 #include "low_precision/common/ie_lpt_exception.hpp"
15 #include "low_precision/network_helper.hpp"
16
17 namespace ngraph {
18 namespace pass {
19 namespace low_precision {
20
21 void ReshapeTransformation::registerMatcherIn(GraphRewrite &pass, TransformationContext &context) const {
22     addPattern(
23         pass,
24         context,
25         make_op_pattern<opset1::Reshape>({ make_op_label<opset1::Multiply>(), make_op_label<opset1::Constant>() }));
26 }
27
28 void reshapeDequantizationConstant(const std::shared_ptr<opset1::Reshape>& reshape) {
29     const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(reshape, 0);
30     if (dequantization.multiply->get_input_node_ptr(1)->get_output_shape(0).size() > 1ul) {
31         auto replaceConstant = [](const std::shared_ptr<opset1::Reshape>& reshape, const std::shared_ptr<Node>& op) {
32             if (reshape->output(0).get_shape().size() == 2ul) {
33                 const auto inputShape = reshape->input(0).get_shape();
34
35                 Shape shape(inputShape);
36                 shape[0] = 1ul;
37
38                 const std::shared_ptr<Node> broadcastedConstant = fold<opset1::Broadcast>(
39                     op->get_input_node_shared_ptr(1),
40                     std::make_shared<opset1::Constant>(element::i32, Shape{ shape.size() }, shape));
41
42                 const std::shared_ptr<Node> reshapedConstant = fold<opset1::Reshape>(
43                     broadcastedConstant,
44                     reshape->get_input_node_shared_ptr(1),
45                     reshape->get_special_zero());
46
47                 replace_node(op->get_input_node_shared_ptr(1), reshapedConstant);
48             } else {
49                 // Original Reshape operation is used to update operation Constant.
50                 // But original Reshape operation output data shape constant should be changed before reshape.
51
52                 // simple broadcast operation Constant shape to shape on activations
53                 auto newOperationConstantShape = op->input(1).get_shape();
54                 auto const reshapeInputShape = reshape->input(0).get_shape();
55                 if ((reshapeInputShape.size() - newOperationConstantShape.size()) == 1ul) {
56                     newOperationConstantShape.insert(newOperationConstantShape.begin(), 1ul);
57                 }
58                 const std::shared_ptr<opset1::Constant> originalConstant = as_type_ptr<opset1::Constant>(op->get_input_node_shared_ptr(1));
59                 const std::shared_ptr<opset1::Constant> newOperationConstant = std::make_shared<opset1::Constant>(
60                     op->input(1).get_element_type(),
61                     newOperationConstantShape,
62                     originalConstant->cast_vector<float>());
63
64                 // update Reshape constant
65                 const std::vector<int> reshapeConstValues = as_type_ptr<opset1::Constant>(reshape->get_input_node_shared_ptr(1))->cast_vector<int>();
66                 std::vector<int> newReshapeConstValues(reshapeConstValues);
67                 for (int i = static_cast<int>(newReshapeConstValues.size() - 1); i >= 0; --i) {
68                     if (newOperationConstantShape.size() <= i) {
69                         newReshapeConstValues[i] = 1;
70                     } else if (newOperationConstantShape[i] == 1ul) {
71                         // not used dimension
72                         newReshapeConstValues[i] = 1;
73                     } else {
74                         break;
75                     }
76                 }
77
78                 const std::shared_ptr<opset1::Constant> newReshapedConstant = std::make_shared<opset1::Constant>(
79                     reshape->input(1).get_element_type(),
80                     Shape({ newReshapeConstValues.size() }),
81                     newReshapeConstValues);
82
83                 const std::shared_ptr<Node> resultConstant = fold<opset1::Reshape>(
84                     newOperationConstant,
85                     newReshapedConstant,
86                     reshape->get_special_zero());
87
88                 replace_node(op->get_input_node_shared_ptr(1), resultConstant);
89             }
90         };
91
92         if (dequantization.subtract != nullptr) {
93             replaceConstant(reshape, dequantization.subtract);
94         }
95
96         if (dequantization.multiply != nullptr) {
97             replaceConstant(reshape, dequantization.multiply);
98         }
99     }
100 }
101
102 bool ReshapeTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) const {
103     std::shared_ptr<opset1::Reshape> reshape = as_type_ptr<opset1::Reshape>(m.get_match_root());
104     if ((reshape == nullptr) || (!canBeTransformed(context, reshape))) {
105         return false;
106     }
107
108     reshape = as_type_ptr<opset1::Reshape>(separateInStandaloneBranch(reshape));
109     reshapeDequantizationConstant(reshape);
110     moveDequantizationAfter(context, reshape, NetworkHelper::getDequantization(reshape, 0), false);
111     return true;
112 }
113
114 bool ReshapeTransformation::isPrecisionPreserved(std::shared_ptr<Node> op) const noexcept {
115     return true;
116 }
117
118 size_t getLastNotBroadcastedChannel(const Shape& shape) {
119     for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
120         if (shape[i] != 1ul) {
121             return i;
122         }
123     }
124     return 0;
125 }
126
127 size_t getFirstChangedChannel(const Shape& shape1, const Shape& shape2) {
128     const size_t minSize = std::min(shape1.size(), shape2.size());
129     size_t i = 0;
130     for (; i < minSize; ++i) {
131         if (shape1[i] != shape2[i]) {
132             return i;
133         }
134     }
135     return i;
136 }
137
138 bool ReshapeTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const {
139     if (!LayerTransformation::canBeTransformed(context, op)) {
140         return false;
141     }
142
143     const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(op);
144     if (dequantization.empty()) {
145         return false;
146     }
147
148     const Shape subtractShape = dequantization.subtract == nullptr ? Shape{} : dequantization.subtract->input(1).get_shape();
149     Shape subtractShapeWithBatch = subtractShape;
150     const Shape inputShape = op->get_input_shape(0);
151     if ((dequantization.subtract != nullptr) &&
152         (subtractShapeWithBatch.size() > 1) &&
153         (subtractShapeWithBatch.size() < inputShape.size())) {
154         subtractShapeWithBatch.insert(subtractShapeWithBatch.begin(), inputShape[0]);
155     }
156
157     const Shape multiplyShape = dequantization.multiply == nullptr ? Shape{} : dequantization.multiply->input(1).get_shape();
158     Shape multiplyShapeWithBatch = multiplyShape;
159     if ((dequantization.multiply != nullptr) &&
160         (multiplyShapeWithBatch.size() > 1) &&
161         (multiplyShapeWithBatch.size() < inputShape.size())) {
162         multiplyShapeWithBatch.insert(multiplyShapeWithBatch.begin(), inputShape[0]);
163     }
164
165     const Shape outputShape = op->get_output_shape(0);
166     return canBeTransformed(subtractShapeWithBatch, multiplyShapeWithBatch, inputShape, outputShape);
167 }
168
169 size_t getChannelVolume(const Shape& shape) {
170     size_t volume = 1ul;
171     for (size_t i = 2; i < shape.size(); ++i) {
172         volume = volume * shape[i];
173     }
174     return volume;
175 }
176
177 bool ReshapeTransformation::canBeTransformed(
178     const ngraph::Shape& subtractShape,
179     const ngraph::Shape& multiplyShape,
180     const ngraph::Shape& inputShape,
181     const ngraph::Shape& outputShape) {
182     if ((inputShape.size() < 2ul) || (outputShape.size() < 2ul) || (inputShape[0] != outputShape[0])) {
183         return false;
184     }
185
186     // TODO: story 38439
187     if ((inputShape.size() == 4ul) && (outputShape.size() == 2ul)) {
188         auto checkSpatialDimensions = [](const Shape& dequantizationConstShape) {
189             for (size_t i = (dequantizationConstShape.size() - 2); i < dequantizationConstShape.size(); ++i) {
190                 if (dequantizationConstShape[i] != 1ul) {
191                     return false;
192                 }
193             }
194             return true;
195         };
196
197         if (((subtractShape.size() >= 3ul) && (!checkSpatialDimensions(subtractShape))) ||
198             ((multiplyShape.size() >= 3ul) && (!checkSpatialDimensions(multiplyShape)))) {
199             return false;
200         }
201
202         // custom validation for Layout::NCHW => Layout::NC
203         const size_t inputChannelsCount = inputShape.size() > 1ul ? inputShape[1] : inputShape[0];
204         const size_t outputChannelsCount = outputShape.size() > 1ul ? outputShape[1] : outputShape[0];
205         if ((inputShape[0] != outputShape[0]) || ((inputChannelsCount * getChannelVolume(inputShape)) != outputChannelsCount)) {
206             return false;
207         }
208     } else {
209         for (size_t i = 0; i < 2ul; ++i) {
210             if (inputShape[i] != outputShape[i]) {
211                 return false;
212             }
213         }
214
215         const size_t lastNotBroadcastedChannel = std::max(getLastNotBroadcastedChannel(subtractShape), getLastNotBroadcastedChannel(multiplyShape));
216         const size_t firstChangedChannel = getFirstChangedChannel(inputShape, outputShape);
217         if (lastNotBroadcastedChannel >= firstChangedChannel) {
218             return false;
219         }
220     }
221
222     return true;
223 }
224
225 } // namespace low_precision
226 } // namespace pass
227 } // namespace ngraph