1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
19 #include "ngraph/attribute_visitor.hpp"
20 #include "ngraph/log.hpp"
21 #include "ngraph/op/convert.hpp"
22 #include "ngraph/op/multiply.hpp"
23 #include "ngraph/op/not.hpp"
24 #include "ngraph/op/select.hpp"
26 NGRAPH_SUPPRESS_DEPRECATED_START
29 using namespace ngraph;
31 NGRAPH_RTTI_DEFINITION(op::v1::Select, "Select", 1);
33 op::v1::Select::Select(const Output<Node>& arg0,
34 const Output<Node>& arg1,
35 const Output<Node>& arg2,
36 const AutoBroadcastSpec& auto_broadcast)
37 : Op({arg0, arg1, arg2})
38 , m_auto_broadcast(auto_broadcast)
40 constructor_validate_and_infer_types();
43 void op::v1::Select::validate_and_infer_types()
45 // Condition element type check
46 NODE_VALIDATION_CHECK(this,
47 get_input_element_type(0).is_dynamic() ||
48 get_input_element_type(0) == element::boolean,
49 "Argument 0 must have boolean element type (element type: ",
50 get_input_element_type(0),
53 // Then/Else element type check
54 element::Type result_et;
55 NODE_VALIDATION_CHECK(
57 element::Type::merge(result_et, get_input_element_type(1), get_input_element_type(2)),
58 "Argument 1 and 2 element types must match.");
60 PartialShape result_shape = get_input_partial_shape(2);
61 for (int i = 1; i >= 0; i--)
63 if (get_auto_broadcast().m_type == op::AutoBroadcastType::NONE)
65 NODE_VALIDATION_CHECK(
67 PartialShape::merge_into(result_shape, get_input_partial_shape(i)),
68 "Argument shapes are inconsistent.");
70 else if (get_auto_broadcast().m_type == op::AutoBroadcastType::NUMPY ||
71 get_auto_broadcast().m_type == op::AutoBroadcastType::PDPD)
73 NODE_VALIDATION_CHECK(this,
74 PartialShape::broadcast_merge_into(result_shape,
75 get_input_partial_shape(i),
76 get_auto_broadcast()),
77 "Argument shapes are inconsistent.");
81 NODE_VALIDATION_CHECK(this, false, "Unsupported auto broadcast specification");
84 set_output_type(0, result_et, result_shape);
87 shared_ptr<Node> op::v1::Select::clone_with_new_inputs(const OutputVector& new_args) const
89 check_new_args_count(this, new_args);
90 return make_shared<v1::Select>(
91 new_args.at(0), new_args.at(1), new_args.at(2), m_auto_broadcast);
94 bool op::v1::Select::visit_attributes(AttributeVisitor& visitor)
96 visitor.on_attribute("auto_broadcast", m_auto_broadcast);
100 constexpr NodeTypeInfo op::v0::Select::type_info;
102 op::v0::Select::Select(const Output<Node>& arg0, const Output<Node>& arg1, const Output<Node>& arg2)
103 : Op({arg0, arg1, arg2})
105 constructor_validate_and_infer_types();
108 void op::v0::Select::validate_and_infer_types()
110 NODE_VALIDATION_CHECK(this,
111 get_input_element_type(0).is_dynamic() ||
112 get_input_element_type(0) == element::boolean,
113 "Argument 0 must have boolean element type (element type: ",
114 get_input_element_type(0),
117 PartialShape result_shape = get_input_partial_shape(0);
119 NODE_VALIDATION_CHECK(this,
120 PartialShape::merge_into(result_shape, get_input_partial_shape(1)),
121 "Argument shapes are inconsistent.");
122 NODE_VALIDATION_CHECK(this,
123 PartialShape::merge_into(result_shape, get_input_partial_shape(2)),
124 "Argument shapes are inconsistent.");
126 element::Type result_et;
128 NODE_VALIDATION_CHECK(
130 element::Type::merge(result_et, get_input_element_type(1), get_input_element_type(2)),
131 "Argument 1 and 2 element types are inconsistent.");
133 set_output_type(0, result_et, result_shape);
136 shared_ptr<Node> op::v0::Select::clone_with_new_inputs(const OutputVector& new_args) const
138 check_new_args_count(this, new_args);
139 return make_shared<v0::Select>(new_args.at(0), new_args.at(1), new_args.at(2));