Remove obsoleted v0::Broadcast and BroadcastLike operators (#2779)
[platform/upstream/dldt.git] / ngraph / core / src / op / grn.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 #include <algorithm>
17 #include <iterator>
18
19 #include "grn.hpp"
20 #include "ngraph/attribute_visitor.hpp"
21 #include "ngraph/axis_set.hpp"
22 #include "ngraph/builder/autobroadcast.hpp"
23 #include "ngraph/builder/norm.hpp"
24 #include "ngraph/builder/reshape.hpp"
25 #include "ngraph/op/broadcast.hpp"
26 #include "ngraph/op/constant.hpp"
27 #include "ngraph/op/divide.hpp"
28 #include "ngraph/shape.hpp"
29
30 using namespace std;
31 using namespace ngraph;
32
33 NGRAPH_SUPPRESS_DEPRECATED_START
34
35 constexpr NodeTypeInfo op::GRN::type_info;
36
37 op::GRN::GRN(const Output<Node>& data, float bias)
38     : FusedOp({data})
39     , m_bias(bias)
40 {
41     constructor_validate_and_infer_types();
42 }
43
44 bool ngraph::op::v0::GRN::visit_attributes(AttributeVisitor& visitor)
45 {
46     visitor.on_attribute("bias", m_bias);
47     return true;
48 }
49
50 void op::GRN::pre_validate_and_infer_types()
51 {
52     const auto& data_pshape = get_input_partial_shape(0);
53
54     if (data_pshape.is_static())
55     {
56         const Shape& data_shape{data_pshape.to_shape()};
57
58         // Input data must be 2, 3 or 4D tensor.
59         NODE_VALIDATION_CHECK(this,
60                               (data_shape.size() >= 2 && data_shape.size() <= 4),
61                               "Input tensor rank must be 2, 3 or 4 dimensional (actual input "
62                               "shape: ",
63                               data_shape,
64                               ").");
65     }
66 }
67
68 OutputVector op::GRN::decompose_op() const
69 {
70     Output<Node> data{input_value(0)};
71     const Shape& input_shape{data.get_shape()};
72
73     // Reshape to 4D tensor.
74     if (input_shape.size() != 4)
75     {
76         Shape data_shape(4 - input_shape.size(), 1);
77         copy(begin(input_shape), end(input_shape), back_inserter(data_shape));
78         data = builder::opset1::reshape(data, data_shape);
79     }
80
81     const auto axis_set_const = op::Constant::create(element::i64, {}, {1});
82     // Calculate l2 norm across channels.
83     shared_ptr<Node> norm = builder::opset1::l2_norm(data, axis_set_const, m_bias);
84     // Get back reduced axis.
85     data = std::make_shared<op::v1::Divide>(
86         data, builder::opset1::make_broadcast(norm, data.get_shape(), AxisSet{1}));
87
88     // get back original input tensor rank
89     if (input_shape.size() != 4)
90     {
91         data = builder::opset1::reshape(data, input_shape);
92     }
93
94     return OutputVector{data};
95 }
96
97 shared_ptr<Node> op::GRN::clone_with_new_inputs(const OutputVector& new_args) const
98 {
99     if (new_args.size() != 1)
100     {
101         throw ngraph_error("Incorrect number of new arguments");
102     }
103     return make_shared<GRN>(new_args.at(0), m_bias);
104 }