Remove obsoleted v0::Broadcast and BroadcastLike operators (#2779)
[platform/upstream/dldt.git] / ngraph / core / src / op / mvn.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 #include <algorithm>
17
18 #include "mvn.hpp"
19 #include "ngraph/builder/autobroadcast.hpp"
20 #include "ngraph/builder/reduce_ops.hpp"
21 #include "ngraph/op/add.hpp"
22 #include "ngraph/op/broadcast.hpp"
23 #include "ngraph/op/constant.hpp"
24 #include "ngraph/op/divide.hpp"
25 #include "ngraph/op/sqrt.hpp"
26 #include "ngraph/op/subtract.hpp"
27
28 using namespace std;
29 using namespace ngraph;
30
31 NGRAPH_SUPPRESS_DEPRECATED_START
32
33 NGRAPH_RTTI_DEFINITION(op::v0::MVN, "MVN", 0);
34
35 op::MVN::MVN(const Output<Node>& data, bool across_channels, bool normalize_variance, double eps)
36     : FusedOp({data})
37     , m_eps{eps}
38     , m_across_channels{across_channels}
39     , m_normalize_variance{normalize_variance}
40 {
41     constructor_validate_and_infer_types();
42 }
43
44 op::MVN::MVN(const Output<Node>& data, AxisSet reduction_axes, bool normalize_variance, double eps)
45     : FusedOp({data})
46     , m_eps{eps}
47     , m_across_channels{false}
48     , m_normalize_variance{normalize_variance}
49     , m_reduction_axes{reduction_axes}
50 {
51     constructor_validate_and_infer_types();
52 }
53
54 // decompose_op() relies on knowing the data type of input data which might
55 // not be available at shape inference time. So do direct shape inference
56 // instead of relying on op decomposition.
57 void op::MVN::validate_and_infer_types()
58 {
59     // if m_across_channels is true we should calculate mean and variance per batch
60     // else we calculate these per channel
61     if (m_reduction_axes.empty() && input_value(0).get_partial_shape().rank().is_static())
62     {
63         AxisSet reduction_axes;
64         size_t start_axis = m_across_channels ? 1 : 2;
65         for (size_t i = start_axis; i < input_value(0).get_partial_shape().rank().get_length(); ++i)
66         {
67             reduction_axes.insert(i);
68         }
69         set_reduction_axes(reduction_axes);
70     }
71
72     set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
73 }
74
75 OutputVector op::MVN::decompose_op() const
76 {
77     auto data = input_value(0);
78     auto data_shape = data.get_shape(); // assume that data has n and c channels.
79
80     // calculate mean normalization
81     auto mean = builder::opset1::mean(data, m_reduction_axes);
82     auto mean_normalization =
83         data - builder::opset1::make_broadcast(mean, data_shape, m_reduction_axes);
84
85     if (!m_normalize_variance)
86     {
87         return {mean_normalization};
88     }
89     else
90     {
91         // calculate variance
92         auto variance = builder::opset1::variance(data, m_reduction_axes);
93         // add epsilon
94         auto eps_node = op::Constant::create(
95             data.get_element_type(), Output<Node>(variance).get_shape(), vector<double>{m_eps});
96         variance = std::make_shared<op::Sqrt>(variance + eps_node);
97
98         return OutputVector{mean_normalization / builder::opset1::make_broadcast(
99                                                      variance, data_shape, m_reduction_axes)};
100     }
101 }
102
103 shared_ptr<Node> op::MVN::clone_with_new_inputs(const OutputVector& new_args) const
104 {
105     NODE_VALIDATION_CHECK(this,
106                           new_args.size() == 1,
107                           "Expected 1 element in new_args for the MVN op but got ",
108                           new_args.size());
109     return make_shared<MVN>(new_args.at(0), m_reduction_axes, m_normalize_variance, m_eps);
110 }
111
112 bool op::MVN::visit_attributes(AttributeVisitor& visitor)
113 {
114     visitor.on_attribute("eps", m_eps);
115     visitor.on_attribute("across_channels", m_across_channels);
116     visitor.on_attribute("normalize_variance", m_normalize_variance);
117     visitor.on_attribute("reduction_axes", m_reduction_axes);
118     return true;
119 }