fc63ed4e55984213bff3eb41e8b862db112c92ba
[platform/upstream/dldt.git] / ngraph / core / src / op / softplus.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "ngraph/op/softplus.hpp"
18 #include "itt.hpp"
19 #include "ngraph/attribute_visitor.hpp"
20 #include "ngraph/runtime/host_tensor.hpp"
21 #include "ngraph/runtime/reference/softplus.hpp"
22
23 using namespace std;
24 using namespace ngraph;
25
26 NGRAPH_RTTI_DEFINITION(op::v4::SoftPlus, "SoftPlus", 4);
27
28 op::v4::SoftPlus::SoftPlus(const Output<Node>& arg)
29     : Op({arg})
30 {
31     constructor_validate_and_infer_types();
32 }
33
34 bool op::v4::SoftPlus::visit_attributes(AttributeVisitor& visitor)
35 {
36     return true;
37 }
38
39 void op::v4::SoftPlus::validate_and_infer_types()
40 {
41     set_output_size(1);
42     set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
43 }
44
45 shared_ptr<Node> op::v4::SoftPlus::clone_with_new_inputs(const OutputVector& new_args) const
46 {
47     check_new_args_count(this, new_args);
48     return make_shared<op::v4::SoftPlus>(new_args.at(0));
49 }
50
51 namespace
52 {
53     template <element::Type_t ET>
54     inline bool evaluate(const HostTensorPtr& arg, const HostTensorPtr& out, const size_t count)
55     {
56         using T = typename element_type_traits<ET>::value_type;
57         runtime::reference::softplus<T>(arg->get_data_ptr<ET>(), out->get_data_ptr<ET>(), count);
58         return true;
59     }
60
61     bool evaluate_softplus(const HostTensorPtr& arg, const HostTensorPtr& out, const size_t count)
62     {
63         bool rc = true;
64         out->set_unary(arg);
65
66         switch (arg->get_element_type())
67         {
68             TYPE_CASE(bf16)(arg, out, count);
69             break;
70             TYPE_CASE(f16)(arg, out, count);
71             break;
72             TYPE_CASE(f32)(arg, out, count);
73             break;
74         default: rc = false; break;
75         }
76         return rc;
77     }
78 }
79
80 bool op::v4::SoftPlus::evaluate(const HostTensorVector& outputs,
81                                 const HostTensorVector& inputs) const
82 {
83     OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::SoftPlus::evaluate");
84     return evaluate_softplus(inputs[0], outputs[0], shape_size(get_output_shape(0)));
85 }