Remove obsoleted v0::Broadcast and BroadcastLike operators (#2779)
[platform/upstream/dldt.git] / ngraph / test / runtime / pass / opset0_downgrade.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <algorithm>
18 #include <cstdint>
19 #include <functional>
20 #include <numeric>
21
22 #include "ngraph/builder/autobroadcast.hpp"
23 #include "ngraph/builder/reshape.hpp"
24 #include "ngraph/graph_util.hpp"
25 #include "ngraph/node.hpp"
26 #include "ngraph/op/util/attr_types.hpp"
27 #include "ngraph/op/util/op_types.hpp"
28 #include "ngraph/ops.hpp"
29 #include "ngraph/provenance.hpp"
30 #include "ngraph/slice_plan.hpp"
31 #include "ngraph/type.hpp"
32 #include "ngraph/validation_util.hpp"
33 #include "op/avg_pool.hpp"
34 #include "op/convolution.hpp"
35 #include "op/group_conv.hpp"
36 #include "pass/implicit_broadcast_elimination.hpp"
37 #include "pass/opset0_downgrade.hpp"
38
39 NGRAPH_SUPPRESS_DEPRECATED_START
40
41 using namespace std;
42 using namespace ngraph;
43
44 namespace opset0_downgrade
45 {
46     template <typename OpV0, typename OpV1>
47     shared_ptr<Node> op_cast_binary_elementwise_node(const shared_ptr<OpV1>& node)
48     {
49         const auto input_arg0 = node->input_value(0);
50         const auto input_arg1 = node->input_value(1);
51         const auto autob = node->get_autob();
52         auto replacement_node = make_shared<OpV0>(input_arg0, input_arg1, autob);
53         replace_node(node, replacement_node);
54         return replacement_node;
55     }
56
57     template <typename OpV0, typename OpV1>
58     shared_ptr<Node> op_cast_reduction_node(const shared_ptr<OpV1>& node)
59     {
60         auto replacement_node = make_shared<OpV0>(node->input_value(0), node->input_value(1));
61         if (node->get_keep_dims())
62         {
63             string v1_op_name = string{node->get_type_name()} + ":v1";
64             string v0_op_name = string{OpV0{}.get_type_name()} + ":v0";
65
66             NGRAPH_CHECK(node->reduction_axes_constant(),
67                          "Unable to convert ",
68                          v1_op_name,
69                          "to ",
70                          v0_op_name,
71                          " if reduction axes are not constant (for keep_dims=true). Node: ",
72                          *node);
73             auto output_pshape = replacement_node->get_output_partial_shape(0);
74             NGRAPH_CHECK(output_pshape.is_static(),
75                          "Unable to convert ",
76                          v1_op_name,
77                          "to ",
78                          v0_op_name,
79                          " if output shape is dynamic (for keep_dims=true). Node: ",
80                          *node);
81             const auto output_shape = output_pshape.to_shape();
82             auto reshaped_output_shape = output_shape;
83             for (const auto& axis : node->get_reduction_axes())
84             {
85                 reshaped_output_shape.insert(reshaped_output_shape.begin() + axis, 1);
86             }
87             auto reshaped_product = make_shared<op::Reshape>(replacement_node->output(0),
88                                                              get_default_order(output_shape),
89                                                              reshaped_output_shape);
90             return reshaped_product;
91         }
92         else
93         {
94             return replacement_node;
95         }
96     }
97
98     // Default is that we did nothing
99     shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; }
100     shared_ptr<Node> op_cast(shared_ptr<op::v1::Add> node)
101     {
102         return op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
103     }
104
105     shared_ptr<Node> op_cast(shared_ptr<op::v1::AvgPool> node)
106     {
107         auto const input_arg = node->input_value(0);
108         const auto ceil_mode = static_cast<bool>(node->get_rounding_type());
109         const auto include_padding_in_avg_computation = !node->get_exclude_pad();
110         const auto pad_type = node->get_auto_pad();
111         const auto padding_below = node->get_pads_begin();
112         const auto padding_above = node->get_pads_end();
113         const auto window_movement_strides = node->get_strides();
114         const auto window_shape = node->get_kernel();
115
116         auto replacement_node = make_shared<op::v0::AvgPool>(input_arg,
117                                                              window_shape,
118                                                              window_movement_strides,
119                                                              padding_below,
120                                                              padding_above,
121                                                              include_padding_in_avg_computation,
122                                                              pad_type,
123                                                              ceil_mode);
124         replace_node(node, replacement_node);
125         return replacement_node;
126     }
127
128     shared_ptr<Node> op_cast(shared_ptr<op::v1::Convolution> node)
129     {
130         const auto data_arg = node->input_value(0);
131         const auto filters_arg = node->input_value(1);
132         const auto strides = node->get_strides();
133         const size_t num_spatial_dims = strides.size();
134         auto replacement_node = make_shared<op::v0::Convolution>(data_arg,
135                                                                  filters_arg,
136                                                                  node->get_strides(),
137                                                                  node->get_dilations(),
138                                                                  node->get_pads_begin(),
139                                                                  node->get_pads_end(),
140                                                                  Strides(num_spatial_dims, 1),
141                                                                  node->get_auto_pad());
142         replace_node(node, replacement_node);
143         return replacement_node;
144     }
145
146     shared_ptr<Node> op_cast(shared_ptr<op::v1::ConvolutionBackpropData> node)
147     {
148         const auto data_arg = node->input_value(0);
149         const auto filters_arg = node->input_value(1);
150
151         auto data_pshape = data_arg.get_partial_shape();
152         auto filters_pshape = filters_arg.get_partial_shape();
153
154         NGRAPH_CHECK(data_pshape.rank().is_static() && data_pshape[0].is_static() &&
155                          filters_pshape.rank().is_static() && filters_pshape[1].is_static(),
156                      "Unable to convert ConvolutionBackpropData:v1 to ConvolutionBackpropData:v0 "
157                      "if data shape N and filters shape C dimensions are not static. Node: ",
158                      *node);
159
160         const size_t num_spatial_dims = data_pshape.rank().get_length() - 2;
161
162         const PartialShape output_pshape{node->get_output_partial_shape(0)};
163         NGRAPH_CHECK(output_pshape.is_static(),
164                      "Unable to convert ConvolutionBackpropData:v1 to ConvolutionBackpropData:v0 "
165                      "if output shape is dynamic. Node: ",
166                      *node);
167         Shape output_shape = output_pshape.to_shape();
168
169         auto replacement_node =
170             make_shared<op::v0::ConvolutionBackpropData>(output_shape,
171                                                          filters_arg,
172                                                          data_arg,
173                                                          node->get_strides(),
174                                                          node->get_dilations(),
175                                                          node->get_pads_begin(),
176                                                          node->get_pads_end(),
177                                                          Strides(num_spatial_dims, 1));
178         replace_node(node, replacement_node);
179         return replacement_node;
180     }
181
182     shared_ptr<Node> op_cast(shared_ptr<op::v1::Divide> node)
183     {
184         const auto input_arg0 = node->input_value(0);
185         const auto input_arg1 = node->input_value(1);
186         const auto autob = node->get_autob();
187         const bool pydiv = node->is_pythondiv();
188         auto replacement_node = make_shared<op::v0::Divide>(input_arg0, input_arg1, pydiv, autob);
189         replace_node(node, replacement_node);
190         return replacement_node;
191     }
192
193     shared_ptr<Node> op_cast(shared_ptr<op::v1::Reshape> node)
194     {
195         shared_ptr<Node> replacement_node;
196
197         const auto target_shape_input = node->input_value(1).get_node_shared_ptr();
198         const auto input_rank = node->get_input_partial_shape(0).rank();
199         if (op::is_constant(target_shape_input) && node->get_output_partial_shape(0).is_static() &&
200             input_rank.is_static())
201         {
202             const auto output_shape = node->get_output_shape(0);
203             replacement_node = make_shared<op::Reshape>(
204                 node->input_value(0), get_default_order(input_rank.get_length()), output_shape);
205         }
206         else
207         {
208             NGRAPH_CHECK(replacement_node, "Unable to convert Reshape:v1 with dynamic shape.");
209         }
210
211         replace_node(node, replacement_node);
212         return replacement_node;
213     }
214
215     shared_ptr<Node> op_cast(shared_ptr<op::v1::Equal> node)
216     {
217         return op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
218     }
219
220     shared_ptr<Node> op_cast(shared_ptr<op::v1::Gather> node)
221     {
222         auto axis_node = as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
223
224         NGRAPH_CHECK(axis_node,
225                      "Unable to convert Gather:v1 to Gather:v0 if axis is not constant. Node: ",
226                      *node);
227
228         NGRAPH_CHECK(
229             axis_node->get_element_type() == element::i64,
230             "Unable to convert Gather:v1 to Gather:v0 with axis other type than int64. Node: ",
231             *node);
232
233         int64_t axis = axis_node->get_vector<int64_t>()[0];
234
235         auto replacement_node =
236             make_shared<op::v0::Gather>(node->input_value(0), node->input_value(1), axis);
237         replace_node(node, replacement_node);
238         return replacement_node;
239     }
240
241     shared_ptr<Node> op_cast(shared_ptr<op::v1::Greater> node)
242     {
243         return op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
244     }
245
246     shared_ptr<Node> op_cast(shared_ptr<op::v1::GreaterEqual> node)
247     {
248         return op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node);
249     }
250
251     shared_ptr<Node> op_cast(shared_ptr<op::v1::GroupConvolution> node)
252     {
253         const auto data_arg = node->input_value(0);
254         const auto filters_arg = node->input_value(1);
255         const auto strides = node->get_strides();
256         const size_t num_spatial_dims = strides.size();
257         auto replacement_node = make_shared<op::v0::GroupConvolution>(data_arg,
258                                                                       filters_arg,
259                                                                       node->get_strides(),
260                                                                       node->get_dilations(),
261                                                                       node->get_pads_begin(),
262                                                                       node->get_pads_end(),
263                                                                       Strides(num_spatial_dims, 1),
264                                                                       node->get_auto_pad());
265         replace_node(node, replacement_node);
266         return replacement_node;
267     }
268
269     shared_ptr<Node> op_cast(shared_ptr<op::v1::GroupConvolutionBackpropData> node)
270     {
271         const auto data_arg = node->input_value(0);
272         const auto filters_arg = node->input_value(1);
273
274         NGRAPH_CHECK(data_arg.get_partial_shape().is_static(),
275                      "Unable to convert GroupConvolutionBackpropData:1 to "
276                      "GroupConvolutionBackpropData:0 with dynamic data shape. Node: ",
277                      *node);
278
279         NGRAPH_CHECK(filters_arg.get_partial_shape().is_static(),
280                      "Unable to convert GroupConvolutionBackpropData:1 to "
281                      "GroupConvolutionBackpropData:0 with dynamic filters shape. Node: ",
282                      *node);
283
284         auto filters_shape = filters_arg.get_shape();
285         const size_t groups = filters_shape.at(0);
286
287         const PartialShape output_pshape{node->get_output_partial_shape(0)};
288         NGRAPH_CHECK(output_pshape.is_static(),
289                      "Unable to convert GroupConvolutionBackpropData:v1 to "
290                      "GroupConvolutionBackpropData:v0 "
291                      "if output_shape is dynamic. Node: ",
292                      *node);
293         Shape output_shape = output_pshape.to_shape();
294
295         // Convert filters data layout from [GROUPS, C_INPUT, C_OUTPUT, K_D, ..., K_1]
296         // into [C x M/group x k1 x k2 x ... x kn]
297         filters_shape.erase(filters_shape.begin());
298         filters_shape[0] *= groups;
299
300         auto reshaped_filters = builder::opset1::reshape(node->input_value(1), filters_shape);
301
302         auto replacement_node = make_shared<op::v0::GroupConvolutionBackpropData>(
303             op::Constant::create(data_arg.get_element_type(), output_shape, {0}),
304             reshaped_filters,
305             data_arg,
306             node->get_strides(),
307             node->get_dilations(),
308             node->get_pads_begin(),
309             node->get_pads_end(),
310             groups);
311         replace_node(node, replacement_node);
312         return replacement_node;
313     }
314
315     shared_ptr<Node> op_cast(shared_ptr<op::v1::Less> node)
316     {
317         return op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
318     }
319
320     shared_ptr<Node> op_cast(shared_ptr<op::v1::LessEqual> node)
321     {
322         return op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
323     }
324
325     shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalNot> node)
326     {
327         auto replacement_node = make_shared<op::v0::Not>(node->input_value(0));
328         replace_node(node, replacement_node);
329         return replacement_node;
330     }
331
332     shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalOr> node)
333     {
334         return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
335     }
336
337     shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalXor> node)
338     {
339         return op_cast_binary_elementwise_node<op::v0::Xor, op::v1::LogicalXor>(node);
340     }
341
342     shared_ptr<Node> op_cast(shared_ptr<op::v1::Maximum> node)
343     {
344         return op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
345     }
346
347     shared_ptr<Node> op_cast(shared_ptr<op::v1::Minimum> node)
348     {
349         return op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
350     }
351
352     shared_ptr<Node> op_cast(shared_ptr<op::v1::Multiply> node)
353     {
354         return op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
355     }
356
357     shared_ptr<Node> op_cast(shared_ptr<op::v1::NotEqual> node)
358     {
359         return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
360     }
361
362     shared_ptr<Node> op_cast(shared_ptr<op::v1::OneHot> node)
363     {
364         const auto indices = node->input_value(0);
365         const auto depth = node->input_value(1).get_node();
366         auto on_value = node->input_value(2);
367         auto off_value = node->input_value(3);
368         const auto axis = node->get_axis();
369
370         NGRAPH_CHECK(op::is_constant(depth), "depth input must be constant", *node);
371         const auto output_pshape = node->get_output_partial_shape(0);
372         NGRAPH_CHECK(output_pshape.is_static(), "output shape must be static", *node);
373         const auto output_shape = output_pshape.to_shape();
374
375         auto one_hot = std::make_shared<ngraph::op::Convert>(
376             std::make_shared<ngraph::op::OneHot>(indices, output_shape, axis),
377             on_value.get_element_type());
378
379         auto broadcasted_values = builder::numpy_broadcast_outputs({one_hot, on_value, off_value});
380         on_value = broadcasted_values[1];
381         off_value = broadcasted_values[2];
382
383         auto replacement_node = one_hot * (on_value - off_value) + off_value;
384
385         replace_node(node, replacement_node);
386         return replacement_node;
387     }
388
389     shared_ptr<Node> op_cast(shared_ptr<op::v1::Power> node)
390     {
391         return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
392     }
393
394     shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMax> node)
395     {
396         auto replacement_node = op_cast_reduction_node<op::v0::Max, op::v1::ReduceMax>(node);
397         replace_node(node, replacement_node);
398         return replacement_node;
399     }
400
401     shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMean> node)
402     {
403         // ReduceMean = Sum / Count
404         auto sum_node = op_cast_reduction_node<op::v0::Sum, op::v1::ReduceMean>(node);
405
406         // Count = Sum(Constant(1, shape=data.shape))
407         const auto data = node->input_value(0);
408         const auto axes = node->input_value(1);
409         const auto const_node =
410             op::v0::Constant::create(data.get_element_type(), data.get_shape(), {1});
411         std::shared_ptr<Node> count_node = std::make_shared<op::v0::Sum>(const_node, axes);
412
413         // Support keep_dims attribute
414         if (node->get_keep_dims())
415         {
416             // In order to keep the original dimensions we need to reshape the Count node
417             // before we use it in Divide with NUMPY broadcast
418             auto output_shape = count_node->get_shape();
419             auto reshaped_output_shape = output_shape;
420             for (const auto& axis : node->get_reduction_axes())
421             {
422                 reshaped_output_shape.insert(reshaped_output_shape.begin() + axis, 1);
423             }
424             count_node = make_shared<op::Reshape>(
425                 count_node->output(0), get_default_order(output_shape), reshaped_output_shape);
426         }
427
428         const auto replacement_node =
429             std::make_shared<op::v0::Divide>(sum_node, count_node, op::AutoBroadcastSpec::NUMPY);
430         replace_node(node, replacement_node);
431         return replacement_node;
432     }
433
434     shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMin> node)
435     {
436         auto replacement_node = op_cast_reduction_node<op::v0::Min, op::v1::ReduceMin>(node);
437         replace_node(node, replacement_node);
438         return replacement_node;
439     }
440
441     shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceProd> node)
442     {
443         auto replacement_node = op_cast_reduction_node<op::v0::Product, op::v1::ReduceProd>(node);
444         replace_node(node, replacement_node);
445         return replacement_node;
446     }
447
448     shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceSum> node)
449     {
450         auto replacement_node = op_cast_reduction_node<op::v0::Sum, op::v1::ReduceSum>(node);
451         replace_node(node, replacement_node);
452         return replacement_node;
453     }
454
455     shared_ptr<Node> op_cast(shared_ptr<op::v1::Reverse> node)
456     {
457         auto axes_node = node->input_value(1).get_node_shared_ptr();
458         NGRAPH_CHECK(op::is_constant(axes_node),
459                      "Unable to convert Reverse:v1 to Reverse:v0 "
460                      "if reduction axes are not constant. Node: ",
461                      *node);
462         const auto axes_node_const = as_type_ptr<op::Constant>(axes_node);
463         AxisSet axes{};
464         if (node->get_mode() == op::v1::Reverse::Mode::INDEX)
465         {
466             axes = axes_node_const->get_axis_vector_val();
467         }
468         else // Mode::MASK
469         {
470             auto axes_mask = axes_node_const->get_vector<bool>();
471             for (size_t i = 0; i < axes_mask.size(); ++i)
472             {
473                 if (axes_mask[i])
474                 {
475                     axes.emplace(i);
476                 }
477             }
478         }
479         auto replacement_node = make_shared<op::v0::Reverse>(node->input_value(0), axes);
480
481         replace_node(node, replacement_node);
482         return replacement_node;
483     }
484
485     shared_ptr<Node> op_cast(shared_ptr<op::v1::Select> node)
486     {
487         ngraph::pass::ImplicitBroadcastElimination().run_on_node(node);
488         auto replacement_node = make_shared<op::v0::Select>(
489             node->input_value(0), node->input_value(1), node->input_value(2));
490         replace_node(node, replacement_node);
491         return replacement_node;
492     }
493
494     shared_ptr<Node> op_cast(shared_ptr<op::v1::StridedSlice> node)
495     {
496         auto convert_mask_to_axes = [](const std::vector<int64_t>& mask) {
497             AxisSet axes{};
498             for (auto i = 0; i < mask.size(); ++i)
499             {
500                 if (mask[i] == 1)
501                 {
502                     axes.emplace(i);
503                 }
504             }
505             return axes;
506         };
507
508         const auto input_data = node->input_value(0);
509         const auto input_data_pshape = input_data.get_partial_shape();
510
511         NGRAPH_CHECK(input_data_pshape.is_static(),
512                      "Unable to convert StridedSlice:v1 to Slice:v0 "
513                      "if input rank is not static. Node: ",
514                      *node);
515
516         const auto begin_const =
517             as_type_ptr<op::Constant>(node->input_value(1).get_node_shared_ptr());
518         const auto end_const =
519             as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
520         const auto strides = as_type_ptr<op::Constant>(node->input_value(3).get_node_shared_ptr());
521
522         NGRAPH_CHECK(begin_const && end_const && strides,
523                      "Unable to convert StridedSlice:v1 to Slice:v0 "
524                      "if begin, end or strides are not constant. Node: ",
525                      *node);
526
527         SlicePlan p = make_slice_plan(input_data_pshape.to_shape(),
528                                       begin_const->get_vector<int64_t>(),
529                                       end_const->get_vector<int64_t>(),
530                                       strides->get_vector<int64_t>(),
531                                       convert_mask_to_axes(node->get_begin_mask()),
532                                       convert_mask_to_axes(node->get_end_mask()),
533                                       convert_mask_to_axes(node->get_new_axis_mask()),
534                                       convert_mask_to_axes(node->get_shrink_axis_mask()),
535                                       convert_mask_to_axes(node->get_ellipsis_mask()));
536
537         shared_ptr<Node> replacement_node =
538             make_shared<op::v0::Slice>(input_data,
539                                        Coordinate(p.begins.begin(), p.begins.end()),
540                                        Coordinate(p.ends.begin(), p.ends.end()),
541                                        Strides(p.strides.begin(), p.strides.end()));
542
543         if (p.reshape_in_shape != p.reshape_out_shape)
544         {
545             replacement_node =
546                 make_shared<op::Reshape>(replacement_node,
547                                          ngraph::get_default_order(p.reshape_in_shape),
548                                          p.reshape_out_shape);
549         }
550
551         if (!p.reverse_axes.empty())
552         {
553             replacement_node = make_shared<op::Reverse>(replacement_node, p.reverse_axes);
554         }
555
556         replace_node(node, replacement_node);
557         return replacement_node;
558     }
559
560     shared_ptr<Node> op_cast(shared_ptr<op::v1::Split> node)
561     {
562         const auto num_splits = node->get_num_splits();
563
564         auto replacement_node =
565             make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), num_splits);
566
567         replace_node(node, replacement_node);
568         return replacement_node;
569     }
570
571     shared_ptr<Node> op_cast(shared_ptr<op::v1::Subtract> node)
572     {
573         return op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
574     }
575
576     shared_ptr<Node> op_cast(shared_ptr<op::v1::TopK> node)
577     {
578         const auto axis = node->get_axis();
579         const auto sort_type = node->get_sort_type();
580         const auto index_elem_type = node->get_index_element_type();
581
582         bool compute_max;
583         switch (node->get_mode())
584         {
585         case op::v1::TopK::Mode::MAX: compute_max = true; break;
586         case op::v1::TopK::Mode::MIN: compute_max = false; break;
587         default: break;
588         }
589
590         const auto arg_node = node->input_value(0);
591         const auto k_node = node->input_value(1);
592
593         auto replacement_node = make_shared<op::v0::TopK>(
594             arg_node, k_node, axis, index_elem_type, compute_max, sort_type);
595
596         // values output will be 0, indices 1
597         vector<int64_t> output_order{1, 0};
598         replace_node(node, replacement_node, output_order);
599         return replacement_node;
600     }
601
602     shared_ptr<Node> op_cast(shared_ptr<op::v1::Transpose> node)
603     {
604         const auto data = node->input_value(0);
605
606         const auto data_pshape = data.get_partial_shape();
607         NGRAPH_CHECK(data_pshape.is_static(),
608                      "Unable to convert Transpose:v1 to Reshape:v0 "
609                      "if data shape is dynamic. Node: ",
610                      *node);
611         const auto data_shape = data_pshape.to_shape();
612
613         const auto order_node = node->input_value(1).get_node_shared_ptr();
614         NGRAPH_CHECK(op::is_constant(order_node),
615                      "Unable to convert Transpose:v1 to Reshape:v0 "
616                      "if order node is not constant. Node: ",
617                      *node);
618         const auto order_const = as_type_ptr<op::Constant>(order_node);
619
620         auto order = order_const->get_axis_vector_val();
621         Shape out_shape = data_shape;
622         if (order.empty())
623         {
624             order.resize(out_shape.size());
625             iota(begin(order), end(order), 0);
626         }
627         else
628         {
629             for (size_t i = 0; i < order.size(); ++i)
630             {
631                 out_shape[i] = data_shape.at(order.at(i));
632             }
633         }
634
635         auto replacement_node = make_shared<op::v0::Reshape>(data, order, out_shape);
636         replace_node(node, replacement_node);
637         return replacement_node;
638     }
639
640     shared_ptr<Node> op_cast(shared_ptr<op::v1::VariadicSplit> node)
641     {
642         const auto split_lengths = node->input_value(2).get_node_shared_ptr();
643
644         NGRAPH_CHECK(op::is_constant(split_lengths),
645                      "Unable to convert VariadicSplit:v1 to Split:v0 "
646                      "if 'split_lengths' input is not constant. Node: ",
647                      *node);
648
649         const auto splits = as_type_ptr<op::Constant>(split_lengths)->cast_vector<int64_t>();
650         const std::vector<size_t> splits_unsigned{splits.begin(), splits.end()};
651
652         auto replacement_node =
653             make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), splits_unsigned);
654
655         replace_node(node, replacement_node);
656         return replacement_node;
657     }
658
659     using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
660
661     template <typename T>
662     bool op_cast_thunk(shared_ptr<Node> node)
663     {
664         auto downgraded_node = op_cast(as_type_ptr<T>(node));
665         if (downgraded_node)
666         {
667             if (ngraph::get_provenance_enabled())
668             {
669                 const std::string provenance_tag =
670                     "<Opset0_Downgrade (v1 " + std::string(node->get_type_name()) + ")>";
671                 downgraded_node->add_provenance_tags_above(node->input_values(), {provenance_tag});
672             }
673             return true;
674         }
675         return false;
676     }
677
678     DispatchMap& get_dispatch_map()
679     {
680         static DispatchMap dispatch_map{
681 #define NGRAPH_OP(NAME, NAMESPACE) {NAMESPACE::NAME::type_info, op_cast_thunk<NAMESPACE::NAME>},
682 #include "ngraph/opsets/opset1_tbl.hpp"
683 #undef NGRAPH_OP
684         };
685         return dispatch_map;
686     }
687 } // namespace opset0_downgrade
688
689 bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
690 {
691     bool modified = false;
692     auto& dispatch_map = opset0_downgrade::get_dispatch_map();
693     auto it = dispatch_map.find(node->get_type_info());
694     if (it != dispatch_map.end())
695     {
696         modified = it->second(node);
697     }
698     return modified;
699 }