1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
19 #include "ngraph/attribute_visitor.hpp"
20 #include "ngraph/op/broadcast.hpp"
21 #include "ngraph/op/constant.hpp"
22 #include "ngraph/op/sum.hpp"
23 #include "ngraph/partial_shape.hpp"
26 #include "ngraph/runtime/host_tensor.hpp"
27 #include "ngraph/runtime/reference/broadcast.hpp"
29 NGRAPH_SUPPRESS_DEPRECATED_START
32 using namespace ngraph;
34 constexpr NodeTypeInfo op::v3::Broadcast::type_info;
36 op::v3::Broadcast::Broadcast(const Output<Node>& arg,
37 const Output<Node>& target_shape,
38 const Output<Node>& axes_mapping,
39 const BroadcastModeSpec& broadcast_spec)
40 : util::BroadcastBase{arg, target_shape, axes_mapping, broadcast_spec}
42 constructor_validate_and_infer_types();
45 op::v3::Broadcast::Broadcast(const Output<Node>& arg,
46 const Output<Node>& target_shape,
47 const BroadcastModeSpec& broadcast_spec)
48 : util::BroadcastBase{arg, target_shape, broadcast_spec}
50 constructor_validate_and_infer_types();
55 std::pair<bool, AxisSet> get_broadcast_axes_bidirectional(const Shape& arg_shape,
56 const Shape& result_shape)
58 AxisSet broadcast_axes;
59 bool axes_known = false;
60 const auto start_axis = result_shape.size() - arg_shape.size();
61 NGRAPH_CHECK(start_axis >= 0);
62 for (size_t i = 0; i < result_shape.size(); i++)
64 if (i < start_axis || result_shape[i] != arg_shape[i - start_axis])
66 broadcast_axes.insert(i);
70 return std::make_pair(axes_known, broadcast_axes);
74 std::pair<bool, AxisSet> op::v3::Broadcast::get_broadcast_axes() const
76 if (m_mode.m_type == BroadcastType::BIDIRECTIONAL)
78 AxisSet broadcast_axes;
79 bool axes_known = false;
81 if (get_input_partial_shape(0).is_static() && get_output_partial_shape(0).is_static())
83 const auto arg_shape = get_input_shape(0);
84 const auto result_shape = get_output_shape(0);
85 return get_broadcast_axes_bidirectional(arg_shape, result_shape);
87 return std::make_pair(axes_known, broadcast_axes);
90 return util::BroadcastBase::get_broadcast_axes();
95 PartialShape get_result_shape_bidirectional(const Node* this_ptr,
96 const PartialShape& arg_shape,
99 if (arg_shape.rank().is_dynamic())
101 return PartialShape::dynamic();
103 auto arg_shape_vec = static_cast<std::vector<Dimension>>(arg_shape);
104 PartialShape result_shape;
105 // Add left padding to shorter target or argument shape
106 const auto target_padded_rank = std::max(arg_shape_vec.size(), target_shape.size());
107 while (arg_shape_vec.size() < target_padded_rank)
109 arg_shape_vec.insert(arg_shape_vec.begin(), 1);
111 while (target_shape.size() < target_padded_rank)
113 target_shape.insert(target_shape.begin(), 1);
116 result_shape = target_shape;
117 for (auto i = 0; i < target_shape.size(); ++i)
119 if (arg_shape_vec[i].is_dynamic())
121 if (target_shape[i] == 1)
123 result_shape[i] = Dimension::dynamic();
127 result_shape[i] = target_shape[i];
131 const size_t arg_shape_dim = arg_shape_vec[i].get_length();
132 NODE_VALIDATION_CHECK(this_ptr,
133 arg_shape_dim == 1 || target_shape[i] == 1 ||
134 arg_shape_dim == target_shape[i],
135 "Broadcast incorrect target shape. Expecting either 1 or ",
140 result_shape[i] = std::max(arg_shape_dim, target_shape[i]);
146 void op::v3::Broadcast::validate_and_infer_types()
148 if (m_mode.m_type == BroadcastType::NONE)
150 NODE_VALIDATION_CHECK(this,
151 get_input_size() == 3,
152 "axes_mapping input should be provided if explicit mode is used");
156 NODE_VALIDATION_CHECK(
158 get_input_size() == 2,
159 "axes_mapping input should not be provided for mode other than explicit");
162 util::BroadcastBase::validate_and_infer_types();
164 auto result_shape = get_output_partial_shape(0);
165 if (m_mode.m_type == BroadcastType::BIDIRECTIONAL)
167 if (get_input_partial_shape(0).rank().is_static() && get_input_partial_shape(1).is_static())
169 auto arg_shape = get_input_partial_shape(0);
171 const auto shape_constant =
172 as_type_ptr<op::v0::Constant>(input_value(1).get_node_shared_ptr());
175 auto target_shape = shape_constant->get_shape_val();
176 result_shape = get_result_shape_bidirectional(this, arg_shape, target_shape);
180 set_input_is_relevant_to_shape(0); // arg - Result element type
181 set_input_is_relevant_to_shape(1); // target_shape - Result shape
182 if (get_input_size() == 3)
184 set_input_is_relevant_to_shape(2); // axes_mapping - Broadcast type
186 set_output_type(0, get_input_element_type(0), result_shape);
189 shared_ptr<Node> op::v3::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const
191 check_new_args_count(this, new_args);
192 if (new_args.size() == 2)
194 return make_shared<v3::Broadcast>(new_args.at(0), new_args.at(1), m_mode);
196 else if (new_args.size() == 3)
198 return make_shared<v3::Broadcast>(new_args.at(0), new_args.at(1), new_args.at(2), m_mode);
202 throw ngraph_error("Not supported number of Broadcast:v3 args");
206 bool op::v3::Broadcast::visit_attributes(AttributeVisitor& visitor)
208 visitor.on_attribute("broadcast_spec", m_mode);
212 bool op::v3::Broadcast::evaluate(const HostTensorVector& outputs,
213 const HostTensorVector& inputs) const
215 OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v3::Broadcast::evaluate");
216 if (get_broadcast_spec().m_type == op::BroadcastType::BIDIRECTIONAL)
218 auto arg_shape = inputs[0]->get_shape();
219 Shape target_shape = op::util::BroadcastBase::get_target_shape(inputs[1]);
220 PartialShape result_shape =
221 get_result_shape_bidirectional(this, PartialShape{arg_shape}, target_shape);
222 auto pair_broadcast_axes =
223 get_broadcast_axes_bidirectional(arg_shape, result_shape.to_shape());
224 return op::util::BroadcastBase::evaluate_broadcast(
225 inputs[0], outputs[0], pair_broadcast_axes, result_shape.to_shape());
227 return op::util::BroadcastBase::evaluate(outputs, inputs);
233 BroadcastModeSpec to_broadcast_mode(const AutoBroadcastSpec& bs)
235 BroadcastModeSpec broadcast_mode;
236 broadcast_mode.m_axis = bs.m_axis;
239 case AutoBroadcastType::NONE: broadcast_mode.m_type = BroadcastType::NONE; break;
240 case AutoBroadcastType::NUMPY: broadcast_mode.m_type = BroadcastType::NUMPY; break;
241 case AutoBroadcastType::PDPD: broadcast_mode.m_type = BroadcastType::PDPD; break;
243 return broadcast_mode;
247 constexpr NodeTypeInfo op::v1::Broadcast::type_info;
249 op::v1::Broadcast::Broadcast(const Output<Node>& arg,
250 const Output<Node>& target_shape,
251 const Output<Node>& axes_mapping,
252 const AutoBroadcastSpec& broadcast_spec)
253 : util::BroadcastBase{arg, target_shape, axes_mapping, to_broadcast_mode(broadcast_spec)}
254 , m_broadcast_spec{broadcast_spec}
256 constructor_validate_and_infer_types();
259 op::v1::Broadcast::Broadcast(const Output<Node>& arg,
260 const Output<Node>& target_shape,
261 const AutoBroadcastSpec& broadcast_spec)
262 : util::BroadcastBase{arg,
264 op::v0::Constant::create(element::u8, Shape{}, {0})->output(0),
265 to_broadcast_mode(broadcast_spec)}
266 , m_broadcast_spec{broadcast_spec}
268 constructor_validate_and_infer_types();
271 void op::v1::Broadcast::validate_and_infer_types()
273 util::BroadcastBase::validate_and_infer_types();
275 set_input_is_relevant_to_shape(0); // arg - Result element type
276 set_input_is_relevant_to_shape(1); // target_shape - Result shape
277 set_input_is_relevant_to_shape(2); // axes_mapping - Broadcast type
280 shared_ptr<Node> op::v1::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const
282 check_new_args_count(this, new_args);
283 return make_shared<v1::Broadcast>(
284 new_args.at(0), new_args.at(1), new_args.at(2), m_broadcast_spec);
287 bool op::v1::Broadcast::visit_attributes(AttributeVisitor& visitor)
289 visitor.on_attribute("broadcast_spec", m_broadcast_spec);
293 bool op::v1::Broadcast::evaluate(const HostTensorVector& outputs,
294 const HostTensorVector& inputs) const
296 OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v1::Broadcast::evaluate");
297 return op::util::BroadcastBase::evaluate(outputs, inputs);
300 constexpr NodeTypeInfo op::v0::Broadcast::type_info;
302 op::v0::Broadcast::Broadcast(const OutputVector& args,
304 const AxisSet& broadcast_axes)
307 , m_broadcast_axes(broadcast_axes)
309 constructor_validate_and_infer_types();
312 op::v0::Broadcast::Broadcast(const Output<Node>& arg,
314 const AxisSet& broadcast_axes)
315 : Broadcast(OutputVector{arg}, shape, broadcast_axes)
319 bool op::v0::Broadcast::visit_attributes(AttributeVisitor& visitor)
321 visitor.on_attribute("shape", m_shape);
322 visitor.on_attribute("broadcast_axes", m_broadcast_axes);
326 void op::v0::Broadcast::validate_and_infer_types()
330 for (auto axis : m_broadcast_axes)
332 NODE_VALIDATION_CHECK(this,
333 axis < m_shape.size(),
334 "Broadcast axis index (",
336 ") exceeds specified output shape rank ",
344 Shape required_input_shape = m_shape;
345 for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
347 required_input_shape.erase(required_input_shape.begin() + *i);
350 // TODO(amprocte): We can probably have a more helpful error message here.
351 // There are two things that can go wrong, which are being picked up in
352 // one fell swoop by this check: either the number of broadcast axes is not
353 // enough, or there is a mismatch with one of the pre-broadcast axis lengths.
354 NODE_VALIDATION_CHECK(
356 get_input_partial_shape(0).compatible(required_input_shape),
357 "Broadcast argument shape, specified output shape, and axes are incompatible ",
359 get_input_partial_shape(0),
362 ", broadcast axes: ",
366 set_output_type(0, get_input_element_type(0), m_shape);
369 shared_ptr<Node> op::v0::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const
371 check_new_args_count(this, new_args);
372 return make_shared<v0::Broadcast>(new_args.at(0), m_shape, m_broadcast_axes);
377 #define TYPE_CASE_v0(a) \
378 case element::Type_t::a: rc = evaluate_v0<element::Type_t::a>
380 template <element::Type_t ET>
381 inline bool evaluate_v0(const HostTensorPtr& arg0,
382 const HostTensorPtr& out,
383 const AxisSet& broadcast_axes)
385 using T = typename element_type_traits<ET>::value_type;
386 runtime::reference::broadcast<T>((arg0->get_data_ptr<ET>()),
387 (out->get_data_ptr<ET>()),
394 bool evaluate_broadcast_v0(const HostTensorPtr& arg0,
395 const HostTensorPtr& out,
396 const AxisSet broadcast_axes,
397 const Shape output_shape)
400 Shape in_shape = arg0->get_shape();
401 out->set_shape(output_shape);
402 out->set_element_type(arg0->get_element_type());
403 switch (arg0->get_element_type())
405 TYPE_CASE_v0(boolean)(arg0, out, broadcast_axes);
407 TYPE_CASE_v0(i8)(arg0, out, broadcast_axes);
409 TYPE_CASE_v0(i16)(arg0, out, broadcast_axes);
411 TYPE_CASE_v0(i32)(arg0, out, broadcast_axes);
413 TYPE_CASE_v0(i64)(arg0, out, broadcast_axes);
415 TYPE_CASE_v0(u8)(arg0, out, broadcast_axes);
417 TYPE_CASE_v0(u16)(arg0, out, broadcast_axes);
419 TYPE_CASE_v0(u32)(arg0, out, broadcast_axes);
421 TYPE_CASE_v0(u64)(arg0, out, broadcast_axes);
423 TYPE_CASE_v0(bf16)(arg0, out, broadcast_axes);
425 TYPE_CASE_v0(f16)(arg0, out, broadcast_axes);
427 TYPE_CASE_v0(f32)(arg0, out, broadcast_axes);
429 TYPE_CASE_v0(f64)(arg0, out, broadcast_axes);
431 default: rc = false; break;
437 bool op::v0::Broadcast::evaluate(const HostTensorVector& outputs,
438 const HostTensorVector& inputs) const
440 OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v0::Broadcast::evaluate");
441 return evaluate_broadcast_v0(inputs[0], outputs[0], get_broadcast_axes(), get_output_shape(0));
444 constexpr NodeTypeInfo op::v0::BroadcastLike::type_info;
446 op::v0::BroadcastLike::BroadcastLike(const Output<Node>& arg,
447 const Output<Node>& like_arg,
448 const AxisSet& initial_broadcast_axes)
449 : op::v0::Broadcast({arg, like_arg}, {}, {})
450 , m_initial_broadcast_axes(initial_broadcast_axes)
452 constructor_validate_and_infer_types();
455 bool op::v0::BroadcastLike::visit_attributes(AttributeVisitor& visitor)
457 visitor.on_attribute("shape", m_shape);
458 visitor.on_attribute("broadcast_axes", m_broadcast_axes);
459 visitor.on_attribute("initial_broadcast_axes", m_initial_broadcast_axes);
463 shared_ptr<Node> op::v0::BroadcastLike::clone_with_new_inputs(const OutputVector& new_args) const
465 if (new_args.size() != 2)
467 throw ngraph_error("Incorrect number of new arguments");
469 return make_shared<v0::BroadcastLike>(new_args.at(0), new_args.at(1), m_initial_broadcast_axes);
472 void op::v0::BroadcastLike::infer_shape()
474 const Shape& in_shape = get_input_shape(0);
475 m_shape = get_input_shape(1);
476 m_broadcast_axes = m_initial_broadcast_axes;
477 if (m_broadcast_axes.size() == 0)
479 for (size_t i = 0; i < m_shape.size(); ++i)
481 if (i < in_shape.size())
483 if (in_shape.at(i) == 1 && m_shape.at(i) > 1)
485 m_broadcast_axes.insert(i);
490 m_broadcast_axes.insert(i);