1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ngraph_functions/low_precision_transformations/common/builders.hpp"
10 #include <ngraph/opsets/opset1.hpp>
11 #include "ngraph_ops/type_relaxed.hpp"
12 #include "ngraph_functions/subgraph_builders.hpp"
13 #include "low_precision/network_helper.hpp"
19 std::shared_ptr<Node> makeDequantization(
20 const Output<Node>& data,
21 const DequantizationOperations& dequantizationOperations) {
22 Output<Node> parent = data;
24 if (!dequantizationOperations.convert.empty()) {
25 std::shared_ptr<ngraph::opset1::Convert> convert = std::make_shared<ngraph::pass::low_precision::DequantizationConvert>(
27 dequantizationOperations.convert.outPrecision);
31 if (!dequantizationOperations.subtract.empty()) {
32 std::shared_ptr<ngraph::opset1::Subtract> subtract;
34 std::vector<size_t> shape;
35 if (dequantizationOperations.subtract.constantShapeIsDefined) {
36 shape = dequantizationOperations.subtract.constantShape;
38 if (dequantizationOperations.subtract.values.size() == 1ul) {
39 shape = std::vector<size_t>({});
41 shape = std::vector<size_t>(parent.get_shape().size(), 1ul);
42 shape[shape.size() >= 2 ? 1ul : 0] = dequantizationOperations.subtract.values.size();
46 const auto subtractConst = std::make_shared<ngraph::opset1::Constant>(
47 dequantizationOperations.subtract.constantPrecision != element::undefined ?
48 dequantizationOperations.subtract.constantPrecision :
49 parent.get_element_type(),
51 dequantizationOperations.subtract.values);
53 if ((dequantizationOperations.subtract.outPrecision == element::undefined) ||
54 (dequantizationOperations.subtract.outPrecision == parent.get_element_type())) {
55 subtract = std::make_shared<ngraph::pass::low_precision::DequantizationSubtract>(parent, subtractConst);
57 subtract = std::make_shared<op::TypeRelaxed<ngraph::pass::low_precision::DequantizationSubtract>>(
58 std::vector<element::Type>{element::f32, element::f32},
59 std::vector<element::Type>{ element::f32 },
60 ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get(),
61 ngraph::op::TemporaryReplaceOutputType(subtractConst, element::f32).get());
62 ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(subtract, dequantizationOperations.subtract.outPrecision);
64 if (!dequantizationOperations.subtract.addDequantizationAttribute) {
65 ngraph::pass::low_precision::NetworkHelper::cleanRunTimeInfo(subtract);
70 if (!dequantizationOperations.multiply.empty()) {
71 std::vector<size_t> shape;
72 if (dequantizationOperations.multiply.constantShapeIsDefined) {
73 shape = dequantizationOperations.multiply.constantShape;
75 if (dequantizationOperations.multiply.values.size() == 1ul) {
76 shape = std::vector<size_t>({});
78 shape = std::vector<size_t>(parent.get_shape().size(), 1ul);
79 shape[shape.size() >= 2 ? 1ul : 0] = dequantizationOperations.multiply.values.size();
83 std::shared_ptr<ngraph::opset1::Multiply> multiply;
84 if ((dequantizationOperations.multiply.outPrecision == element::undefined) ||
85 (dequantizationOperations.multiply.outPrecision == parent.get_element_type())) {
86 const std::shared_ptr<ngraph::opset1::Constant> constant = std::make_shared<ngraph::opset1::Constant>(
87 parent.get_element_type(),
89 dequantizationOperations.multiply.values);
91 multiply = dequantizationOperations.multiply.constantIndex == 1ul ?
92 std::make_shared<ngraph::pass::low_precision::DequantizationMultiply>(parent, constant) :
93 std::make_shared<ngraph::pass::low_precision::DequantizationMultiply>(constant, parent);
95 const std::shared_ptr<ngraph::opset1::Constant> constant = std::make_shared<ngraph::opset1::Constant>(
96 dequantizationOperations.multiply.constantPrecision != element::undefined ?
97 dequantizationOperations.multiply.constantPrecision :
98 parent.get_element_type(),
100 dequantizationOperations.multiply.values);
102 multiply = dequantizationOperations.multiply.constantIndex == 1ul ?
103 std::make_shared<op::TypeRelaxed<ngraph::pass::low_precision::DequantizationMultiply>>(
104 std::vector<element::Type>{element::f32, element::f32},
105 std::vector<element::Type>{ element::f32 },
106 ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get(),
107 ngraph::op::TemporaryReplaceOutputType(constant, element::f32).get()) :
108 std::make_shared<op::TypeRelaxed<ngraph::pass::low_precision::DequantizationMultiply>>(
109 std::vector<element::Type>{element::f32, element::f32},
110 std::vector<element::Type>{ element::f32 },
111 ngraph::op::TemporaryReplaceOutputType(constant, element::f32).get(),
112 ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get());
118 return parent.get_node_shared_ptr();
121 std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantize(
122 const Output<Node>& input,
123 const ngraph::element::Type precision,
124 const FakeQuantizeOnData& fqOnData) {
125 return as_type_ptr<ngraph::opset1::FakeQuantize>(ngraph::builder::makeFakeQuantize(
128 fqOnData.quantizationLevel,
129 fqOnData.constantShape,
130 fqOnData.inputLowValues,
131 fqOnData.inputHighValues,
132 fqOnData.outputLowValues,
133 fqOnData.outputHighValues));
136 std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantizeTypeRelaxed(
137 const std::shared_ptr<ngraph::Node>& input,
138 const ngraph::element::Type precision,
139 const FakeQuantizeOnData& fqOnData) {
140 const std::shared_ptr<ngraph::opset1::FakeQuantize> fq = makeFakeQuantize(input, precision, fqOnData);
141 return std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::FakeQuantize>>(*fq, fqOnData.outputPrecision);
144 } // namespace subgraph
145 } // namespace builder
146 } // namespace ngraph