c5992ca7ac934dd38302a78e60ea172f79c875fa
[platform/upstream/dldt.git] / ngraph / core / src / op / broadcast.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "itt.hpp"
18
19 #include "ngraph/attribute_visitor.hpp"
20 #include "ngraph/op/broadcast.hpp"
21 #include "ngraph/op/constant.hpp"
22 #include "ngraph/op/sum.hpp"
23 #include "ngraph/partial_shape.hpp"
24
25 #include <numeric>
26 #include "ngraph/runtime/host_tensor.hpp"
27 #include "ngraph/runtime/reference/broadcast.hpp"
28
29 NGRAPH_SUPPRESS_DEPRECATED_START
30
31 using namespace std;
32 using namespace ngraph;
33
34 constexpr NodeTypeInfo op::v3::Broadcast::type_info;
35
36 op::v3::Broadcast::Broadcast(const Output<Node>& arg,
37                              const Output<Node>& target_shape,
38                              const Output<Node>& axes_mapping,
39                              const BroadcastModeSpec& broadcast_spec)
40     : util::BroadcastBase{arg, target_shape, axes_mapping, broadcast_spec}
41 {
42     constructor_validate_and_infer_types();
43 }
44
45 op::v3::Broadcast::Broadcast(const Output<Node>& arg,
46                              const Output<Node>& target_shape,
47                              const BroadcastModeSpec& broadcast_spec)
48     : util::BroadcastBase{arg, target_shape, broadcast_spec}
49 {
50     constructor_validate_and_infer_types();
51 }
52
53 namespace
54 {
55     std::pair<bool, AxisSet> get_broadcast_axes_bidirectional(const Shape& arg_shape,
56                                                               const Shape& result_shape)
57     {
58         AxisSet broadcast_axes;
59         bool axes_known = false;
60         const auto start_axis = result_shape.size() - arg_shape.size();
61         NGRAPH_CHECK(start_axis >= 0);
62         for (size_t i = 0; i < result_shape.size(); i++)
63         {
64             if (i < start_axis || result_shape[i] != arg_shape[i - start_axis])
65             {
66                 broadcast_axes.insert(i);
67             }
68         }
69         axes_known = true;
70         return std::make_pair(axes_known, broadcast_axes);
71     }
72 }
73
74 std::pair<bool, AxisSet> op::v3::Broadcast::get_broadcast_axes() const
75 {
76     if (m_mode.m_type == BroadcastType::BIDIRECTIONAL)
77     {
78         AxisSet broadcast_axes;
79         bool axes_known = false;
80
81         if (get_input_partial_shape(0).is_static() && get_output_partial_shape(0).is_static())
82         {
83             const auto arg_shape = get_input_shape(0);
84             const auto result_shape = get_output_shape(0);
85             return get_broadcast_axes_bidirectional(arg_shape, result_shape);
86         }
87         return std::make_pair(axes_known, broadcast_axes);
88     }
89
90     return util::BroadcastBase::get_broadcast_axes();
91 }
92
93 namespace
94 {
95     PartialShape get_result_shape_bidirectional(const Node* this_ptr,
96                                                 const PartialShape& arg_shape,
97                                                 Shape& target_shape)
98     {
99         if (arg_shape.rank().is_dynamic())
100         {
101             return PartialShape::dynamic();
102         }
103         auto arg_shape_vec = static_cast<std::vector<Dimension>>(arg_shape);
104         PartialShape result_shape;
105         // Add left padding to shorter target or argument shape
106         const auto target_padded_rank = std::max(arg_shape_vec.size(), target_shape.size());
107         while (arg_shape_vec.size() < target_padded_rank)
108         {
109             arg_shape_vec.insert(arg_shape_vec.begin(), 1);
110         }
111         while (target_shape.size() < target_padded_rank)
112         {
113             target_shape.insert(target_shape.begin(), 1);
114         }
115
116         result_shape = target_shape;
117         for (auto i = 0; i < target_shape.size(); ++i)
118         {
119             if (arg_shape_vec[i].is_dynamic())
120             {
121                 if (target_shape[i] == 1)
122                 {
123                     result_shape[i] = Dimension::dynamic();
124                 }
125                 else
126                 {
127                     result_shape[i] = target_shape[i];
128                 }
129                 continue;
130             }
131             const size_t arg_shape_dim = arg_shape_vec[i].get_length();
132             NODE_VALIDATION_CHECK(this_ptr,
133                                   arg_shape_dim == 1 || target_shape[i] == 1 ||
134                                       arg_shape_dim == target_shape[i],
135                                   "Broadcast incorrect target shape. Expecting either 1 or ",
136                                   arg_shape_dim,
137                                   ". Got ",
138                                   target_shape[i]);
139
140             result_shape[i] = std::max(arg_shape_dim, target_shape[i]);
141         }
142         return result_shape;
143     }
144 }
145
146 void op::v3::Broadcast::validate_and_infer_types()
147 {
148     if (m_mode.m_type == BroadcastType::NONE)
149     {
150         NODE_VALIDATION_CHECK(this,
151                               get_input_size() == 3,
152                               "axes_mapping input should be provided if explicit mode is used");
153     }
154     else
155     {
156         NODE_VALIDATION_CHECK(
157             this,
158             get_input_size() == 2,
159             "axes_mapping input should not be provided for mode other than explicit");
160     }
161
162     util::BroadcastBase::validate_and_infer_types();
163
164     auto result_shape = get_output_partial_shape(0);
165     if (m_mode.m_type == BroadcastType::BIDIRECTIONAL)
166     {
167         if (get_input_partial_shape(0).rank().is_static() && get_input_partial_shape(1).is_static())
168         {
169             auto arg_shape = get_input_partial_shape(0);
170
171             const auto shape_constant =
172                 as_type_ptr<op::v0::Constant>(input_value(1).get_node_shared_ptr());
173             if (shape_constant)
174             {
175                 auto target_shape = shape_constant->get_shape_val();
176                 result_shape = get_result_shape_bidirectional(this, arg_shape, target_shape);
177             }
178         }
179     }
180     set_input_is_relevant_to_shape(0); // arg - Result element type
181     set_input_is_relevant_to_shape(1); // target_shape - Result shape
182     if (get_input_size() == 3)
183     {
184         set_input_is_relevant_to_shape(2); // axes_mapping - Broadcast type
185     }
186     set_output_type(0, get_input_element_type(0), result_shape);
187 }
188
189 shared_ptr<Node> op::v3::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const
190 {
191     check_new_args_count(this, new_args);
192     if (new_args.size() == 2)
193     {
194         return make_shared<v3::Broadcast>(new_args.at(0), new_args.at(1), m_mode);
195     }
196     else if (new_args.size() == 3)
197     {
198         return make_shared<v3::Broadcast>(new_args.at(0), new_args.at(1), new_args.at(2), m_mode);
199     }
200     else
201     {
202         throw ngraph_error("Not supported number of Broadcast:v3 args");
203     }
204 }
205
206 bool op::v3::Broadcast::visit_attributes(AttributeVisitor& visitor)
207 {
208     visitor.on_attribute("mode", m_mode);
209     return true;
210 }
211
212 bool op::v3::Broadcast::evaluate(const HostTensorVector& outputs,
213                                  const HostTensorVector& inputs) const
214 {
215     OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v3::Broadcast::evaluate");
216     if (get_broadcast_spec().m_type == op::BroadcastType::BIDIRECTIONAL)
217     {
218         auto arg_shape = inputs[0]->get_shape();
219         Shape target_shape = op::util::BroadcastBase::get_target_shape(inputs[1]);
220         PartialShape result_shape =
221             get_result_shape_bidirectional(this, PartialShape{arg_shape}, target_shape);
222         auto pair_broadcast_axes =
223             get_broadcast_axes_bidirectional(arg_shape, result_shape.to_shape());
224         return op::util::BroadcastBase::evaluate_broadcast(
225             inputs[0], outputs[0], pair_broadcast_axes, result_shape.to_shape());
226     }
227     return op::util::BroadcastBase::evaluate(outputs, inputs);
228 }
229
230 namespace
231 {
232     using namespace op;
233     BroadcastModeSpec to_broadcast_mode(const AutoBroadcastSpec& bs)
234     {
235         BroadcastModeSpec broadcast_mode;
236         broadcast_mode.m_axis = bs.m_axis;
237         switch (bs.m_type)
238         {
239         case AutoBroadcastType::NONE: broadcast_mode.m_type = BroadcastType::NONE; break;
240         case AutoBroadcastType::NUMPY: broadcast_mode.m_type = BroadcastType::NUMPY; break;
241         case AutoBroadcastType::PDPD: broadcast_mode.m_type = BroadcastType::PDPD; break;
242         }
243         return broadcast_mode;
244     }
245 }
246
247 constexpr NodeTypeInfo op::v1::Broadcast::type_info;
248
249 op::v1::Broadcast::Broadcast(const Output<Node>& arg,
250                              const Output<Node>& target_shape,
251                              const Output<Node>& axes_mapping,
252                              const AutoBroadcastSpec& broadcast_spec)
253     : util::BroadcastBase{arg, target_shape, axes_mapping, to_broadcast_mode(broadcast_spec)}
254     , m_broadcast_spec{broadcast_spec}
255 {
256     constructor_validate_and_infer_types();
257 }
258
259 op::v1::Broadcast::Broadcast(const Output<Node>& arg,
260                              const Output<Node>& target_shape,
261                              const AutoBroadcastSpec& broadcast_spec)
262     : util::BroadcastBase{arg,
263                           target_shape,
264                           op::v0::Constant::create(element::u8, Shape{}, {0})->output(0),
265                           to_broadcast_mode(broadcast_spec)}
266     , m_broadcast_spec{broadcast_spec}
267 {
268     constructor_validate_and_infer_types();
269 }
270
271 void op::v1::Broadcast::validate_and_infer_types()
272 {
273     util::BroadcastBase::validate_and_infer_types();
274
275     set_input_is_relevant_to_shape(0); // arg - Result element type
276     set_input_is_relevant_to_shape(1); // target_shape - Result shape
277     set_input_is_relevant_to_shape(2); // axes_mapping - Broadcast type
278 }
279
280 shared_ptr<Node> op::v1::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const
281 {
282     check_new_args_count(this, new_args);
283     return make_shared<v1::Broadcast>(
284         new_args.at(0), new_args.at(1), new_args.at(2), m_broadcast_spec);
285 }
286
287 bool op::v1::Broadcast::visit_attributes(AttributeVisitor& visitor)
288 {
289     visitor.on_attribute("mode", m_broadcast_spec);
290     return true;
291 }
292
293 bool op::v1::Broadcast::evaluate(const HostTensorVector& outputs,
294                                  const HostTensorVector& inputs) const
295 {
296     OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v1::Broadcast::evaluate");
297     return op::util::BroadcastBase::evaluate(outputs, inputs);
298 }
299
300 constexpr NodeTypeInfo op::v0::Broadcast::type_info;
301
302 op::v0::Broadcast::Broadcast(const OutputVector& args,
303                              const Shape& shape,
304                              const AxisSet& broadcast_axes)
305     : Op(args)
306     , m_shape(shape)
307     , m_broadcast_axes(broadcast_axes)
308 {
309     constructor_validate_and_infer_types();
310 }
311
312 op::v0::Broadcast::Broadcast(const Output<Node>& arg,
313                              const Shape& shape,
314                              const AxisSet& broadcast_axes)
315     : Broadcast(OutputVector{arg}, shape, broadcast_axes)
316 {
317 }
318
319 bool op::v0::Broadcast::visit_attributes(AttributeVisitor& visitor)
320 {
321     visitor.on_attribute("shape", m_shape);
322     visitor.on_attribute("broadcast_axes", m_broadcast_axes);
323     return true;
324 }
325
326 void op::v0::Broadcast::validate_and_infer_types()
327 {
328     infer_shape();
329
330     for (auto axis : m_broadcast_axes)
331     {
332         NODE_VALIDATION_CHECK(this,
333                               axis < m_shape.size(),
334                               "Broadcast axis index (",
335                               axis,
336                               ") exceeds specified output shape rank ",
337                               "(broadcast axes: ",
338                               m_broadcast_axes,
339                               ", output shape: ",
340                               m_shape,
341                               ").");
342     }
343
344     Shape required_input_shape = m_shape;
345     for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
346     {
347         required_input_shape.erase(required_input_shape.begin() + *i);
348     }
349
350     // TODO(amprocte): We can probably have a more helpful error message here.
351     // There are two things that can go wrong, which are being picked up in
352     // one fell swoop by this check: either the number of broadcast axes is not
353     // enough, or there is a mismatch with one of the pre-broadcast axis lengths.
354     NODE_VALIDATION_CHECK(
355         this,
356         get_input_partial_shape(0).compatible(required_input_shape),
357         "Broadcast argument shape, specified output shape, and axes are incompatible ",
358         "(argument shape: ",
359         get_input_partial_shape(0),
360         ", output shape: ",
361         m_shape,
362         ", broadcast axes: ",
363         m_broadcast_axes,
364         ").");
365
366     set_output_type(0, get_input_element_type(0), m_shape);
367 }
368
369 shared_ptr<Node> op::v0::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const
370 {
371     check_new_args_count(this, new_args);
372     return make_shared<v0::Broadcast>(new_args.at(0), m_shape, m_broadcast_axes);
373 }
374
375 namespace
376 {
377 #define TYPE_CASE_v0(a)                                                                            \
378     case element::Type_t::a: rc = evaluate_v0<element::Type_t::a>
379
380     template <element::Type_t ET>
381     inline bool evaluate_v0(const HostTensorPtr& arg0,
382                             const HostTensorPtr& out,
383                             const AxisSet& broadcast_axes)
384     {
385         using T = typename element_type_traits<ET>::value_type;
386         runtime::reference::broadcast<T>((arg0->get_data_ptr<ET>()),
387                                          (out->get_data_ptr<ET>()),
388                                          arg0->get_shape(),
389                                          out->get_shape(),
390                                          broadcast_axes);
391         return true;
392     }
393
394     bool evaluate_broadcast_v0(const HostTensorPtr& arg0,
395                                const HostTensorPtr& out,
396                                const AxisSet broadcast_axes,
397                                const Shape output_shape)
398     {
399         bool rc = true;
400         Shape in_shape = arg0->get_shape();
401         out->set_shape(output_shape);
402         out->set_element_type(arg0->get_element_type());
403         switch (arg0->get_element_type())
404         {
405             TYPE_CASE_v0(boolean)(arg0, out, broadcast_axes);
406             break;
407             TYPE_CASE_v0(i8)(arg0, out, broadcast_axes);
408             break;
409             TYPE_CASE_v0(i16)(arg0, out, broadcast_axes);
410             break;
411             TYPE_CASE_v0(i32)(arg0, out, broadcast_axes);
412             break;
413             TYPE_CASE_v0(i64)(arg0, out, broadcast_axes);
414             break;
415             TYPE_CASE_v0(u8)(arg0, out, broadcast_axes);
416             break;
417             TYPE_CASE_v0(u16)(arg0, out, broadcast_axes);
418             break;
419             TYPE_CASE_v0(u32)(arg0, out, broadcast_axes);
420             break;
421             TYPE_CASE_v0(u64)(arg0, out, broadcast_axes);
422             break;
423             TYPE_CASE_v0(bf16)(arg0, out, broadcast_axes);
424             break;
425             TYPE_CASE_v0(f16)(arg0, out, broadcast_axes);
426             break;
427             TYPE_CASE_v0(f32)(arg0, out, broadcast_axes);
428             break;
429             TYPE_CASE_v0(f64)(arg0, out, broadcast_axes);
430             break;
431         default: rc = false; break;
432         }
433         return rc;
434     }
435 }
436
437 bool op::v0::Broadcast::evaluate(const HostTensorVector& outputs,
438                                  const HostTensorVector& inputs) const
439 {
440     OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v0::Broadcast::evaluate");
441     return evaluate_broadcast_v0(inputs[0], outputs[0], get_broadcast_axes(), get_output_shape(0));
442 }
443
444 constexpr NodeTypeInfo op::v0::BroadcastLike::type_info;
445
446 op::v0::BroadcastLike::BroadcastLike(const Output<Node>& arg,
447                                      const Output<Node>& like_arg,
448                                      const AxisSet& initial_broadcast_axes)
449     : op::v0::Broadcast({arg, like_arg}, {}, {})
450     , m_initial_broadcast_axes(initial_broadcast_axes)
451 {
452     constructor_validate_and_infer_types();
453 }
454
455 bool op::v0::BroadcastLike::visit_attributes(AttributeVisitor& visitor)
456 {
457     visitor.on_attribute("shape", m_shape);
458     visitor.on_attribute("broadcast_axes", m_broadcast_axes);
459     visitor.on_attribute("initial_broadcast_axes", m_initial_broadcast_axes);
460     return true;
461 }
462
463 shared_ptr<Node> op::v0::BroadcastLike::clone_with_new_inputs(const OutputVector& new_args) const
464 {
465     if (new_args.size() != 2)
466     {
467         throw ngraph_error("Incorrect number of new arguments");
468     }
469     return make_shared<v0::BroadcastLike>(new_args.at(0), new_args.at(1), m_initial_broadcast_axes);
470 }
471
472 void op::v0::BroadcastLike::infer_shape()
473 {
474     const Shape& in_shape = get_input_shape(0);
475     m_shape = get_input_shape(1);
476     m_broadcast_axes = m_initial_broadcast_axes;
477     if (m_broadcast_axes.size() == 0)
478     {
479         for (size_t i = 0; i < m_shape.size(); ++i)
480         {
481             if (i < in_shape.size())
482             {
483                 if (in_shape.at(i) == 1 && m_shape.at(i) > 1)
484                 {
485                     m_broadcast_axes.insert(i);
486                 }
487             }
488             else
489             {
490                 m_broadcast_axes.insert(i);
491             }
492         }
493     }
494 }