[LPT] integration branch: Reshape fix, Concat generalization, runtime info usage...
[platform/upstream/dldt.git] / inference-engine / tests / ngraph_functions / src / low_precision_transformations / add_function.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ngraph_functions/low_precision_transformations/add_function.hpp"
6 #include "low_precision/network_helper.hpp"
7
8 #include <ngraph/opsets/opset1.hpp>
9 #include "ngraph_functions/builders.hpp"
10 #include "ngraph_functions/subgraph_builders.hpp"
11
12 using namespace ngraph::pass::low_precision;
13
14 namespace ngraph {
15 namespace builder {
16 namespace subgraph {
17
18 std::shared_ptr<ngraph::Function> AddFunction::getOriginal(
19     const ngraph::element::Type precision,
20     const ngraph::Shape& inputShape,
21     const bool broadcast,
22     const ngraph::pass::low_precision::LayerTransformation::Params& params,
23     const ngraph::element::Type& precision1,
24     const ngraph::builder::subgraph::DequantizationOperations& dequantization1,
25     const ngraph::element::Type& precision2,
26     const ngraph::builder::subgraph::DequantizationOperations& dequantization2,
27     const int constInput,
28     const std::vector<float>& constValues,
29     const std::string& additionalLayer) {
30     std::shared_ptr<ngraph::Node> input1;
31     if (constInput == 0) {
32         input1 = std::make_shared<ngraph::opset1::Constant>(
33             precision,
34             inputShape,
35             constValues);
36     } else {
37         input1 = std::make_shared<ngraph::opset1::Parameter>(
38             precision1,
39             broadcast ? ngraph::Shape({ inputShape[0], inputShape[1], 1, 1 }) : ngraph::Shape(inputShape));
40     }
41
42     const auto dequantizationOp1 = is_type<ngraph::opset1::Constant>(input1) ? input1 : makeDequantization(input1, dequantization1);
43
44     std::shared_ptr<ngraph::Node> input2;
45     if (constInput == 1) {
46         input2 = std::make_shared<ngraph::opset1::Constant>(
47             precision,
48             inputShape,
49             constValues);
50     } else {
51         input2 = std::make_shared<ngraph::opset1::Parameter>(
52             precision2, ngraph::Shape(inputShape));
53     }
54     auto parent = input2;
55     if (additionalLayer == "convolution") {
56         parent = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Convolution>>(
57             std::vector<element::Type>{ element::f32, element::f32 },
58             std::vector<element::Type>{ element::f32 },
59             ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get(),
60             ngraph::op::TemporaryReplaceOutputType(
61                 std::make_shared<ngraph::opset1::Constant>(element::i8, Shape{ 1, 4, 1, 1 }, std::vector<float>{0.8f, 0.8f, 0.8f, 0.8f}),
62                 element::f32).get(),
63             ngraph::Strides{ 1, 1 },
64             ngraph::CoordinateDiff{ 0, 0 },
65             ngraph::CoordinateDiff{ 0, 0 },
66             ngraph::Strides{ 1, 1 });
67     }
68     if (additionalLayer == "group_convolution") {
69         parent = std::make_shared< ngraph::op::TypeRelaxed<ngraph::opset1::GroupConvolution>>(
70             std::vector<element::Type>{ element::f32, element::f32 },
71             std::vector<element::Type>{ element::f32 },
72             ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get(),
73             ngraph::op::TemporaryReplaceOutputType(
74                 std::make_shared<ngraph::opset1::Constant>(element::i8, Shape{ 4, 1, 1, 1, 1 }, std::vector<float>{0.8f, 0.8f, 0.8f, 0.8f}),
75                 element::f32).get(),
76             ngraph::Strides{ 1, 1 },
77             ngraph::CoordinateDiff{ 0, 0 },
78             ngraph::CoordinateDiff{ 0, 0 },
79             ngraph::Strides{ 1, 1 });
80     }
81     if (additionalLayer != "") {
82         parent = std::make_shared<ngraph::opset1::Add>(
83             parent,
84             std::make_shared<ngraph::opset1::Constant>(element::f32, Shape{ 1, 1, 1, 1 }, std::vector<float>{1.f}));
85         parent = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(
86                 parent,
87                 ngraph::element::f32,
88                 {256, Shape{}, { 0 }, { 255 }, { 0 }, { 255 }, element::u8});
89     }
90     const auto dequantizationOp2 = is_type<ngraph::opset1::Constant>(parent) ? parent : makeDequantization(parent, dequantization2);
91
92     const auto add = std::make_shared<ngraph::opset1::Add>(dequantizationOp1, dequantizationOp2);
93     add->set_friendly_name("output");
94     auto& rtInfo = add->get_rt_info();
95     rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("add");
96
97     ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(add) };
98     ngraph::ParameterVector parameters;
99     if (constInput == -1) {
100         parameters = { as_type_ptr<ngraph::opset1::Parameter>(input1), as_type_ptr<ngraph::opset1::Parameter>(input2) };
101     } else if (constInput == 0) {
102         parameters = { as_type_ptr<ngraph::opset1::Parameter>(input2) };
103     } else if (constInput == 1) {
104         parameters = { as_type_ptr<ngraph::opset1::Parameter>(input1) };
105     } else {
106         THROW_IE_EXCEPTION << "Unexpected constant input index";
107     }
108     return std::make_shared<ngraph::Function>(results, parameters, "AddTransformation");
109 }
110
111 std::shared_ptr<ngraph::Function> AddFunction::getReference(
112     const ngraph::element::Type precision,
113     const ngraph::Shape& inputShape,
114     const bool broadcast,
115     const ngraph::pass::low_precision::LayerTransformation::Params& params,
116     const ngraph::element::Type& precision1,
117     const ngraph::builder::subgraph::DequantizationOperations& dequantization1,
118     const ngraph::element::Type& precision2,
119     const ngraph::builder::subgraph::DequantizationOperations& dequantization2,
120     const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter,
121     const int constInputIndex,
122     const std::vector<float>& constValues,
123     const std::string& additionalLayer,
124     const std::string& operationType) {
125     std::shared_ptr<ngraph::Node> input1;
126     if (constInputIndex == 0) {
127         input1 = std::make_shared<ngraph::opset1::Constant>(
128             precision,
129             inputShape,
130             constValues);
131     } else {
132         input1 = std::make_shared<ngraph::opset1::Parameter>(
133             precision1,
134             broadcast ? ngraph::Shape({ inputShape[0], inputShape[1], 1, 1 }) : ngraph::Shape(inputShape));
135     }
136
137     const auto dequantizationOp1 = is_type<ngraph::opset1::Constant>(input1) ? input1 : makeDequantization(input1, dequantization1);
138
139     std::shared_ptr<ngraph::Node> input2;
140     if (constInputIndex == 1) {
141         input2 = std::make_shared<ngraph::opset1::Constant>(
142             precision,
143             inputShape,
144             constValues);
145     } else {
146         input2 = std::make_shared<ngraph::opset1::Parameter>(
147             precision2, ngraph::Shape(inputShape));
148     }
149     auto parent = input2;
150     if (additionalLayer == "convolution") {
151         parent = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Convolution>>(
152             std::vector<element::Type>{ element::f32, element::f32 },
153             std::vector<element::Type>{ element::f32 },
154             ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get(),
155             ngraph::op::TemporaryReplaceOutputType(
156                 std::make_shared<ngraph::opset1::Constant>(element::i8, Shape{ 1, 4, 1, 1 }, std::vector<float>{0.8f, 0.8f, 0.8f, 0.8f}),
157                 element::f32).get(),
158             ngraph::Strides{ 1, 1 },
159             ngraph::CoordinateDiff{ 0, 0 },
160             ngraph::CoordinateDiff{ 0, 0 },
161             ngraph::Strides{ 1, 1 });
162     }
163     if (additionalLayer == "group_convolution") {
164         parent = std::make_shared< ngraph::op::TypeRelaxed<ngraph::opset1::GroupConvolution>>(
165             std::vector<element::Type>{ element::f32, element::f32 },
166             std::vector<element::Type>{ element::f32 },
167             ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get(),
168             ngraph::op::TemporaryReplaceOutputType(
169                 std::make_shared<ngraph::opset1::Constant>(element::i8, Shape{ 4, 1, 1, 1, 1 }, std::vector<float>{0.8f, 0.8f, 0.8f, 0.8f}),
170                 element::f32).get(),
171             ngraph::Strides{ 1, 1 },
172             ngraph::CoordinateDiff{ 0, 0 },
173             ngraph::CoordinateDiff{ 0, 0 },
174             ngraph::Strides{ 1, 1 });
175     }
176     if (additionalLayer != "") {
177         parent = std::make_shared<ngraph::opset1::Add>(
178             parent,
179             std::make_shared<ngraph::opset1::Constant>(element::f32, Shape{ 1, 1, 1, 1 }, std::vector<float>{1.f}));
180         parent = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(
181                 parent,
182                 ngraph::element::f32,
183                 {256, Shape{}, { 0 }, { 255 }, { 0 }, { 255 }, element::u8});
184     }
185     const auto dequantizationOp2 = is_type<ngraph::opset1::Constant>(parent) ? parent : makeDequantization(parent, dequantization2);
186
187     const std::shared_ptr<Node> add = operationType == "Add" ?
188         std::dynamic_pointer_cast<Node>(std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Add>>(
189             std::vector<element::Type>{ element::f32, element::f32 },
190             std::vector<element::Type>{},
191             ngraph::op::TemporaryReplaceOutputType(dequantizationOp1, element::f32).get(),
192             ngraph::op::TemporaryReplaceOutputType(dequantizationOp2, element::f32).get())) :
193         std::make_shared<ngraph::op::TypeRelaxed<DequantizationSubtract>>(
194             std::vector<element::Type>{ element::f32, element::f32 },
195             std::vector<element::Type>{},
196             ngraph::op::TemporaryReplaceOutputType(dequantizationOp1, element::f32).get(),
197             ngraph::op::TemporaryReplaceOutputType(dequantizationOp2, element::f32).get());
198
199     NetworkHelper::setOutDataPrecisionForTypeRelaxed(add, precision);
200     auto& rtInfo = add->get_rt_info();
201     rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("add");
202
203     const auto dequantizationOpAfter = makeDequantization(add, dequantizationAfter);
204
205     dequantizationOpAfter->set_friendly_name("output");
206
207     ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(dequantizationOpAfter) };
208     ngraph::ParameterVector parameters;
209     if (constInputIndex == -1) {
210         parameters = { as_type_ptr<ngraph::opset1::Parameter>(input1), as_type_ptr<ngraph::opset1::Parameter>(input2) };
211     } else if (constInputIndex == 0) {
212         parameters = { as_type_ptr<ngraph::opset1::Parameter>(input2) };
213     } else if (constInputIndex == 1) {
214         parameters = { as_type_ptr<ngraph::opset1::Parameter>(input1) };
215     } else {
216         THROW_IE_EXCEPTION << "Unexpected constant input index";
217     }
218     return std::make_shared<ngraph::Function>(results, parameters, "AddTransformation");
219 }
220
221 }  // namespace subgraph
222 }  // namespace builder
223 }  // namespace ngraph