1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
20 #include "ngraph/attribute_visitor.hpp"
21 #include "ngraph/axis_set.hpp"
22 #include "ngraph/builder/norm.hpp"
23 #include "ngraph/builder/reshape.hpp"
24 #include "ngraph/op/broadcast.hpp"
25 #include "ngraph/op/constant.hpp"
26 #include "ngraph/op/divide.hpp"
27 #include "ngraph/shape.hpp"
30 using namespace ngraph;
32 NGRAPH_SUPPRESS_DEPRECATED_START
34 constexpr NodeTypeInfo op::GRN::type_info;
36 op::GRN::GRN(const Output<Node>& data, float bias)
40 constructor_validate_and_infer_types();
43 bool ngraph::op::v0::GRN::visit_attributes(AttributeVisitor& visitor)
45 visitor.on_attribute("bias", m_bias);
49 void op::GRN::pre_validate_and_infer_types()
51 const auto& data_pshape = get_input_partial_shape(0);
53 if (data_pshape.is_static())
55 const Shape& data_shape{data_pshape.to_shape()};
57 // Input data must be 2, 3 or 4D tensor.
58 NODE_VALIDATION_CHECK(this,
59 (data_shape.size() >= 2 && data_shape.size() <= 4),
60 "Input tensor rank must be 2, 3 or 4 dimensional (actual input "
67 OutputVector op::GRN::decompose_op() const
69 Output<Node> data{input_value(0)};
70 const Shape& input_shape{data.get_shape()};
72 // Reshape to 4D tensor.
73 if (input_shape.size() != 4)
75 Shape data_shape(4 - input_shape.size(), 1);
76 copy(begin(input_shape), end(input_shape), back_inserter(data_shape));
77 data = builder::opset1::reshape(data, data_shape);
80 const auto axis_set_const = op::Constant::create(element::i64, {}, {1});
81 // Calculate l2 norm across channels.
82 shared_ptr<Node> norm = builder::opset1::l2_norm(data, axis_set_const, m_bias);
83 // Get back reduced axis.
84 norm = std::make_shared<Broadcast>(norm, data.get_shape(), AxisSet{1});
87 // get back original input tensor rank
88 if (input_shape.size() != 4)
90 data = builder::opset1::reshape(data, input_shape);
93 return OutputVector{data};
96 shared_ptr<Node> op::GRN::clone_with_new_inputs(const OutputVector& new_args) const
98 if (new_args.size() != 1)
100 throw ngraph_error("Incorrect number of new arguments");
102 return make_shared<GRN>(new_args.at(0), m_bias);