1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "transformations/convert_precision.hpp"
10 #include <ngraph/opsets/opset5.hpp>
11 #include <ngraph/opsets/opset4.hpp>
12 #include <ngraph/opsets/opset3.hpp>
13 #include <ngraph/opsets/opset1.hpp>
14 #include <ngraph_ops/type_relaxed.hpp>
16 using namespace ngraph;
18 bool fuse_type_to_constant(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, const std::vector<ngraph::Input<ngraph::Node>> & consumers);
19 bool fuse_type_to_shapeof(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
20 bool fuse_type_to_shapeof_v0(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
21 bool fuse_type_to_parameter(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
22 bool fuse_type_to_convert(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
23 bool fuse_type_to_nms3(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
24 bool fuse_type_to_nms4(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
25 bool fuse_type_to_nms5(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
26 bool fuse_type_to_topk(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
27 bool fuse_type_to_nonzero(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
28 bool fuse_type_to_bucketize(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
29 bool fuse_type_to_generic_ie(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
31 bool extend_select_type(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
34 bool fuse_type_to_binary_comparision(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
35 if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
36 type_relaxed->set_overridden_output_type(to);
38 } else if (auto casted = std::dynamic_pointer_cast<T>(node)) {
39 auto relaxed_op = std::make_shared<ngraph::op::TypeRelaxed<T>>(*casted, element::TypeVector{}, element::TypeVector{to});
40 replace_node(node, relaxed_op);
47 bool fuse_type_to_logical(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
48 if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
49 type_relaxed->set_overridden_output_type(to);
50 type_relaxed->set_origin_input_type(element::boolean, 0);
51 type_relaxed->set_origin_input_type(element::boolean, 1);
53 } else if (auto casted = std::dynamic_pointer_cast<T>(node)) {
54 auto relaxed_op = std::make_shared<ngraph::op::TypeRelaxed<T>>(*casted,
55 element::TypeVector{element::boolean, element::boolean}, element::TypeVector{to});
56 replace_node(node, relaxed_op);
63 bool fuse_type_to_reduce_logical(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
64 if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
65 type_relaxed->set_overridden_output_type(to);
66 type_relaxed->set_origin_input_type(element::boolean, 0);
68 } else if (auto casted = std::dynamic_pointer_cast<T>(node)) {
69 auto relaxed_op = std::make_shared<ngraph::op::TypeRelaxed<T>>(*casted,
70 element::TypeVector{element::boolean}, element::TypeVector{to});
71 replace_node(node, relaxed_op);
77 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertPrecision, "ConvertPrecision", 0);
79 bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Function> f) {
80 static std::map<ngraph::NodeTypeInfo, std::function<bool(std::shared_ptr<Node>&, element::Type, size_t idx)>> type_to_fuse {
81 {opset4::Parameter::type_info, fuse_type_to_parameter},
82 {opset4::Convert::type_info, fuse_type_to_convert},
83 {opset4::ShapeOf::type_info, fuse_type_to_shapeof},
84 {opset3::NonMaxSuppression::type_info, fuse_type_to_nms3},
85 {opset4::NonMaxSuppression::type_info, fuse_type_to_nms4},
86 {opset5::NonMaxSuppression::type_info, fuse_type_to_nms5},
87 {opset4::TopK::type_info, fuse_type_to_topk},
88 {opset4::NonZero::type_info, fuse_type_to_nonzero},
89 {opset4::Bucketize::type_info, fuse_type_to_bucketize},
90 {NodeTypeInfo("GenericIE", 1), fuse_type_to_generic_ie},
91 {opset4::Equal::type_info, fuse_type_to_binary_comparision<opset4::Equal>},
92 {opset4::NotEqual::type_info, fuse_type_to_binary_comparision<opset4::NotEqual>},
93 {opset4::Greater::type_info, fuse_type_to_binary_comparision<opset4::Greater>},
94 {opset4::GreaterEqual::type_info, fuse_type_to_binary_comparision<opset4::GreaterEqual>},
95 {opset4::Less::type_info, fuse_type_to_binary_comparision<opset4::Less>},
96 {opset4::LessEqual::type_info, fuse_type_to_binary_comparision<opset4::LessEqual>},
97 {opset4::LogicalAnd::type_info, fuse_type_to_logical<opset4::LogicalAnd>},
98 {opset4::LogicalOr::type_info, fuse_type_to_logical<opset4::LogicalOr>},
99 {opset4::LogicalXor::type_info, fuse_type_to_logical<opset4::LogicalXor>},
100 {opset4::LogicalNot::type_info, fuse_type_to_logical<opset4::LogicalNot>},
101 {opset4::ReduceLogicalAnd::type_info, fuse_type_to_reduce_logical<opset4::ReduceLogicalAnd>},
102 {opset4::ReduceLogicalOr::type_info, fuse_type_to_reduce_logical<opset4::ReduceLogicalOr>},
103 {opset1::ShapeOf::type_info, fuse_type_to_shapeof_v0}
106 static std::map<ngraph::NodeTypeInfo, std::function<bool(std::shared_ptr<Node>&, element::Type, size_t idx)>> type_to_extend {
107 {opset4::Select::type_info, extend_select_type},
110 // As Constant operations can be shared between multiple nGraph Functions so before
111 // changing precision we need to understand which Constant consumers belongs
112 // to the current nGraph Function
113 std::map<std::shared_ptr<Node>, std::vector<Input<Node>>> const_to_internal_output;
115 std::function<void(const std::shared_ptr<Function> &)> register_constants =
116 [&const_to_internal_output, ®ister_constants](const std::shared_ptr<Function> & f) {
117 for (auto & node : f->get_ordered_ops()) {
118 for (auto & input : node->inputs()) {
119 if (auto const_node = std::dynamic_pointer_cast<opset4::Constant>(input.get_source_output().get_node_shared_ptr())) {
120 const_to_internal_output[const_node].emplace_back(input);
126 auto convert_node_output_precision = [this, &const_to_internal_output](std::shared_ptr<Node> & node) {
127 for (auto output : node->outputs()) {
128 if (output.get_element_type() == m_from) {
129 // Handle case with Constants as they can have consumers from other nGraph Function object
130 if (ngraph::op::is_constant(node) && const_to_internal_output.count(node)) {
131 fuse_type_to_constant(node, m_to, const_to_internal_output.at(node));
135 // Check that node type exists in map and we can fuse type into node
136 if (type_to_fuse.count(node->get_type_info()) &&
137 type_to_fuse.at(node->get_type_info())(node, m_to, output.get_index())) {
138 // We need to break if original node was replaced
145 auto convert_node_input_precision = [this](std::shared_ptr<Node> & node) {
146 for (auto input : node->inputs()) {
147 if (input.get_element_type() == m_from) {
148 // For some operations we need to extend their input types to support new type
149 if (type_to_extend.count(node->get_type_info()) &&
150 type_to_extend.at(node->get_type_info())(node, m_to, input.get_index())) {
157 std::function<void(const std::shared_ptr<Function> &)> convert_function_precision =
158 [this, &const_to_internal_output,
160 &convert_node_output_precision,
161 &convert_node_input_precision,
162 &convert_function_precision] (const std::shared_ptr<Function> & f) {
163 // Iterate over all nodes in topological order and then iterate over node outputs.
164 // If output type mismatch given type we try to fuse type into this operation
165 // otherwise we insert Convert operation.
166 for (auto &node : f->get_ordered_ops()) {
167 m_transformation_callback(node);
168 // Recursively apply transformation for sub-graph based operations
169 if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
170 if (auto sub_graph = sub_graph_node->get_function()) {
171 convert_function_precision(sub_graph);
174 convert_node_input_precision(node);
176 // Register internal constants only after fixing input type that could lead to nodes replacement
177 register_constants(f);
179 for (auto &node : f->get_ordered_ops()) {
180 convert_node_output_precision(node);
184 convert_function_precision(f);
185 f->validate_nodes_and_infer_types();
187 // TODO: we need to split NopElimination pass to separate MatcherPasses and call Convert elimination here
188 for (auto &node : f->get_ordered_ops()) {
189 if (auto convert = std::dynamic_pointer_cast<opset4::Convert>(node)) {
190 // WA for topK, dont remove fake convert
191 if (convert->input(0).get_element_type() == convert->get_convert_element_type() &&
192 convert->input_value(0).get_node_shared_ptr()->get_output_size() == 1) {
193 replace_output_update_name(convert->output(0), convert->input_value(0));
200 bool fuse_type_to_shapeof(std::shared_ptr<Node> & node, element::Type to, size_t idx) {
201 if (auto shapeof = as_type_ptr<opset4::ShapeOf>(node)) {
202 if (to == element::i32 || to == element::i64) {
203 shapeof->set_output_type(to);
210 bool fuse_type_to_parameter(std::shared_ptr<Node> & node, element::Type to, size_t idx) {
211 if (auto param = as_type_ptr<opset4::Parameter>(node)) {
212 param->set_element_type(to);
213 param->validate_and_infer_types();
219 bool fuse_type_to_convert(std::shared_ptr<Node> & node, element::Type to, size_t idx) {
220 if (auto convert = as_type_ptr<opset4::Convert>(node)) {
221 convert->set_convert_element_type(to);
227 bool fuse_type_to_nms3(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
228 if (auto nms = as_type_ptr<opset3::NonMaxSuppression>(node)) {
229 nms->set_output_type(to);
235 bool fuse_type_to_nms4(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
236 if (auto nms = as_type_ptr<opset4::NonMaxSuppression>(node)) {
237 nms->set_output_type(to);
243 bool fuse_type_to_nms5(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
244 if (auto nms = as_type_ptr<opset5::NonMaxSuppression>(node)) {
245 nms->set_output_type(to);
251 bool fuse_type_to_topk(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
252 if (auto topk = as_type_ptr<opset4::TopK>(node)) {
253 if (idx == 1 && (to == element::i32 || to == element::i64)) {
254 topk->set_index_element_type(to);
261 bool fuse_type_to_nonzero(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
262 if (auto nonzero = as_type_ptr<opset4::NonZero>(node)) {
263 if (to == element::i32 || to == element::i64) {
264 nonzero->set_output_type(to);
271 bool fuse_type_to_bucketize(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
272 if (auto b = as_type_ptr<opset4::Bucketize>(node)) {
273 if (to == element::i32 || to == element::i64) {
274 b->set_output_type(to);
281 bool fuse_type_to_generic_ie(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
282 node->set_output_type(idx, to, node->output(idx).get_partial_shape());
283 // return false as we do not replace original node
287 bool fuse_type_to_shapeof_v0(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
288 if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
289 type_relaxed->set_overridden_output_type(to);
291 } else if (auto casted = std::dynamic_pointer_cast<opset1::ShapeOf>(node)) {
292 auto relaxed_op = std::make_shared<ngraph::op::TypeRelaxed<opset1::ShapeOf>>(*casted,
293 element::TypeVector{}, element::TypeVector{to});
294 replace_node(node, relaxed_op);
300 bool extend_select_type(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
301 if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
302 type_relaxed->set_origin_input_type(element::boolean, 0);
304 } else if (auto casted = std::dynamic_pointer_cast<opset4::Select>(node)) {
305 auto relaxed_op = std::make_shared<op::TypeRelaxed<opset4::Select>>(*casted,
306 element::TypeVector{element::boolean},
307 element::TypeVector{});
308 replace_node(node, relaxed_op);
314 template <typename src_type, typename dst_type>
315 inline dst_type convert_value(src_type val) {
316 if (val > std::numeric_limits<dst_type>::max()) {
317 return std::numeric_limits<dst_type>::max();
318 } else if (val < std::numeric_limits<dst_type>::lowest()) {
319 return std::numeric_limits<dst_type>::lowest();
321 return static_cast<dst_type>(val);
324 // We need to treat U64->I32 and U32->I32 as a separate case, because of C++'s implicit promotion from signed to unsigned,
325 // and we don't need to compare and clamp the input to std::numeric_limits<int32_t>::lowest()
327 inline int32_t convert_value<uint64_t, int32_t>(uint64_t val) {
328 if (val > std::numeric_limits<int32_t>::max()) {
329 return std::numeric_limits<int32_t>::max();
331 return static_cast<int32_t>(val);
335 inline int32_t convert_value<uint32_t, int32_t>(uint32_t val) {
336 if (val > static_cast<uint32_t>(std::numeric_limits<int32_t>::max())) {
337 return std::numeric_limits<int32_t>::max();
339 return static_cast<int32_t>(val);
342 template <element::Type_t PREC_FROM, element::Type_t PREC_TO>
343 static std::shared_ptr<Node> change_constant_precision(std::shared_ptr<opset4::Constant>& constant) {
344 using src_type = typename element_type_traits<PREC_FROM>::value_type;
345 using dst_type = typename element_type_traits<PREC_TO>::value_type;
347 const auto * src_data = constant->get_data_ptr<src_type>();
348 const auto size = shape_size(constant->get_shape());
350 auto new_constant = std::make_shared<ngraph::opset4::Constant>(PREC_TO, constant->get_shape());
351 auto * dst_data = const_cast<dst_type *>(reinterpret_cast<const dst_type *>(new_constant->get_data_ptr()));
352 if (dst_data == nullptr)
353 throw ngraph_error("Can't get destination data pointer");
355 std::vector<dst_type> final_data;
356 for (size_t i = 0; i < size; ++i) {
357 dst_data[i] = convert_value<src_type, dst_type>(src_data[i]);
362 bool fuse_type_to_constant(std::shared_ptr<Node> & node, element::Type to, const std::vector<Input<Node>> & consumers) {
363 if (auto constant = as_type_ptr<opset4::Constant>(node)) {
364 auto from = constant->get_element_type();
365 std::shared_ptr<Node> new_const;
366 if (from == element::u64 && to == element::i32) {
367 new_const = change_constant_precision<element::Type_t::u64, element::Type_t::i32>(constant);
368 } else if (from == element::i64 && to == element::i32) {
369 new_const = change_constant_precision<element::Type_t::i64, element::Type_t::i32>(constant);
370 } else if (from == element::u8 && to == element::i32) {
371 new_const = change_constant_precision<element::Type_t::u8, element::Type_t::i32>(constant);
372 } else if (from == element::u16 && to == element::i32) {
373 new_const = change_constant_precision<element::Type_t::u16, element::Type_t::i32>(constant);
374 } else if (from == element::u32 && to == element::i32) {
375 new_const = change_constant_precision<element::Type_t::u32, element::Type_t::i32>(constant);
376 } else if (from == element::f16 && to == element::f32) {
377 new_const = change_constant_precision<element::Type_t::f16, element::Type_t::f32>(constant);
378 } else if (from == element::boolean && to == element::u8) {
379 new_const = change_constant_precision<element::Type_t::boolean, element::Type_t::u8>(constant);
380 } else if (from == element::boolean && to == element::i32) {
381 new_const = change_constant_precision<element::Type_t::boolean, element::Type_t::i32>(constant);
383 throw ngraph_error("not supported");
385 for (auto & output : consumers) {
386 output.replace_source_output(new_const);
389 new_const->validate_and_infer_types();
390 if (constant->get_output_target_inputs(0).size() == consumers.size()) {
391 new_const->set_friendly_name(constant->get_friendly_name());