Fixed static analysis issues (transformations) (#3276)
[platform/upstream/dldt.git] / inference-engine / src / transformations / src / transformations / convert_precision.cpp
1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "transformations/convert_precision.hpp"
6
7 #include <memory>
8 #include <vector>
9
10 #include <ngraph/opsets/opset5.hpp>
11 #include <ngraph/opsets/opset4.hpp>
12 #include <ngraph/opsets/opset3.hpp>
13 #include <ngraph/opsets/opset1.hpp>
14 #include <ngraph_ops/type_relaxed.hpp>
15
16 using namespace ngraph;
17
18 bool fuse_type_to_constant(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, const std::vector<ngraph::Input<ngraph::Node>> & consumers);
19 bool fuse_type_to_shapeof(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
20 bool fuse_type_to_shapeof_v0(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
21 bool fuse_type_to_parameter(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
22 bool fuse_type_to_convert(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
23 bool fuse_type_to_nms3(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
24 bool fuse_type_to_nms4(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
25 bool fuse_type_to_nms5(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
26 bool fuse_type_to_topk(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
27 bool fuse_type_to_nonzero(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
28 bool fuse_type_to_bucketize(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
29 bool fuse_type_to_generic_ie(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
30
31 bool extend_select_type(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
32
33 template <typename T>
34 bool fuse_type_to_binary_comparision(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
35     if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
36         type_relaxed->set_overridden_output_type(to);
37         return true;
38     } else if (auto casted = std::dynamic_pointer_cast<T>(node)) {
39         auto relaxed_op = std::make_shared<ngraph::op::TypeRelaxed<T>>(*casted, element::TypeVector{}, element::TypeVector{to});
40         replace_node(node, relaxed_op);
41         return true;
42     }
43     return false;
44 }
45
46 template <typename T>
47 bool fuse_type_to_logical(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
48     if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
49         type_relaxed->set_overridden_output_type(to);
50         type_relaxed->set_origin_input_type(element::boolean, 0);
51         type_relaxed->set_origin_input_type(element::boolean, 1);
52         return true;
53     } else if (auto casted = std::dynamic_pointer_cast<T>(node)) {
54         auto relaxed_op = std::make_shared<ngraph::op::TypeRelaxed<T>>(*casted,
55                 element::TypeVector{element::boolean, element::boolean}, element::TypeVector{to});
56         replace_node(node, relaxed_op);
57         return true;
58     }
59     return false;
60 }
61
62 template <class T>
63 bool fuse_type_to_reduce_logical(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
64     if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
65         type_relaxed->set_overridden_output_type(to);
66         type_relaxed->set_origin_input_type(element::boolean, 0);
67         return true;
68     } else if (auto casted = std::dynamic_pointer_cast<T>(node)) {
69         auto relaxed_op = std::make_shared<ngraph::op::TypeRelaxed<T>>(*casted,
70                 element::TypeVector{element::boolean}, element::TypeVector{to});
71         replace_node(node, relaxed_op);
72         return true;
73     }
74     return false;
75 }
76
77 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertPrecision, "ConvertPrecision", 0);
78
79 bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Function> f) {
80     static std::map<ngraph::NodeTypeInfo, std::function<bool(std::shared_ptr<Node>&, element::Type, size_t idx)>> type_to_fuse {
81         {opset4::Parameter::type_info, fuse_type_to_parameter},
82         {opset4::Convert::type_info, fuse_type_to_convert},
83         {opset4::ShapeOf::type_info, fuse_type_to_shapeof},
84         {opset3::NonMaxSuppression::type_info, fuse_type_to_nms3},
85         {opset4::NonMaxSuppression::type_info, fuse_type_to_nms4},
86         {opset5::NonMaxSuppression::type_info, fuse_type_to_nms5},
87         {opset4::TopK::type_info, fuse_type_to_topk},
88         {opset4::NonZero::type_info, fuse_type_to_nonzero},
89         {opset4::Bucketize::type_info, fuse_type_to_bucketize},
90         {NodeTypeInfo("GenericIE", 1), fuse_type_to_generic_ie},
91         {opset4::Equal::type_info, fuse_type_to_binary_comparision<opset4::Equal>},
92         {opset4::NotEqual::type_info, fuse_type_to_binary_comparision<opset4::NotEqual>},
93         {opset4::Greater::type_info, fuse_type_to_binary_comparision<opset4::Greater>},
94         {opset4::GreaterEqual::type_info, fuse_type_to_binary_comparision<opset4::GreaterEqual>},
95         {opset4::Less::type_info, fuse_type_to_binary_comparision<opset4::Less>},
96         {opset4::LessEqual::type_info, fuse_type_to_binary_comparision<opset4::LessEqual>},
97         {opset4::LogicalAnd::type_info, fuse_type_to_logical<opset4::LogicalAnd>},
98         {opset4::LogicalOr::type_info, fuse_type_to_logical<opset4::LogicalOr>},
99         {opset4::LogicalXor::type_info, fuse_type_to_logical<opset4::LogicalXor>},
100         {opset4::LogicalNot::type_info, fuse_type_to_logical<opset4::LogicalNot>},
101         {opset4::ReduceLogicalAnd::type_info, fuse_type_to_reduce_logical<opset4::ReduceLogicalAnd>},
102         {opset4::ReduceLogicalOr::type_info, fuse_type_to_reduce_logical<opset4::ReduceLogicalOr>},
103         {opset1::ShapeOf::type_info, fuse_type_to_shapeof_v0}
104     };
105
106     static std::map<ngraph::NodeTypeInfo, std::function<bool(std::shared_ptr<Node>&, element::Type, size_t idx)>> type_to_extend {
107             {opset4::Select::type_info, extend_select_type},
108     };
109
110     // As Constant operations can be shared between multiple nGraph Functions so before
111     // changing precision we need to understand which Constant consumers belongs
112     // to the current nGraph Function
113     std::map<std::shared_ptr<Node>, std::vector<Input<Node>>> const_to_internal_output;
114
115     std::function<void(const std::shared_ptr<Function> &)> register_constants =
116             [&const_to_internal_output, &register_constants](const std::shared_ptr<Function> & f) {
117         for (auto & node : f->get_ordered_ops()) {
118             for (auto & input : node->inputs()) {
119                 if (auto const_node = std::dynamic_pointer_cast<opset4::Constant>(input.get_source_output().get_node_shared_ptr())) {
120                     const_to_internal_output[const_node].emplace_back(input);
121                 }
122             }
123         }
124     };
125
126     auto convert_node_output_precision = [this, &const_to_internal_output](std::shared_ptr<Node> & node) {
127         for (auto output : node->outputs()) {
128             if (output.get_element_type() == m_from) {
129                 // Handle case with Constants as they can have consumers from other nGraph Function object
130                 if (ngraph::op::is_constant(node) && const_to_internal_output.count(node)) {
131                     fuse_type_to_constant(node, m_to, const_to_internal_output.at(node));
132                     break;
133                 }
134
135                 // Check that node type exists in map and we can fuse type into node
136                 if (type_to_fuse.count(node->get_type_info()) &&
137                     type_to_fuse.at(node->get_type_info())(node, m_to, output.get_index())) {
138                     // We need to break if original node was replaced
139                     break;
140                 }
141             }
142         }
143     };
144
145     auto convert_node_input_precision = [this](std::shared_ptr<Node> & node) {
146         for (auto input : node->inputs()) {
147             if (input.get_element_type() == m_from) {
148                 // For some operations we need to extend their input types to support new type
149                 if (type_to_extend.count(node->get_type_info()) &&
150                     type_to_extend.at(node->get_type_info())(node, m_to, input.get_index())) {
151                     break;
152                 }
153             }
154         }
155     };
156
157     std::function<void(const std::shared_ptr<Function> &)> convert_function_precision =
158             [this, &const_to_internal_output,
159                    &register_constants,
160                    &convert_node_output_precision,
161                    &convert_node_input_precision,
162                    &convert_function_precision] (const std::shared_ptr<Function> & f) {
163         // Iterate over all nodes in topological order and then iterate over node outputs.
164         // If output type mismatch given type we try to fuse type into this operation
165         // otherwise we insert Convert operation.
166         for (auto &node : f->get_ordered_ops()) {
167             m_transformation_callback(node);
168             // Recursively apply transformation for sub-graph based operations
169             if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
170                 if (auto sub_graph = sub_graph_node->get_function()) {
171                     convert_function_precision(sub_graph);
172                 }
173             }
174             convert_node_input_precision(node);
175         }
176         // Register internal constants only after fixing input type that could lead to nodes replacement
177         register_constants(f);
178
179         for (auto &node : f->get_ordered_ops()) {
180             convert_node_output_precision(node);
181         }
182     };
183
184     convert_function_precision(f);
185     f->validate_nodes_and_infer_types();
186
187     // TODO: we need to split NopElimination pass to separate MatcherPasses and call Convert elimination here
188     for (auto &node : f->get_ordered_ops()) {
189         if (auto convert = std::dynamic_pointer_cast<opset4::Convert>(node)) {
190             // WA for topK, dont remove fake convert
191             if (convert->input(0).get_element_type() == convert->get_convert_element_type() &&
192                 convert->input_value(0).get_node_shared_ptr()->get_output_size() == 1) {
193                 replace_output_update_name(convert->output(0), convert->input_value(0));
194             }
195         }
196     }
197     return true;
198 }
199
200 bool fuse_type_to_shapeof(std::shared_ptr<Node> & node, element::Type to, size_t idx) {
201     if (auto shapeof = as_type_ptr<opset4::ShapeOf>(node)) {
202         if (to == element::i32 || to == element::i64) {
203             shapeof->set_output_type(to);
204             return true;
205         }
206     }
207     return false;
208 }
209
210 bool fuse_type_to_parameter(std::shared_ptr<Node> & node, element::Type to, size_t idx) {
211     if (auto param = as_type_ptr<opset4::Parameter>(node)) {
212         param->set_element_type(to);
213         param->validate_and_infer_types();
214         return true;
215     }
216     return false;
217 }
218
219 bool fuse_type_to_convert(std::shared_ptr<Node> & node, element::Type to, size_t idx) {
220     if (auto convert = as_type_ptr<opset4::Convert>(node)) {
221         convert->set_convert_element_type(to);
222         return true;
223     }
224     return false;
225 }
226
227 bool fuse_type_to_nms3(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
228     if (auto nms = as_type_ptr<opset3::NonMaxSuppression>(node)) {
229         nms->set_output_type(to);
230         return true;
231     }
232     return false;
233 }
234
235 bool fuse_type_to_nms4(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
236     if (auto nms = as_type_ptr<opset4::NonMaxSuppression>(node)) {
237         nms->set_output_type(to);
238         return true;
239     }
240     return false;
241 }
242
243 bool fuse_type_to_nms5(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
244     if (auto nms = as_type_ptr<opset5::NonMaxSuppression>(node)) {
245         nms->set_output_type(to);
246         return true;
247     }
248     return false;
249 }
250
251 bool fuse_type_to_topk(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
252     if (auto topk = as_type_ptr<opset4::TopK>(node)) {
253         if (idx == 1 && (to == element::i32 || to == element::i64)) {
254             topk->set_index_element_type(to);
255             return true;
256         }
257     }
258     return false;
259 }
260
261 bool fuse_type_to_nonzero(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
262     if (auto nonzero = as_type_ptr<opset4::NonZero>(node)) {
263         if (to == element::i32 || to == element::i64) {
264             nonzero->set_output_type(to);
265             return true;
266         }
267     }
268     return false;
269 }
270
271 bool fuse_type_to_bucketize(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
272     if (auto b = as_type_ptr<opset4::Bucketize>(node)) {
273         if (to == element::i32 || to == element::i64) {
274             b->set_output_type(to);
275             return true;
276         }
277     }
278     return false;
279 }
280
281 bool fuse_type_to_generic_ie(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
282     node->set_output_type(idx, to, node->output(idx).get_partial_shape());
283     // return false as we do not replace original node
284     return false;
285 }
286
287 bool fuse_type_to_shapeof_v0(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
288     if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
289         type_relaxed->set_overridden_output_type(to);
290         return true;
291     } else if (auto casted = std::dynamic_pointer_cast<opset1::ShapeOf>(node)) {
292         auto relaxed_op = std::make_shared<ngraph::op::TypeRelaxed<opset1::ShapeOf>>(*casted,
293                 element::TypeVector{}, element::TypeVector{to});
294         replace_node(node, relaxed_op);
295         return true;
296     }
297     return false;
298 }
299
300 bool extend_select_type(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
301     if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
302         type_relaxed->set_origin_input_type(element::boolean, 0);
303         return true;
304     } else if (auto casted = std::dynamic_pointer_cast<opset4::Select>(node)) {
305         auto relaxed_op = std::make_shared<op::TypeRelaxed<opset4::Select>>(*casted,
306                 element::TypeVector{element::boolean},
307                 element::TypeVector{});
308         replace_node(node, relaxed_op);
309         return true;
310     }
311     return false;
312 }
313
314 template <typename src_type, typename dst_type>
315 inline dst_type convert_value(src_type val) {
316     if (val > std::numeric_limits<dst_type>::max()) {
317         return std::numeric_limits<dst_type>::max();
318     } else if (val < std::numeric_limits<dst_type>::lowest()) {
319         return std::numeric_limits<dst_type>::lowest();
320     }
321     return static_cast<dst_type>(val);
322 }
323
324 // We need to treat U64->I32 and U32->I32 as a separate case, because of C++'s implicit promotion from signed to unsigned,
325 // and we don't need to compare and clamp the input to std::numeric_limits<int32_t>::lowest()
326 template <>
327 inline int32_t convert_value<uint64_t, int32_t>(uint64_t val) {
328     if (val > std::numeric_limits<int32_t>::max()) {
329         return std::numeric_limits<int32_t>::max();
330     }
331     return static_cast<int32_t>(val);
332 }
333
334 template <>
335 inline int32_t convert_value<uint32_t, int32_t>(uint32_t val) {
336     if (val > static_cast<uint32_t>(std::numeric_limits<int32_t>::max())) {
337         return std::numeric_limits<int32_t>::max();
338     }
339     return static_cast<int32_t>(val);
340 }
341
342 template <element::Type_t PREC_FROM, element::Type_t PREC_TO>
343 static std::shared_ptr<Node> change_constant_precision(std::shared_ptr<opset4::Constant>& constant) {
344     using src_type = typename element_type_traits<PREC_FROM>::value_type;
345     using dst_type = typename element_type_traits<PREC_TO>::value_type;
346
347     const auto * src_data = constant->get_data_ptr<src_type>();
348     const auto size = shape_size(constant->get_shape());
349
350     auto new_constant = std::make_shared<ngraph::opset4::Constant>(PREC_TO, constant->get_shape());
351     auto * dst_data = const_cast<dst_type *>(reinterpret_cast<const dst_type *>(new_constant->get_data_ptr()));
352     if (dst_data == nullptr)
353         throw ngraph_error("Can't get destination data pointer");
354
355     std::vector<dst_type> final_data;
356     for (size_t i = 0; i < size; ++i) {
357         dst_data[i] = convert_value<src_type, dst_type>(src_data[i]);
358     }
359     return new_constant;
360 }
361
362 bool fuse_type_to_constant(std::shared_ptr<Node> & node, element::Type to, const std::vector<Input<Node>> & consumers) {
363     if (auto constant = as_type_ptr<opset4::Constant>(node)) {
364         auto from = constant->get_element_type();
365         std::shared_ptr<Node> new_const;
366         if (from == element::u64 && to == element::i32) {
367             new_const = change_constant_precision<element::Type_t::u64, element::Type_t::i32>(constant);
368         } else if (from == element::i64 && to == element::i32) {
369             new_const = change_constant_precision<element::Type_t::i64, element::Type_t::i32>(constant);
370         } else if (from == element::u8 && to == element::i32) {
371             new_const = change_constant_precision<element::Type_t::u8, element::Type_t::i32>(constant);
372         } else if (from == element::u16 && to == element::i32) {
373             new_const = change_constant_precision<element::Type_t::u16, element::Type_t::i32>(constant);
374         } else if (from == element::u32 && to == element::i32) {
375             new_const = change_constant_precision<element::Type_t::u32, element::Type_t::i32>(constant);
376         } else if (from == element::f16 && to == element::f32) {
377             new_const = change_constant_precision<element::Type_t::f16, element::Type_t::f32>(constant);
378         } else if (from == element::boolean && to == element::u8) {
379             new_const = change_constant_precision<element::Type_t::boolean, element::Type_t::u8>(constant);
380         } else if (from == element::boolean && to == element::i32) {
381             new_const = change_constant_precision<element::Type_t::boolean, element::Type_t::i32>(constant);
382         } else {
383             throw ngraph_error("not supported");
384         }
385         for (auto & output : consumers) {
386             output.replace_source_output(new_const);
387         }
388
389         new_const->validate_and_infer_types();
390         if (constant->get_output_target_inputs(0).size() == consumers.size()) {
391             new_const->set_friendly_name(constant->get_friendly_name());
392         }
393     }
394     return false;
395 }