db9dd07aec5ae1ec13d22102d4c45571c7d094ac
[platform/upstream/dldt.git] / inference-engine / tests / ngraph_functions / src / low_precision_transformations / common / builders.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ngraph_functions/low_precision_transformations/common/builders.hpp"
6
7 #include <queue>
8 #include <memory>
9
10 #include <ngraph/opsets/opset1.hpp>
11 #include "ngraph_ops/type_relaxed.hpp"
12 #include "ngraph_functions/subgraph_builders.hpp"
13 #include "low_precision/network_helper.hpp"
14
15 namespace ngraph {
16 namespace builder {
17 namespace subgraph {
18
19 std::shared_ptr<Node> makeDequantization(
20     const Output<Node>& data,
21     const DequantizationOperations& dequantizationOperations) {
22     Output<Node> parent = data;
23
24     if (!dequantizationOperations.convert.empty()) {
25         std::shared_ptr<ngraph::opset1::Convert> convert = std::make_shared<ngraph::pass::low_precision::DequantizationConvert>(
26             data,
27             dequantizationOperations.convert.outPrecision);
28         parent = convert;
29     }
30
31     if (!dequantizationOperations.subtract.empty()) {
32         std::shared_ptr<ngraph::opset1::Subtract> subtract;
33
34         std::vector<size_t> shape;
35         if (dequantizationOperations.subtract.constantShapeIsDefined) {
36             shape = dequantizationOperations.subtract.constantShape;
37         } else {
38             if (dequantizationOperations.subtract.values.size() == 1ul) {
39                 shape = std::vector<size_t>({});
40             } else {
41                 shape = std::vector<size_t>(parent.get_shape().size(), 1ul);
42                 shape[shape.size() >= 2 ? 1ul : 0] = dequantizationOperations.subtract.values.size();
43             }
44         }
45
46         const auto subtractConst = std::make_shared<ngraph::opset1::Constant>(
47             dequantizationOperations.subtract.constantPrecision != element::undefined ?
48                 dequantizationOperations.subtract.constantPrecision :
49                 parent.get_element_type(),
50             shape,
51             dequantizationOperations.subtract.values);
52
53         if ((dequantizationOperations.subtract.outPrecision == element::undefined) ||
54             (dequantizationOperations.subtract.outPrecision == parent.get_element_type())) {
55             subtract = std::make_shared<ngraph::pass::low_precision::DequantizationSubtract>(parent, subtractConst);
56         } else {
57             subtract = std::make_shared<op::TypeRelaxed<ngraph::pass::low_precision::DequantizationSubtract>>(
58                     std::vector<element::Type>{element::f32, element::f32},
59                     std::vector<element::Type>{ element::f32 },
60                     ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get(),
61                     ngraph::op::TemporaryReplaceOutputType(subtractConst, element::f32).get());
62             ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(subtract, dequantizationOperations.subtract.outPrecision);
63         }
64         if (!dequantizationOperations.subtract.addDequantizationAttribute) {
65             ngraph::pass::low_precision::NetworkHelper::cleanRunTimeInfo(subtract);
66         }
67         parent = subtract;
68     }
69
70     if (!dequantizationOperations.multiply.empty()) {
71         std::vector<size_t> shape;
72         if (dequantizationOperations.multiply.constantShapeIsDefined) {
73             shape = dequantizationOperations.multiply.constantShape;
74         } else {
75             if (dequantizationOperations.multiply.values.size() == 1ul) {
76                 shape = std::vector<size_t>({});
77             } else {
78                 shape = std::vector<size_t>(parent.get_shape().size(), 1ul);
79                 shape[shape.size() >= 2 ? 1ul : 0] = dequantizationOperations.multiply.values.size();
80             }
81         }
82
83         std::shared_ptr<ngraph::opset1::Multiply> multiply;
84         if ((dequantizationOperations.multiply.outPrecision == element::undefined) ||
85             (dequantizationOperations.multiply.outPrecision == parent.get_element_type())) {
86             const std::shared_ptr<ngraph::opset1::Constant> constant = std::make_shared<ngraph::opset1::Constant>(
87                 parent.get_element_type(),
88                 shape,
89                 dequantizationOperations.multiply.values);
90
91             multiply = dequantizationOperations.multiply.constantIndex == 1ul ?
92                 std::make_shared<ngraph::pass::low_precision::DequantizationMultiply>(parent, constant) :
93                 std::make_shared<ngraph::pass::low_precision::DequantizationMultiply>(constant, parent);
94         } else {
95             const std::shared_ptr<ngraph::opset1::Constant> constant = std::make_shared<ngraph::opset1::Constant>(
96                 dequantizationOperations.multiply.constantPrecision != element::undefined ?
97                     dequantizationOperations.multiply.constantPrecision :
98                     parent.get_element_type(),
99                 shape,
100                 dequantizationOperations.multiply.values);
101
102             multiply = dequantizationOperations.multiply.constantIndex == 1ul ?
103                 std::make_shared<op::TypeRelaxed<ngraph::pass::low_precision::DequantizationMultiply>>(
104                     std::vector<element::Type>{element::f32, element::f32},
105                     std::vector<element::Type>{ element::f32 },
106                     ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get(),
107                     ngraph::op::TemporaryReplaceOutputType(constant, element::f32).get()) :
108                 std::make_shared<op::TypeRelaxed<ngraph::pass::low_precision::DequantizationMultiply>>(
109                     std::vector<element::Type>{element::f32, element::f32},
110                     std::vector<element::Type>{ element::f32 },
111                     ngraph::op::TemporaryReplaceOutputType(constant, element::f32).get(),
112                     ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get());
113         }
114
115         parent = multiply;
116     }
117
118     return parent.get_node_shared_ptr();
119 }
120
121 std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantize(
122     const Output<Node>& input,
123     const ngraph::element::Type precision,
124     const FakeQuantizeOnData& fqOnData) {
125     return as_type_ptr<ngraph::opset1::FakeQuantize>(ngraph::builder::makeFakeQuantize(
126         input,
127         precision,
128         fqOnData.quantizationLevel,
129         fqOnData.constantShape,
130         fqOnData.inputLowValues,
131         fqOnData.inputHighValues,
132         fqOnData.outputLowValues,
133         fqOnData.outputHighValues));
134 }
135
136 std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantizeTypeRelaxed(
137     const std::shared_ptr<ngraph::Node>& input,
138     const ngraph::element::Type precision,
139     const FakeQuantizeOnData& fqOnData) {
140     const std::shared_ptr<ngraph::opset1::FakeQuantize> fq = makeFakeQuantize(input, precision, fqOnData);
141     return std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::FakeQuantize>>(*fq, fqOnData.outputPrecision);
142 }
143
144 } // namespace subgraph
145 } // namespace builder
146 } // namespace ngraph