Remove obsoleted v0::Broadcast and BroadcastLike operators (#2779)
[platform/upstream/dldt.git] / ngraph / core / src / op / prelu.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "ngraph/op/prelu.hpp"
18 #include <ngraph/runtime/reference/prelu.hpp>
19 #include "itt.hpp"
20
21 #include "ngraph/builder/autobroadcast.hpp"
22 #include "ngraph/op/add.hpp"
23 #include "ngraph/op/broadcast.hpp"
24 #include "ngraph/op/constant.hpp"
25 #include "ngraph/op/convert.hpp"
26 #include "ngraph/op/greater.hpp"
27 #include "ngraph/op/less.hpp"
28 #include "ngraph/op/multiply.hpp"
29
30 using namespace std;
31 using namespace ngraph;
32
33 NGRAPH_SUPPRESS_DEPRECATED_START
34
35 NGRAPH_RTTI_DEFINITION(op::PRelu, "PRelu", 0);
36
37 op::PRelu::PRelu(const Output<Node>& data, const Output<Node>& slope)
38     : FusedOp({data, slope})
39 {
40     constructor_validate_and_infer_types();
41 }
42
43 bool ngraph::op::v0::PRelu::visit_attributes(AttributeVisitor& visitor)
44 {
45     return true;
46 }
47
48 void ngraph::op::v0::PRelu::pre_validate_and_infer_types()
49 {
50     set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
51 }
52
53 OutputVector op::PRelu::decompose_op() const
54 {
55     auto data = input_value(0);
56     auto data_shape = data.get_shape();
57     auto slope = input_value(1);
58     slope = std::make_shared<op::Convert>(slope, data.get_element_type());
59     auto slope_shape = slope.get_shape();
60
61     if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1))
62     {
63         auto it = std::find(std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
64         auto index = std::distance(std::begin(data_shape), it);
65         slope = builder::make_broadcast_node(slope, data.get_shape(), index);
66     }
67     else if (data_shape != slope_shape)
68     {
69         slope = builder::numpy_broadcast(slope, data.get_shape());
70     }
71
72     // x <  0 => f(x) = x * slope
73     // x >= 0 => f(x) = x
74
75     std::shared_ptr<ngraph::Node> zero_node = make_zero(data.get_element_type(), data.get_shape());
76
77     std::shared_ptr<ngraph::Node> negative_map = std::make_shared<ngraph::op::Convert>(
78         std::make_shared<ngraph::op::Less>(data, zero_node), data.get_element_type());
79
80     std::shared_ptr<ngraph::Node> positive_map = std::make_shared<ngraph::op::Convert>(
81         std::make_shared<ngraph::op::Greater>(data, zero_node), data.get_element_type());
82
83     slope = negative_map * slope + positive_map;
84
85     return {data * slope};
86 }
87
88 shared_ptr<Node> op::PRelu::clone_with_new_inputs(const OutputVector& new_args) const
89 {
90     if (new_args.size() != 2)
91     {
92         throw ngraph_error("Incorrect number of new arguments");
93     }
94     return make_shared<PRelu>(new_args.at(0), new_args.at(1));
95 }
96
97 namespace prelu
98 {
99     template <element::Type_t ET>
100     bool evaluate(const HostTensorPtr& arg, const HostTensorPtr& slope, const HostTensorPtr& out)
101     {
102         runtime::reference::prelu(arg->get_data_ptr<ET>(),
103                                   slope->get_data_ptr<ET>(),
104                                   out->get_data_ptr<ET>(),
105                                   arg->get_shape(),
106                                   slope->get_shape());
107         return true;
108     }
109
110     bool evaluate_prelu(const HostTensorPtr& arg,
111                         const HostTensorPtr& slope,
112                         const HostTensorPtr& out)
113     {
114         bool rc = true;
115         switch (arg->get_element_type())
116         {
117             TYPE_CASE(i8)(arg, slope, out);
118             break;
119             TYPE_CASE(bf16)(arg, slope, out);
120             break;
121             TYPE_CASE(f16)(arg, slope, out);
122             break;
123             TYPE_CASE(f32)(arg, slope, out);
124             break;
125         default: rc = false; break;
126         }
127         return rc;
128     }
129 }
130
131 bool op::PRelu::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
132 {
133     OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::PRelu::evaluate");
134     return prelu::evaluate_prelu(inputs[0], inputs[1], outputs[0]);
135 }