1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
19 #include "ngraph/builder/autobroadcast.hpp"
20 #include "ngraph/builder/reduce_ops.hpp"
21 #include "ngraph/op/add.hpp"
22 #include "ngraph/op/broadcast.hpp"
23 #include "ngraph/op/constant.hpp"
24 #include "ngraph/op/divide.hpp"
25 #include "ngraph/op/sqrt.hpp"
26 #include "ngraph/op/subtract.hpp"
29 using namespace ngraph;
31 NGRAPH_SUPPRESS_DEPRECATED_START
33 NGRAPH_RTTI_DEFINITION(op::v0::MVN, "MVN", 0);
35 op::MVN::MVN(const Output<Node>& data, bool across_channels, bool normalize_variance, double eps)
38 , m_across_channels{across_channels}
39 , m_normalize_variance{normalize_variance}
41 constructor_validate_and_infer_types();
44 op::MVN::MVN(const Output<Node>& data, AxisSet reduction_axes, bool normalize_variance, double eps)
47 , m_across_channels{false}
48 , m_normalize_variance{normalize_variance}
49 , m_reduction_axes{reduction_axes}
51 constructor_validate_and_infer_types();
54 // decompose_op() relies on knowing the data type of input data which might
55 // not be available at shape inference time. So do direct shape inference
56 // instead of relying on op decomposition.
57 void op::MVN::validate_and_infer_types()
59 // if m_across_channels is true we should calculate mean and variance per batch
60 // else we calculate these per channel
61 if (m_reduction_axes.empty() && input_value(0).get_partial_shape().rank().is_static())
63 AxisSet reduction_axes;
64 size_t start_axis = m_across_channels ? 1 : 2;
65 for (size_t i = start_axis; i < input_value(0).get_partial_shape().rank().get_length(); ++i)
67 reduction_axes.insert(i);
69 set_reduction_axes(reduction_axes);
72 set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
75 OutputVector op::MVN::decompose_op() const
77 auto data = input_value(0);
78 auto data_shape = data.get_shape(); // assume that data has n and c channels.
80 // calculate mean normalization
81 auto mean = builder::opset1::mean(data, m_reduction_axes);
82 auto mean_normalization =
83 data - builder::opset1::make_broadcast(mean, data_shape, m_reduction_axes);
85 if (!m_normalize_variance)
87 return {mean_normalization};
92 auto variance = builder::opset1::variance(data, m_reduction_axes);
94 auto eps_node = op::Constant::create(
95 data.get_element_type(), Output<Node>(variance).get_shape(), vector<double>{m_eps});
96 variance = std::make_shared<op::Sqrt>(variance + eps_node);
98 return OutputVector{mean_normalization / builder::opset1::make_broadcast(
99 variance, data_shape, m_reduction_axes)};
103 shared_ptr<Node> op::MVN::clone_with_new_inputs(const OutputVector& new_args) const
105 NODE_VALIDATION_CHECK(this,
106 new_args.size() == 1,
107 "Expected 1 element in new_args for the MVN op but got ",
109 return make_shared<MVN>(new_args.at(0), m_reduction_axes, m_normalize_variance, m_eps);
112 bool op::MVN::visit_attributes(AttributeVisitor& visitor)
114 visitor.on_attribute("eps", m_eps);
115 visitor.on_attribute("across_channels", m_across_channels);
116 visitor.on_attribute("normalize_variance", m_normalize_variance);
117 visitor.on_attribute("reduction_axes", m_reduction_axes);