Es/lpt/lpt to ngraph fixes2 with master (#2671)
[platform/upstream/dldt.git] / inference-engine / tests / ngraph_functions / src / low_precision_transformations / multiply_function.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ngraph_functions/low_precision_transformations/multiply_function.hpp"
6
7 #include <ngraph/opsets/opset1.hpp>
8 #include <ngraph_ops/type_relaxed.hpp>
9 #include "ngraph_functions/subgraph_builders.hpp"
10 #include "transformations/low_precision/common/dequantization_op.hpp"
11 #include "transformations/low_precision/network_helper.hpp"
12
13 #include "ngraph_functions/low_precision_transformations/common/builders.hpp"
14 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
15
16 using namespace ngraph::pass::low_precision;
17
18 namespace ngraph {
19 namespace builder {
20 namespace subgraph {
21
22 struct BranchNodes {
23     std::shared_ptr<Node> input;
24     std::shared_ptr<Node> dequantization;
25 };
26
27 BranchNodes getBranch(const MultiplyBranch& branch) {
28     if (!branch.constant.empty()) {
29         if (branch.inputShape != branch.constant.shape) {
30             THROW_IE_EXCEPTION << "shapes are not equals: " << branch.inputShape << " & " << branch.constant.shape;
31         }
32
33         if (branch.precisionBeforeDequantization != branch.constant.outPrecision) {
34             THROW_IE_EXCEPTION << "precisions are not equals: " << branch.precisionBeforeDequantization << " & " << branch.constant.outPrecision;
35         }
36     }
37
38     const std::shared_ptr<Node> parent = branch.constant.empty() ?
39         std::make_shared<ngraph::opset1::Parameter>(branch.precisionBeforeDequantization, branch.inputShape) :
40         std::dynamic_pointer_cast<Node>(std::make_shared<ngraph::opset1::Constant>(
41             branch.constant.outPrecision,
42             branch.constant.shape,
43             branch.constant.values));
44
45     const auto dequantization = makeDequantization(parent, branch.dequantization);
46     return {parent, dequantization};
47 }
48
49 std::shared_ptr<ngraph::Function> MultiplyFunction::get(
50     const ngraph::Shape& inputShape,
51     const MultiplyValues& actualValues) {
52     const BranchNodes branchNodes1 = getBranch(actualValues.branch1);
53     const BranchNodes branchNodes2 = getBranch(actualValues.branch2);
54
55     auto multiplyOriginal = actualValues.isDequantization ?
56         DequantizationMultiply(
57             ngraph::op::TemporaryReplaceOutputType(branchNodes1.dequantization, element::f32).get(),
58             ngraph::op::TemporaryReplaceOutputType(branchNodes2.dequantization, element::f32).get()) :
59         ngraph::opset1::Multiply(
60             ngraph::op::TemporaryReplaceOutputType(branchNodes1.dequantization, element::f32).get(),
61             ngraph::op::TemporaryReplaceOutputType(branchNodes2.dequantization, element::f32).get());
62
63     const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Multiply>>(
64         multiplyOriginal,
65         std::vector<element::Type>{element::f32, element::f32},
66         std::vector<element::Type>{});
67
68     multiply->set_friendly_name("output");
69
70     ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(multiply) };
71
72     ngraph::ParameterVector inputs;
73     if (is_type<opset1::Parameter>(branchNodes1.input)) {
74         inputs.push_back(std::dynamic_pointer_cast<opset1::Parameter>(branchNodes1.input));
75     }
76     if (is_type<opset1::Parameter>(branchNodes2.input)) {
77         inputs.push_back(std::dynamic_pointer_cast<opset1::Parameter>(branchNodes2.input));
78     }
79
80     return std::make_shared<ngraph::Function>(results, inputs, "MultiplyTransformation");
81 }
82
83 }  // namespace subgraph
84 }  // namespace builder
85 }  // namespace ngraph