1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp"
6 #include "low_precision/network_helper.hpp"
8 #include <ngraph/opsets/opset1.hpp>
9 #include "ngraph_functions/builders.hpp"
10 #include "ngraph_functions/subgraph_builders.hpp"
12 using namespace ngraph::pass::low_precision;
18 std::shared_ptr<ngraph::Function> ElementwiseWithMultiParentDequantizationFunction::get(
19 const ngraph::element::Type precision,
20 const ngraph::Shape& inputShape,
21 const ngraph::pass::low_precision::LayerTransformation::Params& params,
22 const ngraph::element::Type& precision1,
23 const ngraph::builder::subgraph::DequantizationOperations& dequantization1,
24 const ngraph::element::Type& precision2,
25 const ngraph::builder::subgraph::DequantizationOperations& dequantization2) {
26 const auto input1_1 = std::make_shared<ngraph::opset1::Parameter>(precision1, inputShape);
27 const auto input1_2 = std::make_shared<ngraph::opset1::Parameter>(precision1, ngraph::Shape({ inputShape[0], inputShape[1], 1, 1 }));
28 const std::shared_ptr<ngraph::Node> multiply1 = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Multiply>>(
29 DequantizationMultiply(
30 ngraph::op::TemporaryReplaceOutputType(input1_1, element::f32).get(),
31 ngraph::op::TemporaryReplaceOutputType(input1_2, element::f32).get()),
32 std::vector<element::Type>{element::f32, element::f32},
33 std::vector<element::Type>{});
35 const std::shared_ptr<ngraph::Node> parent1 = dequantization1.empty() ? multiply1 : makeDequantization(multiply1, dequantization1);
37 const auto input2_1 = std::make_shared<ngraph::opset1::Parameter>(precision1, inputShape);
38 const auto input2_2 = std::make_shared<ngraph::opset1::Parameter>(precision1, ngraph::Shape({ inputShape[0], inputShape[1], 1, 1 }));
39 const std::shared_ptr<ngraph::Node> multiply2 = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Multiply>>(
40 DequantizationMultiply(
41 ngraph::op::TemporaryReplaceOutputType(input2_1, element::f32).get(),
42 ngraph::op::TemporaryReplaceOutputType(input2_2, element::f32).get()),
43 std::vector<element::Type>{element::f32, element::f32},
44 std::vector<element::Type>{});
46 const std::shared_ptr<ngraph::Node> parent2 = dequantization2.empty() ? multiply2 : makeDequantization(multiply2, dequantization2);
48 const auto add = std::make_shared<ngraph::opset1::Add>(parent1, parent2);
49 add->set_friendly_name("output");
50 auto& rtInfo = add->get_rt_info();
51 rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("add");
53 ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(add) };
54 ngraph::ParameterVector parameters = { input1_1, input1_2, input2_1, input2_2 };
55 return std::make_shared<ngraph::Function>(results, parameters, "ElementwiseWithMultiParentDequantization");
58 } // namespace subgraph
59 } // namespace builder