Move downgrade passes to pass folder (#1675)
[platform/upstream/dldt.git] / ngraph / test / runtime / pass / opset1_upgrade.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 #include "opset1_upgrade.hpp"
17
18 #include <functional>
19 #include <iterator>
20 #include <limits>
21 #include <numeric>
22
23 #include "ngraph/builder/autobroadcast.hpp"
24 #include "ngraph/builder/reshape.hpp"
25 #include "ngraph/graph_util.hpp"
26 #include "ngraph/op/util/op_types.hpp"
27 #include "ngraph/ops.hpp"
28 #include "ngraph/provenance.hpp"
29 #include "op/avg_pool.hpp"
30 #include "op/convolution.hpp"
31 #include "op/group_conv.hpp"
32
33 using namespace std;
34 using namespace ngraph;
35
36 namespace
37 {
38     template <typename OpV0, typename OpV1>
39     shared_ptr<Node> op_cast_binary_elementwise_node(const shared_ptr<OpV0>& node)
40     {
41         const auto autob = node->get_autob();
42         auto replacement_node =
43             make_shared<OpV1>(node->input_value(0), node->input_value(1), autob);
44         replace_node(node, replacement_node);
45         return replacement_node;
46     }
47
48     // Default is that we didn nothing
49     shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; }
50     shared_ptr<Node> op_cast(shared_ptr<op::Add> node)
51     {
52         return op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
53     }
54
55     shared_ptr<Node> op_cast(shared_ptr<op::Broadcast> node)
56     {
57         auto replacement_node = ngraph::builder::opset1::make_broadcast(
58             node->input_value(0), node->get_broadcast_shape(), node->get_broadcast_axes());
59         replace_node(node, replacement_node.get_node_shared_ptr());
60         return replacement_node.get_node_shared_ptr();
61     }
62
63     shared_ptr<Node> op_cast(shared_ptr<op::BroadcastLike> node) { return nullptr; }
64     shared_ptr<Node> op_cast(shared_ptr<op::v0::Convolution> node)
65     {
66         auto strides = node->get_window_movement_strides();
67         auto dilations = node->get_window_dilation_strides();
68         auto pads_begin = node->get_padding_below();
69         auto pads_end = node->get_padding_above();
70         auto data_dilation_strides = node->get_data_dilation_strides();
71         auto auto_pad = node->get_pad_type();
72
73         bool is_dds_valid = all_of(data_dilation_strides.begin(),
74                                    data_dilation_strides.end(),
75                                    [](size_t value) { return value == 1; });
76
77         NGRAPH_CHECK(is_dds_valid,
78                      "Unable to convert Convolution:0 to Convolution:1 with data dilation strides "
79                      "other than `1`. Node: ",
80                      *node);
81
82         auto replacement_node = make_shared<op::v1::Convolution>(node->input_value(0),
83                                                                  node->input_value(1),
84                                                                  strides,
85                                                                  pads_begin,
86                                                                  pads_end,
87                                                                  dilations,
88                                                                  auto_pad);
89         replace_node(node, replacement_node);
90         return replacement_node;
91     }
92
93     shared_ptr<Node> op_cast(shared_ptr<op::v0::ConvolutionBackpropData> node)
94     {
95         auto data_batch_shape = node->get_data_batch_shape();
96         auto strides = node->get_window_movement_strides_forward();
97         auto dilations = node->get_window_dilation_strides_forward();
98         auto pads_begin = node->get_padding_below_forward();
99         auto pads_end = node->get_padding_above_forward();
100         auto data_dilation_strides = node->get_data_dilation_strides_forward();
101
102         bool is_dds_valid = all_of(data_dilation_strides.begin(),
103                                    data_dilation_strides.end(),
104                                    [](size_t value) { return value == 1; });
105
106         NGRAPH_CHECK(is_dds_valid,
107                      "Unable to convert ConvolutionBackpropData:0 to ConvolutionBackpropData:1 "
108                      "with data dilation strides "
109                      "other than `1`. Node: ",
110                      *node);
111
112         auto replacement_node = make_shared<op::v1::ConvolutionBackpropData>(
113             node->input_value(1), // data
114             node->input_value(0), // filters
115             op::Constant::create(
116                 element::i64,
117                 Shape{data_batch_shape.size() - 2},
118                 vector<size_t>(data_batch_shape.begin() + 2, data_batch_shape.end())),
119             strides,
120             pads_begin,
121             pads_end,
122             dilations);
123         replace_node(node, replacement_node);
124         return replacement_node;
125     }
126
127     shared_ptr<Node> op_cast(shared_ptr<op::Divide> node)
128     {
129         const auto autob = node->get_autob();
130         const bool pydiv = node->is_pythondiv();
131         auto replacement_node =
132             make_shared<op::v1::Divide>(node->input_value(0), node->input_value(1), pydiv, autob);
133         replace_node(node, replacement_node);
134         return replacement_node;
135     }
136
137     shared_ptr<Node> op_cast(shared_ptr<op::Reshape> node)
138     {
139         shared_ptr<Node> replacement_node =
140             builder::opset1::reshape(node->input_value(0), node->get_reshape_output_shape());
141         replace_node(node, replacement_node);
142         return replacement_node;
143     }
144
145     shared_ptr<Node> op_cast(shared_ptr<op::Equal> node)
146     {
147         return op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
148     }
149
150     shared_ptr<Node> op_cast(shared_ptr<op::Gather> node)
151     {
152         int64_t axis = node->get_axis();
153
154         auto axis_node = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{axis});
155         auto replacement_node =
156             make_shared<op::v1::Gather>(node->input_value(0), node->input_value(1), axis_node);
157         replace_node(node, replacement_node);
158         return replacement_node;
159     }
160
161     shared_ptr<Node> op_cast(shared_ptr<op::Greater> node)
162     {
163         return op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
164     }
165
166     shared_ptr<Node> op_cast(shared_ptr<op::GreaterEq> node)
167     {
168         return op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node);
169     }
170
171     shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolution> node)
172     {
173         auto strides = node->get_window_movement_strides();
174         auto dilations = node->get_window_dilation_strides();
175         auto pads_begin = node->get_padding_below();
176         auto pads_end = node->get_padding_above();
177         auto data_dilation_strides = node->get_data_dilation_strides();
178         auto auto_pad = node->get_pad_type();
179
180         bool is_dds_valid = all_of(data_dilation_strides.begin(),
181                                    data_dilation_strides.end(),
182                                    [](size_t value) { return value == 1; });
183
184         NGRAPH_CHECK(is_dds_valid,
185                      "Unable to convert GroupConvolution:0 to GroupConvolution:1"
186                      "with data dilation strides other than `1`. Node: ",
187                      *node);
188
189         shared_ptr<Node> replacement_node;
190         if (node->has_groups_in_filters())
191         {
192             replacement_node = make_shared<op::v1::GroupConvolution>(node->input_value(0),
193                                                                      node->input_value(1),
194                                                                      strides,
195                                                                      pads_begin,
196                                                                      pads_end,
197                                                                      dilations,
198                                                                      auto_pad);
199         }
200         else
201         {
202             NGRAPH_CHECK(node->get_input_partial_shape(1).is_static(),
203                          "Unable to convert GroupConvolution:0 to GroupConvolution:1"
204                          "with dynamic filters shape. Node: ",
205                          *node);
206
207             auto filters_shape = node->get_input_shape(1);
208             auto groups = node->get_groups();
209             filters_shape[0] /= groups;
210             filters_shape.insert(filters_shape.begin(), groups);
211
212             auto reshaped_filters = builder::reshape(node->input_value(1), filters_shape);
213
214             replacement_node = make_shared<op::v1::GroupConvolution>(node->input_value(0),
215                                                                      reshaped_filters,
216                                                                      strides,
217                                                                      pads_begin,
218                                                                      pads_end,
219                                                                      dilations,
220                                                                      auto_pad);
221         }
222         replace_node(node, replacement_node);
223         return replacement_node;
224     }
225
226     shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolutionBackpropData> node)
227     {
228         const auto strides = node->get_window_movement_strides();
229         const auto dilations = node->get_window_dilation_strides();
230         const auto pads_begin = node->get_padding_below();
231         const auto pads_end = node->get_padding_above();
232
233         const auto data_batch_pshape = node->get_input_partial_shape(0);
234         const auto filters_pshape = node->get_input_partial_shape(1);
235
236         NGRAPH_CHECK(data_batch_pshape.is_static(),
237                      "Unable to convert GroupConvolutionBackpropData:0 to "
238                      "GroupConvolutionBackpropData:1 with dynamic data_batch shape. Node: ",
239                      *node);
240         NGRAPH_CHECK(filters_pshape.is_static(),
241                      "Unable to convert GroupConvolutionBackpropData:0 to "
242                      "GroupConvolutionBackpropData:1 with dynamic filters shape. Node: ",
243                      *node);
244
245         auto data_batch_shape = data_batch_pshape.to_shape();
246         // Remove N, C from output shape to preserve only spatial dimentions.
247         data_batch_shape.erase(std::begin(data_batch_shape),
248                                std::next(std::begin(data_batch_shape), 2));
249         auto filters_shape = filters_pshape.to_shape();
250         auto groups = node->get_groups();
251
252         filters_shape[0] /= groups;
253         filters_shape.insert(filters_shape.begin(), groups);
254         auto reshaped_filters = builder::reshape(node->input_value(1), filters_shape);
255
256         auto replacement_node = make_shared<op::v1::GroupConvolutionBackpropData>(
257             node->input_value(2),
258             reshaped_filters,
259             op::Constant::create(element::i64, Shape{data_batch_shape.size()}, data_batch_shape),
260             strides,
261             pads_begin,
262             pads_end,
263             dilations);
264         replace_node(node, replacement_node);
265         return replacement_node;
266     }
267
268     shared_ptr<Node> op_cast(shared_ptr<op::Less> node)
269     {
270         return op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
271     }
272
273     shared_ptr<Node> op_cast(shared_ptr<op::LessEq> node)
274     {
275         return op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
276     }
277
278     shared_ptr<Node> op_cast(shared_ptr<op::Max> node)
279     {
280         bool keep_dims = false;
281         auto replacement_node =
282             make_shared<op::v1::ReduceMax>(node->input_value(0), node->input_value(1), keep_dims);
283         replace_node(node, replacement_node);
284         return replacement_node;
285     }
286
287     shared_ptr<Node> op_cast(shared_ptr<op::Maximum> node)
288     {
289         return op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
290     }
291
292     shared_ptr<Node> op_cast(shared_ptr<op::Min> node)
293     {
294         bool keep_dims = false;
295         auto replacement_node =
296             make_shared<op::v1::ReduceMin>(node->input_value(0), node->input_value(1), keep_dims);
297         replace_node(node, replacement_node);
298         return replacement_node;
299     }
300
301     shared_ptr<Node> op_cast(shared_ptr<op::Minimum> node)
302     {
303         return op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
304     }
305
306     shared_ptr<Node> op_cast(shared_ptr<op::Multiply> node)
307     {
308         return op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
309     }
310
311     shared_ptr<Node> op_cast(shared_ptr<op::Not> node)
312     {
313         auto replacement_node = make_shared<op::v1::LogicalNot>(node->input_value(0));
314         replace_node(node, replacement_node);
315         return replacement_node;
316     }
317
318     shared_ptr<Node> op_cast(shared_ptr<op::NotEqual> node)
319     {
320         return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
321     }
322
323     shared_ptr<Node> op_cast(shared_ptr<op::OneHot> node)
324     {
325         const auto indices = node->input_value(0).get_node_shared_ptr();
326         const auto one_hot_axis = node->get_one_hot_axis();
327
328         const auto output_pshape = node->get_output_partial_shape(0);
329         NGRAPH_CHECK(output_pshape[one_hot_axis].is_static(),
330                      "OneHot:v0 one hot axis dimension must be static ",
331                      *node);
332         const auto depth = output_pshape[one_hot_axis].get_length();
333         const auto depth_node = op::Constant::create(element::i64, Shape{}, {depth});
334
335         const auto on_value = op::Constant::create(element::i64, Shape{}, {1});
336         const auto off_value = op::Constant::create(element::i64, Shape{}, {0});
337
338         auto replacement_node =
339             make_shared<op::v1::OneHot>(indices, depth_node, on_value, off_value, one_hot_axis);
340         replace_node(node, replacement_node);
341         return replacement_node;
342     }
343
344     shared_ptr<Node> op_cast(shared_ptr<op::Or> node)
345     {
346         return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
347     }
348
349     shared_ptr<Node> op_cast(shared_ptr<op::Pad> node)
350     {
351         auto padding_below = node->get_padding_below();
352         auto pads_begin_node =
353             make_shared<op::Constant>(element::i64, Shape{padding_below.size()}, padding_below);
354         auto padding_above = node->get_padding_above();
355         auto pads_end_node =
356             make_shared<op::Constant>(element::i64, Shape{padding_above.size()}, padding_above);
357
358         auto replacement_node = make_shared<op::v1::Pad>(node->input_value(0),
359                                                          pads_begin_node,
360                                                          pads_end_node,
361                                                          node->input_value(1),
362                                                          node->get_pad_mode());
363
364         replace_node(node, replacement_node);
365         return replacement_node;
366     }
367
368     shared_ptr<Node> op_cast(shared_ptr<op::Power> node)
369     {
370         return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
371     }
372
373     shared_ptr<Node> op_cast(shared_ptr<op::Product> node)
374     {
375         bool keep_dims = false;
376         auto replacement_node =
377             make_shared<op::v1::ReduceProd>(node->input_value(0), node->input_value(1), keep_dims);
378         replace_node(node, replacement_node);
379         return replacement_node;
380     }
381
382     shared_ptr<Node> op_cast(shared_ptr<op::Reverse> node)
383     {
384         // creates a Constant node from the v0::Reverse reversed_axes attribute
385         // and uses it as the second input of v1::Reverse
386         const auto reversed_axes = node->get_reversed_axes();
387
388         const auto reversed_axes_constant = op::Constant::create(
389             element::i64, Shape{reversed_axes.size()}, reversed_axes.to_vector());
390
391         const auto replacement_node = make_shared<op::v1::Reverse>(
392             node->input_value(0), reversed_axes_constant, op::v1::Reverse::Mode::INDEX);
393
394         replace_node(node, replacement_node);
395         return replacement_node;
396     }
397
398     shared_ptr<Node> op_cast(shared_ptr<op::Select> node)
399     {
400         auto replacement_node = make_shared<op::v1::Select>(node->input_value(0),
401                                                             node->input_value(1),
402                                                             node->input_value(2),
403                                                             op::AutoBroadcastSpec());
404         replace_node(node, replacement_node);
405         return replacement_node;
406     }
407
408     shared_ptr<Node> op_cast(shared_ptr<op::Softmax> node)
409     {
410         NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
411                      "axes parameter is expected to be a static constant");
412
413         AxisSet axes = node->get_axes();
414
415         NGRAPH_CHECK(
416             axes.size() == 1,
417             "Unable to convert Softmax:0 to Softmax:1 with zero or more than one axis. Node: ",
418             *node);
419
420         auto replacement_node =
421             make_shared<op::v1::Softmax>(node->input_value(0), axes.to_vector()[0]);
422         replace_node(node, replacement_node);
423         return replacement_node;
424     }
425
426     shared_ptr<Node> op_cast(shared_ptr<op::Slice> node)
427     {
428         const auto data = node->input_value(0);
429         const auto begin = op::Constant::create(
430             element::i64, Shape{node->get_lower_bounds().size()}, node->get_lower_bounds());
431         const auto end = op::Constant::create(
432             element::i64, Shape{node->get_upper_bounds().size()}, node->get_upper_bounds());
433         const auto strides = op::Constant::create(
434             element::i64, Shape{node->get_strides().size()}, node->get_strides());
435         int64_t input_size = node->get_lower_bounds().size();
436
437         auto replacement_node = make_shared<op::v1::StridedSlice>(data,
438                                                                   begin,
439                                                                   end,
440                                                                   strides,
441                                                                   vector<int64_t>(input_size, 0),
442                                                                   vector<int64_t>(input_size, 0));
443
444         replace_node(node, replacement_node);
445         return replacement_node;
446     }
447
448     shared_ptr<Node> op_cast(shared_ptr<op::Split> node)
449     {
450         const auto& splits_vec = node->get_splits();
451         const auto first_elem = splits_vec.front();
452
453         const bool split_evenly =
454             std::all_of(splits_vec.begin(), splits_vec.end(), [first_elem](const size_t split) {
455                 return split == first_elem;
456             });
457
458         std::shared_ptr<Node> replacement_node;
459         if (split_evenly)
460         {
461             replacement_node = make_shared<op::v1::Split>(
462                 node->input_value(0), node->input_value(1), splits_vec.front());
463         }
464         else
465         {
466             const auto split_lengths =
467                 ngraph::op::Constant::create(element::u64, Shape{splits_vec.size()}, splits_vec);
468
469             replacement_node = make_shared<op::v1::VariadicSplit>(
470                 node->input_value(0), node->input_value(1), split_lengths);
471         }
472
473         replace_node(node, replacement_node);
474         return replacement_node;
475     }
476
477     shared_ptr<Node> op_cast(shared_ptr<op::Subtract> node)
478     {
479         return op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
480     }
481
482     shared_ptr<Node> op_cast(shared_ptr<op::Sum> node)
483     {
484         bool keep_dims = false;
485         auto replacement_node =
486             make_shared<op::v1::ReduceSum>(node->input_value(0), node->input_value(1), keep_dims);
487         replace_node(node, replacement_node);
488         return replacement_node;
489     }
490
491     shared_ptr<Node> op_cast(shared_ptr<op::TopK> node)
492     {
493         NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
494                      "parameter k is expected to be a static constant");
495         NGRAPH_CHECK(op::is_constant(node->input_value(2).get_node()),
496                      "parameter top_k_axis is expected to be a static constant");
497
498         const auto k = node->get_k();
499         const auto axis = node->get_top_k_axis();
500
501         std::string sort;
502         switch (node->get_sort())
503         {
504         case op::TopK::SortType::SORT_INDICES: sort = "index"; break;
505         case op::TopK::SortType::SORT_VALUES: sort = "value"; break;
506         case op::TopK::SortType::NONE: sort = "none"; break;
507         }
508
509         std::string mode;
510         if (node->get_compute_max())
511         {
512             mode = "max";
513         }
514         else
515         {
516             mode = "min";
517         }
518
519         const auto k_constant = op::Constant::create(element::i64, Shape{}, {k});
520         auto replacement_node =
521             make_shared<op::v1::TopK>(node->input_value(0), k_constant, axis, mode, sort);
522
523         // indices output will be 0, values 1
524         vector<int64_t> output_order{1, 0};
525         replace_node(node, replacement_node, output_order);
526         return replacement_node;
527     }
528
529     shared_ptr<Node> op_cast(shared_ptr<op::Xor> node)
530     {
531         auto replacement_node = make_shared<op::v1::LogicalXor>(
532             node->input_value(0), node->input_value(1), node->get_autob());
533         replace_node(node, replacement_node);
534         return replacement_node;
535     }
536
537     using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
538
539     template <typename T>
540     bool op_cast_thunk(shared_ptr<Node> node)
541     {
542         auto upgraded_node = op_cast(as_type_ptr<T>(node));
543         if (upgraded_node)
544         {
545             if (ngraph::get_provenance_enabled())
546             {
547                 const std::string provenance_tag =
548                     "<Opset1_Upgrade (v0 " + std::string(node->get_type_name()) + ")>";
549                 upgraded_node->add_provenance_tags_above(node->input_values(), {provenance_tag});
550             }
551             return true;
552         }
553         return false;
554     }
555
556     DispatchMap& get_dispatch_map()
557     {
558         static DispatchMap dispatch_map{
559 #define NGRAPH_OP(NAME, NAMESPACE) {NAMESPACE::NAME::type_info, op_cast_thunk<NAMESPACE::NAME>},
560 #include "opset0_tbl.hpp"
561 #undef NGRAPH_OP
562         };
563         return dispatch_map;
564     }
565 } // namespace
566
567 bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
568 {
569     bool modified = false;
570     auto& dispatch_map = get_dispatch_map();
571     auto it = dispatch_map.find(node->get_type_info());
572     if (it != dispatch_map.end())
573     {
574         modified = it->second(node);
575     }
576     return modified;
577 }