Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / test / runtime / pass / opset0_downgrade.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <algorithm>
18 #include <cstdint>
19 #include <functional>
20 #include <numeric>
21
22 #include "ngraph/builder/autobroadcast.hpp"
23 #include "ngraph/builder/reshape.hpp"
24 #include "ngraph/graph_util.hpp"
25 #include "ngraph/node.hpp"
26 #include "ngraph/op/util/attr_types.hpp"
27 #include "ngraph/op/util/op_types.hpp"
28 #include "ngraph/ops.hpp"
29 #include "ngraph/provenance.hpp"
30 #include "ngraph/slice_plan.hpp"
31 #include "ngraph/type.hpp"
32 #include "ngraph/validation_util.hpp"
33 #include "op/avg_pool.hpp"
34 #include "op/convolution.hpp"
35 #include "op/group_conv.hpp"
36 #include "pass/implicit_broadcast_elimination.hpp"
37 #include "pass/opset0_downgrade.hpp"
38
39 NGRAPH_SUPPRESS_DEPRECATED_START
40
41 using namespace std;
42 using namespace ngraph;
43
44 namespace
45 {
46     template <typename OpV0, typename OpV1>
47     shared_ptr<Node> op_cast_binary_elementwise_node(const shared_ptr<OpV1>& node)
48     {
49         const auto input_arg0 = node->input_value(0);
50         const auto input_arg1 = node->input_value(1);
51         const auto autob = node->get_autob();
52         auto replacement_node = make_shared<OpV0>(input_arg0, input_arg1, autob);
53         replace_node(node, replacement_node);
54         return replacement_node;
55     }
56
57     template <typename OpV0, typename OpV1>
58     shared_ptr<Node> op_cast_reduction_node(const shared_ptr<OpV1>& node)
59     {
60         auto replacement_node = make_shared<OpV0>(node->input_value(0), node->input_value(1));
61         if (node->get_keep_dims())
62         {
63             string v1_op_name = string{node->get_type_name()} + ":v1";
64             string v0_op_name = string{OpV0{}.get_type_name()} + ":v0";
65
66             NGRAPH_CHECK(node->reduction_axes_constant(),
67                          "Unable to convert ",
68                          v1_op_name,
69                          "to ",
70                          v0_op_name,
71                          " if reduction axes are not constant (for keep_dims=true). Node: ",
72                          *node);
73             auto output_pshape = replacement_node->get_output_partial_shape(0);
74             NGRAPH_CHECK(output_pshape.is_static(),
75                          "Unable to convert ",
76                          v1_op_name,
77                          "to ",
78                          v0_op_name,
79                          " if output shape is dynamic (for keep_dims=true). Node: ",
80                          *node);
81             const auto output_shape = output_pshape.to_shape();
82             auto reshaped_output_shape = output_shape;
83             for (const auto& axis : node->get_reduction_axes())
84             {
85                 reshaped_output_shape.insert(reshaped_output_shape.begin() + axis, 1);
86             }
87             auto reshaped_product = make_shared<op::Reshape>(replacement_node->output(0),
88                                                              get_default_order(output_shape),
89                                                              reshaped_output_shape);
90             return reshaped_product;
91         }
92         else
93         {
94             return replacement_node;
95         }
96     }
97
98     // Default is that we did nothing
99     shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; }
100     shared_ptr<Node> op_cast(shared_ptr<op::v1::Add> node)
101     {
102         return op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
103     }
104
105     shared_ptr<Node> op_cast(shared_ptr<op::v1::AvgPool> node)
106     {
107         auto const input_arg = node->input_value(0);
108         const auto ceil_mode = static_cast<bool>(node->get_rounding_type());
109         const auto include_padding_in_avg_computation = !node->get_exclude_pad();
110         const auto pad_type = node->get_auto_pad();
111         const auto padding_below = node->get_pads_begin();
112         const auto padding_above = node->get_pads_end();
113         const auto window_movement_strides = node->get_strides();
114         const auto window_shape = node->get_kernel();
115
116         auto replacement_node = make_shared<op::v0::AvgPool>(input_arg,
117                                                              window_shape,
118                                                              window_movement_strides,
119                                                              padding_below,
120                                                              padding_above,
121                                                              include_padding_in_avg_computation,
122                                                              pad_type,
123                                                              ceil_mode);
124         replace_node(node, replacement_node);
125         return replacement_node;
126     }
127
128     shared_ptr<Node> op_cast(shared_ptr<op::v1::Broadcast> node)
129     {
130         auto arg = node->input_value(0);
131         auto arg_pshape = arg.get_partial_shape();
132         auto arg_rank = arg_pshape.rank();
133         auto target_shape_input = node->input_value(1);
134
135         shared_ptr<Node> replacement_node;
136
137         NGRAPH_CHECK(arg_pshape.is_static(),
138                      "Unable to convert Broadcast:v1 to Broadcast:v0 "
139                      "if argument shape is not static. Node: ",
140                      *node);
141         const auto& arg_shape = arg_pshape.to_shape();
142
143         NGRAPH_CHECK(op::is_constant(target_shape_input.get_node()));
144         auto target_shape = node->get_output_shape(0);
145         NGRAPH_CHECK(node->get_broadcast_axes().first);
146
147         // (Re)construct axes_mapping.
148         AxisSet broadcast_axes = node->get_broadcast_axes().second;
149         std::vector<size_t> axes_mapping{
150             ngraph::builder::opset1::get_axes_mapping(target_shape, broadcast_axes)};
151
152         Output<Node> squeezed_arg = arg;
153         // Collect axes to squeeze. Broadcast v0 "adds" new axes, thus we have to squeeze
154         // the empty ones (dim:=1), which would be broadcasted by Broadcast v1.
155         std::vector<size_t> empty_axes;
156         for (size_t a{0}; a < axes_mapping.size(); ++a)
157         {
158             if (arg_shape.at(a) == 1 && target_shape.at(axes_mapping.at(a)) != 1)
159             {
160                 empty_axes.push_back(a);
161             }
162         }
163         // Check if arg_shape contains some more empty dimensions marked to broadcast.
164         // If axes_mapping size is less than arg_shape size, then some of arg dimensions may
165         // be equal to one and marked to broadcast.
166         if (axes_mapping.size() < arg_shape.size())
167         {
168             for (size_t a{axes_mapping.size()}; a < arg_shape.size(); ++a)
169             {
170                 if (arg_shape.at(a) == 1)
171                 {
172                     empty_axes.push_back(a);
173                 }
174             }
175         }
176         if (!empty_axes.empty())
177         {
178             auto v0squeeze = [](const Output<Node>& value, vector<size_t> axes) {
179                 if (axes.empty())
180                 {
181                     return value.get_node_shared_ptr();
182                 }
183
184                 Shape in_shape{value.get_shape()};
185                 for (size_t idx = 0; idx < axes.size(); ++idx)
186                 {
187                     in_shape.at(axes.at(idx)) = 0;
188                 }
189                 Shape output_shape;
190                 for (auto axis : in_shape)
191                 {
192                     if (axis != 0)
193                     {
194                         output_shape.push_back(axis);
195                     }
196                 }
197                 return make_shared<op::Reshape>(
198                            value, get_default_order(value.get_shape().size()), output_shape)
199                     ->add_provenance_group_members_above({value});
200
201             };
202             squeezed_arg = v0squeeze(arg, empty_axes);
203         }
204
205         replacement_node =
206             make_shared<op::v0::Broadcast>(squeezed_arg, target_shape, broadcast_axes);
207         replace_node(node, replacement_node);
208         return replacement_node;
209     }
210
211     shared_ptr<Node> op_cast(shared_ptr<op::v1::Convolution> node)
212     {
213         const auto data_arg = node->input_value(0);
214         const auto filters_arg = node->input_value(1);
215         const auto strides = node->get_strides();
216         const size_t num_spatial_dims = strides.size();
217         auto replacement_node = make_shared<op::v0::Convolution>(data_arg,
218                                                                  filters_arg,
219                                                                  node->get_strides(),
220                                                                  node->get_dilations(),
221                                                                  node->get_pads_begin(),
222                                                                  node->get_pads_end(),
223                                                                  Strides(num_spatial_dims, 1),
224                                                                  node->get_auto_pad());
225         replace_node(node, replacement_node);
226         return replacement_node;
227     }
228
229     shared_ptr<Node> op_cast(shared_ptr<op::v1::ConvolutionBackpropData> node)
230     {
231         const auto data_arg = node->input_value(0);
232         const auto filters_arg = node->input_value(1);
233
234         auto data_pshape = data_arg.get_partial_shape();
235         auto filters_pshape = filters_arg.get_partial_shape();
236
237         NGRAPH_CHECK(data_pshape.rank().is_static() && data_pshape[0].is_static() &&
238                          filters_pshape.rank().is_static() && filters_pshape[1].is_static(),
239                      "Unable to convert ConvolutionBackpropData:v1 to ConvolutionBackpropData:v0 "
240                      "if data shape N and filters shape C dimensions are not static. Node: ",
241                      *node);
242
243         const size_t num_spatial_dims = data_pshape.rank().get_length() - 2;
244
245         const PartialShape output_pshape{node->get_output_partial_shape(0)};
246         NGRAPH_CHECK(output_pshape.is_static(),
247                      "Unable to convert ConvolutionBackpropData:v1 to ConvolutionBackpropData:v0 "
248                      "if output shape is dynamic. Node: ",
249                      *node);
250         Shape output_shape = output_pshape.to_shape();
251
252         auto replacement_node =
253             make_shared<op::v0::ConvolutionBackpropData>(output_shape,
254                                                          filters_arg,
255                                                          data_arg,
256                                                          node->get_strides(),
257                                                          node->get_dilations(),
258                                                          node->get_pads_begin(),
259                                                          node->get_pads_end(),
260                                                          Strides(num_spatial_dims, 1));
261         replace_node(node, replacement_node);
262         return replacement_node;
263     }
264
265     shared_ptr<Node> op_cast(shared_ptr<op::v1::Divide> node)
266     {
267         const auto input_arg0 = node->input_value(0);
268         const auto input_arg1 = node->input_value(1);
269         const auto autob = node->get_autob();
270         const bool pydiv = node->is_pythondiv();
271         auto replacement_node = make_shared<op::v0::Divide>(input_arg0, input_arg1, pydiv, autob);
272         replace_node(node, replacement_node);
273         return replacement_node;
274     }
275
276     shared_ptr<Node> op_cast(shared_ptr<op::v1::Reshape> node)
277     {
278         shared_ptr<Node> replacement_node;
279
280         const auto target_shape_input = node->input_value(1).get_node_shared_ptr();
281         const auto input_rank = node->get_input_partial_shape(0).rank();
282         if (op::is_constant(target_shape_input) && node->get_output_partial_shape(0).is_static() &&
283             input_rank.is_static())
284         {
285             const auto output_shape = node->get_output_shape(0);
286             replacement_node = make_shared<op::Reshape>(
287                 node->input_value(0), get_default_order(input_rank.get_length()), output_shape);
288         }
289         else
290         {
291             NGRAPH_CHECK(replacement_node, "Unable to convert Reshape:v1 with dynamic shape.");
292         }
293
294         replace_node(node, replacement_node);
295         return replacement_node;
296     }
297
298     shared_ptr<Node> op_cast(shared_ptr<op::v1::Equal> node)
299     {
300         return op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
301     }
302
303     shared_ptr<Node> op_cast(shared_ptr<op::v1::Gather> node)
304     {
305         auto axis_node = as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
306
307         NGRAPH_CHECK(axis_node,
308                      "Unable to convert Gather:v1 to Gather:v0 if axis is not constant. Node: ",
309                      *node);
310
311         NGRAPH_CHECK(
312             axis_node->get_element_type() == element::i64,
313             "Unable to convert Gather:v1 to Gather:v0 with axis other type than int64. Node: ",
314             *node);
315
316         int64_t axis = axis_node->get_vector<int64_t>()[0];
317
318         auto replacement_node =
319             make_shared<op::v0::Gather>(node->input_value(0), node->input_value(1), axis);
320         replace_node(node, replacement_node);
321         return replacement_node;
322     }
323
324     shared_ptr<Node> op_cast(shared_ptr<op::v1::Greater> node)
325     {
326         return op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
327     }
328
329     shared_ptr<Node> op_cast(shared_ptr<op::v1::GreaterEqual> node)
330     {
331         return op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node);
332     }
333
334     shared_ptr<Node> op_cast(shared_ptr<op::v1::GroupConvolution> node)
335     {
336         const auto data_arg = node->input_value(0);
337         const auto filters_arg = node->input_value(1);
338         const auto strides = node->get_strides();
339         const size_t num_spatial_dims = strides.size();
340         auto replacement_node = make_shared<op::v0::GroupConvolution>(data_arg,
341                                                                       filters_arg,
342                                                                       node->get_strides(),
343                                                                       node->get_dilations(),
344                                                                       node->get_pads_begin(),
345                                                                       node->get_pads_end(),
346                                                                       Strides(num_spatial_dims, 1),
347                                                                       node->get_auto_pad());
348         replace_node(node, replacement_node);
349         return replacement_node;
350     }
351
352     shared_ptr<Node> op_cast(shared_ptr<op::v1::GroupConvolutionBackpropData> node)
353     {
354         const auto data_arg = node->input_value(0);
355         const auto filters_arg = node->input_value(1);
356
357         NGRAPH_CHECK(data_arg.get_partial_shape().is_static(),
358                      "Unable to convert GroupConvolutionBackpropData:1 to "
359                      "GroupConvolutionBackpropData:0 with dynamic data shape. Node: ",
360                      *node);
361
362         NGRAPH_CHECK(filters_arg.get_partial_shape().is_static(),
363                      "Unable to convert GroupConvolutionBackpropData:1 to "
364                      "GroupConvolutionBackpropData:0 with dynamic filters shape. Node: ",
365                      *node);
366
367         auto filters_shape = filters_arg.get_shape();
368         const size_t groups = filters_shape.at(0);
369
370         const PartialShape output_pshape{node->get_output_partial_shape(0)};
371         NGRAPH_CHECK(output_pshape.is_static(),
372                      "Unable to convert GroupConvolutionBackpropData:v1 to "
373                      "GroupConvolutionBackpropData:v0 "
374                      "if output_shape is dynamic. Node: ",
375                      *node);
376         Shape output_shape = output_pshape.to_shape();
377
378         // Convert filters data layout from [GROUPS, C_INPUT, C_OUTPUT, K_D, ..., K_1]
379         // into [C x M/group x k1 x k2 x ... x kn]
380         filters_shape.erase(filters_shape.begin());
381         filters_shape[0] *= groups;
382
383         auto reshaped_filters = builder::opset1::reshape(node->input_value(1), filters_shape);
384
385         auto replacement_node = make_shared<op::v0::GroupConvolutionBackpropData>(
386             op::Constant::create(data_arg.get_element_type(), output_shape, {0}),
387             reshaped_filters,
388             data_arg,
389             node->get_strides(),
390             node->get_dilations(),
391             node->get_pads_begin(),
392             node->get_pads_end(),
393             groups);
394         replace_node(node, replacement_node);
395         return replacement_node;
396     }
397
398     shared_ptr<Node> op_cast(shared_ptr<op::v1::Less> node)
399     {
400         return op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
401     }
402
403     shared_ptr<Node> op_cast(shared_ptr<op::v1::LessEqual> node)
404     {
405         return op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
406     }
407
408     shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalNot> node)
409     {
410         auto replacement_node = make_shared<op::v0::Not>(node->input_value(0));
411         replace_node(node, replacement_node);
412         return replacement_node;
413     }
414
415     shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalOr> node)
416     {
417         return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
418     }
419
420     shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalXor> node)
421     {
422         return op_cast_binary_elementwise_node<op::v0::Xor, op::v1::LogicalXor>(node);
423     }
424
425     shared_ptr<Node> op_cast(shared_ptr<op::v1::Maximum> node)
426     {
427         return op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
428     }
429
430     shared_ptr<Node> op_cast(shared_ptr<op::v1::Minimum> node)
431     {
432         return op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
433     }
434
435     shared_ptr<Node> op_cast(shared_ptr<op::v1::Multiply> node)
436     {
437         return op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
438     }
439
440     shared_ptr<Node> op_cast(shared_ptr<op::v1::NotEqual> node)
441     {
442         return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
443     }
444
445     shared_ptr<Node> op_cast(shared_ptr<op::v1::OneHot> node)
446     {
447         const auto indices = node->input_value(0);
448         const auto depth = node->input_value(1).get_node();
449         auto on_value = node->input_value(2);
450         auto off_value = node->input_value(3);
451         const auto axis = node->get_axis();
452
453         NGRAPH_CHECK(op::is_constant(depth), "depth input must be constant", *node);
454         const auto output_pshape = node->get_output_partial_shape(0);
455         NGRAPH_CHECK(output_pshape.is_static(), "output shape must be static", *node);
456         const auto output_shape = output_pshape.to_shape();
457
458         auto one_hot = std::make_shared<ngraph::op::Convert>(
459             std::make_shared<ngraph::op::OneHot>(indices, output_shape, axis),
460             on_value.get_element_type());
461
462         auto broadcasted_values = builder::numpy_broadcast_outputs({one_hot, on_value, off_value});
463         on_value = broadcasted_values[1];
464         off_value = broadcasted_values[2];
465
466         auto replacement_node = one_hot * (on_value - off_value) + off_value;
467
468         replace_node(node, replacement_node);
469         return replacement_node;
470     }
471
472     shared_ptr<Node> op_cast(shared_ptr<op::v1::Power> node)
473     {
474         return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
475     }
476
477     shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMax> node)
478     {
479         auto replacement_node = op_cast_reduction_node<op::v0::Max, op::v1::ReduceMax>(node);
480         replace_node(node, replacement_node);
481         return replacement_node;
482     }
483
484     shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMean> node)
485     {
486         // ReduceMean = Sum / Count
487         auto sum_node = op_cast_reduction_node<op::v0::Sum, op::v1::ReduceMean>(node);
488
489         // Count = Sum(Constant(1, shape=data.shape))
490         const auto data = node->input_value(0);
491         const auto axes = node->input_value(1);
492         const auto const_node =
493             op::v0::Constant::create(data.get_element_type(), data.get_shape(), {1});
494         std::shared_ptr<Node> count_node = std::make_shared<op::v0::Sum>(const_node, axes);
495
496         // Support keep_dims attribute
497         if (node->get_keep_dims())
498         {
499             // In order to keep the original dimensions we need to reshape the Count node
500             // before we use it in Divide with NUMPY broadcast
501             auto output_shape = count_node->get_shape();
502             auto reshaped_output_shape = output_shape;
503             for (const auto& axis : node->get_reduction_axes())
504             {
505                 reshaped_output_shape.insert(reshaped_output_shape.begin() + axis, 1);
506             }
507             count_node = make_shared<op::Reshape>(
508                 count_node->output(0), get_default_order(output_shape), reshaped_output_shape);
509         }
510
511         const auto replacement_node =
512             std::make_shared<op::v0::Divide>(sum_node, count_node, op::AutoBroadcastSpec::NUMPY);
513         replace_node(node, replacement_node);
514         return replacement_node;
515     }
516
517     shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMin> node)
518     {
519         auto replacement_node = op_cast_reduction_node<op::v0::Min, op::v1::ReduceMin>(node);
520         replace_node(node, replacement_node);
521         return replacement_node;
522     }
523
524     shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceProd> node)
525     {
526         auto replacement_node = op_cast_reduction_node<op::v0::Product, op::v1::ReduceProd>(node);
527         replace_node(node, replacement_node);
528         return replacement_node;
529     }
530
531     shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceSum> node)
532     {
533         auto replacement_node = op_cast_reduction_node<op::v0::Sum, op::v1::ReduceSum>(node);
534         replace_node(node, replacement_node);
535         return replacement_node;
536     }
537
538     shared_ptr<Node> op_cast(shared_ptr<op::v1::Reverse> node)
539     {
540         auto axes_node = node->input_value(1).get_node_shared_ptr();
541         NGRAPH_CHECK(op::is_constant(axes_node),
542                      "Unable to convert Reverse:v1 to Reverse:v0 "
543                      "if reduction axes are not constant. Node: ",
544                      *node);
545         const auto axes_node_const = as_type_ptr<op::Constant>(axes_node);
546         AxisSet axes{};
547         if (node->get_mode() == op::v1::Reverse::Mode::INDEX)
548         {
549             axes = axes_node_const->get_axis_vector_val();
550         }
551         else // Mode::MASK
552         {
553             auto axes_mask = axes_node_const->get_vector<bool>();
554             for (size_t i = 0; i < axes_mask.size(); ++i)
555             {
556                 if (axes_mask[i])
557                 {
558                     axes.emplace(i);
559                 }
560             }
561         }
562         auto replacement_node = make_shared<op::v0::Reverse>(node->input_value(0), axes);
563
564         replace_node(node, replacement_node);
565         return replacement_node;
566     }
567
568     shared_ptr<Node> op_cast(shared_ptr<op::v1::Select> node)
569     {
570         ngraph::pass::ImplicitBroadcastElimination().run_on_node(node);
571         auto replacement_node = make_shared<op::v0::Select>(
572             node->input_value(0), node->input_value(1), node->input_value(2));
573         replace_node(node, replacement_node);
574         return replacement_node;
575     }
576
577     shared_ptr<Node> op_cast(shared_ptr<op::v1::StridedSlice> node)
578     {
579         auto convert_mask_to_axes = [](const std::vector<int64_t>& mask) {
580             AxisSet axes{};
581             for (auto i = 0; i < mask.size(); ++i)
582             {
583                 if (mask[i] == 1)
584                 {
585                     axes.emplace(i);
586                 }
587             }
588             return axes;
589         };
590
591         const auto input_data = node->input_value(0);
592         const auto input_data_pshape = input_data.get_partial_shape();
593
594         NGRAPH_CHECK(input_data_pshape.is_static(),
595                      "Unable to convert StridedSlice:v1 to Slice:v0 "
596                      "if input rank is not static. Node: ",
597                      *node);
598
599         const auto begin_const =
600             as_type_ptr<op::Constant>(node->input_value(1).get_node_shared_ptr());
601         const auto end_const =
602             as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
603         const auto strides = as_type_ptr<op::Constant>(node->input_value(3).get_node_shared_ptr());
604
605         NGRAPH_CHECK(begin_const && end_const && strides,
606                      "Unable to convert StridedSlice:v1 to Slice:v0 "
607                      "if begin, end or strides are not constant. Node: ",
608                      *node);
609
610         SlicePlan p = make_slice_plan(input_data_pshape.to_shape(),
611                                       begin_const->get_vector<int64_t>(),
612                                       end_const->get_vector<int64_t>(),
613                                       strides->get_vector<int64_t>(),
614                                       convert_mask_to_axes(node->get_begin_mask()),
615                                       convert_mask_to_axes(node->get_end_mask()),
616                                       convert_mask_to_axes(node->get_new_axis_mask()),
617                                       convert_mask_to_axes(node->get_shrink_axis_mask()),
618                                       convert_mask_to_axes(node->get_ellipsis_mask()));
619
620         shared_ptr<Node> replacement_node =
621             make_shared<op::v0::Slice>(input_data,
622                                        Coordinate(p.begins.begin(), p.begins.end()),
623                                        Coordinate(p.ends.begin(), p.ends.end()),
624                                        Strides(p.strides.begin(), p.strides.end()));
625
626         if (p.reshape_in_shape != p.reshape_out_shape)
627         {
628             replacement_node =
629                 make_shared<op::Reshape>(replacement_node,
630                                          ngraph::get_default_order(p.reshape_in_shape),
631                                          p.reshape_out_shape);
632         }
633
634         if (!p.reverse_axes.empty())
635         {
636             replacement_node = make_shared<op::Reverse>(replacement_node, p.reverse_axes);
637         }
638
639         replace_node(node, replacement_node);
640         return replacement_node;
641     }
642
643     shared_ptr<Node> op_cast(shared_ptr<op::v1::Split> node)
644     {
645         const auto num_splits = node->get_num_splits();
646
647         auto replacement_node =
648             make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), num_splits);
649
650         replace_node(node, replacement_node);
651         return replacement_node;
652     }
653
654     shared_ptr<Node> op_cast(shared_ptr<op::v1::Subtract> node)
655     {
656         return op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
657     }
658
659     shared_ptr<Node> op_cast(shared_ptr<op::v1::TopK> node)
660     {
661         const auto axis = node->get_axis();
662         const auto sort_type = node->get_sort_type();
663         const auto index_elem_type = node->get_index_element_type();
664
665         bool compute_max;
666         switch (node->get_mode())
667         {
668         case op::v1::TopK::Mode::MAX: compute_max = true; break;
669         case op::v1::TopK::Mode::MIN: compute_max = false; break;
670         default: break;
671         }
672
673         const auto arg_node = node->input_value(0);
674         const auto k_node = node->input_value(1);
675
676         auto replacement_node = make_shared<op::v0::TopK>(
677             arg_node, k_node, axis, index_elem_type, compute_max, sort_type);
678
679         // values output will be 0, indices 1
680         vector<int64_t> output_order{1, 0};
681         replace_node(node, replacement_node, output_order);
682         return replacement_node;
683     }
684
685     shared_ptr<Node> op_cast(shared_ptr<op::v1::Transpose> node)
686     {
687         const auto data = node->input_value(0);
688
689         const auto data_pshape = data.get_partial_shape();
690         NGRAPH_CHECK(data_pshape.is_static(),
691                      "Unable to convert Transpose:v1 to Reshape:v0 "
692                      "if data shape is dynamic. Node: ",
693                      *node);
694         const auto data_shape = data_pshape.to_shape();
695
696         const auto order_node = node->input_value(1).get_node_shared_ptr();
697         NGRAPH_CHECK(op::is_constant(order_node),
698                      "Unable to convert Transpose:v1 to Reshape:v0 "
699                      "if order node is not constant. Node: ",
700                      *node);
701         const auto order_const = as_type_ptr<op::Constant>(order_node);
702
703         auto order = order_const->get_axis_vector_val();
704         Shape out_shape = data_shape;
705         if (order.empty())
706         {
707             order.resize(out_shape.size());
708             iota(begin(order), end(order), 0);
709         }
710         else
711         {
712             for (size_t i = 0; i < order.size(); ++i)
713             {
714                 out_shape[i] = data_shape.at(order.at(i));
715             }
716         }
717
718         auto replacement_node = make_shared<op::v0::Reshape>(data, order, out_shape);
719         replace_node(node, replacement_node);
720         return replacement_node;
721     }
722
723     shared_ptr<Node> op_cast(shared_ptr<op::v1::VariadicSplit> node)
724     {
725         const auto split_lengths = node->input_value(2).get_node_shared_ptr();
726
727         NGRAPH_CHECK(op::is_constant(split_lengths),
728                      "Unable to convert VariadicSplit:v1 to Split:v0 "
729                      "if 'split_lengths' input is not constant. Node: ",
730                      *node);
731
732         const auto splits = as_type_ptr<op::Constant>(split_lengths)->cast_vector<int64_t>();
733         const std::vector<size_t> splits_unsigned{splits.begin(), splits.end()};
734
735         auto replacement_node =
736             make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), splits_unsigned);
737
738         replace_node(node, replacement_node);
739         return replacement_node;
740     }
741
742     using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
743
744     template <typename T>
745     bool op_cast_thunk(shared_ptr<Node> node)
746     {
747         auto downgraded_node = op_cast(as_type_ptr<T>(node));
748         if (downgraded_node)
749         {
750             if (ngraph::get_provenance_enabled())
751             {
752                 const std::string provenance_tag =
753                     "<Opset0_Downgrade (v1 " + std::string(node->get_type_name()) + ")>";
754                 downgraded_node->add_provenance_tags_above(node->input_values(), {provenance_tag});
755             }
756             return true;
757         }
758         return false;
759     }
760
761     DispatchMap& get_dispatch_map()
762     {
763         static DispatchMap dispatch_map{
764 #define NGRAPH_OP(NAME, NAMESPACE) {NAMESPACE::NAME::type_info, op_cast_thunk<NAMESPACE::NAME>},
765 #include "ngraph/opsets/opset1_tbl.hpp"
766 #undef NGRAPH_OP
767         };
768         return dispatch_map;
769     }
770 } // namespace
771
772 bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
773 {
774     bool modified = false;
775     auto& dispatch_map = get_dispatch_map();
776     auto it = dispatch_map.find(node->get_type_info());
777     if (it != dispatch_map.end())
778     {
779         modified = it->second(node);
780     }
781     return modified;
782 }