Remove obsoleted v0::Broadcast and BroadcastLike operators (#2779)
[platform/upstream/dldt.git] / ngraph / test / runtime / pass / opset1_upgrade.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 #include "opset1_upgrade.hpp"
17
18 #include <functional>
19 #include <iterator>
20 #include <limits>
21 #include <numeric>
22
23 #include "ngraph/builder/autobroadcast.hpp"
24 #include "ngraph/builder/reshape.hpp"
25 #include "ngraph/graph_util.hpp"
26 #include "ngraph/op/util/op_types.hpp"
27 #include "ngraph/ops.hpp"
28 #include "ngraph/provenance.hpp"
29 #include "op/avg_pool.hpp"
30 #include "op/convolution.hpp"
31 #include "op/group_conv.hpp"
32
33 NGRAPH_SUPPRESS_DEPRECATED_START
34
35 using namespace std;
36 using namespace ngraph;
37
38 namespace opset1_upgrade
39 {
40     template <typename OpV0, typename OpV1>
41     shared_ptr<Node> op_cast_binary_elementwise_node(const shared_ptr<OpV0>& node)
42     {
43         const auto autob = node->get_autob();
44         auto replacement_node =
45             make_shared<OpV1>(node->input_value(0), node->input_value(1), autob);
46         replace_node(node, replacement_node);
47         return replacement_node;
48     }
49
50     // Default is that we didn nothing
51     shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; }
52     shared_ptr<Node> op_cast(shared_ptr<op::Add> node)
53     {
54         return op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
55     }
56
57     shared_ptr<Node> op_cast(shared_ptr<op::v0::Convolution> node)
58     {
59         auto strides = node->get_window_movement_strides();
60         auto dilations = node->get_window_dilation_strides();
61         auto pads_begin = node->get_padding_below();
62         auto pads_end = node->get_padding_above();
63         auto data_dilation_strides = node->get_data_dilation_strides();
64         auto auto_pad = node->get_pad_type();
65
66         bool is_dds_valid = all_of(data_dilation_strides.begin(),
67                                    data_dilation_strides.end(),
68                                    [](size_t value) { return value == 1; });
69
70         NGRAPH_CHECK(is_dds_valid,
71                      "Unable to convert Convolution:0 to Convolution:1 with data dilation strides "
72                      "other than `1`. Node: ",
73                      *node);
74
75         auto replacement_node = make_shared<op::v1::Convolution>(node->input_value(0),
76                                                                  node->input_value(1),
77                                                                  strides,
78                                                                  pads_begin,
79                                                                  pads_end,
80                                                                  dilations,
81                                                                  auto_pad);
82         replace_node(node, replacement_node);
83         return replacement_node;
84     }
85
86     shared_ptr<Node> op_cast(shared_ptr<op::v0::ConvolutionBackpropData> node)
87     {
88         auto data_batch_shape = node->get_data_batch_shape();
89         auto strides = node->get_window_movement_strides_forward();
90         auto dilations = node->get_window_dilation_strides_forward();
91         auto pads_begin = node->get_padding_below_forward();
92         auto pads_end = node->get_padding_above_forward();
93         auto data_dilation_strides = node->get_data_dilation_strides_forward();
94
95         bool is_dds_valid = all_of(data_dilation_strides.begin(),
96                                    data_dilation_strides.end(),
97                                    [](size_t value) { return value == 1; });
98
99         NGRAPH_CHECK(is_dds_valid,
100                      "Unable to convert ConvolutionBackpropData:0 to ConvolutionBackpropData:1 "
101                      "with data dilation strides "
102                      "other than `1`. Node: ",
103                      *node);
104
105         auto replacement_node = make_shared<op::v1::ConvolutionBackpropData>(
106             node->input_value(1), // data
107             node->input_value(0), // filters
108             op::Constant::create(
109                 element::i64,
110                 Shape{data_batch_shape.size() - 2},
111                 vector<size_t>(data_batch_shape.begin() + 2, data_batch_shape.end())),
112             strides,
113             pads_begin,
114             pads_end,
115             dilations);
116         replace_node(node, replacement_node);
117         return replacement_node;
118     }
119
120     shared_ptr<Node> op_cast(shared_ptr<op::Divide> node)
121     {
122         const auto autob = node->get_autob();
123         const bool pydiv = node->is_pythondiv();
124         auto replacement_node =
125             make_shared<op::v1::Divide>(node->input_value(0), node->input_value(1), pydiv, autob);
126         replace_node(node, replacement_node);
127         return replacement_node;
128     }
129
130     shared_ptr<Node> op_cast(shared_ptr<op::Reshape> node)
131     {
132         shared_ptr<Node> replacement_node =
133             builder::opset1::reshape(node->input_value(0), node->get_reshape_output_shape());
134         replace_node(node, replacement_node);
135         return replacement_node;
136     }
137
138     shared_ptr<Node> op_cast(shared_ptr<op::Equal> node)
139     {
140         return op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
141     }
142
143     shared_ptr<Node> op_cast(shared_ptr<op::Gather> node)
144     {
145         int64_t axis = node->get_axis();
146
147         auto axis_node = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{axis});
148         auto replacement_node =
149             make_shared<op::v1::Gather>(node->input_value(0), node->input_value(1), axis_node);
150         replace_node(node, replacement_node);
151         return replacement_node;
152     }
153
154     shared_ptr<Node> op_cast(shared_ptr<op::Greater> node)
155     {
156         return op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
157     }
158
159     shared_ptr<Node> op_cast(shared_ptr<op::GreaterEq> node)
160     {
161         return op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node);
162     }
163
164     shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolution> node)
165     {
166         auto strides = node->get_window_movement_strides();
167         auto dilations = node->get_window_dilation_strides();
168         auto pads_begin = node->get_padding_below();
169         auto pads_end = node->get_padding_above();
170         auto data_dilation_strides = node->get_data_dilation_strides();
171         auto auto_pad = node->get_pad_type();
172
173         bool is_dds_valid = all_of(data_dilation_strides.begin(),
174                                    data_dilation_strides.end(),
175                                    [](size_t value) { return value == 1; });
176
177         NGRAPH_CHECK(is_dds_valid,
178                      "Unable to convert GroupConvolution:0 to GroupConvolution:1"
179                      "with data dilation strides other than `1`. Node: ",
180                      *node);
181
182         shared_ptr<Node> replacement_node;
183         if (node->has_groups_in_filters())
184         {
185             replacement_node = make_shared<op::v1::GroupConvolution>(node->input_value(0),
186                                                                      node->input_value(1),
187                                                                      strides,
188                                                                      pads_begin,
189                                                                      pads_end,
190                                                                      dilations,
191                                                                      auto_pad);
192         }
193         else
194         {
195             NGRAPH_CHECK(node->get_input_partial_shape(1).is_static(),
196                          "Unable to convert GroupConvolution:0 to GroupConvolution:1"
197                          "with dynamic filters shape. Node: ",
198                          *node);
199
200             auto filters_shape = node->get_input_shape(1);
201             auto groups = node->get_groups();
202             filters_shape[0] /= groups;
203             filters_shape.insert(filters_shape.begin(), groups);
204
205             auto reshaped_filters = builder::opset1::reshape(node->input_value(1), filters_shape);
206
207             replacement_node = make_shared<op::v1::GroupConvolution>(node->input_value(0),
208                                                                      reshaped_filters,
209                                                                      strides,
210                                                                      pads_begin,
211                                                                      pads_end,
212                                                                      dilations,
213                                                                      auto_pad);
214         }
215         replace_node(node, replacement_node);
216         return replacement_node;
217     }
218
219     shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolutionBackpropData> node)
220     {
221         const auto strides = node->get_window_movement_strides();
222         const auto dilations = node->get_window_dilation_strides();
223         const auto pads_begin = node->get_padding_below();
224         const auto pads_end = node->get_padding_above();
225
226         const auto data_batch_pshape = node->get_input_partial_shape(0);
227         const auto filters_pshape = node->get_input_partial_shape(1);
228
229         NGRAPH_CHECK(data_batch_pshape.is_static(),
230                      "Unable to convert GroupConvolutionBackpropData:0 to "
231                      "GroupConvolutionBackpropData:1 with dynamic data_batch shape. Node: ",
232                      *node);
233         NGRAPH_CHECK(filters_pshape.is_static(),
234                      "Unable to convert GroupConvolutionBackpropData:0 to "
235                      "GroupConvolutionBackpropData:1 with dynamic filters shape. Node: ",
236                      *node);
237
238         auto data_batch_shape = data_batch_pshape.to_shape();
239         // Remove N, C from output shape to preserve only spatial dimentions.
240         data_batch_shape.erase(std::begin(data_batch_shape),
241                                std::next(std::begin(data_batch_shape), 2));
242         auto filters_shape = filters_pshape.to_shape();
243         auto groups = node->get_groups();
244
245         filters_shape[0] /= groups;
246         filters_shape.insert(filters_shape.begin(), groups);
247         auto reshaped_filters = builder::opset1::reshape(node->input_value(1), filters_shape);
248
249         auto replacement_node = make_shared<op::v1::GroupConvolutionBackpropData>(
250             node->input_value(2),
251             reshaped_filters,
252             op::Constant::create(element::i64, Shape{data_batch_shape.size()}, data_batch_shape),
253             strides,
254             pads_begin,
255             pads_end,
256             dilations);
257         replace_node(node, replacement_node);
258         return replacement_node;
259     }
260
261     shared_ptr<Node> op_cast(shared_ptr<op::Less> node)
262     {
263         return op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
264     }
265
266     shared_ptr<Node> op_cast(shared_ptr<op::LessEq> node)
267     {
268         return op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
269     }
270
271     shared_ptr<Node> op_cast(shared_ptr<op::Max> node)
272     {
273         bool keep_dims = false;
274         auto replacement_node =
275             make_shared<op::v1::ReduceMax>(node->input_value(0), node->input_value(1), keep_dims);
276         replace_node(node, replacement_node);
277         return replacement_node;
278     }
279
280     shared_ptr<Node> op_cast(shared_ptr<op::Maximum> node)
281     {
282         return op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
283     }
284
285     shared_ptr<Node> op_cast(shared_ptr<op::Min> node)
286     {
287         bool keep_dims = false;
288         auto replacement_node =
289             make_shared<op::v1::ReduceMin>(node->input_value(0), node->input_value(1), keep_dims);
290         replace_node(node, replacement_node);
291         return replacement_node;
292     }
293
294     shared_ptr<Node> op_cast(shared_ptr<op::Minimum> node)
295     {
296         return op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
297     }
298
299     shared_ptr<Node> op_cast(shared_ptr<op::Multiply> node)
300     {
301         return op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
302     }
303
304     shared_ptr<Node> op_cast(shared_ptr<op::Not> node)
305     {
306         auto replacement_node = make_shared<op::v1::LogicalNot>(node->input_value(0));
307         replace_node(node, replacement_node);
308         return replacement_node;
309     }
310
311     shared_ptr<Node> op_cast(shared_ptr<op::NotEqual> node)
312     {
313         return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
314     }
315
316     shared_ptr<Node> op_cast(shared_ptr<op::OneHot> node)
317     {
318         const auto indices = node->input_value(0).get_node_shared_ptr();
319         const auto one_hot_axis = node->get_one_hot_axis();
320
321         const auto output_pshape = node->get_output_partial_shape(0);
322         NGRAPH_CHECK(output_pshape[one_hot_axis].is_static(),
323                      "OneHot:v0 one hot axis dimension must be static ",
324                      *node);
325         const auto depth = output_pshape[one_hot_axis].get_length();
326         const auto depth_node = op::Constant::create(element::i64, Shape{}, {depth});
327
328         const auto on_value = op::Constant::create(element::i64, Shape{}, {1});
329         const auto off_value = op::Constant::create(element::i64, Shape{}, {0});
330
331         auto replacement_node =
332             make_shared<op::v1::OneHot>(indices, depth_node, on_value, off_value, one_hot_axis);
333         replace_node(node, replacement_node);
334         return replacement_node;
335     }
336
337     shared_ptr<Node> op_cast(shared_ptr<op::Or> node)
338     {
339         return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
340     }
341
342     shared_ptr<Node> op_cast(shared_ptr<op::Power> node)
343     {
344         return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
345     }
346
347     shared_ptr<Node> op_cast(shared_ptr<op::Product> node)
348     {
349         bool keep_dims = false;
350         auto replacement_node =
351             make_shared<op::v1::ReduceProd>(node->input_value(0), node->input_value(1), keep_dims);
352         replace_node(node, replacement_node);
353         return replacement_node;
354     }
355
356     shared_ptr<Node> op_cast(shared_ptr<op::Reverse> node)
357     {
358         // creates a Constant node from the v0::Reverse reversed_axes attribute
359         // and uses it as the second input of v1::Reverse
360         const auto reversed_axes = node->get_reversed_axes();
361
362         const auto reversed_axes_constant = op::Constant::create(
363             element::i64, Shape{reversed_axes.size()}, reversed_axes.to_vector());
364
365         const auto replacement_node = make_shared<op::v1::Reverse>(
366             node->input_value(0), reversed_axes_constant, op::v1::Reverse::Mode::INDEX);
367
368         replace_node(node, replacement_node);
369         return replacement_node;
370     }
371
372     shared_ptr<Node> op_cast(shared_ptr<op::Select> node)
373     {
374         auto replacement_node = make_shared<op::v1::Select>(node->input_value(0),
375                                                             node->input_value(1),
376                                                             node->input_value(2),
377                                                             op::AutoBroadcastSpec());
378         replace_node(node, replacement_node);
379         return replacement_node;
380     }
381
382     shared_ptr<Node> op_cast(shared_ptr<op::Softmax> node)
383     {
384         NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
385                      "axes parameter is expected to be a static constant");
386
387         AxisSet axes = node->get_axes();
388
389         NGRAPH_CHECK(
390             axes.size() == 1,
391             "Unable to convert Softmax:0 to Softmax:1 with zero or more than one axis. Node: ",
392             *node);
393
394         auto replacement_node =
395             make_shared<op::v1::Softmax>(node->input_value(0), axes.to_vector()[0]);
396         replace_node(node, replacement_node);
397         return replacement_node;
398     }
399
400     shared_ptr<Node> op_cast(shared_ptr<op::Slice> node)
401     {
402         const auto data = node->input_value(0);
403         const auto begin = op::Constant::create(
404             element::i64, Shape{node->get_lower_bounds().size()}, node->get_lower_bounds());
405         const auto end = op::Constant::create(
406             element::i64, Shape{node->get_upper_bounds().size()}, node->get_upper_bounds());
407         const auto strides = op::Constant::create(
408             element::i64, Shape{node->get_strides().size()}, node->get_strides());
409         int64_t input_size = node->get_lower_bounds().size();
410
411         auto replacement_node = make_shared<op::v1::StridedSlice>(data,
412                                                                   begin,
413                                                                   end,
414                                                                   strides,
415                                                                   vector<int64_t>(input_size, 0),
416                                                                   vector<int64_t>(input_size, 0));
417
418         replace_node(node, replacement_node);
419         return replacement_node;
420     }
421
422     shared_ptr<Node> op_cast(shared_ptr<op::Split> node)
423     {
424         const auto& splits_vec = node->get_splits();
425         const auto first_elem = splits_vec.front();
426
427         const bool split_evenly =
428             std::all_of(splits_vec.begin(), splits_vec.end(), [first_elem](const size_t split) {
429                 return split == first_elem;
430             });
431
432         std::shared_ptr<Node> replacement_node;
433         if (split_evenly)
434         {
435             replacement_node = make_shared<op::v1::Split>(
436                 node->input_value(0), node->input_value(1), splits_vec.front());
437         }
438         else
439         {
440             const auto split_lengths =
441                 ngraph::op::Constant::create(element::u64, Shape{splits_vec.size()}, splits_vec);
442
443             replacement_node = make_shared<op::v1::VariadicSplit>(
444                 node->input_value(0), node->input_value(1), split_lengths);
445         }
446
447         replace_node(node, replacement_node);
448         return replacement_node;
449     }
450
451     shared_ptr<Node> op_cast(shared_ptr<op::Subtract> node)
452     {
453         return op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
454     }
455
456     shared_ptr<Node> op_cast(shared_ptr<op::Sum> node)
457     {
458         bool keep_dims = false;
459         auto replacement_node =
460             make_shared<op::v1::ReduceSum>(node->input_value(0), node->input_value(1), keep_dims);
461         replace_node(node, replacement_node);
462         return replacement_node;
463     }
464
465     shared_ptr<Node> op_cast(shared_ptr<op::TopK> node)
466     {
467         NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
468                      "parameter k is expected to be a static constant");
469         NGRAPH_CHECK(op::is_constant(node->input_value(2).get_node()),
470                      "parameter top_k_axis is expected to be a static constant");
471
472         const auto k = node->get_k();
473         const auto axis = node->get_top_k_axis();
474
475         std::string sort;
476         switch (node->get_sort())
477         {
478         case op::TopK::SortType::SORT_INDICES: sort = "index"; break;
479         case op::TopK::SortType::SORT_VALUES: sort = "value"; break;
480         case op::TopK::SortType::NONE: sort = "none"; break;
481         }
482
483         std::string mode;
484         if (node->get_compute_max())
485         {
486             mode = "max";
487         }
488         else
489         {
490             mode = "min";
491         }
492
493         const auto k_constant = op::Constant::create(element::i64, Shape{}, {k});
494         auto replacement_node =
495             make_shared<op::v1::TopK>(node->input_value(0), k_constant, axis, mode, sort);
496
497         // indices output will be 0, values 1
498         vector<int64_t> output_order{1, 0};
499         replace_node(node, replacement_node, output_order);
500         return replacement_node;
501     }
502
503     shared_ptr<Node> op_cast(shared_ptr<op::Xor> node)
504     {
505         auto replacement_node = make_shared<op::v1::LogicalXor>(
506             node->input_value(0), node->input_value(1), node->get_autob());
507         replace_node(node, replacement_node);
508         return replacement_node;
509     }
510
511     using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
512
513     template <typename T>
514     bool op_cast_thunk(shared_ptr<Node> node)
515     {
516         auto upgraded_node = op_cast(as_type_ptr<T>(node));
517         if (upgraded_node)
518         {
519             if (ngraph::get_provenance_enabled())
520             {
521                 const std::string provenance_tag =
522                     "<Opset1_Upgrade (v0 " + std::string(node->get_type_name()) + ")>";
523                 upgraded_node->add_provenance_tags_above(node->input_values(), {provenance_tag});
524             }
525             return true;
526         }
527         return false;
528     }
529
530     DispatchMap& get_dispatch_map()
531     {
532         NGRAPH_SUPPRESS_DEPRECATED_START
533         static DispatchMap dispatch_map{
534 #define NGRAPH_OP(NAME, NAMESPACE) {NAMESPACE::NAME::type_info, op_cast_thunk<NAMESPACE::NAME>},
535 #include "opset0_tbl.hpp"
536 #undef NGRAPH_OP
537         };
538         return dispatch_map;
539         NGRAPH_SUPPRESS_DEPRECATED_END
540     }
541 } // namespace opset1_upgrade
542
543 bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
544 {
545     bool modified = false;
546     auto& dispatch_map = opset1_upgrade::get_dispatch_map();
547     auto it = dispatch_map.find(node->get_type_info());
548     if (it != dispatch_map.end())
549     {
550         modified = it->second(node);
551     }
552     return modified;
553 }