f1cc1b053c66cfc3c467a11ce983c9eaa5158c93
[platform/upstream/dldt.git] / ngraph / core / src / op / prelu.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "ngraph/op/prelu.hpp"
18 #include <ngraph/runtime/reference/prelu.hpp>
19 #include "itt.hpp"
20
21 #include "ngraph/builder/autobroadcast.hpp"
22 #include "ngraph/op/add.hpp"
23 #include "ngraph/op/broadcast.hpp"
24 #include "ngraph/op/constant.hpp"
25 #include "ngraph/op/convert.hpp"
26 #include "ngraph/op/greater.hpp"
27 #include "ngraph/op/less.hpp"
28 #include "ngraph/op/multiply.hpp"
29
30 using namespace std;
31 using namespace ngraph;
32
33 NGRAPH_SUPPRESS_DEPRECATED_START
34
35 NGRAPH_RTTI_DEFINITION(op::PRelu, "PRelu", 0);
36
37 op::PRelu::PRelu(const Output<Node>& data, const Output<Node>& slope)
38     : FusedOp({data, slope})
39 {
40     constructor_validate_and_infer_types();
41 }
42
43 bool ngraph::op::v0::PRelu::visit_attributes(AttributeVisitor& visitor)
44 {
45     return true;
46 }
47
48 void ngraph::op::v0::PRelu::pre_validate_and_infer_types()
49 {
50     set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
51 }
52
53 OutputVector op::PRelu::decompose_op() const
54 {
55     auto data = input_value(0);
56     auto data_shape = data.get_shape();
57     auto slope = input_value(1);
58     slope = std::make_shared<op::Convert>(slope, data.get_element_type());
59     auto slope_shape = slope.get_shape();
60
61     if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1))
62     {
63         auto it = std::find(std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
64         auto index = std::distance(std::begin(data_shape), it);
65         slope = builder::make_broadcast_node(slope, data.get_shape(), index);
66     }
67     else if (data_shape != slope_shape)
68     {
69         slope = builder::numpy_broadcast(slope, data.get_shape());
70     }
71
72     // x <  0 => f(x) = x * slope
73     // x >= 0 => f(x) = x
74
75     std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>(
76         data.get_element_type(), ngraph::Shape{}, std::vector<double>{0});
77     zero_node = builder::make_broadcast_node(zero_node, data.get_shape());
78
79     std::shared_ptr<ngraph::Node> negative_map = std::make_shared<ngraph::op::Convert>(
80         std::make_shared<ngraph::op::Less>(data, zero_node), data.get_element_type());
81
82     std::shared_ptr<ngraph::Node> positive_map = std::make_shared<ngraph::op::Convert>(
83         std::make_shared<ngraph::op::Greater>(data, zero_node), data.get_element_type());
84
85     slope = negative_map * slope + positive_map;
86
87     return {data * slope};
88 }
89
90 shared_ptr<Node> op::PRelu::clone_with_new_inputs(const OutputVector& new_args) const
91 {
92     if (new_args.size() != 2)
93     {
94         throw ngraph_error("Incorrect number of new arguments");
95     }
96     return make_shared<PRelu>(new_args.at(0), new_args.at(1));
97 }
98
99 namespace prelu
100 {
101     template <element::Type_t ET>
102     bool evaluate(const HostTensorPtr& arg, const HostTensorPtr& slope, const HostTensorPtr& out)
103     {
104         runtime::reference::prelu(arg->get_data_ptr<ET>(),
105                                   slope->get_data_ptr<ET>(),
106                                   out->get_data_ptr<ET>(),
107                                   arg->get_shape(),
108                                   slope->get_shape());
109         return true;
110     }
111
112     bool evaluate_prelu(const HostTensorPtr& arg,
113                         const HostTensorPtr& slope,
114                         const HostTensorPtr& out)
115     {
116         bool rc = true;
117         switch (arg->get_element_type())
118         {
119             TYPE_CASE(i8)(arg, slope, out);
120             break;
121             TYPE_CASE(bf16)(arg, slope, out);
122             break;
123             TYPE_CASE(f16)(arg, slope, out);
124             break;
125             TYPE_CASE(f32)(arg, slope, out);
126             break;
127         default: rc = false; break;
128         }
129         return rc;
130     }
131 }
132
133 bool op::PRelu::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
134 {
135     OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::PRelu::evaluate");
136     return prelu::evaluate_prelu(inputs[0], inputs[1], outputs[0]);
137 }