1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ngraph_functions/low_precision_transformations/mul_add_to_scaleshift_or_power_function.hpp"
7 #include <ngraph/opsets/opset1.hpp>
8 #include "ngraph_ops/type_relaxed.hpp"
9 #include "low_precision/network_helper.hpp"
10 #include "low_precision/common/dequantization_op.hpp"
12 #include <legacy/ngraph_ops/power.hpp>
13 #include <legacy/ngraph_ops/scaleshift.hpp>
15 #include "ngraph_functions/subgraph_builders.hpp"
16 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
21 using namespace ngraph::pass;
22 std::shared_ptr<ngraph::Function> MulAddToScaleshiftOrPowerFunction::getOriginal(
23 const ngraph::element::Type precision,
24 const ngraph::Shape& inputShape,
25 bool isDequantization,
26 const ngraph::builder::subgraph::DequantizationOperations::Multiply& mulValues,
27 const ngraph::builder::subgraph::Add& addValues) {
28 const auto input = std::make_shared<ngraph::op::v0::Parameter>(precision, inputShape);
30 const auto mulConst = ngraph::op::Constant::create(ngraph::element::f32, mulValues.constantShape, mulValues.values);
31 const auto mul = std::make_shared<ngraph::op::TypeRelaxed<ngraph::pass::low_precision::DequantizationMultiply>>(
32 std::vector<element::Type>{element::f32, element::f32}, std::vector<element::Type>{ element::f32 },
33 ngraph::op::TemporaryReplaceOutputType(input, element::f32).get(),
34 ngraph::op::TemporaryReplaceOutputType(mulConst, element::f32).get());
36 const auto addConst = ngraph::op::Constant::create(ngraph::element::f32, addValues.constantShape, addValues.values);
37 const auto add = std::make_shared<ngraph::pass::low_precision::DequantizationAdd>(mul, addConst);
38 add->set_friendly_name("add");
40 if (!isDequantization) {
41 ngraph::pass::low_precision::NetworkHelper::cleanRunTimeInfo(mul);
42 ngraph::pass::low_precision::NetworkHelper::cleanRunTimeInfo(add);
45 ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(add) };
46 return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MulAddToScaleshiftOrPowerFunction");
49 std::shared_ptr<ngraph::Function> MulAddToScaleshiftOrPowerFunction::getReference(
50 const ngraph::element::Type precision,
51 const ngraph::Shape& inputShape,
52 bool isDequantization,
53 const ngraph::builder::subgraph::DequantizationOperations::Multiply& weightsValues,
54 const ngraph::builder::subgraph::Add& biasesValues,
55 const ngraph::element::Type precisionAfterOperation) {
56 const auto input = std::make_shared<ngraph::op::v0::Parameter>(precision, inputShape);
58 ngraph::Shape constShape = { 1, inputShape[1], 1, 1 };
59 const auto weights = ngraph::op::Constant::create(ngraph::element::f32, constShape, weightsValues.values);
60 const auto biases = ngraph::op::Constant::create(ngraph::element::f32, constShape, biasesValues.values);
62 std::shared_ptr<ngraph::Node> lastNode;
63 if (isDequantization) {
64 std::shared_ptr<Node> scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(input, weights, biases, precisionAfterOperation);
65 scaleshift = low_precision::NetworkHelper::markAsDequantizationOp(scaleshift);
66 scaleshift->set_friendly_name("add");
67 lastNode = scaleshift;
69 float scale = weightsValues.values[0];
70 float shift = biasesValues.values[0];
71 const auto power = std::make_shared<ngraph::op::PowerIE>(input, 1.f, scale, shift, precisionAfterOperation);
72 power->set_friendly_name("add");
77 ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(lastNode) };
78 return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MulAddToScaleshiftOrPowerFunction");
80 } // namespace subgraph
81 } // namespace builder