Remove obsoleted v0::Product op (#2860)
[platform/upstream/dldt.git] / ngraph / test / runtime / pass / opset1_upgrade.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 #include "opset1_upgrade.hpp"
17
18 #include <functional>
19 #include <iterator>
20 #include <limits>
21 #include <numeric>
22
23 #include "ngraph/builder/autobroadcast.hpp"
24 #include "ngraph/builder/reshape.hpp"
25 #include "ngraph/graph_util.hpp"
26 #include "ngraph/op/util/op_types.hpp"
27 #include "ngraph/ops.hpp"
28 #include "ngraph/provenance.hpp"
29 #include "op/avg_pool.hpp"
30 #include "op/convolution.hpp"
31 #include "op/group_conv.hpp"
32
33 NGRAPH_SUPPRESS_DEPRECATED_START
34
35 using namespace std;
36 using namespace ngraph;
37
38 namespace opset1_upgrade
39 {
40     template <typename OpV0, typename OpV1>
41     shared_ptr<Node> op_cast_binary_elementwise_node(const shared_ptr<OpV0>& node)
42     {
43         const auto autob = node->get_autob();
44         auto replacement_node =
45             make_shared<OpV1>(node->input_value(0), node->input_value(1), autob);
46         replace_node(node, replacement_node);
47         return replacement_node;
48     }
49
50     // Default is that we didn nothing
51     shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; }
52     shared_ptr<Node> op_cast(shared_ptr<op::Add> node)
53     {
54         return op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
55     }
56
57     shared_ptr<Node> op_cast(shared_ptr<op::v0::Convolution> node)
58     {
59         auto strides = node->get_window_movement_strides();
60         auto dilations = node->get_window_dilation_strides();
61         auto pads_begin = node->get_padding_below();
62         auto pads_end = node->get_padding_above();
63         auto data_dilation_strides = node->get_data_dilation_strides();
64         auto auto_pad = node->get_pad_type();
65
66         bool is_dds_valid = all_of(data_dilation_strides.begin(),
67                                    data_dilation_strides.end(),
68                                    [](size_t value) { return value == 1; });
69
70         NGRAPH_CHECK(is_dds_valid,
71                      "Unable to convert Convolution:0 to Convolution:1 with data dilation strides "
72                      "other than `1`. Node: ",
73                      *node);
74
75         auto replacement_node = make_shared<op::v1::Convolution>(node->input_value(0),
76                                                                  node->input_value(1),
77                                                                  strides,
78                                                                  pads_begin,
79                                                                  pads_end,
80                                                                  dilations,
81                                                                  auto_pad);
82         replace_node(node, replacement_node);
83         return replacement_node;
84     }
85
86     shared_ptr<Node> op_cast(shared_ptr<op::v0::ConvolutionBackpropData> node)
87     {
88         auto data_batch_shape = node->get_data_batch_shape();
89         auto strides = node->get_window_movement_strides_forward();
90         auto dilations = node->get_window_dilation_strides_forward();
91         auto pads_begin = node->get_padding_below_forward();
92         auto pads_end = node->get_padding_above_forward();
93         auto data_dilation_strides = node->get_data_dilation_strides_forward();
94
95         bool is_dds_valid = all_of(data_dilation_strides.begin(),
96                                    data_dilation_strides.end(),
97                                    [](size_t value) { return value == 1; });
98
99         NGRAPH_CHECK(is_dds_valid,
100                      "Unable to convert ConvolutionBackpropData:0 to ConvolutionBackpropData:1 "
101                      "with data dilation strides "
102                      "other than `1`. Node: ",
103                      *node);
104
105         auto replacement_node = make_shared<op::v1::ConvolutionBackpropData>(
106             node->input_value(1), // data
107             node->input_value(0), // filters
108             op::Constant::create(
109                 element::i64,
110                 Shape{data_batch_shape.size() - 2},
111                 vector<size_t>(data_batch_shape.begin() + 2, data_batch_shape.end())),
112             strides,
113             pads_begin,
114             pads_end,
115             dilations);
116         replace_node(node, replacement_node);
117         return replacement_node;
118     }
119
120     shared_ptr<Node> op_cast(shared_ptr<op::Divide> node)
121     {
122         const auto autob = node->get_autob();
123         const bool pydiv = node->is_pythondiv();
124         auto replacement_node =
125             make_shared<op::v1::Divide>(node->input_value(0), node->input_value(1), pydiv, autob);
126         replace_node(node, replacement_node);
127         return replacement_node;
128     }
129
130     shared_ptr<Node> op_cast(shared_ptr<op::Reshape> node)
131     {
132         shared_ptr<Node> replacement_node =
133             builder::opset1::reshape(node->input_value(0), node->get_reshape_output_shape());
134         replace_node(node, replacement_node);
135         return replacement_node;
136     }
137
138     shared_ptr<Node> op_cast(shared_ptr<op::Equal> node)
139     {
140         return op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
141     }
142
143     shared_ptr<Node> op_cast(shared_ptr<op::Greater> node)
144     {
145         return op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
146     }
147
148     shared_ptr<Node> op_cast(shared_ptr<op::GreaterEq> node)
149     {
150         return op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node);
151     }
152
153     shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolution> node)
154     {
155         auto strides = node->get_window_movement_strides();
156         auto dilations = node->get_window_dilation_strides();
157         auto pads_begin = node->get_padding_below();
158         auto pads_end = node->get_padding_above();
159         auto data_dilation_strides = node->get_data_dilation_strides();
160         auto auto_pad = node->get_pad_type();
161
162         bool is_dds_valid = all_of(data_dilation_strides.begin(),
163                                    data_dilation_strides.end(),
164                                    [](size_t value) { return value == 1; });
165
166         NGRAPH_CHECK(is_dds_valid,
167                      "Unable to convert GroupConvolution:0 to GroupConvolution:1"
168                      "with data dilation strides other than `1`. Node: ",
169                      *node);
170
171         shared_ptr<Node> replacement_node;
172         if (node->has_groups_in_filters())
173         {
174             replacement_node = make_shared<op::v1::GroupConvolution>(node->input_value(0),
175                                                                      node->input_value(1),
176                                                                      strides,
177                                                                      pads_begin,
178                                                                      pads_end,
179                                                                      dilations,
180                                                                      auto_pad);
181         }
182         else
183         {
184             NGRAPH_CHECK(node->get_input_partial_shape(1).is_static(),
185                          "Unable to convert GroupConvolution:0 to GroupConvolution:1"
186                          "with dynamic filters shape. Node: ",
187                          *node);
188
189             auto filters_shape = node->get_input_shape(1);
190             auto groups = node->get_groups();
191             filters_shape[0] /= groups;
192             filters_shape.insert(filters_shape.begin(), groups);
193
194             auto reshaped_filters = builder::opset1::reshape(node->input_value(1), filters_shape);
195
196             replacement_node = make_shared<op::v1::GroupConvolution>(node->input_value(0),
197                                                                      reshaped_filters,
198                                                                      strides,
199                                                                      pads_begin,
200                                                                      pads_end,
201                                                                      dilations,
202                                                                      auto_pad);
203         }
204         replace_node(node, replacement_node);
205         return replacement_node;
206     }
207
208     shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolutionBackpropData> node)
209     {
210         const auto strides = node->get_window_movement_strides();
211         const auto dilations = node->get_window_dilation_strides();
212         const auto pads_begin = node->get_padding_below();
213         const auto pads_end = node->get_padding_above();
214
215         const auto data_batch_pshape = node->get_input_partial_shape(0);
216         const auto filters_pshape = node->get_input_partial_shape(1);
217
218         NGRAPH_CHECK(data_batch_pshape.is_static(),
219                      "Unable to convert GroupConvolutionBackpropData:0 to "
220                      "GroupConvolutionBackpropData:1 with dynamic data_batch shape. Node: ",
221                      *node);
222         NGRAPH_CHECK(filters_pshape.is_static(),
223                      "Unable to convert GroupConvolutionBackpropData:0 to "
224                      "GroupConvolutionBackpropData:1 with dynamic filters shape. Node: ",
225                      *node);
226
227         auto data_batch_shape = data_batch_pshape.to_shape();
228         // Remove N, C from output shape to preserve only spatial dimentions.
229         data_batch_shape.erase(std::begin(data_batch_shape),
230                                std::next(std::begin(data_batch_shape), 2));
231         auto filters_shape = filters_pshape.to_shape();
232         auto groups = node->get_groups();
233
234         filters_shape[0] /= groups;
235         filters_shape.insert(filters_shape.begin(), groups);
236         auto reshaped_filters = builder::opset1::reshape(node->input_value(1), filters_shape);
237
238         auto replacement_node = make_shared<op::v1::GroupConvolutionBackpropData>(
239             node->input_value(2),
240             reshaped_filters,
241             op::Constant::create(element::i64, Shape{data_batch_shape.size()}, data_batch_shape),
242             strides,
243             pads_begin,
244             pads_end,
245             dilations);
246         replace_node(node, replacement_node);
247         return replacement_node;
248     }
249
250     shared_ptr<Node> op_cast(shared_ptr<op::Less> node)
251     {
252         return op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
253     }
254
255     shared_ptr<Node> op_cast(shared_ptr<op::LessEq> node)
256     {
257         return op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
258     }
259
260     shared_ptr<Node> op_cast(shared_ptr<op::Max> node)
261     {
262         bool keep_dims = false;
263         auto replacement_node =
264             make_shared<op::v1::ReduceMax>(node->input_value(0), node->input_value(1), keep_dims);
265         replace_node(node, replacement_node);
266         return replacement_node;
267     }
268
269     shared_ptr<Node> op_cast(shared_ptr<op::Maximum> node)
270     {
271         return op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
272     }
273
274     shared_ptr<Node> op_cast(shared_ptr<op::Min> node)
275     {
276         bool keep_dims = false;
277         auto replacement_node =
278             make_shared<op::v1::ReduceMin>(node->input_value(0), node->input_value(1), keep_dims);
279         replace_node(node, replacement_node);
280         return replacement_node;
281     }
282
283     shared_ptr<Node> op_cast(shared_ptr<op::Minimum> node)
284     {
285         return op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
286     }
287
288     shared_ptr<Node> op_cast(shared_ptr<op::Multiply> node)
289     {
290         return op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
291     }
292
293     shared_ptr<Node> op_cast(shared_ptr<op::Not> node)
294     {
295         auto replacement_node = make_shared<op::v1::LogicalNot>(node->input_value(0));
296         replace_node(node, replacement_node);
297         return replacement_node;
298     }
299
300     shared_ptr<Node> op_cast(shared_ptr<op::NotEqual> node)
301     {
302         return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
303     }
304
305     shared_ptr<Node> op_cast(shared_ptr<op::OneHot> node)
306     {
307         const auto indices = node->input_value(0).get_node_shared_ptr();
308         const auto one_hot_axis = node->get_one_hot_axis();
309
310         const auto output_pshape = node->get_output_partial_shape(0);
311         NGRAPH_CHECK(output_pshape[one_hot_axis].is_static(),
312                      "OneHot:v0 one hot axis dimension must be static ",
313                      *node);
314         const auto depth = output_pshape[one_hot_axis].get_length();
315         const auto depth_node = op::Constant::create(element::i64, Shape{}, {depth});
316
317         const auto on_value = op::Constant::create(element::i64, Shape{}, {1});
318         const auto off_value = op::Constant::create(element::i64, Shape{}, {0});
319
320         auto replacement_node =
321             make_shared<op::v1::OneHot>(indices, depth_node, on_value, off_value, one_hot_axis);
322         replace_node(node, replacement_node);
323         return replacement_node;
324     }
325
326     shared_ptr<Node> op_cast(shared_ptr<op::Or> node)
327     {
328         return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
329     }
330
331     shared_ptr<Node> op_cast(shared_ptr<op::Power> node)
332     {
333         return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
334     }
335
336     shared_ptr<Node> op_cast(shared_ptr<op::Reverse> node)
337     {
338         // creates a Constant node from the v0::Reverse reversed_axes attribute
339         // and uses it as the second input of v1::Reverse
340         const auto reversed_axes = node->get_reversed_axes();
341
342         const auto reversed_axes_constant = op::Constant::create(
343             element::i64, Shape{reversed_axes.size()}, reversed_axes.to_vector());
344
345         const auto replacement_node = make_shared<op::v1::Reverse>(
346             node->input_value(0), reversed_axes_constant, op::v1::Reverse::Mode::INDEX);
347
348         replace_node(node, replacement_node);
349         return replacement_node;
350     }
351
352     shared_ptr<Node> op_cast(shared_ptr<op::Select> node)
353     {
354         auto replacement_node = make_shared<op::v1::Select>(node->input_value(0),
355                                                             node->input_value(1),
356                                                             node->input_value(2),
357                                                             op::AutoBroadcastSpec());
358         replace_node(node, replacement_node);
359         return replacement_node;
360     }
361
362     shared_ptr<Node> op_cast(shared_ptr<op::Softmax> node)
363     {
364         NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
365                      "axes parameter is expected to be a static constant");
366
367         AxisSet axes = node->get_axes();
368
369         NGRAPH_CHECK(
370             axes.size() == 1,
371             "Unable to convert Softmax:0 to Softmax:1 with zero or more than one axis. Node: ",
372             *node);
373
374         auto replacement_node =
375             make_shared<op::v1::Softmax>(node->input_value(0), axes.to_vector()[0]);
376         replace_node(node, replacement_node);
377         return replacement_node;
378     }
379
380     shared_ptr<Node> op_cast(shared_ptr<op::Slice> node)
381     {
382         const auto data = node->input_value(0);
383         const auto begin = op::Constant::create(
384             element::i64, Shape{node->get_lower_bounds().size()}, node->get_lower_bounds());
385         const auto end = op::Constant::create(
386             element::i64, Shape{node->get_upper_bounds().size()}, node->get_upper_bounds());
387         const auto strides = op::Constant::create(
388             element::i64, Shape{node->get_strides().size()}, node->get_strides());
389         int64_t input_size = node->get_lower_bounds().size();
390
391         auto replacement_node = make_shared<op::v1::StridedSlice>(data,
392                                                                   begin,
393                                                                   end,
394                                                                   strides,
395                                                                   vector<int64_t>(input_size, 0),
396                                                                   vector<int64_t>(input_size, 0));
397
398         replace_node(node, replacement_node);
399         return replacement_node;
400     }
401
402     shared_ptr<Node> op_cast(shared_ptr<op::Split> node)
403     {
404         const auto& splits_vec = node->get_splits();
405         const auto first_elem = splits_vec.front();
406
407         const bool split_evenly =
408             std::all_of(splits_vec.begin(), splits_vec.end(), [first_elem](const size_t split) {
409                 return split == first_elem;
410             });
411
412         std::shared_ptr<Node> replacement_node;
413         if (split_evenly)
414         {
415             replacement_node = make_shared<op::v1::Split>(
416                 node->input_value(0), node->input_value(1), splits_vec.front());
417         }
418         else
419         {
420             const auto split_lengths =
421                 ngraph::op::Constant::create(element::u64, Shape{splits_vec.size()}, splits_vec);
422
423             replacement_node = make_shared<op::v1::VariadicSplit>(
424                 node->input_value(0), node->input_value(1), split_lengths);
425         }
426
427         replace_node(node, replacement_node);
428         return replacement_node;
429     }
430
431     shared_ptr<Node> op_cast(shared_ptr<op::Subtract> node)
432     {
433         return op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
434     }
435
436     shared_ptr<Node> op_cast(shared_ptr<op::Sum> node)
437     {
438         bool keep_dims = false;
439         auto replacement_node =
440             make_shared<op::v1::ReduceSum>(node->input_value(0), node->input_value(1), keep_dims);
441         replace_node(node, replacement_node);
442         return replacement_node;
443     }
444
445     shared_ptr<Node> op_cast(shared_ptr<op::TopK> node)
446     {
447         NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
448                      "parameter k is expected to be a static constant");
449         NGRAPH_CHECK(op::is_constant(node->input_value(2).get_node()),
450                      "parameter top_k_axis is expected to be a static constant");
451
452         const auto k = node->get_k();
453         const auto axis = node->get_top_k_axis();
454
455         std::string sort;
456         switch (node->get_sort())
457         {
458         case op::TopK::SortType::SORT_INDICES: sort = "index"; break;
459         case op::TopK::SortType::SORT_VALUES: sort = "value"; break;
460         case op::TopK::SortType::NONE: sort = "none"; break;
461         }
462
463         std::string mode;
464         if (node->get_compute_max())
465         {
466             mode = "max";
467         }
468         else
469         {
470             mode = "min";
471         }
472
473         const auto k_constant = op::Constant::create(element::i64, Shape{}, {k});
474         auto replacement_node =
475             make_shared<op::v1::TopK>(node->input_value(0), k_constant, axis, mode, sort);
476
477         // indices output will be 0, values 1
478         vector<int64_t> output_order{1, 0};
479         replace_node(node, replacement_node, output_order);
480         return replacement_node;
481     }
482
483     shared_ptr<Node> op_cast(shared_ptr<op::Xor> node)
484     {
485         auto replacement_node = make_shared<op::v1::LogicalXor>(
486             node->input_value(0), node->input_value(1), node->get_autob());
487         replace_node(node, replacement_node);
488         return replacement_node;
489     }
490
491     using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
492
493     template <typename T>
494     bool op_cast_thunk(shared_ptr<Node> node)
495     {
496         auto upgraded_node = op_cast(as_type_ptr<T>(node));
497         if (upgraded_node)
498         {
499             if (ngraph::get_provenance_enabled())
500             {
501                 const std::string provenance_tag =
502                     "<Opset1_Upgrade (v0 " + std::string(node->get_type_name()) + ")>";
503                 upgraded_node->add_provenance_tags_above(node->input_values(), {provenance_tag});
504             }
505             return true;
506         }
507         return false;
508     }
509
510     DispatchMap& get_dispatch_map()
511     {
512         NGRAPH_SUPPRESS_DEPRECATED_START
513         static DispatchMap dispatch_map{
514 #define NGRAPH_OP(NAME, NAMESPACE) {NAMESPACE::NAME::type_info, op_cast_thunk<NAMESPACE::NAME>},
515 #include "opset0_tbl.hpp"
516 #undef NGRAPH_OP
517         };
518         return dispatch_map;
519         NGRAPH_SUPPRESS_DEPRECATED_END
520     }
521 } // namespace opset1_upgrade
522
523 bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
524 {
525     bool modified = false;
526     auto& dispatch_map = opset1_upgrade::get_dispatch_map();
527     auto it = dispatch_map.find(node->get_type_info());
528     if (it != dispatch_map.end())
529     {
530         modified = it->second(node);
531     }
532     return modified;
533 }