1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "transformations/low_precision/fake_quantize.hpp"
16 #include <ngraph/opsets/opset1.hpp>
18 #include "transformations/low_precision/common/ie_lpt_exception.hpp"
19 #include "transformations/low_precision/network_helper.hpp"
23 namespace low_precision {
25 void FakeQuantizeTransformation::registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const {
26 addSingleNodePattern<opset1::FakeQuantize>(pass, context);
29 bool FakeQuantizeTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) const {
30 std::shared_ptr<opset1::FakeQuantize> layer = std::dynamic_pointer_cast<opset1::FakeQuantize>(m.get_match_root());
32 std::shared_ptr<opset1::FakeQuantize> fakeQuantize = layer;
36 fakeQuantize = fuseElementwise(context, fakeQuantize);
37 } while (fakeQuantize != nullptr);
39 const ngraph::element::Type precision = layer->get_output_element_type(0);
40 if ((precision == ngraph::element::i8) || (precision == ngraph::element::u8)) {
44 // FakeQuantize on weights are used without dequantization ScaleShifts
45 if (NetworkHelper::onWeights(layer)) {
49 if (as_type<opset1::Constant>(layer->get_input_node_ptr(0))) {
50 bool nextOpearionsWillBeNotHandled = true;
51 for (auto output : layer->outputs()) {
52 for (auto input : output.get_target_inputs()) {
53 auto activations = paramsManager->getPrecisionsOnActivations(*input.get_node());
54 if (paramsManager->getPrecisionsOnActivations(*input.get_node()).size() != 0ul) {
55 nextOpearionsWillBeNotHandled = false;
60 if (!nextOpearionsWillBeNotHandled) {
65 if (nextOpearionsWillBeNotHandled) {
66 const std::shared_ptr<ngraph::Node> resultConstant = NetworkHelper::fold_fake_quantize(layer);
67 if (as_type_ptr<opset1::Constant>(resultConstant)) {
68 replace_node(layer, resultConstant);
74 if (!QuantizationDetails::outputLayoutIsSupported(layer)) {
78 if (!QuantizationDetails::isSupportedLevel(layer->get_levels())) {
82 const QuantizationDetails quantizationDetails = QuantizationDetails::getDetails(layer);
83 const DataPrecision dataPrecision = getDataPrecision(layer, quantizationDetails, false);
84 if (dataPrecision.precision == element::undefined) {
88 // Split FakeQuantize to two parts: Quantize and Dequantize
89 auto QDQ = NetworkHelper::decomposeFakeQuantize(
90 as_type_ptr<opset1::FakeQuantize>(layer),
91 dataPrecision.precision,
94 dataPrecision.hasZeroPoint,
97 #ifdef LPT_PRINT_DEQUANTIZATION_INFO
99 const std::shared_ptr<opset1::Multiply> multiply = as_type_ptr<opset1::Multiply>(std::get<1>(QDQ));
100 const std::shared_ptr<opset1::Constant> multiplyConst = as_type_ptr<opset1::Constant>(multiply->get_input_node_shared_ptr(1));
101 const std::vector<float> dequantizationScales = multiplyConst->cast_vector<float>();
103 const std::shared_ptr<opset1::Subtract> subtract = as_type_ptr<opset1::Subtract>(multiply->get_input_node_shared_ptr(0));
104 std::vector<float> dequantizationShifts;
105 if (subtract != nullptr) {
106 const std::shared_ptr<opset1::Constant> subtractConst = as_type_ptr<opset1::Constant>(subtract->get_input_node_shared_ptr(1));
107 dequantizationShifts = subtractConst->cast_vector<float>();
109 dequantizationShifts = std::vector<float>(dequantizationScales.size());
112 printDequantizationValues(dequantizationScales, dequantizationShifts);
116 std::shared_ptr<ngraph::Node> dequantize = std::get<1>(QDQ);
117 updateOutput(context, dequantize, layer);
124 static std::shared_ptr<Node> updateShape(std::shared_ptr<Node> op, const Shape& targetShape) {
125 const Shape shape = op->get_output_shape(0);
126 if ((shape.size() < targetShape.size()) && (shape.size() > 1ul)) {
127 op = fold<opset1::Unsqueeze>(
129 std::make_shared<opset1::Constant>(ngraph::element::i32, Shape{ 1 }, std::vector<size_t>({ 0ul })));
134 static std::shared_ptr<Node> getData(const std::shared_ptr<Node>& eltwise) {
135 if (!is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(0))) {
136 return eltwise->get_input_node_shared_ptr(0);
139 if (!is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(1))) {
140 return eltwise->get_input_node_shared_ptr(1);
146 static std::shared_ptr<opset1::Constant> getConstant(const std::shared_ptr<Node>& eltwise) {
147 if (eltwise->get_input_size() != 2) {
151 std::shared_ptr<opset1::Constant> constant = as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(1));
152 if (constant != nullptr) {
156 return as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(0));
161 bool FakeQuantizeTransformation::checkElementwise(const std::shared_ptr<Node>& eltwise) {
162 std::shared_ptr<opset1::Constant> constant = fq::getConstant(eltwise);
163 if (constant == nullptr) {
167 Shape shape = constant->get_output_shape(0);
168 if ((!shape.empty()) && (shape_size(shape) != 1ul)) {
169 const Shape eltwiseShape = eltwise->get_output_shape(0);
170 if ((eltwiseShape.size() - shape.size()) > 1) {
174 if ((eltwiseShape.size() - shape.size()) == 1ul) {
175 shape.insert(shape.begin(), 1ul);
178 for (size_t i = 2ul; i < shape.size(); ++i) {
179 if (shape[i] != 1ul) {
185 return fq::getData(eltwise) != nullptr;
188 std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwise(
189 TransformationContext& context,
190 const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) const {
191 const std::shared_ptr<Node> eltwise = fakeQuantize->get_input_node_shared_ptr(0);
193 std::shared_ptr<Node> inputLowConst = fakeQuantize->get_input_node_shared_ptr(1);
194 std::shared_ptr<Node> inputHightConst = fakeQuantize->get_input_node_shared_ptr(2);
196 std::shared_ptr<opset1::Constant> constant = fq::getConstant(eltwise);
197 if (is_type<opset1::Multiply>(eltwise) && checkElementwise(eltwise)) {
198 const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
200 fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
202 const auto valueVec = as_type_ptr<opset1::Constant>(value)->cast_vector<float>();
203 // TODO: temporary fix for GPU Plugin (inverted intervals)
204 for (const float& val : valueVec) {
210 inputLowConst = fq::updateShape(fold<opset1::Divide>(inputLowConst, value), fakeQuantize->get_output_shape(0));
211 inputHightConst = fq::updateShape(fold<opset1::Divide>(inputHightConst, value), fakeQuantize->get_output_shape(0));
212 } else if (is_type<opset1::Divide>(eltwise) && checkElementwise(eltwise)) {
213 const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
215 fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
217 const auto valueVec = as_type_ptr<opset1::Constant>(value)->cast_vector<float>();
218 // TODO: temporary fix for GPU Plugin (inverted intervals)
219 for (const float& val : valueVec) {
225 inputLowConst = fq::updateShape(fold<opset1::Multiply>(inputLowConst, value), fakeQuantize->get_output_shape(0));
226 inputHightConst = fq::updateShape(fold<opset1::Multiply>(inputHightConst, value), fakeQuantize->get_output_shape(0));
227 } else if (is_type<opset1::Subtract>(eltwise) && checkElementwise(eltwise)) {
228 const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
230 fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
232 inputLowConst = fq::updateShape(fold<opset1::Add>(inputLowConst, value), fakeQuantize->get_output_shape(0));
233 inputHightConst = fq::updateShape(fold<opset1::Add>(inputHightConst, value), fakeQuantize->get_output_shape(0));
234 } else if (is_type<opset1::Add>(eltwise) && checkElementwise(eltwise)) {
235 if (is_type<opset1::Convolution>(fq::getData(eltwise)) ||
236 is_type<opset1::GroupConvolution>(fq::getData(eltwise))) {
240 const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
242 fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
244 inputLowConst = fq::updateShape(fold<opset1::Subtract>(inputLowConst, value), fakeQuantize->get_output_shape(0));
245 inputHightConst = fq::updateShape(fold<opset1::Subtract>(inputHightConst, value), fakeQuantize->get_output_shape(0));
246 } else if (is_type<opset1::Convert>(eltwise)) {
248 if ((eltwise->input(0).get_element_type() == element::i32) && (eltwise->output(0).get_element_type() == element::f32)) {
255 std::shared_ptr<opset1::FakeQuantize> newFakeQuantize = as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({
256 fq::getData(eltwise),
259 fakeQuantize->input_value(3),
260 fakeQuantize->input_value(4) }));
262 replace_node(fakeQuantize, newFakeQuantize);
263 NetworkHelper::copyInfo(fakeQuantize, newFakeQuantize);
265 return newFakeQuantize;
268 bool FakeQuantizeTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
271 } // namespace low_precision
273 } // namespace ngraph