1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ngraph_functions/low_precision_transformations/common/builders.hpp"
10 #include <ngraph/opsets/opset1.hpp>
11 #include "ngraph_ops/type_relaxed.hpp"
12 #include "ngraph_functions/subgraph_builders.hpp"
13 #include "low_precision/network_helper.hpp"
19 std::shared_ptr<Node> makeDequantization(
20 const Output<Node>& data,
21 const DequantizationOperations& dequantizationOperations) {
22 Output<Node> parent = data;
24 if (!dequantizationOperations.convert.empty()) {
25 std::shared_ptr<ngraph::opset1::Convert> convert = std::make_shared<ngraph::pass::low_precision::DequantizationConvert>(
27 dequantizationOperations.convert.outPrecision);
28 ngraph::copy_runtime_info({ data.get_node_shared_ptr(), convert }, convert);
32 if (!dequantizationOperations.subtract.empty()) {
33 std::shared_ptr<ngraph::opset1::Subtract> subtract;
35 std::vector<size_t> shape;
36 if (dequantizationOperations.subtract.constantShapeIsDefined) {
37 shape = dequantizationOperations.subtract.constantShape;
39 if (dequantizationOperations.subtract.values.size() == 1ul) {
40 shape = std::vector<size_t>({});
42 shape = std::vector<size_t>(parent.get_shape().size(), 1ul);
43 shape[shape.size() >= 2 ? 1ul : 0] = dequantizationOperations.subtract.values.size();
47 const auto subtractConst = std::make_shared<ngraph::opset1::Constant>(
48 dequantizationOperations.subtract.constantPrecision != element::undefined ?
49 dequantizationOperations.subtract.constantPrecision :
50 parent.get_element_type(),
52 dequantizationOperations.subtract.values);
54 if ((dequantizationOperations.subtract.outPrecision == element::undefined) ||
55 (dequantizationOperations.subtract.outPrecision == parent.get_element_type())) {
56 subtract = std::make_shared<ngraph::pass::low_precision::DequantizationSubtract>(parent, subtractConst);
58 subtract = std::make_shared<op::TypeRelaxed<ngraph::pass::low_precision::DequantizationSubtract>>(
59 std::vector<element::Type>{element::f32, element::f32},
60 std::vector<element::Type>{ element::f32 },
61 ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get(),
62 ngraph::op::TemporaryReplaceOutputType(subtractConst, element::f32).get());
63 ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(subtract, dequantizationOperations.subtract.outPrecision);
65 if (!dequantizationOperations.subtract.addDequantizationAttribute) {
66 ngraph::pass::low_precision::NetworkHelper::cleanRunTimeInfo(subtract);
68 ngraph::copy_runtime_info({ data.get_node_shared_ptr(), subtract }, subtract);
72 if (!dequantizationOperations.multiply.empty()) {
73 std::vector<size_t> shape;
74 if (dequantizationOperations.multiply.constantShapeIsDefined) {
75 shape = dequantizationOperations.multiply.constantShape;
77 if (dequantizationOperations.multiply.values.size() == 1ul) {
78 shape = std::vector<size_t>({});
80 shape = std::vector<size_t>(parent.get_shape().size(), 1ul);
81 shape[shape.size() >= 2 ? 1ul : 0] = dequantizationOperations.multiply.values.size();
85 std::shared_ptr<ngraph::opset1::Multiply> multiply;
86 if ((dequantizationOperations.multiply.outPrecision == element::undefined) ||
87 (dequantizationOperations.multiply.outPrecision == parent.get_element_type())) {
88 const std::shared_ptr<ngraph::opset1::Constant> constant = std::make_shared<ngraph::opset1::Constant>(
89 parent.get_element_type(),
91 dequantizationOperations.multiply.values);
93 multiply = dequantizationOperations.multiply.constantIndex == 1ul ?
94 std::make_shared<ngraph::pass::low_precision::DequantizationMultiply>(parent, constant) :
95 std::make_shared<ngraph::pass::low_precision::DequantizationMultiply>(constant, parent);
97 const std::shared_ptr<ngraph::opset1::Constant> constant = std::make_shared<ngraph::opset1::Constant>(
98 dequantizationOperations.multiply.constantPrecision != element::undefined ?
99 dequantizationOperations.multiply.constantPrecision :
100 parent.get_element_type(),
102 dequantizationOperations.multiply.values);
104 multiply = dequantizationOperations.multiply.constantIndex == 1ul ?
105 std::make_shared<op::TypeRelaxed<ngraph::pass::low_precision::DequantizationMultiply>>(
106 std::vector<element::Type>{element::f32, element::f32},
107 std::vector<element::Type>{ element::f32 },
108 ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get(),
109 ngraph::op::TemporaryReplaceOutputType(constant, element::f32).get()) :
110 std::make_shared<op::TypeRelaxed<ngraph::pass::low_precision::DequantizationMultiply>>(
111 std::vector<element::Type>{element::f32, element::f32},
112 std::vector<element::Type>{ element::f32 },
113 ngraph::op::TemporaryReplaceOutputType(constant, element::f32).get(),
114 ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get());
116 ngraph::copy_runtime_info({ data.get_node_shared_ptr(), multiply }, multiply);
120 return parent.get_node_shared_ptr();
123 std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantize(
124 const Output<Node>& input,
125 const ngraph::element::Type precision,
126 const FakeQuantizeOnData& fqOnData) {
127 return as_type_ptr<ngraph::opset1::FakeQuantize>(ngraph::builder::makeFakeQuantize(
130 fqOnData.quantizationLevel,
131 fqOnData.constantShape,
132 fqOnData.inputLowValues,
133 fqOnData.inputHighValues,
134 fqOnData.outputLowValues,
135 fqOnData.outputHighValues));
138 std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantizeTypeRelaxed(
139 const std::shared_ptr<ngraph::Node>& input,
140 const ngraph::element::Type precision,
141 const FakeQuantizeOnData& fqOnData) {
142 const std::shared_ptr<ngraph::opset1::FakeQuantize> fq = makeFakeQuantize(input, precision, fqOnData);
143 return std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::FakeQuantize>>(*fq, fqOnData.outputPrecision);
146 std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantize(
147 const Output<Node>& input,
148 const ngraph::element::Type precision,
149 const FakeQuantizeOnDataWithConstant& fqOnData) {
150 const auto inputLowNode = ngraph::builder::makeConstant(
152 fqOnData.constantShapes.empty() ? ngraph::Shape{} : fqOnData.constantShapes[0],
153 fqOnData.inputLowValues,
154 fqOnData.inputLowValues.empty());
156 const auto inputHighNode = ngraph::builder::makeConstant(
158 fqOnData.constantShapes.empty() ? ngraph::Shape{} : fqOnData.constantShapes[1],
159 fqOnData.inputHighValues,
160 fqOnData.inputHighValues.empty());
162 const auto outputLowNode = ngraph::builder::makeConstant(
164 fqOnData.constantShapes.empty() ? ngraph::Shape{} : fqOnData.constantShapes[2],
165 fqOnData.outputLowValues,
166 fqOnData.outputLowValues.empty());
168 const auto outputHighNode = ngraph::builder::makeConstant(
170 fqOnData.constantShapes.empty() ? ngraph::Shape{} : fqOnData.constantShapes[3],
171 fqOnData.outputHighValues,
172 fqOnData.outputHighValues.empty());
174 auto fq = std::make_shared<ngraph::opset1::FakeQuantize>(input, inputLowNode, inputHighNode, outputLowNode, outputHighNode, fqOnData.quantizationLevel);
178 std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantizeTypeRelaxed(
179 const std::shared_ptr<ngraph::Node>& input,
180 const ngraph::element::Type precision,
181 const FakeQuantizeOnDataWithConstant& fqOnData) {
182 const std::shared_ptr<ngraph::opset1::FakeQuantize> fq = makeFakeQuantize(input, precision, fqOnData);
183 return std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::FakeQuantize>>(*fq, fqOnData.outputPrecision);
186 std::shared_ptr<Node> addDequantizationAttribute(const std::shared_ptr<Node>& op) {
187 auto& rtInfo = op->get_rt_info();
188 rtInfo["DEQUANTIZATION"] = std::make_shared<VariantWrapper<DequantizationAttr>>(DequantizationAttr());
192 } // namespace subgraph
193 } // namespace builder
194 } // namespace ngraph