1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "transformations/low_precision/fake_quantize.hpp"
16 #include <ngraph/opsets/opset1.hpp>
18 #include "transformations/low_precision/common/ie_lpt_exception.hpp"
19 #include "transformations/low_precision/network_helper.hpp"
23 namespace low_precision {
25 void FakeQuantizeTransformation::registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const {
26 addSingleNodePattern<opset1::FakeQuantize>(pass, context);
29 bool FakeQuantizeTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) const {
30 std::shared_ptr<opset1::FakeQuantize> layer = std::dynamic_pointer_cast<opset1::FakeQuantize>(m.get_match_root());
32 std::shared_ptr<opset1::FakeQuantize> fakeQuantize = layer;
36 fakeQuantize = fuseElementwise(context, fakeQuantize);
37 } while (fakeQuantize != nullptr);
39 const ngraph::element::Type precision = layer->get_output_element_type(0);
40 if ((precision == ngraph::element::i8) || (precision == ngraph::element::u8)) {
44 // FakeQuantize on weights are used without dequantization ScaleShifts
45 if (NetworkHelper::onWeights(layer)) {
49 if (as_type<opset1::Constant>(layer->get_input_node_ptr(0))) {
50 bool nextOpearionsWillBeNotHandled = true;
51 for (auto output : layer->outputs()) {
52 for (auto input : output.get_target_inputs()) {
53 auto activations = paramsManager->getPrecisionsOnActivations(*input.get_node());
54 if (paramsManager->getPrecisionsOnActivations(*input.get_node()).size() != 0ul) {
55 nextOpearionsWillBeNotHandled = false;
60 if (!nextOpearionsWillBeNotHandled) {
65 if (nextOpearionsWillBeNotHandled) {
66 const std::shared_ptr<ngraph::Node> resultConstant = NetworkHelper::fold_fake_quantize(layer);
67 if (as_type_ptr<opset1::Constant>(resultConstant)) {
68 replace_node(layer, resultConstant);
74 if (!QuantizationDetails::outputLayoutIsSupported(layer)) {
78 if (!QuantizationDetails::isSupportedLevel(layer->get_levels())) {
82 const QuantizationDetails quantizationDetails = QuantizationDetails::getDetails(layer);
83 const DataPrecision dataPrecision = getDataPrecision(layer, quantizationDetails, false);
84 if (dataPrecision.precision == element::undefined) {
88 // Split FakeQuantize to two parts: Quantize and Dequantize
89 auto QDQ = NetworkHelper::decomposeFakeQuantize(
90 as_type_ptr<opset1::FakeQuantize>(layer),
91 dataPrecision.precision,
94 dataPrecision.hasZeroPoint,
97 #ifdef LPT_PRINT_DEQUANTIZATION_INFO
99 const std::shared_ptr<opset1::Multiply> multiply = as_type_ptr<opset1::Multiply>(std::get<1>(QDQ));
100 const std::shared_ptr<opset1::Constant> multiplyConst = as_type_ptr<opset1::Constant>(multiply->get_input_node_shared_ptr(1));
101 const std::vector<float> dequantizationScales = multiplyConst->cast_vector<float>();
103 const std::shared_ptr<opset1::Subtract> subtract = as_type_ptr<opset1::Subtract>(multiply->get_input_node_shared_ptr(0));
104 std::vector<float> dequantizationShifts;
105 if (subtract != nullptr) {
106 const std::shared_ptr<opset1::Constant> subtractConst = as_type_ptr<opset1::Constant>(subtract->get_input_node_shared_ptr(1));
107 dequantizationShifts = subtractConst->cast_vector<float>();
109 dequantizationShifts = std::vector<float>(dequantizationScales.size());
112 printDequantizationValues(dequantizationScales, dequantizationShifts);
116 std::shared_ptr<ngraph::Node> dequantize = std::get<1>(QDQ);
117 updateOutput(context, dequantize, layer);
122 static std::shared_ptr<Node> updateShape(std::shared_ptr<Node> op, const Shape& targetShape) {
123 const Shape shape = op->get_output_shape(0);
124 if ((shape.size() < targetShape.size()) && (shape.size() > 1ul)) {
125 op = fold<opset1::Unsqueeze>(
127 std::make_shared<opset1::Constant>(ngraph::element::i32, Shape{ 1 }, std::vector<size_t>({ 0ul })));
132 static std::shared_ptr<Node> getData(const std::shared_ptr<Node>& eltwise) {
133 if (!is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(0))) {
134 return eltwise->get_input_node_shared_ptr(0);
137 if (!is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(1))) {
138 return eltwise->get_input_node_shared_ptr(1);
144 static std::shared_ptr<opset1::Constant> getConstant(const std::shared_ptr<Node>& eltwise) {
145 if (eltwise->get_input_size() != 2) {
149 std::shared_ptr<opset1::Constant> constant = as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(1));
150 if (constant != nullptr) {
154 return as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(0));
157 bool FakeQuantizeTransformation::checkElementwise(const std::shared_ptr<Node>& eltwise) {
158 std::shared_ptr<opset1::Constant> constant = getConstant(eltwise);
159 if (constant == nullptr) {
163 Shape shape = constant->get_output_shape(0);
164 if ((!shape.empty()) && (shape_size(shape) != 1ul)) {
165 const Shape eltwiseShape = eltwise->get_output_shape(0);
166 if ((eltwiseShape.size() - shape.size()) > 1) {
170 if ((eltwiseShape.size() - shape.size()) == 1ul) {
171 shape.insert(shape.begin(), 1ul);
174 for (size_t i = 2ul; i < shape.size(); ++i) {
175 if (shape[i] != 1ul) {
181 return getData(eltwise) != nullptr;
184 std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwise(
185 TransformationContext& context,
186 const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) const {
187 const std::shared_ptr<Node> eltwise = fakeQuantize->get_input_node_shared_ptr(0);
189 std::shared_ptr<Node> inputLowConst = fakeQuantize->get_input_node_shared_ptr(1);
190 std::shared_ptr<Node> inputHightConst = fakeQuantize->get_input_node_shared_ptr(2);
192 std::shared_ptr<opset1::Constant> constant = getConstant(eltwise);
193 if (is_type<opset1::Multiply>(eltwise) && checkElementwise(eltwise)) {
194 const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
196 fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
198 const auto valueVec = as_type_ptr<opset1::Constant>(value)->cast_vector<float>();
199 // TODO: temporary fix for GPU Plugin (inverted intervals)
200 for (const float& val : valueVec) {
206 inputLowConst = updateShape(fold<opset1::Divide>(inputLowConst, value), fakeQuantize->get_output_shape(0));
207 inputHightConst = updateShape(fold<opset1::Divide>(inputHightConst, value), fakeQuantize->get_output_shape(0));
208 } else if (is_type<opset1::Divide>(eltwise) && checkElementwise(eltwise)) {
209 const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
211 fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
213 const auto valueVec = as_type_ptr<opset1::Constant>(value)->cast_vector<float>();
214 // TODO: temporary fix for GPU Plugin (inverted intervals)
215 for (const float& val : valueVec) {
221 inputLowConst = updateShape(fold<opset1::Multiply>(inputLowConst, value), fakeQuantize->get_output_shape(0));
222 inputHightConst = updateShape(fold<opset1::Multiply>(inputHightConst, value), fakeQuantize->get_output_shape(0));
223 } else if (is_type<opset1::Subtract>(eltwise) && checkElementwise(eltwise)) {
224 const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
226 fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
228 inputLowConst = updateShape(fold<opset1::Add>(inputLowConst, value), fakeQuantize->get_output_shape(0));
229 inputHightConst = updateShape(fold<opset1::Add>(inputHightConst, value), fakeQuantize->get_output_shape(0));
230 } else if (is_type<opset1::Add>(eltwise) && checkElementwise(eltwise)) {
231 if (is_type<opset1::Convolution>(getData(eltwise)) ||
232 is_type<opset1::GroupConvolution>(getData(eltwise))) {
236 const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
238 fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
240 inputLowConst = updateShape(fold<opset1::Subtract>(inputLowConst, value), fakeQuantize->get_output_shape(0));
241 inputHightConst = updateShape(fold<opset1::Subtract>(inputHightConst, value), fakeQuantize->get_output_shape(0));
242 } else if (is_type<opset1::Convert>(eltwise)) {
244 if ((eltwise->input(0).get_element_type() == element::i32) && (eltwise->output(0).get_element_type() == element::f32)) {
251 std::shared_ptr<opset1::FakeQuantize> newFakeQuantize = as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({
255 fakeQuantize->input_value(3),
256 fakeQuantize->input_value(4) }));
258 replace_node(fakeQuantize, newFakeQuantize);
259 NetworkHelper::copyInfo(fakeQuantize, newFakeQuantize);
261 return newFakeQuantize;
264 bool FakeQuantizeTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
267 } // namespace low_precision
269 } // namespace ngraph