1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "builder/autobroadcast.hpp"
23 #include "builder/reshape.hpp"
24 #include "ngraph/axis_vector.hpp"
25 #include "ngraph/check.hpp"
26 #include "ngraph/op/broadcast.hpp"
27 #include "ngraph/op/constant.hpp"
28 #include "ngraph/op/reshape.hpp"
29 #include "ngraph/util.hpp"
31 NGRAPH_SUPPRESS_DEPRECATED_START
39 numpy_autobroadcast_incompatible_shapes::numpy_autobroadcast_incompatible_shapes(
40 const Shape& shape1, const Shape& shape2)
41 : ngraph_error(error_str(shape1, shape2))
47 string numpy_autobroadcast_incompatible_shapes::error_str(const Shape& shape1,
51 os << "Auto-broadcast not possible for these input shapes:"
52 << " shape1=" << vector_to_string(shape1) << " shape2=" << vector_to_string(shape2);
57 /// \brief Calculate the output shape of numpy-style broadcast operation for two
61 /// https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules
62 /// Example: left: [3, 1, 10] right: [5, 1] return: [3, 5, 10]
64 /// \param lhs_shape First input shape.
65 /// \param rhs_shape Second input Shape.
67 /// \return Broadcast shape of input shapes.
69 static Shape calculate_broadcast_shape(Shape lhs_shape, Shape rhs_shape)
72 auto lhs_rank = lhs_shape.size();
73 auto rhs_rank = rhs_shape.size();
74 auto max_rank = max(lhs_rank, rhs_rank);
76 // left-pad the lhs_shape with ones
77 lhs_shape.insert(begin(lhs_shape), max_rank - lhs_rank, 1);
78 // left-pad the rhs_shape with ones
79 rhs_shape.insert(begin(rhs_shape), max_rank - rhs_rank, 1);
81 for (size_t index = 0; index < max_rank; ++index)
83 size_t lhs_dim = lhs_shape.at(index);
84 size_t rhs_dim = rhs_shape.at(index);
86 if (lhs_dim != rhs_dim && lhs_dim != 1 && rhs_dim != 1)
88 throw numpy_autobroadcast_incompatible_shapes(lhs_shape, rhs_shape);
91 result.push_back(max(lhs_dim, rhs_dim));
97 pair<Shape, vector<Shape>> get_numpy_broadcast_shapes(const vector<Shape>& input_shapes)
99 Shape target_shape = accumulate(
100 begin(input_shapes), end(input_shapes), Shape{}, calculate_broadcast_shape);
102 vector<Shape> full_shapes;
103 for (const Shape& input : input_shapes)
105 Shape padded_shape{input};
107 begin(padded_shape), target_shape.size() - padded_shape.size(), 1);
108 full_shapes.push_back(move(padded_shape));
111 return {target_shape, full_shapes};
114 static pair<Shape, vector<Shape>> get_numpy_broadcast_shapes(const OutputVector& values)
116 vector<Shape> input_shapes;
118 for (const auto& input : values)
120 input_shapes.push_back(input.get_shape());
123 return get_numpy_broadcast_shapes(input_shapes);
126 /// \brief Broadcast input node.
128 /// \note The source shape does not have to be the actual shape of input node. However
129 /// it should be a superset of it (containing it as a continuous subset). This
130 /// implies we may expand the number of axes of input node. The ranks of
131 /// source_shape and output_shape must be equal. This means that the
132 /// source_shape has to be padded with ones for this operation.
134 /// \param[in] value The input Node to be broadcast.
135 /// \param[in] output_shape The output shape.
136 /// \param[in] source_shape The source shape from which we want to broadcast input node.
138 /// \return The broadcasted Node.
140 static shared_ptr<Node> numpy_broadcast_node(const Output<Node>& value,
141 const Shape& output_shape,
142 const Shape& source_shape)
144 shared_ptr<Node> broadcasted_node = value.get_node_shared_ptr();
145 // If node already has the required shape, return original node
146 if (output_shape == value.get_shape())
148 return broadcasted_node;
151 NGRAPH_CHECK(source_shape.size() == output_shape.size(),
152 "Ranks of source_shape and output_shape dont match: ",
155 output_shape.size());
157 AxisVector broadcast_axes;
158 Shape squeezed_shape;
159 // Positions of axes which have length of 1 are needed to calculate broadcast_axes
160 // for nGraph broadcast operation. We need to remove ones from source shape
161 // to avoid broadcasting axis conflict.
162 for (size_t index = 0; index < output_shape.size(); ++index)
164 if (source_shape.at(index) == 1 && output_shape.at(index) != 1)
166 broadcast_axes.push_back(index);
170 squeezed_shape.push_back(source_shape.at(index));
174 if (squeezed_shape != value.get_shape())
176 broadcasted_node = builder::opset1::reshape(value, squeezed_shape);
179 if (!broadcast_axes.empty())
182 make_shared<op::Broadcast>(broadcasted_node, output_shape, broadcast_axes);
185 return broadcasted_node;
188 /// \brief Broadcast input node.
190 /// \param[in] value The input Node to be broadcast.
191 /// \param[in] output_shape The output shape.
192 /// \param[in] axis The start index to align with output_shape
194 /// \return The broadcasted Node.
196 static shared_ptr<Node> broadcast_value_pdpd_style(const Output<Node>& value,
197 const Shape& output_shape,
200 auto value_shape = value.get_shape();
202 // If node already has the required shape, return original node
203 if (output_shape == value_shape)
205 return value.get_node_shared_ptr();
210 axis = output_shape.size() - value_shape.size();
213 auto trimmed_value_shape = value_shape;
214 while (trimmed_value_shape.size() > 0 && trimmed_value_shape.back() == 1)
216 trimmed_value_shape.pop_back();
220 for (int64_t i = 0; i < axis; ++i)
222 axes.insert(static_cast<size_t>(i));
225 for (size_t i = axis + trimmed_value_shape.size(); i < output_shape.size(); ++i)
230 auto trimmed_value = value;
231 if (value_shape != trimmed_value_shape)
233 trimmed_value = make_shared<op::Reshape>(
234 value, get_default_order(value_shape), trimmed_value_shape);
237 auto value_bcast = make_shared<op::Broadcast>(trimmed_value, output_shape, axes);
239 return move(value_bcast);
242 pair<shared_ptr<Node>, shared_ptr<Node>>
243 numpy_broadcast(const pair<Output<Node>, Output<Node>>& args)
245 NGRAPH_CHECK(args.first.get_node());
246 NGRAPH_CHECK(args.second.get_node());
248 const Shape& arg1_in_shape = args.first.get_shape();
249 const Shape& arg2_in_shape = args.second.get_shape();
251 // Handle the trivial case...
252 if (arg1_in_shape == arg2_in_shape)
254 return make_pair(args.first.get_node_shared_ptr(),
255 args.second.get_node_shared_ptr());
258 NodeVector bcasted_outputs =
259 as_node_vector(numpy_broadcast_outputs({args.first, args.second}));
261 return make_pair(bcasted_outputs.at(0), bcasted_outputs.at(1));
264 OutputVector numpy_broadcast_outputs(const OutputVector& values)
266 if (values.size() <= 1)
271 // find the output tensor's shape, then broadcast all inputs so that they are compatible
272 auto bcast_shapes = get_numpy_broadcast_shapes(values);
274 OutputVector broadcasted_inputs;
275 for (size_t i = 0; i < values.size(); ++i)
277 broadcasted_inputs.push_back(
278 numpy_broadcast_node(values[i], bcast_shapes.first, bcast_shapes.second[i]));
280 return broadcasted_inputs;
283 shared_ptr<Node> numpy_broadcast(const Output<Node>& value, const Shape& shape)
285 auto bcast_shape = get_numpy_broadcast_shapes({value.get_shape(), shape});
286 return numpy_broadcast_node(value, bcast_shape.first, bcast_shape.second[0]);
289 OutputVector numpy_broadcast_for_matmul_operation(const Output<Node>& left,
290 const Output<Node>& right)
292 const auto& left_shape = left.get_shape();
293 const auto& right_shape = right.get_shape();
294 // Broadcast only _stack of matrices_ axes.
295 const auto& numpy_shapes =
296 get_numpy_broadcast_shapes({Shape{begin(left_shape), next(end(left_shape), -2)},
297 Shape{begin(right_shape), next(end(right_shape), -2)}});
299 // Prepare tensors output shapes with broadcasted _stack of matrices_ axes.
300 auto left_output_shape = numpy_shapes.first;
301 auto right_output_shape = numpy_shapes.first;
302 // Append the last two axes original dimensions.
303 left_output_shape.insert(end(left_output_shape),
304 next(begin(left_shape), left_shape.size() - 2),
306 right_output_shape.insert(end(right_output_shape),
307 next(begin(right_shape), right_shape.size() - 2),
310 auto left_full_shape = numpy_shapes.second.at(0);
311 auto right_full_shape = numpy_shapes.second.at(1);
312 // Append the last two axes original dimensions.
313 left_full_shape.insert(end(left_full_shape),
314 next(begin(left_shape), left_shape.size() - 2),
316 right_full_shape.insert(end(right_full_shape),
317 next(begin(right_shape), right_shape.size() - 2),
320 return {numpy_broadcast_node(left, left_output_shape, left_full_shape),
321 numpy_broadcast_node(right, right_output_shape, right_full_shape)};
324 OutputVector pdpd_broadcast(const OutputVector& inputs, int64_t axis)
326 if (inputs.size() <= 1)
331 OutputVector broadcasted_inputs{inputs[0]};
332 for (size_t i = 1; i < inputs.size(); ++i)
334 broadcasted_inputs.push_back(
335 broadcast_value_pdpd_style(inputs[i], inputs[0].get_shape(), axis));
337 return broadcasted_inputs;
340 AxisSet calculate_broadcast_axes(const Shape& output_shape,
341 const Shape& input_shape,
342 size_t start_match_axis)
344 vector<size_t> result(output_shape.size() - input_shape.size());
345 // Populate the result vector with monotonic increasing series from 0 until
346 // output_shape_size, excluding values in range:
347 // [start_match_axis, start_match_axis + input_shape.size()]
348 iota(begin(result), begin(result) + start_match_axis, 0);
349 iota(begin(result) + start_match_axis,
351 start_match_axis + input_shape.size());
357 Output<Node> legacy_broadcast_for_binary_operation(const Output<Node>& left,
358 const Output<Node>& right,
359 size_t start_match_axis)
361 const auto& left_shape = left.get_shape();
362 const auto& right_shape = right.get_shape();
364 bool dimensions_identical = (left_shape == right_shape);
365 if (dimensions_identical)
370 // Prepare new shape of right operand for broadcasting
371 // Remove dimensions with length=1 from back
372 auto new_right_shape = right_shape;
373 for (int dimension = new_right_shape.size() - 1; dimension >= 0; --dimension)
375 if (new_right_shape.at(dimension) == 1)
377 new_right_shape.pop_back();
385 // Find first dimensions at front with length different from 1
387 for (size_t dimension : new_right_shape)
399 // Remove dimensions with length=1 from front
400 new_right_shape.erase(begin(new_right_shape),
401 next(begin(new_right_shape), num_ones));
403 auto reshape_right = reshape(right, new_right_shape);
405 // Move broadcast start axis parameter to right
406 start_match_axis += num_ones;
408 return make_broadcast(reshape_right, left_shape, start_match_axis);
411 vector<size_t> get_axes_mapping(const Shape& output_shape,
412 const AxisSet& broadcast_axes)
414 NGRAPH_CHECK((broadcast_axes.size() <= output_shape.size()));
415 vector<size_t> axes_mapping(output_shape.size());
416 iota(axes_mapping.begin(), axes_mapping.end(), 0);
417 for (auto i = broadcast_axes.rbegin(); i != broadcast_axes.rend(); ++i)
419 axes_mapping.erase(axes_mapping.begin() + *i);
424 Output<Node> get_axes_mapping_output(const Shape& output_shape,
425 const Shape& input_shape,
426 size_t start_match_axis)
428 NGRAPH_CHECK((input_shape.size() + start_match_axis <= output_shape.size()));
429 vector<size_t> mapping(input_shape.size());
430 iota(begin(mapping), end(mapping), start_match_axis);
432 return op::Constant::create(element::i64, Shape{mapping.size()}, mapping);
435 Output<Node> get_axes_mapping_output(const Shape& output_shape,
436 const AxisSet& broadcast_axes)
438 vector<size_t> axes_mapping{get_axes_mapping(output_shape, broadcast_axes)};
439 return op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
442 Output<Node> make_broadcast(const Output<Node>& node,
443 const Shape& target_shape,
444 const AxisSet& broadcast_axes)
446 return make_shared<op::v1::Broadcast>(
448 op::Constant::create(element::i64, Shape{target_shape.size()}, target_shape),
449 get_axes_mapping_output(target_shape, broadcast_axes));
452 Output<Node> make_broadcast(const Output<Node>& node,
453 const Shape& target_shape,
454 size_t start_match_axis)
456 return make_shared<op::v1::Broadcast>(
458 op::Constant::create(element::i64, Shape{target_shape.size()}, target_shape),
459 get_axes_mapping_output(target_shape, node.get_shape(), start_match_axis));
462 } // namespace opset1
463 } // namespace builder
464 } // namespace ngraph