[LPT] POT support: absent convert fix & element-wise empty dequantization data (...
[platform/upstream/dldt.git] / inference-engine / tests / ngraph_functions / src / low_precision_transformations / elementwise_with_multi_parent_dequantization_function.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp"
6 #include "low_precision/network_helper.hpp"
7
8 #include <ngraph/opsets/opset1.hpp>
9 #include "ngraph_functions/builders.hpp"
10 #include "ngraph_functions/subgraph_builders.hpp"
11
12 using namespace ngraph::pass::low_precision;
13
14 namespace ngraph {
15 namespace builder {
16 namespace subgraph {
17
18 std::shared_ptr<ngraph::Function> ElementwiseWithMultiParentDequantizationFunction::get(
19     const ngraph::element::Type precision,
20     const ngraph::Shape& inputShape,
21     const ngraph::pass::low_precision::LayerTransformation::Params& params,
22     const ngraph::element::Type& precision1,
23     const ngraph::builder::subgraph::DequantizationOperations& dequantization1,
24     const ngraph::element::Type& precision2,
25     const ngraph::builder::subgraph::DequantizationOperations& dequantization2) {
26     const auto input1_1 = std::make_shared<ngraph::opset1::Parameter>(precision1, inputShape);
27     const auto input1_2 = std::make_shared<ngraph::opset1::Parameter>(precision1, ngraph::Shape({ inputShape[0], inputShape[1], 1, 1 }));
28     const std::shared_ptr<ngraph::Node> multiply1 = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Multiply>>(
29         DequantizationMultiply(
30             ngraph::op::TemporaryReplaceOutputType(input1_1, element::f32).get(),
31             ngraph::op::TemporaryReplaceOutputType(input1_2, element::f32).get()),
32         std::vector<element::Type>{element::f32, element::f32},
33         std::vector<element::Type>{});
34
35     const std::shared_ptr<ngraph::Node> parent1 = dequantization1.empty() ? multiply1 : makeDequantization(multiply1, dequantization1);
36
37     const auto input2_1 = std::make_shared<ngraph::opset1::Parameter>(precision1, inputShape);
38     const auto input2_2 = std::make_shared<ngraph::opset1::Parameter>(precision1, ngraph::Shape({ inputShape[0], inputShape[1], 1, 1 }));
39     const std::shared_ptr<ngraph::Node> multiply2 = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Multiply>>(
40         DequantizationMultiply(
41             ngraph::op::TemporaryReplaceOutputType(input2_1, element::f32).get(),
42             ngraph::op::TemporaryReplaceOutputType(input2_2, element::f32).get()),
43         std::vector<element::Type>{element::f32, element::f32},
44         std::vector<element::Type>{});
45
46     const std::shared_ptr<ngraph::Node> parent2 = dequantization2.empty() ? multiply2 : makeDequantization(multiply2, dequantization2);
47
48     const auto add = std::make_shared<ngraph::opset1::Add>(parent1, parent2);
49     add->set_friendly_name("output");
50     auto& rtInfo = add->get_rt_info();
51     rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("add");
52
53     ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(add) };
54     ngraph::ParameterVector parameters = { input1_1, input1_2, input2_1, input2_2 };
55     return std::make_shared<ngraph::Function>(results, parameters, "ElementwiseWithMultiParentDequantization");
56 }
57
58 }  // namespace subgraph
59 }  // namespace builder
60 }  // namespace ngraph