f27462bf79617eefe904d39abe2f7cd5faec098a
[platform/upstream/dldt.git] / inference-engine / src / low_precision_transformations / include / low_precision / network_helper.hpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <cmath>
8 #include <memory>
9 #include <string>
10 #include <vector>
11 #include <unordered_set>
12
13 #include <ngraph/ngraph.hpp>
14 #include <ngraph/pattern/matcher.hpp>
15 #include <ngraph/opsets/opset1.hpp>
16 #include "ngraph_ops/type_relaxed.hpp"
17 #include <ngraph/rt_info.hpp>
18
19 #include "transformation_context.hpp"
20 #include "quantization_details.hpp"
21 #include "transformations/utils/utils.hpp"
22 #include "common/fake_quantize_dequantization.hpp"
23 #include "common/ie_lpt_exception.hpp"
24
25 namespace ngraph {
26 namespace pass {
27 namespace low_precision {
28
29 /**
30 * @brief NetworkHelper class encapsulates manipulations with nGraph function.
31 */
32 class TRANSFORMATIONS_API NetworkHelper {
33 public:
34     // Return true if `type` can be castable to at least one of `type`
35     static bool is_castable_to_one_of(NodeTypeInfo type, const std::unordered_set<NodeTypeInfo>& types);
36
37     static std::vector<Input<Node>> consumer_inputs(std::shared_ptr<Node> node);
38
39     // Collect and return a vector with all nodes that consumes any of the `node` output
40     static std::vector<std::shared_ptr<Node>> consumers(std::shared_ptr<Node> node);
41
42     static Shape alignShapeForChannelDim(const Shape& shape, Rank rank);
43
44     // return true if at least one child uses layer on weights
45     static bool onWeights(std::shared_ptr<Node> layer);
46
47     template <typename OperationType>
48     static std::shared_ptr<Node> setOutDataPrecisionForTypeRelaxed(std::shared_ptr<OperationType> operation, const element::Type& precision);
49
50     template <typename OperationType>
51     static std::shared_ptr<Node> setOutDataPrecision(std::shared_ptr<OperationType> operation, const element::Type& precision);
52
53     static size_t getOutputChannelsCount(std::shared_ptr<const Node> layer, bool isOnWeights = false);
54
55     static std::vector<std::shared_ptr<Node>> getParentsRecursivelyExceptTypes(
56         std::shared_ptr<Node> layer,
57         const std::unordered_set<NodeTypeInfo>& exceptionLayerTypes = {},
58         const int portIndex = -1);
59
60     static size_t getInputChannelsCount(std::shared_ptr<Node> layer);
61
62     static size_t getGroupsCount(std::shared_ptr<Node> layer);
63
64     // Remove node by connecting its 0th input with 0th output
65     static void removeLayer(std::shared_ptr<Node> node);
66
67     static std::shared_ptr<Node> swapMultiplyAndAdd(std::shared_ptr<opset1::Add> addAfterMultiply, const int multiplyBranch);
68
69     static void copyInfo(const std::shared_ptr<Node>& source, const std::shared_ptr<Node>& target);
70
71     static void cleanRunTimeInfo(const std::shared_ptr<Node>& layer);
72
73     static bool isScalarLike(std::shared_ptr<opset1::Constant> constant);
74
75     static bool isZero(std::shared_ptr<opset1::Constant> constant);
76
77     static std::shared_ptr<opset1::Constant> toScalar(std::shared_ptr<opset1::Constant> constant);
78
79     static std::shared_ptr<Node> getConstantInput(std::shared_ptr<Node> node);
80
81     // Optimizes the series of multiplies after a given output port
82     static std::shared_ptr<ngraph::opset1::Multiply> optimizeMultipliesAfter(std::shared_ptr<Node> multiply);
83
84     static std::shared_ptr<opset1::Constant> roundWithTolerance(std::shared_ptr<Node> node, element::Type target_type, float tolerance = 0.1);
85
86     static std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> decomposeFakeQuantize(
87         std::shared_ptr<opset1::FakeQuantize> fq,
88         const element::Type precision,
89         const float min,
90         const float max,
91         const bool hasZeroPoint,
92         const bool updatePrecision);
93
94     static std::shared_ptr<opset1::FakeQuantize> updateFakeQuantize(
95         std::shared_ptr<opset1::FakeQuantize> fq,
96         element::Type precision,
97         float min,
98         float max);
99
100     static FakeQuantizeDequantization makeDequantization(
101         const float dequantizationMul,
102         const float dequantizationSub,
103         const ngraph::element::Type originalPrecision,
104         const ngraph::Shape dataNodeOutputShape,
105         element::Type precision,
106         float min,
107         float max);
108
109     static FakeQuantizeDequantization createDequantizationFromFakeQuantize(
110         std::shared_ptr<opset1::FakeQuantize> fq,
111         element::Type precision,
112         float min,
113         float max,
114         const bool hasZeroPoint,
115         const bool updatePrecision);
116
117     static FakeQuantizeDequantization getDequantization(const std::shared_ptr<Node> node, const size_t parentIndex = 0ul, const bool inPlace = false);
118
119     static std::shared_ptr<Node> optimizeSubtract(std::shared_ptr<opset1::Subtract> add);
120
121     class InsertDequantizationResult {
122     public:
123         InsertDequantizationResult(
124             const std::shared_ptr<Node>& newOperation,
125             const std::shared_ptr<Node>& lastDequantization) : newOperation(newOperation), lastDequantization(lastDequantization) {}
126
127         std::shared_ptr<Node> newOperation;
128         std::shared_ptr<Node> lastDequantization;
129     };
130
131     static InsertDequantizationResult moveDequantizationAfter(
132         const std::shared_ptr<ngraph::Node>& operation,
133         const FakeQuantizeDequantization& dequantization,
134         const bool updatePrecision,
135         const bool moveSubtract);
136
137     // TODO: rename: fuseConvertIfPossible
138     static void removeConvertIfPossible(
139         const std::shared_ptr<ngraph::Node>& operation,
140         const FakeQuantizeDequantization& dequantization);
141
142     static bool checkConstantValuePrecision(const element::Type expectedPrecision, const std::shared_ptr<Node>& constant);
143
144     static size_t getChildInputIndex(const std::shared_ptr<ngraph::Node>& parent, const std::shared_ptr<ngraph::Node>& child);
145
146     static size_t getParentOutputIndex(const std::shared_ptr<ngraph::Node>& parent, const std::shared_ptr<ngraph::Node>& child);
147
148     static std::vector<Output<Node>> getInputs(const std::shared_ptr<ngraph::Node>& node);
149
150     static FakeQuantizeDequantizationValues createEmptyValues(const FakeQuantizeDequantization& dequantization);
151
152     static bool isZeroConst(const std::shared_ptr<Node>& node);
153
154     static std::shared_ptr<Node> toScalarIfPossible(std::shared_ptr<Node> node);
155
156     static std::shared_ptr<Node> fold_fake_quantize(const std::shared_ptr<opset1::FakeQuantize>& fq);
157     static std::shared_ptr<Node> fold_fake_quantize(const std::shared_ptr<opset1::FakeQuantize>& fq, const bool roundValues);
158
159     // multi-precision constant folding
160     // handles only specific case: Constant -> [dequantization operations] -> [node]
161     static void foldDequantization(std::shared_ptr<Node>& node, const size_t branchIndex, const bool inPlace = false);
162
163 private:
164     static std::shared_ptr<Node> foldFakeQuantize(const std::shared_ptr<opset1::FakeQuantize>& fq, const bool roundValues, const bool roundValuesWasSet);
165
166     // 1  - on weights
167     // 0  - weightable layer was not found
168     // -1 - on activations
169     static int onWeightsInDepth(std::shared_ptr<Node> layer);
170 };
171
172 template <typename OperationType>
173 std::shared_ptr<Node> NetworkHelper::setOutDataPrecisionForTypeRelaxed(std::shared_ptr<OperationType> layer, const element::Type& precision) {
174     // check if it already exteded operation node
175     if (auto relaxed_layer = std::dynamic_pointer_cast<ngraph::op::TypeRelaxedBase>(layer)) {
176         relaxed_layer->set_overridden_output_type(precision);
177         std::dynamic_pointer_cast<ngraph::Node>(layer)->validate_and_infer_types();
178         return layer;
179     } else {
180         THROW_IE_LPT_EXCEPTION(*layer) << "TypeRelaxed type is expected";
181     }
182 }
183
184 template <typename OperationType>
185 std::shared_ptr<Node> NetworkHelper::setOutDataPrecision(std::shared_ptr<OperationType> layer, const element::Type& precision) {
186     // check if it already exteded operation node
187     if (auto relaxed_layer = std::dynamic_pointer_cast<ngraph::op::TypeRelaxedBase>(layer)) {
188         relaxed_layer->set_overridden_output_type(precision);
189         std::dynamic_pointer_cast<ngraph::Node>(layer)->validate_and_infer_types();
190         return layer;
191     } else {
192         // Make such replacements in advance for all supported polymorphic layer types
193         // extend a node with new semantics: overriden output data_type
194         // OperationType should be a real type of an object, otherwise it will lead to undefined behavior
195         auto replacement = std::make_shared<ngraph::op::TypeRelaxed<OperationType>>(*layer, precision);
196         copy_runtime_info(layer, replacement);
197         replace_node(layer, replacement);
198         return replacement;
199     }
200 }
201
202 template <typename T>
203 std::shared_ptr<Node> make_op_pattern(const ngraph::NodeVector& args) {
204     return std::make_shared<ngraph::pattern::op::Any>(element::undefined, PartialShape{}, [](std::shared_ptr<Node> n) {return !!as_type_ptr<T>(n); }, args);
205 }
206
207 template <typename T>
208 std::shared_ptr<Node> make_op_label() {
209     return std::make_shared<ngraph::pattern::op::Label>(
210             element::undefined,
211             PartialShape{},
212             [](std::shared_ptr<Node> n) {return !!as_type_ptr<T>(n); });
213 }
214
215 template <typename T, typename... Args>
216 std::shared_ptr<Node> fold(Args&&... args) {
217     auto node = std::make_shared<T>(std::forward<Args>(args)...);
218     if (node->get_output_size() == 1) {
219         OutputVector folded(node->get_output_size());
220         if (node->constant_fold(folded, node->input_values())) {
221             return folded[0].get_node_shared_ptr();
222         }
223     }
224     return node;
225 }
226
227 template <typename T, typename... Args>
228 std::shared_ptr<Node> fold_reshape(Args&&... args) {
229     std::shared_ptr<Node> node = std::make_shared<T>(std::forward<Args>(args)...);
230     if (node->get_output_size() == 1) {
231         OutputVector folded;
232         if (is_type<opset1::Constant>(node->input_value(0).get_node_shared_ptr()) &&
233             is_type<opset1::Constant>(node->input_value(1).get_node_shared_ptr())) {
234             return std::make_shared<opset1::Constant>(
235                     node->get_input_element_type(0),
236                     Shape(as_type_ptr<opset1::Constant>(node->input_value(1).get_node_shared_ptr())->template cast_vector<size_t>()),
237                     as_type_ptr<opset1::Constant>(node->input_value(0).get_node_shared_ptr())->get_data_ptr());
238         }
239     }
240     return node;
241 }
242
243 }  // namespace low_precision
244 }  // namespace pass
245 }  // namespace ngraph