Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / core / src / op / select.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <memory>
18
19 #include "ngraph/attribute_visitor.hpp"
20 #include "ngraph/log.hpp"
21 #include "ngraph/op/convert.hpp"
22 #include "ngraph/op/multiply.hpp"
23 #include "ngraph/op/not.hpp"
24 #include "ngraph/op/select.hpp"
25
26 NGRAPH_SUPPRESS_DEPRECATED_START
27
28 using namespace std;
29 using namespace ngraph;
30
31 NGRAPH_RTTI_DEFINITION(op::v1::Select, "Select", 1);
32
33 op::v1::Select::Select(const Output<Node>& arg0,
34                        const Output<Node>& arg1,
35                        const Output<Node>& arg2,
36                        const AutoBroadcastSpec& auto_broadcast)
37     : Op({arg0, arg1, arg2})
38     , m_auto_broadcast(auto_broadcast)
39 {
40     constructor_validate_and_infer_types();
41 }
42
43 void op::v1::Select::validate_and_infer_types()
44 {
45     // Condition element type check
46     NODE_VALIDATION_CHECK(this,
47                           get_input_element_type(0).is_dynamic() ||
48                               get_input_element_type(0) == element::boolean,
49                           "Argument 0 must have boolean element type (element type: ",
50                           get_input_element_type(0),
51                           ").");
52
53     // Then/Else element type check
54     element::Type result_et;
55     NODE_VALIDATION_CHECK(
56         this,
57         element::Type::merge(result_et, get_input_element_type(1), get_input_element_type(2)),
58         "Argument 1 and 2 element types must match.");
59
60     PartialShape result_shape = get_input_partial_shape(2);
61     for (int i = 1; i >= 0; i--)
62     {
63         if (get_auto_broadcast().m_type == op::AutoBroadcastType::NONE)
64         {
65             NODE_VALIDATION_CHECK(
66                 this,
67                 PartialShape::merge_into(result_shape, get_input_partial_shape(i)),
68                 "Argument shapes are inconsistent.");
69         }
70         else if (get_auto_broadcast().m_type == op::AutoBroadcastType::NUMPY ||
71                  get_auto_broadcast().m_type == op::AutoBroadcastType::PDPD)
72         {
73             NODE_VALIDATION_CHECK(this,
74                                   PartialShape::broadcast_merge_into(result_shape,
75                                                                      get_input_partial_shape(i),
76                                                                      get_auto_broadcast()),
77                                   "Argument shapes are inconsistent.");
78         }
79         else
80         {
81             NODE_VALIDATION_CHECK(this, false, "Unsupported auto broadcast specification");
82         }
83     }
84     set_output_type(0, result_et, result_shape);
85 }
86
87 shared_ptr<Node> op::v1::Select::clone_with_new_inputs(const OutputVector& new_args) const
88 {
89     check_new_args_count(this, new_args);
90     return make_shared<v1::Select>(
91         new_args.at(0), new_args.at(1), new_args.at(2), m_auto_broadcast);
92 }
93
94 bool op::v1::Select::visit_attributes(AttributeVisitor& visitor)
95 {
96     visitor.on_attribute("auto_broadcast", m_auto_broadcast);
97     return true;
98 }
99
100 constexpr NodeTypeInfo op::v0::Select::type_info;
101
102 op::v0::Select::Select(const Output<Node>& arg0, const Output<Node>& arg1, const Output<Node>& arg2)
103     : Op({arg0, arg1, arg2})
104 {
105     constructor_validate_and_infer_types();
106 }
107
108 void op::v0::Select::validate_and_infer_types()
109 {
110     NODE_VALIDATION_CHECK(this,
111                           get_input_element_type(0).is_dynamic() ||
112                               get_input_element_type(0) == element::boolean,
113                           "Argument 0 must have boolean element type (element type: ",
114                           get_input_element_type(0),
115                           ").");
116
117     PartialShape result_shape = get_input_partial_shape(0);
118
119     NODE_VALIDATION_CHECK(this,
120                           PartialShape::merge_into(result_shape, get_input_partial_shape(1)),
121                           "Argument shapes are inconsistent.");
122     NODE_VALIDATION_CHECK(this,
123                           PartialShape::merge_into(result_shape, get_input_partial_shape(2)),
124                           "Argument shapes are inconsistent.");
125
126     element::Type result_et;
127
128     NODE_VALIDATION_CHECK(
129         this,
130         element::Type::merge(result_et, get_input_element_type(1), get_input_element_type(2)),
131         "Argument 1 and 2 element types are inconsistent.");
132
133     set_output_type(0, result_et, result_shape);
134 }
135
136 shared_ptr<Node> op::v0::Select::clone_with_new_inputs(const OutputVector& new_args) const
137 {
138     check_new_args_count(this, new_args);
139     return make_shared<v0::Select>(new_args.at(0), new_args.at(1), new_args.at(2));
140 }