[LPT] integration branch: Reshape fix, Concat generalization, runtime info usage...
[platform/upstream/dldt.git] / inference-engine / tests / ngraph_functions / src / low_precision_transformations / mvn_function.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ngraph_functions/low_precision_transformations/mvn_function.hpp"
6
7 #include "ngraph_functions/subgraph_builders.hpp"
8 #include "ngraph_functions/low_precision_transformations/common/builders.hpp"
9 #include "ngraph_ops/type_relaxed.hpp"
10
11 namespace ngraph {
12 namespace builder {
13 namespace subgraph {
14
15 std::shared_ptr<ngraph::Function> MVNFunction::getOriginal(
16     const ngraph::Shape& inputShape,
17     const AxisSet& reductionAxes,
18     const bool& normalizeVariance,
19     const ngraph::element::Type precisionBeforeDequantization,
20     const ngraph::builder::subgraph::DequantizationOperations& dequantization) {
21     const std::shared_ptr<op::v0::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
22         precisionBeforeDequantization,
23         ngraph::Shape(inputShape));
24
25     const auto dequantizationOp = makeDequantization(input, dequantization);
26     const auto mvn = std::make_shared<ngraph::op::MVN>(dequantizationOp, reductionAxes, normalizeVariance);
27     mvn->set_friendly_name("output");
28     auto& rtInfo = mvn->get_rt_info();
29     rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("mvn");
30
31     ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(mvn) };
32     return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MVNFunction");
33 }
34
35 std::shared_ptr<ngraph::Function> MVNFunction::getOriginal(
36     const ngraph::element::Type precision,
37     const ngraph::Shape& inputShape,
38     const AxisSet& reductionAxes,
39     const bool& normalizeVariance) {
40     float k = 50.f;
41
42     const auto input = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
43     const auto fakeQuantizeOnActivations = ngraph::builder::makeFakeQuantize(
44         input, precision, 256ul, { 1ul },
45         { 0.f }, { 255.f / k }, { 0.f }, { 255.f / k });
46     const auto mvn = std::make_shared<ngraph::op::MVN>(fakeQuantizeOnActivations, reductionAxes, normalizeVariance);
47
48     ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(mvn) };
49     return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MVNFunction");
50 }
51
52 std::shared_ptr<ngraph::Function> MVNFunction::getReference(
53     const ngraph::Shape& inputShape,
54     const AxisSet& reductionAxes,
55     const bool& normalizeVariance,
56     const ngraph::element::Type precisionBeforeDequantization,
57     const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
58     const ngraph::element::Type precisionAfterOperation,
59     const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter) {
60     const std::shared_ptr<op::v0::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
61         precisionBeforeDequantization,
62         ngraph::Shape(inputShape));
63
64     const std::shared_ptr<Node> dequantizationOpBefore = makeDequantization(input, dequantizationBefore);
65     const auto mvn = std::make_shared<ngraph::op::TypeRelaxed<ngraph::op::MVN>>(
66         op::MVN(dequantizationOpBefore, reductionAxes, normalizeVariance), precisionAfterOperation);
67     auto& rtInfo = mvn->get_rt_info();
68     rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("mvn");
69
70     const std::shared_ptr<Node> dequantizationOpAfter = makeDequantization(mvn, dequantizationAfter);
71     dequantizationOpAfter->set_friendly_name("output");
72
73     ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(dequantizationOpAfter) };
74     return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MVNFunction");
75 }
76
77 }  // namespace subgraph
78 }  // namespace builder
79 }  // namespace ngraph