54a595d1c943079f8ee13bb703a5f8185cf7e2bc
[platform/upstream/dldt.git] / inference-engine / tests / ngraph_functions / src / low_precision_transformations / mvn_function.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ngraph_functions/low_precision_transformations/mvn_function.hpp"
6
7 #include "ngraph_functions/subgraph_builders.hpp"
8 #include "ngraph_functions/low_precision_transformations/common/builders.hpp"
9 #include "ngraph_ops/type_relaxed.hpp"
10
11 namespace ngraph {
12 namespace builder {
13 namespace subgraph {
14
15 std::shared_ptr<ngraph::Function> MVNFunction::getOriginal(
16     const ngraph::Shape& inputShape,
17     const AxisSet& reductionAxes,
18     const bool& normalizeVariance,
19     const ngraph::element::Type precisionBeforeDequantization,
20     const ngraph::builder::subgraph::DequantizationOperations& dequantization) {
21     const std::shared_ptr<op::v0::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
22         precisionBeforeDequantization,
23         ngraph::Shape(inputShape));
24
25     const auto dequantizationOp = makeDequantization(input, dequantization);
26     const auto mvn = std::make_shared<ngraph::op::MVN>(dequantizationOp, reductionAxes, normalizeVariance);
27     mvn->set_friendly_name("output");
28
29     ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(mvn) };
30     return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MVNFunction");
31 }
32
33 std::shared_ptr<ngraph::Function> MVNFunction::getOriginal(
34     const ngraph::element::Type precision,
35     const ngraph::Shape& inputShape,
36     const AxisSet& reductionAxes,
37     const bool& normalizeVariance) {
38     float k = 50.f;
39
40     const auto input = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
41     const auto fakeQuantizeOnActivations = ngraph::builder::makeFakeQuantize(
42         input, precision, 256ul, { 1ul },
43         { 0.f }, { 255.f / k }, { 0.f }, { 255.f / k });
44     const auto mvn = std::make_shared<ngraph::op::MVN>(fakeQuantizeOnActivations, reductionAxes, normalizeVariance);
45
46     ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(mvn) };
47     return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MVNFunction");
48 }
49
50 std::shared_ptr<ngraph::Function> MVNFunction::getReference(
51     const ngraph::Shape& inputShape,
52     const AxisSet& reductionAxes,
53     const bool& normalizeVariance,
54     const ngraph::element::Type precisionBeforeDequantization,
55     const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
56     const ngraph::element::Type precisionAfterOperation,
57     const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter) {
58     const std::shared_ptr<op::v0::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
59         precisionBeforeDequantization,
60         ngraph::Shape(inputShape));
61
62     const std::shared_ptr<Node> dequantizationOpBefore = makeDequantization(input, dequantizationBefore);
63     const auto mvn = std::make_shared<ngraph::op::TypeRelaxed<ngraph::op::MVN>>(
64         op::MVN(dequantizationOpBefore, reductionAxes, normalizeVariance), precisionAfterOperation);
65     const std::shared_ptr<Node> dequantizationOpAfter = makeDequantization(mvn, dequantizationAfter);
66     dequantizationOpAfter->set_friendly_name("output");
67
68     ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(dequantizationOpAfter) };
69     return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MVNFunction");
70 }
71
72 }  // namespace subgraph
73 }  // namespace builder
74 }  // namespace ngraph