Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / test / runtime / pass / opset1_upgrade.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 #include "opset1_upgrade.hpp"
17
18 #include <functional>
19 #include <iterator>
20 #include <limits>
21 #include <numeric>
22
23 #include "ngraph/builder/autobroadcast.hpp"
24 #include "ngraph/builder/reshape.hpp"
25 #include "ngraph/graph_util.hpp"
26 #include "ngraph/op/util/op_types.hpp"
27 #include "ngraph/ops.hpp"
28 #include "ngraph/provenance.hpp"
29 #include "op/avg_pool.hpp"
30 #include "op/convolution.hpp"
31 #include "op/group_conv.hpp"
32
33 NGRAPH_SUPPRESS_DEPRECATED_START
34
35 using namespace std;
36 using namespace ngraph;
37
38 namespace
39 {
40     template <typename OpV0, typename OpV1>
41     shared_ptr<Node> op_cast_binary_elementwise_node(const shared_ptr<OpV0>& node)
42     {
43         const auto autob = node->get_autob();
44         auto replacement_node =
45             make_shared<OpV1>(node->input_value(0), node->input_value(1), autob);
46         replace_node(node, replacement_node);
47         return replacement_node;
48     }
49
50     // Default is that we didn nothing
51     shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; }
52     shared_ptr<Node> op_cast(shared_ptr<op::Add> node)
53     {
54         return op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
55     }
56
57     shared_ptr<Node> op_cast(shared_ptr<op::Broadcast> node)
58     {
59         auto replacement_node = ngraph::builder::opset1::make_broadcast(
60             node->input_value(0), node->get_broadcast_shape(), node->get_broadcast_axes());
61         replace_node(node, replacement_node.get_node_shared_ptr());
62         return replacement_node.get_node_shared_ptr();
63     }
64
65     shared_ptr<Node> op_cast(shared_ptr<op::BroadcastLike> node) { return nullptr; }
66     shared_ptr<Node> op_cast(shared_ptr<op::v0::Convolution> node)
67     {
68         auto strides = node->get_window_movement_strides();
69         auto dilations = node->get_window_dilation_strides();
70         auto pads_begin = node->get_padding_below();
71         auto pads_end = node->get_padding_above();
72         auto data_dilation_strides = node->get_data_dilation_strides();
73         auto auto_pad = node->get_pad_type();
74
75         bool is_dds_valid = all_of(data_dilation_strides.begin(),
76                                    data_dilation_strides.end(),
77                                    [](size_t value) { return value == 1; });
78
79         NGRAPH_CHECK(is_dds_valid,
80                      "Unable to convert Convolution:0 to Convolution:1 with data dilation strides "
81                      "other than `1`. Node: ",
82                      *node);
83
84         auto replacement_node = make_shared<op::v1::Convolution>(node->input_value(0),
85                                                                  node->input_value(1),
86                                                                  strides,
87                                                                  pads_begin,
88                                                                  pads_end,
89                                                                  dilations,
90                                                                  auto_pad);
91         replace_node(node, replacement_node);
92         return replacement_node;
93     }
94
95     shared_ptr<Node> op_cast(shared_ptr<op::v0::ConvolutionBackpropData> node)
96     {
97         auto data_batch_shape = node->get_data_batch_shape();
98         auto strides = node->get_window_movement_strides_forward();
99         auto dilations = node->get_window_dilation_strides_forward();
100         auto pads_begin = node->get_padding_below_forward();
101         auto pads_end = node->get_padding_above_forward();
102         auto data_dilation_strides = node->get_data_dilation_strides_forward();
103
104         bool is_dds_valid = all_of(data_dilation_strides.begin(),
105                                    data_dilation_strides.end(),
106                                    [](size_t value) { return value == 1; });
107
108         NGRAPH_CHECK(is_dds_valid,
109                      "Unable to convert ConvolutionBackpropData:0 to ConvolutionBackpropData:1 "
110                      "with data dilation strides "
111                      "other than `1`. Node: ",
112                      *node);
113
114         auto replacement_node = make_shared<op::v1::ConvolutionBackpropData>(
115             node->input_value(1), // data
116             node->input_value(0), // filters
117             op::Constant::create(
118                 element::i64,
119                 Shape{data_batch_shape.size() - 2},
120                 vector<size_t>(data_batch_shape.begin() + 2, data_batch_shape.end())),
121             strides,
122             pads_begin,
123             pads_end,
124             dilations);
125         replace_node(node, replacement_node);
126         return replacement_node;
127     }
128
129     shared_ptr<Node> op_cast(shared_ptr<op::Divide> node)
130     {
131         const auto autob = node->get_autob();
132         const bool pydiv = node->is_pythondiv();
133         auto replacement_node =
134             make_shared<op::v1::Divide>(node->input_value(0), node->input_value(1), pydiv, autob);
135         replace_node(node, replacement_node);
136         return replacement_node;
137     }
138
139     shared_ptr<Node> op_cast(shared_ptr<op::Reshape> node)
140     {
141         shared_ptr<Node> replacement_node =
142             builder::opset1::reshape(node->input_value(0), node->get_reshape_output_shape());
143         replace_node(node, replacement_node);
144         return replacement_node;
145     }
146
147     shared_ptr<Node> op_cast(shared_ptr<op::Equal> node)
148     {
149         return op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
150     }
151
152     shared_ptr<Node> op_cast(shared_ptr<op::Gather> node)
153     {
154         int64_t axis = node->get_axis();
155
156         auto axis_node = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{axis});
157         auto replacement_node =
158             make_shared<op::v1::Gather>(node->input_value(0), node->input_value(1), axis_node);
159         replace_node(node, replacement_node);
160         return replacement_node;
161     }
162
163     shared_ptr<Node> op_cast(shared_ptr<op::Greater> node)
164     {
165         return op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
166     }
167
168     shared_ptr<Node> op_cast(shared_ptr<op::GreaterEq> node)
169     {
170         return op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node);
171     }
172
173     shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolution> node)
174     {
175         auto strides = node->get_window_movement_strides();
176         auto dilations = node->get_window_dilation_strides();
177         auto pads_begin = node->get_padding_below();
178         auto pads_end = node->get_padding_above();
179         auto data_dilation_strides = node->get_data_dilation_strides();
180         auto auto_pad = node->get_pad_type();
181
182         bool is_dds_valid = all_of(data_dilation_strides.begin(),
183                                    data_dilation_strides.end(),
184                                    [](size_t value) { return value == 1; });
185
186         NGRAPH_CHECK(is_dds_valid,
187                      "Unable to convert GroupConvolution:0 to GroupConvolution:1"
188                      "with data dilation strides other than `1`. Node: ",
189                      *node);
190
191         shared_ptr<Node> replacement_node;
192         if (node->has_groups_in_filters())
193         {
194             replacement_node = make_shared<op::v1::GroupConvolution>(node->input_value(0),
195                                                                      node->input_value(1),
196                                                                      strides,
197                                                                      pads_begin,
198                                                                      pads_end,
199                                                                      dilations,
200                                                                      auto_pad);
201         }
202         else
203         {
204             NGRAPH_CHECK(node->get_input_partial_shape(1).is_static(),
205                          "Unable to convert GroupConvolution:0 to GroupConvolution:1"
206                          "with dynamic filters shape. Node: ",
207                          *node);
208
209             auto filters_shape = node->get_input_shape(1);
210             auto groups = node->get_groups();
211             filters_shape[0] /= groups;
212             filters_shape.insert(filters_shape.begin(), groups);
213
214             auto reshaped_filters = builder::opset1::reshape(node->input_value(1), filters_shape);
215
216             replacement_node = make_shared<op::v1::GroupConvolution>(node->input_value(0),
217                                                                      reshaped_filters,
218                                                                      strides,
219                                                                      pads_begin,
220                                                                      pads_end,
221                                                                      dilations,
222                                                                      auto_pad);
223         }
224         replace_node(node, replacement_node);
225         return replacement_node;
226     }
227
228     shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolutionBackpropData> node)
229     {
230         const auto strides = node->get_window_movement_strides();
231         const auto dilations = node->get_window_dilation_strides();
232         const auto pads_begin = node->get_padding_below();
233         const auto pads_end = node->get_padding_above();
234
235         const auto data_batch_pshape = node->get_input_partial_shape(0);
236         const auto filters_pshape = node->get_input_partial_shape(1);
237
238         NGRAPH_CHECK(data_batch_pshape.is_static(),
239                      "Unable to convert GroupConvolutionBackpropData:0 to "
240                      "GroupConvolutionBackpropData:1 with dynamic data_batch shape. Node: ",
241                      *node);
242         NGRAPH_CHECK(filters_pshape.is_static(),
243                      "Unable to convert GroupConvolutionBackpropData:0 to "
244                      "GroupConvolutionBackpropData:1 with dynamic filters shape. Node: ",
245                      *node);
246
247         auto data_batch_shape = data_batch_pshape.to_shape();
248         // Remove N, C from output shape to preserve only spatial dimentions.
249         data_batch_shape.erase(std::begin(data_batch_shape),
250                                std::next(std::begin(data_batch_shape), 2));
251         auto filters_shape = filters_pshape.to_shape();
252         auto groups = node->get_groups();
253
254         filters_shape[0] /= groups;
255         filters_shape.insert(filters_shape.begin(), groups);
256         auto reshaped_filters = builder::opset1::reshape(node->input_value(1), filters_shape);
257
258         auto replacement_node = make_shared<op::v1::GroupConvolutionBackpropData>(
259             node->input_value(2),
260             reshaped_filters,
261             op::Constant::create(element::i64, Shape{data_batch_shape.size()}, data_batch_shape),
262             strides,
263             pads_begin,
264             pads_end,
265             dilations);
266         replace_node(node, replacement_node);
267         return replacement_node;
268     }
269
270     shared_ptr<Node> op_cast(shared_ptr<op::Less> node)
271     {
272         return op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
273     }
274
275     shared_ptr<Node> op_cast(shared_ptr<op::LessEq> node)
276     {
277         return op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
278     }
279
280     shared_ptr<Node> op_cast(shared_ptr<op::Max> node)
281     {
282         bool keep_dims = false;
283         auto replacement_node =
284             make_shared<op::v1::ReduceMax>(node->input_value(0), node->input_value(1), keep_dims);
285         replace_node(node, replacement_node);
286         return replacement_node;
287     }
288
289     shared_ptr<Node> op_cast(shared_ptr<op::Maximum> node)
290     {
291         return op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
292     }
293
294     shared_ptr<Node> op_cast(shared_ptr<op::Min> node)
295     {
296         bool keep_dims = false;
297         auto replacement_node =
298             make_shared<op::v1::ReduceMin>(node->input_value(0), node->input_value(1), keep_dims);
299         replace_node(node, replacement_node);
300         return replacement_node;
301     }
302
303     shared_ptr<Node> op_cast(shared_ptr<op::Minimum> node)
304     {
305         return op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
306     }
307
308     shared_ptr<Node> op_cast(shared_ptr<op::Multiply> node)
309     {
310         return op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
311     }
312
313     shared_ptr<Node> op_cast(shared_ptr<op::Not> node)
314     {
315         auto replacement_node = make_shared<op::v1::LogicalNot>(node->input_value(0));
316         replace_node(node, replacement_node);
317         return replacement_node;
318     }
319
320     shared_ptr<Node> op_cast(shared_ptr<op::NotEqual> node)
321     {
322         return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
323     }
324
325     shared_ptr<Node> op_cast(shared_ptr<op::OneHot> node)
326     {
327         const auto indices = node->input_value(0).get_node_shared_ptr();
328         const auto one_hot_axis = node->get_one_hot_axis();
329
330         const auto output_pshape = node->get_output_partial_shape(0);
331         NGRAPH_CHECK(output_pshape[one_hot_axis].is_static(),
332                      "OneHot:v0 one hot axis dimension must be static ",
333                      *node);
334         const auto depth = output_pshape[one_hot_axis].get_length();
335         const auto depth_node = op::Constant::create(element::i64, Shape{}, {depth});
336
337         const auto on_value = op::Constant::create(element::i64, Shape{}, {1});
338         const auto off_value = op::Constant::create(element::i64, Shape{}, {0});
339
340         auto replacement_node =
341             make_shared<op::v1::OneHot>(indices, depth_node, on_value, off_value, one_hot_axis);
342         replace_node(node, replacement_node);
343         return replacement_node;
344     }
345
346     shared_ptr<Node> op_cast(shared_ptr<op::Or> node)
347     {
348         return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
349     }
350
351     shared_ptr<Node> op_cast(shared_ptr<op::Power> node)
352     {
353         return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
354     }
355
356     shared_ptr<Node> op_cast(shared_ptr<op::Product> node)
357     {
358         bool keep_dims = false;
359         auto replacement_node =
360             make_shared<op::v1::ReduceProd>(node->input_value(0), node->input_value(1), keep_dims);
361         replace_node(node, replacement_node);
362         return replacement_node;
363     }
364
365     shared_ptr<Node> op_cast(shared_ptr<op::Reverse> node)
366     {
367         // creates a Constant node from the v0::Reverse reversed_axes attribute
368         // and uses it as the second input of v1::Reverse
369         const auto reversed_axes = node->get_reversed_axes();
370
371         const auto reversed_axes_constant = op::Constant::create(
372             element::i64, Shape{reversed_axes.size()}, reversed_axes.to_vector());
373
374         const auto replacement_node = make_shared<op::v1::Reverse>(
375             node->input_value(0), reversed_axes_constant, op::v1::Reverse::Mode::INDEX);
376
377         replace_node(node, replacement_node);
378         return replacement_node;
379     }
380
381     shared_ptr<Node> op_cast(shared_ptr<op::Select> node)
382     {
383         auto replacement_node = make_shared<op::v1::Select>(node->input_value(0),
384                                                             node->input_value(1),
385                                                             node->input_value(2),
386                                                             op::AutoBroadcastSpec());
387         replace_node(node, replacement_node);
388         return replacement_node;
389     }
390
391     shared_ptr<Node> op_cast(shared_ptr<op::Softmax> node)
392     {
393         NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
394                      "axes parameter is expected to be a static constant");
395
396         AxisSet axes = node->get_axes();
397
398         NGRAPH_CHECK(
399             axes.size() == 1,
400             "Unable to convert Softmax:0 to Softmax:1 with zero or more than one axis. Node: ",
401             *node);
402
403         auto replacement_node =
404             make_shared<op::v1::Softmax>(node->input_value(0), axes.to_vector()[0]);
405         replace_node(node, replacement_node);
406         return replacement_node;
407     }
408
409     shared_ptr<Node> op_cast(shared_ptr<op::Slice> node)
410     {
411         const auto data = node->input_value(0);
412         const auto begin = op::Constant::create(
413             element::i64, Shape{node->get_lower_bounds().size()}, node->get_lower_bounds());
414         const auto end = op::Constant::create(
415             element::i64, Shape{node->get_upper_bounds().size()}, node->get_upper_bounds());
416         const auto strides = op::Constant::create(
417             element::i64, Shape{node->get_strides().size()}, node->get_strides());
418         int64_t input_size = node->get_lower_bounds().size();
419
420         auto replacement_node = make_shared<op::v1::StridedSlice>(data,
421                                                                   begin,
422                                                                   end,
423                                                                   strides,
424                                                                   vector<int64_t>(input_size, 0),
425                                                                   vector<int64_t>(input_size, 0));
426
427         replace_node(node, replacement_node);
428         return replacement_node;
429     }
430
431     shared_ptr<Node> op_cast(shared_ptr<op::Split> node)
432     {
433         const auto& splits_vec = node->get_splits();
434         const auto first_elem = splits_vec.front();
435
436         const bool split_evenly =
437             std::all_of(splits_vec.begin(), splits_vec.end(), [first_elem](const size_t split) {
438                 return split == first_elem;
439             });
440
441         std::shared_ptr<Node> replacement_node;
442         if (split_evenly)
443         {
444             replacement_node = make_shared<op::v1::Split>(
445                 node->input_value(0), node->input_value(1), splits_vec.front());
446         }
447         else
448         {
449             const auto split_lengths =
450                 ngraph::op::Constant::create(element::u64, Shape{splits_vec.size()}, splits_vec);
451
452             replacement_node = make_shared<op::v1::VariadicSplit>(
453                 node->input_value(0), node->input_value(1), split_lengths);
454         }
455
456         replace_node(node, replacement_node);
457         return replacement_node;
458     }
459
460     shared_ptr<Node> op_cast(shared_ptr<op::Subtract> node)
461     {
462         return op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
463     }
464
465     shared_ptr<Node> op_cast(shared_ptr<op::Sum> node)
466     {
467         bool keep_dims = false;
468         auto replacement_node =
469             make_shared<op::v1::ReduceSum>(node->input_value(0), node->input_value(1), keep_dims);
470         replace_node(node, replacement_node);
471         return replacement_node;
472     }
473
474     shared_ptr<Node> op_cast(shared_ptr<op::TopK> node)
475     {
476         NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
477                      "parameter k is expected to be a static constant");
478         NGRAPH_CHECK(op::is_constant(node->input_value(2).get_node()),
479                      "parameter top_k_axis is expected to be a static constant");
480
481         const auto k = node->get_k();
482         const auto axis = node->get_top_k_axis();
483
484         std::string sort;
485         switch (node->get_sort())
486         {
487         case op::TopK::SortType::SORT_INDICES: sort = "index"; break;
488         case op::TopK::SortType::SORT_VALUES: sort = "value"; break;
489         case op::TopK::SortType::NONE: sort = "none"; break;
490         }
491
492         std::string mode;
493         if (node->get_compute_max())
494         {
495             mode = "max";
496         }
497         else
498         {
499             mode = "min";
500         }
501
502         const auto k_constant = op::Constant::create(element::i64, Shape{}, {k});
503         auto replacement_node =
504             make_shared<op::v1::TopK>(node->input_value(0), k_constant, axis, mode, sort);
505
506         // indices output will be 0, values 1
507         vector<int64_t> output_order{1, 0};
508         replace_node(node, replacement_node, output_order);
509         return replacement_node;
510     }
511
512     shared_ptr<Node> op_cast(shared_ptr<op::Xor> node)
513     {
514         auto replacement_node = make_shared<op::v1::LogicalXor>(
515             node->input_value(0), node->input_value(1), node->get_autob());
516         replace_node(node, replacement_node);
517         return replacement_node;
518     }
519
520     using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
521
522     template <typename T>
523     bool op_cast_thunk(shared_ptr<Node> node)
524     {
525         auto upgraded_node = op_cast(as_type_ptr<T>(node));
526         if (upgraded_node)
527         {
528             if (ngraph::get_provenance_enabled())
529             {
530                 const std::string provenance_tag =
531                     "<Opset1_Upgrade (v0 " + std::string(node->get_type_name()) + ")>";
532                 upgraded_node->add_provenance_tags_above(node->input_values(), {provenance_tag});
533             }
534             return true;
535         }
536         return false;
537     }
538
539     DispatchMap& get_dispatch_map()
540     {
541         NGRAPH_SUPPRESS_DEPRECATED_START
542         static DispatchMap dispatch_map{
543 #define NGRAPH_OP(NAME, NAMESPACE) {NAMESPACE::NAME::type_info, op_cast_thunk<NAMESPACE::NAME>},
544 #include "opset0_tbl.hpp"
545 #undef NGRAPH_OP
546         };
547         return dispatch_map;
548         NGRAPH_SUPPRESS_DEPRECATED_END
549     }
550 } // namespace
551
552 bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
553 {
554     bool modified = false;
555     auto& dispatch_map = get_dispatch_map();
556     auto it = dispatch_map.find(node->get_type_info());
557     if (it != dispatch_map.end())
558     {
559         modified = it->second(node);
560     }
561     return modified;
562 }