9195d1345bd4cb4eb1a26a369cc04c4dcace415e
[platform/upstream/dldt.git] / inference-engine / src / low_precision_transformations / src / common / fuse_convert.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "low_precision/fuse_convert.hpp"
6
7 #include <memory>
8 #include <string>
9 #include <vector>
10
11 #include "low_precision/common/ie_lpt_exception.hpp"
12 #include "low_precision/network_helper.hpp"
13
14 namespace ngraph {
15 namespace pass {
16 namespace low_precision {
17
18 void FuseConvertTransformation::registerMatcherIn(GraphRewrite &pass, TransformationContext &context) const {
19     addPattern(
20         pass,
21         context,
22         make_op_pattern<opset1::Multiply>({ make_op_label<opset1::Convert>(), make_op_label<opset1::Constant>() }));
23
24     addPattern(
25         pass,
26         context,
27         make_op_pattern<opset1::Subtract>({ make_op_label<opset1::Convert>(), make_op_label<opset1::Constant>() }));
28
29     addPattern(
30         pass,
31         context,
32         make_op_pattern<opset1::Add>({ make_op_label<opset1::Convert>(), make_op_label<opset1::Constant>() }));
33 }
34
35 std::shared_ptr<Node> removeConvertIfPossibleForSubtract(
36     const std::shared_ptr<opset1::Convert>& convert,
37     const std::shared_ptr<opset1::Subtract>& subtract) {
38     std::shared_ptr<Node> newSubtract;
39
40     const element::Type precisionBeforeConvert = convert->input(0).get_element_type();
41     if (NetworkHelper::checkConstantValuePrecision(precisionBeforeConvert, subtract->get_input_node_shared_ptr(1))) {
42         newSubtract = std::make_shared<ngraph::op::TypeRelaxed<opset1::Subtract>>(
43             std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{},
44             ngraph::op::TemporaryReplaceOutputType(convert->get_input_source_output(0), element::f32).get(),
45             ngraph::op::TemporaryReplaceOutputType(subtract->get_input_node_shared_ptr(1), element::f32).get());
46         NetworkHelper::setOutDataPrecisionForTypeRelaxed(newSubtract, subtract->get_output_element_type(0));
47         replace_node(subtract, newSubtract);
48     }
49
50     return newSubtract;
51 }
52
53 bool FuseConvertTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) const {
54     auto op = m.get_match_root();
55     if (!canBeTransformed(context, op)) {
56         return false;
57     }
58
59     std::shared_ptr<opset1::Convert> convert = as_type_ptr<opset1::Convert>(op->get_input_node_shared_ptr(0));
60     // issue #40395
61     if (convert == nullptr) {
62         return false;
63     }
64
65     std::shared_ptr<Node> parent = convert->get_input_node_shared_ptr(0);
66
67     if (is_type<opset1::Constant>(parent)) {
68         auto convertedConstant = fold<opset1::Convert>(parent, convert->get_convert_element_type());
69         NetworkHelper::copyInfo(parent, convertedConstant);
70         replace_node(convert, convertedConstant);
71     } else {
72         std::shared_ptr<Node> newOp;
73         if (is_type<opset1::Subtract>(op)) {
74             auto subtract = as_type_ptr<opset1::Subtract>(op);
75             newOp = removeConvertIfPossibleForSubtract(convert, subtract);
76         } else if (is_type<opset1::Multiply>(op)) {
77             newOp = std::make_shared<ngraph::op::TypeRelaxed<opset1::Multiply>>(
78                     std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{},
79                     ngraph::op::TemporaryReplaceOutputType(convert->get_input_source_output(0), element::f32).get(),
80                     ngraph::op::TemporaryReplaceOutputType(op->get_input_node_shared_ptr(1), element::f32).get());
81             NetworkHelper::setOutDataPrecisionForTypeRelaxed(newOp, op->get_output_element_type(0));
82             replace_node(op, newOp);
83         } else if (is_type<opset1::Add>(op)) {
84             newOp = std::make_shared<ngraph::op::TypeRelaxed<opset1::Add>>(
85                     std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{},
86                     ngraph::op::TemporaryReplaceOutputType(convert->get_input_source_output(0), element::f32).get(),
87                     ngraph::op::TemporaryReplaceOutputType(op->get_input_node_shared_ptr(1), element::f32).get());
88             NetworkHelper::setOutDataPrecisionForTypeRelaxed(newOp, op->get_output_element_type(0));
89             replace_node(op, newOp);
90         }
91
92         if (newOp != nullptr) {
93             NetworkHelper::copyInfo(op, newOp);
94         }
95     }
96
97     return true;
98 }
99
100 bool FuseConvertTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const {
101     return true;
102 }
103
104 bool FuseConvertTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
105     return false;
106 }
107
108 } // namespace low_precision
109 } // namespace pass
110 } // namespace ngraph