[NGraph] Add scatterNDUpdate and scatterUpdate reference implementations (#1494)
[platform/upstream/dldt.git] / ngraph / test / runtime / pass / opset0_downgrade.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <algorithm>
18 #include <cstdint>
19 #include <functional>
20 #include <numeric>
21
22 #include "ngraph/builder/autobroadcast.hpp"
23 #include "ngraph/builder/reshape.hpp"
24 #include "ngraph/graph_util.hpp"
25 #include "ngraph/node.hpp"
26 #include "ngraph/op/util/attr_types.hpp"
27 #include "ngraph/op/util/op_types.hpp"
28 #include "ngraph/ops.hpp"
29 #include "ngraph/provenance.hpp"
30 #include "ngraph/slice_plan.hpp"
31 #include "ngraph/type.hpp"
32 #include "ngraph/validation_util.hpp"
33 #include "op/avg_pool.hpp"
34 #include "op/convolution.hpp"
35 #include "op/group_conv.hpp"
36 #include "pass/implicit_broadcast_elimination.hpp"
37 #include "pass/opset0_downgrade.hpp"
38
39 using namespace std;
40 using namespace ngraph;
41
42 namespace
43 {
44     template <typename OpV0, typename OpV1>
45     shared_ptr<Node> op_cast_binary_elementwise_node(const shared_ptr<OpV1>& node)
46     {
47         const auto input_arg0 = node->input_value(0);
48         const auto input_arg1 = node->input_value(1);
49         const auto autob = node->get_autob();
50         auto replacement_node = make_shared<OpV0>(input_arg0, input_arg1, autob);
51         replace_node(node, replacement_node);
52         return replacement_node;
53     }
54
55     template <typename OpV0, typename OpV1>
56     shared_ptr<Node> op_cast_reduction_node(const shared_ptr<OpV1>& node)
57     {
58         auto replacement_node = make_shared<OpV0>(node->input_value(0), node->input_value(1));
59         if (node->get_keep_dims())
60         {
61             string v1_op_name = string{node->get_type_name()} + ":v1";
62             string v0_op_name = string{OpV0{}.get_type_name()} + ":v0";
63
64             NGRAPH_CHECK(node->reduction_axes_constant(),
65                          "Unable to convert ",
66                          v1_op_name,
67                          "to ",
68                          v0_op_name,
69                          " if reduction axes are not constant (for keep_dims=true). Node: ",
70                          *node);
71             auto output_pshape = replacement_node->get_output_partial_shape(0);
72             NGRAPH_CHECK(output_pshape.is_static(),
73                          "Unable to convert ",
74                          v1_op_name,
75                          "to ",
76                          v0_op_name,
77                          " if output shape is dynamic (for keep_dims=true). Node: ",
78                          *node);
79             const auto output_shape = output_pshape.to_shape();
80             auto reshaped_output_shape = output_shape;
81             for (const auto& axis : node->get_reduction_axes())
82             {
83                 reshaped_output_shape.insert(reshaped_output_shape.begin() + axis, 1);
84             }
85             auto reshaped_product = make_shared<op::Reshape>(replacement_node->output(0),
86                                                              get_default_order(output_shape),
87                                                              reshaped_output_shape);
88             return reshaped_product;
89         }
90         else
91         {
92             return replacement_node;
93         }
94     }
95
96     // Default is that we did nothing
97     shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; }
98     shared_ptr<Node> op_cast(shared_ptr<op::v1::Add> node)
99     {
100         return op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
101     }
102
103     shared_ptr<Node> op_cast(shared_ptr<op::v1::AvgPool> node)
104     {
105         auto const input_arg = node->input_value(0);
106         const auto ceil_mode = static_cast<bool>(node->get_rounding_type());
107         const auto include_padding_in_avg_computation = !node->get_exclude_pad();
108         const auto pad_type = node->get_auto_pad();
109         const auto padding_below = node->get_pads_begin();
110         const auto padding_above = node->get_pads_end();
111         const auto window_movement_strides = node->get_strides();
112         const auto window_shape = node->get_kernel();
113
114         auto replacement_node = make_shared<op::v0::AvgPool>(input_arg,
115                                                              window_shape,
116                                                              window_movement_strides,
117                                                              padding_below,
118                                                              padding_above,
119                                                              include_padding_in_avg_computation,
120                                                              pad_type,
121                                                              ceil_mode);
122         replace_node(node, replacement_node);
123         return replacement_node;
124     }
125
126     shared_ptr<Node> op_cast(shared_ptr<op::v1::Broadcast> node)
127     {
128         auto arg = node->input_value(0);
129         auto arg_pshape = arg.get_partial_shape();
130         auto arg_rank = arg_pshape.rank();
131         auto target_shape_input = node->input_value(1);
132
133         shared_ptr<Node> replacement_node;
134
135         NGRAPH_CHECK(arg_pshape.is_static(),
136                      "Unable to convert Broadcast:v1 to Broadcast:v0 "
137                      "if argument shape is not static. Node: ",
138                      *node);
139         const auto& arg_shape = arg_pshape.to_shape();
140
141         NGRAPH_CHECK(op::is_constant(target_shape_input.get_node()));
142         auto target_shape = node->get_output_shape(0);
143         NGRAPH_CHECK(node->get_broadcast_axes().first);
144
145         // (Re)construct axes_mapping.
146         AxisSet broadcast_axes = node->get_broadcast_axes().second;
147         std::vector<size_t> axes_mapping{
148             ngraph::builder::opset1::get_axes_mapping(target_shape, broadcast_axes)};
149
150         Output<Node> squeezed_arg = arg;
151         // Collect axes to squeeze. Broadcast v0 "adds" new axes, thus we have to squeeze
152         // the empty ones (dim:=1), which would be broadcasted by Broadcast v1.
153         std::vector<size_t> empty_axes;
154         for (size_t a{0}; a < axes_mapping.size(); ++a)
155         {
156             if (arg_shape.at(a) == 1 && target_shape.at(axes_mapping.at(a)) != 1)
157             {
158                 empty_axes.push_back(a);
159             }
160         }
161         // Check if arg_shape contains some more empty dimensions marked to broadcast.
162         // If axes_mapping size is less than arg_shape size, then some of arg dimensions may
163         // be equal to one and marked to broadcast.
164         if (axes_mapping.size() < arg_shape.size())
165         {
166             for (size_t a{axes_mapping.size()}; a < arg_shape.size(); ++a)
167             {
168                 if (arg_shape.at(a) == 1)
169                 {
170                     empty_axes.push_back(a);
171                 }
172             }
173         }
174         if (!empty_axes.empty())
175         {
176             squeezed_arg = builder::squeeze(arg, empty_axes);
177         }
178
179         replacement_node =
180             make_shared<op::v0::Broadcast>(squeezed_arg, target_shape, broadcast_axes);
181         replace_node(node, replacement_node);
182         return replacement_node;
183     }
184
185     shared_ptr<Node> op_cast(shared_ptr<op::v1::Convolution> node)
186     {
187         const auto data_arg = node->input_value(0);
188         const auto filters_arg = node->input_value(1);
189         const auto strides = node->get_strides();
190         const size_t num_spatial_dims = strides.size();
191         auto replacement_node = make_shared<op::v0::Convolution>(data_arg,
192                                                                  filters_arg,
193                                                                  node->get_strides(),
194                                                                  node->get_dilations(),
195                                                                  node->get_pads_begin(),
196                                                                  node->get_pads_end(),
197                                                                  Strides(num_spatial_dims, 1),
198                                                                  node->get_auto_pad());
199         replace_node(node, replacement_node);
200         return replacement_node;
201     }
202
203     shared_ptr<Node> op_cast(shared_ptr<op::v1::ConvolutionBackpropData> node)
204     {
205         const auto data_arg = node->input_value(0);
206         const auto filters_arg = node->input_value(1);
207
208         auto data_pshape = data_arg.get_partial_shape();
209         auto filters_pshape = filters_arg.get_partial_shape();
210
211         NGRAPH_CHECK(data_pshape.rank().is_static() && data_pshape[0].is_static() &&
212                          filters_pshape.rank().is_static() && filters_pshape[1].is_static(),
213                      "Unable to convert ConvolutionBackpropData:v1 to ConvolutionBackpropData:v0 "
214                      "if data shape N and filters shape C dimensions are not static. Node: ",
215                      *node);
216
217         const size_t num_spatial_dims = data_pshape.rank().get_length() - 2;
218
219         const PartialShape output_pshape{node->get_output_partial_shape(0)};
220         NGRAPH_CHECK(output_pshape.is_static(),
221                      "Unable to convert ConvolutionBackpropData:v1 to ConvolutionBackpropData:v0 "
222                      "if output shape is dynamic. Node: ",
223                      *node);
224         Shape output_shape = output_pshape.to_shape();
225
226         auto replacement_node =
227             make_shared<op::v0::ConvolutionBackpropData>(output_shape,
228                                                          filters_arg,
229                                                          data_arg,
230                                                          node->get_strides(),
231                                                          node->get_dilations(),
232                                                          node->get_pads_begin(),
233                                                          node->get_pads_end(),
234                                                          Strides(num_spatial_dims, 1));
235         replace_node(node, replacement_node);
236         return replacement_node;
237     }
238
239     shared_ptr<Node> op_cast(shared_ptr<op::v1::Divide> node)
240     {
241         const auto input_arg0 = node->input_value(0);
242         const auto input_arg1 = node->input_value(1);
243         const auto autob = node->get_autob();
244         const bool pydiv = node->is_pythondiv();
245         auto replacement_node = make_shared<op::v0::Divide>(input_arg0, input_arg1, pydiv, autob);
246         replace_node(node, replacement_node);
247         return replacement_node;
248     }
249
250     shared_ptr<Node> op_cast(shared_ptr<op::v1::Reshape> node)
251     {
252         shared_ptr<Node> replacement_node;
253
254         const auto target_shape_input = node->input_value(1).get_node_shared_ptr();
255         const auto input_rank = node->get_input_partial_shape(0).rank();
256         if (op::is_constant(target_shape_input) && node->get_output_partial_shape(0).is_static() &&
257             input_rank.is_static())
258         {
259             const auto output_shape = node->get_output_shape(0);
260             replacement_node = make_shared<op::Reshape>(
261                 node->input_value(0), get_default_order(input_rank.get_length()), output_shape);
262         }
263         else
264         {
265             NGRAPH_CHECK(replacement_node, "Unable to convert Reshape:v1 with dynamic shape.");
266         }
267
268         replace_node(node, replacement_node);
269         return replacement_node;
270     }
271
272     shared_ptr<Node> op_cast(shared_ptr<op::v1::Equal> node)
273     {
274         return op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
275     }
276
277     shared_ptr<Node> op_cast(shared_ptr<op::v1::Gather> node)
278     {
279         auto axis_node = as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
280
281         NGRAPH_CHECK(axis_node,
282                      "Unable to convert Gather:v1 to Gather:v0 if axis is not constant. Node: ",
283                      *node);
284
285         NGRAPH_CHECK(
286             axis_node->get_element_type() == element::i64,
287             "Unable to convert Gather:v1 to Gather:v0 with axis other type than int64. Node: ",
288             *node);
289
290         int64_t axis = axis_node->get_vector<int64_t>()[0];
291
292         auto replacement_node =
293             make_shared<op::v0::Gather>(node->input_value(0), node->input_value(1), axis);
294         replace_node(node, replacement_node);
295         return replacement_node;
296     }
297
298     shared_ptr<Node> op_cast(shared_ptr<op::v1::Greater> node)
299     {
300         return op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
301     }
302
303     shared_ptr<Node> op_cast(shared_ptr<op::v1::GreaterEqual> node)
304     {
305         return op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node);
306     }
307
308     shared_ptr<Node> op_cast(shared_ptr<op::v1::GroupConvolution> node)
309     {
310         const auto data_arg = node->input_value(0);
311         const auto filters_arg = node->input_value(1);
312         const auto strides = node->get_strides();
313         const size_t num_spatial_dims = strides.size();
314         auto replacement_node = make_shared<op::v0::GroupConvolution>(data_arg,
315                                                                       filters_arg,
316                                                                       node->get_strides(),
317                                                                       node->get_dilations(),
318                                                                       node->get_pads_begin(),
319                                                                       node->get_pads_end(),
320                                                                       Strides(num_spatial_dims, 1),
321                                                                       node->get_auto_pad());
322         replace_node(node, replacement_node);
323         return replacement_node;
324     }
325
326     shared_ptr<Node> op_cast(shared_ptr<op::v1::GroupConvolutionBackpropData> node)
327     {
328         const auto data_arg = node->input_value(0);
329         const auto filters_arg = node->input_value(1);
330
331         NGRAPH_CHECK(data_arg.get_partial_shape().is_static(),
332                      "Unable to convert GroupConvolutionBackpropData:1 to "
333                      "GroupConvolutionBackpropData:0 with dynamic data shape. Node: ",
334                      *node);
335
336         NGRAPH_CHECK(filters_arg.get_partial_shape().is_static(),
337                      "Unable to convert GroupConvolutionBackpropData:1 to "
338                      "GroupConvolutionBackpropData:0 with dynamic filters shape. Node: ",
339                      *node);
340
341         auto filters_shape = filters_arg.get_shape();
342         const size_t groups = filters_shape.at(0);
343
344         const PartialShape output_pshape{node->get_output_partial_shape(0)};
345         NGRAPH_CHECK(output_pshape.is_static(),
346                      "Unable to convert GroupConvolutionBackpropData:v1 to "
347                      "GroupConvolutionBackpropData:v0 "
348                      "if output_shape is dynamic. Node: ",
349                      *node);
350         Shape output_shape = output_pshape.to_shape();
351
352         // Convert filters data layout from [GROUPS, C_INPUT, C_OUTPUT, K_D, ..., K_1]
353         // into [C x M/group x k1 x k2 x ... x kn]
354         filters_shape.erase(filters_shape.begin());
355         filters_shape[0] *= groups;
356
357         auto reshaped_filters = builder::opset1::reshape(node->input_value(1), filters_shape);
358
359         auto replacement_node = make_shared<op::v0::GroupConvolutionBackpropData>(
360             op::Constant::create(data_arg.get_element_type(), output_shape, {0}),
361             reshaped_filters,
362             data_arg,
363             node->get_strides(),
364             node->get_dilations(),
365             node->get_pads_begin(),
366             node->get_pads_end(),
367             groups);
368         replace_node(node, replacement_node);
369         return replacement_node;
370     }
371
372     shared_ptr<Node> op_cast(shared_ptr<op::v1::Less> node)
373     {
374         return op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
375     }
376
377     shared_ptr<Node> op_cast(shared_ptr<op::v1::LessEqual> node)
378     {
379         return op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
380     }
381
382     shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalNot> node)
383     {
384         auto replacement_node = make_shared<op::v0::Not>(node->input_value(0));
385         replace_node(node, replacement_node);
386         return replacement_node;
387     }
388
389     shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalOr> node)
390     {
391         return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
392     }
393
394     shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalXor> node)
395     {
396         return op_cast_binary_elementwise_node<op::v0::Xor, op::v1::LogicalXor>(node);
397     }
398
399     shared_ptr<Node> op_cast(shared_ptr<op::v1::Maximum> node)
400     {
401         return op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
402     }
403
404     shared_ptr<Node> op_cast(shared_ptr<op::v1::Minimum> node)
405     {
406         return op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
407     }
408
409     shared_ptr<Node> op_cast(shared_ptr<op::v1::Multiply> node)
410     {
411         return op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
412     }
413
414     shared_ptr<Node> op_cast(shared_ptr<op::v1::NotEqual> node)
415     {
416         return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
417     }
418
419     shared_ptr<Node> op_cast(shared_ptr<op::v1::OneHot> node)
420     {
421         const auto indices = node->input_value(0);
422         const auto depth = node->input_value(1).get_node();
423         auto on_value = node->input_value(2);
424         auto off_value = node->input_value(3);
425         const auto axis = node->get_axis();
426
427         NGRAPH_CHECK(op::is_constant(depth), "depth input must be constant", *node);
428         const auto output_pshape = node->get_output_partial_shape(0);
429         NGRAPH_CHECK(output_pshape.is_static(), "output shape must be static", *node);
430         const auto output_shape = output_pshape.to_shape();
431
432         auto one_hot = std::make_shared<ngraph::op::Convert>(
433             std::make_shared<ngraph::op::OneHot>(indices, output_shape, axis),
434             on_value.get_element_type());
435
436         auto broadcasted_values = builder::numpy_broadcast_outputs({one_hot, on_value, off_value});
437         on_value = broadcasted_values[1];
438         off_value = broadcasted_values[2];
439
440         auto replacement_node = one_hot * (on_value - off_value) + off_value;
441
442         replace_node(node, replacement_node);
443         return replacement_node;
444     }
445
446     shared_ptr<Node> op_cast(shared_ptr<op::v1::Pad> node)
447     {
448         const auto pad_arg = node->input_value(0);
449         Output<Node> pad_value;
450         if (node->get_input_size() == 4)
451         {
452             pad_value = node->input_value(3);
453         }
454         else
455         {
456             pad_value =
457                 make_shared<op::Constant>(pad_arg.get_element_type(), Shape{}, vector<float>{0.f});
458         }
459         auto replacement_node = make_shared<op::v0::Pad>(
460             pad_arg, pad_value, node->get_pads_begin(), node->get_pads_end(), node->get_pad_mode());
461
462         replace_node(node, replacement_node);
463         return replacement_node;
464     }
465
466     shared_ptr<Node> op_cast(shared_ptr<op::v1::Power> node)
467     {
468         return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
469     }
470
471     shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMax> node)
472     {
473         auto replacement_node = op_cast_reduction_node<op::v0::Max, op::v1::ReduceMax>(node);
474         replace_node(node, replacement_node);
475         return replacement_node;
476     }
477
478     shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMean> node)
479     {
480         // ReduceMean = Sum / Count
481         auto sum_node = op_cast_reduction_node<op::v0::Sum, op::v1::ReduceMean>(node);
482
483         // Count = Sum(Constant(1, shape=data.shape))
484         const auto data = node->input_value(0);
485         const auto axes = node->input_value(1);
486         const auto const_node =
487             op::v0::Constant::create(data.get_element_type(), data.get_shape(), {1});
488         std::shared_ptr<Node> count_node = std::make_shared<op::v0::Sum>(const_node, axes);
489
490         // Support keep_dims attribute
491         if (node->get_keep_dims())
492         {
493             // In order to keep the original dimensions we need to reshape the Count node
494             // before we use it in Divide with NUMPY broadcast
495             auto output_shape = count_node->get_shape();
496             auto reshaped_output_shape = output_shape;
497             for (const auto& axis : node->get_reduction_axes())
498             {
499                 reshaped_output_shape.insert(reshaped_output_shape.begin() + axis, 1);
500             }
501             count_node = make_shared<op::Reshape>(
502                 count_node->output(0), get_default_order(output_shape), reshaped_output_shape);
503         }
504
505         const auto replacement_node =
506             std::make_shared<op::v0::Divide>(sum_node, count_node, op::AutoBroadcastSpec::NUMPY);
507         replace_node(node, replacement_node);
508         return replacement_node;
509     }
510
511     shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMin> node)
512     {
513         auto replacement_node = op_cast_reduction_node<op::v0::Min, op::v1::ReduceMin>(node);
514         replace_node(node, replacement_node);
515         return replacement_node;
516     }
517
518     shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceProd> node)
519     {
520         auto replacement_node = op_cast_reduction_node<op::v0::Product, op::v1::ReduceProd>(node);
521         replace_node(node, replacement_node);
522         return replacement_node;
523     }
524
525     shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceSum> node)
526     {
527         auto replacement_node = op_cast_reduction_node<op::v0::Sum, op::v1::ReduceSum>(node);
528         replace_node(node, replacement_node);
529         return replacement_node;
530     }
531
532     shared_ptr<Node> op_cast(shared_ptr<op::v1::Reverse> node)
533     {
534         auto axes_node = node->input_value(1).get_node_shared_ptr();
535         NGRAPH_CHECK(op::is_constant(axes_node),
536                      "Unable to convert Reverse:v1 to Reverse:v0 "
537                      "if reduction axes are not constant. Node: ",
538                      *node);
539         const auto axes_node_const = as_type_ptr<op::Constant>(axes_node);
540         AxisSet axes{};
541         if (node->get_mode() == op::v1::Reverse::Mode::INDEX)
542         {
543             axes = axes_node_const->get_axis_vector_val();
544         }
545         else // Mode::MASK
546         {
547             auto axes_mask = axes_node_const->get_vector<bool>();
548             for (size_t i = 0; i < axes_mask.size(); ++i)
549             {
550                 if (axes_mask[i])
551                 {
552                     axes.emplace(i);
553                 }
554             }
555         }
556         auto replacement_node = make_shared<op::v0::Reverse>(node->input_value(0), axes);
557
558         replace_node(node, replacement_node);
559         return replacement_node;
560     }
561
562     shared_ptr<Node> op_cast(shared_ptr<op::v1::Select> node)
563     {
564         ngraph::pass::ImplicitBroadcastElimination().run_on_node(node);
565         auto replacement_node = make_shared<op::v0::Select>(
566             node->input_value(0), node->input_value(1), node->input_value(2));
567         replace_node(node, replacement_node);
568         return replacement_node;
569     }
570
571     shared_ptr<Node> op_cast(shared_ptr<op::v1::StridedSlice> node)
572     {
573         auto convert_mask_to_axes = [](const std::vector<int64_t>& mask) {
574             AxisSet axes{};
575             for (auto i = 0; i < mask.size(); ++i)
576             {
577                 if (mask[i] == 1)
578                 {
579                     axes.emplace(i);
580                 }
581             }
582             return axes;
583         };
584
585         const auto input_data = node->input_value(0);
586         const auto input_data_pshape = input_data.get_partial_shape();
587
588         NGRAPH_CHECK(input_data_pshape.is_static(),
589                      "Unable to convert StridedSlice:v1 to Slice:v0 "
590                      "if input rank is not static. Node: ",
591                      *node);
592
593         const auto begin_const =
594             as_type_ptr<op::Constant>(node->input_value(1).get_node_shared_ptr());
595         const auto end_const =
596             as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
597         const auto strides = as_type_ptr<op::Constant>(node->input_value(3).get_node_shared_ptr());
598
599         NGRAPH_CHECK(begin_const && end_const && strides,
600                      "Unable to convert StridedSlice:v1 to Slice:v0 "
601                      "if begin, end or strides are not constant. Node: ",
602                      *node);
603
604         SlicePlan p = make_slice_plan(input_data_pshape.to_shape(),
605                                       begin_const->get_vector<int64_t>(),
606                                       end_const->get_vector<int64_t>(),
607                                       strides->get_vector<int64_t>(),
608                                       convert_mask_to_axes(node->get_begin_mask()),
609                                       convert_mask_to_axes(node->get_end_mask()),
610                                       convert_mask_to_axes(node->get_new_axis_mask()),
611                                       convert_mask_to_axes(node->get_shrink_axis_mask()),
612                                       convert_mask_to_axes(node->get_ellipsis_mask()));
613
614         shared_ptr<Node> replacement_node =
615             make_shared<op::v0::Slice>(input_data,
616                                        Coordinate(p.begins.begin(), p.begins.end()),
617                                        Coordinate(p.ends.begin(), p.ends.end()),
618                                        Strides(p.strides.begin(), p.strides.end()));
619
620         if (p.reshape_in_shape != p.reshape_out_shape)
621         {
622             replacement_node =
623                 make_shared<op::Reshape>(replacement_node,
624                                          ngraph::get_default_order(p.reshape_in_shape),
625                                          p.reshape_out_shape);
626         }
627
628         if (!p.reverse_axes.empty())
629         {
630             replacement_node = make_shared<op::Reverse>(replacement_node, p.reverse_axes);
631         }
632
633         replace_node(node, replacement_node);
634         return replacement_node;
635     }
636
637     shared_ptr<Node> op_cast(shared_ptr<op::v1::Split> node)
638     {
639         const auto num_splits = node->get_num_splits();
640
641         auto replacement_node =
642             make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), num_splits);
643
644         replace_node(node, replacement_node);
645         return replacement_node;
646     }
647
648     shared_ptr<Node> op_cast(shared_ptr<op::v1::Subtract> node)
649     {
650         return op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
651     }
652
653     shared_ptr<Node> op_cast(shared_ptr<op::v1::TopK> node)
654     {
655         const auto axis = node->get_axis();
656         const auto sort_type = node->get_sort_type();
657         const auto index_elem_type = node->get_index_element_type();
658
659         bool compute_max;
660         switch (node->get_mode())
661         {
662         case op::v1::TopK::Mode::MAX: compute_max = true; break;
663         case op::v1::TopK::Mode::MIN: compute_max = false; break;
664         default: break;
665         }
666
667         const auto arg_node = node->input_value(0);
668         const auto k_node = node->input_value(1);
669
670         auto replacement_node = make_shared<op::v0::TopK>(
671             arg_node, k_node, axis, index_elem_type, compute_max, sort_type);
672
673         // values output will be 0, indices 1
674         vector<int64_t> output_order{1, 0};
675         replace_node(node, replacement_node, output_order);
676         return replacement_node;
677     }
678
679     shared_ptr<Node> op_cast(shared_ptr<op::v1::Transpose> node)
680     {
681         const auto data = node->input_value(0);
682
683         const auto data_pshape = data.get_partial_shape();
684         NGRAPH_CHECK(data_pshape.is_static(),
685                      "Unable to convert Transpose:v1 to Reshape:v0 "
686                      "if data shape is dynamic. Node: ",
687                      *node);
688         const auto data_shape = data_pshape.to_shape();
689
690         const auto order_node = node->input_value(1).get_node_shared_ptr();
691         NGRAPH_CHECK(op::is_constant(order_node),
692                      "Unable to convert Transpose:v1 to Reshape:v0 "
693                      "if order node is not constant. Node: ",
694                      *node);
695         const auto order_const = as_type_ptr<op::Constant>(order_node);
696
697         auto order = order_const->get_axis_vector_val();
698         Shape out_shape = data_shape;
699         if (order.empty())
700         {
701             order.resize(out_shape.size());
702             iota(begin(order), end(order), 0);
703         }
704         else
705         {
706             for (size_t i = 0; i < order.size(); ++i)
707             {
708                 out_shape[i] = data_shape.at(order.at(i));
709             }
710         }
711
712         auto replacement_node = make_shared<op::v0::Reshape>(data, order, out_shape);
713         replace_node(node, replacement_node);
714         return replacement_node;
715     }
716
717     shared_ptr<Node> op_cast(shared_ptr<op::v1::VariadicSplit> node)
718     {
719         const auto split_lengths = node->input_value(2).get_node_shared_ptr();
720
721         NGRAPH_CHECK(op::is_constant(split_lengths),
722                      "Unable to convert VariadicSplit:v1 to Split:v0 "
723                      "if 'split_lengths' input is not constant. Node: ",
724                      *node);
725
726         const auto splits = as_type_ptr<op::Constant>(split_lengths)->cast_vector<int64_t>();
727         const std::vector<size_t> splits_unsigned{splits.begin(), splits.end()};
728
729         auto replacement_node =
730             make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), splits_unsigned);
731
732         replace_node(node, replacement_node);
733         return replacement_node;
734     }
735
736     using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
737
738     template <typename T>
739     bool op_cast_thunk(shared_ptr<Node> node)
740     {
741         auto downgraded_node = op_cast(as_type_ptr<T>(node));
742         if (downgraded_node)
743         {
744             if (ngraph::get_provenance_enabled())
745             {
746                 const std::string provenance_tag =
747                     "<Opset0_Downgrade (v1 " + std::string(node->get_type_name()) + ")>";
748                 downgraded_node->add_provenance_tags_above(node->input_values(), {provenance_tag});
749             }
750             return true;
751         }
752         return false;
753     }
754
755     DispatchMap& get_dispatch_map()
756     {
757         static DispatchMap dispatch_map{
758 #define NGRAPH_OP(NAME, NAMESPACE) {NAMESPACE::NAME::type_info, op_cast_thunk<NAMESPACE::NAME>},
759 #include "ngraph/opsets/opset1_tbl.hpp"
760 #undef NGRAPH_OP
761         };
762         return dispatch_map;
763     }
764 } // namespace
765
766 bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
767 {
768     bool modified = false;
769     auto& dispatch_map = get_dispatch_map();
770     auto it = dispatch_map.find(node->get_type_info());
771     if (it != dispatch_map.end())
772     {
773         modified = it->second(node);
774     }
775     return modified;
776 }