1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "builder/autobroadcast.hpp"
23 #include "builder/reshape.hpp"
24 #include "ngraph/axis_vector.hpp"
25 #include "ngraph/check.hpp"
26 #include "ngraph/op/broadcast.hpp"
27 #include "ngraph/op/reshape.hpp"
28 #include "ngraph/util.hpp"
30 NGRAPH_SUPPRESS_DEPRECATED_START
38 numpy_autobroadcast_incompatible_shapes::numpy_autobroadcast_incompatible_shapes(
39 const Shape& shape1, const Shape& shape2)
40 : ngraph_error(error_str(shape1, shape2))
46 string numpy_autobroadcast_incompatible_shapes::error_str(const Shape& shape1,
50 os << "Auto-broadcast not possible for these input shapes:"
51 << " shape1=" << vector_to_string(shape1) << " shape2=" << vector_to_string(shape2);
56 /// \brief Calculate the output shape of numpy-style broadcast operation for two
60 /// https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules
61 /// Example: left: [3, 1, 10] right: [5, 1] return: [3, 5, 10]
63 /// \param lhs_shape First input shape.
64 /// \param rhs_shape Second input Shape.
66 /// \return Broadcast shape of input shapes.
68 static Shape calculate_broadcast_shape(Shape lhs_shape, Shape rhs_shape)
71 auto lhs_rank = lhs_shape.size();
72 auto rhs_rank = rhs_shape.size();
73 auto max_rank = max(lhs_rank, rhs_rank);
75 // left-pad the lhs_shape with ones
76 lhs_shape.insert(begin(lhs_shape), max_rank - lhs_rank, 1);
77 // left-pad the rhs_shape with ones
78 rhs_shape.insert(begin(rhs_shape), max_rank - rhs_rank, 1);
80 for (size_t index = 0; index < max_rank; ++index)
82 size_t lhs_dim = lhs_shape.at(index);
83 size_t rhs_dim = rhs_shape.at(index);
85 if (lhs_dim != rhs_dim && lhs_dim != 1 && rhs_dim != 1)
87 throw numpy_autobroadcast_incompatible_shapes(lhs_shape, rhs_shape);
90 result.push_back(max(lhs_dim, rhs_dim));
96 pair<Shape, vector<Shape>> get_numpy_broadcast_shapes(const vector<Shape>& input_shapes)
98 Shape target_shape = accumulate(
99 begin(input_shapes), end(input_shapes), Shape{}, calculate_broadcast_shape);
101 vector<Shape> full_shapes;
102 for (const Shape& input : input_shapes)
104 Shape padded_shape{input};
106 begin(padded_shape), target_shape.size() - padded_shape.size(), 1);
107 full_shapes.push_back(move(padded_shape));
110 return {target_shape, full_shapes};
113 static pair<Shape, vector<Shape>> get_numpy_broadcast_shapes(const OutputVector& values)
115 vector<Shape> input_shapes;
117 for (const auto& input : values)
119 input_shapes.push_back(input.get_shape());
122 return get_numpy_broadcast_shapes(input_shapes);
125 /// \brief Broadcast input node.
127 /// \note The source shape does not have to be the actual shape of input node. However
128 /// it should be a superset of it (containing it as a continuous subset). This
129 /// implies we may expand the number of axes of input node. The ranks of
130 /// source_shape and output_shape must be equal. This means that the
131 /// source_shape has to be padded with ones for this operation.
133 /// \param[in] value The input Node to be broadcast.
134 /// \param[in] output_shape The output shape.
135 /// \param[in] source_shape The source shape from which we want to broadcast input node.
137 /// \return The broadcasted Node.
139 static shared_ptr<Node> numpy_broadcast_node(const Output<Node>& value,
140 const Shape& output_shape,
141 const Shape& source_shape)
143 shared_ptr<Node> broadcasted_node = value.get_node_shared_ptr();
144 // If node already has the required shape, return original node
145 if (output_shape == value.get_shape())
147 return broadcasted_node;
150 NGRAPH_CHECK(source_shape.size() == output_shape.size(),
151 "Ranks of source_shape and output_shape dont match: ",
154 output_shape.size());
156 AxisVector broadcast_axes;
157 Shape squeezed_shape;
158 // Positions of axes which have length of 1 are needed to calculate broadcast_axes
159 // for nGraph broadcast operation. We need to remove ones from source shape
160 // to avoid broadcasting axis conflict.
161 for (size_t index = 0; index < output_shape.size(); ++index)
163 if (source_shape.at(index) == 1 && output_shape.at(index) != 1)
165 broadcast_axes.push_back(index);
169 squeezed_shape.push_back(source_shape.at(index));
173 if (squeezed_shape != value.get_shape())
175 broadcasted_node = builder::opset1::reshape(value, squeezed_shape);
178 if (!broadcast_axes.empty())
181 op::Constant::create(element::u64, Shape{output_shape.size()}, output_shape);
182 broadcasted_node = make_shared<op::v1::Broadcast>(
185 opset1::get_axes_mapping_output(output_shape, broadcast_axes));
188 return broadcasted_node;
191 /// \brief Broadcast input node.
193 /// \param[in] value The input Node to be broadcast.
194 /// \param[in] output_shape The output shape.
195 /// \param[in] axis The start index to align with output_shape
197 /// \return The broadcasted Node.
199 static shared_ptr<Node> broadcast_value_pdpd_style(const Output<Node>& value,
200 const Shape& output_shape,
203 auto value_shape = value.get_shape();
205 // If node already has the required shape, return original node
206 if (output_shape == value_shape)
208 return value.get_node_shared_ptr();
213 axis = output_shape.size() - value_shape.size();
216 auto trimmed_value_shape = value_shape;
217 while (trimmed_value_shape.size() > 0 && trimmed_value_shape.back() == 1)
219 trimmed_value_shape.pop_back();
223 for (int64_t i = 0; i < axis; ++i)
225 axes.insert(static_cast<size_t>(i));
228 for (size_t i = axis + trimmed_value_shape.size(); i < output_shape.size(); ++i)
233 auto trimmed_value = value;
234 if (value_shape != trimmed_value_shape)
236 trimmed_value = make_shared<op::Reshape>(
237 value, get_default_order(value_shape), trimmed_value_shape);
241 op::Constant::create(element::u64, Shape{output_shape.size()}, output_shape);
242 auto value_bcast = make_shared<op::v1::Broadcast>(
243 trimmed_value, shape_const, opset1::get_axes_mapping_output(output_shape, axes));
245 return move(value_bcast);
248 pair<shared_ptr<Node>, shared_ptr<Node>>
249 numpy_broadcast(const pair<Output<Node>, Output<Node>>& args)
251 NGRAPH_CHECK(args.first.get_node());
252 NGRAPH_CHECK(args.second.get_node());
254 const Shape& arg1_in_shape = args.first.get_shape();
255 const Shape& arg2_in_shape = args.second.get_shape();
257 // Handle the trivial case...
258 if (arg1_in_shape == arg2_in_shape)
260 return make_pair(args.first.get_node_shared_ptr(),
261 args.second.get_node_shared_ptr());
264 NodeVector bcasted_outputs =
265 as_node_vector(numpy_broadcast_outputs({args.first, args.second}));
267 return make_pair(bcasted_outputs.at(0), bcasted_outputs.at(1));
270 OutputVector numpy_broadcast_outputs(const OutputVector& values)
272 if (values.size() <= 1)
277 // find the output tensor's shape, then broadcast all inputs so that they are compatible
278 auto bcast_shapes = get_numpy_broadcast_shapes(values);
280 OutputVector broadcasted_inputs;
281 for (size_t i = 0; i < values.size(); ++i)
283 broadcasted_inputs.push_back(
284 numpy_broadcast_node(values[i], bcast_shapes.first, bcast_shapes.second[i]));
286 return broadcasted_inputs;
289 shared_ptr<Node> numpy_broadcast(const Output<Node>& value, const Shape& shape)
291 auto bcast_shape = get_numpy_broadcast_shapes({value.get_shape(), shape});
292 return numpy_broadcast_node(value, bcast_shape.first, bcast_shape.second[0]);
295 OutputVector numpy_broadcast_for_matmul_operation(const Output<Node>& left,
296 const Output<Node>& right)
298 const auto& left_shape = left.get_shape();
299 const auto& right_shape = right.get_shape();
300 // Broadcast only _stack of matrices_ axes.
301 const auto& numpy_shapes =
302 get_numpy_broadcast_shapes({Shape{begin(left_shape), next(end(left_shape), -2)},
303 Shape{begin(right_shape), next(end(right_shape), -2)}});
305 // Prepare tensors output shapes with broadcasted _stack of matrices_ axes.
306 auto left_output_shape = numpy_shapes.first;
307 auto right_output_shape = numpy_shapes.first;
308 // Append the last two axes original dimensions.
309 left_output_shape.insert(end(left_output_shape),
310 next(begin(left_shape), left_shape.size() - 2),
312 right_output_shape.insert(end(right_output_shape),
313 next(begin(right_shape), right_shape.size() - 2),
316 auto left_full_shape = numpy_shapes.second.at(0);
317 auto right_full_shape = numpy_shapes.second.at(1);
318 // Append the last two axes original dimensions.
319 left_full_shape.insert(end(left_full_shape),
320 next(begin(left_shape), left_shape.size() - 2),
322 right_full_shape.insert(end(right_full_shape),
323 next(begin(right_shape), right_shape.size() - 2),
326 return {numpy_broadcast_node(left, left_output_shape, left_full_shape),
327 numpy_broadcast_node(right, right_output_shape, right_full_shape)};
330 OutputVector pdpd_broadcast(const OutputVector& inputs, int64_t axis)
332 if (inputs.size() <= 1)
337 OutputVector broadcasted_inputs{inputs[0]};
338 for (size_t i = 1; i < inputs.size(); ++i)
340 broadcasted_inputs.push_back(
341 broadcast_value_pdpd_style(inputs[i], inputs[0].get_shape(), axis));
343 return broadcasted_inputs;
346 std::shared_ptr<Node> calculate_broadcast_axes(const Shape& output_shape,
347 const Shape& input_shape,
348 size_t start_match_axis)
350 vector<size_t> axes(output_shape.size() - input_shape.size());
351 // Populate the axes vector with monotonic increasing series from 0 until
352 // output_shape_size, excluding values in range:
353 // [start_match_axis, start_match_axis + input_shape.size()]
354 iota(begin(axes), begin(axes) + start_match_axis, 0);
355 iota(begin(axes) + start_match_axis, end(axes), start_match_axis + input_shape.size());
357 auto axes_mapping = opset1::get_axes_mapping(output_shape, axes);
358 return op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
363 Output<Node> legacy_broadcast_for_binary_operation(const Output<Node>& left,
364 const Output<Node>& right,
365 size_t start_match_axis)
367 const auto& left_shape = left.get_shape();
368 const auto& right_shape = right.get_shape();
370 bool dimensions_identical = (left_shape == right_shape);
371 if (dimensions_identical)
376 // Prepare new shape of right operand for broadcasting
377 // Remove dimensions with length=1 from back
378 auto new_right_shape = right_shape;
379 for (int dimension = new_right_shape.size() - 1; dimension >= 0; --dimension)
381 if (new_right_shape.at(dimension) == 1)
383 new_right_shape.pop_back();
391 // Find first dimensions at front with length different from 1
393 for (size_t dimension : new_right_shape)
405 // Remove dimensions with length=1 from front
406 new_right_shape.erase(begin(new_right_shape),
407 next(begin(new_right_shape), num_ones));
409 auto reshape_right = reshape(right, new_right_shape);
411 // Move broadcast start axis parameter to right
412 start_match_axis += num_ones;
414 return make_broadcast(reshape_right, left_shape, start_match_axis);
417 vector<size_t> get_axes_mapping(const Shape& output_shape,
418 const AxisSet& broadcast_axes)
420 NGRAPH_CHECK((broadcast_axes.size() <= output_shape.size()));
421 vector<size_t> axes_mapping(output_shape.size());
422 iota(axes_mapping.begin(), axes_mapping.end(), 0);
423 for (auto i = broadcast_axes.rbegin(); i != broadcast_axes.rend(); ++i)
425 axes_mapping.erase(axes_mapping.begin() + *i);
430 Output<Node> get_axes_mapping_output(const Shape& output_shape,
431 const Shape& input_shape,
432 size_t start_match_axis)
434 NGRAPH_CHECK((input_shape.size() + start_match_axis <= output_shape.size()));
435 vector<size_t> mapping(input_shape.size());
436 iota(begin(mapping), end(mapping), start_match_axis);
438 return op::Constant::create(element::i64, Shape{mapping.size()}, mapping);
441 Output<Node> get_axes_mapping_output(const Shape& output_shape,
442 const AxisSet& broadcast_axes)
444 vector<size_t> axes_mapping{get_axes_mapping(output_shape, broadcast_axes)};
445 return op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
448 Output<Node> make_broadcast(const Output<Node>& node,
449 const Shape& target_shape,
450 const AxisSet& broadcast_axes)
452 return make_shared<op::v1::Broadcast>(
454 op::Constant::create(element::i64, Shape{target_shape.size()}, target_shape),
455 get_axes_mapping_output(target_shape, broadcast_axes));
458 Output<Node> make_broadcast(const Output<Node>& node,
459 const Shape& target_shape,
460 size_t start_match_axis)
462 return make_shared<op::v1::Broadcast>(
464 op::Constant::create(element::i64, Shape{target_shape.size()}, target_shape),
465 get_axes_mapping_output(target_shape, node.get_shape(), start_match_axis));
468 } // namespace opset1
469 } // namespace builder
470 } // namespace ngraph