[LPT] integration branch: Reshape fix, Concat generalization, runtime info usage...
[platform/upstream/dldt.git] / inference-engine / tests / ngraph_functions / src / low_precision_transformations / common / builders.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ngraph_functions/low_precision_transformations/common/builders.hpp"
6
7 #include <queue>
8 #include <memory>
9
10 #include <ngraph/opsets/opset1.hpp>
11 #include "ngraph_ops/type_relaxed.hpp"
12 #include "ngraph_functions/subgraph_builders.hpp"
13 #include "low_precision/network_helper.hpp"
14
15 namespace ngraph {
16 namespace builder {
17 namespace subgraph {
18
19 std::shared_ptr<Node> makeDequantization(
20     const Output<Node>& data,
21     const DequantizationOperations& dequantizationOperations) {
22     Output<Node> parent = data;
23
24     if (!dequantizationOperations.convert.empty()) {
25         std::shared_ptr<ngraph::opset1::Convert> convert = std::make_shared<ngraph::pass::low_precision::DequantizationConvert>(
26             data,
27             dequantizationOperations.convert.outPrecision);
28         ngraph::copy_runtime_info({ data.get_node_shared_ptr(), convert }, convert);
29         parent = convert;
30     }
31
32     if (!dequantizationOperations.subtract.empty()) {
33         std::shared_ptr<ngraph::opset1::Subtract> subtract;
34
35         std::vector<size_t> shape;
36         if (dequantizationOperations.subtract.constantShapeIsDefined) {
37             shape = dequantizationOperations.subtract.constantShape;
38         } else {
39             if (dequantizationOperations.subtract.values.size() == 1ul) {
40                 shape = std::vector<size_t>({});
41             } else {
42                 shape = std::vector<size_t>(parent.get_shape().size(), 1ul);
43                 shape[shape.size() >= 2 ? 1ul : 0] = dequantizationOperations.subtract.values.size();
44             }
45         }
46
47         const auto subtractConst = std::make_shared<ngraph::opset1::Constant>(
48             dequantizationOperations.subtract.constantPrecision != element::undefined ?
49                 dequantizationOperations.subtract.constantPrecision :
50                 parent.get_element_type(),
51             shape,
52             dequantizationOperations.subtract.values);
53
54         if ((dequantizationOperations.subtract.outPrecision == element::undefined) ||
55             (dequantizationOperations.subtract.outPrecision == parent.get_element_type())) {
56             subtract = std::make_shared<ngraph::pass::low_precision::DequantizationSubtract>(parent, subtractConst);
57         } else {
58             subtract = std::make_shared<op::TypeRelaxed<ngraph::pass::low_precision::DequantizationSubtract>>(
59                     std::vector<element::Type>{element::f32, element::f32},
60                     std::vector<element::Type>{ element::f32 },
61                     ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get(),
62                     ngraph::op::TemporaryReplaceOutputType(subtractConst, element::f32).get());
63             ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(subtract, dequantizationOperations.subtract.outPrecision);
64         }
65         if (!dequantizationOperations.subtract.addDequantizationAttribute) {
66             ngraph::pass::low_precision::NetworkHelper::cleanRunTimeInfo(subtract);
67         }
68         ngraph::copy_runtime_info({ data.get_node_shared_ptr(), subtract }, subtract);
69         parent = subtract;
70     }
71
72     if (!dequantizationOperations.multiply.empty()) {
73         std::vector<size_t> shape;
74         if (dequantizationOperations.multiply.constantShapeIsDefined) {
75             shape = dequantizationOperations.multiply.constantShape;
76         } else {
77             if (dequantizationOperations.multiply.values.size() == 1ul) {
78                 shape = std::vector<size_t>({});
79             } else {
80                 shape = std::vector<size_t>(parent.get_shape().size(), 1ul);
81                 shape[shape.size() >= 2 ? 1ul : 0] = dequantizationOperations.multiply.values.size();
82             }
83         }
84
85         std::shared_ptr<ngraph::opset1::Multiply> multiply;
86         if ((dequantizationOperations.multiply.outPrecision == element::undefined) ||
87             (dequantizationOperations.multiply.outPrecision == parent.get_element_type())) {
88             const std::shared_ptr<ngraph::opset1::Constant> constant = std::make_shared<ngraph::opset1::Constant>(
89                 parent.get_element_type(),
90                 shape,
91                 dequantizationOperations.multiply.values);
92
93             multiply = dequantizationOperations.multiply.constantIndex == 1ul ?
94                 std::make_shared<ngraph::pass::low_precision::DequantizationMultiply>(parent, constant) :
95                 std::make_shared<ngraph::pass::low_precision::DequantizationMultiply>(constant, parent);
96         } else {
97             const std::shared_ptr<ngraph::opset1::Constant> constant = std::make_shared<ngraph::opset1::Constant>(
98                 dequantizationOperations.multiply.constantPrecision != element::undefined ?
99                     dequantizationOperations.multiply.constantPrecision :
100                     parent.get_element_type(),
101                 shape,
102                 dequantizationOperations.multiply.values);
103
104             multiply = dequantizationOperations.multiply.constantIndex == 1ul ?
105                 std::make_shared<op::TypeRelaxed<ngraph::pass::low_precision::DequantizationMultiply>>(
106                     std::vector<element::Type>{element::f32, element::f32},
107                     std::vector<element::Type>{ element::f32 },
108                     ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get(),
109                     ngraph::op::TemporaryReplaceOutputType(constant, element::f32).get()) :
110                 std::make_shared<op::TypeRelaxed<ngraph::pass::low_precision::DequantizationMultiply>>(
111                     std::vector<element::Type>{element::f32, element::f32},
112                     std::vector<element::Type>{ element::f32 },
113                     ngraph::op::TemporaryReplaceOutputType(constant, element::f32).get(),
114                     ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get());
115         }
116         ngraph::copy_runtime_info({ data.get_node_shared_ptr(), multiply }, multiply);
117         parent = multiply;
118     }
119
120     return parent.get_node_shared_ptr();
121 }
122
123 std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantize(
124     const Output<Node>& input,
125     const ngraph::element::Type precision,
126     const FakeQuantizeOnData& fqOnData) {
127     return as_type_ptr<ngraph::opset1::FakeQuantize>(ngraph::builder::makeFakeQuantize(
128         input,
129         precision,
130         fqOnData.quantizationLevel,
131         fqOnData.constantShape,
132         fqOnData.inputLowValues,
133         fqOnData.inputHighValues,
134         fqOnData.outputLowValues,
135         fqOnData.outputHighValues));
136 }
137
138 std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantizeTypeRelaxed(
139     const std::shared_ptr<ngraph::Node>& input,
140     const ngraph::element::Type precision,
141     const FakeQuantizeOnData& fqOnData) {
142     const std::shared_ptr<ngraph::opset1::FakeQuantize> fq = makeFakeQuantize(input, precision, fqOnData);
143     return std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::FakeQuantize>>(*fq, fqOnData.outputPrecision);
144 }
145
146 std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantize(
147     const Output<Node>& input,
148     const ngraph::element::Type precision,
149     const FakeQuantizeOnDataWithConstant& fqOnData) {
150     const auto inputLowNode = ngraph::builder::makeConstant(
151         precision,
152         fqOnData.constantShapes.empty() ? ngraph::Shape{} : fqOnData.constantShapes[0],
153         fqOnData.inputLowValues,
154         fqOnData.inputLowValues.empty());
155
156     const auto inputHighNode = ngraph::builder::makeConstant(
157         precision,
158         fqOnData.constantShapes.empty() ? ngraph::Shape{} : fqOnData.constantShapes[1],
159         fqOnData.inputHighValues,
160         fqOnData.inputHighValues.empty());
161
162     const auto outputLowNode = ngraph::builder::makeConstant(
163         precision,
164         fqOnData.constantShapes.empty() ? ngraph::Shape{} : fqOnData.constantShapes[2],
165         fqOnData.outputLowValues,
166         fqOnData.outputLowValues.empty());
167
168     const auto outputHighNode = ngraph::builder::makeConstant(
169         precision,
170         fqOnData.constantShapes.empty() ? ngraph::Shape{} : fqOnData.constantShapes[3],
171         fqOnData.outputHighValues,
172         fqOnData.outputHighValues.empty());
173
174     auto fq = std::make_shared<ngraph::opset1::FakeQuantize>(input, inputLowNode, inputHighNode, outputLowNode, outputHighNode, fqOnData.quantizationLevel);
175     return fq;
176 }
177
178 std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantizeTypeRelaxed(
179     const std::shared_ptr<ngraph::Node>& input,
180     const ngraph::element::Type precision,
181     const FakeQuantizeOnDataWithConstant& fqOnData) {
182     const std::shared_ptr<ngraph::opset1::FakeQuantize> fq = makeFakeQuantize(input, precision, fqOnData);
183     return std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::FakeQuantize>>(*fq, fqOnData.outputPrecision);
184 }
185
186 std::shared_ptr<Node> addDequantizationAttribute(const std::shared_ptr<Node>& op) {
187     auto& rtInfo = op->get_rt_info();
188     rtInfo["DEQUANTIZATION"] = std::make_shared<VariantWrapper<DequantizationAttr>>(DequantizationAttr());
189     return op;
190 }
191
192 } // namespace subgraph
193 } // namespace builder
194 } // namespace ngraph