Remove obsoleted v0::Gather and v0::GatherND (#2826)
[platform/upstream/dldt.git] / ngraph / test / runtime / pass / opset1_upgrade.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 #include "opset1_upgrade.hpp"
17
18 #include <functional>
19 #include <iterator>
20 #include <limits>
21 #include <numeric>
22
23 #include "ngraph/builder/autobroadcast.hpp"
24 #include "ngraph/builder/reshape.hpp"
25 #include "ngraph/graph_util.hpp"
26 #include "ngraph/op/util/op_types.hpp"
27 #include "ngraph/ops.hpp"
28 #include "ngraph/provenance.hpp"
29 #include "op/avg_pool.hpp"
30 #include "op/convolution.hpp"
31 #include "op/group_conv.hpp"
32
33 NGRAPH_SUPPRESS_DEPRECATED_START
34
35 using namespace std;
36 using namespace ngraph;
37
38 namespace opset1_upgrade
39 {
40     template <typename OpV0, typename OpV1>
41     shared_ptr<Node> op_cast_binary_elementwise_node(const shared_ptr<OpV0>& node)
42     {
43         const auto autob = node->get_autob();
44         auto replacement_node =
45             make_shared<OpV1>(node->input_value(0), node->input_value(1), autob);
46         replace_node(node, replacement_node);
47         return replacement_node;
48     }
49
50     // Default is that we didn nothing
51     shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; }
52     shared_ptr<Node> op_cast(shared_ptr<op::Add> node)
53     {
54         return op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
55     }
56
57     shared_ptr<Node> op_cast(shared_ptr<op::v0::Convolution> node)
58     {
59         auto strides = node->get_window_movement_strides();
60         auto dilations = node->get_window_dilation_strides();
61         auto pads_begin = node->get_padding_below();
62         auto pads_end = node->get_padding_above();
63         auto data_dilation_strides = node->get_data_dilation_strides();
64         auto auto_pad = node->get_pad_type();
65
66         bool is_dds_valid = all_of(data_dilation_strides.begin(),
67                                    data_dilation_strides.end(),
68                                    [](size_t value) { return value == 1; });
69
70         NGRAPH_CHECK(is_dds_valid,
71                      "Unable to convert Convolution:0 to Convolution:1 with data dilation strides "
72                      "other than `1`. Node: ",
73                      *node);
74
75         auto replacement_node = make_shared<op::v1::Convolution>(node->input_value(0),
76                                                                  node->input_value(1),
77                                                                  strides,
78                                                                  pads_begin,
79                                                                  pads_end,
80                                                                  dilations,
81                                                                  auto_pad);
82         replace_node(node, replacement_node);
83         return replacement_node;
84     }
85
86     shared_ptr<Node> op_cast(shared_ptr<op::v0::ConvolutionBackpropData> node)
87     {
88         auto data_batch_shape = node->get_data_batch_shape();
89         auto strides = node->get_window_movement_strides_forward();
90         auto dilations = node->get_window_dilation_strides_forward();
91         auto pads_begin = node->get_padding_below_forward();
92         auto pads_end = node->get_padding_above_forward();
93         auto data_dilation_strides = node->get_data_dilation_strides_forward();
94
95         bool is_dds_valid = all_of(data_dilation_strides.begin(),
96                                    data_dilation_strides.end(),
97                                    [](size_t value) { return value == 1; });
98
99         NGRAPH_CHECK(is_dds_valid,
100                      "Unable to convert ConvolutionBackpropData:0 to ConvolutionBackpropData:1 "
101                      "with data dilation strides "
102                      "other than `1`. Node: ",
103                      *node);
104
105         auto replacement_node = make_shared<op::v1::ConvolutionBackpropData>(
106             node->input_value(1), // data
107             node->input_value(0), // filters
108             op::Constant::create(
109                 element::i64,
110                 Shape{data_batch_shape.size() - 2},
111                 vector<size_t>(data_batch_shape.begin() + 2, data_batch_shape.end())),
112             strides,
113             pads_begin,
114             pads_end,
115             dilations);
116         replace_node(node, replacement_node);
117         return replacement_node;
118     }
119
120     shared_ptr<Node> op_cast(shared_ptr<op::Divide> node)
121     {
122         const auto autob = node->get_autob();
123         const bool pydiv = node->is_pythondiv();
124         auto replacement_node =
125             make_shared<op::v1::Divide>(node->input_value(0), node->input_value(1), pydiv, autob);
126         replace_node(node, replacement_node);
127         return replacement_node;
128     }
129
130     shared_ptr<Node> op_cast(shared_ptr<op::Reshape> node)
131     {
132         shared_ptr<Node> replacement_node =
133             builder::opset1::reshape(node->input_value(0), node->get_reshape_output_shape());
134         replace_node(node, replacement_node);
135         return replacement_node;
136     }
137
138     shared_ptr<Node> op_cast(shared_ptr<op::Equal> node)
139     {
140         return op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
141     }
142
143     shared_ptr<Node> op_cast(shared_ptr<op::Greater> node)
144     {
145         return op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
146     }
147
148     shared_ptr<Node> op_cast(shared_ptr<op::GreaterEq> node)
149     {
150         return op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node);
151     }
152
153     shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolution> node)
154     {
155         auto strides = node->get_window_movement_strides();
156         auto dilations = node->get_window_dilation_strides();
157         auto pads_begin = node->get_padding_below();
158         auto pads_end = node->get_padding_above();
159         auto data_dilation_strides = node->get_data_dilation_strides();
160         auto auto_pad = node->get_pad_type();
161
162         bool is_dds_valid = all_of(data_dilation_strides.begin(),
163                                    data_dilation_strides.end(),
164                                    [](size_t value) { return value == 1; });
165
166         NGRAPH_CHECK(is_dds_valid,
167                      "Unable to convert GroupConvolution:0 to GroupConvolution:1"
168                      "with data dilation strides other than `1`. Node: ",
169                      *node);
170
171         shared_ptr<Node> replacement_node;
172         if (node->has_groups_in_filters())
173         {
174             replacement_node = make_shared<op::v1::GroupConvolution>(node->input_value(0),
175                                                                      node->input_value(1),
176                                                                      strides,
177                                                                      pads_begin,
178                                                                      pads_end,
179                                                                      dilations,
180                                                                      auto_pad);
181         }
182         else
183         {
184             NGRAPH_CHECK(node->get_input_partial_shape(1).is_static(),
185                          "Unable to convert GroupConvolution:0 to GroupConvolution:1"
186                          "with dynamic filters shape. Node: ",
187                          *node);
188
189             auto filters_shape = node->get_input_shape(1);
190             auto groups = node->get_groups();
191             filters_shape[0] /= groups;
192             filters_shape.insert(filters_shape.begin(), groups);
193
194             auto reshaped_filters = builder::opset1::reshape(node->input_value(1), filters_shape);
195
196             replacement_node = make_shared<op::v1::GroupConvolution>(node->input_value(0),
197                                                                      reshaped_filters,
198                                                                      strides,
199                                                                      pads_begin,
200                                                                      pads_end,
201                                                                      dilations,
202                                                                      auto_pad);
203         }
204         replace_node(node, replacement_node);
205         return replacement_node;
206     }
207
208     shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolutionBackpropData> node)
209     {
210         const auto strides = node->get_window_movement_strides();
211         const auto dilations = node->get_window_dilation_strides();
212         const auto pads_begin = node->get_padding_below();
213         const auto pads_end = node->get_padding_above();
214
215         const auto data_batch_pshape = node->get_input_partial_shape(0);
216         const auto filters_pshape = node->get_input_partial_shape(1);
217
218         NGRAPH_CHECK(data_batch_pshape.is_static(),
219                      "Unable to convert GroupConvolutionBackpropData:0 to "
220                      "GroupConvolutionBackpropData:1 with dynamic data_batch shape. Node: ",
221                      *node);
222         NGRAPH_CHECK(filters_pshape.is_static(),
223                      "Unable to convert GroupConvolutionBackpropData:0 to "
224                      "GroupConvolutionBackpropData:1 with dynamic filters shape. Node: ",
225                      *node);
226
227         auto data_batch_shape = data_batch_pshape.to_shape();
228         // Remove N, C from output shape to preserve only spatial dimentions.
229         data_batch_shape.erase(std::begin(data_batch_shape),
230                                std::next(std::begin(data_batch_shape), 2));
231         auto filters_shape = filters_pshape.to_shape();
232         auto groups = node->get_groups();
233
234         filters_shape[0] /= groups;
235         filters_shape.insert(filters_shape.begin(), groups);
236         auto reshaped_filters = builder::opset1::reshape(node->input_value(1), filters_shape);
237
238         auto replacement_node = make_shared<op::v1::GroupConvolutionBackpropData>(
239             node->input_value(2),
240             reshaped_filters,
241             op::Constant::create(element::i64, Shape{data_batch_shape.size()}, data_batch_shape),
242             strides,
243             pads_begin,
244             pads_end,
245             dilations);
246         replace_node(node, replacement_node);
247         return replacement_node;
248     }
249
250     shared_ptr<Node> op_cast(shared_ptr<op::Less> node)
251     {
252         return op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
253     }
254
255     shared_ptr<Node> op_cast(shared_ptr<op::LessEq> node)
256     {
257         return op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
258     }
259
260     shared_ptr<Node> op_cast(shared_ptr<op::Max> node)
261     {
262         bool keep_dims = false;
263         auto replacement_node =
264             make_shared<op::v1::ReduceMax>(node->input_value(0), node->input_value(1), keep_dims);
265         replace_node(node, replacement_node);
266         return replacement_node;
267     }
268
269     shared_ptr<Node> op_cast(shared_ptr<op::Maximum> node)
270     {
271         return op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
272     }
273
274     shared_ptr<Node> op_cast(shared_ptr<op::Min> node)
275     {
276         bool keep_dims = false;
277         auto replacement_node =
278             make_shared<op::v1::ReduceMin>(node->input_value(0), node->input_value(1), keep_dims);
279         replace_node(node, replacement_node);
280         return replacement_node;
281     }
282
283     shared_ptr<Node> op_cast(shared_ptr<op::Minimum> node)
284     {
285         return op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
286     }
287
288     shared_ptr<Node> op_cast(shared_ptr<op::Multiply> node)
289     {
290         return op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
291     }
292
293     shared_ptr<Node> op_cast(shared_ptr<op::Not> node)
294     {
295         auto replacement_node = make_shared<op::v1::LogicalNot>(node->input_value(0));
296         replace_node(node, replacement_node);
297         return replacement_node;
298     }
299
300     shared_ptr<Node> op_cast(shared_ptr<op::NotEqual> node)
301     {
302         return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
303     }
304
305     shared_ptr<Node> op_cast(shared_ptr<op::OneHot> node)
306     {
307         const auto indices = node->input_value(0).get_node_shared_ptr();
308         const auto one_hot_axis = node->get_one_hot_axis();
309
310         const auto output_pshape = node->get_output_partial_shape(0);
311         NGRAPH_CHECK(output_pshape[one_hot_axis].is_static(),
312                      "OneHot:v0 one hot axis dimension must be static ",
313                      *node);
314         const auto depth = output_pshape[one_hot_axis].get_length();
315         const auto depth_node = op::Constant::create(element::i64, Shape{}, {depth});
316
317         const auto on_value = op::Constant::create(element::i64, Shape{}, {1});
318         const auto off_value = op::Constant::create(element::i64, Shape{}, {0});
319
320         auto replacement_node =
321             make_shared<op::v1::OneHot>(indices, depth_node, on_value, off_value, one_hot_axis);
322         replace_node(node, replacement_node);
323         return replacement_node;
324     }
325
326     shared_ptr<Node> op_cast(shared_ptr<op::Or> node)
327     {
328         return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
329     }
330
331     shared_ptr<Node> op_cast(shared_ptr<op::Power> node)
332     {
333         return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
334     }
335
336     shared_ptr<Node> op_cast(shared_ptr<op::Product> node)
337     {
338         bool keep_dims = false;
339         auto replacement_node =
340             make_shared<op::v1::ReduceProd>(node->input_value(0), node->input_value(1), keep_dims);
341         replace_node(node, replacement_node);
342         return replacement_node;
343     }
344
345     shared_ptr<Node> op_cast(shared_ptr<op::Reverse> node)
346     {
347         // creates a Constant node from the v0::Reverse reversed_axes attribute
348         // and uses it as the second input of v1::Reverse
349         const auto reversed_axes = node->get_reversed_axes();
350
351         const auto reversed_axes_constant = op::Constant::create(
352             element::i64, Shape{reversed_axes.size()}, reversed_axes.to_vector());
353
354         const auto replacement_node = make_shared<op::v1::Reverse>(
355             node->input_value(0), reversed_axes_constant, op::v1::Reverse::Mode::INDEX);
356
357         replace_node(node, replacement_node);
358         return replacement_node;
359     }
360
361     shared_ptr<Node> op_cast(shared_ptr<op::Select> node)
362     {
363         auto replacement_node = make_shared<op::v1::Select>(node->input_value(0),
364                                                             node->input_value(1),
365                                                             node->input_value(2),
366                                                             op::AutoBroadcastSpec());
367         replace_node(node, replacement_node);
368         return replacement_node;
369     }
370
371     shared_ptr<Node> op_cast(shared_ptr<op::Softmax> node)
372     {
373         NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
374                      "axes parameter is expected to be a static constant");
375
376         AxisSet axes = node->get_axes();
377
378         NGRAPH_CHECK(
379             axes.size() == 1,
380             "Unable to convert Softmax:0 to Softmax:1 with zero or more than one axis. Node: ",
381             *node);
382
383         auto replacement_node =
384             make_shared<op::v1::Softmax>(node->input_value(0), axes.to_vector()[0]);
385         replace_node(node, replacement_node);
386         return replacement_node;
387     }
388
389     shared_ptr<Node> op_cast(shared_ptr<op::Slice> node)
390     {
391         const auto data = node->input_value(0);
392         const auto begin = op::Constant::create(
393             element::i64, Shape{node->get_lower_bounds().size()}, node->get_lower_bounds());
394         const auto end = op::Constant::create(
395             element::i64, Shape{node->get_upper_bounds().size()}, node->get_upper_bounds());
396         const auto strides = op::Constant::create(
397             element::i64, Shape{node->get_strides().size()}, node->get_strides());
398         int64_t input_size = node->get_lower_bounds().size();
399
400         auto replacement_node = make_shared<op::v1::StridedSlice>(data,
401                                                                   begin,
402                                                                   end,
403                                                                   strides,
404                                                                   vector<int64_t>(input_size, 0),
405                                                                   vector<int64_t>(input_size, 0));
406
407         replace_node(node, replacement_node);
408         return replacement_node;
409     }
410
411     shared_ptr<Node> op_cast(shared_ptr<op::Split> node)
412     {
413         const auto& splits_vec = node->get_splits();
414         const auto first_elem = splits_vec.front();
415
416         const bool split_evenly =
417             std::all_of(splits_vec.begin(), splits_vec.end(), [first_elem](const size_t split) {
418                 return split == first_elem;
419             });
420
421         std::shared_ptr<Node> replacement_node;
422         if (split_evenly)
423         {
424             replacement_node = make_shared<op::v1::Split>(
425                 node->input_value(0), node->input_value(1), splits_vec.front());
426         }
427         else
428         {
429             const auto split_lengths =
430                 ngraph::op::Constant::create(element::u64, Shape{splits_vec.size()}, splits_vec);
431
432             replacement_node = make_shared<op::v1::VariadicSplit>(
433                 node->input_value(0), node->input_value(1), split_lengths);
434         }
435
436         replace_node(node, replacement_node);
437         return replacement_node;
438     }
439
440     shared_ptr<Node> op_cast(shared_ptr<op::Subtract> node)
441     {
442         return op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
443     }
444
445     shared_ptr<Node> op_cast(shared_ptr<op::Sum> node)
446     {
447         bool keep_dims = false;
448         auto replacement_node =
449             make_shared<op::v1::ReduceSum>(node->input_value(0), node->input_value(1), keep_dims);
450         replace_node(node, replacement_node);
451         return replacement_node;
452     }
453
454     shared_ptr<Node> op_cast(shared_ptr<op::TopK> node)
455     {
456         NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
457                      "parameter k is expected to be a static constant");
458         NGRAPH_CHECK(op::is_constant(node->input_value(2).get_node()),
459                      "parameter top_k_axis is expected to be a static constant");
460
461         const auto k = node->get_k();
462         const auto axis = node->get_top_k_axis();
463
464         std::string sort;
465         switch (node->get_sort())
466         {
467         case op::TopK::SortType::SORT_INDICES: sort = "index"; break;
468         case op::TopK::SortType::SORT_VALUES: sort = "value"; break;
469         case op::TopK::SortType::NONE: sort = "none"; break;
470         }
471
472         std::string mode;
473         if (node->get_compute_max())
474         {
475             mode = "max";
476         }
477         else
478         {
479             mode = "min";
480         }
481
482         const auto k_constant = op::Constant::create(element::i64, Shape{}, {k});
483         auto replacement_node =
484             make_shared<op::v1::TopK>(node->input_value(0), k_constant, axis, mode, sort);
485
486         // indices output will be 0, values 1
487         vector<int64_t> output_order{1, 0};
488         replace_node(node, replacement_node, output_order);
489         return replacement_node;
490     }
491
492     shared_ptr<Node> op_cast(shared_ptr<op::Xor> node)
493     {
494         auto replacement_node = make_shared<op::v1::LogicalXor>(
495             node->input_value(0), node->input_value(1), node->get_autob());
496         replace_node(node, replacement_node);
497         return replacement_node;
498     }
499
500     using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
501
502     template <typename T>
503     bool op_cast_thunk(shared_ptr<Node> node)
504     {
505         auto upgraded_node = op_cast(as_type_ptr<T>(node));
506         if (upgraded_node)
507         {
508             if (ngraph::get_provenance_enabled())
509             {
510                 const std::string provenance_tag =
511                     "<Opset1_Upgrade (v0 " + std::string(node->get_type_name()) + ")>";
512                 upgraded_node->add_provenance_tags_above(node->input_values(), {provenance_tag});
513             }
514             return true;
515         }
516         return false;
517     }
518
519     DispatchMap& get_dispatch_map()
520     {
521         NGRAPH_SUPPRESS_DEPRECATED_START
522         static DispatchMap dispatch_map{
523 #define NGRAPH_OP(NAME, NAMESPACE) {NAMESPACE::NAME::type_info, op_cast_thunk<NAMESPACE::NAME>},
524 #include "opset0_tbl.hpp"
525 #undef NGRAPH_OP
526         };
527         return dispatch_map;
528         NGRAPH_SUPPRESS_DEPRECATED_END
529     }
530 } // namespace opset1_upgrade
531
532 bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
533 {
534     bool modified = false;
535     auto& dispatch_map = opset1_upgrade::get_dispatch_map();
536     auto it = dispatch_map.find(node->get_type_info());
537     if (it != dispatch_map.end())
538     {
539         modified = it->second(node);
540     }
541     return modified;
542 }