bf16f8ff8c7c2ac1a368b902399f58c84f793983
[platform/upstream/dldt.git] / ngraph / core / src / op / relu.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "itt.hpp"
18
19 #include "ngraph/op/multiply.hpp"
20 #include "ngraph/op/relu.hpp"
21
22 #include "ngraph/runtime/host_tensor.hpp"
23 #include "ngraph/runtime/reference/relu.hpp"
24
25 using namespace std;
26 using namespace ngraph;
27
28 NGRAPH_RTTI_DEFINITION(op::Relu, "Relu", 0);
29
30 op::Relu::Relu(const Output<Node>& arg)
31     : UnaryElementwiseArithmetic(arg)
32 {
33     constructor_validate_and_infer_types();
34 }
35
36 shared_ptr<Node> op::Relu::clone_with_new_inputs(const OutputVector& new_args) const
37 {
38     check_new_args_count(this, new_args);
39     return make_shared<Relu>(new_args.at(0));
40 }
41
42 namespace
43 {
44     template <element::Type_t ET>
45     inline bool evaluate(const HostTensorPtr& arg0, const HostTensorPtr& out, const size_t count)
46     {
47         using T = typename element_type_traits<ET>::value_type;
48         runtime::reference::relu<T>(arg0->get_data_ptr<ET>(), out->get_data_ptr<ET>(), count);
49         return true;
50     }
51
52     bool evaluate_relu(const HostTensorPtr& arg0, const HostTensorPtr& out, const size_t count)
53     {
54         bool rc = true;
55         out->set_unary(arg0);
56
57         switch (arg0->get_element_type())
58         {
59             TYPE_CASE(boolean)(arg0, out, count);
60             break;
61             TYPE_CASE(i32)(arg0, out, count);
62             break;
63             TYPE_CASE(i64)(arg0, out, count);
64             break;
65             TYPE_CASE(u32)(arg0, out, count);
66             break;
67             TYPE_CASE(u64)(arg0, out, count);
68             break;
69             TYPE_CASE(f16)(arg0, out, count);
70             break;
71             TYPE_CASE(f32)(arg0, out, count);
72             break;
73         default: rc = false; break;
74         }
75         return rc;
76     }
77 }
78
79 bool op::Relu::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
80 {
81     OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::Relu::evaluate");
82     return evaluate_relu(inputs[0], outputs[0], shape_size(get_output_shape(0)));
83 }