7434068638983f25af71611c891152019c490706
[platform/upstream/dldt.git] / ngraph / core / src / op / greater.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "ngraph/op/greater.hpp"
18 #include "itt.hpp"
19 #include "ngraph/runtime/host_tensor.hpp"
20 #include "ngraph/runtime/reference/greater.hpp"
21
22 NGRAPH_SUPPRESS_DEPRECATED_START
23
24 using namespace std;
25 using namespace ngraph;
26
27 //-------------------------------------- v0 ------------------------------------
28
29 constexpr NodeTypeInfo op::v0::Greater::type_info;
30
31 op::v0::Greater::Greater(const Output<Node>& arg0,
32                          const Output<Node>& arg1,
33                          const AutoBroadcastSpec& auto_broadcast)
34     : BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
35 {
36     constructor_validate_and_infer_types();
37 }
38
39 shared_ptr<Node> op::v0::Greater::clone_with_new_inputs(const OutputVector& new_args) const
40 {
41     check_new_args_count(this, new_args);
42     return make_shared<op::v0::Greater>(new_args.at(0), new_args.at(1), this->get_autob());
43 }
44
45 namespace
46 {
47     template <element::Type_t ET>
48     bool evaluate(const HostTensorPtr& arg0,
49                   const HostTensorPtr& arg1,
50                   const HostTensorPtr& out,
51                   const op::AutoBroadcastSpec& broadcast_spec)
52     {
53         runtime::reference::greater(arg0->get_data_ptr<ET>(),
54                                     arg1->get_data_ptr<ET>(),
55                                     out->get_data_ptr<element::Type_t::boolean>(),
56                                     arg0->get_shape(),
57                                     arg1->get_shape(),
58                                     broadcast_spec);
59         return true;
60     }
61
62     bool evaluate_greater(const HostTensorPtr& arg0,
63                           const HostTensorPtr& arg1,
64                           const HostTensorPtr& out,
65                           const op::AutoBroadcastSpec& broadcast_spec)
66     {
67         bool rc = true;
68         out->set_broadcast(broadcast_spec, arg0, arg1, element::boolean);
69         switch (arg0->get_element_type())
70         {
71             TYPE_CASE(boolean)(arg0, arg1, out, broadcast_spec);
72             break;
73             TYPE_CASE(i32)(arg0, arg1, out, broadcast_spec);
74             break;
75             TYPE_CASE(i64)(arg0, arg1, out, broadcast_spec);
76             break;
77             TYPE_CASE(u32)(arg0, arg1, out, broadcast_spec);
78             break;
79             TYPE_CASE(u64)(arg0, arg1, out, broadcast_spec);
80             break;
81             TYPE_CASE(f16)(arg0, arg1, out, broadcast_spec);
82             break;
83             TYPE_CASE(f32)(arg0, arg1, out, broadcast_spec);
84             break;
85         default: rc = false; break;
86         }
87         return rc;
88     }
89 }
90
91 bool op::v0::Greater::evaluate(const HostTensorVector& outputs,
92                                const HostTensorVector& inputs) const
93 {
94     OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v0::Greater::evaluate");
95     return evaluate_greater(inputs[0], inputs[1], outputs[0], get_autob());
96 }
97
98 //-------------------------------------- v1 ------------------------------------
99
100 NGRAPH_RTTI_DEFINITION(op::v1::Greater, "Greater", 1);
101
102 op::v1::Greater::Greater(const Output<Node>& arg0,
103                          const Output<Node>& arg1,
104                          const AutoBroadcastSpec& auto_broadcast)
105     : BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
106 {
107     constructor_validate_and_infer_types();
108 }
109
110 shared_ptr<Node> op::v1::Greater::clone_with_new_inputs(const OutputVector& new_args) const
111 {
112     check_new_args_count(this, new_args);
113     return make_shared<op::v1::Greater>(new_args.at(0), new_args.at(1), this->get_autob());
114 }
115
116 bool op::v1::Greater::evaluate(const HostTensorVector& outputs,
117                                const HostTensorVector& inputs) const
118 {
119     OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v1::Greater::evaluate");
120     return evaluate_greater(inputs[0], inputs[1], outputs[0], get_autob());
121 }