1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "low_precision/subtract.hpp"
10 #include <unordered_set>
14 #include "low_precision/common/ie_lpt_exception.hpp"
15 #include "low_precision/network_helper.hpp"
19 namespace low_precision {
21 void SubtractTransformation::registerMatcherIn(GraphRewrite &pass, TransformationContext &context) const {
25 make_op_pattern<opset1::Subtract>({ make_op_label<opset1::Multiply>(), make_op_label<opset1::Constant>() }));
30 make_op_pattern<opset1::Subtract>({ make_op_label<opset1::Convert>(), make_op_label<opset1::Constant>() }));
33 bool SubtractTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) const {
34 std::shared_ptr<opset1::Subtract> subtract = as_type_ptr<opset1::Subtract>(m.get_match_root());
35 if (!canBeTransformed(context, subtract)) {
39 const ngraph::element::Type originalPrecision = subtract->get_output_element_type(0);
41 const FakeQuantizeDequantization dequantization = ngraph::pass::low_precision::NetworkHelper::getDequantization(subtract);
42 if (dequantization.multiply != nullptr) {
43 // before: Y = X * SC - SH, after: Y = (X - SH') * SC
44 // X * SC - SH = X * SC - SH' * SC
46 std::shared_ptr<opset1::Subtract> newSubtract = as_type_ptr<opset1::Subtract>(subtract->copy_with_new_inputs({
47 dequantization.multiply->get_input_node_shared_ptr(0),
48 ngraph::pass::low_precision::fold<ngraph::opset1::Divide>(
49 subtract->get_input_node_shared_ptr(1),
50 dequantization.multiply->get_input_node_shared_ptr(1))
53 std::shared_ptr<Node> newMultiply = dequantization.multiply->copy_with_new_inputs({
55 dequantization.multiply->input_value(1)
58 replace_node(subtract, newMultiply);
59 subtract = newSubtract;
62 if (dequantization.subtract != nullptr) {
63 std::shared_ptr<opset1::Subtract> newSubtract = as_type_ptr<opset1::Subtract>(subtract->copy_with_new_inputs({
64 dequantization.subtract->get_input_node_shared_ptr(0),
65 ngraph::pass::low_precision::fold<ngraph::opset1::Add>(
66 subtract->get_input_node_shared_ptr(1),
67 dequantization.subtract->get_input_node_shared_ptr(1))
70 replace_node(subtract, newSubtract);
71 subtract = newSubtract;
74 if (dequantization.convert != nullptr) {
76 // std::shared_ptr<Node> newSubtract = NetworkHelper::optimizeElementwise(subtract);
77 subtract->set_output_type(0, originalPrecision, subtract->get_output_partial_shape(0));
79 replace_node(subtract, std::make_shared<op::TypeRelaxed<opset1::Subtract>>(
80 subtract->get_input_node_shared_ptr(0),
81 subtract->get_input_node_shared_ptr(1)));
86 } // namespace low_precision