Remove obsoleted v0::Broadcast and BroadcastLike operators (#2779)
[platform/upstream/dldt.git] / ngraph / core / src / op / broadcast.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "itt.hpp"
18
19 #include "ngraph/attribute_visitor.hpp"
20 #include "ngraph/op/broadcast.hpp"
21 #include "ngraph/op/constant.hpp"
22 #include "ngraph/op/sum.hpp"
23 #include "ngraph/partial_shape.hpp"
24
25 #include <numeric>
26 #include "ngraph/runtime/host_tensor.hpp"
27 #include "ngraph/runtime/reference/broadcast.hpp"
28
29 NGRAPH_SUPPRESS_DEPRECATED_START
30
31 using namespace std;
32 using namespace ngraph;
33
34 constexpr NodeTypeInfo op::v3::Broadcast::type_info;
35
36 op::v3::Broadcast::Broadcast(const Output<Node>& arg,
37                              const Output<Node>& target_shape,
38                              const Output<Node>& axes_mapping,
39                              const BroadcastModeSpec& broadcast_spec)
40     : util::BroadcastBase{arg, target_shape, axes_mapping, broadcast_spec}
41 {
42     constructor_validate_and_infer_types();
43 }
44
45 op::v3::Broadcast::Broadcast(const Output<Node>& arg,
46                              const Output<Node>& target_shape,
47                              const BroadcastModeSpec& broadcast_spec)
48     : util::BroadcastBase{arg, target_shape, broadcast_spec}
49 {
50     constructor_validate_and_infer_types();
51 }
52
53 namespace
54 {
55     std::pair<bool, AxisSet> get_broadcast_axes_bidirectional(const Shape& arg_shape,
56                                                               const Shape& result_shape)
57     {
58         AxisSet broadcast_axes;
59         bool axes_known = false;
60         const auto start_axis = result_shape.size() - arg_shape.size();
61         NGRAPH_CHECK(start_axis >= 0);
62         for (size_t i = 0; i < result_shape.size(); i++)
63         {
64             if (i < start_axis || result_shape[i] != arg_shape[i - start_axis])
65             {
66                 broadcast_axes.insert(i);
67             }
68         }
69         axes_known = true;
70         return std::make_pair(axes_known, broadcast_axes);
71     }
72 }
73
74 std::pair<bool, AxisSet> op::v3::Broadcast::get_broadcast_axes() const
75 {
76     if (m_mode.m_type == BroadcastType::BIDIRECTIONAL)
77     {
78         AxisSet broadcast_axes;
79         bool axes_known = false;
80
81         if (get_input_partial_shape(0).is_static() && get_output_partial_shape(0).is_static())
82         {
83             const auto arg_shape = get_input_shape(0);
84             const auto result_shape = get_output_shape(0);
85             return get_broadcast_axes_bidirectional(arg_shape, result_shape);
86         }
87         return std::make_pair(axes_known, broadcast_axes);
88     }
89
90     return util::BroadcastBase::get_broadcast_axes();
91 }
92
93 namespace
94 {
95     PartialShape get_result_shape_bidirectional(const Node* this_ptr,
96                                                 const PartialShape& arg_shape,
97                                                 Shape& target_shape)
98     {
99         if (arg_shape.rank().is_dynamic())
100         {
101             return PartialShape::dynamic();
102         }
103         auto arg_shape_vec = static_cast<std::vector<Dimension>>(arg_shape);
104         PartialShape result_shape;
105         // Add left padding to shorter target or argument shape
106         const auto target_padded_rank = std::max(arg_shape_vec.size(), target_shape.size());
107         while (arg_shape_vec.size() < target_padded_rank)
108         {
109             arg_shape_vec.insert(arg_shape_vec.begin(), 1);
110         }
111         while (target_shape.size() < target_padded_rank)
112         {
113             target_shape.insert(target_shape.begin(), 1);
114         }
115
116         result_shape = target_shape;
117         for (auto i = 0; i < target_shape.size(); ++i)
118         {
119             if (arg_shape_vec[i].is_dynamic())
120             {
121                 if (target_shape[i] == 1)
122                 {
123                     result_shape[i] = Dimension::dynamic();
124                 }
125                 else
126                 {
127                     result_shape[i] = target_shape[i];
128                 }
129                 continue;
130             }
131             const size_t arg_shape_dim = arg_shape_vec[i].get_length();
132             NODE_VALIDATION_CHECK(this_ptr,
133                                   arg_shape_dim == 1 || target_shape[i] == 1 ||
134                                       arg_shape_dim == target_shape[i],
135                                   "Broadcast incorrect target shape. Expecting either 1 or ",
136                                   arg_shape_dim,
137                                   ". Got ",
138                                   target_shape[i]);
139
140             result_shape[i] = std::max(arg_shape_dim, target_shape[i]);
141         }
142         return result_shape;
143     }
144 }
145
146 void op::v3::Broadcast::validate_and_infer_types()
147 {
148     if (m_mode.m_type == BroadcastType::NONE)
149     {
150         NODE_VALIDATION_CHECK(this,
151                               get_input_size() == 3,
152                               "axes_mapping input should be provided if explicit mode is used");
153     }
154     else
155     {
156         NODE_VALIDATION_CHECK(
157             this,
158             get_input_size() == 2,
159             "axes_mapping input should not be provided for mode other than explicit");
160     }
161
162     util::BroadcastBase::validate_and_infer_types();
163
164     auto result_shape = get_output_partial_shape(0);
165     if (m_mode.m_type == BroadcastType::BIDIRECTIONAL)
166     {
167         if (get_input_partial_shape(0).rank().is_static() && get_input_partial_shape(1).is_static())
168         {
169             auto arg_shape = get_input_partial_shape(0);
170
171             const auto shape_constant =
172                 as_type_ptr<op::v0::Constant>(input_value(1).get_node_shared_ptr());
173             if (shape_constant)
174             {
175                 auto target_shape = shape_constant->get_shape_val();
176                 result_shape = get_result_shape_bidirectional(this, arg_shape, target_shape);
177             }
178         }
179     }
180     set_input_is_relevant_to_shape(0); // arg - Result element type
181     set_input_is_relevant_to_shape(1); // target_shape - Result shape
182     if (get_input_size() == 3)
183     {
184         set_input_is_relevant_to_shape(2); // axes_mapping - Broadcast type
185     }
186     set_output_type(0, get_input_element_type(0), result_shape);
187 }
188
189 shared_ptr<Node> op::v3::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const
190 {
191     check_new_args_count(this, new_args);
192     if (new_args.size() == 2)
193     {
194         return make_shared<v3::Broadcast>(new_args.at(0), new_args.at(1), m_mode);
195     }
196     else if (new_args.size() == 3)
197     {
198         return make_shared<v3::Broadcast>(new_args.at(0), new_args.at(1), new_args.at(2), m_mode);
199     }
200     else
201     {
202         throw ngraph_error("Not supported number of Broadcast:v3 args");
203     }
204 }
205
206 bool op::v3::Broadcast::visit_attributes(AttributeVisitor& visitor)
207 {
208     visitor.on_attribute("mode", m_mode);
209     return true;
210 }
211
212 bool op::v3::Broadcast::evaluate(const HostTensorVector& outputs,
213                                  const HostTensorVector& inputs) const
214 {
215     OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v3::Broadcast::evaluate");
216     if (get_broadcast_spec().m_type == op::BroadcastType::BIDIRECTIONAL)
217     {
218         auto arg_shape = inputs[0]->get_shape();
219         Shape target_shape = op::util::BroadcastBase::get_target_shape(inputs[1]);
220         PartialShape result_shape =
221             get_result_shape_bidirectional(this, PartialShape{arg_shape}, target_shape);
222         auto pair_broadcast_axes =
223             get_broadcast_axes_bidirectional(arg_shape, result_shape.to_shape());
224         return op::util::BroadcastBase::evaluate_broadcast(
225             inputs[0], outputs[0], pair_broadcast_axes, result_shape.to_shape());
226     }
227     return op::util::BroadcastBase::evaluate(outputs, inputs);
228 }
229
230 namespace
231 {
232     using namespace op;
233     BroadcastModeSpec to_broadcast_mode(const AutoBroadcastSpec& bs)
234     {
235         BroadcastModeSpec broadcast_mode;
236         broadcast_mode.m_axis = bs.m_axis;
237         switch (bs.m_type)
238         {
239         case AutoBroadcastType::NONE: broadcast_mode.m_type = BroadcastType::NONE; break;
240         case AutoBroadcastType::NUMPY: broadcast_mode.m_type = BroadcastType::NUMPY; break;
241         case AutoBroadcastType::PDPD: broadcast_mode.m_type = BroadcastType::PDPD; break;
242         }
243         return broadcast_mode;
244     }
245 }
246
247 constexpr NodeTypeInfo op::v1::Broadcast::type_info;
248
249 op::v1::Broadcast::Broadcast(const Output<Node>& arg,
250                              const Output<Node>& target_shape,
251                              const Output<Node>& axes_mapping,
252                              const AutoBroadcastSpec& broadcast_spec)
253     : util::BroadcastBase{arg, target_shape, axes_mapping, to_broadcast_mode(broadcast_spec)}
254     , m_broadcast_spec{broadcast_spec}
255 {
256     constructor_validate_and_infer_types();
257 }
258
259 op::v1::Broadcast::Broadcast(const Output<Node>& arg,
260                              const Output<Node>& target_shape,
261                              const AutoBroadcastSpec& broadcast_spec)
262     : util::BroadcastBase{arg,
263                           target_shape,
264                           op::v0::Constant::create(element::u8, Shape{}, {0})->output(0),
265                           to_broadcast_mode(broadcast_spec)}
266     , m_broadcast_spec{broadcast_spec}
267 {
268     constructor_validate_and_infer_types();
269 }
270
271 void op::v1::Broadcast::validate_and_infer_types()
272 {
273     util::BroadcastBase::validate_and_infer_types();
274
275     set_input_is_relevant_to_shape(0); // arg - Result element type
276     set_input_is_relevant_to_shape(1); // target_shape - Result shape
277     set_input_is_relevant_to_shape(2); // axes_mapping - Broadcast type
278 }
279
280 shared_ptr<Node> op::v1::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const
281 {
282     check_new_args_count(this, new_args);
283     return make_shared<v1::Broadcast>(
284         new_args.at(0), new_args.at(1), new_args.at(2), m_broadcast_spec);
285 }
286
287 bool op::v1::Broadcast::visit_attributes(AttributeVisitor& visitor)
288 {
289     visitor.on_attribute("mode", m_broadcast_spec);
290     return true;
291 }
292
293 bool op::v1::Broadcast::evaluate(const HostTensorVector& outputs,
294                                  const HostTensorVector& inputs) const
295 {
296     OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v1::Broadcast::evaluate");
297     return op::util::BroadcastBase::evaluate(outputs, inputs);
298 }