1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ngraph_functions/low_precision_transformations/multiply_function.hpp"
7 #include <ngraph/opsets/opset1.hpp>
8 #include <ngraph_ops/type_relaxed.hpp>
9 #include "ngraph_functions/subgraph_builders.hpp"
10 #include "transformations/low_precision/common/dequantization_op.hpp"
11 #include "transformations/low_precision/network_helper.hpp"
13 #include "ngraph_functions/low_precision_transformations/common/builders.hpp"
14 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
16 using namespace ngraph::pass::low_precision;
23 std::shared_ptr<Node> input;
24 std::shared_ptr<Node> dequantization;
27 BranchNodes getBranch(const MultiplyBranch& branch) {
28 if (!branch.constant.empty()) {
29 if (branch.inputShape != branch.constant.shape) {
30 THROW_IE_EXCEPTION << "shapes are not equals: " << branch.inputShape << " & " << branch.constant.shape;
33 if (branch.precisionBeforeDequantization != branch.constant.outPrecision) {
34 THROW_IE_EXCEPTION << "precisions are not equals: " << branch.precisionBeforeDequantization << " & " << branch.constant.outPrecision;
38 const std::shared_ptr<Node> parent = branch.constant.empty() ?
39 std::make_shared<ngraph::opset1::Parameter>(branch.precisionBeforeDequantization, branch.inputShape) :
40 std::dynamic_pointer_cast<Node>(std::make_shared<ngraph::opset1::Constant>(
41 branch.constant.outPrecision,
42 branch.constant.shape,
43 branch.constant.values));
45 const auto dequantization = makeDequantization(parent, branch.dequantization);
46 return {parent, dequantization};
49 std::shared_ptr<ngraph::Function> MultiplyFunction::get(
50 const ngraph::Shape& inputShape,
51 const MultiplyValues& actualValues) {
52 const BranchNodes branchNodes1 = getBranch(actualValues.branch1);
53 const BranchNodes branchNodes2 = getBranch(actualValues.branch2);
55 auto multiplyOriginal = actualValues.isDequantization ?
56 DequantizationMultiply(
57 ngraph::op::TemporaryReplaceOutputType(branchNodes1.dequantization, element::f32).get(),
58 ngraph::op::TemporaryReplaceOutputType(branchNodes2.dequantization, element::f32).get()) :
59 ngraph::opset1::Multiply(
60 ngraph::op::TemporaryReplaceOutputType(branchNodes1.dequantization, element::f32).get(),
61 ngraph::op::TemporaryReplaceOutputType(branchNodes2.dequantization, element::f32).get());
63 const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Multiply>>(
65 std::vector<element::Type>{element::f32, element::f32},
66 std::vector<element::Type>{});
68 multiply->set_friendly_name("output");
70 ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(multiply) };
72 ngraph::ParameterVector inputs;
73 if (is_type<opset1::Parameter>(branchNodes1.input)) {
74 inputs.push_back(std::dynamic_pointer_cast<opset1::Parameter>(branchNodes1.input));
76 if (is_type<opset1::Parameter>(branchNodes2.input)) {
77 inputs.push_back(std::dynamic_pointer_cast<opset1::Parameter>(branchNodes2.input));
80 return std::make_shared<ngraph::Function>(results, inputs, "MultiplyTransformation");
83 } // namespace subgraph
84 } // namespace builder