5003097b2d38ac216e5860fcb4410d6d945b51e4
[platform/upstream/dldt.git] / ngraph / core / builder / src / builder / autobroadcast.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "builder/autobroadcast.hpp"
18
19 #include <memory>
20 #include <numeric>
21 #include <sstream>
22
23 #include "builder/reshape.hpp"
24 #include "ngraph/axis_vector.hpp"
25 #include "ngraph/check.hpp"
26 #include "ngraph/op/broadcast.hpp"
27 #include "ngraph/op/constant.hpp"
28 #include "ngraph/op/reshape.hpp"
29 #include "ngraph/util.hpp"
30
31 NGRAPH_SUPPRESS_DEPRECATED_START
32
33 using namespace std;
34
35 namespace ngraph
36 {
37     namespace builder
38     {
39         numpy_autobroadcast_incompatible_shapes::numpy_autobroadcast_incompatible_shapes(
40             const Shape& shape1, const Shape& shape2)
41             : ngraph_error(error_str(shape1, shape2))
42             , m_shape1(shape1)
43             , m_shape2(shape2)
44         {
45         }
46
47         string numpy_autobroadcast_incompatible_shapes::error_str(const Shape& shape1,
48                                                                   const Shape& shape2)
49         {
50             ostringstream os;
51             os << "Auto-broadcast not possible for these input shapes:"
52                << " shape1=" << vector_to_string(shape1) << " shape2=" << vector_to_string(shape2);
53             return os.str();
54         }
55
56         ///
57         /// \brief      Calculate the output shape of numpy-style broadcast operation for two
58         ///             shapes.
59         ///
60         /// \note       More info:
61         /// https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules
62         ///             Example: left: [3, 1, 10] right: [5, 1] return: [3, 5, 10]
63         ///
64         /// \param      lhs_shape  First input shape.
65         /// \param      rhs_shape  Second input Shape.
66         ///
67         /// \return     Broadcast shape of input shapes.
68         ///
69         static Shape calculate_broadcast_shape(Shape lhs_shape, Shape rhs_shape)
70         {
71             Shape result;
72             auto lhs_rank = lhs_shape.size();
73             auto rhs_rank = rhs_shape.size();
74             auto max_rank = max(lhs_rank, rhs_rank);
75
76             // left-pad the lhs_shape with ones
77             lhs_shape.insert(begin(lhs_shape), max_rank - lhs_rank, 1);
78             // left-pad the rhs_shape with ones
79             rhs_shape.insert(begin(rhs_shape), max_rank - rhs_rank, 1);
80
81             for (size_t index = 0; index < max_rank; ++index)
82             {
83                 size_t lhs_dim = lhs_shape.at(index);
84                 size_t rhs_dim = rhs_shape.at(index);
85
86                 if (lhs_dim != rhs_dim && lhs_dim != 1 && rhs_dim != 1)
87                 {
88                     throw numpy_autobroadcast_incompatible_shapes(lhs_shape, rhs_shape);
89                 }
90
91                 result.push_back(max(lhs_dim, rhs_dim));
92             }
93
94             return result;
95         };
96
97         pair<Shape, vector<Shape>> get_numpy_broadcast_shapes(const vector<Shape>& input_shapes)
98         {
99             Shape target_shape = accumulate(
100                 begin(input_shapes), end(input_shapes), Shape{}, calculate_broadcast_shape);
101
102             vector<Shape> full_shapes;
103             for (const Shape& input : input_shapes)
104             {
105                 Shape padded_shape{input};
106                 padded_shape.insert(
107                     begin(padded_shape), target_shape.size() - padded_shape.size(), 1);
108                 full_shapes.push_back(move(padded_shape));
109             }
110
111             return {target_shape, full_shapes};
112         }
113
114         static pair<Shape, vector<Shape>> get_numpy_broadcast_shapes(const OutputVector& values)
115         {
116             vector<Shape> input_shapes;
117
118             for (const auto& input : values)
119             {
120                 input_shapes.push_back(input.get_shape());
121             }
122
123             return get_numpy_broadcast_shapes(input_shapes);
124         }
125
126         /// \brief      Broadcast input node.
127         ///
128         /// \note       The source shape does not have to be the actual shape of input node. However
129         ///             it should be a superset of it (containing it as a continuous subset). This
130         ///             implies we may expand the number of axes of input node. The ranks of
131         ///             source_shape and output_shape must be equal. This means that the
132         ///             source_shape has to be padded with ones for this operation.
133         ///
134         /// \param[in]  value         The input Node to be broadcast.
135         /// \param[in]  output_shape  The output shape.
136         /// \param[in]  source_shape  The source shape from which we want to broadcast input node.
137         ///
138         /// \return     The broadcasted Node.
139         ///
140         static shared_ptr<Node> numpy_broadcast_node(const Output<Node>& value,
141                                                      const Shape& output_shape,
142                                                      const Shape& source_shape)
143         {
144             shared_ptr<Node> broadcasted_node = value.get_node_shared_ptr();
145             // If node already has the required shape, return original node
146             if (output_shape == value.get_shape())
147             {
148                 return broadcasted_node;
149             }
150
151             NGRAPH_CHECK(source_shape.size() == output_shape.size(),
152                          "Ranks of source_shape and output_shape dont match: ",
153                          source_shape.size(),
154                          " vs ",
155                          output_shape.size());
156
157             AxisVector broadcast_axes;
158             Shape squeezed_shape;
159             // Positions of axes which have length of 1 are needed to calculate broadcast_axes
160             // for nGraph broadcast operation. We need to remove ones from source shape
161             // to avoid broadcasting axis conflict.
162             for (size_t index = 0; index < output_shape.size(); ++index)
163             {
164                 if (source_shape.at(index) == 1 && output_shape.at(index) != 1)
165                 {
166                     broadcast_axes.push_back(index);
167                 }
168                 else
169                 {
170                     squeezed_shape.push_back(source_shape.at(index));
171                 }
172             }
173
174             if (squeezed_shape != value.get_shape())
175             {
176                 broadcasted_node = builder::opset1::reshape(value, squeezed_shape);
177             }
178
179             if (!broadcast_axes.empty())
180             {
181                 broadcasted_node =
182                     make_shared<op::Broadcast>(broadcasted_node, output_shape, broadcast_axes);
183             }
184
185             return broadcasted_node;
186         }
187
188         /// \brief      Broadcast input node.
189         ///
190         /// \param[in]  value         The input Node to be broadcast.
191         /// \param[in]  output_shape  The output shape.
192         /// \param[in]  axis          The start index to align with output_shape
193         ///
194         /// \return     The broadcasted Node.
195         ///
196         static shared_ptr<Node> broadcast_value_pdpd_style(const Output<Node>& value,
197                                                            const Shape& output_shape,
198                                                            int64_t axis)
199         {
200             auto value_shape = value.get_shape();
201
202             // If node already has the required shape, return original node
203             if (output_shape == value_shape)
204             {
205                 return value.get_node_shared_ptr();
206             }
207
208             if (axis == -1)
209             {
210                 axis = output_shape.size() - value_shape.size();
211             }
212
213             auto trimmed_value_shape = value_shape;
214             while (trimmed_value_shape.size() > 0 && trimmed_value_shape.back() == 1)
215             {
216                 trimmed_value_shape.pop_back();
217             }
218
219             AxisSet axes;
220             for (int64_t i = 0; i < axis; ++i)
221             {
222                 axes.insert(static_cast<size_t>(i));
223             }
224
225             for (size_t i = axis + trimmed_value_shape.size(); i < output_shape.size(); ++i)
226             {
227                 axes.insert(i);
228             }
229
230             auto trimmed_value = value;
231             if (value_shape != trimmed_value_shape)
232             {
233                 trimmed_value = make_shared<op::Reshape>(
234                     value, get_default_order(value_shape), trimmed_value_shape);
235             }
236
237             auto value_bcast = make_shared<op::Broadcast>(trimmed_value, output_shape, axes);
238
239             return move(value_bcast);
240         }
241
242         pair<shared_ptr<Node>, shared_ptr<Node>>
243             numpy_broadcast(const pair<Output<Node>, Output<Node>>& args)
244         {
245             NGRAPH_CHECK(args.first.get_node());
246             NGRAPH_CHECK(args.second.get_node());
247
248             const Shape& arg1_in_shape = args.first.get_shape();
249             const Shape& arg2_in_shape = args.second.get_shape();
250
251             // Handle the trivial case...
252             if (arg1_in_shape == arg2_in_shape)
253             {
254                 return make_pair(args.first.get_node_shared_ptr(),
255                                  args.second.get_node_shared_ptr());
256             }
257
258             NodeVector bcasted_outputs =
259                 as_node_vector(numpy_broadcast_outputs({args.first, args.second}));
260
261             return make_pair(bcasted_outputs.at(0), bcasted_outputs.at(1));
262         }
263
264         OutputVector numpy_broadcast_outputs(const OutputVector& values)
265         {
266             if (values.size() <= 1)
267             {
268                 return values;
269             }
270
271             // find the output tensor's shape, then broadcast all inputs so that they are compatible
272             auto bcast_shapes = get_numpy_broadcast_shapes(values);
273
274             OutputVector broadcasted_inputs;
275             for (size_t i = 0; i < values.size(); ++i)
276             {
277                 broadcasted_inputs.push_back(
278                     numpy_broadcast_node(values[i], bcast_shapes.first, bcast_shapes.second[i]));
279             }
280             return broadcasted_inputs;
281         }
282
283         shared_ptr<Node> numpy_broadcast(const Output<Node>& value, const Shape& shape)
284         {
285             auto bcast_shape = get_numpy_broadcast_shapes({value.get_shape(), shape});
286             return numpy_broadcast_node(value, bcast_shape.first, bcast_shape.second[0]);
287         }
288
289         OutputVector numpy_broadcast_for_matmul_operation(const Output<Node>& left,
290                                                           const Output<Node>& right)
291         {
292             const auto& left_shape = left.get_shape();
293             const auto& right_shape = right.get_shape();
294             // Broadcast only _stack of matrices_ axes.
295             const auto& numpy_shapes =
296                 get_numpy_broadcast_shapes({Shape{begin(left_shape), next(end(left_shape), -2)},
297                                             Shape{begin(right_shape), next(end(right_shape), -2)}});
298
299             // Prepare tensors output shapes with broadcasted _stack of matrices_ axes.
300             auto left_output_shape = numpy_shapes.first;
301             auto right_output_shape = numpy_shapes.first;
302             // Append the last two axes original dimensions.
303             left_output_shape.insert(end(left_output_shape),
304                                      next(begin(left_shape), left_shape.size() - 2),
305                                      end(left_shape));
306             right_output_shape.insert(end(right_output_shape),
307                                       next(begin(right_shape), right_shape.size() - 2),
308                                       end(right_shape));
309
310             auto left_full_shape = numpy_shapes.second.at(0);
311             auto right_full_shape = numpy_shapes.second.at(1);
312             // Append the last two axes original dimensions.
313             left_full_shape.insert(end(left_full_shape),
314                                    next(begin(left_shape), left_shape.size() - 2),
315                                    end(left_shape));
316             right_full_shape.insert(end(right_full_shape),
317                                     next(begin(right_shape), right_shape.size() - 2),
318                                     end(right_shape));
319
320             return {numpy_broadcast_node(left, left_output_shape, left_full_shape),
321                     numpy_broadcast_node(right, right_output_shape, right_full_shape)};
322         }
323
324         OutputVector pdpd_broadcast(const OutputVector& inputs, int64_t axis)
325         {
326             if (inputs.size() <= 1)
327             {
328                 return inputs;
329             }
330
331             OutputVector broadcasted_inputs{inputs[0]};
332             for (size_t i = 1; i < inputs.size(); ++i)
333             {
334                 broadcasted_inputs.push_back(
335                     broadcast_value_pdpd_style(inputs[i], inputs[0].get_shape(), axis));
336             }
337             return broadcasted_inputs;
338         }
339
340         AxisSet calculate_broadcast_axes(const Shape& output_shape,
341                                          const Shape& input_shape,
342                                          size_t start_match_axis)
343         {
344             vector<size_t> result(output_shape.size() - input_shape.size());
345             // Populate the result vector with monotonic increasing series from 0 until
346             // output_shape_size, excluding values in range:
347             // [start_match_axis, start_match_axis + input_shape.size()]
348             iota(begin(result), begin(result) + start_match_axis, 0);
349             iota(begin(result) + start_match_axis,
350                  end(result),
351                  start_match_axis + input_shape.size());
352             return result;
353         }
354
355         namespace opset1
356         {
357             Output<Node> legacy_broadcast_for_binary_operation(const Output<Node>& left,
358                                                                const Output<Node>& right,
359                                                                size_t start_match_axis)
360             {
361                 const auto& left_shape = left.get_shape();
362                 const auto& right_shape = right.get_shape();
363
364                 bool dimensions_identical = (left_shape == right_shape);
365                 if (dimensions_identical)
366                 {
367                     return right;
368                 }
369
370                 // Prepare new shape of right operand for broadcasting
371                 // Remove dimensions with length=1 from back
372                 auto new_right_shape = right_shape;
373                 for (int dimension = new_right_shape.size() - 1; dimension >= 0; --dimension)
374                 {
375                     if (new_right_shape.at(dimension) == 1)
376                     {
377                         new_right_shape.pop_back();
378                     }
379                     else
380                     {
381                         break;
382                     }
383                 }
384
385                 // Find first dimensions at front with length different from 1
386                 size_t num_ones = 0;
387                 for (size_t dimension : new_right_shape)
388                 {
389                     if (dimension == 1)
390                     {
391                         ++num_ones;
392                     }
393                     else
394                     {
395                         break;
396                     }
397                 }
398
399                 // Remove dimensions with length=1 from front
400                 new_right_shape.erase(begin(new_right_shape),
401                                       next(begin(new_right_shape), num_ones));
402
403                 auto reshape_right = reshape(right, new_right_shape);
404
405                 // Move broadcast start axis parameter to right
406                 start_match_axis += num_ones;
407
408                 return make_broadcast(reshape_right, left_shape, start_match_axis);
409             }
410
411             vector<size_t> get_axes_mapping(const Shape& output_shape,
412                                             const AxisSet& broadcast_axes)
413             {
414                 NGRAPH_CHECK((broadcast_axes.size() <= output_shape.size()));
415                 vector<size_t> axes_mapping(output_shape.size());
416                 iota(axes_mapping.begin(), axes_mapping.end(), 0);
417                 for (auto i = broadcast_axes.rbegin(); i != broadcast_axes.rend(); ++i)
418                 {
419                     axes_mapping.erase(axes_mapping.begin() + *i);
420                 }
421                 return axes_mapping;
422             }
423
424             Output<Node> get_axes_mapping_output(const Shape& output_shape,
425                                                  const Shape& input_shape,
426                                                  size_t start_match_axis)
427             {
428                 NGRAPH_CHECK((input_shape.size() + start_match_axis <= output_shape.size()));
429                 vector<size_t> mapping(input_shape.size());
430                 iota(begin(mapping), end(mapping), start_match_axis);
431
432                 return op::Constant::create(element::i64, Shape{mapping.size()}, mapping);
433             }
434
435             Output<Node> get_axes_mapping_output(const Shape& output_shape,
436                                                  const AxisSet& broadcast_axes)
437             {
438                 vector<size_t> axes_mapping{get_axes_mapping(output_shape, broadcast_axes)};
439                 return op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
440             }
441
442             Output<Node> make_broadcast(const Output<Node>& node,
443                                         const Shape& target_shape,
444                                         const AxisSet& broadcast_axes)
445             {
446                 return make_shared<op::v1::Broadcast>(
447                     node,
448                     op::Constant::create(element::i64, Shape{target_shape.size()}, target_shape),
449                     get_axes_mapping_output(target_shape, broadcast_axes));
450             }
451
452             Output<Node> make_broadcast(const Output<Node>& node,
453                                         const Shape& target_shape,
454                                         size_t start_match_axis)
455             {
456                 return make_shared<op::v1::Broadcast>(
457                     node,
458                     op::Constant::create(element::i64, Shape{target_shape.size()}, target_shape),
459                     get_axes_mapping_output(target_shape, node.get_shape(), start_match_axis));
460             }
461
462         } // namespace opset1
463     }     // namespace builder
464 } // namespace ngraph