1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
11 #include <unordered_set>
13 #include <ngraph/ngraph.hpp>
14 #include <ngraph/pattern/matcher.hpp>
15 #include <ngraph/opsets/opset1.hpp>
16 #include "ngraph_ops/type_relaxed.hpp"
17 #include <ngraph/rt_info.hpp>
19 #include "transformation_context.hpp"
20 #include "quantization_details.hpp"
21 #include "transformations/utils/utils.hpp"
22 #include "common/fake_quantize_dequantization.hpp"
23 #include "common/ie_lpt_exception.hpp"
27 namespace low_precision {
30 * @brief NetworkHelper class encapsulates manipulations with nGraph function.
32 class TRANSFORMATIONS_API NetworkHelper {
34 // Return true if `type` can be castable to at least one of `type`
35 static bool is_castable_to_one_of(NodeTypeInfo type, const std::unordered_set<NodeTypeInfo>& types);
37 static std::vector<Input<Node>> consumer_inputs(std::shared_ptr<Node> node);
39 // Collect and return a vector with all nodes that consumes any of the `node` output
40 static std::vector<std::shared_ptr<Node>> consumers(std::shared_ptr<Node> node);
42 static Shape alignShapeForChannelDim(const Shape& shape, Rank rank);
44 // return true if at least one child uses layer on weights
45 static bool onWeights(std::shared_ptr<Node> layer);
47 template <typename OperationType>
48 static std::shared_ptr<Node> setOutDataPrecisionForTypeRelaxed(std::shared_ptr<OperationType> operation, const element::Type& precision);
50 template <typename OperationType>
51 static std::shared_ptr<Node> setOutDataPrecision(std::shared_ptr<OperationType> operation, const element::Type& precision);
53 static size_t getOutputChannelsCount(std::shared_ptr<const Node> layer, bool isOnWeights = false);
55 static std::vector<std::shared_ptr<Node>> getParentsRecursivelyExceptTypes(
56 std::shared_ptr<Node> layer,
57 const std::unordered_set<NodeTypeInfo>& exceptionLayerTypes = {},
58 const int portIndex = -1);
60 static size_t getInputChannelsCount(std::shared_ptr<Node> layer);
62 static size_t getGroupsCount(std::shared_ptr<Node> layer);
64 // Remove node by connecting its 0th input with 0th output
65 static void removeLayer(std::shared_ptr<Node> node);
67 static std::shared_ptr<Node> swapMultiplyAndAdd(std::shared_ptr<opset1::Add> addAfterMultiply, const int multiplyBranch);
69 static void copyInfo(const std::shared_ptr<Node>& source, const std::shared_ptr<Node>& target);
71 static void cleanRunTimeInfo(const std::shared_ptr<Node>& layer);
73 static bool isScalarLike(std::shared_ptr<opset1::Constant> constant);
75 static bool isZero(std::shared_ptr<opset1::Constant> constant);
77 static std::shared_ptr<opset1::Constant> toScalar(std::shared_ptr<opset1::Constant> constant);
79 static std::shared_ptr<Node> getConstantInput(std::shared_ptr<Node> node);
81 // Optimizes the series of multiplies after a given output port
82 static std::shared_ptr<ngraph::opset1::Multiply> optimizeMultipliesAfter(std::shared_ptr<Node> multiply);
84 static std::shared_ptr<opset1::Constant> roundWithTolerance(std::shared_ptr<Node> node, element::Type target_type, float tolerance = 0.1);
86 static std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> decomposeFakeQuantize(
87 std::shared_ptr<opset1::FakeQuantize> fq,
88 const element::Type precision,
91 const bool hasZeroPoint,
92 const bool updatePrecision);
94 static std::shared_ptr<opset1::FakeQuantize> updateFakeQuantize(
95 std::shared_ptr<opset1::FakeQuantize> fq,
96 element::Type precision,
100 static FakeQuantizeDequantization makeDequantization(
101 const float dequantizationMul,
102 const float dequantizationSub,
103 const ngraph::element::Type originalPrecision,
104 const ngraph::Shape dataNodeOutputShape,
105 element::Type precision,
109 static FakeQuantizeDequantization createDequantizationFromFakeQuantize(
110 std::shared_ptr<opset1::FakeQuantize> fq,
111 element::Type precision,
114 const bool hasZeroPoint,
115 const bool updatePrecision);
117 static FakeQuantizeDequantization getDequantization(const std::shared_ptr<Node> node, const size_t parentIndex = 0ul, const bool inPlace = false);
119 static std::shared_ptr<Node> optimizeSubtract(std::shared_ptr<opset1::Subtract> add);
121 class InsertDequantizationResult {
123 InsertDequantizationResult(
124 const std::shared_ptr<Node>& newOperation,
125 const std::shared_ptr<Node>& lastDequantization) : newOperation(newOperation), lastDequantization(lastDequantization) {}
127 std::shared_ptr<Node> newOperation;
128 std::shared_ptr<Node> lastDequantization;
131 static InsertDequantizationResult moveDequantizationAfter(
132 const std::shared_ptr<ngraph::Node>& operation,
133 const FakeQuantizeDequantization& dequantization,
134 const bool updatePrecision,
135 const bool moveSubtract);
137 // TODO: rename: fuseConvertIfPossible
138 static void removeConvertIfPossible(
139 const std::shared_ptr<ngraph::Node>& operation,
140 const FakeQuantizeDequantization& dequantization);
142 static bool checkConstantValuePrecision(const element::Type expectedPrecision, const std::shared_ptr<Node>& constant);
144 static size_t getChildInputIndex(const std::shared_ptr<ngraph::Node>& parent, const std::shared_ptr<ngraph::Node>& child);
146 static size_t getParentOutputIndex(const std::shared_ptr<ngraph::Node>& parent, const std::shared_ptr<ngraph::Node>& child);
148 static std::vector<Output<Node>> getInputs(const std::shared_ptr<ngraph::Node>& node);
150 static FakeQuantizeDequantizationValues createEmptyValues(const FakeQuantizeDequantization& dequantization);
152 static bool isZeroConst(const std::shared_ptr<Node>& node);
154 static std::shared_ptr<Node> toScalarIfPossible(std::shared_ptr<Node> node);
156 static std::shared_ptr<Node> fold_fake_quantize(const std::shared_ptr<opset1::FakeQuantize>& fq);
157 static std::shared_ptr<Node> fold_fake_quantize(const std::shared_ptr<opset1::FakeQuantize>& fq, const bool roundValues);
159 // multi-precision constant folding
160 // handles only specific case: Constant -> [dequantization operations] -> [node]
161 static void foldDequantization(std::shared_ptr<Node>& node, const size_t branchIndex, const bool inPlace = false);
164 static std::shared_ptr<Node> foldFakeQuantize(const std::shared_ptr<opset1::FakeQuantize>& fq, const bool roundValues, const bool roundValuesWasSet);
167 // 0 - weightable layer was not found
168 // -1 - on activations
169 static int onWeightsInDepth(std::shared_ptr<Node> layer);
172 template <typename OperationType>
173 std::shared_ptr<Node> NetworkHelper::setOutDataPrecisionForTypeRelaxed(std::shared_ptr<OperationType> layer, const element::Type& precision) {
174 // check if it already exteded operation node
175 if (auto relaxed_layer = std::dynamic_pointer_cast<ngraph::op::TypeRelaxedBase>(layer)) {
176 relaxed_layer->set_overridden_output_type(precision);
177 std::dynamic_pointer_cast<ngraph::Node>(layer)->validate_and_infer_types();
180 THROW_IE_LPT_EXCEPTION(*layer) << "TypeRelaxed type is expected";
184 template <typename OperationType>
185 std::shared_ptr<Node> NetworkHelper::setOutDataPrecision(std::shared_ptr<OperationType> layer, const element::Type& precision) {
186 // check if it already exteded operation node
187 if (auto relaxed_layer = std::dynamic_pointer_cast<ngraph::op::TypeRelaxedBase>(layer)) {
188 relaxed_layer->set_overridden_output_type(precision);
189 std::dynamic_pointer_cast<ngraph::Node>(layer)->validate_and_infer_types();
192 // Make such replacements in advance for all supported polymorphic layer types
193 // extend a node with new semantics: overriden output data_type
194 // OperationType should be a real type of an object, otherwise it will lead to undefined behavior
195 auto replacement = std::make_shared<ngraph::op::TypeRelaxed<OperationType>>(*layer, precision);
196 copy_runtime_info(layer, replacement);
197 replace_node(layer, replacement);
202 template <typename T>
203 std::shared_ptr<Node> make_op_pattern(const ngraph::NodeVector& args) {
204 return std::make_shared<ngraph::pattern::op::Any>(element::undefined, PartialShape{}, [](std::shared_ptr<Node> n) {return !!as_type_ptr<T>(n); }, args);
207 template <typename T>
208 std::shared_ptr<Node> make_op_label() {
209 return std::make_shared<ngraph::pattern::op::Label>(
212 [](std::shared_ptr<Node> n) {return !!as_type_ptr<T>(n); });
215 template <typename T, typename... Args>
216 std::shared_ptr<Node> fold(Args&&... args) {
217 auto node = std::make_shared<T>(std::forward<Args>(args)...);
218 if (node->get_output_size() == 1) {
219 OutputVector folded(node->get_output_size());
220 if (node->constant_fold(folded, node->input_values())) {
221 return folded[0].get_node_shared_ptr();
227 template <typename T, typename... Args>
228 std::shared_ptr<Node> fold_reshape(Args&&... args) {
229 std::shared_ptr<Node> node = std::make_shared<T>(std::forward<Args>(args)...);
230 if (node->get_output_size() == 1) {
232 if (is_type<opset1::Constant>(node->input_value(0).get_node_shared_ptr()) &&
233 is_type<opset1::Constant>(node->input_value(1).get_node_shared_ptr())) {
234 return std::make_shared<opset1::Constant>(
235 node->get_input_element_type(0),
236 Shape(as_type_ptr<opset1::Constant>(node->input_value(1).get_node_shared_ptr())->template cast_vector<size_t>()),
237 as_type_ptr<opset1::Constant>(node->input_value(0).get_node_shared_ptr())->get_data_ptr());
243 } // namespace low_precision
245 } // namespace ngraph