b176dfe75c1dbc51a1854eb634f6a8031363a98c
[platform/upstream/dldt.git] / ngraph / core / src / op / grn.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 #include <algorithm>
17 #include <iterator>
18
19 #include "grn.hpp"
20 #include "ngraph/attribute_visitor.hpp"
21 #include "ngraph/axis_set.hpp"
22 #include "ngraph/builder/norm.hpp"
23 #include "ngraph/builder/reshape.hpp"
24 #include "ngraph/op/broadcast.hpp"
25 #include "ngraph/op/constant.hpp"
26 #include "ngraph/op/divide.hpp"
27 #include "ngraph/shape.hpp"
28
29 using namespace std;
30 using namespace ngraph;
31
32 NGRAPH_SUPPRESS_DEPRECATED_START
33
34 constexpr NodeTypeInfo op::GRN::type_info;
35
36 op::GRN::GRN(const Output<Node>& data, float bias)
37     : FusedOp({data})
38     , m_bias(bias)
39 {
40     constructor_validate_and_infer_types();
41 }
42
43 bool ngraph::op::v0::GRN::visit_attributes(AttributeVisitor& visitor)
44 {
45     visitor.on_attribute("bias", m_bias);
46     return true;
47 }
48
49 void op::GRN::pre_validate_and_infer_types()
50 {
51     const auto& data_pshape = get_input_partial_shape(0);
52
53     if (data_pshape.is_static())
54     {
55         const Shape& data_shape{data_pshape.to_shape()};
56
57         // Input data must be 2, 3 or 4D tensor.
58         NODE_VALIDATION_CHECK(this,
59                               (data_shape.size() >= 2 && data_shape.size() <= 4),
60                               "Input tensor rank must be 2, 3 or 4 dimensional (actual input "
61                               "shape: ",
62                               data_shape,
63                               ").");
64     }
65 }
66
67 OutputVector op::GRN::decompose_op() const
68 {
69     Output<Node> data{input_value(0)};
70     const Shape& input_shape{data.get_shape()};
71
72     // Reshape to 4D tensor.
73     if (input_shape.size() != 4)
74     {
75         Shape data_shape(4 - input_shape.size(), 1);
76         copy(begin(input_shape), end(input_shape), back_inserter(data_shape));
77         data = builder::opset1::reshape(data, data_shape);
78     }
79
80     const auto axis_set_const = op::Constant::create(element::i64, {}, {1});
81     // Calculate l2 norm across channels.
82     shared_ptr<Node> norm = builder::opset1::l2_norm(data, axis_set_const, m_bias);
83     // Get back reduced axis.
84     norm = std::make_shared<Broadcast>(norm, data.get_shape(), AxisSet{1});
85     data = data / norm;
86
87     // get back original input tensor rank
88     if (input_shape.size() != 4)
89     {
90         data = builder::opset1::reshape(data, input_shape);
91     }
92
93     return OutputVector{data};
94 }
95
96 shared_ptr<Node> op::GRN::clone_with_new_inputs(const OutputVector& new_args) const
97 {
98     if (new_args.size() != 1)
99     {
100         throw ngraph_error("Incorrect number of new arguments");
101     }
102     return make_shared<GRN>(new_args.at(0), m_bias);
103 }