f4733ac89f78ae6d7d50325e3bcc7a67517b96c7
[platform/upstream/dldt.git] / inference-engine / tests / ngraph_functions / src / low_precision_transformations / convert_mul_or_add_finally_with_dequantization_function.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ngraph_functions/low_precision_transformations/convert_mul_or_add_finally_with_dequantization_function.hpp"
6
7 #include <memory>
8 #include <vector>
9 #include <ngraph/ngraph.hpp>
10
11
12 #include <ngraph/opsets/opset1.hpp>
13 #include "ngraph_functions/subgraph_builders.hpp"
14 #include "ngraph_functions/low_precision_transformations/common/builders.hpp"
15 #include "low_precision/network_helper.hpp"
16 #include <legacy/ngraph_ops/scaleshift.hpp>
17 #include "low_precision/common/dequantization_op.hpp"
18
19 namespace ngraph {
20 namespace builder {
21 namespace subgraph {
22
23 using namespace ngraph::pass;
24
25 std::shared_ptr<ngraph::Function> ConvertMulOrAddWithDequantizationFunction::getOriginal(
26     const ngraph::Shape& inputShape,
27     const ngraph::element::Type inputPrecision,
28     const std::vector<float>& multiplyConst) {
29     const auto input = std::make_shared<ngraph::opset1::Parameter>(inputPrecision, inputShape);
30     const auto reluOriginal = ngraph::opset1::Relu(
31         ngraph::op::TemporaryReplaceOutputType(input, element::f32).get());
32
33     std::shared_ptr<ngraph::opset1::Relu> relu = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Relu>>(
34         reluOriginal,
35         std::vector<element::Type>{ element::f32, element::f32 },
36         std::vector<element::Type>{});
37
38
39     const auto multiply = std::make_shared<ngraph::pass::low_precision::DequantizationMultiply>(relu,
40                                                             std::make_shared<opset1::Constant>(element::f32, inputShape, multiplyConst));
41
42     multiply->set_friendly_name("output");
43
44     ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(multiply) };
45     return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input },
46                                               "ConvertMulOrAddTransformationWithDequantization");
47 }
48
49 std::shared_ptr<ngraph::Function> ConvertMulOrAddWithDequantizationFunction::getReference(
50     const ngraph::Shape& inputShape,
51     const ngraph::element::Type inputPrecision,
52     const std::vector<float>& multiplyConst) {
53     const auto input = std::make_shared<ngraph::opset1::Parameter>(inputPrecision, inputShape);
54     const auto reluOriginal = ngraph::opset1::Relu(
55         ngraph::op::TemporaryReplaceOutputType(input, element::f32).get());
56
57     std::shared_ptr<ngraph::opset1::Relu> relu = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Relu>>(
58         reluOriginal,
59         std::vector<element::Type>{ element::f32, element::f32 },
60         std::vector<element::Type>{});
61
62     const auto weights = std::make_shared<opset1::Constant>(element::f32, inputShape, multiplyConst);
63     const auto bias = std::make_shared<opset1::Constant>(element::f32, inputShape, 0.0);
64     std::shared_ptr<Node> scaleShift = std::make_shared<ngraph::op::ScaleShiftIE>(relu, weights, bias);
65
66     scaleShift = low_precision::NetworkHelper::markAsDequantizationOp(scaleShift);
67
68     scaleShift->set_friendly_name("output");
69
70     ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(scaleShift) };
71     return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "ConvertMulOrAddTransformationWithDequantization");
72 }
73
74 }  // namespace subgraph
75 }  // namespace builder
76 }  // namespace ngraph