1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "low_precision/fuse_convert.hpp"
11 #include "low_precision/common/ie_lpt_exception.hpp"
12 #include "low_precision/network_helper.hpp"
16 namespace low_precision {
18 void FuseConvertTransformation::registerMatcherIn(GraphRewrite &pass, TransformationContext &context) const {
22 make_op_pattern<opset1::Multiply>({ make_op_label<opset1::Convert>(), make_op_label<opset1::Constant>() }));
27 make_op_pattern<opset1::Subtract>({ make_op_label<opset1::Convert>(), make_op_label<opset1::Constant>() }));
32 make_op_pattern<opset1::Add>({ make_op_label<opset1::Convert>(), make_op_label<opset1::Constant>() }));
35 std::shared_ptr<Node> removeConvertIfPossibleForSubtract(
36 const std::shared_ptr<opset1::Convert>& convert,
37 const std::shared_ptr<opset1::Subtract>& subtract) {
38 std::shared_ptr<Node> newSubtract;
40 const element::Type precisionBeforeConvert = convert->input(0).get_element_type();
41 if (NetworkHelper::checkConstantValuePrecision(precisionBeforeConvert, subtract->get_input_node_shared_ptr(1))) {
42 newSubtract = std::make_shared<ngraph::op::TypeRelaxed<opset1::Subtract>>(
43 std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{},
44 ngraph::op::TemporaryReplaceOutputType(convert->get_input_source_output(0), element::f32).get(),
45 ngraph::op::TemporaryReplaceOutputType(subtract->get_input_node_shared_ptr(1), element::f32).get());
46 NetworkHelper::setOutDataPrecisionForTypeRelaxed(newSubtract, subtract->get_output_element_type(0));
47 replace_node(subtract, newSubtract);
53 bool FuseConvertTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) const {
54 auto op = m.get_match_root();
55 if (!canBeTransformed(context, op)) {
59 std::shared_ptr<opset1::Convert> convert = as_type_ptr<opset1::Convert>(op->get_input_node_shared_ptr(0));
61 if (convert == nullptr) {
65 std::shared_ptr<Node> parent = convert->get_input_node_shared_ptr(0);
67 if (is_type<opset1::Constant>(parent)) {
68 auto convertedConstant = fold<opset1::Convert>(parent, convert->get_convert_element_type());
69 NetworkHelper::copyInfo(parent, convertedConstant);
70 replace_node(convert, convertedConstant);
72 std::shared_ptr<Node> newOp;
73 if (is_type<opset1::Subtract>(op)) {
74 auto subtract = as_type_ptr<opset1::Subtract>(op);
75 newOp = removeConvertIfPossibleForSubtract(convert, subtract);
76 } else if (is_type<opset1::Multiply>(op)) {
77 newOp = std::make_shared<ngraph::op::TypeRelaxed<opset1::Multiply>>(
78 std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{},
79 ngraph::op::TemporaryReplaceOutputType(convert->get_input_source_output(0), element::f32).get(),
80 ngraph::op::TemporaryReplaceOutputType(op->get_input_node_shared_ptr(1), element::f32).get());
81 NetworkHelper::setOutDataPrecisionForTypeRelaxed(newOp, op->get_output_element_type(0));
82 replace_node(op, newOp);
83 } else if (is_type<opset1::Add>(op)) {
84 newOp = std::make_shared<ngraph::op::TypeRelaxed<opset1::Add>>(
85 std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{},
86 ngraph::op::TemporaryReplaceOutputType(convert->get_input_source_output(0), element::f32).get(),
87 ngraph::op::TemporaryReplaceOutputType(op->get_input_node_shared_ptr(1), element::f32).get());
88 NetworkHelper::setOutDataPrecisionForTypeRelaxed(newOp, op->get_output_element_type(0));
89 replace_node(op, newOp);
92 if (newOp != nullptr) {
93 NetworkHelper::copyInfo(op, newOp);
100 bool FuseConvertTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const {
104 bool FuseConvertTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
108 } // namespace low_precision
110 } // namespace ngraph