e30526f052f21264c7b71d026b508623e99f8490
[platform/upstream/dldt.git] / inference-engine / tests / ngraph_functions / src / low_precision_transformations / mul_add_to_scaleshift_or_power_function.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ngraph_functions/low_precision_transformations/mul_add_to_scaleshift_or_power_function.hpp"
6
7 #include <ngraph/opsets/opset1.hpp>
8 #include "ngraph_ops/type_relaxed.hpp"
9 #include "low_precision/network_helper.hpp"
10 #include "low_precision/common/dequantization_op.hpp"
11
12 #include <legacy/ngraph_ops/power.hpp>
13 #include <legacy/ngraph_ops/scaleshift.hpp>
14
15 #include "ngraph_functions/subgraph_builders.hpp"
16 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
17
18 namespace ngraph {
19 namespace builder {
20 namespace subgraph {
21     using namespace ngraph::pass;
22     std::shared_ptr<ngraph::Function> MulAddToScaleshiftOrPowerFunction::getOriginal(
23         const ngraph::element::Type precision,
24         const ngraph::Shape& inputShape,
25         bool isDequantization,
26         const ngraph::builder::subgraph::DequantizationOperations::Multiply& mulValues,
27         const ngraph::builder::subgraph::Add& addValues) {
28         const auto input = std::make_shared<ngraph::op::v0::Parameter>(precision, inputShape);
29
30         const auto mulConst = ngraph::op::Constant::create(ngraph::element::f32, mulValues.constantShape, mulValues.values);
31         const auto mul = std::make_shared<ngraph::op::TypeRelaxed<ngraph::pass::low_precision::DequantizationMultiply>>(
32             std::vector<element::Type>{element::f32, element::f32}, std::vector<element::Type>{ element::f32 },
33             ngraph::op::TemporaryReplaceOutputType(input, element::f32).get(),
34             ngraph::op::TemporaryReplaceOutputType(mulConst, element::f32).get());
35
36         const auto addConst = ngraph::op::Constant::create(ngraph::element::f32, addValues.constantShape, addValues.values);
37         const auto add = std::make_shared<ngraph::pass::low_precision::DequantizationAdd>(mul, addConst);
38         add->set_friendly_name("add");
39
40         if (!isDequantization) {
41             ngraph::pass::low_precision::NetworkHelper::cleanRunTimeInfo(mul);
42             ngraph::pass::low_precision::NetworkHelper::cleanRunTimeInfo(add);
43         }
44
45         ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(add) };
46         return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MulAddToScaleshiftOrPowerFunction");
47     }
48
49     std::shared_ptr<ngraph::Function> MulAddToScaleshiftOrPowerFunction::getReference(
50         const ngraph::element::Type precision,
51         const ngraph::Shape& inputShape,
52         bool isDequantization,
53         const ngraph::builder::subgraph::DequantizationOperations::Multiply& weightsValues,
54         const ngraph::builder::subgraph::Add& biasesValues,
55         const ngraph::element::Type precisionAfterOperation) {
56         const auto input = std::make_shared<ngraph::op::v0::Parameter>(precision, inputShape);
57
58         ngraph::Shape constShape = { 1, inputShape[1], 1, 1 };
59         const auto weights = ngraph::op::Constant::create(ngraph::element::f32, constShape, weightsValues.values);
60         const auto biases = ngraph::op::Constant::create(ngraph::element::f32, constShape, biasesValues.values);
61
62         std::shared_ptr<ngraph::Node> lastNode;
63         if (isDequantization) {
64             std::shared_ptr<Node> scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(input, weights, biases, precisionAfterOperation);
65             scaleshift = low_precision::NetworkHelper::markAsDequantizationOp(scaleshift);
66             scaleshift->set_friendly_name("add");
67             lastNode = scaleshift;
68         } else {
69             float scale = weightsValues.values[0];
70             float shift = biasesValues.values[0];
71             const auto power = std::make_shared<ngraph::op::PowerIE>(input, 1.f, scale, shift, precisionAfterOperation);
72             power->set_friendly_name("add");
73             lastNode = power;
74         }
75
76
77         ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(lastNode) };
78         return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MulAddToScaleshiftOrPowerFunction");
79     }
80 }  // namespace subgraph
81 }  // namespace builder
82 }  // namespace ngraph