1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "transformations/low_precision/fuse_fake_quantize.hpp"
7 #include <ngraph/ngraph.hpp>
8 #include "transformations/low_precision/common/ie_lpt_exception.hpp"
9 #include "transformations/low_precision/network_helper.hpp"
13 namespace low_precision {
15 void FuseFakeQuantizeTransformation::registerMatcherIn(GraphRewrite &pass, TransformationContext &context) const {
16 addSingleNodePattern<opset1::FakeQuantize>(pass, context);
19 bool FuseFakeQuantizeTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) const {
20 std::shared_ptr<opset1::FakeQuantize> fakeQuantize = as_type_ptr<ngraph::opset1::FakeQuantize>(m.get_match_root());
22 fakeQuantize = handle(context, fakeQuantize);
23 } while (fakeQuantize != nullptr);
29 std::shared_ptr<Node> updateShape(std::shared_ptr<Node> op, const Shape& targetShape) {
30 const Shape shape = op->get_output_shape(0);
31 if ((shape.size() < targetShape.size()) && (shape.size() > 1ul)) {
32 op = fold<opset1::Unsqueeze>(
34 std::make_shared<opset1::Constant>(ngraph::element::i32, Shape{ 1 }, std::vector<size_t>({ 0ul })));
39 std::shared_ptr<Node> getData(const std::shared_ptr<Node>& eltwise) {
40 if (!is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(0))) {
41 return eltwise->get_input_node_shared_ptr(0);
44 if (!is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(1))) {
45 return eltwise->get_input_node_shared_ptr(1);
51 std::shared_ptr<opset1::Constant> getConstant(const std::shared_ptr<Node>& eltwise) {
52 if (eltwise->get_input_size() != 2) {
56 std::shared_ptr<opset1::Constant> constant = as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(1));
57 if (constant != nullptr) {
61 return as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(0));
64 bool eltwiseWithConstant(const std::shared_ptr<Node>& eltwise) {
65 std::shared_ptr<opset1::Constant> constant = getConstant(eltwise);
66 if (constant == nullptr) {
70 Shape shape = constant->get_output_shape(0);
71 if ((!shape.empty()) && (shape_size(shape) != 1ul)) {
72 const Shape eltwiseShape = eltwise->get_output_shape(0);
73 if ((eltwiseShape.size() - shape.size()) > 1) {
77 if ((eltwiseShape.size() - shape.size()) == 1ul) {
78 shape.insert(shape.begin(), 1ul);
81 for (size_t i = 2ul; i < shape.size(); ++i) {
82 if (shape[i] != 1ul) {
88 return getData(eltwise) != nullptr;
91 } // namespace fuse_fq
93 std::shared_ptr<opset1::FakeQuantize> FuseFakeQuantizeTransformation::handle(
94 TransformationContext& context,
95 const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) const {
96 const std::shared_ptr<Node> eltwise = fakeQuantize->get_input_node_shared_ptr(0);
98 std::shared_ptr<Node> inputLowConst = fakeQuantize->get_input_node_shared_ptr(1);
99 std::shared_ptr<Node> inputHightConst = fakeQuantize->get_input_node_shared_ptr(2);
101 std::shared_ptr<opset1::Constant> constant = fuse_fq::getConstant(eltwise);
102 if (is_type<opset1::Multiply>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
103 const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
105 fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
107 inputLowConst = fuse_fq::updateShape(fold<opset1::Divide>(inputLowConst, value), fakeQuantize->get_output_shape(0));
108 inputHightConst = fuse_fq::updateShape(fold<opset1::Divide>(inputHightConst, value), fakeQuantize->get_output_shape(0));
109 } else if (is_type<opset1::Divide>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
110 const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
112 fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
114 inputLowConst = fuse_fq::updateShape(fold<opset1::Multiply>(inputLowConst, value), fakeQuantize->get_output_shape(0));
115 inputHightConst = fuse_fq::updateShape(fold<opset1::Multiply>(inputHightConst, value), fakeQuantize->get_output_shape(0));
116 } else if (is_type<opset1::Subtract>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
117 const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
119 fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
121 inputLowConst = fuse_fq::updateShape(fold<opset1::Add>(inputLowConst, value), fakeQuantize->get_output_shape(0));
122 inputHightConst = fuse_fq::updateShape(fold<opset1::Add>(inputHightConst, value), fakeQuantize->get_output_shape(0));
123 } else if (is_type<opset1::Add>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
124 if (is_type<opset1::Convolution>(fuse_fq::getData(eltwise)) ||
125 is_type<opset1::GroupConvolution>(fuse_fq::getData(eltwise))) {
129 const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
131 fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
133 inputLowConst = fuse_fq::updateShape(fold<opset1::Subtract>(inputLowConst, value), fakeQuantize->get_output_shape(0));
134 inputHightConst = fuse_fq::updateShape(fold<opset1::Subtract>(inputHightConst, value), fakeQuantize->get_output_shape(0));
135 } else if (is_type<opset1::Convert>(eltwise)) {
137 if ((eltwise->input(0).get_element_type() == element::i32) && (eltwise->output(0).get_element_type() == element::f32)) {
144 std::shared_ptr<opset1::FakeQuantize> newFakeQuantize = as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({
145 fuse_fq::getData(eltwise),
148 fakeQuantize->input_value(3),
149 fakeQuantize->input_value(4) }));
151 replace_node(fakeQuantize, newFakeQuantize);
152 NetworkHelper::copyInfo(fakeQuantize, newFakeQuantize);
153 return newFakeQuantize;
156 bool FuseFakeQuantizeTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
160 } // namespace low_precision
162 } // namespace ngraph