[IE] Fix UNITY build (#2799)
[platform/upstream/dldt.git] / inference-engine / src / transformations / src / transformations / low_precision / fuse_fake_quantize.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "transformations/low_precision/fuse_fake_quantize.hpp"
6 #include <memory>
7 #include <ngraph/ngraph.hpp>
8 #include "transformations/low_precision/common/ie_lpt_exception.hpp"
9 #include "transformations/low_precision/network_helper.hpp"
10
11 namespace ngraph {
12 namespace pass {
13 namespace low_precision {
14
15 void FuseFakeQuantizeTransformation::registerMatcherIn(GraphRewrite &pass, TransformationContext &context) const {
16     addSingleNodePattern<opset1::FakeQuantize>(pass, context);
17 }
18
19 bool FuseFakeQuantizeTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) const {
20     std::shared_ptr<opset1::FakeQuantize> fakeQuantize = as_type_ptr<ngraph::opset1::FakeQuantize>(m.get_match_root());
21     do {
22         fakeQuantize = handle(context, fakeQuantize);
23     } while (fakeQuantize != nullptr);
24     return true;
25 }
26
27 namespace fuse_fq {
28
29 std::shared_ptr<Node> updateShape(std::shared_ptr<Node> op, const Shape& targetShape) {
30     const Shape shape = op->get_output_shape(0);
31     if ((shape.size() < targetShape.size()) && (shape.size() > 1ul)) {
32         op = fold<opset1::Unsqueeze>(
33             op,
34             std::make_shared<opset1::Constant>(ngraph::element::i32, Shape{ 1 }, std::vector<size_t>({ 0ul })));
35     }
36     return op;
37 }
38
39 std::shared_ptr<Node> getData(const std::shared_ptr<Node>& eltwise) {
40     if (!is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(0))) {
41         return eltwise->get_input_node_shared_ptr(0);
42     }
43
44     if (!is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(1))) {
45         return eltwise->get_input_node_shared_ptr(1);
46     }
47
48     return nullptr;
49 }
50
51 std::shared_ptr<opset1::Constant> getConstant(const std::shared_ptr<Node>& eltwise) {
52     if (eltwise->get_input_size() != 2) {
53         return nullptr;
54     }
55
56     std::shared_ptr<opset1::Constant> constant = as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(1));
57     if (constant != nullptr) {
58         return constant;
59     }
60
61     return as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(0));
62 }
63
64 bool eltwiseWithConstant(const std::shared_ptr<Node>& eltwise) {
65     std::shared_ptr<opset1::Constant> constant = getConstant(eltwise);
66     if (constant == nullptr) {
67         return false;
68     }
69
70     Shape shape = constant->get_output_shape(0);
71     if ((!shape.empty()) && (shape_size(shape) != 1ul)) {
72         const Shape eltwiseShape = eltwise->get_output_shape(0);
73         if ((eltwiseShape.size() - shape.size()) > 1) {
74             return false;
75         }
76
77         if ((eltwiseShape.size() - shape.size()) == 1ul) {
78             shape.insert(shape.begin(), 1ul);
79         }
80
81         for (size_t i = 2ul; i < shape.size(); ++i) {
82             if (shape[i] != 1ul) {
83                 return false;
84             }
85         }
86     }
87
88     return getData(eltwise) != nullptr;
89 }
90
91 }  // namespace fuse_fq
92
93 std::shared_ptr<opset1::FakeQuantize> FuseFakeQuantizeTransformation::handle(
94     TransformationContext& context,
95     const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) const {
96     const std::shared_ptr<Node> eltwise = fakeQuantize->get_input_node_shared_ptr(0);
97
98     std::shared_ptr<Node> inputLowConst = fakeQuantize->get_input_node_shared_ptr(1);
99     std::shared_ptr<Node> inputHightConst = fakeQuantize->get_input_node_shared_ptr(2);
100
101     std::shared_ptr<opset1::Constant> constant = fuse_fq::getConstant(eltwise);
102     if (is_type<opset1::Multiply>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
103         const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
104             constant :
105             fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
106
107         inputLowConst = fuse_fq::updateShape(fold<opset1::Divide>(inputLowConst, value), fakeQuantize->get_output_shape(0));
108         inputHightConst = fuse_fq::updateShape(fold<opset1::Divide>(inputHightConst, value), fakeQuantize->get_output_shape(0));
109     } else if (is_type<opset1::Divide>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
110         const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
111             constant :
112             fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
113
114         inputLowConst = fuse_fq::updateShape(fold<opset1::Multiply>(inputLowConst, value), fakeQuantize->get_output_shape(0));
115         inputHightConst = fuse_fq::updateShape(fold<opset1::Multiply>(inputHightConst, value), fakeQuantize->get_output_shape(0));
116     } else if (is_type<opset1::Subtract>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
117         const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
118             constant :
119             fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
120
121         inputLowConst = fuse_fq::updateShape(fold<opset1::Add>(inputLowConst, value), fakeQuantize->get_output_shape(0));
122         inputHightConst = fuse_fq::updateShape(fold<opset1::Add>(inputHightConst, value), fakeQuantize->get_output_shape(0));
123     } else if (is_type<opset1::Add>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
124         if (is_type<opset1::Convolution>(fuse_fq::getData(eltwise)) ||
125             is_type<opset1::GroupConvolution>(fuse_fq::getData(eltwise))) {
126             return nullptr;
127         }
128
129         const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
130             constant :
131             fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
132
133         inputLowConst = fuse_fq::updateShape(fold<opset1::Subtract>(inputLowConst, value), fakeQuantize->get_output_shape(0));
134         inputHightConst = fuse_fq::updateShape(fold<opset1::Subtract>(inputHightConst, value), fakeQuantize->get_output_shape(0));
135     } else if (is_type<opset1::Convert>(eltwise)) {
136         // issue #40611
137         if ((eltwise->input(0).get_element_type() == element::i32) && (eltwise->output(0).get_element_type() == element::f32)) {
138             return nullptr;
139         }
140     } else {
141         return nullptr;
142     }
143
144     std::shared_ptr<opset1::FakeQuantize> newFakeQuantize = as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({
145         fuse_fq::getData(eltwise),
146         inputLowConst,
147         inputHightConst,
148         fakeQuantize->input_value(3),
149         fakeQuantize->input_value(4) }));
150
151     replace_node(fakeQuantize, newFakeQuantize);
152     NetworkHelper::copyInfo(fakeQuantize, newFakeQuantize);
153     return newFakeQuantize;
154 }
155
156 bool FuseFakeQuantizeTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
157     return false;
158 }
159
160 } // namespace low_precision
161 } // namespace pass
162 } // namespace ngraph