1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ngraph_functions/low_precision_transformations/mvn_function.hpp"
7 #include "ngraph_functions/subgraph_builders.hpp"
8 #include "ngraph_functions/low_precision_transformations/common/builders.hpp"
9 #include "ngraph_ops/type_relaxed.hpp"
15 std::shared_ptr<ngraph::Function> MVNFunction::getOriginal(
16 const ngraph::Shape& inputShape,
17 const AxisSet& reductionAxes,
18 const bool& normalizeVariance,
19 const ngraph::element::Type precisionBeforeDequantization,
20 const ngraph::builder::subgraph::DequantizationOperations& dequantization) {
21 const std::shared_ptr<op::v0::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
22 precisionBeforeDequantization,
23 ngraph::Shape(inputShape));
25 const auto dequantizationOp = makeDequantization(input, dequantization);
26 const auto mvn = std::make_shared<ngraph::op::MVN>(dequantizationOp, reductionAxes, normalizeVariance);
27 mvn->set_friendly_name("output");
28 auto& rtInfo = mvn->get_rt_info();
29 rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("mvn");
31 ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(mvn) };
32 return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MVNFunction");
35 std::shared_ptr<ngraph::Function> MVNFunction::getOriginal(
36 const ngraph::element::Type precision,
37 const ngraph::Shape& inputShape,
38 const AxisSet& reductionAxes,
39 const bool& normalizeVariance) {
42 const auto input = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
43 const auto fakeQuantizeOnActivations = ngraph::builder::makeFakeQuantize(
44 input, precision, 256ul, { 1ul },
45 { 0.f }, { 255.f / k }, { 0.f }, { 255.f / k });
46 const auto mvn = std::make_shared<ngraph::op::MVN>(fakeQuantizeOnActivations, reductionAxes, normalizeVariance);
48 ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(mvn) };
49 return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MVNFunction");
52 std::shared_ptr<ngraph::Function> MVNFunction::getReference(
53 const ngraph::Shape& inputShape,
54 const AxisSet& reductionAxes,
55 const bool& normalizeVariance,
56 const ngraph::element::Type precisionBeforeDequantization,
57 const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
58 const ngraph::element::Type precisionAfterOperation,
59 const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter) {
60 const std::shared_ptr<op::v0::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
61 precisionBeforeDequantization,
62 ngraph::Shape(inputShape));
64 const std::shared_ptr<Node> dequantizationOpBefore = makeDequantization(input, dequantizationBefore);
65 const auto mvn = std::make_shared<ngraph::op::TypeRelaxed<ngraph::op::MVN>>(
66 op::MVN(dequantizationOpBefore, reductionAxes, normalizeVariance), precisionAfterOperation);
67 auto& rtInfo = mvn->get_rt_info();
68 rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("mvn");
70 const std::shared_ptr<Node> dequantizationOpAfter = makeDequantization(mvn, dequantizationAfter);
71 dequantizationOpAfter->set_friendly_name("output");
73 ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(dequantizationOpAfter) };
74 return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MVNFunction");
77 } // namespace subgraph
78 } // namespace builder