1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "low_precision/mvn.hpp"
13 #include "ngraph/type/element_type.hpp"
14 #include "ngraph/type/element_type_traits.hpp"
15 #include "low_precision/network_helper.hpp"
16 #include "low_precision/common/dequantization_op.hpp"
18 using namespace ngraph;
19 using namespace ngraph::pass;
20 using namespace ngraph::pass::low_precision;
25 std::shared_ptr<ngraph::op::Constant> createNewScalesConst(const ngraph::op::Constant& originalConst) {
26 std::vector<T> source = originalConst.cast_vector<T>();
28 std::vector<T> newData(source.size());
29 for (size_t i = 0; i < source.size(); ++i) {
30 newData[i] = source[i] < 0 ? T{-1} : T{1};
33 const ngraph::element::Type type = originalConst.get_output_element_type(0);
34 return ngraph::op::Constant::create(type, originalConst.get_shape(), newData);
39 bool MVNTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
40 if (!LayerTransformation::canBeTransformed(context, operation)) {
44 if (!canSubtractBeHandled(operation)) {
48 auto mvn = as_type_ptr<op::MVN>(operation);
50 const std::shared_ptr<Node> multiply = mvn->get_input_node_shared_ptr(0);
51 auto scalesConst = as_type_ptr<ngraph::opset1::Constant>(multiply->get_input_node_shared_ptr(1));
52 if (scalesConst == nullptr) {
53 scalesConst = as_type_ptr<ngraph::opset1::Constant>(multiply->get_input_node_shared_ptr(0));
55 if (scalesConst == nullptr) {
59 const bool acrossChannels = mvn->get_reduction_axes().count(1) > 0;
60 const bool normalizeVariance = mvn->get_normalize_variance();
62 if (!NetworkHelper::isScalarLike(scalesConst) && acrossChannels) {
68 void MVNTransformation::registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const {
72 make_op_pattern<ngraph::op::MVN>({ make_op_label<ngraph::opset1::Multiply>() }));
75 bool MVNTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
76 std::shared_ptr<Node> operation = m.get_match_root();
77 if (!canBeTransformed(context, operation)) {
81 auto mvn = as_type_ptr<op::MVN>(separateInStandaloneBranch(operation));
83 FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(mvn);
84 auto scalesConst = as_type_ptr<opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(1));
85 if (scalesConst == nullptr) {
86 scalesConst = as_type_ptr<opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(0));
89 const bool acrossChannels = mvn->get_reduction_axes().count(1) > 0;
90 const bool normalizeVariance = mvn->get_normalize_variance();
92 auto newScalesConst = scalesConst;
93 const auto type = scalesConst->get_output_element_type(0);
94 if (normalizeVariance) {
96 case ngraph::element::Type_t::f16: {
97 newScalesConst = mvn::createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f16>::value_type>(*scalesConst);
100 case ngraph::element::Type_t::f32: {
101 newScalesConst = mvn::createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f32>::value_type>(*scalesConst);
105 THROW_TRANSFORMATION_EXCEPTION << "unexpected element type " << type;
110 auto newMVN = std::make_shared<op::TypeRelaxed<op::MVN>>(
111 op::MVN(dequantization.subtract ?
112 dequantization.subtract :
114 mvn->get_reduction_axes(),
115 mvn->get_normalize_variance(),
118 NetworkHelper::copyInfo(mvn, newMVN);
120 auto newMultiply = std::make_shared<DequantizationMultiply>(newMVN, newScalesConst);
121 ngraph::copy_runtime_info({ mvn, newMultiply }, newMultiply);
123 replace_node(mvn, newMultiply);
125 updateOutput(context, newMultiply, newMVN);
129 bool MVNTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {