1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "ngraph/op/prelu.hpp"
18 #include <ngraph/runtime/reference/prelu.hpp>
21 #include "ngraph/builder/autobroadcast.hpp"
22 #include "ngraph/op/add.hpp"
23 #include "ngraph/op/broadcast.hpp"
24 #include "ngraph/op/constant.hpp"
25 #include "ngraph/op/convert.hpp"
26 #include "ngraph/op/greater.hpp"
27 #include "ngraph/op/less.hpp"
28 #include "ngraph/op/multiply.hpp"
31 using namespace ngraph;
33 NGRAPH_SUPPRESS_DEPRECATED_START
35 NGRAPH_RTTI_DEFINITION(op::PRelu, "PRelu", 0);
37 op::PRelu::PRelu(const Output<Node>& data, const Output<Node>& slope)
38 : FusedOp({data, slope})
40 constructor_validate_and_infer_types();
43 bool ngraph::op::v0::PRelu::visit_attributes(AttributeVisitor& visitor)
48 void ngraph::op::v0::PRelu::pre_validate_and_infer_types()
50 set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
53 OutputVector op::PRelu::decompose_op() const
55 auto data = input_value(0);
56 auto data_shape = data.get_shape();
57 auto slope = input_value(1);
58 slope = std::make_shared<op::Convert>(slope, data.get_element_type());
59 auto slope_shape = slope.get_shape();
61 if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1))
63 auto it = std::find(std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
64 auto index = std::distance(std::begin(data_shape), it);
65 slope = builder::make_broadcast_node(slope, data.get_shape(), index);
67 else if (data_shape != slope_shape)
69 slope = builder::numpy_broadcast(slope, data.get_shape());
72 // x < 0 => f(x) = x * slope
75 std::shared_ptr<ngraph::Node> zero_node = make_zero(data.get_element_type(), data.get_shape());
77 std::shared_ptr<ngraph::Node> negative_map = std::make_shared<ngraph::op::Convert>(
78 std::make_shared<ngraph::op::Less>(data, zero_node), data.get_element_type());
80 std::shared_ptr<ngraph::Node> positive_map = std::make_shared<ngraph::op::Convert>(
81 std::make_shared<ngraph::op::Greater>(data, zero_node), data.get_element_type());
83 slope = negative_map * slope + positive_map;
85 return {data * slope};
88 shared_ptr<Node> op::PRelu::clone_with_new_inputs(const OutputVector& new_args) const
90 if (new_args.size() != 2)
92 throw ngraph_error("Incorrect number of new arguments");
94 return make_shared<PRelu>(new_args.at(0), new_args.at(1));
99 template <element::Type_t ET>
100 bool evaluate(const HostTensorPtr& arg, const HostTensorPtr& slope, const HostTensorPtr& out)
102 runtime::reference::prelu(arg->get_data_ptr<ET>(),
103 slope->get_data_ptr<ET>(),
104 out->get_data_ptr<ET>(),
110 bool evaluate_prelu(const HostTensorPtr& arg,
111 const HostTensorPtr& slope,
112 const HostTensorPtr& out)
115 switch (arg->get_element_type())
117 TYPE_CASE(i8)(arg, slope, out);
119 TYPE_CASE(bf16)(arg, slope, out);
121 TYPE_CASE(f16)(arg, slope, out);
123 TYPE_CASE(f32)(arg, slope, out);
125 default: rc = false; break;
131 bool op::PRelu::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
133 OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::PRelu::evaluate");
134 return prelu::evaluate_prelu(inputs[0], inputs[1], outputs[0]);